12Methods and a script for training PID calibration weights.
15 $ python pidTrainWeights.py data/ models/net.pt -n 100
17Use `python pidTrainWeights.py -h` to see all command-line options.
22import torch.nn.functional as F
23import torch.optim as optim
27from os import makedirs
28from os.path import join, dirname
29from tqdm.auto import tqdm
32def _make_const_lists():
33 """Moving this code into a function to avoid a top-level ROOT import."""
36 PARTICLES, PDG_CODES = [], []
37 for i
in range(len(ROOT.Belle2.Const.chargedStableSet)):
38 particle = ROOT.Belle2.Const.chargedStableSet.at(i)
39 name = (particle.__repr__()[7:-1]
42 .replace(
"euteron",
""))
43 PARTICLES.append(name)
44 PDG_CODES.append(particle.getPDGCode())
49 for det
in ROOT.Belle2.Const.PIDDetectors.set():
50 DETECTORS.append(ROOT.Belle2.Const.parseDetectors(det))
53 return PARTICLES, PDG_CODES, DETECTORS
57PARTICLES = [
"e",
"mu",
"pi",
"K",
"p",
"d"]
58PDG_CODES = [11, 13, 211, 321, 2212, 1000010020]
59DETECTORS = [
"SVD",
"CDC",
"TOP",
"ARICH",
"ECL",
"KLM"]
63 """PyTorch architecture for training calibration weights."""
65 def __init__(self, n_class=6, n_detector=6, const_init=1):
66 """Initialize the network for training.
69 n_class (int, optional): Number of classification classes (particle
70 types). Defaults to 6.
71 n_detector (int, optional): Number of detectors. Defaults to 6.
72 const_init (int, optional): Constant value to initialize all
73 weights. If None, PyTorch
's default random initialization is
74 used instead. Defaults to 1.
85 self.fcs = nn.ModuleList(
89 if const_init
is not None:
93 """Network's forward methods. Sums the detector log-likelihoods for each particle
94 type, then computes the likelihood ratios. Uses the weights.
97 x (torch.Tensor): Input detector log-likelihood data. Should be of
98 shape (N, n_detector * n_class), where N is the number of samples.
101 torch.Tensor: Weighted likelihood ratios.
104 outs = [self.fcs[i](x[:, i * n: (i + 1) * n]) for i
in range(self.
n_class)]
105 out = torch.cat(outs, dim=1)
106 return F.softmax(out, dim=1)
109 """Returns the weights as a six-by-six array or tensor.
112 to_numpy (bool, optional): Whether to return the weights
as a numpy
113 array (
True)
or torch tensor (
False). Defaults to
True.
116 np.array
or torch.Tensor: The six-by-six matrix containing the
119 weights = torch.cat([fc.weight.detach() for fc
in self.
fcs])
121 return weights.cpu().numpy()
126 """Fill all the weights with the given value.
129 const (float): Constant value to fill all weights with.
131 with torch.no_grad():
133 fc.weight.fill_(const)
136 """Fill all the weights with values sampled from a Normal distribution
137 with given mean
and standard deviation.
140 mean (float, optional): The mean of the Normal distribution.
142 std (float, optional): The standard deviation of the Normal
143 distribution. Defaults to 0.5.
145 with torch.no_grad():
148 fc.weight.add_(torch.normal(mean=mean, std=std, size=fc.weight.size()))
151 """Kills weights corresponding to unused particle types.
154 only (list(str) or None): List of allowed particle types. The
155 weights corresponding to any particle types that are _not_
in
156 this list will be filled
with zero
and be frozen (e.g. gradients
157 will
not be computed/updated).
162 for i, pdg
in enumerate(PDG_CODES):
165 self.
fcs[i].weight.requires_grad =
False
166 self.
fcs[i].weight.fill_(1)
169def load_training_data(directory, p_lims=None, theta_lims=None, device=None):
170 """Loads training and validation data within the given momentum and theta
174 directory (str): Directory containing the train
and validation sets.
175 p_lims (tuple(float), optional): Minimum
and maximum momentum. Defaults
177 theta_lims (tuple(float), optional): Minimum
and maximum theta
in
178 degrees. Defaults to
None.
179 device (torch.device, optional): Device to move the data onto. Defaults
183 torch.Tensor: Training log-likelihood data.
184 torch.Tensor: Training labels.
185 torch.Tensor: Validation log-likelihood data.
186 torch.Tensor: Validation labels.
188 p_lo, p_hi = p_lims if p_lims
is not None else (-np.inf, +np.inf)
189 t_lo, t_hi = theta_lims
if theta_lims
is not None else (-np.inf, +np.inf)
190 t_lo, t_hi = np.radians(t_lo), np.radians(t_hi)
193 data = np.load(filename)
194 X, y, p, t = data[
"X"], data[
"y"], data[
"p"], data[
"theta"]
195 mask = np.logical_and.reduce([p >= p_lo, p <= p_hi, t >= t_lo, t <= t_hi])
196 X = torch.tensor(X[mask]).to(device=device, dtype=torch.float)
197 y = torch.tensor(y[mask]).to(device=device, dtype=torch.long)
200 X_tr, y_tr = _load(join(directory,
"train.npz"))
201 X_va, y_va = _load(join(directory,
"val.npz"))
202 return X_tr, y_tr, X_va, y_va
205def load_checkpoint(filename, device=None, only=None):
206 """Loads training from a checkpoint.
209 filename (str): Checkpoint filename.
210 device (torch.device, optional): Device to move the data onto. Defaults
212 only (list(str), optional): List of allowed particle types. Defaults to
216 WeightNet: The network.
217 optim.Optimizer: The optimizer.
219 dict(str -> list(float)): Training losses (diag, pion, sum)
from each
221 dict(str -> list(float)): Validation losses (diag, pion, sum)
from every
223 dict(str -> list(float)): Training accuracies (net, pion)
from each
225 dict(str -> list(float)): Validation accuracies (net, pion)
from every
228 checkpoint = torch.load(filename, map_location=torch.device("cpu"))
230 net.load_state_dict(checkpoint[
"model_state_dict"])
231 net.kill_unused(only)
232 net.to(device=device)
234 opt = optim.Adam(filter(
lambda p: p.requires_grad, net.parameters()), lr=5e-4)
235 opt.load_state_dict(checkpoint[
"optimizer_state_dict"])
241 checkpoint[
"loss_t"],
242 checkpoint[
"loss_v"],
243 checkpoint[
"accu_t"],
244 checkpoint[
"accu_v"],
248def save_checkpoint(filename, net, opt, epoch, loss_t, loss_v, accu_t, accu_v):
249 """Saves training to a checkpoint.
252 filename (str): Checkpoint filename.
253 net (WeightNet): The network.
254 opt (optim.Optimizer): The optimizer.
255 epoch (int): The current epoch number.
256 loss_t (dict(str -> list(float))): Training losses (diag, pion, sum)
258 loss_v (dict(str -> list(float))): Validation losses (diag, pion, sum)
259 from every tenth epoch.
260 accu_t (dict(str -> list(float))): Training accuracies (net, pion)
from
262 accu_v (dict(str -> list(float))): Validation accuracies (net, pion)
263 from every tenth epoch.
266 makedirs(dirname(filename), exist_ok=True)
269 "model_state_dict": net.state_dict(),
270 "optimizer_state_dict": opt.state_dict(),
281def initialize(args, device=None):
282 """Initializes training from the parsed command-line arguments.
285 args (argparse.Namespace): Parsed command-line arguments.
286 device (torch.device, optional): Device to move the data onto. Defaults
290 WeightNet: The network.
291 optim.Optimizer: The optimizer.
293 dict(str -> list(float)): Training losses (diag, pion, sum)
from each
295 dict(str -> list(float)): Validation losses (diag, pion, sum)
from every
297 dict(str -> list(float)): Training accuracies (net, pion)
from each
299 dict(str -> list(float)): Validation accuracies (net, pion)
from every
302 if args.resume
is not None:
303 net, opt, epochs_0, l_t, l_v, a_t, a_v = load_checkpoint(args.resume, device=device, only=args.only)
309 net.kill_unused(args.only)
310 net.to(device=device)
311 opt = optim.Adam(filter(
lambda p: p.requires_grad, net.parameters()), lr=5e-4)
313 l_t = {
"diag": [],
"pion": [],
"sum": []}
314 l_v = {
"diag": [],
"pion": [],
"sum": []}
315 a_t = {
"net": [],
"pion": []}
316 a_v = {
"net": [],
"pion": []}
318 return net, opt, epochs_0, l_t, l_v, a_t, a_v
322 """Trains and saves a model according to the command-line arguments.
325 args (argparse.Namespace): Parsed command-line arguments.
326 use_tqdm (bool, optional): Use TQDM to track progress. Defaults to True.
328 print("Reading data.")
329 device = torch.device(
"cuda" if torch.cuda.is_available()
else "cpu")
330 print(
"...and moving to", device)
331 X_tr, y_tr, X_va, y_va = load_training_data(
332 args.input, device=device, p_lims=args.p_lims, theta_lims=args.theta_lims
335 if len(y_tr) < 10
or len(y_va) < 10:
336 print(
"Not enough data. Aborting...")
339 print(f
"{len(y_tr)} train events, {len(y_va)} val events")
341 print(
"Initializing network.")
342 net, opt, epochs_0, loss_t, loss_v, accu_t, accu_v = initialize(args, device=device)
344 diag_lfn = nn.CrossEntropyLoss()
345 pion_lfn = nn.BCELoss()
348 def compute_accuracies(out, y):
349 output = out.detach().cpu().numpy()
350 target = y.detach().cpu().numpy()
351 pred = np.squeeze(np.argmax(output, axis=1))
352 accu = np.count_nonzero(pred == target) / len(pred)
353 pi_out = output[(target == 2), 2]
354 pi_pred = (pi_out > 0.5).astype(float)
355 pi_accu = pi_pred.sum() / len(pi_pred)
358 def lfn(output, target):
359 diag = diag_lfn(output, target)
360 pi_mask = target == 2
361 pi_out = output[pi_mask, 2]
362 pi_y = (target[pi_mask] == 2).float()
363 pion = pion_lfn(pi_out, pi_y)
364 return diag + pion_wgt * pion, diag, pion
366 print(f
"Training network for {args.n_epochs} epochs.")
368 iterator = range(args.n_epochs)
370 iterator = tqdm(iterator)
372 for epoch
in iterator:
377 loss, diag, pion = lfn(out, y_tr)
382 loss_t[
"diag"].append(diag.item())
383 loss_t[
"pion"].append(pion.item())
384 loss_t[
"sum"].append(loss.item())
387 accu, pi_accu = compute_accuracies(out, y_tr)
388 accu_t[
"net"].append(accu)
389 accu_t[
"pion"].append(pi_accu)
393 with torch.no_grad():
395 loss, diag, pion = lfn(out, y_va)
397 loss_v[
"diag"].append(diag.item())
398 loss_v[
"pion"].append(pion.item())
399 loss_v[
"sum"].append(loss.item())
402 accu, pi_accu = compute_accuracies(out, y_va)
403 accu_v[
"net"].append(accu)
404 accu_v[
"pion"].append(pi_accu)
406 print(
"Training complete.")
410 args.output, net, opt, epochs_0 + args.n_epochs, loss_t, loss_v, accu_t, accu_v,
413 wgt = net.get_weights(to_numpy=
True)
414 np.save(args.output.replace(
".pt",
"_wgt.npy"), wgt)
416 print(f
"Model saved to {args.output}.")
420 """Handles the command-line argument parsing.
423 argparse.Namespace: The parsed arguments.
425 from argparse
import ArgumentParser
427 ap = ArgumentParser(description=
"", epilog=
"")
429 "input", type=str, help=
"Path to folder with training files (in .npz format).",
434 help=
"Output filename where model will be saved (should end in .pt).",
441 help=
"Number of epochs to train the network. Defaults to 500.",
446 default=[-float(
'inf'), +float(
'inf')],
448 "Lower and upper limits for momentum in GeV. Lower limit "
449 "should be given first. Default values are -inf, +inf."
455 default=[-float(
'inf'), +float(
'inf')],
457 "Lower and upper limits for theta in degrees. Lower limit "
458 "should be given first. Default values are -inf, +inf."
469 "Load a pre-existing model and resume training instead of "
470 "starting a new one. If '--resume' is used and no file is "
471 "specified, the output filepath will be loaded, and the "
472 "training will overwrite that file. Alternatively, if a "
473 "filepath is given, training will begin from the state "
474 "saved in that file and will be saved to the output filepath."
482 "Use only log-likelihood data from a subset of particle "
483 "types specified by PDG code. If left unspecified, all "
484 "particle types will be used."
491 "Initialize network weights to random values, normally "
492 "distributed with mean of 1 and width of 0.5. Has no effect "
493 "if 'resume' is used."
501 "Scaling factor for the pion binary cross entropy term in "
502 "the loss function. Defaults to 0.1."
508def validate_args(args):
510 assert args.n_epochs > 0,
"Number of epochs must be larger than 0."
511 assert args.p_lims[0] < args.p_lims[1],
"p_lims: lower limit must be < upper limit."
513 args.theta_lims[0] < args.theta_lims[1]
514 ),
"theta_lims: lower limit must be < upper limit."
515 if args.only
is not None:
516 for pdg
in args.only:
517 assert pdg
in PDG_CODES, f
"Given PDG code {pdg} not understood."
519 if args.resume ==
"use_output":
520 args.resume = args.output
526 """Main network training logic."""
528 args = get_parser().parse_args()
529 args = validate_args(args)
531 print(
"Welcome to the network trainer.")
532 print(f
"Data will be read from {args.input}.")
533 if args.resume
is not None:
534 print(f
"Model training will be continued from the saved state at {args.resume}")
535 print(f
"The trained model will be saved to {args.output}.")
537 if not (args.p_lims[0] == -np.inf
and args.p_lims[0] == +np.inf):
539 f
"Model will be trained on data with momentum in "
540 f
"[{args.p_lims[0]}, {args.p_lims[1]}] GeV."
543 if not (args.theta_lims[0] == -np.inf
and args.theta_lims[0] == +np.inf):
545 f
"Model will be trained on data with theta in "
546 f
"[{args.theta_lims[0]}, {args.theta_lims[1]}] degrees."
549 print(f
"The model will be trained for {args.n_epochs} epochs.")
550 print(f
"Training will use a pion scaling factor of beta = {args.beta}.")
557if __name__ ==
"__main__":
def random_init(self, mean=1.0, std=0.5)
def kill_unused(self, only)
n_detector
number of detectors
def __init__(self, n_class=6, n_detector=6, const_init=1)
def get_weights(self, to_numpy=True)
fcs
linear layers for each particle type
def const_init(self, const)
n_class
number of particle types