12Methods and a script for training PID calibration weights. 
   15    $ python pidTrainWeights.py data/ models/net.pt -n 100 
   17Use `python pidTrainWeights.py -h` to see all command-line options. 
   22import torch.nn.functional 
as F
 
   23import torch.optim 
as optim
 
   27from os 
import makedirs
 
   28from os.path 
import join, dirname
 
   29from tqdm.auto 
import tqdm
 
   32def _make_const_lists():
 
   33    """Moving this code into a function to avoid a top-level ROOT import.""" 
   34    from ROOT 
import Belle2  
 
   37    PARTICLES, PDG_CODES = [], []
 
   38    for i 
in range(len(ROOT.Belle2.Const.chargedStableSet)):
 
   39        particle = ROOT.Belle2.Const.chargedStableSet.at(i)
 
   40        name = (particle.__repr__()[7:-1]
 
   43                .replace(
"euteron", 
""))
 
   44        PARTICLES.append(name)
 
   45        PDG_CODES.append(particle.getPDGCode())
 
   50    for det 
in ROOT.Belle2.Const.PIDDetectors.set():
 
   51        DETECTORS.append(ROOT.Belle2.Const.parseDetectors(det))
 
   54    return PARTICLES, PDG_CODES, DETECTORS
 
   58PARTICLES = [
"e", 
"mu", 
"pi", 
"K", 
"p", 
"d"]
 
   59PDG_CODES = [11, 13, 211, 321, 2212, 1000010020]
 
   60DETECTORS = [
"SVD", 
"CDC", 
"TOP", 
"ARICH", 
"ECL", 
"KLM"]
 
   64    """PyTorch architecture for training calibration weights.""" 
   66    def __init__(self, n_class=6, n_detector=6, const_init=1):
 
   67        """Initialize the network for training. 
   70            n_class (int, optional): Number of classification classes (particle 
   71                types). Defaults to 6. 
   72            n_detector (int, optional): Number of detectors. Defaults to 6. 
   73            const_init (int, optional): Constant value to initialize all 
   74                weights. If None, PyTorch's default random initialization is 
   75                used instead. Defaults to 1. 
   86        self.
fcs = nn.ModuleList(
 
   90        if const_init 
is not None:
 
 
   94        """Network's forward methods. Sums the detector log-likelihoods for each particle 
   95        type, then computes the likelihood ratios. Uses the weights. 
   98            x (torch.Tensor): Input detector log-likelihood data. Should be of 
   99                shape (N, n_detector * n_class), where N is the number of samples. 
  102            torch.Tensor: Weighted likelihood ratios. 
  105        outs = [self.
fcs[i](x[:, i * n: (i + 1) * n]) 
for i 
in range(self.
n_class)]
 
  106        out = torch.cat(outs, dim=1)
 
  107        return F.softmax(out, dim=1)
 
 
  110        """Returns the weights as a six-by-six array or tensor. 
  113            to_numpy (bool, optional): Whether to return the weights as a numpy 
  114                array (True) or torch tensor (False). Defaults to True. 
  117            np.array or torch.Tensor: The six-by-six matrix containing the 
  120        weights = torch.cat([fc.weight.detach() 
for fc 
in self.
fcs])
 
  122            return weights.cpu().numpy()
 
 
  127        """Fill all the weights with the given value. 
  130            const (float): Constant value to fill all weights with. 
  132        with torch.no_grad():
 
  134                fc.weight.fill_(const)
 
 
  137        """Fill all the weights with values sampled from a Normal distribution 
  138        with given mean and standard deviation. 
  141            mean (float, optional): The mean of the Normal distribution. 
  143            std (float, optional): The standard deviation of the Normal 
  144                distribution. Defaults to 0.5. 
  146        with torch.no_grad():
 
  149                fc.weight.add_(torch.normal(mean=mean, std=std, size=fc.weight.size()))
 
 
  152        """Kills weights corresponding to unused particle types. 
  155            only (list(str) or None): List of allowed particle types. The 
  156                weights corresponding to any particle types that are _not_ in 
  157                this list will be filled with zero and be frozen (e.g. gradients 
  158                will not be computed/updated). 
  163            for i, pdg 
in enumerate(PDG_CODES):
 
  166                self.
