21class TrainingRunMixin(BrowseTFileOnTerminateRunMixin, PostProcessingRunMixin):
22 """Prepare and execute a basf2 job to train neural network, postprocess, and inspect"""
41 """Database identifier of the filte being trained"""
42 return "trackfindingcdc_" + self.__class__.__name__[:-len(
"TrainingRun")]
45 def sample_file_name(self):
46 """File name of the recorded sample to be trained on
48 Defaults to the class name minus the mandatory TrainingRun postfix
50 return self.__class__.__name__[:-len(
"TrainingRun")] +
'_' + self.task +
'.root'
52 def create_argument_parser(self, **kwds):
53 """Create argument parser"""
54 argument_parser = super().create_argument_parser(**kwds)
56 argument_parser.add_argument(
58 choices=NonstrictChoices(["train", "eval", "explore", ]),
61 help=("Select a prepared recording task")
64 return argument_parser
66 def postprocess(self):
67 """Run the training as post-processing job
69 To run only the training run with --postprocess-only
71 ## Process each event according to the user's desired task (train, eval, explore)
72 if self.task == "train":
74 "trackfindingcdc_teacher",
78 cmd += ["--variables"]
82 "--identifier=" + self.identifier,
83 "--truth=" + self.truth,
84 self.sample_file_name,
89 # Move training file to the right location
90 if self.identifier.endswith(".xml"):
91 tracking_data_dir_path = os.path.join(os.environ["BELLE2_LOCAL_DIR"], "tracking", "data")
92 shutil.copy(self.identifier, tracking_data_dir_path)
96 "trackfindingcdc-classification-overview",
101 cmd += self.variables
103 if self.groupby is not None:
105 if isinstance(self.groupby, str):
106 cmd += [self.groupby]
110 if self.auxiliaries is not None:
112 if isinstance(self.auxiliaries, str):
113 cmd += [self.auxiliaries]
115 cmd += self.auxiliaries
118 "--truth=" + self.truth,
119 self.sample_file_name,
121 print("Running", cmd)
123 ## Set file name for the TBrowser to show if demanded
124 self.output_file_name = self.sample_file_name[:-len(".root")] + ".overview.root"
126 super().postprocess()