Belle II Software development
trainer.py
1#!/usr/bin/env python3
2
3
10
11import os
12import ROOT
13import argparse
14
15os.environ["KERAS_BACKEND"] = "torch"
16
17
18if __name__ == "__main__":
19
20 import torch
21 import keras
22 from fitter import fit
23 import tflat.utils as utils
24 from tflat.model import get_tflat_model
25 from basf2_mva_util import create_onnx_mva_weightfile
26
27 # parse cli arguments
28 parser = argparse.ArgumentParser(description='Train TFlat')
29 parser.add_argument( # input parser
30 '--train_input',
31 metavar='train_input',
32 dest='train_input',
33 type=str,
34 default="dummyin_train.parquet",
35 help='Path to training parquet file'
36 )
37 parser.add_argument( # input parser
38 '--val_input',
39 metavar='val_input',
40 dest='val_input',
41 type=str,
42 default="dummyin_val.parquet",
43 help='Path to validation parquet file'
44 )
45 parser.add_argument( # input parser
46 '--uniqueIdentifier',
47 metavar='uniqueIdentifier',
48 dest='uniqueIdentifier',
49 type=str,
50 default="TFlaT_MC16rd_light_2601_hyperion",
51 help='Name of the config .yaml to be used and the produced weightfile'
52 )
53 parser.add_argument( # checkpoint parser
54 '--checkpoint',
55 metavar='checkpoint',
56 dest='checkpoint',
57 type=str,
58 nargs='+',
59 default="./ckpt/checkpoint.model.keras",
60 help='Path to checkpoints'
61 )
62 parser.add_argument(
63 '--warmstart',
64 help='Start from checkpoint',
65 action=argparse.BooleanOptionalAction
66 )
67 args = parser.parse_args()
68
69 train_file = args.train_input
70 val_file = args.val_input
71 checkpoint_filepath = args.checkpoint
72 warmstart = args.warmstart
73 uniqueIdentifier = args.uniqueIdentifier
74
75 config = utils.load_config(uniqueIdentifier)
76 parameters = config['parameters']
77 rank_variable = 'p'
78 trk_variable_list = config['trk_variable_list']
79 ecl_variable_list = config['ecl_variable_list']
80 roe_variable_list = config['roe_variable_list']
81 variables = utils.get_variables('pi+:tflat', rank_variable, trk_variable_list, particleNumber=parameters['num_trk'])
82 variables += utils.get_variables('gamma:tflat', rank_variable, ecl_variable_list, particleNumber=parameters['num_ecl'])
83 variables += utils.get_variables('pi+:tflat', rank_variable, roe_variable_list, particleNumber=parameters['num_roe'])
84
85 if not warmstart:
86 if os.path.isfile(checkpoint_filepath):
87 os.remove(checkpoint_filepath)
88
89 model = get_tflat_model(parameters=parameters, number_of_features=len(variables))
90
91 # configure the optimizer
92 cosine_decay_scheduler = keras.optimizers.schedules.CosineDecay(
93 initial_learning_rate=config['initial_learning_rate'],
94 decay_steps=config['decay_steps'],
95 alpha=config['alpha']
96 )
97
98 optimizer = keras.optimizers.AdamW(
99 learning_rate=cosine_decay_scheduler, weight_decay=config['weight_decay']
100 )
101
102 # compile the model
103 model.compile(
104 optimizer=optimizer,
105 loss=keras.losses.binary_crossentropy,
106 metrics=[
107 'accuracy',
108 keras.metrics.AUC(),
109 keras.metrics.MeanSquaredError()])
110 else:
111 model = keras.models.load_model(checkpoint_filepath)
112
113 model.summary()
114
115 fit(
116 model,
117 train_file,
118 val_file,
119 "tflat_variables",
120 variables,
121 "qrCombined",
122 config,
123 checkpoint_filepath
124 )
125
126 torch.onnx.export(
127 model,
128 (torch.randn(1, len(variables)),),
129 "model.onnx",
130 input_names=["input"],
131 output_names=["output"],
132 )
133
134 weightfile = create_onnx_mva_weightfile(
135 "model.onnx",
136 variables=variables,
137 target_variable="qrCombined",
138 )
139
140 ROOT.Belle2.MVA.Weightfile.saveToDatabase(weightfile, uniqueIdentifier)