fcs[i].weight.requires_grad = 
False 
  167                self.
fcs[i].weight.fill_(1)
 
 
 
  170def load_training_data(directory, p_lims=None, theta_lims=None, device=None):
 
  171    """Loads training and validation data within the given momentum and theta 
  175        directory (str): Directory containing the train and validation sets. 
  176        p_lims (tuple(float), optional): Minimum and maximum momentum. Defaults 
  178        theta_lims (tuple(float), optional): Minimum and maximum theta in 
  179            degrees. Defaults to None. 
  180        device (torch.device, optional): Device to move the data onto. Defaults 
  184        torch.Tensor: Training log-likelihood data. 
  185        torch.Tensor: Training labels. 
  186        torch.Tensor: Validation log-likelihood data. 
  187        torch.Tensor: Validation labels. 
  189    p_lo, p_hi = p_lims 
if p_lims 
is not None else (-np.inf, +np.inf)
 
  190    t_lo, t_hi = theta_lims 
if theta_lims 
is not None else (-np.inf, +np.inf)
 
  191    t_lo, t_hi = np.radians(t_lo), np.radians(t_hi)
 
  194        data = np.load(filename)
 
  195        X, y, p, t = data[
"X"], data[
"y"], data[
"p"], data[
"theta"]
 
  196        mask = np.logical_and.reduce([p >= p_lo, p <= p_hi, t >= t_lo, t <= t_hi])
 
  197        X = torch.tensor(X[mask]).to(device=device, dtype=torch.float)
 
  198        y = torch.tensor(y[mask]).to(device=device, dtype=torch.long)
 
  201    X_tr, y_tr = _load(join(directory, 
"train.npz"))
 
  202    X_va, y_va = _load(join(directory, 
"val.npz"))
 
  203    return X_tr, y_tr, X_va, y_va
 
  206def load_checkpoint(filename, device=None, only=None):
 
  207    """Loads training from a checkpoint. 
  210        filename (str): Checkpoint filename. 
  211        device (torch.device, optional): Device to move the data onto. Defaults 
  213        only (list(str), optional): List of allowed particle types. Defaults to 
  217        WeightNet: The network. 
  218        optim.Optimizer: The optimizer. 
  220        dict(str -> list(float)): Training losses (diag, pion, sum) from each 
  222        dict(str -> list(float)): Validation losses (diag, pion, sum) from every 
  224        dict(str -> list(float)): Training accuracies (net, pion) from each 
  226        dict(str -> list(float)): Validation accuracies (net, pion) from every 
  229    checkpoint = torch.load(filename, map_location=torch.device(
"cpu"))
 
  231    net.load_state_dict(checkpoint[
"model_state_dict"])
 
  232    net.kill_unused(only)
 
  233    net.to(device=device)
 
  235    opt = optim.Adam(filter(
lambda p: p.requires_grad, net.parameters()), lr=5e-4)
 
  236    opt.load_state_dict(checkpoint[
"optimizer_state_dict"])
 
  242        checkpoint[
"loss_t"],
 
  243        checkpoint[
"loss_v"],
 
  244        checkpoint[
"accu_t"],
 
  245        checkpoint[
"accu_v"],
 
  249def save_checkpoint(filename, net, opt, epoch, loss_t, loss_v, accu_t, accu_v):
 
  250    """Saves training to a checkpoint. 
  253        filename (str): Checkpoint filename. 
  254        net (WeightNet): The network. 
  255        opt (optim.Optimizer): The optimizer. 
  256        epoch (int): The current epoch number. 
  257        loss_t (dict(str -> list(float))): Training losses (diag, pion, sum) 
  259        loss_v (dict(str -> list(float))): Validation losses (diag, pion, sum) 
  260            from every tenth epoch. 
  261        accu_t (dict(str -> list(float))): Training accuracies (net, pion) from 
  263        accu_v (dict(str -> list(float))): Validation accuracies (net, pion) 
  264            from every tenth epoch. 
  267    makedirs(dirname(filename), exist_ok=
True)
 
  270            "model_state_dict": net.state_dict(),
 
  271            "optimizer_state_dict": opt.state_dict(),
 
  282def initialize(args, device=None):
 
  283    """Initializes training from the parsed command-line arguments. 
  286        args (argparse.Namespace): Parsed command-line arguments. 
  287        device (torch.device, optional): Device to move the data onto. Defaults 
  291        WeightNet: The network. 
  292        optim.Optimizer: The optimizer. 
  294        dict(str -> list(float)): Training losses (diag, pion, sum) from each 
  296        dict(str -> list(float)): Validation losses (diag, pion, sum) from every 
  298        dict(str -> list(float)): Training accuracies (net, pion) from each 
  300        dict(str -> list(float)): Validation accuracies (net, pion) from every 
  303    if args.resume 
