12Methods and a script for training PID calibration weights.
15 $ python pidTrainWeights.py data/ models/net.pt -n 100
17Use `python pidTrainWeights.py -h` to see all command-line options.
22import torch.nn.functional as F
23import torch.optim as optim
27from os import makedirs
28from os.path import join, dirname
29from tqdm.auto import tqdm
32def _make_const_lists():
33 """Moving this code into a function to avoid a top-level ROOT import."""
34 from ROOT
import Belle2
37 PARTICLES, PDG_CODES = [], []
38 for i
in range(len(ROOT.Belle2.Const.chargedStableSet)):
39 particle = ROOT.Belle2.Const.chargedStableSet.at(i)
40 name = (particle.__repr__()[7:-1]
43 .replace(
"euteron",
""))
44 PARTICLES.append(name)
45 PDG_CODES.append(particle.getPDGCode())
50 for det
in ROOT.Belle2.Const.PIDDetectors.set():
51 DETECTORS.append(ROOT.Belle2.Const.parseDetectors(det))
54 return PARTICLES, PDG_CODES, DETECTORS
58PARTICLES = [
"e",
"mu",
"pi",
"K",
"p",
"d"]
59PDG_CODES = [11, 13, 211, 321, 2212, 1000010020]
60DETECTORS = [
"SVD",
"CDC",
"TOP",
"ARICH",
"ECL",
"KLM"]
64 """PyTorch architecture for training calibration weights."""
66 def __init__(self, n_class=6, n_detector=6, const_init=1):
67 """Initialize the network for training.
70 n_class (int, optional): Number of classification classes (particle
71 types). Defaults to 6.
72 n_detector (int, optional): Number of detectors. Defaults to 6.
73 const_init (int, optional): Constant value to initialize all
74 weights. If None, PyTorch
's default random initialization is
75 used instead. Defaults to 1.
86 self.fcs = nn.ModuleList(
90 if const_init
is not None:
94 """Network's forward methods. Sums the detector log-likelihoods for each particle
95 type, then computes the likelihood ratios. Uses the weights.
98 x (torch.Tensor): Input detector log-likelihood data. Should be of
99 shape (N, n_detector * n_class), where N is the number of samples.
102 torch.Tensor: Weighted likelihood ratios.
105 outs = [self.fcs[i](x[:, i * n: (i + 1) * n]) for i
in range(self.
n_class)]
106 out = torch.cat(outs, dim=1)
107 return F.softmax(out, dim=1)
110 """Returns the weights as a six-by-six array or tensor.
113 to_numpy (bool, optional): Whether to return the weights
as a numpy
114 array (
True)
or torch tensor (
False). Defaults to
True.
117 np.array
or torch.Tensor: The six-by-six matrix containing the
120 weights = torch.cat([fc.weight.detach() for fc
in self.
fcs])
122 return weights.cpu().numpy()
127 """Fill all the weights with the given value.
130 const (float): Constant value to fill all weights with.
132 with torch.no_grad():
134 fc.weight.fill_(const)
137 """Fill all the weights with values sampled from a Normal distribution
138 with given mean
and standard deviation.
141 mean (float, optional): The mean of the Normal distribution.
143 std (float, optional): The standard deviation of the Normal
144 distribution. Defaults to 0.5.
146 with torch.no_grad():
149 fc.weight.add_(torch.normal(mean=mean, std=std, size=fc.weight.size()))
152 """Kills weights corresponding to unused particle types.
155 only (list(str) or None): List of allowed particle types. The
156 weights corresponding to any particle types that are _not_
in
157 this list will be filled
with zero
and be frozen (e.g. gradients
158 will
not be computed/updated).
163 for i, pdg
in enumerate(PDG_CODES):
166 self.
fcs[i].weight.requires_grad =
False
167 self.
fcs[i].weight.fill_(1)
170def load_training_data(directory, p_lims=None, theta_lims=None, device=None):
171 """Loads training and validation data within the given momentum and theta
175 directory (str): Directory containing the train
and validation sets.
176 p_lims (tuple(float), optional): Minimum
and maximum momentum. Defaults
178 theta_lims (tuple(float), optional): Minimum
and maximum theta
in
179 degrees. Defaults to
None.
180 device (torch.device, optional): Device to move the data onto. Defaults
184 torch.Tensor: Training log-likelihood data.
185 torch.Tensor: Training labels.
186 torch.Tensor: Validation log-likelihood data.
187 torch.Tensor: Validation labels.
189 p_lo, p_hi = p_lims if p_lims
is not None else (-np.inf, +np.inf)
190 t_lo, t_hi = theta_lims
if theta_lims
is not None else (-np.inf, +np.inf)
191 t_lo, t_hi = np.radians(t_lo), np.radians(t_hi)
194 data = np.load(filename)
195 X, y, p, t = data[
"X"], data[
"y"], data[
"p"], data[
"theta"]
196 mask = np.logical_and.reduce([p >= p_lo, p <= p_hi, t >= t_lo, t <= t_hi])
197 X = torch.tensor(X[mask]).to(device=device, dtype=torch.float)
198 y = torch.tensor(y[mask]).to(device=device, dtype=torch.long)
201 X_tr, y_tr = _load(join(directory,
"train.npz"))
202 X_va, y_va = _load(join(directory,
"val.npz"))
203 return X_tr, y_tr, X_va, y_va
206def load_checkpoint(filename, device=None, only=None):
207 """Loads training from a checkpoint.
210 filename (str): Checkpoint filename.
211 device (torch.device, optional): Device to move the data onto. Defaults
213 only (list(str), optional): List of allowed particle types. Defaults to
217 WeightNet: The network.
218 optim.Optimizer: The optimizer.
220 dict(str -> list(float)): Training losses (diag, pion, sum)
from each
222 dict(str -> list(float)): Validation losses (diag, pion, sum)
from every
224 dict(str -> list(float)): Training accuracies (net, pion)
from each
226 dict(str -> list(float)): Validation accuracies (net, pion)
from every
229 checkpoint = torch.load(filename, map_location=torch.device("cpu"))
231 net.load_state_dict(checkpoint[
"model_state_dict"])
232 net.kill_unused(only)
233 net.to(device=device)
235 opt = optim.Adam(filter(
lambda p: p.requires_grad, net.parameters()), lr=5e-4)
236 opt.load_state_dict(checkpoint[
"optimizer_state_dict"])
242 checkpoint[
"loss_t"],
243 checkpoint[
"loss_v"],
244 checkpoint[
"accu_t"],
245 checkpoint[
"accu_v"],
249def save_checkpoint(filename, net, opt, epoch, loss_t, loss_v, accu_t, accu_v):
250 """Saves training to a checkpoint.
253 filename (str): Checkpoint filename.
254 net (WeightNet): The network.
255 opt (optim.Optimizer): The optimizer.
256 epoch (int): The current epoch number.
257 loss_t (dict(str -> list(float))): Training losses (diag, pion, sum)
259 loss_v (dict(str -> list(float))): Validation losses (diag, pion, sum)
260 from every tenth epoch.
261 accu_t (dict(str -> list(float))): Training accuracies (net, pion)
from
263 accu_v (dict(str -> list(float))): Validation accuracies (net, pion)
264 from every tenth epoch.
267 makedirs(dirname(filename), exist_ok=True)
270 "model_state_dict": net.state_dict(),
271 "optimizer_state_dict": opt.state_dict(),
282def initialize(args, device=None):
283 """Initializes training from the parsed command-line arguments.
286 args (argparse.Namespace): Parsed command-line arguments.
287 device (torch.device, optional): Device to move the data onto. Defaults
291 WeightNet: The network.
292 optim.Optimizer: The optimizer.
294 dict(str -> list(float)): Training losses (diag, pion, sum)
from each
296 dict(str -> list(float)): Validation losses (diag, pion, sum)
from every
298 dict(str -> list(float)): Training accuracies (net, pion)
from each
300 dict(str -> list(float)): Validation accuracies (net, pion)
from every
303 if args.resume
is not None:
304 net, opt, epochs_0, l_t, l_v, a_t, a_v = load_checkpoint(args.resume, device=device, only=args.only)
310 net.kill_unused(args.only)
311 net.to(device=device)
312 opt = optim.Adam(filter(
lambda p: p.requires_grad, net.parameters()), lr=5e-4)
314 l_t = {
"diag": [],
"pion": [],
"sum": []}
315 l_v = {
"diag": [],
"pion": [],
"sum": []}
316 a_t = {
"net": [],
"pion": []}
317 a_v = {
"net": [],
"pion": []}
319 return net, opt, epochs_0, l_t, l_v, a_t, a_v
323 """Trains and saves a model according to the command-line arguments.
326 args (argparse.Namespace): Parsed command-line arguments.
327 use_tqdm (bool, optional): Use TQDM to track progress. Defaults to True.
329 print("Reading data.")
330 device = torch.device(
"cuda" if torch.cuda.is_available()
else "cpu")
331 print(
"...and moving to", device)
332 X_tr, y_tr, X_va, y_va = load_training_data(
333 args.input, device=device, p_lims=args.p_lims, theta_lims=args.theta_lims
336 if len(y_tr) < 10
or len(y_va) < 10:
337 print(
"Not enough data. Aborting...")
340 print(f
"{len(y_tr)} train events, {len(y_va)} val events")
342 print(
"Initializing network.")
343 net, opt, epochs_0, loss_t, loss_v, accu_t, accu_v = initialize(args, device=device)
345 diag_lfn = nn.CrossEntropyLoss()
346 pion_lfn = nn.BCELoss()
349 def compute_accuracies(out, y):
350 output = out.detach().cpu().numpy()
351 target = y.detach().cpu().numpy()
352 pred = np.squeeze(np.argmax(output, axis=1))
353 accu = np.count_nonzero(pred == target) / len(pred)
354 pi_out = output[(target == 2), 2]
355 pi_pred = (pi_out > 0.5).astype(float)
356 pi_accu = pi_pred.sum() / len(pi_pred)
359 def lfn(output, target):
360 diag = diag_lfn(output, target)
361 pi_mask = target == 2
362 pi_out = output[pi_mask, 2]
363 pi_y = (target[pi_mask] == 2).float()
364 pion = pion_lfn(pi_out, pi_y)
365 return diag + pion_wgt * pion, diag, pion
367 print(f
"Training network for {args.n_epochs} epochs.")
369 iterator = range(args.n_epochs)
371 iterator = tqdm(iterator)
373 for epoch
in iterator:
378 loss, diag, pion = lfn(out, y_tr)
383 loss_t[
"diag"].append(diag.item())
384 loss_t[
"pion"].append(pion.item())
385 loss_t[
"sum"].append(loss.item())
388 accu, pi_accu = compute_accuracies(out, y_tr)
389 accu_t[
"net"].append(accu)
390 accu_t[
"pion"].append(pi_accu)
394 with torch.no_grad():
396 loss, diag, pion = lfn(out, y_va)
398 loss_v[
"diag"].append(diag.item())
399 loss_v[
"pion"].append(pion.item())
400 loss_v[
"sum"].append(loss.item())
403 accu, pi_accu = compute_accuracies(out, y_va)
404 accu_v[
"net"].append(accu)
405 accu_v[
"pion"].append(pi_accu)
407 print(
"Training complete.")
411 args.output, net, opt, epochs_0 + args.n_epochs, loss_t, loss_v, accu_t, accu_v,
414 wgt = net.get_weights(to_numpy=
True)
415 np.save(args.output.replace(
".pt",
"_wgt.npy"), wgt)
417 print(f
"Model saved to {args.output}.")
421 """Handles the command-line argument parsing.
424 argparse.Namespace: The parsed arguments.
426 from argparse
import ArgumentParser
428 ap = ArgumentParser(description=
"", epilog=
"")
430 "input", type=str, help=
"Path to folder with training files (in .npz format).",
435 help=
"Output filename where model will be saved (should end in .pt).",
442 help=
"Number of epochs to train the network. Defaults to 500.",
447 default=[-float(
'inf'), +float(
'inf')],
449 "Lower and upper limits for momentum in GeV. Lower limit "
450 "should be given first. Default values are -inf, +inf."
456 default=[-float(
'inf'), +float(
'inf')],
458 "Lower and upper limits for theta in degrees. Lower limit "
459 "should be given first. Default values are -inf, +inf."
470 "Load a pre-existing model and resume training instead of "
471 "starting a new one. If '--resume' is used and no file is "
472 "specified, the output filepath will be loaded, and the "
473 "training will overwrite that file. Alternatively, if a "
474 "filepath is given, training will begin from the state "
475 "saved in that file and will be saved to the output filepath."
483 "Use only log-likelihood data from a subset of particle "
484 "types specified by PDG code. If left unspecified, all "
485 "particle types will be used."
492 "Initialize network weights to random values, normally "
493 "distributed with mean of 1 and width of 0.5. Has no effect "
494 "if 'resume' is used."
502 "Scaling factor for the pion binary cross entropy term in "
503 "the loss function. Defaults to 0.1."
509def validate_args(args):
511 assert args.n_epochs > 0,
"Number of epochs must be larger than 0."
512 assert args.p_lims[0] < args.p_lims[1],
"p_lims: lower limit must be < upper limit."
514 args.theta_lims[0] < args.theta_lims[1]
515 ),
"theta_lims: lower limit must be < upper limit."
516 if args.only
is not None:
517 for pdg
in args.only:
518 assert pdg
in PDG_CODES, f
"Given PDG code {pdg} not understood."
520 if args.resume ==
"use_output":
521 args.resume = args.output
527 """Main network training logic."""
529 args = get_parser().parse_args()
530 args = validate_args(args)
532 print(
"Welcome to the network trainer.")
533 print(f
"Data will be read from {args.input}.")
534 if args.resume
is not None:
535 print(f
"Model training will be continued from the saved state at {args.resume}")
536 print(f
"The trained model will be saved to {args.output}.")
538 if not (args.p_lims[0] == -np.inf
and args.p_lims[0] == +np.inf):
540 f
"Model will be trained on data with momentum in "
541 f
"[{args.p_lims[0]}, {args.p_lims[1]}] GeV."
544 if not (args.theta_lims[0] == -np.inf
and args.theta_lims[0] == +np.inf):
546 f
"Model will be trained on data with theta in "
547 f
"[{args.theta_lims[0]}, {args.theta_lims[1]}] degrees."
550 print(f
"The model will be trained for {args.n_epochs} epochs.")
551 print(f
"Training will use a pion scaling factor of beta = {args.beta}.")
558if __name__ ==
"__main__":
def random_init(self, mean=1.0, std=0.5)
def kill_unused(self, only)
n_detector
number of detectors
def __init__(self, n_class=6, n_detector=6, const_init=1)
def get_weights(self, to_numpy=True)
fcs
linear layers for each particle type
def const_init(self, const)
n_class
number of particle types