Belle II Software  release-05-02-19
xgboost.py
1 #!/usr/bin/env python3
2 # -*- coding: utf-8 -*-
3 
4 # Thomas Keck 2016
5 
6 import numpy as np
7 
8 try:
9  import xgboost as xgb
10 except ImportError:
11  print("Please install xgboost: pip3 install xgboost")
12  import sys
13  sys.exit(1)
14 
15 import os
16 import tempfile
17 import collections
18 
19 
20 class State(object):
21  """
22  XGBoost state
23  """
24  def __init__(self, num_round=0, parameters=None):
25  """ Constructor of the state object """
26 
27  self.parameters = parameters
28 
29  self.num_round = num_round
30 
31  self.estimator = None
32 
33 
34 def get_model(number_of_features, number_of_spectators, number_of_events, training_fraction, parameters):
35  """
36  Return default xgboost model
37  """
38  param = {'bst:max_depth': 2, 'bst:eta': 1, 'silent': 1, 'objective': 'binary:logistic'}
39  nTrees = 100
40  if 'nTrees' in parameters:
41  nTrees = parameters['nTrees']
42  del parameters['nTrees']
43  if isinstance(parameters, collections.Mapping):
44  param.update(parameters)
45  return State(nTrees, param)
46 
47 
48 def feature_importance(state):
49  """
50  Return a list containing the feature importances
51  """
52  return []
53 
54 
55 def load(obj):
56  """
57  Load XGBoost estimator into state
58  """
59  state = State()
60  f = tempfile.NamedTemporaryFile(delete=False)
61  f.write(obj)
62  f.close()
63  state.estimator = xgb.Booster({})
64  state.estimator.load_model(f.name)
65  os.unlink(f.name)
66  return state
67 
68 
69 def apply(state, X):
70  """
71  Apply estimator to passed data.
72  """
73  data = xgb.DMatrix(X)
74  result = state.estimator.predict(data)
75  return np.require(result, dtype=np.float32, requirements=['A', 'W', 'C', 'O'])
76 
77 
78 def begin_fit(state, Xtest, Stest, ytest, wtest):
79  """
80  Initialize lists which will store the received data
81  """
82  state.X = []
83  state.y = []
84  state.w = []
85  state.Xtest = Xtest
86  state.ytest = ytest.flatten()
87  state.wtest = wtest.flatten()
88  return state
89 
90 
91 def partial_fit(state, X, S, y, w, epoch):
92  """
93  Stores received training data.
94  XGBoost is usually not able to perform a partial fit.
95  """
96  state.X.append(X)
97  state.y.append(y.flatten())
98  state.w.append(w.flatten())
99  return True
100 
101 
102 def end_fit(state):
103  """
104  Merge received data together and fit estimator
105  """
106  dtrain = xgb.DMatrix(np.vstack(state.X), label=np.hstack(state.y).astype(int), weight=np.hstack(state.w))
107 
108  if len(state.Xtest) > 0:
109  dtest = xgb.DMatrix(state.Xtest, label=state.ytest.astype(int), weight=state.wtest)
110  evallist = [(dtest, 'eval'), (dtrain, 'train')]
111  else:
112  evallist = [(dtrain, 'train')]
113 
114  state.estimator = xgb.train(state.parameters, dtrain, state.num_round, evallist)
115  f = tempfile.NamedTemporaryFile(delete=False)
116  f.close()
117  state.estimator.save_model(f.name)
118  with open(f.name, 'rb') as f2:
119  content = f2.read()
120  os.unlink(f.name)
121  return content
basf2_mva_python_interface.xgboost.State
Definition: xgboost.py:20
basf2_mva_python_interface.xgboost.State.__init__
def __init__(self, num_round=0, parameters=None)
Definition: xgboost.py:24
basf2_mva_python_interface.xgboost.State.parameters
parameters
Parameters passed to xgboost model.
Definition: xgboost.py:27
basf2_mva_python_interface.xgboost.State.estimator
estimator
XGBoost estimator.
Definition: xgboost.py:31
basf2_mva_python_interface.xgboost.State.num_round
num_round
Number of boosting rounds used in xgboost training.
Definition: xgboost.py:29