Belle II Software development
torch.py
1#!/usr/bin/env python3
2# -*- coding: utf-8 -*-
3
4
11
12import tempfile
13import pathlib
14import numpy as np
15import torch
16
17
18class State(object):
19 """
20 torch state
21 """
22
23 def __init__(self, model=None, **kwargs):
24 """ Constructor of the state object """
25
26 self.model = model
27
28
30
31 # other possible things to save that are needed by the model
32 for key, value in kwargs.items():
33 self.collection_keys.append(key)
34 setattr(self, key, value)
35
36
37def feature_importance(state):
38 """
39 Return a list containing the feature importances.
40 Torch does not provide feature importances so return an empty list.
41 """
42 return []
43
44
45class myModel(torch.nn.Module):
46 """
47 My dense neural network
48 """
49
50 def __init__(self, number_of_features):
51 """
52 Init the network
53 param: number_of_features number of input variables
54 """
55 super(myModel, self).__init__()
56
57
58 self.network = torch.nn.Sequential(
59 torch.nn.Linear(number_of_features, 128),
60 torch.nn.ReLU(),
61 torch.nn.Linear(128, 128),
62 torch.nn.ReLU(),
63 torch.nn.Linear(128, 1),
64 torch.nn.Sigmoid(),
65 )
66
67 def forward(self, x):
68 """
69 Run the network
70 """
71 prob = self.network(x)
72 return prob
73
74
75def get_model(number_of_features, number_of_spectators, number_of_events, training_fraction, parameters):
76 """
77 Returns default torch model
78 """
79
80 state = State(myModel(number_of_features).to("cpu"))
81 print(state.model)
82
83 if parameters is None:
84 parameters = {}
85
86 state.optimizer = torch.optim.SGD(state.model.parameters(), parameters.get('learning_rate', 1e-3))
87
88 # we recreate the loss function on each batch so that we can pass in the weights
89 # this is a weird feature of how torch handles event weights
90 state.loss_fn = torch.nn.BCELoss
91
92 return state
93
94
95def begin_fit(state, Xtest, Stest, ytest, wtest, nBatches):
96 """
97 Passes in a fraction of events if specific_options.m_training_fraction is set.
98 """
99 # transform to torch tensor and store the validation sample for later use
100 device = "cpu"
101 state.Xtest = torch.from_numpy(Xtest).to(device)
102 state.ytest = torch.from_numpy(ytest).to(device)
103 state.wtest = torch.from_numpy(wtest).to(device)
104 return state
105
106
107def partial_fit(state, X, S, y, w, epoch, batch):
108 """
109 Pass received data to the torch model and train.
110
111 The epochs and batching are handled by the mva package.
112 If you prefer to do this yourself set
113 specific_options.m_nIterations = 1
114 specific_options.m_mini_batch_size = 0
115 which will pass all training data as a single batch once.
116 This can then be loaded into torch in any way you want.
117 """
118 # transform to torch tensor
119 device = "cpu"
120 tensor_x = torch.from_numpy(X).to(device)
121 tensor_y = torch.from_numpy(y).to(device).type(torch.float)
122 tensor_w = torch.from_numpy(w).to(device)
123
124 # Compute prediction and loss
125 loss_fn = state.loss_fn(weight=tensor_w)
126 pred = state.model(tensor_x)
127 loss = loss_fn(pred, tensor_y)
128
129 # Backpropagation
130 state.optimizer.zero_grad()
131 loss.backward()
132 state.optimizer.step()
133
134 if batch == 0 and epoch == 0:
135 state.avg_costs = [loss.detach().numpy()]
136 state.epoch = epoch
137 elif epoch != state.epoch:
138 # we are at the start of a new epoch, print out details of the last epoch
139 if len(state.ytest) > 0:
140 # run the validation set:
141 state.model.eval()
142 with torch.no_grad():
143 test_pred = state.model(state.Xtest)
144 test_loss_fn = state.loss_fn(weight=state.wtest)
145 test_loss = test_loss_fn(test_pred, state.ytest).item()
146 test_correct = (test_pred.round() == state.ytest).type(torch.float).sum().item()
147
148 print(f"Epoch: {epoch-1:04d},\t Training Cost: {np.mean((state.avg_costs)):.4f},"
149 f"\t Testing Cost: {test_loss:.4f}, \t Testing Accuracy: {test_correct/len(state.ytest)}")
150 state.model.train()
151 else:
152 print(f"Epoch: {epoch-1:04d},\t Training Cost: {np.mean((state.avg_costs)):.4f}")
153
154 state.avg_costs = [loss.detach().numpy()]
155 state.epoch = epoch
156 else:
157 state.avg_costs.append(loss.detach().numpy())
158
159 if epoch == 100000:
160 return False
161 return True
162
163
164def apply(state, X):
165 """
166 Apply estimator to passed data.
167 """
168 with torch.no_grad():
169 r = state.model(torch.from_numpy(X)).detach().numpy()
170 if r.shape[1] == 1:
171 r = r[:, 0] # cannot use squeeze because we might have output of shape [1,X classes]
172 return np.require(r, dtype=np.float32, requirements=['A', 'W', 'C', 'O'])
173
174
175def load(obj):
176 """
177 Load the trained torch model into state.
178 """
179 with tempfile.TemporaryDirectory() as temp_path:
180 temp_path = pathlib.Path(temp_path)
181
182 file_names = obj[0]
183 for file_index, file_name in enumerate(file_names):
184 path = temp_path.joinpath(pathlib.Path(file_name))
185 path.parents[0].mkdir(parents=True, exist_ok=True)
186
187 with open(path, 'w+b') as file:
188 file.write(bytes(obj[1][file_index]))
189
190 # state = get_model(*args, **kwargs)
191 # model = torch.jit.load(temp_path.joinpath(file_names[0]))
192 # model.load_state_dict(torch.load(temp_path.joinpath(file_names[0])))
193 model = torch.load(temp_path.joinpath(file_names[0]))
194 model.eval() # sets dropout and batch norm layers to eval mode
195 device = "cpu"
196 model.to(device)
197 state = State(model)
198
199 # load everything else we saved
200 for index, key in enumerate(obj[2]):
201 setattr(state, key, obj[3][index])
202 return state
203
204
205def end_fit(state):
206 """
207 Store torch model
208 """
209 with tempfile.TemporaryDirectory() as temp_path:
210
211 temp_path = pathlib.Path(temp_path)
212
213 # this creates:
214 # path/my_model.pt
215 torch.save(state.model, temp_path.joinpath('my_model.pt'))
216
217 file_names = ['my_model.pt']
218 files = []
219 for file_name in file_names:
220 with open(temp_path.joinpath(file_name), 'rb') as file:
221 files.append(file.read())
222
223 collection_keys = state.collection_keys
224 collections_to_store = []
225 for key in state.collection_keys:
226 collections_to_store.append(getattr(state, key))
227
228 del state
229 return [file_names, files, collection_keys, collections_to_store]
collection_keys
list of keys to save.
Definition: torch.py:29
def __init__(self, model=None, **kwargs)
Definition: torch.py:23
network
a dense model with one hidden layer
Definition: torch.py:58
def __init__(self, number_of_features)
Definition: torch.py:50