Belle II Software  light-2212-foldex
sklearn.py
1 #!/usr/bin/env python3
2 # -*- coding: utf-8 -*-
3 
4 
11 
12 import numpy as np
13 from basf2 import B2WARNING
14 
15 try:
16  import sklearn # noqa
17 except ImportError:
18  print("Please install sklearn: pip3 install sklearn")
19  import sys
20  sys.exit(1)
21 
22 import collections
23 
24 
25 class State(object):
26  """
27  SKLearn state
28  """
29 
30  def __init__(self, estimator=None):
31  """ Constructor of the state object """
32 
33  self.estimatorestimator = estimator
34 
35 
36 def get_model(number_of_features, number_of_spectators, number_of_events, training_fraction, parameters):
37  """
38  Create SKLearn classifier and store it in a State object
39  """
40  from sklearn.ensemble import GradientBoostingClassifier
41  if isinstance(parameters, collections.Mapping):
42  clf = GradientBoostingClassifier(**parameters)
43  elif isinstance(parameters, collections.Sequence):
44  clf = GradientBoostingClassifier(*parameters)
45  else:
46  clf = GradientBoostingClassifier()
47  return State(clf)
48 
49 
50 def feature_importance(state):
51  """
52  Return a list containing the feature importances
53  """
54  from sklearn.ensemble import GradientBoostingClassifier
55  if isinstance(state.estimator, GradientBoostingClassifier):
56  return [x for x in state.estimator.feature_importances_]
57  return []
58 
59 
60 def load(obj):
61  """
62  Load sklearn estimator into state
63  """
64  return State(obj)
65 
66 
67 def apply(state, X):
68  """
69  Apply estimator to passed data.
70  If the estimator has a predict_proba it is called, otherwise call just predict.
71  """
72  if hasattr(state.estimator, 'predict_proba'):
73  x = state.estimator.predict_proba(X)
74  if x.shape[1] == 2:
75  x = state.estimator.predict_proba(X)[:, 1]
76  else:
77  x = state.estimator.predict(X)
78  return np.require(x, dtype=np.float32, requirements=['A', 'W', 'C', 'O'])
79 
80 
81 def begin_fit(state, Xtest, Stest, ytest, wtest, nBatches):
82  """
83  Initialize lists which will store the received data
84  """
85  state.X = []
86  state.y = []
87  state.w = []
88  return state
89 
90 
91 def partial_fit(state, X, S, y, w, epoch, batch):
92  """
93  Stores received training data.
94  SKLearn is usually not able to perform a partial fit.
95  """
96  if epoch > 0:
97  B2WARNING("The sklearn training interface has been called with specific_options.m_nIterations > 1."
98  " This means duplicates of the training sample will be used during training.")
99 
100  state.X.append(X)
101  state.y.append(y.flatten())
102  state.w.append(w.flatten())
103  return True
104 
105 
106 def end_fit(state):
107  """
108  Merge received data together and fit estimator
109  """
110  state.estimator = state.estimator.fit(np.vstack(state.X), np.hstack(state.y), np.hstack(state.w))
111  return state.estimator
estimator
Pickable sklearn estimator.
Definition: sklearn.py:33
def __init__(self, estimator=None)
Definition: sklearn.py:30