is not None:
 
  304        net, opt, epochs_0, l_t, l_v, a_t, a_v = load_checkpoint(args.resume, device=device, only=args.only)
 
  310        net.kill_unused(args.only)
 
  311        net.to(device=device)
 
  312        opt = optim.Adam(filter(
lambda p: p.requires_grad, net.parameters()), lr=5e-4)
 
  314        l_t = {
"diag": [], 
"pion": [], 
"sum": []}
 
  315        l_v = {
"diag": [], 
"pion": [], 
"sum": []}
 
  316        a_t = {
"net": [], 
"pion": []}
 
  317        a_v = {
"net": [], 
"pion": []}
 
  319    return net, opt, epochs_0, l_t, l_v, a_t, a_v
 
  323    """Trains and saves a model according to the command-line arguments. 
  326        args (argparse.Namespace): Parsed command-line arguments. 
  327        use_tqdm (bool, optional): Use TQDM to track progress. Defaults to True. 
  329    print(
"Reading data.")
 
  330    device = torch.device(
"cuda" if torch.cuda.is_available() 
else "cpu")
 
  331    print(
"...and moving to", device)
 
  332    X_tr, y_tr, X_va, y_va = load_training_data(
 
  333        args.input, device=device, p_lims=args.p_lims, theta_lims=args.theta_lims
 
  336    if len(y_tr) < 10 
or len(y_va) < 10:
 
  337        print(
"Not enough data. Aborting...")
 
  340    print(f
"{len(y_tr)} train events, {len(y_va)} val events")
 
  342    print(
"Initializing network.")
 
  343    net, opt, epochs_0, loss_t, loss_v, accu_t, accu_v = initialize(args, device=device)
 
  345    diag_lfn = nn.CrossEntropyLoss()
 
  346    pion_lfn = nn.BCELoss()
 
  349    def compute_accuracies(out, y):
 
  350        output = out.detach().cpu().numpy()
 
  351        target = y.detach().cpu().numpy()
 
  352        pred = np.squeeze(np.argmax(output, axis=1))
 
  353        accu = np.count_nonzero(pred == target) / len(pred)
 
  354        pi_out = output[(target == 2), 2]
 
  355        pi_pred = (pi_out > 0.5).astype(float)
 
  356        pi_accu = pi_pred.sum() / len(pi_pred)
 
  359    def lfn(output, target):
 
  360        diag = diag_lfn(output, target)
 
  361        pi_mask = target == 2
 
  362        pi_out = output[pi_mask, 2]
 
  363        pi_y = (target[pi_mask] == 2).float()
 
  364        pion = pion_lfn(pi_out, pi_y)
 
  365        return diag + pion_wgt * pion, diag, pion
 
  367    print(f
"Training network for {args.n_epochs} epochs.")
 
  369    iterator = range(args.n_epochs)
 
  371        iterator = tqdm(iterator)
 
  373    for epoch 
in iterator:
 
  378        loss, diag, pion = lfn(out, y_tr)
 
  383        loss_t[
"diag"].append(diag.item())
 
  384        loss_t[
"pion"].append(pion.item())
 
  385        loss_t[
"sum"].append(loss.item())
 
  388            accu, pi_accu = compute_accuracies(out, y_tr)
 
  389            accu_t[
"net"].append(accu)
 
  390            accu_t[
"pion"].append(pi_accu)
 
  394        with torch.no_grad():
 
  396            loss, diag, pion = lfn(out, y_va)
 
  398            loss_v[
"diag"].append(diag.item())
 
  399            loss_v[
"pion"].append(pion.item())
 
  400            loss_v[
"sum"].append(loss.item())
 
  403                accu, pi_accu = compute_accuracies(out, y_va)
 
  404                accu_v[
"net"].append(accu)
 
  405                accu_v[
"pion"].append(pi_accu)
 
  407    print(
"Training complete.")
 
  411        args.output, net, opt, epochs_0 + args.n_epochs, loss_t, loss_v, accu_t, accu_v,
 
  414    wgt = net.get_weights(to_numpy=
True)
 
  415    np.save(args.output.replace(
".pt", 
"_wgt.npy"), wgt)
 
  417    print(f
"Model saved to {args.output}.")
 
  421    """Handles the command-line argument parsing. 
  424        argparse.Namespace: The parsed arguments. 
  426    from argparse 
