15os.environ[
"KERAS_BACKEND"] =
"torch"
18if __name__ ==
"__main__":
22 from fitter
import fit
25 from basf2_mva_util
import create_onnx_mva_weightfile
28 parser = argparse.ArgumentParser(description=
'Train TFlat')
31 metavar=
'train_input',
34 default=
"dummyin_train.parquet",
35 help=
'Path to training parquet file'
42 default=
"dummyin_val.parquet",
43 help=
'Path to validation parquet file'
47 metavar=
'uniqueIdentifier',
48 dest=
'uniqueIdentifier',
50 default=
"TFlaT_MC16rd_light_2601_hyperion",
51 help=
'Name of the config .yaml to be used and the produced weightfile'
59 default=
"./ckpt/checkpoint.model.keras",
60 help=
'Path to checkpoints'
64 help=
'Start from checkpoint',
65 action=argparse.BooleanOptionalAction
67 args = parser.parse_args()
69 train_file = args.train_input
70 val_file = args.val_input
71 checkpoint_filepath = args.checkpoint
72 warmstart = args.warmstart
73 uniqueIdentifier = args.uniqueIdentifier
75 config = utils.load_config(uniqueIdentifier)
76 parameters = config[
'parameters']
78 trk_variable_list = config[
'trk_variable_list']
79 ecl_variable_list = config[
'ecl_variable_list']
80 roe_variable_list = config[
'roe_variable_list']
81 variables = utils.get_variables(
'pi+:tflat', rank_variable, trk_variable_list, particleNumber=parameters[
'num_trk'])
82 variables += utils.get_variables(
'gamma:tflat', rank_variable, ecl_variable_list, particleNumber=parameters[
'num_ecl'])
83 variables += utils.get_variables(
'pi+:tflat', rank_variable, roe_variable_list, particleNumber=parameters[
'num_roe'])
86 if os.path.isfile(checkpoint_filepath):
87 os.remove(checkpoint_filepath)
89 model = get_tflat_model(parameters=parameters, number_of_features=len(variables))
92 cosine_decay_scheduler = keras.optimizers.schedules.CosineDecay(
93 initial_learning_rate=config[
'initial_learning_rate'],
94 decay_steps=config[
'decay_steps'],
98 optimizer = keras.optimizers.AdamW(
99 learning_rate=cosine_decay_scheduler, weight_decay=config[
'weight_decay']
105 loss=keras.losses.binary_crossentropy,
108 keras.metrics.AUC()])
110 model = keras.models.load_model(checkpoint_filepath)
127 (torch.randn(1, len(variables)),),
129 input_names=[
"input"],
130 output_names=[
"output"],
133 weightfile = create_onnx_mva_weightfile(
136 target_variable=
"qrCombined",
139 ROOT.Belle2.MVA.Weightfile.saveToDatabase(weightfile, uniqueIdentifier)