Belle II Software  release-05-01-25
sklearn.py
1 #!/usr/bin/env python3
2 # -*- coding: utf-8 -*-
3 
4 # Thomas Keck 2016
5 
6 import numpy as np
7 
8 try:
9  import sklearn
10 except ImportError:
11  print("Please install sklearn: pip3 install sklearn")
12  import sys
13  sys.exit(1)
14 
15 import collections
16 
17 
18 class State(object):
19  """
20  SKLearn state
21  """
22  def __init__(self, estimator=None):
23  """ Constructor of the state object """
24 
25  self.estimator = estimator
26 
27 
28 def get_model(number_of_features, number_of_spectators, number_of_events, training_fraction, parameters):
29  """
30  Create SKLearn classifier and store it in a State object
31  """
32  from sklearn.ensemble import GradientBoostingClassifier
33  if isinstance(parameters, collections.Mapping):
34  clf = GradientBoostingClassifier(**parameters)
35  elif isinstance(parameters, collections.Sequence):
36  clf = GradientBoostingClassifier(*parameters)
37  else:
38  clf = GradientBoostingClassifier()
39  return State(clf)
40 
41 
42 def feature_importance(state):
43  """
44  Return a list containing the feature importances
45  """
46  from sklearn.ensemble import GradientBoostingClassifier
47  if isinstance(state.estimator, GradientBoostingClassifier):
48  return [x for x in state.estimator.feature_importances_]
49  return []
50 
51 
52 def load(obj):
53  """
54  Load sklearn estimator into state
55  """
56  return State(obj)
57 
58 
59 def apply(state, X):
60  """
61  Apply estimator to passed data.
62  If the estimator has a predict_proba it is called, otherwise call just predict.
63  """
64  if hasattr(state.estimator, 'predict_proba'):
65  x = state.estimator.predict_proba(X)[:, 1]
66  else:
67  x = state.estimator.predict(X)
68  return np.require(x, dtype=np.float32, requirements=['A', 'W', 'C', 'O'])
69 
70 
71 def begin_fit(state, X, S, y, w):
72  """
73  Initialize lists which will store the received data
74  """
75  state.X = []
76  state.y = []
77  state.w = []
78  return state
79 
80 
81 def partial_fit(state, X, S, y, w, epoch):
82  """
83  Stores received training data.
84  SKLearn is usually not able to perform a partial fit.
85  """
86  state.X.append(X)
87  state.y.append(y.flatten())
88  state.w.append(w.flatten())
89  return True
90 
91 
92 def end_fit(state):
93  """
94  Merge received data together and fit estimator
95  """
96  state.estimator = state.estimator.fit(np.vstack(state.X), np.hstack(state.y), np.hstack(state.w))
97  return state.estimator
basf2_mva_python_interface.sklearn.State
Definition: sklearn.py:18
basf2_mva_python_interface.sklearn.State.estimator
estimator
Pickable sklearn estimator.
Definition: sklearn.py:25
basf2_mva_python_interface.sklearn.State.__init__
def __init__(self, estimator=None)
Definition: sklearn.py:22