Belle II Software  release-08-01-10
xgboost.py
1 #!/usr/bin/env python3
2 
3 
10 
11 import numpy as np
12 
13 try:
14  import xgboost as xgb
15 except ImportError:
16  print("Please install xgboost: pip3 install xgboost")
17  import sys
18  sys.exit(1)
19 
20 import os
21 import tempfile
22 import collections
23 
24 
25 class State:
26  """
27  XGBoost state
28  """
29 
30  def __init__(self, num_round=0, parameters=None):
31  """ Constructor of the state object """
32 
33  self.parametersparameters = parameters
34 
35  self.num_roundnum_round = num_round
36 
37  self.estimatorestimator = None
38 
39 
40 def get_model(number_of_features, number_of_spectators, number_of_events, training_fraction, parameters):
41  """
42  Return default xgboost model
43  """
44  param = {'bst:max_depth': 2, 'bst:eta': 1, 'silent': 1, 'objective': 'binary:logistic'}
45  nTrees = 100
46  if isinstance(parameters, collections.Mapping):
47  if 'nTrees' in parameters:
48  nTrees = parameters.pop('nTrees')
49  param.update(parameters)
50  return State(nTrees, param)
51 
52 
53 def feature_importance(state):
54  """
55  Return a list containing the feature importances
56  """
57  return []
58 
59 
60 def load(obj):
61  """
62  Load XGBoost estimator into state
63  """
64  state = State()
65  f = tempfile.NamedTemporaryFile(delete=False)
66  f.write(obj)
67  f.close()
68  state.estimator = xgb.Booster({})
69  state.estimator.load_model(f.name)
70  os.unlink(f.name)
71  return state
72 
73 
74 def apply(state, X):
75  """
76  Apply estimator to passed data.
77  """
78  data = xgb.DMatrix(X)
79  result = state.estimator.predict(data)
80  return np.require(result, dtype=np.float32, requirements=['A', 'W', 'C', 'O'])
81 
82 
83 def begin_fit(state, Xtest, Stest, ytest, wtest, nBatches):
84  """
85  Initialize lists which will store the received data
86  """
87  state.X = []
88  state.y = []
89  state.w = []
90  state.Xtest = Xtest
91  state.ytest = ytest.flatten()
92  state.wtest = wtest.flatten()
93  return state
94 
95 
96 def partial_fit(state, X, S, y, w, epoch, batch):
97  """
98  Stores received training data.
99  XGBoost is usually not able to perform a partial fit.
100  """
101  state.X.append(X)
102  state.y.append(y.flatten())
103  state.w.append(w.flatten())
104  return True
105 
106 
107 def end_fit(state):
108  """
109  Merge received data together and fit estimator
110  """
111  dtrain = xgb.DMatrix(np.vstack(state.X), label=np.hstack(state.y).astype(int), weight=np.hstack(state.w))
112 
113  if len(state.Xtest) > 0:
114  dtest = xgb.DMatrix(state.Xtest, label=state.ytest.astype(int), weight=state.wtest)
115  evallist = [(dtest, 'eval'), (dtrain, 'train')]
116  else:
117  evallist = [(dtrain, 'train')]
118 
119  state.estimator = xgb.train(state.parameters, dtrain, state.num_round, evallist)
120  f = tempfile.NamedTemporaryFile(delete=False)
121  f.close()
122  state.estimator.save_model(f.name)
123  with open(f.name, 'rb') as f2:
124  content = f2.read()
125  os.unlink(f.name)
126  return content
num_round
Number of boosting rounds used in xgboost training.
Definition: xgboost.py:35
def __init__(self, num_round=0, parameters=None)
Definition: xgboost.py:30
parameters
Parameters passed to xgboost model.
Definition: xgboost.py:33