Belle II Software  release-06-02-00
xgboost.py
1 #!/usr/bin/env python3
2 # -*- coding: utf-8 -*-
3 
4 
11 
12 import numpy as np
13 
14 try:
15  import xgboost as xgb
16 except ImportError:
17  print("Please install xgboost: pip3 install xgboost")
18  import sys
19  sys.exit(1)
20 
21 import os
22 import tempfile
23 import collections
24 
25 
26 class State(object):
27  """
28  XGBoost state
29  """
30  def __init__(self, num_round=0, parameters=None):
31  """ Constructor of the state object """
32 
33  self.parametersparameters = parameters
34 
35  self.num_roundnum_round = num_round
36 
37  self.estimatorestimator = None
38 
39 
40 def get_model(number_of_features, number_of_spectators, number_of_events, training_fraction, parameters):
41  """
42  Return default xgboost model
43  """
44  param = {'bst:max_depth': 2, 'bst:eta': 1, 'silent': 1, 'objective': 'binary:logistic'}
45  nTrees = 100
46  if 'nTrees' in parameters:
47  nTrees = parameters['nTrees']
48  del parameters['nTrees']
49  if isinstance(parameters, collections.Mapping):
50  param.update(parameters)
51  return State(nTrees, param)
52 
53 
54 def feature_importance(state):
55  """
56  Return a list containing the feature importances
57  """
58  return []
59 
60 
61 def load(obj):
62  """
63  Load XGBoost estimator into state
64  """
65  state = State()
66  f = tempfile.NamedTemporaryFile(delete=False)
67  f.write(obj)
68  f.close()
69  state.estimator = xgb.Booster({})
70  state.estimator.load_model(f.name)
71  os.unlink(f.name)
72  return state
73 
74 
75 def apply(state, X):
76  """
77  Apply estimator to passed data.
78  """
79  data = xgb.DMatrix(X)
80  result = state.estimator.predict(data)
81  return np.require(result, dtype=np.float32, requirements=['A', 'W', 'C', 'O'])
82 
83 
84 def begin_fit(state, Xtest, Stest, ytest, wtest):
85  """
86  Initialize lists which will store the received data
87  """
88  state.X = []
89  state.y = []
90  state.w = []
91  state.Xtest = Xtest
92  state.ytest = ytest.flatten()
93  state.wtest = wtest.flatten()
94  return state
95 
96 
97 def partial_fit(state, X, S, y, w, epoch):
98  """
99  Stores received training data.
100  XGBoost is usually not able to perform a partial fit.
101  """
102  state.X.append(X)
103  state.y.append(y.flatten())
104  state.w.append(w.flatten())
105  return True
106 
107 
108 def end_fit(state):
109  """
110  Merge received data together and fit estimator
111  """
112  dtrain = xgb.DMatrix(np.vstack(state.X), label=np.hstack(state.y).astype(int), weight=np.hstack(state.w))
113 
114  if len(state.Xtest) > 0:
115  dtest = xgb.DMatrix(state.Xtest, label=state.ytest.astype(int), weight=state.wtest)
116  evallist = [(dtest, 'eval'), (dtrain, 'train')]
117  else:
118  evallist = [(dtrain, 'train')]
119 
120  state.estimator = xgb.train(state.parameters, dtrain, state.num_round, evallist)
121  f = tempfile.NamedTemporaryFile(delete=False)
122  f.close()
123  state.estimator.save_model(f.name)
124  with open(f.name, 'rb') as f2:
125  content = f2.read()
126  os.unlink(f.name)
127  return content
num_round
Number of boosting rounds used in xgboost training.
Definition: xgboost.py:35
def __init__(self, num_round=0, parameters=None)
Definition: xgboost.py:30
parameters
Parameters passed to xgboost model.
Definition: xgboost.py:33