Belle II Software  release-06-02-00
sklearn.py
1 #!/usr/bin/env python3
2 # -*- coding: utf-8 -*-
3 
4 
11 
12 import numpy as np
13 
14 try:
15  import sklearn # noqa
16 except ImportError:
17  print("Please install sklearn: pip3 install sklearn")
18  import sys
19  sys.exit(1)
20 
21 import collections
22 
23 
24 class State(object):
25  """
26  SKLearn state
27  """
28 
29  def __init__(self, estimator=None):
30  """ Constructor of the state object """
31 
32  self.estimatorestimator = estimator
33 
34 
35 def get_model(number_of_features, number_of_spectators, number_of_events, training_fraction, parameters):
36  """
37  Create SKLearn classifier and store it in a State object
38  """
39  from sklearn.ensemble import GradientBoostingClassifier
40  if isinstance(parameters, collections.Mapping):
41  clf = GradientBoostingClassifier(**parameters)
42  elif isinstance(parameters, collections.Sequence):
43  clf = GradientBoostingClassifier(*parameters)
44  else:
45  clf = GradientBoostingClassifier()
46  return State(clf)
47 
48 
49 def feature_importance(state):
50  """
51  Return a list containing the feature importances
52  """
53  from sklearn.ensemble import GradientBoostingClassifier
54  if isinstance(state.estimator, GradientBoostingClassifier):
55  return [x for x in state.estimator.feature_importances_]
56  return []
57 
58 
59 def load(obj):
60  """
61  Load sklearn estimator into state
62  """
63  return State(obj)
64 
65 
66 def apply(state, X):
67  """
68  Apply estimator to passed data.
69  If the estimator has a predict_proba it is called, otherwise call just predict.
70  """
71  if hasattr(state.estimator, 'predict_proba'):
72  x = state.estimator.predict_proba(X)[:, 1]
73  else:
74  x = state.estimator.predict(X)
75  return np.require(x, dtype=np.float32, requirements=['A', 'W', 'C', 'O'])
76 
77 
78 def begin_fit(state, X, S, y, w):
79  """
80  Initialize lists which will store the received data
81  """
82  state.X = []
83  state.y = []
84  state.w = []
85  return state
86 
87 
88 def partial_fit(state, X, S, y, w, epoch):
89  """
90  Stores received training data.
91  SKLearn is usually not able to perform a partial fit.
92  """
93  state.X.append(X)
94  state.y.append(y.flatten())
95  state.w.append(w.flatten())
96  return True
97 
98 
99 def end_fit(state):
100  """
101  Merge received data together and fit estimator
102  """
103  state.estimator = state.estimator.fit(np.vstack(state.X), np.hstack(state.y), np.hstack(state.w))
104  return state.estimator
estimator
Pickable sklearn estimator.
Definition: sklearn.py:32
def __init__(self, estimator=None)
Definition: sklearn.py:29