Belle II Software  release-08-01-10
pidTrainWeights.py
1 #!/usr/bin/env python3
2 
3 
10 
11 """
12 Methods and a script for training PID calibration weights.
13 
14 Sample usage:
15  $ python pidTrainWeights.py data/ models/net.pt -n 100
16 
17 Use `python pidTrainWeights.py -h` to see all command-line options.
18 """
19 
20 import torch
21 import torch.nn as nn
22 import torch.nn.functional as F
23 import torch.optim as optim
24 
25 import numpy as np
26 
27 from os import makedirs
28 from os.path import join, dirname
29 from tqdm.auto import tqdm
30 
31 
32 def _make_const_lists():
33  """Moving this code into a function to avoid a top-level ROOT import."""
34  import ROOT.Belle2
35 
36  PARTICLES, PDG_CODES = [], []
37  for i in range(len(ROOT.Belle2.Const.chargedStableSet)):
38  particle = ROOT.Belle2.Const.chargedStableSet.at(i)
39  name = (particle.__repr__()[7:-1]
40  .replace("-", "")
41  .replace("+", "")
42  .replace("euteron", ""))
43  PARTICLES.append(name)
44  PDG_CODES.append(particle.getPDGCode())
45  # PARTICLES = ["e", "mu", "pi", "K", "p", "d"]
46  # PDG_CODES = [11, 13, 211, 321, 2212, 1000010020]
47 
48  DETECTORS = []
49  for det in ROOT.Belle2.Const.PIDDetectors.set():
50  DETECTORS.append(ROOT.Belle2.Const.parseDetectors(det))
51  # DETECTORS = ["SVD", "CDC", "TOP", "ARICH", "ECL", "KLM"]
52 
53  return PARTICLES, PDG_CODES, DETECTORS
54 
55 
56 # PARTICLES, PDG_CODES, DETECTORS = _make_const_lists()
57 PARTICLES = ["e", "mu", "pi", "K", "p", "d"]
58 PDG_CODES = [11, 13, 211, 321, 2212, 1000010020]
59 DETECTORS = ["SVD", "CDC", "TOP", "ARICH", "ECL", "KLM"]
60 
61 
62 class WeightNet(nn.Module):
63  """PyTorch architecture for training calibration weights."""
64 
65  def __init__(self, n_class=6, n_detector=6, const_init=1):
66  """Initialize the network for training.
67 
68  Args:
69  n_class (int, optional): Number of classification classes (particle
70  types). Defaults to 6.
71  n_detector (int, optional): Number of detectors. Defaults to 6.
72  const_init (int, optional): Constant value to initialize all
73  weights. If None, PyTorch's default random initialization is
74  used instead. Defaults to 1.
75  """
76  super().__init__()
77 
78 
79  self.n_classn_class = n_class
80 
81 
82  self.n_detectorn_detector = n_detector
83 
84 
85  self.fcsfcs = nn.ModuleList(
86  [nn.Linear(self.n_detectorn_detector, 1, bias=False) for _ in range(self.n_classn_class)]
87  )
88 
89  if const_init is not None:
90  self.const_initconst_init(const_init)
91 
92  def forward(self, x):
93  """Network's forward methods. Sums the detector log-likelihoods for each particle
94  type, then computes the likelihood ratios. Uses the weights.
95 
96  Args:
97  x (torch.Tensor): Input detector log-likelihood data. Should be of
98  shape (N, n_detector * n_class), where N is the number of samples.
99 
100  Returns:
101  torch.Tensor: Weighted likelihood ratios.
102  """
103  n = self.n_detectorn_detector
104  outs = [self.fcsfcs[i](x[:, i * n: (i + 1) * n]) for i in range(self.n_classn_class)]
105  out = torch.cat(outs, dim=1)
106  return F.softmax(out, dim=1)
107 
108  def get_weights(self, to_numpy=True):
109  """Returns the weights as a six-by-six array or tensor.
110 
111  Args:
112  to_numpy (bool, optional): Whether to return the weights as a numpy
113  array (True) or torch tensor (False). Defaults to True.
114 
115  Returns:
116  np.array or torch.Tensor: The six-by-six matrix containing the
117  weights.
118  """
119  weights = torch.cat([fc.weight.detach() for fc in self.fcsfcs])
120  if to_numpy:
121  return weights.cpu().numpy()
122  else:
123  return weights
124 
125  def const_init(self, const):
126  """Fill all the weights with the given value.
127 
128  Args:
129  const (float): Constant value to fill all weights with.
130  """
131  with torch.no_grad():
132  for fc in self.fcsfcs:
133  fc.weight.fill_(const)
134 
135  def random_init(self, mean=1.0, std=0.5):
136  """Fill all the weights with values sampled from a Normal distribution
137  with given mean and standard deviation.
138 
139  Args:
140  mean (float, optional): The mean of the Normal distribution.
141  Defaults to 1.0.
142  std (float, optional): The standard deviation of the Normal
143  distribution. Defaults to 0.5.
144  """
145  with torch.no_grad():
146  for fc in self.fcsfcs:
147  fc.weight.fill_(0)
148  fc.weight.add_(torch.normal(mean=mean, std=std, size=fc.weight.size()))
149 
150  def kill_unused(self, only):
151  """Kills weights corresponding to unused particle types.
152 
153  Args:
154  only (list(str) or None): List of allowed particle types. The
155  weights corresponding to any particle types that are _not_ in
156  this list will be filled with zero and be frozen (e.g. gradients
157  will not be computed/updated).
158  """
159  if only is not None:
160  # particle types that are not being trained...
161  # set to zero and freeze
162  for i, pdg in enumerate(PDG_CODES):
163  if pdg in only:
164  continue
165  self.fcsfcs[i].weight.requires_grad = False
166  self.fcsfcs[i].weight.fill_(1)
167 
168 
169 def load_training_data(directory, p_lims=None, theta_lims=None, device=None):
170  """Loads training and validation data within the given momentum and theta
171  limits (if given).
172 
173  Args:
174  directory (str): Directory containing the train and validation sets.
175  p_lims (tuple(float), optional): Minimum and maximum momentum. Defaults
176  to None.
177  theta_lims (tuple(float), optional): Minimum and maximum theta in
178  degrees. Defaults to None.
179  device (torch.device, optional): Device to move the data onto. Defaults
180  to None.
181 
182  Returns:
183  torch.Tensor: Training log-likelihood data.
184  torch.Tensor: Training labels.
185  torch.Tensor: Validation log-likelihood data.
186  torch.Tensor: Validation labels.
187  """
188  p_lo, p_hi = p_lims if p_lims is not None else (-np.inf, +np.inf)
189  t_lo, t_hi = theta_lims if theta_lims is not None else (-np.inf, +np.inf)
190  t_lo, t_hi = np.radians(t_lo), np.radians(t_hi)
191 
192  def _load(filename):
193  data = np.load(filename)
194  X, y, p, t = data["X"], data["y"], data["p"], data["theta"]
195  mask = np.logical_and.reduce([p >= p_lo, p <= p_hi, t >= t_lo, t <= t_hi])
196  X = torch.tensor(X[mask]).to(device=device, dtype=torch.float)
197  y = torch.tensor(y[mask]).to(device=device, dtype=torch.long)
198  return X, y
199 
200  X_tr, y_tr = _load(join(directory, "train.npz"))
201  X_va, y_va = _load(join(directory, "val.npz"))
202  return X_tr, y_tr, X_va, y_va
203 
204 
205 def load_checkpoint(filename, device=None, only=None):
206  """Loads training from a checkpoint.
207 
208  Args:
209  filename (str): Checkpoint filename.
210  device (torch.device, optional): Device to move the data onto. Defaults
211  to None.
212  only (list(str), optional): List of allowed particle types. Defaults to
213  None.
214 
215  Returns:
216  WeightNet: The network.
217  optim.Optimizer: The optimizer.
218  int: Epoch number.
219  dict(str -> list(float)): Training losses (diag, pion, sum) from each
220  epoch.
221  dict(str -> list(float)): Validation losses (diag, pion, sum) from every
222  tenth epoch.
223  dict(str -> list(float)): Training accuracies (net, pion) from each
224  epoch.
225  dict(str -> list(float)): Validation accuracies (net, pion) from every
226  tenth epoch.
227  """
228  checkpoint = torch.load(filename, map_location=torch.device("cpu"))
229  net = WeightNet()
230  net.load_state_dict(checkpoint["model_state_dict"])
231  net.kill_unused(only)
232  net.to(device=device)
233 
234  opt = optim.Adam(filter(lambda p: p.requires_grad, net.parameters()), lr=5e-4)
235  opt.load_state_dict(checkpoint["optimizer_state_dict"])
236 
237  return (
238  net,
239  opt,
240  checkpoint["epoch"],
241  checkpoint["loss_t"],
242  checkpoint["loss_v"],
243  checkpoint["accu_t"],
244  checkpoint["accu_v"],
245  )
246 
247 
248 def save_checkpoint(filename, net, opt, epoch, loss_t, loss_v, accu_t, accu_v):
249  """Saves training to a checkpoint.
250 
251  Args:
252  filename (str): Checkpoint filename.
253  net (WeightNet): The network.
254  opt (optim.Optimizer): The optimizer.
255  epoch (int): The current epoch number.
256  loss_t (dict(str -> list(float))): Training losses (diag, pion, sum)
257  from each epoch.
258  loss_v (dict(str -> list(float))): Validation losses (diag, pion, sum)
259  from every tenth epoch.
260  accu_t (dict(str -> list(float))): Training accuracies (net, pion) from
261  each epoch.
262  accu_v (dict(str -> list(float))): Validation accuracies (net, pion)
263  from every tenth epoch.
264  """
265  net.cpu()
266  makedirs(dirname(filename), exist_ok=True)
267  torch.save(
268  {
269  "model_state_dict": net.state_dict(),
270  "optimizer_state_dict": opt.state_dict(),
271  "epoch": epoch,
272  "loss_t": loss_t,
273  "loss_v": loss_v,
274  "accu_t": accu_t,
275  "accu_v": accu_v,
276  },
277  filename,
278  )
279 
280 
281 def initialize(args, device=None):
282  """Initializes training from the parsed command-line arguments.
283 
284  Args:
285  args (argparse.Namespace): Parsed command-line arguments.
286  device (torch.device, optional): Device to move the data onto. Defaults
287  to None.
288 
289  Returns:
290  WeightNet: The network.
291  optim.Optimizer: The optimizer.
292  int: Epoch number.
293  dict(str -> list(float)): Training losses (diag, pion, sum) from each
294  epoch.
295  dict(str -> list(float)): Validation losses (diag, pion, sum) from every
296  tenth epoch.
297  dict(str -> list(float)): Training accuracies (net, pion) from each
298  epoch.
299  dict(str -> list(float)): Validation accuracies (net, pion) from every
300  tenth epoch.
301  """
302  if args.resume is not None:
303  net, opt, epochs_0, l_t, l_v, a_t, a_v = load_checkpoint(args.resume, device=device, only=args.only)
304 
305  else:
306  net = WeightNet()
307  if args.random:
308  net.random_init()
309  net.kill_unused(args.only)
310  net.to(device=device)
311  opt = optim.Adam(filter(lambda p: p.requires_grad, net.parameters()), lr=5e-4)
312  epochs_0 = 0
313  l_t = {"diag": [], "pion": [], "sum": []}
314  l_v = {"diag": [], "pion": [], "sum": []}
315  a_t = {"net": [], "pion": []}
316  a_v = {"net": [], "pion": []}
317 
318  return net, opt, epochs_0, l_t, l_v, a_t, a_v
319 
320 
321 def train_model(args, use_tqdm=True):
322  """Trains and saves a model according to the command-line arguments.
323 
324  Args:
325  args (argparse.Namespace): Parsed command-line arguments.
326  use_tqdm (bool, optional): Use TQDM to track progress. Defaults to True.
327  """
328  print("Reading data.")
329  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
330  print("...and moving to", device)
331  X_tr, y_tr, X_va, y_va = load_training_data(
332  args.input, device=device, p_lims=args.p_lims, theta_lims=args.theta_lims
333  )
334 
335  if len(y_tr) < 10 or len(y_va) < 10:
336  print("Not enough data. Aborting...")
337  return
338 
339  print(f"{len(y_tr)} train events, {len(y_va)} val events")
340 
341  print("Initializing network.")
342  net, opt, epochs_0, loss_t, loss_v, accu_t, accu_v = initialize(args, device=device)
343 
344  diag_lfn = nn.CrossEntropyLoss()
345  pion_lfn = nn.BCELoss()
346  pion_wgt = args.beta
347 
348  def compute_accuracies(out, y):
349  output = out.detach().cpu().numpy()
350  target = y.detach().cpu().numpy()
351  pred = np.squeeze(np.argmax(output, axis=1))
352  accu = np.count_nonzero(pred == target) / len(pred)
353  pi_out = output[(target == 2), 2]
354  pi_pred = (pi_out > 0.5).astype(float)
355  pi_accu = pi_pred.sum() / len(pi_pred)
356  return accu, pi_accu
357 
358  def lfn(output, target):
359  diag = diag_lfn(output, target)
360  pi_mask = target == 2
361  pi_out = output[pi_mask, 2]
362  pi_y = (target[pi_mask] == 2).float()
363  pion = pion_lfn(pi_out, pi_y)
364  return diag + pion_wgt * pion, diag, pion
365 
366  print(f"Training network for {args.n_epochs} epochs.")
367 
368  iterator = range(args.n_epochs)
369  if use_tqdm:
370  iterator = tqdm(iterator)
371 
372  for epoch in iterator:
373  # train step
374  net.train()
375  opt.zero_grad()
376  out = net(X_tr)
377  loss, diag, pion = lfn(out, y_tr)
378  loss.backward()
379  opt.step()
380 
381  # record training data
382  loss_t["diag"].append(diag.item())
383  loss_t["pion"].append(pion.item())
384  loss_t["sum"].append(loss.item())
385 
386  if epoch % 10 == 0:
387  accu, pi_accu = compute_accuracies(out, y_tr)
388  accu_t["net"].append(accu)
389  accu_t["pion"].append(pi_accu)
390 
391  # val step
392  net.eval()
393  with torch.no_grad():
394  out = net(X_va)
395  loss, diag, pion = lfn(out, y_va)
396 
397  loss_v["diag"].append(diag.item())
398  loss_v["pion"].append(pion.item())
399  loss_v["sum"].append(loss.item())
400 
401  if epoch % 10 == 0:
402  accu, pi_accu = compute_accuracies(out, y_va)
403  accu_v["net"].append(accu)
404  accu_v["pion"].append(pi_accu)
405 
406  print("Training complete.")
407 
408  net.cpu()
409  save_checkpoint(
410  args.output, net, opt, epochs_0 + args.n_epochs, loss_t, loss_v, accu_t, accu_v,
411  )
412 
413  wgt = net.get_weights(to_numpy=True)
414  np.save(args.output.replace(".pt", "_wgt.npy"), wgt)
415 
416  print(f"Model saved to {args.output}.")
417 
418 
419 def get_parser():
420  """Handles the command-line argument parsing.
421 
422  Returns:
423  argparse.Namespace: The parsed arguments.
424  """
425  from argparse import ArgumentParser
426 
427  ap = ArgumentParser(description="", epilog="")
428  ap.add_argument(
429  "input", type=str, help="Path to folder with training files (in .npz format).",
430  )
431  ap.add_argument(
432  "output",
433  type=str,
434  help="Output filename where model will be saved (should end in .pt).",
435  )
436  ap.add_argument(
437  "-n",
438  "--n_epochs",
439  type=int,
440  default=500,
441  help="Number of epochs to train the network. Defaults to 500.",
442  )
443  ap.add_argument(
444  "--p_lims",
445  nargs=2,
446  default=[-float('inf'), +float('inf')],
447  help=(
448  "Lower and upper limits for momentum in GeV. Lower limit "
449  "should be given first. Default values are -inf, +inf."
450  ),
451  )
452  ap.add_argument(
453  "--theta_lims",
454  nargs=2,
455  default=[-float('inf'), +float('inf')],
456  help=(
457  "Lower and upper limits for theta in degrees. Lower limit "
458  "should be given first. Default values are -inf, +inf."
459  ),
460  )
461  ap.add_argument(
462  "-R",
463  "--resume",
464  nargs="?",
465  type=str,
466  const="use_output",
467  default=None,
468  help=(
469  "Load a pre-existing model and resume training instead of "
470  "starting a new one. If '--resume' is used and no file is "
471  "specified, the output filepath will be loaded, and the "
472  "training will overwrite that file. Alternatively, if a "
473  "filepath is given, training will begin from the state "
474  "saved in that file and will be saved to the output filepath."
475  ),
476  )
477  ap.add_argument(
478  "--only",
479  type=int,
480  nargs="*",
481  help=(
482  "Use only log-likelihood data from a subset of particle "
483  "types specified by PDG code. If left unspecified, all "
484  "particle types will be used."
485  ),
486  )
487  ap.add_argument(
488  "--random",
489  action="store_true",
490  help=(
491  "Initialize network weights to random values, normally "
492  "distributed with mean of 1 and width of 0.5. Has no effect "
493  "if 'resume' is used."
494  ),
495  )
496  ap.add_argument(
497  "--beta",
498  type=float,
499  default=0.1,
500  help=(
501  "Scaling factor for the pion binary cross entropy term in "
502  "the loss function. Defaults to 0.1."
503  ),
504  )
505  return ap
506 
507 
508 def validate_args(args):
509  # validate some values
510  assert args.n_epochs > 0, "Number of epochs must be larger than 0."
511  assert args.p_lims[0] < args.p_lims[1], "p_lims: lower limit must be < upper limit."
512  assert (
513  args.theta_lims[0] < args.theta_lims[1]
514  ), "theta_lims: lower limit must be < upper limit."
515  if args.only is not None:
516  for pdg in args.only:
517  assert pdg in PDG_CODES, f"Given PDG code {pdg} not understood."
518 
519  if args.resume == "use_output":
520  args.resume = args.output
521 
522  return args
523 
524 
525 def main():
526  """Main network training logic."""
527 
528  args = get_parser().parse_args()
529  args = validate_args(args)
530 
531  print("Welcome to the network trainer.")
532  print(f"Data will be read from {args.input}.")
533  if args.resume is not None:
534  print(f"Model training will be continued from the saved state at {args.resume}")
535  print(f"The trained model will be saved to {args.output}.")
536 
537  if not (args.p_lims[0] == -np.inf and args.p_lims[0] == +np.inf):
538  print(
539  f"Model will be trained on data with momentum in "
540  f"[{args.p_lims[0]}, {args.p_lims[1]}] GeV."
541  )
542 
543  if not (args.theta_lims[0] == -np.inf and args.theta_lims[0] == +np.inf):
544  print(
545  f"Model will be trained on data with theta in "
546  f"[{args.theta_lims[0]}, {args.theta_lims[1]}] degrees."
547  )
548 
549  print(f"The model will be trained for {args.n_epochs} epochs.")
550  print(f"Training will use a pion scaling factor of beta = {args.beta}.")
551  print("---")
552 
553  train_model(args)
554  print("\nFinished!")
555 
556 
557 if __name__ == "__main__":
558  main()
def random_init(self, mean=1.0, std=0.5)
n_detector
number of detectors
def __init__(self, n_class=6, n_detector=6, const_init=1)
def get_weights(self, to_numpy=True)
fcs
linear layers for each particle type
def const_init(self, const)
n_class
number of particle types
std::map< ExpRun, std::pair< double, double > > filter(const std::map< ExpRun, std::pair< double, double >> &runs, double cut, std::map< ExpRun, std::pair< double, double >> &runsRemoved)
filter events to remove runs shorter than cut, it stores removed runs in runsRemoved
Definition: Splitter.cc:38
Definition: main.py:1
int main(int argc, char **argv)
Run all tests.
Definition: test_main.cc:91