Belle II Software light-2406-ragdoll
pidTrainWeights.py
1#!/usr/bin/env python3
2
3
10
11"""
12Methods and a script for training PID calibration weights.
13
14Sample usage:
15 $ python pidTrainWeights.py data/ models/net.pt -n 100
16
17Use `python pidTrainWeights.py -h` to see all command-line options.
18"""
19
20import torch
21import torch.nn as nn
22import torch.nn.functional as F
23import torch.optim as optim
24
25import numpy as np
26
27from os import makedirs
28from os.path import join, dirname
29from tqdm.auto import tqdm
30
31
32def _make_const_lists():
33 """Moving this code into a function to avoid a top-level ROOT import."""
34 import ROOT.Belle2
35
36 PARTICLES, PDG_CODES = [], []
37 for i in range(len(ROOT.Belle2.Const.chargedStableSet)):
38 particle = ROOT.Belle2.Const.chargedStableSet.at(i)
39 name = (particle.__repr__()[7:-1]
40 .replace("-", "")
41 .replace("+", "")
42 .replace("euteron", ""))
43 PARTICLES.append(name)
44 PDG_CODES.append(particle.getPDGCode())
45 # PARTICLES = ["e", "mu", "pi", "K", "p", "d"]
46 # PDG_CODES = [11, 13, 211, 321, 2212, 1000010020]
47
48 DETECTORS = []
49 for det in ROOT.Belle2.Const.PIDDetectors.set():
50 DETECTORS.append(ROOT.Belle2.Const.parseDetectors(det))
51 # DETECTORS = ["SVD", "CDC", "TOP", "ARICH", "ECL", "KLM"]
52
53 return PARTICLES, PDG_CODES, DETECTORS
54
55
56# PARTICLES, PDG_CODES, DETECTORS = _make_const_lists()
57PARTICLES = ["e", "mu", "pi", "K", "p", "d"]
58PDG_CODES = [11, 13, 211, 321, 2212, 1000010020]
59DETECTORS = ["SVD", "CDC", "TOP", "ARICH", "ECL", "KLM"]
60
61
62class WeightNet(nn.Module):
63 """PyTorch architecture for training calibration weights."""
64
65 def __init__(self, n_class=6, n_detector=6, const_init=1):
66 """Initialize the network for training.
67
68 Args:
69 n_class (int, optional): Number of classification classes (particle
70 types). Defaults to 6.
71 n_detector (int, optional): Number of detectors. Defaults to 6.
72 const_init (int, optional): Constant value to initialize all
73 weights. If None, PyTorch's default random initialization is
74 used instead. Defaults to 1.
75 """
76 super().__init__()
77
78
79 self.n_class = n_class
80
81
82 self.n_detector = n_detector
83
84
85 self.fcs = nn.ModuleList(
86 [nn.Linear(self.n_detector, 1, bias=False) for _ in range(self.n_class)]
87 )
88
89 if const_init is not None:
90 self.const_init(const_init)
91
92 def forward(self, x):
93 """Network's forward methods. Sums the detector log-likelihoods for each particle
94 type, then computes the likelihood ratios. Uses the weights.
95
96 Args:
97 x (torch.Tensor): Input detector log-likelihood data. Should be of
98 shape (N, n_detector * n_class), where N is the number of samples.
99
100 Returns:
101 torch.Tensor: Weighted likelihood ratios.
102 """
103 n = self.n_detector
104 outs = [self.fcs[i](x[:, i * n: (i + 1) * n]) for i in range(self.n_class)]
105 out = torch.cat(outs, dim=1)
106 return F.softmax(out, dim=1)
107
108 def get_weights(self, to_numpy=True):
109 """Returns the weights as a six-by-six array or tensor.
110
111 Args:
112 to_numpy (bool, optional): Whether to return the weights as a numpy
113 array (True) or torch tensor (False). Defaults to True.
114
115 Returns:
116 np.array or torch.Tensor: The six-by-six matrix containing the
117 weights.
118 """
119 weights = torch.cat([fc.weight.detach() for fc in self.fcs])
120 if to_numpy:
121 return weights.cpu().numpy()
122 else:
123 return weights
124
125 def const_init(self, const):
126 """Fill all the weights with the given value.
127
128 Args:
129 const (float): Constant value to fill all weights with.
130 """
131 with torch.no_grad():
132 for fc in self.fcs:
133 fc.weight.fill_(const)
134
135 def random_init(self, mean=1.0, std=0.5):
136 """Fill all the weights with values sampled from a Normal distribution
137 with given mean and standard deviation.
138
139 Args:
140 mean (float, optional): The mean of the Normal distribution.
141 Defaults to 1.0.
142 std (float, optional): The standard deviation of the Normal
143 distribution. Defaults to 0.5.
144 """
145 with torch.no_grad():
146 for fc in self.fcs:
147 fc.weight.fill_(0)
148 fc.weight.add_(torch.normal(mean=mean, std=std, size=fc.weight.size()))
149
150 def kill_unused(self, only):
151 """Kills weights corresponding to unused particle types.
152
153 Args:
154 only (list(str) or None): List of allowed particle types. The
155 weights corresponding to any particle types that are _not_ in
156 this list will be filled with zero and be frozen (e.g. gradients
157 will not be computed/updated).
158 """
159 if only is not None:
160 # particle types that are not being trained...
161 # set to zero and freeze
162 for i, pdg in enumerate(PDG_CODES):
163 if pdg in only:
164 continue
165 self.fcs[i].weight.requires_grad = False
166 self.fcs[i].weight.fill_(1)
167
168
169def load_training_data(directory, p_lims=None, theta_lims=None, device=None):
170 """Loads training and validation data within the given momentum and theta
171 limits (if given).
172
173 Args:
174 directory (str): Directory containing the train and validation sets.
175 p_lims (tuple(float), optional): Minimum and maximum momentum. Defaults
176 to None.
177 theta_lims (tuple(float), optional): Minimum and maximum theta in
178 degrees. Defaults to None.
179 device (torch.device, optional): Device to move the data onto. Defaults
180 to None.
181
182 Returns:
183 torch.Tensor: Training log-likelihood data.
184 torch.Tensor: Training labels.
185 torch.Tensor: Validation log-likelihood data.
186 torch.Tensor: Validation labels.
187 """
188 p_lo, p_hi = p_lims if p_lims is not None else (-np.inf, +np.inf)
189 t_lo, t_hi = theta_lims if theta_lims is not None else (-np.inf, +np.inf)
190 t_lo, t_hi = np.radians(t_lo), np.radians(t_hi)
191
192 def _load(filename):
193 data = np.load(filename)
194 X, y, p, t = data["X"], data["y"], data["p"], data["theta"]
195 mask = np.logical_and.reduce([p >= p_lo, p <= p_hi, t >= t_lo, t <= t_hi])
196 X = torch.tensor(X[mask]).to(device=device, dtype=torch.float)
197 y = torch.tensor(y[mask]).to(device=device, dtype=torch.long)
198 return X, y
199
200 X_tr, y_tr = _load(join(directory, "train.npz"))
201 X_va, y_va = _load(join(directory, "val.npz"))
202 return X_tr, y_tr, X_va, y_va
203
204
205def load_checkpoint(filename, device=None, only=None):
206 """Loads training from a checkpoint.
207
208 Args:
209 filename (str): Checkpoint filename.
210 device (torch.device, optional): Device to move the data onto. Defaults
211 to None.
212 only (list(str), optional): List of allowed particle types. Defaults to
213 None.
214
215 Returns:
216 WeightNet: The network.
217 optim.Optimizer: The optimizer.
218 int: Epoch number.
219 dict(str -> list(float)): Training losses (diag, pion, sum) from each
220 epoch.
221 dict(str -> list(float)): Validation losses (diag, pion, sum) from every
222 tenth epoch.
223 dict(str -> list(float)): Training accuracies (net, pion) from each
224 epoch.
225 dict(str -> list(float)): Validation accuracies (net, pion) from every
226 tenth epoch.
227 """
228 checkpoint = torch.load(filename, map_location=torch.device("cpu"))
229 net = WeightNet()
230 net.load_state_dict(checkpoint["model_state_dict"])
231 net.kill_unused(only)
232 net.to(device=device)
233
234 opt = optim.Adam(filter(lambda p: p.requires_grad, net.parameters()), lr=5e-4)
235 opt.load_state_dict(checkpoint["optimizer_state_dict"])
236
237 return (
238 net,
239 opt,
240 checkpoint["epoch"],
241 checkpoint["loss_t"],
242 checkpoint["loss_v"],
243 checkpoint["accu_t"],
244 checkpoint["accu_v"],
245 )
246
247
248def save_checkpoint(filename, net, opt, epoch, loss_t, loss_v, accu_t, accu_v):
249 """Saves training to a checkpoint.
250
251 Args:
252 filename (str): Checkpoint filename.
253 net (WeightNet): The network.
254 opt (optim.Optimizer): The optimizer.
255 epoch (int): The current epoch number.
256 loss_t (dict(str -> list(float))): Training losses (diag, pion, sum)
257 from each epoch.
258 loss_v (dict(str -> list(float))): Validation losses (diag, pion, sum)
259 from every tenth epoch.
260 accu_t (dict(str -> list(float))): Training accuracies (net, pion) from
261 each epoch.
262 accu_v (dict(str -> list(float))): Validation accuracies (net, pion)
263 from every tenth epoch.
264 """
265 net.cpu()
266 makedirs(dirname(filename), exist_ok=True)
267 torch.save(
268 {
269 "model_state_dict": net.state_dict(),
270 "optimizer_state_dict": opt.state_dict(),
271 "epoch": epoch,
272 "loss_t": loss_t,
273 "loss_v": loss_v,
274 "accu_t": accu_t,
275 "accu_v": accu_v,
276 },
277 filename,
278 )
279
280
281def initialize(args, device=None):
282 """Initializes training from the parsed command-line arguments.
283
284 Args:
285 args (argparse.Namespace): Parsed command-line arguments.
286 device (torch.device, optional): Device to move the data onto. Defaults
287 to None.
288
289 Returns:
290 WeightNet: The network.
291 optim.Optimizer: The optimizer.
292 int: Epoch number.
293 dict(str -> list(float)): Training losses (diag, pion, sum) from each
294 epoch.
295 dict(str -> list(float)): Validation losses (diag, pion, sum) from every
296 tenth epoch.
297 dict(str -> list(float)): Training accuracies (net, pion) from each
298 epoch.
299 dict(str -> list(float)): Validation accuracies (net, pion) from every
300 tenth epoch.
301 """
302 if args.resume is not None:
303 net, opt, epochs_0, l_t, l_v, a_t, a_v = load_checkpoint(args.resume, device=device, only=args.only)
304
305 else:
306 net = WeightNet()
307 if args.random:
308 net.random_init()
309 net.kill_unused(args.only)
310 net.to(device=device)
311 opt = optim.Adam(filter(lambda p: p.requires_grad, net.parameters()), lr=5e-4)
312 epochs_0 = 0
313 l_t = {"diag": [], "pion": [], "sum": []}
314 l_v = {"diag": [], "pion": [], "sum": []}
315 a_t = {"net": [], "pion": []}
316 a_v = {"net": [], "pion": []}
317
318 return net, opt, epochs_0, l_t, l_v, a_t, a_v
319
320
321def train_model(args, use_tqdm=True):
322 """Trains and saves a model according to the command-line arguments.
323
324 Args:
325 args (argparse.Namespace): Parsed command-line arguments.
326 use_tqdm (bool, optional): Use TQDM to track progress. Defaults to True.
327 """
328 print("Reading data.")
329 device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
330 print("...and moving to", device)
331 X_tr, y_tr, X_va, y_va = load_training_data(
332 args.input, device=device, p_lims=args.p_lims, theta_lims=args.theta_lims
333 )
334
335 if len(y_tr) < 10 or len(y_va) < 10:
336 print("Not enough data. Aborting...")
337 return
338
339 print(f"{len(y_tr)} train events, {len(y_va)} val events")
340
341 print("Initializing network.")
342 net, opt, epochs_0, loss_t, loss_v, accu_t, accu_v = initialize(args, device=device)
343
344 diag_lfn = nn.CrossEntropyLoss()
345 pion_lfn = nn.BCELoss()
346 pion_wgt = args.beta
347
348 def compute_accuracies(out, y):
349 output = out.detach().cpu().numpy()
350 target = y.detach().cpu().numpy()
351 pred = np.squeeze(np.argmax(output, axis=1))
352 accu = np.count_nonzero(pred == target) / len(pred)
353 pi_out = output[(target == 2), 2]
354 pi_pred = (pi_out > 0.5).astype(float)
355 pi_accu = pi_pred.sum() / len(pi_pred)
356 return accu, pi_accu
357
358 def lfn(output, target):
359 diag = diag_lfn(output, target)
360 pi_mask = target == 2
361 pi_out = output[pi_mask, 2]
362 pi_y = (target[pi_mask] == 2).float()
363 pion = pion_lfn(pi_out, pi_y)
364 return diag + pion_wgt * pion, diag, pion
365
366 print(f"Training network for {args.n_epochs} epochs.")
367
368 iterator = range(args.n_epochs)
369 if use_tqdm:
370 iterator = tqdm(iterator)
371
372 for epoch in iterator:
373 # train step
374 net.train()
375 opt.zero_grad()
376 out = net(X_tr)
377 loss, diag, pion = lfn(out, y_tr)
378 loss.backward()
379 opt.step()
380
381 # record training data
382 loss_t["diag"].append(diag.item())
383 loss_t["pion"].append(pion.item())
384 loss_t["sum"].append(loss.item())
385
386 if epoch % 10 == 0:
387 accu, pi_accu = compute_accuracies(out, y_tr)
388 accu_t["net"].append(accu)
389 accu_t["pion"].append(pi_accu)
390
391 # val step
392 net.eval()
393 with torch.no_grad():
394 out = net(X_va)
395 loss, diag, pion = lfn(out, y_va)
396
397 loss_v["diag"].append(diag.item())
398 loss_v["pion"].append(pion.item())
399 loss_v["sum"].append(loss.item())
400
401 if epoch % 10 == 0:
402 accu, pi_accu = compute_accuracies(out, y_va)
403 accu_v["net"].append(accu)
404 accu_v["pion"].append(pi_accu)
405
406 print("Training complete.")
407
408 net.cpu()
409 save_checkpoint(
410 args.output, net, opt, epochs_0 + args.n_epochs, loss_t, loss_v, accu_t, accu_v,
411 )
412
413 wgt = net.get_weights(to_numpy=True)
414 np.save(args.output.replace(".pt", "_wgt.npy"), wgt)
415
416 print(f"Model saved to {args.output}.")
417
418
419def get_parser():
420 """Handles the command-line argument parsing.
421
422 Returns:
423 argparse.Namespace: The parsed arguments.
424 """
425 from argparse import ArgumentParser
426
427 ap = ArgumentParser(description="", epilog="")
428 ap.add_argument(
429 "input", type=str, help="Path to folder with training files (in .npz format).",
430 )
431 ap.add_argument(
432 "output",
433 type=str,
434 help="Output filename where model will be saved (should end in .pt).",
435 )
436 ap.add_argument(
437 "-n",
438 "--n_epochs",
439 type=int,
440 default=500,
441 help="Number of epochs to train the network. Defaults to 500.",
442 )
443 ap.add_argument(
444 "--p_lims",
445 nargs=2,
446 default=[-float('inf'), +float('inf')],
447 help=(
448 "Lower and upper limits for momentum in GeV. Lower limit "
449 "should be given first. Default values are -inf, +inf."
450 ),
451 )
452 ap.add_argument(
453 "--theta_lims",
454 nargs=2,
455 default=[-float('inf'), +float('inf')],
456 help=(
457 "Lower and upper limits for theta in degrees. Lower limit "
458 "should be given first. Default values are -inf, +inf."
459 ),
460 )
461 ap.add_argument(
462 "-R",
463 "--resume",
464 nargs="?",
465 type=str,
466 const="use_output",
467 default=None,
468 help=(
469 "Load a pre-existing model and resume training instead of "
470 "starting a new one. If '--resume' is used and no file is "
471 "specified, the output filepath will be loaded, and the "
472 "training will overwrite that file. Alternatively, if a "
473 "filepath is given, training will begin from the state "
474 "saved in that file and will be saved to the output filepath."
475 ),
476 )
477 ap.add_argument(
478 "--only",
479 type=int,
480 nargs="*",
481 help=(
482 "Use only log-likelihood data from a subset of particle "
483 "types specified by PDG code. If left unspecified, all "
484 "particle types will be used."
485 ),
486 )
487 ap.add_argument(
488 "--random",
489 action="store_true",
490 help=(
491 "Initialize network weights to random values, normally "
492 "distributed with mean of 1 and width of 0.5. Has no effect "
493 "if 'resume' is used."
494 ),
495 )
496 ap.add_argument(
497 "--beta",
498 type=float,
499 default=0.1,
500 help=(
501 "Scaling factor for the pion binary cross entropy term in "
502 "the loss function. Defaults to 0.1."
503 ),
504 )
505 return ap
506
507
508def validate_args(args):
509 # validate some values
510 assert args.n_epochs > 0, "Number of epochs must be larger than 0."
511 assert args.p_lims[0] < args.p_lims[1], "p_lims: lower limit must be < upper limit."
512 assert (
513 args.theta_lims[0] < args.theta_lims[1]
514 ), "theta_lims: lower limit must be < upper limit."
515 if args.only is not None:
516 for pdg in args.only:
517 assert pdg in PDG_CODES, f"Given PDG code {pdg} not understood."
518
519 if args.resume == "use_output":
520 args.resume = args.output
521
522 return args
523
524
525def main():
526 """Main network training logic."""
527
528 args = get_parser().parse_args()
529 args = validate_args(args)
530
531 print("Welcome to the network trainer.")
532 print(f"Data will be read from {args.input}.")
533 if args.resume is not None:
534 print(f"Model training will be continued from the saved state at {args.resume}")
535 print(f"The trained model will be saved to {args.output}.")
536
537 if not (args.p_lims[0] == -np.inf and args.p_lims[0] == +np.inf):
538 print(
539 f"Model will be trained on data with momentum in "
540 f"[{args.p_lims[0]}, {args.p_lims[1]}] GeV."
541 )
542
543 if not (args.theta_lims[0] == -np.inf and args.theta_lims[0] == +np.inf):
544 print(
545 f"Model will be trained on data with theta in "
546 f"[{args.theta_lims[0]}, {args.theta_lims[1]}] degrees."
547 )
548
549 print(f"The model will be trained for {args.n_epochs} epochs.")
550 print(f"Training will use a pion scaling factor of beta = {args.beta}.")
551 print("---")
552
553 train_model(args)
554 print("\nFinished!")
555
556
557if __name__ == "__main__":
558 main()
def random_init(self, mean=1.0, std=0.5)
n_detector
number of detectors
def __init__(self, n_class=6, n_detector=6, const_init=1)
def get_weights(self, to_numpy=True)
fcs
linear layers for each particle type
def const_init(self, const)
n_class
number of particle types
Definition: main.py:1