Belle II Software development
torch.py
1#!/usr/bin/env python3
2# -*- coding: utf-8 -*-
3
4
11
12import tempfile
13import pathlib
14import pickle
15import numpy as np
16import torch as pytorch
17from basf2 import B2INFO
18
19
20class State(object):
21 """
22 torch state
23 """
24
25 def __init__(self, model=None, **kwargs):
26 """ Constructor of the state object """
27
28 self.model = model
29
30
32
33 # other possible things to save that are needed by the model
34 for key, value in kwargs.items():
35 self.collection_keys.append(key)
36 setattr(self, key, value)
37
38
40 """
41 Custom PickleModule with a custom Unpickler
42
43 to be passed via the `pickle_module` argument in `torch.load`
44 """
45 class Unpickler(pickle.Unpickler):
46 """
47 Custom Unpickler that tries to find missing classes in the current global namespace.
48
49 This is needed since the move to per-python-mva-method module instances in which the classes live now.
50 """
51 def find_class(self, module, name):
52 """
53 If class can't be retrieved the regular way,
54 try to take from global namespace
55 """
56 try:
57 return super().find_class(module, name)
58 except (ModuleNotFoundError, AttributeError):
59 B2INFO(f"Missing class: {module}.{name}")
60 if name in globals():
61 B2INFO(f"Using `{name}` from global namespace")
62 return globals()[name]
63 else:
64 raise
65
66
67def feature_importance(state):
68 """
69 Return a list containing the feature importances.
70 Torch does not provide feature importances so return an empty list.
71 """
72 return []
73
74
75class myModel(pytorch.nn.Module):
76 """
77 My dense neural network
78 """
79
80 def __init__(self, number_of_features):
81 """
82 Init the network
83 param: number_of_features number of input variables
84 """
85 super(myModel, self).__init__()
86
87
88 self.network = pytorch.nn.Sequential(
89 pytorch.nn.Linear(number_of_features, 128),
90 pytorch.nn.ReLU(),
91 pytorch.nn.Linear(128, 128),
92 pytorch.nn.ReLU(),
93 pytorch.nn.Linear(128, 1),
94 pytorch.nn.Sigmoid(),
95 )
96
97 def forward(self, x):
98 """
99 Run the network
100 """
101 prob = self.network(x)
102 return prob
103
104
105def get_model(number_of_features, number_of_spectators, number_of_events, training_fraction, parameters):
106 """
107 Returns default torch model
108 """
109
110 # Note: if you override this function, need to pass all arguments of
111 # get_model that you need explicitly as kwargs to State!
112 state = State(
113 myModel(number_of_features).to("cpu"),
114 number_of_features=number_of_features,
115 number_of_spectators=number_of_spectators,
116 number_of_events=number_of_events,
117 training_fraction=training_fraction,
118 parameters=parameters,
119 )
120 print(state.model)
121
122 if parameters is None:
123 parameters = {}
124
125 state.optimizer = pytorch.optim.SGD(state.model.parameters(), parameters.get('learning_rate', 1e-3))
126
127 # we recreate the loss function on each batch so that we can pass in the weights
128 # this is a weird feature of how torch handles event weights
129 state.loss_fn = pytorch.nn.BCELoss
130
131 return state
132
133
134def begin_fit(state, Xtest, Stest, ytest, wtest, nBatches):
135 """
136 Passes in a fraction of events if specific_options.m_training_fraction is set.
137 """
138 # transform to torch tensor and store the validation sample for later use
139 device = "cpu"
140 state.Xtest = pytorch.from_numpy(Xtest).to(device)
141 state.ytest = pytorch.from_numpy(ytest).to(device)
142 state.wtest = pytorch.from_numpy(wtest).to(device)
143 return state
144
145
146def partial_fit(state, X, S, y, w, epoch, batch):
147 """
148 Pass received data to the torch model and train.
149
150 The epochs and batching are handled by the mva package.
151 If you prefer to do this yourself set
152 specific_options.m_nIterations = 1
153 specific_options.m_mini_batch_size = 0
154 which will pass all training data as a single batch once.
155 This can then be loaded into torch in any way you want.
156 """
157 # transform to torch tensor
158 device = "cpu"
159 tensor_x = pytorch.from_numpy(X).to(device)
160 tensor_y = pytorch.from_numpy(y).to(device).type(pytorch.float)
161 tensor_w = pytorch.from_numpy(w).to(device)
162
163 # Compute prediction and loss
164 loss_fn = state.loss_fn(weight=tensor_w)
165 pred = state.model(tensor_x)
166 loss = loss_fn(pred, tensor_y)
167
168 # Backpropagation
169 state.optimizer.zero_grad()
170 loss.backward()
171 state.optimizer.step()
172
173 if batch == 0 and epoch == 0:
174 state.avg_costs = [loss.detach().numpy()]
175 state.epoch = epoch
176 elif epoch != state.epoch:
177 # we are at the start of a new epoch, print out details of the last epoch
178 if len(state.ytest) > 0:
179 # run the validation set:
180 state.model.eval()
181 with pytorch.no_grad():
182 test_pred = state.model(state.Xtest)
183 test_loss_fn = state.loss_fn(weight=state.wtest)
184 test_loss = test_loss_fn(test_pred, state.ytest).item()
185 test_correct = (test_pred.round() == state.ytest).type(pytorch.float).sum().item()
186
187 print(f"Epoch: {epoch-1:04d},\t Training Cost: {np.mean((state.avg_costs)):.4f},"
188 f"\t Testing Cost: {test_loss:.4f}, \t Testing Accuracy: {test_correct/len(state.ytest)}")
189 state.model.train()
190 else:
191 print(f"Epoch: {epoch-1:04d},\t Training Cost: {np.mean((state.avg_costs)):.4f}")
192
193 state.avg_costs = [loss.detach().numpy()]
194 state.epoch = epoch
195 else:
196 state.avg_costs.append(loss.detach().numpy())
197
198 if epoch == 100000:
199 return False
200 return True
201
202
203def apply(state, X):
204 """
205 Apply estimator to passed data.
206 """
207 with pytorch.no_grad():
208 r = state.model(pytorch.from_numpy(X)).detach().numpy()
209 if r.shape[1] == 1:
210 r = r[:, 0] # cannot use squeeze because we might have output of shape [1,X classes]
211 return np.require(r, dtype=np.float32, requirements=['A', 'W', 'C', 'O'])
212
213
214def load(obj):
215 """
216 Load the trained torch model into state.
217 """
218 file_names, files, collection_keys, collections_to_store = obj
219 collections = {key: value for key, value in zip(collection_keys, collections_to_store)}
220
221 with tempfile.TemporaryDirectory() as temp_path:
222 temp_path = pathlib.Path(temp_path)
223
224 for file_index, file_name in enumerate(file_names):
225 path = temp_path.joinpath(pathlib.Path(file_name))
226 path.parents[0].mkdir(parents=True, exist_ok=True)
227 with open(path, 'w+b') as file:
228 file.write(bytes(files[file_index]))
229
230 # load model
231 try:
232 weights = pytorch.load(temp_path.joinpath(file_names[0]), weights_only=True)
233 state = get_model(
234 **{
235 k: collections.get(k, None)
236 for k in [
237 "number_of_features",
238 "number_of_spectators",
239 "number_of_events",
240 "training_fraction",
241 "parameters",
242 ]
243 }
244 )
245 state.model.load_state_dict(weights)
246 model = state.model
247 except pickle.UnpicklingError:
248 # in this case weights_only failed and we may have a legacy model stored via pickle
249 # -> try to unpickle with custom Unpickler
250 model = pytorch.load(temp_path.joinpath(file_names[0]), pickle_module=PickleModule)
251 state = State(model)
252
253 model.eval()
254 device = "cpu"
255 model.to(device)
256
257 # load everything else we saved
258 for key, value in collections.items():
259 setattr(state, key, value)
260 return state
261
262
263def end_fit(state):
264 """
265 Store torch model
266 """
267 with tempfile.TemporaryDirectory() as temp_path:
268
269 temp_path = pathlib.Path(temp_path)
270
271 # this creates:
272 # path/weights.pt
273 pytorch.save(state.model.state_dict(), temp_path / "weights.pt")
274
275 file_names = ['weights.pt']
276 files = []
277 for file_name in file_names:
278 with open(temp_path.joinpath(file_name), 'rb') as file:
279 files.append(file.read())
280
281 collection_keys = state.collection_keys
282 collections_to_store = []
283 for key in state.collection_keys:
284 collections_to_store.append(getattr(state, key))
285
286 del state
287 return [file_names, files, collection_keys, collections_to_store]
list collection_keys
list of keys to save.
Definition torch.py:31
__init__(self, model=None, **kwargs)
Definition torch.py:25
network
a dense model with one hidden layer
Definition torch.py:88
__init__(self, number_of_features)
Definition torch.py:80