13 from basf2
import B2WARNING
18 print(
"Please install sklearn: pip3 install sklearn")
31 """ Constructor of the state object """
36 def get_model(number_of_features, number_of_spectators, number_of_events, training_fraction, parameters):
38 Create SKLearn classifier and store it in a State object
40 from sklearn.ensemble
import GradientBoostingClassifier
41 if isinstance(parameters, collections.Mapping):
42 clf = GradientBoostingClassifier(**parameters)
43 elif isinstance(parameters, collections.Sequence):
44 clf = GradientBoostingClassifier(*parameters)
46 clf = GradientBoostingClassifier()
50 def feature_importance(state):
52 Return a list containing the feature importances
54 from sklearn.ensemble
import GradientBoostingClassifier
55 if isinstance(state.estimator, GradientBoostingClassifier):
56 return [x
for x
in state.estimator.feature_importances_]
62 Load sklearn estimator into state
69 Apply estimator to passed data.
70 If the estimator has a predict_proba it is called, otherwise call just predict.
72 if hasattr(state.estimator,
'predict_proba'):
73 x = state.estimator.predict_proba(X)
75 x = state.estimator.predict_proba(X)[:, 1]
77 x = state.estimator.predict(X)
78 return np.require(x, dtype=np.float32, requirements=[
'A',
'W',
'C',
'O'])
81 def begin_fit(state, Xtest, Stest, ytest, wtest, nBatches):
83 Initialize lists which will store the received data
91 def partial_fit(state, X, S, y, w, epoch, batch):
93 Stores received training data.
94 SKLearn is usually not able to perform a partial fit.
97 B2WARNING(
"The sklearn training interface has been called with specific_options.m_nIterations > 1."
98 " This means duplicates of the training sample will be used during training.")
101 state.y.append(y.flatten())
102 state.w.append(w.flatten())
108 Merge received data together and fit estimator
110 state.estimator = state.estimator.fit(np.vstack(state.X), np.hstack(state.y), np.hstack(state.w))
111 return state.estimator
estimator
Pickable sklearn estimator.
def __init__(self, estimator=None)