Belle II Software development
training.py
1
8
9import os
10import os.path
11import shutil
12import subprocess
13
14from tracking.run.mixins import BrowseTFileOnTerminateRunMixin
15from tracking.run.mixins import PostProcessingRunMixin
16from tracking.run.utilities import NonstrictChoices
17
18# @cond internal_test
19
20
21class TrainingRunMixin(BrowseTFileOnTerminateRunMixin, PostProcessingRunMixin):
22 """Prepare and execute a basf2 job to train neural network, postprocess, and inspect"""
23
24
25 task = "train"
26
27
28 variables = None
29
30
31 groupby = None
32
33
34 auxiliaries = None
35
36
37 truth = "truth"
38
39 @property
40 def identifier(self):
41 """Database identifier of the filte being trained"""
42 return "trackfindingcdc_" + self.__class__.__name__[:-len("TrainingRun")]
43
44 @property
45 def sample_file_name(self):
46 """File name of the recorded sample to be trained on
47
48 Defaults to the class name minus the mandatory TrainingRun postfix
49 """
50 return self.__class__.__name__[:-len("TrainingRun")] + '_' + self.task + '.root'
51
52 def create_argument_parser(self, **kwds):
53 """Create argument parser"""
54 argument_parser = super().create_argument_parser(**kwds)
55
56 argument_parser.add_argument(
57 "--task",
58 choices=NonstrictChoices(["train", "eval", "explore", ]),
59 default=self.task,
60 dest="task",
61 help=("Select a prepared recording task")
62 )
63
64 return argument_parser
65
66 def postprocess(self):
67 """Run the training as post-processing job
68
69 To run only the training run with --postprocess-only
70 """
71 ## Process each event according to the user's desired task (train, eval, explore)
72 if self.task == "train":
73 cmd = [
74 "trackfindingcdc_teacher",
75 ]
76
77 if self.variables:
78 cmd += ["--variables"]
79 cmd += self.variables
80
81 cmd += [
82 "--identifier=" + self.identifier,
83 "--truth=" + self.truth,
84 self.sample_file_name,
85 ]
86 print("Running", cmd)
87 subprocess.call(cmd)
88
89 # Move training file to the right location
90 if self.identifier.endswith(".xml"):
91 tracking_data_dir_path = os.path.join(os.environ["BELLE2_LOCAL_DIR"], "tracking", "data")
92 shutil.copy(self.identifier, tracking_data_dir_path)
93
94 else:
95 cmd = [
96 "trackfindingcdc-classification-overview",
97 ]
98
99 if self.variables:
100 cmd += ["-v"]
101 cmd += self.variables
102
103 if self.groupby is not None:
104 cmd += ["-g"]
105 if isinstance(self.groupby, str):
106 cmd += [self.groupby]
107 else:
108 cmd += self.groupby
109
110 if self.auxiliaries is not None:
111 cmd += ["-a"]
112 if isinstance(self.auxiliaries, str):
113 cmd += [self.auxiliaries]
114 else:
115 cmd += self.auxiliaries
116
117 cmd += [
118 "--truth=" + self.truth,
119 self.sample_file_name,
120 ]
121 print("Running", cmd)
122 subprocess.call(cmd)
123 ## Set file name for the TBrowser to show if demanded
124 self.output_file_name = self.sample_file_name[:-len(".root")] + ".overview.root"
125
126 super().postprocess()
127
128# @endcond