9from ipython_tools
import handler
13from subprocess
import check_output, CalledProcessError, STDOUT
15from shutil
import copy
21from ROOT
import Belle2
27 Helper class to show a PDF file in a jupyter notebook.
30 def __init__(self, pdf, size=(600, 700)):
33 :param pdf: The filename of the PDF file.
34 :param size: The size to use.
41 def _repr_html_(self):
42 """HTML representation"""
43 return f
'<iframe src={self.pdf} width={self.size[0]} height={self.size[1]}></iframe>'
45 def _repr_latex_(self):
46 """LaTeX representation"""
47 return r'\includegraphics[width=1.0\textwidth]{{{0}}}'.format(self.pdf)
50class MVATeacherAndAnalyser:
52 Class for training and analysing a tracking module, which has a MVA filter in it.
54 Works best, if you are on a jupyter ntoebook.
56 You need to supply a run_class, which includes all needed settings, on how to
57 train and execute the module. This class will be mixed in with the normal trackfindingcdc
58 run classes, so you can add the setting (e.g. tracking_coverage etc.) as normal.
63 # This module will be trained
64 recording_module = "FilterBasedVXDCDCTrackMerger"
65 # This is the name of the parameter of this module, which will be set to "mva" etc.
66 recording_parameter = "filter"
68 # These mva cuts will be tested during evaluation.
69 evaluation_cuts = [0.1, 0.2, ...]
77 # Some options, which will control the run classes
79 generator_module = "EvtGenInput"
81 # This will be added to the "normal" path, to record the training data (you do not have to set the module to
82 # recording, as this is done automatically).
83 def add_recording_modules(self, path):
84 mctrackfinder = path.add_module('TrackFinderMCTruthRecoTracks',
85 RecoTracksStoreArrayName='MCRecoTracks',
88 path.add_module('MCRecoTracksMatcher', mcRecoTracksStoreArrayName="MCRecoTracks",
89 prRecoTracksStoreArrayName="CDCRecoTracks", UseCDCHits=True, UsePXDHits=False, UseSVDHits=False)
90 path.add_module('MCRecoTracksMatcher', mcRecoTracksStoreArrayName="MCRecoTracks",
91 prRecoTracksStoreArrayName="VXDRecoTracks", UseCDCHits=False, UsePXDHits=True, UseSVDHits=True)
93 # Merge CDC and CXD tracks
94 path.add_module('FilterBasedVXDCDCTrackMerger',
96 CDCRecoTrackStoreArrayName="CDCRecoTracks",
97 VXDRecoTrackStoreArrayName="VXDRecoTracks",
98 MergedRecoTrackStoreArrayName="RecoTracks")
102 # This will be added to the "normal" path, to evaluate the mva cuts. In most cases, this is the same as the
103 # add_recording_modules (as the module parameters will be set automatically), but maybe you need
105 def add_validation_modules(self, path):
106 mctrackfinder = path.add_module('TrackFinderMCTruthRecoTracks',
107 RecoTracksStoreArrayName='MCRecoTracks',
110 # Merge CDC and CXD tracks
111 path.add_module('FilterBasedVXDCDCTrackMerger',
113 CDCRecoTrackStoreArrayName="CDCRecoTracks",
114 VXDRecoTrackStoreArrayName="VXDRecoTracks",
115 MergedRecoTrackStoreArrayName="PrefitRecoTracks")
117 path.add_module("SetupGenfitExtrapolation")
119 path.add_module("DAFRecoFitter", recoTracksStoreArrayName="PrefitRecoTracks")
121 path.add_module("TrackCreator", recoTrackColName="PrefitRecoTracks")
123 path.add_module("FittedTracksStorer", inputRecoTracksStoreArrayName="PrefitRecoTracks",
124 outputRecoTracksStoreArrayName="RecoTracks")
126 # We need to include the matching ourselves, as we have already a matching algorithm in place
127 path.add_module('MCRecoTracksMatcher', mcRecoTracksStoreArrayName="MCRecoTracks",
128 prRecoTracksStoreArrayName="RecoTracks", UseCDCHits=True, UsePXDHits=True, UseSVDHits=True)
133 def __init__(self, run_class, use_jupyter=True):
136 ## cached copy of the run class
137 self.run_class = run_class
138 ## cached flag to use jupyter notebook
139 self.use_jupyter = use_jupyter
141 ## cached name of the output file
142 self.recording_file_name = self.run_class.recording_module + ".root"
144 ## cached path without extension of the output file
145 self.file_name_path, ext = os.path.splitext(self.recording_file_name)
147 ## cached path with extension of the training-output file
148 self.training_file_name = self.file_name_path + "Training" + ext
149 ## cached path with extension of the testing-output file
150 self.test_file_name = self.file_name_path + "Testing" + ext
153 self.identifier_name = "FastBDT.weights.xml"
154 ## cached name of the output PDF file
155 self.evaluation_file_name = self.identifier_name + ".pdf"
157 ## cached path with extension of the testing-export file
158 self.expert_file_name = self.file_name_path + "TestingExport" + ext
160 ## cached path of the weight input data
161 self.weight_data_location = Belle2.FileSystem.findFile(os.path.join("tracking/data",
162 self.run_class.weight_data_location))
165 """Record a training file, split it in two parts and call the training method of the mva package"""
166 if not os.path.exists(self.recording_file_name):
167 self._create_records_file()
169 if not os.path.exists(self.training_file_name) or not os.path.exists(self.test_file_name):
170 self._write_train_and_test_files()
172 self._call_training_routine()
174 def evaluate_tracking(self):
176 Use the trained weight file and call the path again using different mva cuts. Validation using the
177 normal tracking validation modules.
179 copy(self.identifier_name, self.weight_data_location)
183 except FileExistsError:
186 def create_path(mva_cut):
187 class ValidationRun(self.run_class, TrackingValidationRun):
189 def finder_module(self, path):
190 self.add_validation_modules(path)
193 adjust_module(path, self.recording_module,
194 **{self.recording_parameter + "Parameters": {"cut": mva_cut},
195 self.recording_parameter: "mva"})
197 adjust_module(path, self.recording_module, **{self.recording_parameter: "truth"})
199 output_file_name = f"results/validation_{mva_cut}.root"
201 run = ValidationRun()
203 if not os.path.exists(run.output_file_name):
204 return {"path": run.create_path()}
206 return {"path": None}
208 assert self.use_jupyter
210 calculations = handler.process_parameter_space(create_path, mva_cut=self.run_class.evaluation_cuts + [999])
212 calculations.wait_for_end()
216 def evaluate_classification(self):
218 Evaluate the classification power on the test data set and produce a PDF.
220 if not os.path.exists(self.expert_file_name) or not os.path.exists(self.evaluation_file_name):
221 self._call_evaluation_routine()
222 self._call_expert_routine()
224 df = uproot.concatenate(
225 self.expert_file_name,
234 from IPython.display import display
235 display(PDF(self.evaluation_file_name, size=(800, 800)))
239 def _call_training_routine(self):
240 """Call the mva training routine in the train file"""
242 check_output(["trackfindingcdc_teacher", self.training_file_name])
243 except CalledProcessError as e:
244 raise RuntimeError(e.output)
246 def _write_train_and_test_files(self):
247 """Split the recorded file into two halves: training and test file and write it back"""
248 # TODO: This seems to reorder the columns...
249 df = uproot.concatenate(self.recording_file_name, library='pd')
250 mask = np.random.rand(len(df)) < 0.5
251 training_sample = df[mask]
252 test_sample = df[~mask]
254 with uproot.recreate(self.training_file_name) as outfile:
255 outfile["records"] = training_sample
256 with uproot.recreate(self.test_file_name) as outfile:
257 outfile["records"] = test_sample
259 def _create_records_file(self):
261 Create a path using the settings of the run_class and process it.
262 This will create a ROOT file with the recorded data.
264 recording_file_name = self.recording_file_name
266 class RecordRun(self.run_class, ReadOrGenerateEventsRun):
268 def create_path(self):
269 path = ReadOrGenerateEventsRun.create_path(self)
271 self.add_recording_modules(path)
273 adjust_module(path, self.recording_module,
274 **{self.recording_parameter + "Parameters": {"rootFileName": recording_file_name},
275 self.recording_parameter: "recording"})
280 path = run.create_path()
283 calculation = handler.process(path)
285 calculation.wait_for_end()
291 def _call_expert_routine(self):
292 """Call the mva expert"""
294 check_output(["basf2_mva_expert",
295 "--identifiers", self.identifier_name, self.weight_data_location,
296 "--datafiles", self.test_file_name,
297 "--outputfile", self.expert_file_name,
298 "--treename", "records"])
299 except CalledProcessError as e:
300 raise RuntimeError(e.output)
302 def _call_evaluation_routine(self):
303 """Call the mva evaluation routine"""
305 check_output(["basf2_mva_evaluate.py",
306 "--identifiers", self.identifier_name, self.weight_data_location,
307 "--train_datafiles", self.training_file_name,
308 "--datafiles", self.test_file_name,
309 "--treename", "records",
310 "--outputfile", self.evaluation_file_name],
312 except CalledProcessError as e:
313 raise RuntimeError(e.output)