Belle II Software development
sklearn.py
1#!/usr/bin/env python3
2
3
10
11import numpy as np
12from basf2 import B2WARNING
13
14try:
15 import sklearn # noqa
16except ImportError:
17 print("Please install sklearn: pip3 install sklearn")
18 import sys
19 sys.exit(1)
20
21import collections
22
23
24class State:
25 """
26 SKLearn state
27 """
28
29 def __init__(self, estimator=None):
30 """ Constructor of the state object """
31
32 self.estimator = estimator
33
34
35def get_model(number_of_features, number_of_spectators, number_of_events, training_fraction, parameters):
36 """
37 Create SKLearn classifier and store it in a State object
38 """
39 from sklearn.ensemble import GradientBoostingClassifier
40 if isinstance(parameters, collections.abc.Mapping):
41 clf = GradientBoostingClassifier(**parameters)
42 elif isinstance(parameters, collections.abc.Sequence):
43 clf = GradientBoostingClassifier(*parameters)
44 else:
45 clf = GradientBoostingClassifier()
46 return State(clf)
47
48
49def feature_importance(state):
50 """
51 Return a list containing the feature importances
52 """
53 from sklearn.ensemble import GradientBoostingClassifier
54 if isinstance(state.estimator, GradientBoostingClassifier):
55 return [x for x in state.estimator.feature_importances_]
56 return []
57
58
59def load(obj):
60 """
61 Load sklearn estimator into state
62 """
63 return State(obj)
64
65
66def apply(state, X):
67 """
68 Apply estimator to passed data.
69 If the estimator has a predict_proba it is called, otherwise call just predict.
70 """
71 if hasattr(state.estimator, 'predict_proba'):
72 x = state.estimator.predict_proba(X)
73 if x.shape[1] == 2:
74 x = state.estimator.predict_proba(X)[:, 1]
75 else:
76 x = state.estimator.predict(X)
77 return np.require(x, dtype=np.float32, requirements=['A', 'W', 'C', 'O'])
78
79
80def begin_fit(state, Xtest, Stest, ytest, wtest, nBatches):
81 """
82 Initialize lists which will store the received data
83 """
84 state.X = []
85 state.y = []
86 state.w = []
87 return state
88
89
90def partial_fit(state, X, S, y, w, epoch, batch):
91 """
92 Stores received training data.
93 SKLearn is usually not able to perform a partial fit.
94 """
95 if epoch > 0:
96 B2WARNING("The sklearn training interface has been called with specific_options.m_nIterations > 1."
97 " This means duplicates of the training sample will be used during training.")
98
99 state.X.append(X)
100 state.y.append(y.flatten())
101 state.w.append(w.flatten())
102 return True
103
104
105def end_fit(state):
106 """
107 Merge received data together and fit estimator
108 """
109 state.estimator = state.estimator.fit(np.vstack(state.X), np.hstack(state.y), np.hstack(state.w))
110 return state.estimator
estimator
Pickable sklearn estimator.
Definition: sklearn.py:32
def __init__(self, estimator=None)
Definition: sklearn.py:29