Belle II Software  release-08-01-10
sklearn.py
1 #!/usr/bin/env python3
2 
3 
10 
11 import numpy as np
12 from basf2 import B2WARNING
13 
14 try:
15  import sklearn # noqa
16 except ImportError:
17  print("Please install sklearn: pip3 install sklearn")
18  import sys
19  sys.exit(1)
20 
21 import collections
22 
23 
24 class State:
25  """
26  SKLearn state
27  """
28 
29  def __init__(self, estimator=None):
30  """ Constructor of the state object """
31 
32  self.estimatorestimator = estimator
33 
34 
35 def get_model(number_of_features, number_of_spectators, number_of_events, training_fraction, parameters):
36  """
37  Create SKLearn classifier and store it in a State object
38  """
39  from sklearn.ensemble import GradientBoostingClassifier
40  if isinstance(parameters, collections.Mapping):
41  clf = GradientBoostingClassifier(**parameters)
42  elif isinstance(parameters, collections.Sequence):
43  clf = GradientBoostingClassifier(*parameters)
44  else:
45  clf = GradientBoostingClassifier()
46  return State(clf)
47 
48 
49 def feature_importance(state):
50  """
51  Return a list containing the feature importances
52  """
53  from sklearn.ensemble import GradientBoostingClassifier
54  if isinstance(state.estimator, GradientBoostingClassifier):
55  return [x for x in state.estimator.feature_importances_]
56  return []
57 
58 
59 def load(obj):
60  """
61  Load sklearn estimator into state
62  """
63  return State(obj)
64 
65 
66 def apply(state, X):
67  """
68  Apply estimator to passed data.
69  If the estimator has a predict_proba it is called, otherwise call just predict.
70  """
71  if hasattr(state.estimator, 'predict_proba'):
72  x = state.estimator.predict_proba(X)
73  if x.shape[1] == 2:
74  x = state.estimator.predict_proba(X)[:, 1]
75  else:
76  x = state.estimator.predict(X)
77  return np.require(x, dtype=np.float32, requirements=['A', 'W', 'C', 'O'])
78 
79 
80 def begin_fit(state, Xtest, Stest, ytest, wtest, nBatches):
81  """
82  Initialize lists which will store the received data
83  """
84  state.X = []
85  state.y = []
86  state.w = []
87  return state
88 
89 
90 def partial_fit(state, X, S, y, w, epoch, batch):
91  """
92  Stores received training data.
93  SKLearn is usually not able to perform a partial fit.
94  """
95  if epoch > 0:
96  B2WARNING("The sklearn training interface has been called with specific_options.m_nIterations > 1."
97  " This means duplicates of the training sample will be used during training.")
98 
99  state.X.append(X)
100  state.y.append(y.flatten())
101  state.w.append(w.flatten())
102  return True
103 
104 
105 def end_fit(state):
106  """
107  Merge received data together and fit estimator
108  """
109  state.estimator = state.estimator.fit(np.vstack(state.X), np.hstack(state.y), np.hstack(state.w))
110  return state.estimator
estimator
Pickable sklearn estimator.
Definition: sklearn.py:32
def __init__(self, estimator=None)
Definition: sklearn.py:29