16 print(
"Please install xgboost: pip3 install xgboost")
30 def __init__(self, num_round=0, parameters=None):
31 """ Constructor of the state object """
40 def get_model(number_of_features, number_of_spectators, number_of_events, training_fraction, parameters):
42 Return default xgboost model
44 param = {
'bst:max_depth': 2,
'bst:eta': 1,
'silent': 1,
'objective':
'binary:logistic'}
46 if isinstance(parameters, collections.Mapping):
47 if 'nTrees' in parameters:
48 nTrees = parameters.pop(
'nTrees')
49 param.update(parameters)
50 return State(nTrees, param)
53 def feature_importance(state):
55 Return a list containing the feature importances
62 Load XGBoost estimator into state
65 f = tempfile.NamedTemporaryFile(delete=
False)
68 state.estimator = xgb.Booster({})
69 state.estimator.load_model(f.name)
76 Apply estimator to passed data.
79 result = state.estimator.predict(data)
80 return np.require(result, dtype=np.float32, requirements=[
'A',
'W',
'C',
'O'])
83 def begin_fit(state, Xtest, Stest, ytest, wtest, nBatches):
85 Initialize lists which will store the received data
91 state.ytest = ytest.flatten()
92 state.wtest = wtest.flatten()
96 def partial_fit(state, X, S, y, w, epoch, batch):
98 Stores received training data.
99 XGBoost is usually not able to perform a partial fit.
102 state.y.append(y.flatten())
103 state.w.append(w.flatten())
109 Merge received data together and fit estimator
111 dtrain = xgb.DMatrix(np.vstack(state.X), label=np.hstack(state.y).astype(int), weight=np.hstack(state.w))
113 if len(state.Xtest) > 0:
114 dtest = xgb.DMatrix(state.Xtest, label=state.ytest.astype(int), weight=state.wtest)
115 evallist = [(dtest,
'eval'), (dtrain,
'train')]
117 evallist = [(dtrain,
'train')]
119 state.estimator = xgb.train(state.parameters, dtrain, state.num_round, evallist)
120 f = tempfile.NamedTemporaryFile(delete=
False)
122 state.estimator.save_model(f.name)
123 with open(f.name,
'rb')
as f2:
num_round
Number of boosting rounds used in xgboost training.
estimator
XGBoost estimator.
def __init__(self, num_round=0, parameters=None)
parameters
Parameters passed to xgboost model.