16import torch
as pytorch
17from basf2
import B2INFO
26 """ Constructor of the state object """
34 for key, value
in kwargs.items():
36 setattr(self, key, value)
41 Custom PickleModule with a custom Unpickler
43 to be passed via the `pickle_module` argument in `torch.load`
47 Custom Unpickler that tries to find missing classes in the current global namespace.
49 This is needed since the move to per-python-mva-method module instances in which the classes live now.
53 If class can't be retrieved the regular way,
54 try to take from global namespace
58 except (ModuleNotFoundError, AttributeError):
59 B2INFO(f
"Missing class: {module}.{name}")
61 B2INFO(f
"Using `{name}` from global namespace")
62 return globals()[name]
67def feature_importance(state):
69 Return a list containing the feature importances.
70 Torch does not provide feature importances so return an empty list.
77 My dense neural network
83 param: number_of_features number of input variables
89 pytorch.nn.Linear(number_of_features, 128),
91 pytorch.nn.Linear(128, 128),
93 pytorch.nn.Linear(128, 1),
105def get_model(number_of_features, number_of_spectators, number_of_events, training_fraction, parameters):
107 Returns default torch model
113 myModel(number_of_features).to(
"cpu"),
114 number_of_features=number_of_features,
115 number_of_spectators=number_of_spectators,
116 number_of_events=number_of_events,
117 training_fraction=training_fraction,
118 parameters=parameters,
122 if parameters
is None:
125 state.optimizer = pytorch.optim.SGD(state.model.parameters(), parameters.get(
'learning_rate', 1e-3))
129 state.loss_fn = pytorch.nn.BCELoss
134def begin_fit(state, Xtest, Stest, ytest, wtest, nBatches):
136 Passes in a fraction of events if specific_options.m_training_fraction is set.
140 state.Xtest = pytorch.from_numpy(Xtest).to(device)
141 state.ytest = pytorch.from_numpy(ytest).to(device)
142 state.wtest = pytorch.from_numpy(wtest).to(device)
146def partial_fit(state, X, S, y, w, epoch, batch):
148 Pass received data to the torch model and train.
150 The epochs and batching are handled by the mva package.
151 If you prefer to do this yourself set
152 specific_options.m_nIterations = 1
153 specific_options.m_mini_batch_size = 0
154 which will pass all training data as a single batch once.
155 This can then be loaded into torch in any way you want.
159 tensor_x = pytorch.from_numpy(X).to(device)
160 tensor_y = pytorch.from_numpy(y).to(device).type(pytorch.float)
161 tensor_w = pytorch.from_numpy(w).to(device)
164 loss_fn = state.loss_fn(weight=tensor_w)
165 pred = state.model(tensor_x)
166 loss = loss_fn(pred, tensor_y)
169 state.optimizer.zero_grad()
171 state.optimizer.step()
173 if batch == 0
and epoch == 0:
174 state.avg_costs = [loss.detach().numpy()]
176 elif epoch != state.epoch:
178 if len(state.ytest) > 0:
181 with pytorch.no_grad():
182 test_pred = state.model(state.Xtest)
183 test_loss_fn = state.loss_fn(weight=state.wtest)
184 test_loss = test_loss_fn(test_pred, state.ytest).item()
185 test_correct = (test_pred.round() == state.ytest).type(pytorch.float).sum().item()
187 print(f
"Epoch: {epoch-1:04d},\t Training Cost: {np.mean((state.avg_costs)):.4f},"
188 f
"\t Testing Cost: {test_loss:.4f}, \t Testing Accuracy: {test_correct/len(state.ytest)}")
191 print(f
"Epoch: {epoch-1:04d},\t Training Cost: {np.mean((state.avg_costs)):.4f}")
193 state.avg_costs = [loss.detach().numpy()]
196 state.avg_costs.append(loss.detach().numpy())
205 Apply estimator to passed data.
207 with pytorch.no_grad():
208 r = state.model(pytorch.from_numpy(X)).detach().numpy()
211 return np.require(r, dtype=np.float32, requirements=[
'A',
'W',
'C',
'O'])
216 Load the trained torch model into state.
218 file_names, files, collection_keys, collections_to_store = obj
219 collections = {key: value
for key, value
in zip(collection_keys, collections_to_store)}
221 with tempfile.TemporaryDirectory()
as temp_path:
222 temp_path = pathlib.Path(temp_path)
224 for file_index, file_name
in enumerate(file_names):
225 path = temp_path.joinpath(pathlib.Path(file_name))
226 path.parents[0].mkdir(parents=
True, exist_ok=
True)
227 with open(path,
'w+b')
as file:
228 file.write(bytes(files[file_index]))
232 weights = pytorch.load(temp_path.joinpath(file_names[0]), weights_only=
True)
235 k: collections.get(k,
None)
237 "number_of_features",
238 "number_of_spectators",
245 state.model.load_state_dict(weights)
247 except pickle.UnpicklingError:
250 model = pytorch.load(temp_path.joinpath(file_names[0]), pickle_module=PickleModule)
258 for key, value
in collections.items():
259 setattr(state, key, value)
267 with tempfile.TemporaryDirectory()
as temp_path:
269 temp_path = pathlib.Path(temp_path)
273 pytorch.save(state.model.state_dict(), temp_path /
"weights.pt")
275 file_names = [
'weights.pt']
277 for file_name
in file_names:
278 with open(temp_path.joinpath(file_name),
'rb')
as file:
279 files.append(file.read())
281 collection_keys = state.collection_keys
282 collections_to_store = []
283 for key
in state.collection_keys:
284 collections_to_store.append(getattr(state, key))
287 return [file_names, files, collection_keys, collections_to_store]
find_class(self, module, name)
list collection_keys
list of keys to save.
__init__(self, model=None, **kwargs)
network
a dense model with one hidden layer
__init__(self, number_of_features)