Belle II Software development
xgboost.py
1#!/usr/bin/env python3
2
3
10
11import numpy as np
12
13try:
14 import xgboost as xgb
15except ImportError:
16 print("Please install xgboost: pip3 install xgboost")
17 import sys
18 sys.exit(1)
19
20import os
21import tempfile
22import collections
23
24
25class State:
26 """
27 XGBoost state
28 """
29
30 def __init__(self, num_round=0, parameters=None):
31 """ Constructor of the state object """
32
33 self.parameters = parameters
34
35 self.num_round = num_round
36
37 self.estimator = None
38
39
40def get_model(number_of_features, number_of_spectators, number_of_events, training_fraction, parameters):
41 """
42 Return default xgboost model
43 """
44 param = {'bst:max_depth': 2, 'bst:eta': 1, 'silent': 1, 'objective': 'binary:logistic'}
45 nTrees = 100
46 if isinstance(parameters, collections.abc.Mapping):
47 if 'nTrees' in parameters:
48 nTrees = parameters.pop('nTrees')
49 param.update(parameters)
50 return State(nTrees, param)
51
52
53def feature_importance(state):
54 """
55 Return a list containing the feature importances
56 """
57 return []
58
59
60def load(obj):
61 """
62 Load XGBoost estimator into state
63 """
64 state = State()
65 f = tempfile.NamedTemporaryFile(delete=False)
66 f.write(obj)
67 f.close()
68 state.estimator = xgb.Booster({})
69 state.estimator.load_model(f.name)
70 os.unlink(f.name)
71 return state
72
73
74def apply(state, X):
75 """
76 Apply estimator to passed data.
77 """
78 data = xgb.DMatrix(X)
79 result = state.estimator.predict(data)
80 return np.require(result, dtype=np.float32, requirements=['A', 'W', 'C', 'O'])
81
82
83def begin_fit(state, Xtest, Stest, ytest, wtest, nBatches):
84 """
85 Initialize lists which will store the received data
86 """
87 state.X = []
88 state.y = []
89 state.w = []
90 state.Xtest = Xtest
91 state.ytest = ytest.flatten()
92 state.wtest = wtest.flatten()
93 return state
94
95
96def partial_fit(state, X, S, y, w, epoch, batch):
97 """
98 Stores received training data.
99 XGBoost is usually not able to perform a partial fit.
100 """
101 state.X.append(X)
102 state.y.append(y.flatten())
103 state.w.append(w.flatten())
104 return True
105
106
107def end_fit(state):
108 """
109 Merge received data together and fit estimator
110 """
111 dtrain = xgb.DMatrix(np.vstack(state.X), label=np.hstack(state.y).astype(int), weight=np.hstack(state.w))
112
113 if len(state.Xtest) > 0:
114 dtest = xgb.DMatrix(state.Xtest, label=state.ytest.astype(int), weight=state.wtest)
115 evallist = [(dtest, 'eval'), (dtrain, 'train')]
116 else:
117 evallist = [(dtrain, 'train')]
118
119 state.estimator = xgb.train(state.parameters, dtrain, state.num_round, evallist)
120 f = tempfile.NamedTemporaryFile(delete=False)
121 f.close()
122 state.estimator.save_model(f.name)
123 with open(f.name, 'rb') as f2:
124 content = f2.read()
125 os.unlink(f.name)
126 return content
num_round
Number of boosting rounds used in xgboost training.
Definition: xgboost.py:35
def __init__(self, num_round=0, parameters=None)
Definition: xgboost.py:30
parameters
Parameters passed to xgboost model.
Definition: xgboost.py:33