import ArgumentParser
 
  428    ap = ArgumentParser(description=
"", epilog=
"")
 
  430        "input", type=str, help=
"Path to folder with training files (in .npz format).",
 
  435        help=
"Output filename where model will be saved (should end in .pt).",
 
  442        help=
"Number of epochs to train the network. Defaults to 500.",
 
  447        default=[-float(
'inf'), +float(
'inf')],
 
  449            "Lower and upper limits for momentum in GeV. Lower limit " 
  450            "should be given first. Default values are -inf, +inf." 
  456        default=[-float(
'inf'), +float(
'inf')],
 
  458            "Lower and upper limits for theta in degrees. Lower limit " 
  459            "should be given first. Default values are -inf, +inf." 
  470            "Load a pre-existing model and resume training instead of " 
  471            "starting a new one. If '--resume' is used and no file is " 
  472            "specified, the output filepath will be loaded, and the " 
  473            "training will overwrite that file. Alternatively, if a " 
  474            "filepath is given, training will begin from the state " 
  475            "saved in that file and will be saved to the output filepath." 
  483            "Use only log-likelihood data from a subset of particle " 
  484            "types specified by PDG code. If left unspecified, all " 
  485            "particle types will be used." 
  492            "Initialize network weights to random values, normally " 
  493            "distributed with mean of 1 and width of 0.5. Has no effect " 
  494            "if 'resume' is used." 
  502            "Scaling factor for the pion binary cross entropy term in " 
  503            "the loss function. Defaults to 0.1." 
  509def validate_args(args):
 
  511    assert args.n_epochs > 0, 
"Number of epochs must be larger than 0." 
  512    assert args.p_lims[0] < args.p_lims[1], 
"p_lims: lower limit must be < upper limit." 
  514        args.theta_lims[0] < args.theta_lims[1]
 
  515    ), 
"theta_lims: lower limit must be < upper limit." 
  516    if args.only 
is not None:
 
  517        for pdg 
in args.only:
 
  518            assert pdg 
in PDG_CODES, f
"Given PDG code {pdg} not understood." 
  520    if args.resume == 
"use_output":
 
  521        args.resume = args.output
 
  527    """Main network training logic.""" 
  529    args = get_parser().parse_args()
 
  530    args = validate_args(args)
 
  532    print(
"Welcome to the network trainer.")
 
  533    print(f
"Data will be read from {args.input}.")
 
  534    if args.resume 
is not None:
 
  535        print(f
"Model training will be continued from the saved state at {args.resume}")
 
  536    print(f
"The trained model will be saved to {args.output}.")
 
  538    if not (args.p_lims[0] == -np.inf 
and args.p_lims[0] == +np.inf):
 
  540            f
"Model will be trained on data with momentum in " 
  541            f
"[{args.p_lims[0]}, {args.p_lims[1]}] GeV." 
  544    if not (args.theta_lims[0] == -np.inf 
and args.theta_lims[0] == +np.inf):
 
  546            f
"Model will be trained on data with theta in " 
  547            f
"[{args.theta_lims[0]}, {args.theta_lims[1]}] degrees." 
  550    print(f
"The model will be trained for {args.n_epochs} epochs.")
 
  551    print(f
"Training will use a pion scaling factor of beta = {args.beta}.")
 
  558if __name__ == 
"__main__":
 
__init__(self, n_class=6, n_detector=6, const_init=1)
n_detector
number of detectors
random_init(self, mean=1.0, std=0.5)
fcs
linear layers for each particle type
get_weights(self, to_numpy=True)
n_class
number of particle types