12 from basf2
import B2WARNING
17 print(
"Please install sklearn: pip3 install sklearn")
30 """ Constructor of the state object """
35 def get_model(number_of_features, number_of_spectators, number_of_events, training_fraction, parameters):
37 Create SKLearn classifier and store it in a State object
39 from sklearn.ensemble
import GradientBoostingClassifier
40 if isinstance(parameters, collections.Mapping):
41 clf = GradientBoostingClassifier(**parameters)
42 elif isinstance(parameters, collections.Sequence):
43 clf = GradientBoostingClassifier(*parameters)
45 clf = GradientBoostingClassifier()
49 def feature_importance(state):
51 Return a list containing the feature importances
53 from sklearn.ensemble
import GradientBoostingClassifier
54 if isinstance(state.estimator, GradientBoostingClassifier):
55 return [x
for x
in state.estimator.feature_importances_]
61 Load sklearn estimator into state
68 Apply estimator to passed data.
69 If the estimator has a predict_proba it is called, otherwise call just predict.
71 if hasattr(state.estimator,
'predict_proba'):
72 x = state.estimator.predict_proba(X)
74 x = state.estimator.predict_proba(X)[:, 1]
76 x = state.estimator.predict(X)
77 return np.require(x, dtype=np.float32, requirements=[
'A',
'W',
'C',
'O'])
80 def begin_fit(state, Xtest, Stest, ytest, wtest, nBatches):
82 Initialize lists which will store the received data
90 def partial_fit(state, X, S, y, w, epoch, batch):
92 Stores received training data.
93 SKLearn is usually not able to perform a partial fit.
96 B2WARNING(
"The sklearn training interface has been called with specific_options.m_nIterations > 1."
97 " This means duplicates of the training sample will be used during training.")
100 state.y.append(y.flatten())
101 state.w.append(w.flatten())
107 Merge received data together and fit estimator
109 state.estimator = state.estimator.fit(np.vstack(state.X), np.hstack(state.y), np.hstack(state.w))
110 return state.estimator
estimator
Pickable sklearn estimator.
def __init__(self, estimator=None)