Belle II Software light-2601-hyperion
trainer.py
1#!/usr/bin/env python3
2
3
10
11import os
12import ROOT
13import argparse
14
15os.environ["KERAS_BACKEND"] = "torch"
16
17
18if __name__ == "__main__":
19
20 import torch
21 import keras
22 from fitter import fit
23 import tflat.utils as utils
24 from tflat.model import get_tflat_model
25 from basf2_mva_util import create_onnx_mva_weightfile
26
27 # parse cli arguments
28 parser = argparse.ArgumentParser(description='Train TFlat')
29 parser.add_argument( # input parser
30 '--train_input',
31 metavar='train_input',
32 dest='train_input',
33 type=str,
34 default="dummyin_train.parquet",
35 help='Path to training parquet file'
36 )
37 parser.add_argument( # input parser
38 '--val_input',
39 metavar='val_input',
40 dest='val_input',
41 type=str,
42 default="dummyin_val.parquet",
43 help='Path to validation parquet file'
44 )
45 parser.add_argument( # input parser
46 '--uniqueIdentifier',
47 metavar='uniqueIdentifier',
48 dest='uniqueIdentifier',
49 type=str,
50 default="TFlaT_MC16rd_light_2601_hyperion",
51 help='Name of the config .yaml to be used and the produced weightfile'
52 )
53 parser.add_argument( # checkpoint parser
54 '--checkpoint',
55 metavar='checkpoint',
56 dest='checkpoint',
57 type=str,
58 nargs='+',
59 default="./ckpt/checkpoint.model.keras",
60 help='Path to checkpoints'
61 )
62 parser.add_argument(
63 '--warmstart',
64 help='Start from checkpoint',
65 action=argparse.BooleanOptionalAction
66 )
67 args = parser.parse_args()
68
69 train_file = args.train_input
70 val_file = args.val_input
71 checkpoint_filepath = args.checkpoint
72 warmstart = args.warmstart
73 uniqueIdentifier = args.uniqueIdentifier
74
75 config = utils.load_config(uniqueIdentifier)
76 parameters = config['parameters']
77 rank_variable = 'p'
78 trk_variable_list = config['trk_variable_list']
79 ecl_variable_list = config['ecl_variable_list']
80 roe_variable_list = config['roe_variable_list']
81 variables = utils.get_variables('pi+:tflat', rank_variable, trk_variable_list, particleNumber=parameters['num_trk'])
82 variables += utils.get_variables('gamma:tflat', rank_variable, ecl_variable_list, particleNumber=parameters['num_ecl'])
83 variables += utils.get_variables('pi+:tflat', rank_variable, roe_variable_list, particleNumber=parameters['num_roe'])
84
85 if not warmstart:
86 if os.path.isfile(checkpoint_filepath):
87 os.remove(checkpoint_filepath)
88
89 model = get_tflat_model(parameters=parameters, number_of_features=len(variables))
90
91 # configure the optimizer
92 cosine_decay_scheduler = keras.optimizers.schedules.CosineDecay(
93 initial_learning_rate=config['initial_learning_rate'],
94 decay_steps=config['decay_steps'],
95 alpha=config['alpha']
96 )
97
98 optimizer = keras.optimizers.AdamW(
99 learning_rate=cosine_decay_scheduler, weight_decay=config['weight_decay']
100 )
101
102 # compile the model
103 model.compile(
104 optimizer=optimizer,
105 loss=keras.losses.binary_crossentropy,
106 metrics=[
107 'accuracy',
108 keras.metrics.AUC()])
109 else:
110 model = keras.models.load_model(checkpoint_filepath)
111
112 model.summary()
113
114 fit(
115 model,
116 train_file,
117 val_file,
118 "tflat_variables",
119 variables,
120 "qrCombined",
121 config,
122 checkpoint_filepath
123 )
124
125 torch.onnx.export(
126 model,
127 (torch.randn(1, len(variables)),),
128 "model.onnx",
129 input_names=["input"],
130 output_names=["output"],
131 )
132
133 weightfile = create_onnx_mva_weightfile(
134 "model.onnx",
135 variables=variables,
136 target_variable="qrCombined",
137 )
138
139 ROOT.Belle2.MVA.Weightfile.saveToDatabase(weightfile, uniqueIdentifier)