Belle II Software development
pidTrainWeights.py
1#!/usr/bin/env python3
2
3
10
11"""
12Methods and a script for training PID calibration weights.
13
14Sample usage:
15 $ python pidTrainWeights.py data/ models/net.pt -n 100
16
17Use `python pidTrainWeights.py -h` to see all command-line options.
18"""
19
20import torch
21import torch.nn as nn
22import torch.nn.functional as F
23import torch.optim as optim
24
25import numpy as np
26
27from os import makedirs
28from os.path import join, dirname
29from tqdm.auto import tqdm
30
31
32def _make_const_lists():
33 """Moving this code into a function to avoid a top-level ROOT import."""
34 from ROOT import Belle2 # noqa: make Belle2 namespace available
35 import ROOT.Belle2
36
37 PARTICLES, PDG_CODES = [], []
38 for i in range(len(ROOT.Belle2.Const.chargedStableSet)):
39 particle = ROOT.Belle2.Const.chargedStableSet.at(i)
40 name = (particle.__repr__()[7:-1]
41 .replace("-", "")
42 .replace("+", "")
43 .replace("euteron", ""))
44 PARTICLES.append(name)
45 PDG_CODES.append(particle.getPDGCode())
46 # PARTICLES = ["e", "mu", "pi", "K", "p", "d"]
47 # PDG_CODES = [11, 13, 211, 321, 2212, 1000010020]
48
49 DETECTORS = []
50 for det in ROOT.Belle2.Const.PIDDetectors.set():
51 DETECTORS.append(ROOT.Belle2.Const.parseDetectors(det))
52 # DETECTORS = ["SVD", "CDC", "TOP", "ARICH", "ECL", "KLM"]
53
54 return PARTICLES, PDG_CODES, DETECTORS
55
56
57# PARTICLES, PDG_CODES, DETECTORS = _make_const_lists()
58PARTICLES = ["e", "mu", "pi", "K", "p", "d"]
59PDG_CODES = [11, 13, 211, 321, 2212, 1000010020]
60DETECTORS = ["SVD", "CDC", "TOP", "ARICH", "ECL", "KLM"]
61
62
63class WeightNet(nn.Module):
64 """PyTorch architecture for training calibration weights."""
65
66 def __init__(self, n_class=6, n_detector=6, const_init=1):
67 """Initialize the network for training.
68
69 Args:
70 n_class (int, optional): Number of classification classes (particle
71 types). Defaults to 6.
72 n_detector (int, optional): Number of detectors. Defaults to 6.
73 const_init (int, optional): Constant value to initialize all
74 weights. If None, PyTorch's default random initialization is
75 used instead. Defaults to 1.
76 """
77 super().__init__()
78
79
80 self.n_class = n_class
81
82
83 self.n_detector = n_detector
84
85
86 self.fcs = nn.ModuleList(
87 [nn.Linear(self.n_detector, 1, bias=False) for _ in range(self.n_class)]
88 )
89
90 if const_init is not None:
91 self.const_init(const_init)
92
93 def forward(self, x):
94 """Network's forward methods. Sums the detector log-likelihoods for each particle
95 type, then computes the likelihood ratios. Uses the weights.
96
97 Args:
98 x (torch.Tensor): Input detector log-likelihood data. Should be of
99 shape (N, n_detector * n_class), where N is the number of samples.
100
101 Returns:
102 torch.Tensor: Weighted likelihood ratios.
103 """
104 n = self.n_detector
105 outs = [self.fcs[i](x[:, i * n: (i + 1) * n]) for i in range(self.n_class)]
106 out = torch.cat(outs, dim=1)
107 return F.softmax(out, dim=1)
108
109 def get_weights(self, to_numpy=True):
110 """Returns the weights as a six-by-six array or tensor.
111
112 Args:
113 to_numpy (bool, optional): Whether to return the weights as a numpy
114 array (True) or torch tensor (False). Defaults to True.
115
116 Returns:
117 np.array or torch.Tensor: The six-by-six matrix containing the
118 weights.
119 """
120 weights = torch.cat([fc.weight.detach() for fc in self.fcs])
121 if to_numpy:
122 return weights.cpu().numpy()
123 else:
124 return weights
125
126 def const_init(self, const):
127 """Fill all the weights with the given value.
128
129 Args:
130 const (float): Constant value to fill all weights with.
131 """
132 with torch.no_grad():
133 for fc in self.fcs:
134 fc.weight.fill_(const)
135
136 def random_init(self, mean=1.0, std=0.5):
137 """Fill all the weights with values sampled from a Normal distribution
138 with given mean and standard deviation.
139
140 Args:
141 mean (float, optional): The mean of the Normal distribution.
142 Defaults to 1.0.
143 std (float, optional): The standard deviation of the Normal
144 distribution. Defaults to 0.5.
145 """
146 with torch.no_grad():
147 for fc in self.fcs:
148 fc.weight.fill_(0)
149 fc.weight.add_(torch.normal(mean=mean, std=std, size=fc.weight.size()))
150
151 def kill_unused(self, only):
152 """Kills weights corresponding to unused particle types.
153
154 Args:
155 only (list(str) or None): List of allowed particle types. The
156 weights corresponding to any particle types that are _not_ in
157 this list will be filled with zero and be frozen (e.g. gradients
158 will not be computed/updated).
159 """
160 if only is not None:
161 # particle types that are not being trained...
162 # set to zero and freeze
163 for i, pdg in enumerate(PDG_CODES):
164 if pdg in only:
165 continue
166 self.fcs[i].weight.requires_grad = False
167 self.fcs[i].weight.fill_(1)
168
169
170def load_training_data(directory, p_lims=None, theta_lims=None, device=None):
171 """Loads training and validation data within the given momentum and theta
172 limits (if given).
173
174 Args:
175 directory (str): Directory containing the train and validation sets.
176 p_lims (tuple(float), optional): Minimum and maximum momentum. Defaults
177 to None.
178 theta_lims (tuple(float), optional): Minimum and maximum theta in
179 degrees. Defaults to None.
180 device (torch.device, optional): Device to move the data onto. Defaults
181 to None.
182
183 Returns:
184 torch.Tensor: Training log-likelihood data.
185 torch.Tensor: Training labels.
186 torch.Tensor: Validation log-likelihood data.
187 torch.Tensor: Validation labels.
188 """
189 p_lo, p_hi = p_lims if p_lims is not None else (-np.inf, +np.inf)
190 t_lo, t_hi = theta_lims if theta_lims is not None else (-np.inf, +np.inf)
191 t_lo, t_hi = np.radians(t_lo), np.radians(t_hi)
192
193 def _load(filename):
194 data = np.load(filename)
195 X, y, p, t = data["X"], data["y"], data["p"], data["theta"]
196 mask = np.logical_and.reduce([p >= p_lo, p <= p_hi, t >= t_lo, t <= t_hi])
197 X = torch.tensor(X[mask]).to(device=device, dtype=torch.float)
198 y = torch.tensor(y[mask]).to(device=device, dtype=torch.long)
199 return X, y
200
201 X_tr, y_tr = _load(join(directory, "train.npz"))
202 X_va, y_va = _load(join(directory, "val.npz"))
203 return X_tr, y_tr, X_va, y_va
204
205
206def load_checkpoint(filename, device=None, only=None):
207 """Loads training from a checkpoint.
208
209 Args:
210 filename (str): Checkpoint filename.
211 device (torch.device, optional): Device to move the data onto. Defaults
212 to None.
213 only (list(str), optional): List of allowed particle types. Defaults to
214 None.
215
216 Returns:
217 WeightNet: The network.
218 optim.Optimizer: The optimizer.
219 int: Epoch number.
220 dict(str -> list(float)): Training losses (diag, pion, sum) from each
221 epoch.
222 dict(str -> list(float)): Validation losses (diag, pion, sum) from every
223 tenth epoch.
224 dict(str -> list(float)): Training accuracies (net, pion) from each
225 epoch.
226 dict(str -> list(float)): Validation accuracies (net, pion) from every
227 tenth epoch.
228 """
229 checkpoint = torch.load(filename, map_location=torch.device("cpu"))
230 net = WeightNet()
231 net.load_state_dict(checkpoint["model_state_dict"])
232 net.kill_unused(only)
233 net.to(device=device)
234
235 opt = optim.Adam(filter(lambda p: p.requires_grad, net.parameters()), lr=5e-4)
236 opt.load_state_dict(checkpoint["optimizer_state_dict"])
237
238 return (
239 net,
240 opt,
241 checkpoint["epoch"],
242 checkpoint["loss_t"],
243 checkpoint["loss_v"],
244 checkpoint["accu_t"],
245 checkpoint["accu_v"],
246 )
247
248
249def save_checkpoint(filename, net, opt, epoch, loss_t, loss_v, accu_t, accu_v):
250 """Saves training to a checkpoint.
251
252 Args:
253 filename (str): Checkpoint filename.
254 net (WeightNet): The network.
255 opt (optim.Optimizer): The optimizer.
256 epoch (int): The current epoch number.
257 loss_t (dict(str -> list(float))): Training losses (diag, pion, sum)
258 from each epoch.
259 loss_v (dict(str -> list(float))): Validation losses (diag, pion, sum)
260 from every tenth epoch.
261 accu_t (dict(str -> list(float))): Training accuracies (net, pion) from
262 each epoch.
263 accu_v (dict(str -> list(float))): Validation accuracies (net, pion)
264 from every tenth epoch.
265 """
266 net.cpu()
267 makedirs(dirname(filename), exist_ok=True)
268 torch.save(
269 {
270 "model_state_dict": net.state_dict(),
271 "optimizer_state_dict": opt.state_dict(),
272 "epoch": epoch,
273 "loss_t": loss_t,
274 "loss_v": loss_v,
275 "accu_t": accu_t,
276 "accu_v": accu_v,
277 },
278 filename,
279 )
280
281
282def initialize(args, device=None):
283 """Initializes training from the parsed command-line arguments.
284
285 Args:
286 args (argparse.Namespace): Parsed command-line arguments.
287 device (torch.device, optional): Device to move the data onto. Defaults
288 to None.
289
290 Returns:
291 WeightNet: The network.
292 optim.Optimizer: The optimizer.
293 int: Epoch number.
294 dict(str -> list(float)): Training losses (diag, pion, sum) from each
295 epoch.
296 dict(str -> list(float)): Validation losses (diag, pion, sum) from every
297 tenth epoch.
298 dict(str -> list(float)): Training accuracies (net, pion) from each
299 epoch.
300 dict(str -> list(float)): Validation accuracies (net, pion) from every
301 tenth epoch.
302 """
303 if args.resume is not None:
304 net, opt, epochs_0, l_t, l_v, a_t, a_v = load_checkpoint(args.resume, device=device, only=args.only)
305
306 else:
307 net = WeightNet()
308 if args.random:
309 net.random_init()
310 net.kill_unused(args.only)
311 net.to(device=device)
312 opt = optim.Adam(filter(lambda p: p.requires_grad, net.parameters()), lr=5e-4)
313 epochs_0 = 0
314 l_t = {"diag": [], "pion": [], "sum": []}
315 l_v = {"diag": [], "pion": [], "sum": []}
316 a_t = {"net": [], "pion": []}
317 a_v = {"net": [], "pion": []}
318
319 return net, opt, epochs_0, l_t, l_v, a_t, a_v
320
321
322def train_model(args, use_tqdm=True):
323 """Trains and saves a model according to the command-line arguments.
324
325 Args:
326 args (argparse.Namespace): Parsed command-line arguments.
327 use_tqdm (bool, optional): Use TQDM to track progress. Defaults to True.
328 """
329 print("Reading data.")
330 device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
331 print("...and moving to", device)
332 X_tr, y_tr, X_va, y_va = load_training_data(
333 args.input, device=device, p_lims=args.p_lims, theta_lims=args.theta_lims
334 )
335
336 if len(y_tr) < 10 or len(y_va) < 10:
337 print("Not enough data. Aborting...")
338 return
339
340 print(f"{len(y_tr)} train events, {len(y_va)} val events")
341
342 print("Initializing network.")
343 net, opt, epochs_0, loss_t, loss_v, accu_t, accu_v = initialize(args, device=device)
344
345 diag_lfn = nn.CrossEntropyLoss()
346 pion_lfn = nn.BCELoss()
347 pion_wgt = args.beta
348
349 def compute_accuracies(out, y):
350 output = out.detach().cpu().numpy()
351 target = y.detach().cpu().numpy()
352 pred = np.squeeze(np.argmax(output, axis=1))
353 accu = np.count_nonzero(pred == target) / len(pred)
354 pi_out = output[(target == 2), 2]
355 pi_pred = (pi_out > 0.5).astype(float)
356 pi_accu = pi_pred.sum() / len(pi_pred)
357 return accu, pi_accu
358
359 def lfn(output, target):
360 diag = diag_lfn(output, target)
361 pi_mask = target == 2
362 pi_out = output[pi_mask, 2]
363 pi_y = (target[pi_mask] == 2).float()
364 pion = pion_lfn(pi_out, pi_y)
365 return diag + pion_wgt * pion, diag, pion
366
367 print(f"Training network for {args.n_epochs} epochs.")
368
369 iterator = range(args.n_epochs)
370 if use_tqdm:
371 iterator = tqdm(iterator)
372
373 for epoch in iterator:
374 # train step
375 net.train()
376 opt.zero_grad()
377 out = net(X_tr)
378 loss, diag, pion = lfn(out, y_tr)
379 loss.backward()
380 opt.step()
381
382 # record training data
383 loss_t["diag"].append(diag.item())
384 loss_t["pion"].append(pion.item())
385 loss_t["sum"].append(loss.item())
386
387 if epoch % 10 == 0:
388 accu, pi_accu = compute_accuracies(out, y_tr)
389 accu_t["net"].append(accu)
390 accu_t["pion"].append(pi_accu)
391
392 # val step
393 net.eval()
394 with torch.no_grad():
395 out = net(X_va)
396 loss, diag, pion = lfn(out, y_va)
397
398 loss_v["diag"].append(diag.item())
399 loss_v["pion"].append(pion.item())
400 loss_v["sum"].append(loss.item())
401
402 if epoch % 10 == 0:
403 accu, pi_accu = compute_accuracies(out, y_va)
404 accu_v["net"].append(accu)
405 accu_v["pion"].append(pi_accu)
406
407 print("Training complete.")
408
409 net.cpu()
410 save_checkpoint(
411 args.output, net, opt, epochs_0 + args.n_epochs, loss_t, loss_v, accu_t, accu_v,
412 )
413
414 wgt = net.get_weights(to_numpy=True)
415 np.save(args.output.replace(".pt", "_wgt.npy"), wgt)
416
417 print(f"Model saved to {args.output}.")
418
419
420def get_parser():
421 """Handles the command-line argument parsing.
422
423 Returns:
424 argparse.Namespace: The parsed arguments.
425 """
426 from argparse import ArgumentParser
427
428 ap = ArgumentParser(description="", epilog="")
429 ap.add_argument(
430 "input", type=str, help="Path to folder with training files (in .npz format).",
431 )
432 ap.add_argument(
433 "output",
434 type=str,
435 help="Output filename where model will be saved (should end in .pt).",
436 )
437 ap.add_argument(
438 "-n",
439 "--n_epochs",
440 type=int,
441 default=500,
442 help="Number of epochs to train the network. Defaults to 500.",
443 )
444 ap.add_argument(
445 "--p_lims",
446 nargs=2,
447 default=[-float('inf'), +float('inf')],
448 help=(
449 "Lower and upper limits for momentum in GeV. Lower limit "
450 "should be given first. Default values are -inf, +inf."
451 ),
452 )
453 ap.add_argument(
454 "--theta_lims",
455 nargs=2,
456 default=[-float('inf'), +float('inf')],
457 help=(
458 "Lower and upper limits for theta in degrees. Lower limit "
459 "should be given first. Default values are -inf, +inf."
460 ),
461 )
462 ap.add_argument(
463 "-R",
464 "--resume",
465 nargs="?",
466 type=str,
467 const="use_output",
468 default=None,
469 help=(
470 "Load a pre-existing model and resume training instead of "
471 "starting a new one. If '--resume' is used and no file is "
472 "specified, the output filepath will be loaded, and the "
473 "training will overwrite that file. Alternatively, if a "
474 "filepath is given, training will begin from the state "
475 "saved in that file and will be saved to the output filepath."
476 ),
477 )
478 ap.add_argument(
479 "--only",
480 type=int,
481 nargs="*",
482 help=(
483 "Use only log-likelihood data from a subset of particle "
484 "types specified by PDG code. If left unspecified, all "
485 "particle types will be used."
486 ),
487 )
488 ap.add_argument(
489 "--random",
490 action="store_true",
491 help=(
492 "Initialize network weights to random values, normally "
493 "distributed with mean of 1 and width of 0.5. Has no effect "
494 "if 'resume' is used."
495 ),
496 )
497 ap.add_argument(
498 "--beta",
499 type=float,
500 default=0.1,
501 help=(
502 "Scaling factor for the pion binary cross entropy term in "
503 "the loss function. Defaults to 0.1."
504 ),
505 )
506 return ap
507
508
509def validate_args(args):
510 # validate some values
511 assert args.n_epochs > 0, "Number of epochs must be larger than 0."
512 assert args.p_lims[0] < args.p_lims[1], "p_lims: lower limit must be < upper limit."
513 assert (
514 args.theta_lims[0] < args.theta_lims[1]
515 ), "theta_lims: lower limit must be < upper limit."
516 if args.only is not None:
517 for pdg in args.only:
518 assert pdg in PDG_CODES, f"Given PDG code {pdg} not understood."
519
520 if args.resume == "use_output":
521 args.resume = args.output
522
523 return args
524
525
526def main():
527 """Main network training logic."""
528
529 args = get_parser().parse_args()
530 args = validate_args(args)
531
532 print("Welcome to the network trainer.")
533 print(f"Data will be read from {args.input}.")
534 if args.resume is not None:
535 print(f"Model training will be continued from the saved state at {args.resume}")
536 print(f"The trained model will be saved to {args.output}.")
537
538 if not (args.p_lims[0] == -np.inf and args.p_lims[0] == +np.inf):
539 print(
540 f"Model will be trained on data with momentum in "
541 f"[{args.p_lims[0]}, {args.p_lims[1]}] GeV."
542 )
543
544 if not (args.theta_lims[0] == -np.inf and args.theta_lims[0] == +np.inf):
545 print(
546 f"Model will be trained on data with theta in "
547 f"[{args.theta_lims[0]}, {args.theta_lims[1]}] degrees."
548 )
549
550 print(f"The model will be trained for {args.n_epochs} epochs.")
551 print(f"Training will use a pion scaling factor of beta = {args.beta}.")
552 print("---")
553
554 train_model(args)
555 print("\nFinished!")
556
557
558if __name__ == "__main__":
559 main()
def random_init(self, mean=1.0, std=0.5)
n_detector
number of detectors
def __init__(self, n_class=6, n_detector=6, const_init=1)
def get_weights(self, to_numpy=True)
fcs
linear layers for each particle type
def const_init(self, const)
n_class
number of particle types
Definition: main.py:1