Belle II Software prerelease-11-00-00a
combined_quality_estimator_teacher.py
1#!/usr/bin/env python3
2
3
10
11"""
12combined_module_quality_estimator_teacher
13-----------------------------------------
14
15Information on the MVA Track Quality Indicator / Estimator can be found
16on `XWiki
17<https://xwiki.desy.de/xwiki/rest/p/0d3f4>`_.
18
19Purpose of this script
20~~~~~~~~~~~~~~~~~~~~~~
21
22This python script is used for the combined training and validation of three
23classifiers, the actual final MVA track quality estimator and the two quality
24estimators for the intermediate standalone track finders that it depends on.
25
26 - Final MVA track quality estimator:
27 The final quality estimator for fully merged and fitted tracks (RecoTracks).
28 Its classifier uses features from the track fitting, merger, hit pattern, ...
29 But it also uses the outputs from respective intermediate quality
30 estimators for the VXD and the CDC track finding as inputs. It provides
31 the final quality indicator (QI) exported to the track objects.
32
33 - VXDTF2 track quality estimator:
34 MVA quality estimator for the VXD standalone track finding.
35
36 - CDC track quality estimator:
37 MVA quality estimator for the CDC standalone track finding.
38
39Each classifier requires for its training a different training data set and they
40need to be validated on a separate testing data set. Further, the final quality
41estimator can only be trained, when the trained weights for the intermediate
42quality estimators are available. If the final estimator shall be trained without
43one or both previous estimators, the requirements have to be commented out in the
44__init__.py file of tracking.
45For all estimators, a list of variables to be ignored is specified in the MasterTask.
46The current choice is mainly based on pure data MC agreement in these quantities or
47on outdated implementations. It was decided to leave them in the hardcoded "ugly" way
48in here to remind future generations that they exist in principle and they should and
49could be added to the estimator, once their modelling becomes better in future or an
50alternative implementation is programmed.
51To avoid mistakes, b2luigi is used to create a task chain for a combined training and
52validation of all classifiers.
53
54b2luigi: Understanding the steering file
55~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
56
57All trainings and validations are done in the correct order in this steering
58file. For the purpose of creating a dependency graph, the `b2luigi
59<https://b2luigi.readthedocs.io>`_ python package is used, which extends the
60`luigi <https://luigi.readthedocs.io>`_ package developed by spotify.
61
62Each task that has to be done is represented by a special class, which defines
63which defines parameters, output files and which other tasks with which
64parameters it depends on. For example a teacher task, which runs
65``basf2_mva_teacher.py`` to train the classifier, depends on a data collection
66task which runs a reconstruction and writes out track-wise variables into a root
67file for training. An evaluation/validation task for testing the classifier
68requires both the teacher task, as it needs the weightfile to be present, and
69also a data collection task, because it needs a dataset for testing classifier.
70
71The final task that defines which tasks need to be done for the steering file to
72finish is the ``MasterTask``. When you only want to run parts of the
73training/validation pipeline, you can comment out requirements in the Master
74task or replace them by lower-level tasks during debugging.
75
76Requirements
77~~~~~~~~~~~~
78
79This steering file relies on b2luigi_ for task scheduling and `uncertain_panda
80<https://github.com/nils-braun/uncertain_panda>`_ for uncertainty calculations.
81uncertain_panda is not in the externals and b2luigi is not upto v01-07-01. Both
82can be installed via pip::
83
84 python3 -m pip install [--user] b2luigi uncertain_panda
85
86Use the ``--user`` option if you have not rights to install python packages into
87your externals (e.g. because you are using cvmfs) and install them in
88``$HOME/.local`` instead.
89
90Configuration
91~~~~~~~~~~~~~
92
93Instead of command line arguments, the b2luigi script is configured via a
94``settings.json`` file. Open it in your favorite text editor and modify it to
95fit to your requirements.
96
97Usage
98~~~~~
99
100You can test the b2luigi without running it via::
101
102 python3 combined_quality_estimator_teacher.py --dry-run
103 python3 combined_quality_estimator_teacher.py --show-output
104
105This will show the outputs and show potential errors in the definitions of the
106luigi task dependencies. To run the the steering file in normal (local) mode,
107run::
108
109 python3 combined_quality_estimator_teacher.py
110
111I usually use the interactive luigi web interface via the central scheduler
112which visualizes the task graph while it is running. Therefore, the scheduler
113daemon ``luigid`` has to run in the background, which is located in
114``~/.local/bin/luigid`` in case b2luigi had been installed with ``--user``. For
115example, run::
116
117 luigid --port 8886
118
119Then, execute your steering (e.g. in another terminal) with::
120
121 python3 combined_quality_estimator_teacher.py --scheduler-port 8886
122
123To view the web interface, open your webbrowser enter into the url bar::
124
125 localhost:8886
126
127If you don't run the steering file on the same machine on which you run your web
128browser, you have two options:
129
130 1. Run both the steering file and ``luigid`` remotely and use
131 ssh-port-forwarding to your local host. Therefore, run on your local
132 machine::
133
134 ssh -N -f -L 8886:localhost:8886 <remote_user>@<remote_host>
135
136 2. Run the ``luigid`` scheduler locally and use the ``--scheduler-host <your
137 local host>`` argument when calling the steering file
138
139Accessing the results / output files
140~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
141
142All output files are stored in a directory structure in the ``result_path``. The
143directory tree encodes the used b2luigi parameters. This ensures reproducibility
144and makes parameter searches easy. Sometimes, it is hard to find the relevant
145output files. You can view the whole directory structure by running ``tree
146<result_path>``. Ise the unix ``find`` command to find the files that interest
147you, e.g.::
148
149 find <result_path> -name "*.pdf" # find all validation plot files
150 find <result_path> -name "*.root" # find all ROOT files
151"""
152
153import itertools
154import os
155from pathlib import Path
156import shutil
157import subprocess
158import textwrap
159from datetime import datetime
160from typing import Iterable
161
162import matplotlib.pyplot as plt
163import numpy as np
164import uproot
165from matplotlib.backends.backend_pdf import PdfPages
166
167import basf2
168import basf2_mva
169from packaging import version
170import background
171import simulation
172import tracking
173import tracking.root_utils as root_utils
174from tracking.harvesting_validation.combined_module import CombinedTrackingValidationModule
175from tracking_mva_filter_payloads.write_tracking_mva_filter_payloads_to_db import write_tracking_mva_filter_payloads_to_db
176from tracking_mva_filter_payloads.write_tracking_mva_filter_payloads_to_db import write_mva_weightfile_content_to_db
177
178# @cond internal_test
179
180# wrap python modules that are used here but not in the externals into a try except block
181install_helpstring_formatter = ("\nCould not find {module} python module.Try installing it via\n"
182 " python3 -m pip install [--user] {module}\n")
183try:
184 import b2luigi
185 from b2luigi.core.utils import get_serialized_parameters, get_log_file_dir, create_output_dirs
186 from b2luigi.basf2_helper import Basf2PathTask, Basf2Task
187 from b2luigi.core.task import Task, ExternalTask
188 from b2luigi.basf2_helper.utils import get_basf2_git_hash
189except ModuleNotFoundError:
190 print(install_helpstring_formatter.format(module="b2luigi"))
191 raise
192try:
193 from uncertain_panda import pandas as upd
194except ModuleNotFoundError:
195 print(install_helpstring_formatter.format(module="uncertain_panda"))
196 raise
197
198# If b2luigi version 0.3.2 or older, it relies on $BELLE2_RELEASE being "head",
199# which is not the case in the new externals. A fix has been merged into b2luigi
200# via https://github.com/nils-braun/b2luigi/pull/17 and thus should be available
201# in future releases.
202if (
203 version.parse(b2luigi.__version__) <= version.parse("0.3.2") and
204 get_basf2_git_hash() is None and
205 os.getenv("BELLE2_LOCAL_DIR") is not None
206):
207 print(f"b2luigi version could not obtain git hash because of a bug not yet fixed in version {b2luigi.__version__}\n"
208 "Please install the latest version of b2luigi from github via\n\n"
209 " python3 -m pip install --upgrade [--user] git+https://github.com/nils-braun/b2luigi.git\n")
210 raise ImportError
211
212# Utility functions
213
214
215def create_fbdt_option_string(fast_bdt_option):
216 """
217 returns a readable string created by the fast_bdt_option array
218 """
219 return "_nTrees" + str(fast_bdt_option[0]) + "_nCuts" + str(fast_bdt_option[1]) + "_nLevels" + \
220 str(fast_bdt_option[2]) + "_shrin" + str(int(round(100*fast_bdt_option[3], 0)))
221
222
223def createV0momenta(x, mu, beta):
224 """
225 Copied from Biancas K_S0 particle gun code: Returns a realistic V0 momentum distribution
226 when running over x. Mu and Beta are properties of the function that define center and tails.
227 Used for the particle gun simulation code for K_S0 and Lambda_0
228 """
229 return (1/beta)*np.exp(-(x - mu)/beta) * np.exp(-np.exp(-(x - mu) / beta))
230
231
232def my_basf2_mva_teacher(
233 records_files,
234 tree_name,
235 weightfile_identifier,
236 target_variable="truth",
237 exclude_variables=None,
238 fast_bdt_option=[200, 8, 3, 0.1] # nTrees, nCuts, nLevels, shrinkage
239):
240 """
241 My custom wrapper for basf2 mva teacher. Adapted from code in ``trackfindingcdc_teacher``.
242
243 :param records_files: List of files with collected ("recorded") variables to use as training data for the MVA.
244 :param tree_name: Name of the TTree in the ROOT file from the ``data_collection_task``
245 that contains the training data for the MVA teacher.
246 :param weightfile_identifier: Name of the weightfile that is created.
247 Should either end in ".xml" for local weightfiles or in ".root", when
248 the weightfile needs later to be uploaded as a payload to the conditions
249 database.
250 :param target_variable: Feature/variable to use as truth label in the quality estimator MVA classifier.
251 :param exclude_variables: List of collected variables to not use in the training of the QE MVA classifier.
252 In addition to variables containing the "truth" substring, which are excluded by default.
253 :param fast_bdt_option: specified fast BDT options, default: [200, 8, 3, 0.1] [nTrees, nCuts, nLevels, shrinkage]
254 """
255 if exclude_variables is None:
256 exclude_variables = []
257
258 weightfile_extension = Path(weightfile_identifier).suffix
259 if weightfile_extension not in {".xml", ".root"}:
260 raise ValueError(f"Weightfile Identifier should end in .xml or .root, but ends in {weightfile_extension}")
261
262 # extract names of all variables from one record file
263 with root_utils.root_open(records_files[0]) as records_tfile:
264 input_tree = records_tfile.Get(tree_name)
265 feature_names = [leave.GetName() for leave in input_tree.GetListOfLeaves()]
266
267 # get list of variables to use for training without MC truth
268 truth_free_variable_names = [
269 name
270 for name in feature_names
271 if (
272 ("truth" not in name) and
273 (name != target_variable) and
274 (name not in exclude_variables)
275 )
276 ]
277 if "weight" in truth_free_variable_names:
278 truth_free_variable_names.remove("weight")
279 weight_variable = "weight"
280 elif "__weight__" in truth_free_variable_names:
281 truth_free_variable_names.remove("__weight__")
282 weight_variable = "__weight__"
283 else:
284 weight_variable = ""
285
286 # Set options for MVA training
287 general_options = basf2_mva.GeneralOptions()
288 general_options.m_datafiles = basf2_mva.vector(*records_files)
289 general_options.m_treename = tree_name
290 general_options.m_weight_variable = weight_variable
291 general_options.m_identifier = weightfile_identifier
292 general_options.m_variables = basf2_mva.vector(*truth_free_variable_names)
293 general_options.m_target_variable = target_variable
294 fastbdt_options = basf2_mva.FastBDTOptions()
295
296 fastbdt_options.m_nTrees = fast_bdt_option[0]
297 fastbdt_options.m_nCuts = fast_bdt_option[1]
298 fastbdt_options.m_nLevels = fast_bdt_option[2]
299 fastbdt_options.m_shrinkage = fast_bdt_option[3]
300 # Train a MVA method and store the weightfile (MVAFastBDT.root) locally.
301 basf2_mva.teacher(general_options, fastbdt_options)
302
303
304def _my_uncertain_mean(series: upd.Series):
305 """
306 Temporary Workaround bug in ``uncertain_panda`` where a ``ValueError`` is
307 thrown for ``Series.unc.mean`` if the series is empty. Can be replaced by
308 .unc.mean when the issue is fixed.
309 https://github.com/nils-braun/uncertain_panda/issues/2
310 """
311 try:
312 return series.unc.mean()
313 except ValueError:
314 if series.empty:
315 return np.nan
316 else:
317 raise
318
319
320def get_uncertain_means_for_qi_cuts(df: upd.DataFrame, column: str, qi_cuts: Iterable[float]):
321 """
322 Return a pandas series with an mean of the dataframe column and
323 uncertainty for each quality indicator cut.
324
325 :param df: Pandas dataframe with at least ``quality_indicator``
326 and another numeric ``column``.
327 :param column: Column of which we want to aggregate the means
328 and uncertainties for different QI cuts
329 :param qi_cuts: Iterable of quality indicator minimal thresholds.
330 :returns: Series of of means and uncertainties with ``qi_cuts`` as index
331 """
332
333 uncertain_means = (_my_uncertain_mean(df.query(f"quality_indicator > {qi_cut}")[column])
334 for qi_cut in qi_cuts)
335 uncertain_means_series = upd.Series(data=uncertain_means, index=qi_cuts)
336 return uncertain_means_series
337
338
339def plot_with_errobands(uncertain_series,
340 error_band_alpha=0.3,
341 plot_kwargs={},
342 fill_between_kwargs={},
343 ax=None):
344 """
345 Plot an uncertain series with error bands for y-errors
346 """
347 if ax is None:
348 ax = plt.gca()
349 uncertain_series = uncertain_series.dropna()
350 ax.plot(uncertain_series.index.values, uncertain_series.nominal_value, **plot_kwargs)
351 ax.fill_between(x=uncertain_series.index,
352 y1=uncertain_series.nominal_value - uncertain_series.std_dev,
353 y2=uncertain_series.nominal_value + uncertain_series.std_dev,
354 alpha=error_band_alpha,
355 **fill_between_kwargs)
356
357
358def format_dictionary(adict, width=80, bullet="•"):
359 """
360 Helper function to format dictionary to string as a wrapped key-value bullet
361 list. Useful to print metadata from dictionaries.
362
363 :param adict: Dictionary to format
364 :param width: Characters after which to wrap a key-value line
365 :param bullet: Character to begin a key-value line with, e.g. ``-`` for a
366 yaml-like string
367 """
368 # It might be possible to replace this function yaml.dump, but the current
369 # version in the externals does not allow to disable the sorting of the
370 # dictionary yet and also I am not sure if it is wrappable
371 return "\n".join(textwrap.fill(f"{bullet} {key}: {value}", width=width)
372 for (key, value) in adict.items())
373
374# Begin definitions of b2luigi task classes
375
376
377class GenerateSimTask(Basf2PathTask):
378 """
379 Generate simulated Monte Carlo with background overlay.
380
381 Make sure to use different ``random_seed`` parameters for the training data
382 format the classifier trainings and for the test data for the respective
383 evaluation/validation tasks.
384 """
385
386
387 n_events = b2luigi.IntParameter()
388
389 experiment_number = b2luigi.IntParameter()
390
392 random_seed = b2luigi.Parameter()
393
394 bkgfiles_dir = b2luigi.Parameter(
395
396 hashed=True
397
398 )
399
400 queue = 'l'
401
402
403 def output_file_name(self, n_events=None, random_seed=None):
404 """
405 Create output file name depending on number of events and production
406 mode that is specified in the random_seed string.
407 """
408 if n_events is None:
409 n_events = self.n_events
410 if random_seed is None:
411 random_seed = self.random_seed
412 return "generated_mc_N" + str(n_events) + "_" + random_seed + ".root"
413
414 def output(self):
415 """
416 Generate list of output files that the task should produce.
417 The task is considered finished if and only if the outputs all exist.
418 """
419 yield self.add_to_output(self.output_file_name())
420
421 def create_path(self):
422 """
423 Create basf2 path to process with event generation and simulation.
424 """
425 basf2.set_random_seed(self.random_seed)
426 path = basf2.create_path()
427 if self.experiment_number in [0, 1002, 1003]:
428 runNo = 0
429 else:
430 runNo = 0
431 raise ValueError(
432 f"Simulating events with experiment number {self.experiment_number} is not implemented yet.")
433 path.add_module(
434 "EventInfoSetter", evtNumList=[self.n_events], runList=[runNo], expList=[self.experiment_number]
435 )
436 if "BBBAR" in self.random_seed:
437 path.add_module("EvtGenInput")
438 elif "V0BBBAR" in self.random_seed:
439 path.add_module("EvtGenInput")
440 path.add_module("InclusiveParticleChecker", particles=[310, 3122], includeConjugates=True)
441 else:
442 import generators as ge
443 # WARNING: There are a few differences in the production of MC13a and b like the following lines
444 # as well as ActivatePXD.. and the beamparams for bhabha... I use these from MC13b, not a... :/
445 # import beamparameters as bp
446 # beamparameters = bp.add_beamparameters(path, "Y4S")
447 # beamparameters.param("covVertex", [(14.8e-4)**2, (1.5e-4)**2, (360e-4)**2])
448 if "V0STUDY" in self.random_seed:
449 if "V0STUDYKS" in self.random_seed:
450 # Bianca looked at the Ks dists and extracted these values:
451 mu = 0.5
452 beta = 0.2
453 pdgs = [310] # Ks (has no antiparticle, Klong is different)
454 if "V0STUDYL0" in self.random_seed:
455 # I just made the lambda values up, such that they peak at 0.35 and are slightly shifted to lower values
456 mu = 0.35
457 beta = 0.15 # if this is chosen higher, one needs to make sure not to get values >0 for 0
458 pdgs = [3122, -3122] # Lambda0
459 else:
460 # also these values are made up
461 mu = 0.43
462 beta = 0.18
463 pdgs = [310, 3122, -3122] # Ks and Lambda0
464 # create realistic momentum distribution
465 myx = [i*0.01 for i in range(321)]
466 myy = []
467 for x in myx:
468 y = createV0momenta(x, mu, beta)
469 myy.append(y)
470 polParams = myx + myy
471 # define particles that are produced
472 pdg_list = pdgs
473
474 particlegun = basf2.register_module('ParticleGun')
475 particlegun.param('pdgCodes', pdg_list)
476 particlegun.param('nTracks', 8) # number of particles (not tracks!) that is created in each event
477 particlegun.param('momentumGeneration', 'polyline')
478 particlegun.param('momentumParams', polParams)
479 particlegun.param('thetaGeneration', 'uniformCos')
480 particlegun.param('thetaParams', [17, 150]) # [0, 180]) #[17, 150]
481 particlegun.param('phiGeneration', 'uniform')
482 particlegun.param('phiParams', [0, 360])
483 particlegun.param('vertexGeneration', 'fixed')
484 particlegun.param('xVertexParams', [0])
485 particlegun.param('yVertexParams', [0])
486 particlegun.param('zVertexParams', [0])
487 path.add_module(particlegun)
488 if "BHABHA" in self.random_seed:
489 ge.add_babayaganlo_generator(path=path, finalstate='ee', minenergy=0.15, minangle=10.0)
490 elif "MUMU" in self.random_seed:
491 ge.add_kkmc_generator(path=path, finalstate='mu+mu-')
492 elif "YY" in self.random_seed:
493 babayaganlo = basf2.register_module('BabayagaNLOInput')
494 babayaganlo.param('FinalState', 'gg')
495 babayaganlo.param('MaxAcollinearity', 180.0)
496 babayaganlo.param('ScatteringAngleRange', [0., 180.])
497 babayaganlo.param('FMax', 75000)
498 babayaganlo.param('MinEnergy', 0.01)
499 babayaganlo.param('Order', 'exp')
500 babayaganlo.param('DebugEnergySpread', 0.01)
501 babayaganlo.param('Epsilon', 0.00005)
502 path.add_module(babayaganlo)
503 generatorpreselection = basf2.register_module('GeneratorPreselection')
504 generatorpreselection.param('nChargedMin', 0)
505 generatorpreselection.param('nChargedMax', 999)
506 generatorpreselection.param('MinChargedPt', 0.15)
507 generatorpreselection.param('MinChargedTheta', 17.)
508 generatorpreselection.param('MaxChargedTheta', 150.)
509 generatorpreselection.param('nPhotonMin', 1)
510 generatorpreselection.param('MinPhotonEnergy', 1.5)
511 generatorpreselection.param('MinPhotonTheta', 15.0)
512 generatorpreselection.param('MaxPhotonTheta', 165.0)
513 generatorpreselection.param('applyInCMS', True)
514 path.add_module(generatorpreselection)
515 empty = basf2.create_path()
516 generatorpreselection.if_value('!=11', empty)
517 elif "EEEE" in self.random_seed:
518 ge.add_aafh_generator(path=path, finalstate='e+e-e+e-', preselection=False)
519 elif "EEMUMU" in self.random_seed:
520 ge.add_aafh_generator(path=path, finalstate='e+e-mu+mu-', preselection=False)
521 elif "TAUPAIR" in self.random_seed:
522 ge.add_kkmc_generator(path, finalstate='tau+tau-')
523 elif "DDBAR" in self.random_seed:
524 ge.add_continuum_generator(path, finalstate='ddbar')
525 elif "UUBAR" in self.random_seed:
526 ge.add_continuum_generator(path, finalstate='uubar')
527 elif "SSBAR" in self.random_seed:
528 ge.add_continuum_generator(path, finalstate='ssbar')
529 elif "CCBAR" in self.random_seed:
530 ge.add_continuum_generator(path, finalstate='ccbar')
531 # activate simulation of dead/masked pixel and reproduce detector gain, which will be
532 # applied at reconstruction level when the data GT is present in the DB chain
533 # path.add_module("ActivatePXDPixelMasker")
534 # path.add_module("ActivatePXDGainCalibrator")
535 bkg_files = background.get_background_files(self.bkgfiles_dir)
536 # \cond suppress doxygen warning
537 if self.experiment_number == 1002:
538 # remove KLM because of bug in background files with release 4
539 components = ['PXD', 'SVD', 'CDC', 'ECL', 'TOP', 'ARICH', 'TRG']
540 else:
541 components = None
542 # \endcond
543 simulation.add_simulation(path, bkgfiles=bkg_files, bkgOverlay=True, components=components) # , usePXDDataReduction=False)
544
545 path.add_module(
546 "RootOutput",
547 outputFileName=self.get_output_file_name(self.output_file_name()),
548 )
549 return path
550
551
552# I don't use the default MergeTask or similar because they only work if every input file is called the same.
553# Additionally, I want to add more features like deleting the original input to save storage space.
554class SplitNMergeSimTask(Basf2Task):
555 """
556 Generate simulated Monte Carlo with background overlay.
557
558 Make sure to use different ``random_seed`` parameters for the training data
559 format the classifier trainings and for the test data for the respective
560 evaluation/validation tasks.
561 """
562
563
564 n_events = b2luigi.IntParameter()
565
566 experiment_number = b2luigi.IntParameter()
567
569 random_seed = b2luigi.Parameter()
570
571 bkgfiles_dir = b2luigi.Parameter(
572
573 hashed=True
574
575 )
576
577 queue = 'sx'
578
579
580 def output_file_name(self, n_events=None, random_seed=None):
581 """
582 Create output file name depending on number of events and production
583 mode that is specified in the random_seed string.
584 """
585 if n_events is None:
586 n_events = self.n_events
587 if random_seed is None:
588 random_seed = self.random_seed
589 return "generated_mc_N" + str(n_events) + "_" + random_seed + ".root"
590
591 def output(self):
592 """
593 Generate list of output files that the task should produce.
594 The task is considered finished if and only if the outputs all exist.
595 """
596 yield self.add_to_output(self.output_file_name())
597
598 def requires(self):
599 """
600 Generate list of luigi Tasks that this Task depends on.
601 """
602 n_events_per_task = MasterTask.n_events_per_task
603 quotient, remainder = divmod(self.n_events, n_events_per_task)
604 for i in range(quotient):
605 yield GenerateSimTask(
606 bkgfiles_dir=self.bkgfiles_dir,
607 num_processes=MasterTask.num_processes,
608 random_seed=self.random_seed + '_' + str(i).zfill(3),
609 n_events=n_events_per_task,
610 experiment_number=self.experiment_number,
611 )
612 if remainder > 0:
613 yield GenerateSimTask(
614 bkgfiles_dir=self.bkgfiles_dir,
615 num_processes=MasterTask.num_processes,
616 random_seed=self.random_seed + '_' + str(quotient).zfill(3),
617 n_events=remainder,
618 experiment_number=self.experiment_number,
619 )
620
621 @b2luigi.on_temporary_files
622 def process(self):
623 """
624 When all GenerateSimTasks finished, merge the output.
625 """
626 create_output_dirs(self)
627
628 file_list = list(self.get_all_input_file_names())
629 print("Merge the following files:")
630 print(file_list)
631 cmd = ["b2file-merge", "-f"]
632 args = cmd + [self.get_output_file_name(self.output_file_name())] + file_list
633 subprocess.check_call(args)
634 print("Finished merging. Now remove the input files to save space.")
635 cmd2 = ["rm", "-f"]
636 for tempfile in file_list:
637 args = cmd2 + [tempfile]
638 subprocess.check_call(args)
639
640
641class CheckExistingFile(ExternalTask):
642 """
643 Task to check if the given file really exists.
644 """
645
646 filename = b2luigi.Parameter()
647
648 def output(self):
649 """
650 Specify the output to be the file that was just checked.
651 """
652 from luigi import LocalTarget
653 return LocalTarget(self.filename)
654
655
656class VXDQEDataCollectionTask(Basf2PathTask):
657 """
658 Collect variables/features from VXDTF2 tracking and write them to a ROOT
659 file.
660
661 These variables are to be used as labelled training data for the MVA
662 classifier which is the VXD track quality estimator
663 """
664
665 n_events = b2luigi.IntParameter()
666
667 experiment_number = b2luigi.IntParameter()
668
670 random_seed = b2luigi.Parameter()
671
672 queue = 'l'
673
674
675 def get_records_file_name(self, n_events=None, random_seed=None):
676 """
677 Create output file name depending on number of events and production
678 mode that is specified in the random_seed string.
679 """
680 if n_events is None:
681 n_events = self.n_events
682 if random_seed is None:
683 random_seed = self.random_seed
684 if 'vxd' not in random_seed:
685 random_seed += '_vxd'
686 if 'DATA' in random_seed:
687 return 'qe_records_DATA_vxd.root'
688 else:
689 if 'USESIMBB' in random_seed:
690 random_seed = 'BBBAR_' + random_seed.split("_", 1)[1]
691 elif 'USESIMEE' in random_seed:
692 random_seed = 'BHABHA_' + random_seed.split("_", 1)[1]
693 return 'qe_records_N' + str(n_events) + '_' + random_seed + '.root'
694
695 def get_input_files(self, n_events=None, random_seed=None):
696 """
697 Get input file names depending on the use case: If they already exist, search in
698 the corresponding folders, for data check the specified list and if they are created
699 in the same run, check for the task that produced them.
700 """
701 if n_events is None:
702 n_events = self.n_events
703 if random_seed is None:
704 random_seed = self.random_seed
705 if "USESIM" in random_seed:
706 if 'USESIMBB' in random_seed:
707 random_seed = 'BBBAR_' + random_seed.split("_", 1)[1]
708 elif 'USESIMEE' in random_seed:
709 random_seed = 'BHABHA_' + random_seed.split("_", 1)[1]
710 return ['datafiles/' + GenerateSimTask.output_file_name(GenerateSimTask,
711 n_events=n_events, random_seed=random_seed)]
712 elif "DATA" in random_seed:
713 return MasterTask.datafiles
714 else:
715 return self.get_input_file_names(GenerateSimTask.output_file_name(
716 GenerateSimTask, n_events=n_events, random_seed=random_seed))
717
718 def requires(self):
719 """
720 Generate list of luigi Tasks that this Task depends on.
721 """
722 if "USESIM" in self.random_seed or "DATA" in self.random_seed:
723 for filename in self.get_input_files():
724 yield CheckExistingFile(
725 filename=filename,
726 )
727 else:
728 yield SplitNMergeSimTask(
729 bkgfiles_dir=MasterTask.bkgfiles_by_exp[self.experiment_number],
730 random_seed=self.random_seed,
731 n_events=self.n_events,
732 experiment_number=self.experiment_number,
733 )
734
735 def output(self):
736 """
737 Generate list of output files that the task should produce.
738 The task is considered finished if and only if the outputs all exist.
739 """
740 yield self.add_to_output(self.get_records_file_name())
741
742 def create_path(self):
743 """
744 Create basf2 path with VXDTF2 tracking and VXD QE data collection.
745 """
746 path = basf2.create_path()
747 inputFileNames = self.get_input_files()
748 path.add_module(
749 "RootInput",
750 inputFileNames=inputFileNames,
751 )
752 path.add_module("Gearbox")
753 tracking.add_geometry_modules(path)
754 if 'DATA' in self.random_seed:
755 from rawdata import add_unpackers
756 add_unpackers(path, components=['SVD', 'PXD'])
757 tracking.add_hit_preparation_modules(path)
758 tracking.add_vxd_track_finding_vxdtf2(
759 path, components=["SVD"], add_mva_quality_indicator=False
760 )
761 if 'DATA' in self.random_seed:
762 path.add_module(
763 "VXDQETrainingDataCollector",
764 TrainingDataOutputName=self.get_output_file_name(self.get_records_file_name()),
765 SpacePointTrackCandsStoreArrayName="SPTrackCands",
766 EstimationMethod="tripletFit",
767 UseTimingInfo=False,
768 ClusterInformation="Average",
769 MCStrictQualityEstimator=False,
770 mva_target=False,
771 MCInfo=False,
772 )
773 else:
774 path.add_module(
775 "TrackFinderMCTruthRecoTracks",
776 RecoTracksStoreArrayName="MCRecoTracks",
777 WhichParticles=[],
778 UsePXDHits=False,
779 UseSVDHits=True,
780 UseCDCHits=False,
781 )
782 path.add_module(
783 "VXDQETrainingDataCollector",
784 TrainingDataOutputName=self.get_output_file_name(self.get_records_file_name()),
785 SpacePointTrackCandsStoreArrayName="SPTrackCands",
786 EstimationMethod="tripletFit",
787 UseTimingInfo=False,
788 ClusterInformation="Average",
789 MCStrictQualityEstimator=True,
790 mva_target=False,
791 )
792 return path
793
794
795class CDCQEDataCollectionTask(Basf2PathTask):
796 """
797 Collect variables/features from CDC tracking and write them to a ROOT file.
798
799 These variables are to be used as labelled training data for the MVA
800 classifier which is the CDC track quality estimator
801 """
802
803 n_events = b2luigi.IntParameter()
804
805 experiment_number = b2luigi.IntParameter()
806
808 random_seed = b2luigi.Parameter()
809
810 queue = 'l'
811
812
813 def get_records_file_name(self, n_events=None, random_seed=None):
814 """
815 Create output file name depending on number of events and production
816 mode that is specified in the random_seed string.
817 """
818 if n_events is None:
819 n_events = self.n_events
820 if random_seed is None:
821 random_seed = self.random_seed
822 if 'cdc' not in random_seed:
823 random_seed += '_cdc'
824 if 'DATA' in random_seed:
825 return 'qe_records_DATA_cdc.root'
826 else:
827 if 'USESIMBB' in random_seed:
828 random_seed = 'BBBAR_' + random_seed.split("_", 1)[1]
829 elif 'USESIMEE' in random_seed:
830 random_seed = 'BHABHA_' + random_seed.split("_", 1)[1]
831 return 'qe_records_N' + str(n_events) + '_' + random_seed + '.root'
832
833 def get_input_files(self, n_events=None, random_seed=None):
834 """
835 Get input file names depending on the use case: If they already exist, search in
836 the corresponding folders, for data check the specified list and if they are created
837 in the same run, check for the task that produced them.
838 """
839 if n_events is None:
840 n_events = self.n_events
841 if random_seed is None:
842 random_seed = self.random_seed
843 if "USESIM" in random_seed:
844 if 'USESIMBB' in random_seed:
845 random_seed = 'BBBAR_' + random_seed.split("_", 1)[1]
846 elif 'USESIMEE' in random_seed:
847 random_seed = 'BHABHA_' + random_seed.split("_", 1)[1]
848 return ['datafiles/' + GenerateSimTask.output_file_name(GenerateSimTask,
849 n_events=n_events, random_seed=random_seed)]
850 elif "DATA" in random_seed:
851 return MasterTask.datafiles
852 else:
853 return self.get_input_file_names(GenerateSimTask.output_file_name(
854 GenerateSimTask, n_events=n_events, random_seed=random_seed))
855
856 def requires(self):
857 """
858 Generate list of luigi Tasks that this Task depends on.
859 """
860 if "USESIM" in self.random_seed or "DATA" in self.random_seed:
861 for filename in self.get_input_files():
862 yield CheckExistingFile(
863 filename=filename,
864 )
865 else:
866 yield SplitNMergeSimTask(
867 bkgfiles_dir=MasterTask.bkgfiles_by_exp[self.experiment_number],
868 random_seed=self.random_seed,
869 n_events=self.n_events,
870 experiment_number=self.experiment_number,
871 )
872
873 def output(self):
874 """
875 Generate list of output files that the task should produce.
876 The task is considered finished if and only if the outputs all exist.
877 """
878 yield self.add_to_output(self.get_records_file_name())
879
880 def create_path(self):
881 """
882 Create basf2 path with CDC standalone tracking and CDC QE with recording filter for MVA feature collection.
883 """
884 path = basf2.create_path()
885 inputFileNames = self.get_input_files()
886 path.add_module(
887 "RootInput",
888 inputFileNames=inputFileNames,
889 )
890 path.add_module("Gearbox")
891 tracking.add_geometry_modules(path)
892 if 'DATA' in self.random_seed:
893 filter_choice = "recording_data"
894 from rawdata import add_unpackers
895 add_unpackers(path, components=['CDC'])
896 else:
897 filter_choice = "recording"
898 # tracking.add_hit_preparation_modules(path) # only needed for SVD and
899 # PXD hit preparation. Does not change the CDC output.
900 tracking.add_cdc_track_finding(path, add_mva_quality_indicator=True)
901
902 basf2.set_module_parameters(
903 path,
904 name="TFCDC_TrackQualityEstimator",
905 filter=filter_choice,
906 filterParameters={
907 "rootFileName": self.get_output_file_name(self.get_records_file_name())
908 },
909 deactivateIfDeadBoard=False # original behavior before deactivateIfDeadBoard was introduced
910 )
911 return path
912
913
914class RecoTrackQEDataCollectionTask(Basf2PathTask):
915 """
916 Collect variables/features from the reco track reconstruction including the
917 fit and write them to a ROOT file.
918
919 These variables are to be used as labelled training data for the MVA
920 classifier which is the MVA track quality estimator. The collected
921 variables include the classifier outputs from the VXD and CDC quality
922 estimators, namely the CDC and VXD quality indicators, combined with fit,
923 merger, timing, energy loss information etc. This task requires the
924 subdetector quality estimators to be trained.
925 """
926
927
928 n_events = b2luigi.IntParameter()
929
930 experiment_number = b2luigi.IntParameter()
931
933 random_seed = b2luigi.Parameter()
934
935 cdc_training_target = b2luigi.Parameter()
936
939 recotrack_option = b2luigi.Parameter(
940
941 default='deleteCDCQI080'
942
943 )
944
945 fast_bdt_option = b2luigi.ListParameter(
946
947 hashed=True, default=[200, 8, 3, 0.1]
948
949 )
950
951 queue = 'l'
952
953
954 def get_records_file_name(self, n_events=None, random_seed=None, recotrack_option=None):
955 """
956 Create output file name depending on number of events and production
957 mode that is specified in the random_seed string.
958 """
959 if n_events is None:
960 n_events = self.n_events
961 if random_seed is None:
962 random_seed = self.random_seed
963 if recotrack_option is None:
964 if isinstance(self.recotrack_option, str):
965 recotrack_option = self.recotrack_option
966 else:
967 recotrack_option = self.recotrack_option._default
968 if not isinstance(recotrack_option, str):
969 recotrack_option = recotrack_option._default
970 if 'rec' not in random_seed:
971 random_seed += '_rec'
972 if 'DATA' in random_seed:
973 return 'qe_records_DATA_rec.root'
974 else:
975 if 'USESIMBB' in random_seed:
976 random_seed = 'BBBAR_' + random_seed.split("_", 1)[1]
977 elif 'USESIMEE' in random_seed:
978 random_seed = 'BHABHA_' + random_seed.split("_", 1)[1]
979 return 'qe_records_N' + str(n_events) + '_' + random_seed + '_' + recotrack_option + '.root'
980
981 def get_input_files(self, n_events=None, random_seed=None):
982 """
983 Get input file names depending on the use case: If they already exist, search in
984 the corresponding folders, for data check the specified list and if they are created
985 in the same run, check for the task that produced them.
986 """
987 if n_events is None:
988 n_events = self.n_events
989 if random_seed is None:
990 random_seed = self.random_seed
991 if "USESIM" in random_seed:
992 if 'USESIMBB' in random_seed:
993 random_seed = 'BBBAR_' + random_seed.split("_", 1)[1]
994 elif 'USESIMEE' in random_seed:
995 random_seed = 'BHABHA_' + random_seed.split("_", 1)[1]
996 return ['datafiles/' + GenerateSimTask.output_file_name(GenerateSimTask,
997 n_events=n_events, random_seed=random_seed)]
998 elif "DATA" in random_seed:
999 return MasterTask.datafiles
1000 else:
1001 return self.get_input_file_names(GenerateSimTask.output_file_name(
1002 GenerateSimTask, n_events=n_events, random_seed=random_seed))
1003
1004 def requires(self):
1005 """
1006 Generate list of luigi Tasks that this Task depends on.
1007 """
1008 if "USESIM" in self.random_seed or "DATA" in self.random_seed:
1009 for filename in self.get_input_files():
1010 yield CheckExistingFile(
1011 filename=filename,
1012 )
1013 else:
1014 yield SplitNMergeSimTask(
1015 bkgfiles_dir=MasterTask.bkgfiles_by_exp[self.experiment_number],
1016 random_seed=self.random_seed,
1017 n_events=self.n_events,
1018 experiment_number=self.experiment_number,
1019 )
1020 if "DATA" not in self.random_seed:
1021 if 'useCDC' not in self.recotrack_option and 'noCDC' not in self.recotrack_option:
1022 yield CDCQETeacherTask(
1023 n_events_training=MasterTask.n_events_training,
1024 experiment_number=self.experiment_number,
1025 training_target=self.cdc_training_target,
1026 process_type=self.random_seed.split("_", 1)[0],
1027 exclude_variables=MasterTask.exclude_variables_cdc,
1028 fast_bdt_option=self.fast_bdt_option,
1029 )
1030 if 'useVXD' not in self.recotrack_option and 'noVXD' not in self.recotrack_option:
1031 yield VXDQETeacherTask(
1032 n_events_training=MasterTask.n_events_training,
1033 experiment_number=self.experiment_number,
1034 process_type=self.random_seed.split("_", 1)[0],
1035 exclude_variables=MasterTask.exclude_variables_vxd,
1036 fast_bdt_option=self.fast_bdt_option,
1037 )
1038
1039 def output(self):
1040 """
1041 Generate list of output files that the task should produce.
1042 The task is considered finished if and only if the outputs all exist.
1043 """
1044 yield self.add_to_output(self.get_records_file_name())
1045
1046 def create_path(self):
1047 """
1048 Create basf2 reconstruction path that should mirror the default path
1049 from ``add_tracking_reconstruction()``, but with modules for the VXD QE
1050 and CDC QE application and for collection of variables for the reco
1051 track quality estimator.
1052 """
1053 path = basf2.create_path()
1054 inputFileNames = self.get_input_files()
1055 path.add_module(
1056 "RootInput",
1057 inputFileNames=inputFileNames,
1058 )
1059 path.add_module("Gearbox")
1060
1061 # First add tracking reconstruction with default quality estimation modules
1062 mvaCDC = True
1063 mvaVXD = True
1064 if 'noCDC' in self.recotrack_option:
1065 mvaCDC = False
1066 if 'noVXD' in self.recotrack_option:
1067 mvaVXD = False
1068 if 'DATA' in self.random_seed:
1069 from rawdata import add_unpackers
1070 add_unpackers(path)
1071 tracking.add_tracking_reconstruction(path, add_cdcTrack_QI=mvaCDC, add_vxdTrack_QI=mvaVXD, add_recoTrack_QI=True)
1072
1073 cdc_identifier = ""
1074 # if data shall be processed check if newly trained mva files are available. Otherwise use default ones (CDB payloads):
1075 # if useCDC/VXD is specified, use the identifier lying in datafiles/ Otherwise, replace weightfile identifiers from defaults
1076 # (CDB payloads) to new weightfiles created by this b2luigi script
1077 if ('DATA' in self.random_seed or 'useCDC' in self.recotrack_option) and 'noCDC' not in self.recotrack_option:
1078 cdc_identifier = 'datafiles/' + \
1079 CDCQETeacherTask.get_weightfile_xml_identifier(CDCQETeacherTask, fast_bdt_option=self.fast_bdt_option)
1080 if os.path.exists(cdc_identifier):
1081 replace_cdc_qi = True
1082 elif 'useCDC' in self.recotrack_option:
1083 raise ValueError(f"CDC QI Identifier not found: {cdc_identifier}")
1084 else:
1085 replace_cdc_qi = False
1086 elif 'noCDC' in self.recotrack_option:
1087 replace_cdc_qi = False
1088 else:
1089 cdc_identifier = self.get_input_file_names(
1090 CDCQETeacherTask.get_weightfile_xml_identifier(
1091 CDCQETeacherTask, fast_bdt_option=self.fast_bdt_option))[0]
1092 replace_cdc_qi = True
1093 if ('DATA' in self.random_seed or 'useVXD' in self.recotrack_option) and 'noVXD' not in self.recotrack_option:
1094 vxd_identifier = 'datafiles/' + \
1095 VXDQETeacherTask.get_weightfile_xml_identifier(VXDQETeacherTask, fast_bdt_option=self.fast_bdt_option)
1096 if os.path.exists(vxd_identifier):
1097 replace_vxd_qi = True
1098 elif 'useVXD' in self.recotrack_option:
1099 raise ValueError(f"VXD QI Identifier not found: {vxd_identifier}")
1100 else:
1101 replace_vxd_qi = False
1102 elif 'noVXD' in self.recotrack_option:
1103 replace_vxd_qi = False
1104 else:
1105 vxd_identifier = self.get_input_file_names(
1106 VXDQETeacherTask.get_weightfile_xml_identifier(
1107 VXDQETeacherTask, fast_bdt_option=self.fast_bdt_option))[0]
1108 replace_vxd_qi = True
1109
1110 cdc_qe_mva_filter_parameters = None
1111 # if tracks below a certain CDC QI index shall be deleted online, this needs to be specified in the filter parameters.
1112 # this is also possible in case of the default (CBD) payloads.
1113 cut = 0
1114 if 'deleteCDCQI' in self.recotrack_option:
1115 cut_index = self.recotrack_option.find('deleteCDCQI') + len('deleteCDCQI')
1116 cut = int(self.recotrack_option[cut_index:cut_index+3])/100.
1117 if replace_cdc_qi:
1118 cdc_qe_mva_filter_parameters = {
1119 "identifier": cdc_identifier, "cut": cut}
1120 else:
1121 cdc_qe_mva_filter_parameters = {
1122 "cut": cut}
1123 elif replace_cdc_qi:
1124 cdc_qe_mva_filter_parameters = {
1125 "identifier": cdc_identifier}
1126 # change weightfile of quality estimator to the one produced by this training script
1127 basf2.conditions.prepend_testing_payloads("localdb/database.txt")
1128
1129 if cdc_qe_mva_filter_parameters is not None and cdc_identifier is not None:
1130 name = 'TrackingMVAFilterParameters'
1132 dbobj_name=name,
1133 iovList=(0, 0, 0, -1),
1134 weightfile_identifier=cdc_identifier,
1135 cut_value=cut)
1136 cdc_qe_mva_filter_parameters = {'DBPayloadName': name}
1137 if replace_vxd_qi and vxd_identifier is not None:
1138 vxd_name = 'VXDQualityEstimatorMVAWeightFileIdentifier'
1139 with open(vxd_identifier) as f:
1140 weight_file_content = f.read()
1141 vxd_name = write_mva_weightfile_content_to_db(vxd_name, weight_file_content, (0, 0, 0, -1))
1142 if cdc_qe_mva_filter_parameters is not None:
1143 # if no cut is specified, the default value is at zero and nothing is deleted.
1144 basf2.set_module_parameters(
1145 path,
1146 name="TFCDC_TrackQualityEstimator",
1147 filterParameters=cdc_qe_mva_filter_parameters,
1148 deleteTracks=True,
1149 resetTakenFlag=True,
1150 deactivateIfDeadBoard=False, # original behavior before deactivateIfDeadBoard was introduced
1151 )
1152 if replace_vxd_qi:
1153 basf2.set_module_parameters(
1154 path,
1155 name="VXDQualityEstimatorMVA",
1156 WeightFileIdentifier=vxd_identifier)
1157
1158 # Replace final quality estimator module by training data collector module
1159 track_qe_module_name = "TrackQualityEstimatorMVA"
1160 mc_track_matcher_module_name = "MCRecoTracksMatcher"
1161 mc_matcher_module_found = False
1162 qe_module_found = False
1163 new_path = basf2.create_path()
1164 for module in path.modules():
1165 if module.name() == track_qe_module_name:
1166 # the TrackCreator needs to be conducted before the Collector such that
1167 # MDSTTracks are related to RecoTracks and d0 and z0 can be read out
1168 new_path.add_module(
1169 'TrackCreator',
1170 pdgCodes=[
1171 211,
1172 321,
1173 2212],
1174 recoTrackColName='RecoTracks',
1175 trackColName='MDSTTracks') # , useClosestHitToIP=True, useBFieldAtHit=True)
1176 qe_module_found = True
1177 elif module.name() == mc_track_matcher_module_name:
1178 new_path.add_module(module)
1179 # move TrackQETrainingDataCollector module after the MCRecoTracksMatcher module
1180 new_path.add_module(
1181 "TrackQETrainingDataCollector",
1182 TrainingDataOutputName=self.get_output_file_name(self.get_records_file_name()),
1183 collectEventFeatures=True
1184 )
1185 mc_matcher_module_found = True
1186 else:
1187 new_path.add_module(module)
1188 if not qe_module_found:
1189 raise KeyError(f"No module {track_qe_module_name} found in path")
1190 if not mc_matcher_module_found:
1191 raise KeyError(f"No module {mc_matcher_module_found} found in path")
1192 path = new_path
1193 return path
1194
1195
1196class TrackQETeacherBaseTask(Basf2Task):
1197 """
1198 A teacher task runs the basf2 mva teacher on the training data provided by a
1199 data collection task.
1200
1201 Since teacher tasks are needed for all quality estimators covered by this
1202 steering file and the only thing that changes is the required data
1203 collection task and some training parameters, I decided to use inheritance
1204 and have the basic functionality in this base class/interface and have the
1205 specific teacher tasks inherit from it.
1206 """
1207
1208 n_events_training = b2luigi.IntParameter()
1209
1210 experiment_number = b2luigi.IntParameter()
1211
1214 process_type = b2luigi.Parameter(
1215
1216 default="BBBAR"
1217
1218 )
1219
1220 training_target = b2luigi.Parameter(
1221
1222 default="truth"
1223
1224 )
1225
1227 exclude_variables = b2luigi.ListParameter(
1228
1229 hashed=True, default=[]
1230
1231 )
1232
1233 fast_bdt_option = b2luigi.ListParameter(
1234
1235 hashed=True, default=[200, 8, 3, 0.1]
1236
1237 )
1238
1239 @property
1240 def weightfile_identifier_basename(self):
1241 """
1242 Property defining the basename for the .xml and .root weightfiles that are created.
1243 Has to be implemented by the inheriting teacher task class.
1244 """
1245 raise NotImplementedError(
1246 "Teacher Task must define a static weightfile_identifier"
1247 )
1248
1249 def get_weightfile_xml_identifier(self, fast_bdt_option=None, recotrack_option=None):
1250 """
1251 Name of the xml weightfile that is created by the teacher task.
1252 It is subsequently used as a local weightfile in the following validation tasks.
1253 """
1254 if fast_bdt_option is None:
1255 fast_bdt_option = self.fast_bdt_option
1256 if recotrack_option is None and hasattr(self, 'recotrack_option'):
1257 if isinstance(self.recotrack_option, str):
1258 recotrack_option = self.recotrack_option
1259 else:
1260 recotrack_option = self.recotrack_option._default
1261 else:
1262 recotrack_option = ''
1263 weightfile_details = create_fbdt_option_string(fast_bdt_option)
1264 weightfile_name = self.weightfile_identifier_basename + weightfile_details
1265 if recotrack_option != '':
1266 weightfile_name = weightfile_name + '_' + recotrack_option
1267 return weightfile_name + "_weights.xml"
1268
1269 @property
1270 def tree_name(self):
1271 """
1272 Property defining the name of the tree in the ROOT file from the
1273 ``data_collection_task`` that contains the recorded training data. Must
1274 implemented by the inheriting specific teacher task class.
1275 """
1276 raise NotImplementedError("Teacher Task must define a static tree_name")
1277
1278 @property
1279 def random_seed(self):
1280 """
1281 Property defining random seed to be used by the ``GenerateSimTask``.
1282 Should differ from the random seed in the test data samples. Must
1283 implemented by the inheriting specific teacher task class.
1284 """
1285 raise NotImplementedError("Teacher Task must define a static random seed")
1286
1287 @property
1288 def data_collection_task(self) -> Basf2PathTask:
1289 """
1290 Property defining the specific ``DataCollectionTask`` to require. Must
1291 implemented by the inheriting specific teacher task class.
1292 """
1293 raise NotImplementedError(
1294 "Teacher Task must define a data collection task to require "
1295 )
1296
1297 def requires(self):
1298 """
1299 Generate list of luigi Tasks that this Task depends on.
1300 """
1301 if 'USEREC' in self.process_type:
1302 if 'USERECBB' in self.process_type:
1303 process = 'BBBAR'
1304 elif 'USERECEE' in self.process_type:
1305 process = 'BHABHA'
1306 yield CheckExistingFile(
1307 filename='datafiles/qe_records_N' + str(self.n_events_training) + '_' + process + '_' + self.random_seed + '.root',
1308 )
1309 else:
1310 yield self.data_collection_task(
1311 num_processes=MasterTask.num_processes,
1312 n_events=self.n_events_training,
1313 experiment_number=self.experiment_number,
1314 random_seed=self.process_type + '_' + self.random_seed,
1315 )
1316
1317 def output(self):
1318 """
1319 Generate list of output files that the task should produce.
1320 The task is considered finished if and only if the outputs all exist.
1321 """
1322 yield self.add_to_output(self.get_weightfile_xml_identifier())
1323
1324 def process(self):
1325 """
1326 Use basf2_mva teacher to create MVA weightfile from collected training
1327 data variables.
1328
1329 This is the main process that is dispatched by the ``run`` method that
1330 is inherited from ``Basf2Task``.
1331 """
1332 if 'USEREC' in self.process_type:
1333 if 'USERECBB' in self.process_type:
1334 process = 'BBBAR'
1335 elif 'USERECEE' in self.process_type:
1336 process = 'BHABHA'
1337 records_files = ['datafiles/qe_records_N' + str(self.n_events_training) +
1338 '_' + process + '_' + self.random_seed + '.root']
1339 else:
1340 if hasattr(self, 'recotrack_option'):
1341 records_files = self.get_input_file_names(
1342 self.data_collection_task.get_records_file_name(
1343 self.data_collection_task,
1344 n_events=self.n_events_training,
1345 random_seed=self.process_type + '_' + self.random_seed,
1346 recotrack_option=self.recotrack_option))
1347 else:
1348 records_files = self.get_input_file_names(
1349 self.data_collection_task.get_records_file_name(
1350 self.data_collection_task,
1351 n_events=self.n_events_training,
1352 random_seed=self.process_type + '_' + self.random_seed))
1353
1354 my_basf2_mva_teacher(
1355 records_files=records_files,
1356 tree_name=self.tree_name,
1357 weightfile_identifier=self.get_output_file_name(self.get_weightfile_xml_identifier()),
1358 target_variable=self.training_target,
1359 exclude_variables=self.exclude_variables,
1360 fast_bdt_option=self.fast_bdt_option,
1361 )
1362
1363
1364class VXDQETeacherTask(TrackQETeacherBaseTask):
1365 """
1366 Task to run basf2 mva teacher on collected data for VXDTF2 track quality estimator
1367 """
1368
1369 weightfile_identifier_basename = "vxdtf2_mva_qe"
1370
1372 tree_name = "tree"
1373
1374 random_seed = "train_vxd"
1375
1377 data_collection_task = VXDQEDataCollectionTask
1378
1379
1380class CDCQETeacherTask(TrackQETeacherBaseTask):
1381 """
1382 Task to run basf2 mva teacher on collected data for CDC track quality estimator
1383 """
1384
1385 weightfile_identifier_basename = "cdc_mva_qe"
1386
1388 tree_name = "records"
1389
1390 random_seed = "train_cdc"
1391
1393 data_collection_task = CDCQEDataCollectionTask
1394
1395
1396class RecoTrackQETeacherTask(TrackQETeacherBaseTask):
1397 """
1398 Task to run basf2 mva teacher on collected data for the final, combined
1399 track quality estimator
1400 """
1401
1404 recotrack_option = b2luigi.Parameter(
1405
1406 default='deleteCDCQI080'
1407
1408 )
1409
1410
1411 weightfile_identifier_basename = "recotrack_mva_qe"
1412
1414 tree_name = "tree"
1415
1416 random_seed = "train_rec"
1417
1419 data_collection_task = RecoTrackQEDataCollectionTask
1420
1421 cdc_training_target = b2luigi.Parameter()
1422
1423 def requires(self):
1424 """
1425 Generate list of luigi Tasks that this Task depends on.
1426 """
1427 if 'USEREC' in self.process_type:
1428 if 'USERECBB' in self.process_type:
1429 process = 'BBBAR'
1430 elif 'USERECEE' in self.process_type:
1431 process = 'BHABHA'
1432 yield CheckExistingFile(
1433 filename='datafiles/qe_records_N' + str(self.n_events_training) + '_' + process + '_' + self.random_seed + '.root',
1434 )
1435 else:
1436 yield self.data_collection_task(
1437 cdc_training_target=self.cdc_training_target,
1438 num_processes=MasterTask.num_processes,
1439 n_events=self.n_events_training,
1440 experiment_number=self.experiment_number,
1441 random_seed=self.process_type + '_' + self.random_seed,
1442 recotrack_option=self.recotrack_option,
1443 fast_bdt_option=self.fast_bdt_option,
1444 )
1445
1446
1447class HarvestingValidationBaseTask(Basf2PathTask):
1448 """
1449 Run track reconstruction with MVA quality estimator and write out
1450 (="harvest") a root file with variables useful for the validation.
1451 """
1452
1453
1454 n_events_testing = b2luigi.IntParameter()
1455
1456 n_events_training = b2luigi.IntParameter()
1457
1458 experiment_number = b2luigi.IntParameter()
1459
1462 process_type = b2luigi.Parameter(
1463
1464 default="BBBAR"
1465
1466 )
1467
1469 exclude_variables = b2luigi.ListParameter(
1470
1471 hashed=True
1472
1473 )
1474
1475 fast_bdt_option = b2luigi.ListParameter(
1476
1477 hashed=True, default=[200, 8, 3, 0.1]
1478
1479 )
1480
1481 validation_output_file_name = "harvesting_validation.root"
1482
1483 reco_output_file_name = "reconstruction.root"
1484
1485 components = None
1486
1487 @property
1488 def teacher_task(self) -> TrackQETeacherBaseTask:
1489 """
1490 Teacher task to require to provide a quality estimator weightfile for ``add_tracking_with_quality_estimation``
1491 """
1492 raise NotImplementedError()
1493
1494 def add_tracking_with_quality_estimation(self, path: basf2.Path) -> None:
1495 """
1496 Add modules for track reconstruction to basf2 path that are to be
1497 validated. Besides track finding it should include MC matching, fitted
1498 track creation and a quality estimator module.
1499 """
1500 raise NotImplementedError()
1501
1502 def requires(self):
1503 """
1504 Generate list of luigi Tasks that this Task depends on.
1505 """
1506 yield self.teacher_task(
1507 n_events_training=self.n_events_training,
1508 experiment_number=self.experiment_number,
1509 process_type=self.process_type,
1510 exclude_variables=self.exclude_variables,
1511 fast_bdt_option=self.fast_bdt_option,
1512 )
1513 if 'USE' in self.process_type: # USESIM and USEREC
1514 if 'BB' in self.process_type:
1515 process = 'BBBAR'
1516 elif 'EE' in self.process_type:
1517 process = 'BHABHA'
1518 yield CheckExistingFile(
1519 filename='datafiles/generated_mc_N' + str(self.n_events_testing) + '_' + process + '_test.root'
1520 )
1521 else:
1522 yield SplitNMergeSimTask(
1523 bkgfiles_dir=MasterTask.bkgfiles_by_exp[self.experiment_number],
1524 random_seed=self.process_type + '_test',
1525 n_events=self.n_events_testing,
1526 experiment_number=self.experiment_number,
1527 )
1528
1529 def output(self):
1530 """
1531 Generate list of output files that the task should produce.
1532 The task is considered finished if and only if the outputs all exist.
1533 """
1534 yield self.add_to_output(self.validation_output_file_name)
1535 yield self.add_to_output(self.reco_output_file_name)
1536
1537 def create_path(self):
1538 """
1539 Create a basf2 path that uses ``add_tracking_with_quality_estimation()``
1540 and adds the ``CombinedTrackingValidationModule`` to write out variables
1541 for validation.
1542 """
1543 # prepare track finding
1544 path = basf2.create_path()
1545 if 'USE' in self.process_type:
1546 if 'BB' in self.process_type:
1547 process = 'BBBAR'
1548 elif 'EE' in self.process_type:
1549 process = 'BHABHA'
1550 inputFileNames = ['datafiles/generated_mc_N' + str(self.n_events_testing) + '_' + process + '_test.root']
1551 else:
1552 inputFileNames = self.get_input_file_names(GenerateSimTask.output_file_name(
1553 GenerateSimTask, n_events=self.n_events_testing, random_seed=self.process_type + '_test'))
1554 path.add_module(
1555 "RootInput",
1556 inputFileNames=inputFileNames,
1557 )
1558 path.add_module("Gearbox")
1559 tracking.add_geometry_modules(path)
1560 tracking.add_hit_preparation_modules(path) # only needed for simulated hits
1561 # add track finding module that needs to be validated
1562 self.add_tracking_with_quality_estimation(path)
1563 # add modules for validation
1564 path.add_module(
1565 CombinedTrackingValidationModule(
1566 name=None,
1567 contact=None,
1568 expert_level=200,
1569 output_file_name=self.get_output_file_name(
1570 self.validation_output_file_name
1571 ),
1572 )
1573 )
1574 path.add_module(
1575 "RootOutput",
1576 outputFileName=self.get_output_file_name(self.reco_output_file_name),
1577 )
1578 return path
1579
1580
1581class VXDQEHarvestingValidationTask(HarvestingValidationBaseTask):
1582 """
1583 Run VXDTF2 track reconstruction and write out (="harvest") a root file with
1584 variables useful for validation of the VXD Quality Estimator.
1585 """
1586
1587
1588 validation_output_file_name = "vxd_qe_harvesting_validation.root"
1589
1590 reco_output_file_name = "vxd_qe_reconstruction.root"
1591
1592 teacher_task = VXDQETeacherTask
1593
1594 def add_tracking_with_quality_estimation(self, path):
1595 """
1596 Add modules for VXDTF2 tracking with VXD quality estimator to basf2 path.
1597 """
1598 tracking.add_vxd_track_finding_vxdtf2(
1599 path,
1600 components=["SVD"],
1601 reco_tracks="RecoTracks",
1602 add_mva_quality_indicator=True,
1603 )
1604 # Replace the weightfiles of all quality estimator module by those
1605 # produced in this training by b2luigi
1606 vxd_identifier = self.get_input_file_names(
1607 self.teacher_task.get_weightfile_xml_identifier(self.teacher_task, fast_bdt_option=self.fast_bdt_option)
1608 )[0]
1609 with open(vxd_identifier) as f:
1610 weight_file_content = f.read()
1611 vxd_name = write_mva_weightfile_content_to_db(
1612 dbobj_name='VXDQualityEstimatorMVAWeightFileIdentifier',
1613 content=weight_file_content,
1614 iovList=(0, 0, 0, -1)
1615 )
1616 basf2.set_module_parameters(
1617 path,
1618 name="VXDQualityEstimatorMVA",
1619 WeightFileIdentifier=vxd_name,
1620 )
1621 tracking.add_mc_matcher(path, components=["SVD"])
1622 tracking.add_track_fit_and_track_creator(path, components=["SVD"])
1623
1624
1625class CDCQEHarvestingValidationTask(HarvestingValidationBaseTask):
1626 """
1627 Run CDC reconstruction and write out (="harvest") a root file with variables
1628 useful for validation of the CDC Quality Estimator.
1629 """
1630
1631 training_target = b2luigi.Parameter()
1632
1633 validation_output_file_name = "cdc_qe_harvesting_validation.root"
1634
1635 reco_output_file_name = "cdc_qe_reconstruction.root"
1636
1637 teacher_task = CDCQETeacherTask
1638
1639 # overload needed due to specific training target
1640 def requires(self):
1641 """
1642 Generate list of luigi Tasks that this Task depends on.
1643 """
1644 yield self.teacher_task(
1645 n_events_training=self.n_events_training,
1646 experiment_number=self.experiment_number,
1647 process_type=self.process_type,
1648 training_target=self.training_target,
1649 exclude_variables=self.exclude_variables,
1650 fast_bdt_option=self.fast_bdt_option,
1651 )
1652 if 'USE' in self.process_type: # USESIM and USEREC
1653 if 'BB' in self.process_type:
1654 process = 'BBBAR'
1655 elif 'EE' in self.process_type:
1656 process = 'BHABHA'
1657 yield CheckExistingFile(
1658 filename='datafiles/generated_mc_N' + str(self.n_events_testing) + '_' + process + '_test.root'
1659 )
1660 else:
1661 yield SplitNMergeSimTask(
1662 bkgfiles_dir=MasterTask.bkgfiles_by_exp[self.experiment_number],
1663 random_seed=self.process_type + '_test',
1664 n_events=self.n_events_testing,
1665 experiment_number=self.experiment_number,
1666 )
1667
1668 def add_tracking_with_quality_estimation(self, path):
1669 """
1670 Add modules for CDC standalone tracking with CDC quality estimator to basf2 path.
1671 """
1672 tracking.add_cdc_track_finding(
1673 path,
1674 output_reco_tracks="RecoTracks",
1675 add_mva_quality_indicator=True,
1676 )
1677 # change weightfile of quality estimator to the one produced by this training script
1678 basf2.conditions.prepend_testing_payloads("localdb/database.txt")
1679 cdc_qe_mva_filter_parameters = {
1680 "identifier": self.get_input_file_names(
1681 CDCQETeacherTask.get_weightfile_xml_identifier(
1682 CDCQETeacherTask,
1683 fast_bdt_option=self.fast_bdt_option))[0]}
1684
1685 # fbdt_string = create_fbdt_option_string(self.fast_bdt_option)
1686 name = 'TrackingMVAFilterParameters'
1688 dbobj_name=name,
1689 iovList=(0, 0, 0, -1),
1690 weightfile_identifier=self.get_input_file_names(
1691 CDCQETeacherTask.get_weightfile_xml_identifier(
1692 CDCQETeacherTask,
1693 fast_bdt_option=self.fast_bdt_option))[0],
1694 cut_value=0)
1695 cdc_qe_mva_filter_parameters = {'DBPayloadName': name} # 'identifier': 'trackfindingcdc_TrackQualityIndicator',
1696 basf2.set_module_parameters(
1697 path,
1698 name="TFCDC_TrackQualityEstimator",
1699 filterParameters=cdc_qe_mva_filter_parameters,
1700 deactivateIfDeadBoard=False, # original behavior before deactivateIfDeadBoard was introduced
1701 )
1702 tracking.add_track_fit_and_track_creator(path, components=["CDC"])
1703 tracking.add_mc_matcher(path, components=["CDC"])
1704
1705
1706class RecoTrackQEHarvestingValidationTask(HarvestingValidationBaseTask):
1707 """
1708 Run track reconstruction and write out (="harvest") a root file with variables
1709 useful for validation of the MVA track Quality Estimator.
1710 """
1711
1712 cdc_training_target = b2luigi.Parameter()
1713
1714 validation_output_file_name = "reco_qe_harvesting_validation.root"
1715
1716 reco_output_file_name = "reco_qe_reconstruction.root"
1717
1718 teacher_task = RecoTrackQETeacherTask
1719
1720 def requires(self):
1721 """
1722 Generate list of luigi Tasks that this Task depends on.
1723 """
1724 yield CDCQETeacherTask(
1725 n_events_training=self.n_events_training,
1726 experiment_number=self.experiment_number,
1727 process_type=self.process_type,
1728 training_target=self.cdc_training_target,
1729 exclude_variables=MasterTask.exclude_variables_cdc,
1730 fast_bdt_option=self.fast_bdt_option,
1731 )
1732 yield VXDQETeacherTask(
1733 n_events_training=self.n_events_training,
1734 experiment_number=self.experiment_number,
1735 process_type=self.process_type,
1736 exclude_variables=MasterTask.exclude_variables_vxd,
1737 fast_bdt_option=self.fast_bdt_option,
1738 )
1739
1740 yield self.teacher_task(
1741 n_events_training=self.n_events_training,
1742 experiment_number=self.experiment_number,
1743 process_type=self.process_type,
1744 exclude_variables=self.exclude_variables,
1745 cdc_training_target=self.cdc_training_target,
1746 fast_bdt_option=self.fast_bdt_option,
1747 )
1748 if 'USE' in self.process_type: # USESIM and USEREC
1749 if 'BB' in self.process_type:
1750 process = 'BBBAR'
1751 elif 'EE' in self.process_type:
1752 process = 'BHABHA'
1753 yield CheckExistingFile(
1754 filename='datafiles/generated_mc_N' + str(self.n_events_testing) + '_' + process + '_test.root'
1755 )
1756 else:
1757 yield SplitNMergeSimTask(
1758 bkgfiles_dir=MasterTask.bkgfiles_by_exp[self.experiment_number],
1759 random_seed=self.process_type + '_test',
1760 n_events=self.n_events_testing,
1761 experiment_number=self.experiment_number,
1762 )
1763
1764 def add_tracking_with_quality_estimation(self, path):
1765 """
1766 Add modules for reco tracking with all track quality estimators to basf2 path.
1767 """
1768
1769 # add tracking reconstruction with quality estimator modules added
1770 tracking.add_tracking_reconstruction(
1771 path,
1772 add_cdcTrack_QI=True,
1773 add_vxdTrack_QI=True,
1774 add_recoTrack_QI=True,
1775 skipGeometryAdding=True,
1776 skipHitPreparerAdding=True,
1777 )
1778
1779 # Replace the weightfiles of all quality estimator modules by those
1780 # produced in the training by b2luigi
1781 basf2.conditions.prepend_testing_payloads("localdb/database.txt")
1782 cdc_qe_mva_filter_parameters = {
1783 "identifier": self.get_input_file_names(
1784 CDCQETeacherTask.get_weightfile_xml_identifier(
1785 CDCQETeacherTask,
1786 fast_bdt_option=self.fast_bdt_option))[0]}
1787
1788 # fbdt_string = create_fbdt_option_string(self.fast_bdt_option)
1789 name = 'TrackingMVAFilterParameters'
1791 dbobj_name=name,
1792 iovList=(0, 0, 0, -1),
1793 weightfile_identifier=self.get_input_file_names(
1794 CDCQETeacherTask.get_weightfile_xml_identifier(
1795 CDCQETeacherTask,
1796 fast_bdt_option=self.fast_bdt_option))[0],
1797 cut_value=0)
1798 cdc_qe_mva_filter_parameters = {'DBPayloadName': name}
1799 basf2.set_module_parameters(
1800 path,
1801 name="TFCDC_TrackQualityEstimator",
1802 filterParameters=cdc_qe_mva_filter_parameters,
1803 deactivateIfDeadBoard=False, # original behavior before deactivateIfDeadBoard was introduced
1804 )
1805 vxd_identifier = self.get_input_file_names(
1806 VXDQETeacherTask.get_weightfile_xml_identifier(VXDQETeacherTask, fast_bdt_option=self.fast_bdt_option)
1807 )[0]
1808 with open(vxd_identifier) as f:
1809 weight_file_content = f.read()
1810 vxd_name = write_mva_weightfile_content_to_db(
1811 dbobj_name='VXDQualityEstimatorMVAWeightFileIdentifier',
1812 content=weight_file_content,
1813 iovList=(0, 0, 0, -1)
1814 )
1815 basf2.set_module_parameters(
1816 path,
1817 name="VXDQualityEstimatorMVA",
1818 WeightFileIdentifier=vxd_name,
1819 )
1820 recotrack_identifier = self.get_input_file_names(
1821 RecoTrackQETeacherTask.get_weightfile_xml_identifier(RecoTrackQETeacherTask, fast_bdt_option=self.fast_bdt_option)
1822 )[0]
1823 with open(recotrack_identifier) as f:
1824 weight_file_content = f.read()
1825 recotrack_name = write_mva_weightfile_content_to_db(
1826 dbobj_name='RecoTrackQualityEstimatorMVAWeightFileIdentifier',
1827 content=weight_file_content,
1828 iovList=(0, 0, 0, -1)
1829 )
1830 basf2.set_module_parameters(
1831 path,
1832 name="TrackQualityEstimatorMVA",
1833 WeightFileIdentifier=recotrack_name,
1834 )
1835
1836
1837class TrackQEEvaluationBaseTask(Task):
1838 """
1839 Base class for evaluating a quality estimator ``basf2_mva_evaluate.py`` on a
1840 separate test data set.
1841
1842 Evaluation tasks for VXD, CDC and combined QE can inherit from it.
1843 """
1844
1845
1850 git_hash = b2luigi.Parameter(
1851
1852 default=get_basf2_git_hash()
1853
1854 )
1855
1856 n_events_testing = b2luigi.IntParameter()
1857
1858 n_events_training = b2luigi.IntParameter()
1859
1860 experiment_number = b2luigi.IntParameter()
1861
1864 process_type = b2luigi.Parameter(
1865
1866 default="BBBAR"
1867
1868 )
1869
1870 training_target = b2luigi.Parameter(
1871
1872 default="truth"
1873
1874 )
1875
1877 exclude_variables = b2luigi.ListParameter(
1878
1879 hashed=True
1880
1881 )
1882
1883 fast_bdt_option = b2luigi.ListParameter(
1884
1885 hashed=True, default=[200, 8, 3, 0.1]
1886
1887 )
1888
1889 @property
1890 def teacher_task(self) -> TrackQETeacherBaseTask:
1891 """
1892 Property defining specific teacher task to require.
1893 """
1894 raise NotImplementedError(
1895 "Evaluation Tasks must define a teacher task to require "
1896 )
1897
1898 @property
1899 def data_collection_task(self) -> Basf2PathTask:
1900 """
1901 Property defining the specific ``DataCollectionTask`` to require. Must
1902 implemented by the inheriting specific teacher task class.
1903 """
1904 raise NotImplementedError(
1905 "Evaluation Tasks must define a data collection task to require "
1906 )
1907
1908 @property
1909 def task_acronym(self):
1910 """
1911 Acronym to distinguish between cdc, vxd and rec(o) MVA
1912 """
1913 raise NotImplementedError(
1914 "Evaluation Tasks must define a task acronym."
1915 )
1916
1917 def requires(self):
1918 """
1919 Generate list of luigi Tasks that this Task depends on.
1920 """
1921 yield self.teacher_task(
1922 n_events_training=self.n_events_training,
1923 experiment_number=self.experiment_number,
1924 process_type=self.process_type,
1925 training_target=self.training_target,
1926 exclude_variables=self.exclude_variables,
1927 fast_bdt_option=self.fast_bdt_option,
1928 )
1929 if 'USEREC' in self.process_type:
1930 if 'USERECBB' in self.process_type:
1931 process = 'BBBAR'
1932 elif 'USERECEE' in self.process_type:
1933 process = 'BHABHA'
1934 yield CheckExistingFile(
1935 filename='datafiles/qe_records_N' + str(self.n_events_testing) + '_' + process + '_test_' +
1936 self.task_acronym + '.root'
1937 )
1938 else:
1939 yield self.data_collection_task(
1940 num_processes=MasterTask.num_processes,
1941 n_events=self.n_events_testing,
1942 experiment_number=self.experiment_number,
1943 random_seed=self.process_type + '_test',
1944 )
1945
1946 def output(self):
1947 """
1948 Generate list of output files that the task should produce.
1949 The task is considered finished if and only if the outputs all exist.
1950 """
1951 weightfile_details = create_fbdt_option_string(self.fast_bdt_option)
1952 evaluation_pdf_output = self.teacher_task.weightfile_identifier_basename + weightfile_details + ".zip"
1953 yield self.add_to_output(evaluation_pdf_output)
1954
1955 @b2luigi.on_temporary_files
1956 def run(self):
1957 """
1958 Run ``basf2_mva_evaluate.py`` subprocess to evaluate QE MVA.
1959
1960 The MVA weight file created from training on the training data set is
1961 evaluated on separate test data.
1962 """
1963 weightfile_details = create_fbdt_option_string(self.fast_bdt_option)
1964 evaluation_pdf_output_basename = self.teacher_task.weightfile_identifier_basename + weightfile_details + ".zip"
1965
1966 evaluation_pdf_output_path = self.get_output_file_name(evaluation_pdf_output_basename)
1967
1968 if 'USEREC' in self.process_type:
1969 if 'USERECBB' in self.process_type:
1970 process = 'BBBAR'
1971 elif 'USERECEE' in self.process_type:
1972 process = 'BHABHA'
1973 datafiles = 'datafiles/qe_records_N' + str(self.n_events_testing) + '_' + \
1974 process + '_test_' + self.task_acronym + '.root'
1975 else:
1976 datafiles = self.get_input_file_names(
1977 self.data_collection_task.get_records_file_name(
1978 self.data_collection_task,
1979 n_events=self.n_events_testing,
1980 random_seed=self.process_type + '_test_' +
1981 self.task_acronym))[0]
1982 teacher_task = None
1983 for req in b2luigi.task.flatten(self.requires()):
1984 if isinstance(req, self.teacher_task):
1985 teacher_task = req
1986 break
1987 if hasattr(teacher_task, 'recotrack_option'):
1988 records_files = teacher_task.get_input_file_names(
1989 self.data_collection_task.get_records_file_name(
1990 self.data_collection_task,
1991 n_events=self.n_events_training,
1992 random_seed=self.process_type + '_' + teacher_task.random_seed,
1993 recotrack_option=teacher_task.recotrack_option))
1994 else:
1995 records_files = teacher_task.get_input_file_names(
1996 self.data_collection_task.get_records_file_name(
1997 self.data_collection_task,
1998 n_events=self.n_events_training,
1999 random_seed=self.process_type + '_' + teacher_task.random_seed))
2000 cmd = [
2001 "basf2_mva_evaluate.py",
2002 "--identifiers",
2003 self.get_input_file_names(
2004 self.teacher_task.get_weightfile_xml_identifier(
2005 self.teacher_task,
2006 fast_bdt_option=self.fast_bdt_option))[0],
2007 "--train_datafiles",
2008 records_files[0],
2009 "--datafiles",
2010 datafiles,
2011 "--treename",
2012 self.teacher_task.tree_name,
2013 "--outputfile",
2014 evaluation_pdf_output_path,
2015 ]
2016
2017 # Prepare log files
2018 log_file_dir = get_log_file_dir(self)
2019 # check if directory already exists, if not, create it. I think this is necessary as this task does not
2020 # inherit properly from b2luigi and thus does not do it automatically??
2021 try:
2022 os.makedirs(log_file_dir, exist_ok=True)
2023 # the following should be unnecessary as exist_ok=True should take care that no FileExistError rises. I
2024 # might ask about a permission error...
2025 except FileExistsError:
2026 print('Directory ' + log_file_dir + 'already exists.')
2027 stderr_log_file_path = log_file_dir + "stderr"
2028 stdout_log_file_path = log_file_dir + "stdout"
2029 with open(stdout_log_file_path, "w") as stdout_file:
2030 stdout_file.write(f'stdout output of the command:\n{" ".join(cmd)}\n\n')
2031 if os.path.exists(stderr_log_file_path):
2032 # remove stderr file if it already exists b/c in the following it will be opened in appending mode
2033 os.remove(stderr_log_file_path)
2034
2035 # Run evaluation via subprocess and write output into logfiles
2036 with open(stdout_log_file_path, "a") as stdout_file:
2037 with open(stderr_log_file_path, "a") as stderr_file:
2038 try:
2039 subprocess.run(cmd, check=True, stdout=stdout_file, stderr=stderr_file)
2040 except subprocess.CalledProcessError as err:
2041 stderr_file.write(f"Evaluation failed with error:\n{err}")
2042 raise err
2043
2044
2045class VXDTrackQEEvaluationTask(TrackQEEvaluationBaseTask):
2046 """
2047 Run ``basf2_mva_evaluate.py`` for the VXD quality estimator on separate test data
2048 """
2049
2051 teacher_task = VXDQETeacherTask
2052
2054 data_collection_task = VXDQEDataCollectionTask
2055
2057 task_acronym = 'vxd'
2058
2059
2060class CDCTrackQEEvaluationTask(TrackQEEvaluationBaseTask):
2061 """
2062 Run ``basf2_mva_evaluate.py`` for the CDC quality estimator on separate test data
2063 """
2064
2066 teacher_task = CDCQETeacherTask
2067
2069 data_collection_task = CDCQEDataCollectionTask
2070
2072 task_acronym = 'cdc'
2073
2074
2075class RecoTrackQEEvaluationTask(TrackQEEvaluationBaseTask):
2076 """
2077 Run ``basf2_mva_evaluate.py`` for the final, combined quality estimator on
2078 separate test data
2079 """
2080
2082 teacher_task = RecoTrackQETeacherTask
2083
2085 data_collection_task = RecoTrackQEDataCollectionTask
2086
2088 task_acronym = 'rec'
2089
2090 cdc_training_target = b2luigi.Parameter()
2091
2092 def requires(self):
2093 """
2094 Generate list of luigi Tasks that this Task depends on.
2095 """
2096 yield self.teacher_task(
2097 n_events_training=self.n_events_training,
2098 experiment_number=self.experiment_number,
2099 process_type=self.process_type,
2100 training_target=self.training_target,
2101 exclude_variables=self.exclude_variables,
2102 cdc_training_target=self.cdc_training_target,
2103 fast_bdt_option=self.fast_bdt_option,
2104 )
2105 if 'USEREC' in self.process_type:
2106 if 'USERECBB' in self.process_type:
2107 process = 'BBBAR'
2108 elif 'USERECEE' in self.process_type:
2109 process = 'BHABHA'
2110 yield CheckExistingFile(
2111 filename='datafiles/qe_records_N' + str(self.n_events_testing) + '_' + process + '_test_' +
2112 self.task_acronym + '.root'
2113 )
2114 else:
2115 yield self.data_collection_task(
2116 num_processes=MasterTask.num_processes,
2117 n_events=self.n_events_testing,
2118 experiment_number=self.experiment_number,
2119 random_seed=self.process_type + "_test",
2120 cdc_training_target=self.cdc_training_target,
2121 )
2122
2123
2124class PlotsFromHarvestingValidationBaseTask(Basf2Task):
2125 """
2126 Create a PDF file with validation plots for a quality estimator produced
2127 from the ROOT ntuples produced by a harvesting validation task
2128 """
2129
2130 n_events_testing = b2luigi.IntParameter()
2131
2132 n_events_training = b2luigi.IntParameter()
2133
2134 experiment_number = b2luigi.IntParameter()
2135
2138 process_type = b2luigi.Parameter(
2139
2140 default="BBBAR"
2141
2142 )
2143
2145 exclude_variables = b2luigi.ListParameter(
2146
2147 hashed=True
2148
2149 )
2150
2151 fast_bdt_option = b2luigi.ListParameter(
2152
2153 hashed=True, default=[200, 8, 3, 0.1]
2154
2155 )
2156
2157 primaries_only = b2luigi.BoolParameter(
2158
2159 default=True
2160
2161 ) # normalize finding efficiencies to primary MC-tracks
2162
2163 @property
2164 def harvesting_validation_task_instance(self) -> HarvestingValidationBaseTask:
2165 """
2166 Specifies related harvesting validation task which produces the ROOT
2167 files with the data that is plotted by this task.
2168 """
2169 raise NotImplementedError("Must define a QI harvesting validation task for which to do the plots")
2170
2171 @property
2172 def output_pdf_file_basename(self):
2173 """
2174 Name of the output PDF file containing the validation plots
2175 """
2176 validation_harvest_basename = self.harvesting_validation_task_instance.validation_output_file_name
2177 return validation_harvest_basename.replace(".root", "_plots.pdf")
2178
2179 def requires(self):
2180 """
2181 Generate list of luigi Tasks that this Task depends on.
2182 """
2183 yield self.harvesting_validation_task_instance
2184
2185 def output(self):
2186 """
2187 Generate list of output files that the task should produce.
2188 The task is considered finished if and only if the outputs all exist.
2189 """
2190 yield self.add_to_output(self.output_pdf_file_basename)
2191
2192 @b2luigi.on_temporary_files
2193 def process(self):
2194 """
2195 Use basf2_mva teacher to create MVA weightfile from collected training
2196 data variables.
2197
2198 Main process that is dispatched by the ``run`` method that is inherited
2199 from ``Basf2Task``.
2200 """
2201 # get the validation "harvest", which is the ROOT file with ntuples for validation
2202 validation_harvest_basename = self.harvesting_validation_task_instance.validation_output_file_name
2203 validation_harvest_path = self.get_input_file_names(validation_harvest_basename)[0]
2204
2205 # Load "harvested" validation data from root files into dataframes (requires enough memory to hold data)
2206 pr_columns = [ # Restrict memory usage by only reading in columns that are used in the steering file
2207 'is_fake', 'is_clone', 'is_matched', 'quality_indicator',
2208 'experiment_number', 'run_number', 'event_number', 'pr_store_array_number',
2209 'pt_estimate', 'z0_estimate', 'd0_estimate', 'tan_lambda_estimate',
2210 'phi0_estimate', 'pt_truth', 'z0_truth', 'd0_truth', 'tan_lambda_truth',
2211 'phi0_truth',
2212 ]
2213 # In ``pr_df`` each row corresponds to a track from Pattern Recognition
2214 pr_df = uproot.open(validation_harvest_path)['pr_tree/pr_tree'].arrays(pr_columns, library='pd')
2215 mc_columns = [ # restrict mc_df to these columns
2216 'experiment_number',
2217 'run_number',
2218 'event_number',
2219 'pr_store_array_number',
2220 'is_missing',
2221 'is_primary',
2222 ]
2223 # In ``mc_df`` each row corresponds to an MC track
2224 mc_df = uproot.open(validation_harvest_path)['mc_tree/mc_tree'].arrays(mc_columns, library='pd')
2225 if self.primaries_only:
2226 mc_df = mc_df[mc_df.is_primary.eq(True)]
2227
2228 # Define QI thresholds for the FOM plots and the ROC curves
2229 qi_cuts = np.linspace(0., 1, 20, endpoint=False)
2230 # # Add more points at the very end between the previous maximum and 1
2231 # qi_cuts = np.append(qi_cuts, np.linspace(np.max(qi_cuts), 1, 20, endpoint=False))
2232
2233 # Create plots and append them to single output pdf
2234
2235 output_pdf_file_path = self.get_output_file_name(self.output_pdf_file_basename)
2236 with PdfPages(output_pdf_file_path, keep_empty=False) as pdf:
2237
2238 # Add a title page to validation plot PDF with some metadata
2239 # Remember that most metadata is in the xml file of the weightfile
2240 # and in the b2luigi directory structure
2241 titlepage_fig, titlepage_ax = plt.subplots()
2242 titlepage_ax.axis("off")
2243 title = f"Quality Estimator validation plots from {self.__class__.__name__}"
2244 titlepage_ax.set_title(title)
2245 teacher_task = self.harvesting_validation_task_instance.teacher_task
2246 weightfile_identifier = teacher_task.get_weightfile_xml_identifier(teacher_task, fast_bdt_option=self.fast_bdt_option)
2247 meta_data = {
2248 "Date": datetime.today().strftime("%Y-%m-%d %H:%M"),
2249 "Created by steering file": os.path.realpath(__file__),
2250 "Created from data in": validation_harvest_path,
2251 "Background directory": MasterTask.bkgfiles_by_exp[self.experiment_number],
2252 "weight file": weightfile_identifier,
2253 }
2254 if hasattr(self, 'exclude_variables'):
2255 meta_data["Excluded variables"] = ", ".join(self.exclude_variables)
2256 meta_data_string = (format_dictionary(meta_data) +
2257 "\n\n(For all MVA training parameters look into the produced weight file)")
2258 luigi_params = get_serialized_parameters(self)
2259 luigi_param_string = (f"\n\nb2luigi parameters for {self.__class__.__name__}\n" +
2260 format_dictionary(luigi_params))
2261 title_page_text = meta_data_string + luigi_param_string
2262 titlepage_ax.text(0, 1, title_page_text, ha="left", va="top", wrap=True, fontsize=8)
2263 pdf.savefig(titlepage_fig)
2264 plt.close(titlepage_fig)
2265
2266 fake_rates = get_uncertain_means_for_qi_cuts(pr_df, "is_fake", qi_cuts)
2267 fake_fig, fake_ax = plt.subplots()
2268 fake_ax.set_title("Fake rate")
2269 plot_with_errobands(fake_rates, ax=fake_ax)
2270 fake_ax.set_ylabel("fake rate")
2271 fake_ax.set_xlabel("quality indicator requirement")
2272 pdf.savefig(fake_fig, bbox_inches="tight")
2273 plt.close(fake_fig)
2274
2275 # Plot clone rates
2276 clone_rates = get_uncertain_means_for_qi_cuts(pr_df, "is_clone", qi_cuts)
2277 clone_fig, clone_ax = plt.subplots()
2278 clone_ax.set_title("Clone rate")
2279 plot_with_errobands(clone_rates, ax=clone_ax)
2280 clone_ax.set_ylabel("clone rate")
2281 clone_ax.set_xlabel("quality indicator requirement")
2282 pdf.savefig(clone_fig, bbox_inches="tight")
2283 plt.close(clone_fig)
2284
2285 # Plot finding efficiency
2286
2287 # The Quality Indicator is only available in pr_tree and thus the
2288 # PR-track dataframe. To get the QI of the related PR track for an MC
2289 # track, merge the PR dataframe into the MC dataframe
2290 pr_track_identifiers = ['experiment_number', 'run_number', 'event_number', 'pr_store_array_number']
2291 mc_df = upd.merge(
2292 left=mc_df, right=pr_df[pr_track_identifiers + ['quality_indicator']],
2293 how='left',
2294 on=pr_track_identifiers
2295 )
2296
2297 missing_fractions = (
2298 _my_uncertain_mean(mc_df[
2299 mc_df.quality_indicator.isnull() | (mc_df.quality_indicator > qi_cut)]['is_missing'])
2300 for qi_cut in qi_cuts
2301 )
2302
2303 findeff_fig, findeff_ax = plt.subplots()
2304 findeff_ax.set_title("Finding efficiency")
2305 finding_efficiencies = 1.0 - upd.Series(data=missing_fractions, index=qi_cuts)
2306 plot_with_errobands(finding_efficiencies, ax=findeff_ax)
2307 findeff_ax.set_ylabel("finding efficiency")
2308 findeff_ax.set_xlabel("quality indicator requirement")
2309 pdf.savefig(findeff_fig, bbox_inches="tight")
2310 plt.close(findeff_fig)
2311
2312 # Plot ROC curves
2313
2314 # Fake rate vs. finding efficiency ROC curve
2315 fake_roc_fig, fake_roc_ax = plt.subplots()
2316 fake_roc_ax.set_title("Fake rate vs. finding efficiency ROC curve")
2317 fake_roc_ax.errorbar(x=finding_efficiencies.nominal_value, y=fake_rates.nominal_value,
2318 xerr=finding_efficiencies.std_dev, yerr=fake_rates.std_dev, elinewidth=0.8)
2319 fake_roc_ax.set_xlabel('finding efficiency')
2320 fake_roc_ax.set_ylabel('fake rate')
2321 pdf.savefig(fake_roc_fig, bbox_inches="tight")
2322 plt.close(fake_roc_fig)
2323
2324 # Clone rate vs. finding efficiency ROC curve
2325 clone_roc_fig, clone_roc_ax = plt.subplots()
2326 clone_roc_ax.set_title("Clone rate vs. finding efficiency ROC curve")
2327 clone_roc_ax.errorbar(x=finding_efficiencies.nominal_value, y=clone_rates.nominal_value,
2328 xerr=finding_efficiencies.std_dev, yerr=clone_rates.std_dev, elinewidth=0.8)
2329 clone_roc_ax.set_xlabel('finding efficiency')
2330 clone_roc_ax.set_ylabel('clone rate')
2331 pdf.savefig(clone_roc_fig, bbox_inches="tight")
2332 plt.close(clone_roc_fig)
2333
2334 # Plot kinematic distributions
2335
2336 # use fewer qi cuts as each cut will be it's own subplot now and not a point
2337 kinematic_qi_cuts = [0, 0.5, 0.9]
2338
2339 # Define kinematic parameters which we want to histogram and define
2340 # dictionaries relating them to latex labels, units and binnings
2341 params = ['d0', 'z0', 'pt', 'tan_lambda', 'phi0']
2342 label_by_param = {
2343 "pt": "$p_T$",
2344 "z0": "$z_0$",
2345 "d0": "$d_0$",
2346 "tan_lambda": r"$\tan{\lambda}$",
2347 "phi0": r"$\phi_0$"
2348 }
2349 unit_by_param = {
2350 "pt": "GeV",
2351 "z0": "cm",
2352 "d0": "cm",
2353 "tan_lambda": "rad",
2354 "phi0": "rad"
2355 }
2356 n_kinematic_bins = 75 # number of bins per kinematic variable
2357 bins_by_param = {
2358 "pt": np.linspace(0, np.percentile(pr_df['pt_truth'].dropna(), 95), n_kinematic_bins),
2359 "z0": np.linspace(-0.1, 0.1, n_kinematic_bins),
2360 "d0": np.linspace(0, 0.01, n_kinematic_bins),
2361 "tan_lambda": np.linspace(-2, 3, n_kinematic_bins),
2362 "phi0": np.linspace(0, 2 * np.pi, n_kinematic_bins)
2363 }
2364
2365 # Iterate over each parameter and for each make stacked histograms for different QI cuts
2366 kinematic_qi_cuts = [0, 0.5, 0.8]
2367 blue, yellow, green = plt.get_cmap("tab10").colors[0:3]
2368 for param in params:
2369 fig, axarr = plt.subplots(ncols=len(kinematic_qi_cuts), sharey=True, sharex=True, figsize=(14, 6))
2370 fig.suptitle(f"{label_by_param[param]} distributions")
2371 for i, qi in enumerate(kinematic_qi_cuts):
2372 ax = axarr[i]
2373 ax.set_title(f"QI > {qi}")
2374 incut = pr_df[(pr_df['quality_indicator'] > qi)]
2375 incut_matched = incut[incut.is_matched.eq(True)]
2376 incut_clones = incut[incut.is_clone.eq(True)]
2377 incut_fake = incut[incut.is_fake.eq(True)]
2378
2379 # if any series is empty, break out of loop and don't draw try to draw a stacked histogram
2380 if any(series.empty for series in (incut, incut_matched, incut_clones, incut_fake)):
2381 ax.text(0.5, 0.5, "Not enough data in bin", ha="center", va="center", transform=ax.transAxes)
2382 continue
2383
2384 bins = bins_by_param[param]
2385 stacked_histogram_series_tuple = (
2386 incut_matched[f'{param}_estimate'],
2387 incut_clones[f'{param}_estimate'],
2388 incut_fake[f'{param}_estimate'],
2389 )
2390 histvals, _, _ = ax.hist(stacked_histogram_series_tuple,
2391 stacked=True,
2392 bins=bins, range=(bins.min(), bins.max()),
2393 color=(blue, green, yellow),
2394 label=("matched", "clones", "fakes"))
2395 ax.set_xlabel(f'{label_by_param[param]} estimate / ({unit_by_param[param]})')
2396 ax.set_ylabel('# tracks')
2397 axarr[0].legend(loc="upper center", bbox_to_anchor=(0, -0.15))
2398 pdf.savefig(fig, bbox_inches="tight")
2399 plt.close(fig)
2400
2401
2402class VXDQEValidationPlotsTask(PlotsFromHarvestingValidationBaseTask):
2403 """
2404 Create a PDF file with validation plots for the VXDTF2 track quality
2405 estimator produced from the ROOT ntuples produced by a VXDTF2 track QE
2406 harvesting validation task
2407 """
2408
2409 @property
2410 def harvesting_validation_task_instance(self):
2411 """
2412 Harvesting validation task to require, which produces the ROOT files
2413 with variables to produce the VXD QE validation plots.
2414 """
2415 return VXDQEHarvestingValidationTask(
2416 n_events_testing=self.n_events_testing,
2417 n_events_training=self.n_events_training,
2418 process_type=self.process_type,
2419 experiment_number=self.experiment_number,
2420 exclude_variables=self.exclude_variables,
2421 num_processes=MasterTask.num_processes,
2422 fast_bdt_option=self.fast_bdt_option,
2423 )
2424
2425
2426class CDCQEValidationPlotsTask(PlotsFromHarvestingValidationBaseTask):
2427 """
2428 Create a PDF file with validation plots for the CDC track quality estimator
2429 produced from the ROOT ntuples produced by a CDC track QE harvesting
2430 validation task
2431 """
2432
2433 training_target = b2luigi.Parameter()
2434
2435 @property
2436 def harvesting_validation_task_instance(self):
2437 """
2438 Harvesting validation task to require, which produces the ROOT files
2439 with variables to produce the CDC QE validation plots.
2440 """
2441 return CDCQEHarvestingValidationTask(
2442 n_events_testing=self.n_events_testing,
2443 n_events_training=self.n_events_training,
2444 process_type=self.process_type,
2445 experiment_number=self.experiment_number,
2446 training_target=self.training_target,
2447 exclude_variables=self.exclude_variables,
2448 num_processes=MasterTask.num_processes,
2449 fast_bdt_option=self.fast_bdt_option,
2450 )
2451
2452
2453class RecoTrackQEValidationPlotsTask(PlotsFromHarvestingValidationBaseTask):
2454 """
2455 Create a PDF file with validation plots for the reco MVA track quality
2456 estimator produced from the ROOT ntuples produced by a reco track QE
2457 harvesting validation task
2458 """
2459
2460 cdc_training_target = b2luigi.Parameter()
2461
2462 @property
2463 def harvesting_validation_task_instance(self):
2464 """
2465 Harvesting validation task to require, which produces the ROOT files
2466 with variables to produce the final MVA track QE validation plots.
2467 """
2468 return RecoTrackQEHarvestingValidationTask(
2469 n_events_testing=self.n_events_testing,
2470 n_events_training=self.n_events_training,
2471 process_type=self.process_type,
2472 experiment_number=self.experiment_number,
2473 cdc_training_target=self.cdc_training_target,
2474 exclude_variables=self.exclude_variables,
2475 num_processes=MasterTask.num_processes,
2476 fast_bdt_option=self.fast_bdt_option,
2477 )
2478
2479
2480class QEWeightsLocalDBCreatorTask(Basf2Task):
2481 """
2482 Collect weightfile identifiers from different teacher tasks and merge them
2483 into a local database for testing.
2484 """
2485
2486 n_events_training = b2luigi.IntParameter()
2487
2488 experiment_number = b2luigi.IntParameter()
2489
2492 process_type = b2luigi.Parameter(
2493
2494 default="BBBAR"
2495
2496 )
2497
2498 cdc_training_target = b2luigi.Parameter()
2499
2500 fast_bdt_option = b2luigi.ListParameter(
2501
2502 hashed=True, default=[200, 8, 3, 0.1]
2503
2504 )
2505
2506 def requires(self):
2507 """
2508 Required teacher tasks
2509 """
2510 yield VXDQETeacherTask(
2511 n_events_training=self.n_events_training,
2512 process_type=self.process_type,
2513 experiment_number=self.experiment_number,
2514 exclude_variables=MasterTask.exclude_variables_vxd,
2515 fast_bdt_option=self.fast_bdt_option,
2516 )
2517 yield CDCQETeacherTask(
2518 n_events_training=self.n_events_training,
2519 process_type=self.process_type,
2520 experiment_number=self.experiment_number,
2521 training_target=self.cdc_training_target,
2522 exclude_variables=MasterTask.exclude_variables_cdc,
2523 fast_bdt_option=self.fast_bdt_option,
2524 )
2525 yield RecoTrackQETeacherTask(
2526 n_events_training=self.n_events_training,
2527 process_type=self.process_type,
2528 experiment_number=self.experiment_number,
2529 cdc_training_target=self.cdc_training_target,
2530 exclude_variables=MasterTask.exclude_variables_rec,
2531 fast_bdt_option=self.fast_bdt_option,
2532 )
2533
2534 def output(self):
2535 """
2536 Local database
2537 """
2538 yield self.add_to_output("localdb.tar")
2539
2540 def process(self):
2541 """
2542 Create local database
2543 """
2544 current_path = Path.cwd()
2545 localdb_archive_path = Path(self.get_output_file_name("localdb.tar")).absolute()
2546 output_dir = localdb_archive_path.parent
2547
2548 # remove existing local databases in output directories
2549 self._clean()
2550 # "Upload" the weightfiles of all 3 teacher tasks into the same localdb
2551 for task in (VXDQETeacherTask, CDCQETeacherTask, RecoTrackQETeacherTask):
2552 # Extract xml identifier input file name before switching working directories, as it returns relative paths
2553 weightfile_xml_identifier_path = os.path.abspath(self.get_input_file_names(
2554 task.get_weightfile_xml_identifier(task, fast_bdt_option=self.fast_bdt_option))[0])
2555 # As localdb is created in working directory, chdir into desired output path
2556 try:
2557 os.chdir(output_dir)
2558 # Same as basf2_mva_upload on the command line, creates localdb directory in current working dir
2559 basf2_mva.upload(
2560 weightfile_xml_identifier_path,
2561 task.weightfile_identifier_basename,
2562 self.experiment_number, 0,
2563 self.experiment_number, -1,
2564 )
2565 finally: # Switch back to working directory of b2luigi, even if upload failed
2566 os.chdir(current_path)
2567
2568 # Pack localdb into tar archive, so that we can have on single output file instead
2569 shutil.make_archive(
2570 base_name=localdb_archive_path.as_posix().split('.')[0],
2571 format="tar",
2572 root_dir=output_dir,
2573 base_dir="localdb",
2574 verbose=True,
2575 )
2576
2577 def _clean(self):
2578 """
2579 Remove local database and tar archives in output directory
2580 """
2581 localdb_archive_path = Path(self.get_output_file_name("localdb.tar"))
2582 localdb_path = localdb_archive_path.parent / "localdb"
2583
2584 if localdb_path.exists():
2585 print(f"Deleting localdb\n{localdb_path}\nwith contents\n ",
2586 "\n ".join(f.name for f in localdb_path.iterdir()))
2587 shutil.rmtree(localdb_path, ignore_errors=False) # recursively delete localdb
2588
2589 if localdb_archive_path.is_file():
2590 print(f"Deleting {localdb_archive_path}")
2591 os.remove(localdb_archive_path)
2592
2593 def on_failure(self, exception):
2594 """
2595 Cleanup: Remove local database to prevent existing outputs when task did not finish successfully
2596 """
2597 self._clean()
2598 # Run existing on_failure from parent class
2599 super().on_failure(exception)
2600
2601
2602class MasterTask(b2luigi.WrapperTask):
2603 """
2604 Wrapper task that needs to finish for b2luigi to finish running this steering file.
2605
2606 It is done if the outputs of all required subtasks exist. It is thus at the
2607 top of the luigi task graph. Edit the ``requires`` method to steer which
2608 tasks and with which parameters you want to run.
2609 """
2610
2613 process_type = b2luigi.get_setting(
2614
2615 "process_type", default='BBBAR'
2616
2617 )
2618
2619 n_events_training = b2luigi.get_setting(
2620
2621 "n_events_training", default=20000
2622
2623 )
2624
2625 n_events_testing = b2luigi.get_setting(
2626
2627 "n_events_testing", default=5000
2628
2629 )
2630
2631 n_events_per_task = b2luigi.get_setting(
2632
2633 "n_events_per_task", default=100
2634
2635 )
2636
2637 num_processes = b2luigi.get_setting(
2638
2639 "basf2_processes_per_worker", default=0
2640
2641 )
2642
2643 datafiles = b2luigi.get_setting("datafiles")
2644
2645 bkgfiles_by_exp = b2luigi.get_setting("bkgfiles_by_exp")
2646
2647 bkgfiles_by_exp = {int(key): val for (key, val) in bkgfiles_by_exp.items()}
2648
2649 exclude_variables_cdc = [
2650 "has_matching_segment",
2651 "size",
2652 "n_tracks", # not written out per default anyway
2653 "avg_hit_dist",
2654 "cont_layer_mean",
2655 "cont_layer_variance",
2656 "cont_layer_max",
2657 "cont_layer_min",
2658 "cont_layer_first",
2659 "cont_layer_last",
2660 "cont_layer_max_vs_last",
2661 "cont_layer_first_vs_min",
2662 "cont_layer_count",
2663 "cont_layer_occupancy",
2664 "super_layer_mean",
2665 "super_layer_variance",
2666 "super_layer_max_vs_last",
2667 "super_layer_first_vs_min",
2668 "super_layer_occupancy",
2669 "drift_length_mean",
2670 "drift_length_variance",
2671 "drift_length_max",
2672 "drift_length_min",
2673 "drift_length_sum",
2674 "norm_drift_length_mean",
2675 "norm_drift_length_variance",
2676 "norm_drift_length_max",
2677 "norm_drift_length_min",
2678 "norm_drift_length_sum",
2679 "adc_mean",
2680 "adc_variance",
2681 "adc_max",
2682 "adc_min",
2683 "adc_sum",
2684 "tot_mean",
2685 "tot_variance",
2686 "tot_max",
2687 "tot_min",
2688 "tot_sum",
2689 "empty_s_mean",
2690 "empty_s_variance",
2691 "empty_s_max"]
2692
2693 exclude_variables_vxd = [
2694 'energyLoss_max', 'energyLoss_min', 'energyLoss_mean', 'energyLoss_std', 'energyLoss_sum',
2695 'size_max', 'size_min', 'size_mean', 'size_std', 'size_sum',
2696 'seedCharge_max', 'seedCharge_min', 'seedCharge_mean', 'seedCharge_std', 'seedCharge_sum',
2697 'tripletFit_P_Mag', 'tripletFit_P_Eta', 'tripletFit_P_Phi', 'tripletFit_P_X', 'tripletFit_P_Y', 'tripletFit_P_Z']
2698
2699 exclude_variables_rec = [
2700 'background',
2701 'ghost',
2702 'fake',
2703 'clone',
2704 '__experiment__',
2705 '__run__',
2706 '__event__',
2707 'N_RecoTracks',
2708 'N_PXDRecoTracks',
2709 'N_SVDRecoTracks',
2710 'N_CDCRecoTracks',
2711 'N_diff_PXD_SVD_RecoTracks',
2712 'N_diff_SVD_CDC_RecoTracks',
2713 'Fit_Successful',
2714 'Fit_NFailedPoints',
2715 'Fit_Chi2',
2716 'N_TrackPoints_without_KalmanFitterInfo',
2717 'N_Hits_without_TrackPoint',
2718 'SVD_CDC_CDCwall_Chi2',
2719 'SVD_CDC_CDCwall_Pos_diff_Z',
2720 'SVD_CDC_CDCwall_Pos_diff_Pt',
2721 'SVD_CDC_CDCwall_Pos_diff_Theta',
2722 'SVD_CDC_CDCwall_Pos_diff_Phi',
2723 'SVD_CDC_CDCwall_Pos_diff_Mag',
2724 'SVD_CDC_CDCwall_Pos_diff_Eta',
2725 'SVD_CDC_CDCwall_Mom_diff_Z',
2726 'SVD_CDC_CDCwall_Mom_diff_Pt',
2727 'SVD_CDC_CDCwall_Mom_diff_Theta',
2728 'SVD_CDC_CDCwall_Mom_diff_Phi',
2729 'SVD_CDC_CDCwall_Mom_diff_Mag',
2730 'SVD_CDC_CDCwall_Mom_diff_Eta',
2731 'SVD_CDC_POCA_Pos_diff_Z',
2732 'SVD_CDC_POCA_Pos_diff_Pt',
2733 'SVD_CDC_POCA_Pos_diff_Theta',
2734 'SVD_CDC_POCA_Pos_diff_Phi',
2735 'SVD_CDC_POCA_Pos_diff_Mag',
2736 'SVD_CDC_POCA_Pos_diff_Eta',
2737 'SVD_CDC_POCA_Mom_diff_Z',
2738 'SVD_CDC_POCA_Mom_diff_Pt',
2739 'SVD_CDC_POCA_Mom_diff_Theta',
2740 'SVD_CDC_POCA_Mom_diff_Phi',
2741 'SVD_CDC_POCA_Mom_diff_Mag',
2742 'SVD_CDC_POCA_Mom_diff_Eta',
2743 'POCA_Pos_Pt',
2744 'POCA_Pos_Mag',
2745 'POCA_Pos_Phi',
2746 'POCA_Pos_Z',
2747 'POCA_Pos_Theta',
2748 'PXD_QI',
2749 'SVD_FitSuccessful',
2750 'CDC_FitSuccessful',
2751 'pdg_ID',
2752 'pdg_ID_Mother',
2753 'is_Vzero_Daughter',
2754 'is_Primary',
2755 'z0',
2756 'd0',
2757 'seed_Charge',
2758 'Fit_Charge',
2759 'weight_max',
2760 'weight_min',
2761 'weight_mean',
2762 'weight_std',
2763 'weight_median',
2764 'weight_n_zeros',
2765 'weight_firstCDCHit',
2766 'weight_lastSVDHit',
2767 'smoothedChi2_max',
2768 'smoothedChi2_min',
2769 'smoothedChi2_mean',
2770 'smoothedChi2_std',
2771 'smoothedChi2_median',
2772 'smoothedChi2_n_zeros',
2773 'smoothedChi2_firstCDCHit',
2774 'smoothedChi2_lastSVDHit',
2775 'SVD_QI'] + \
2776 ["SVD_" + x for x in exclude_variables_vxd] + \
2777 ["SVDbefore_" + x for x in exclude_variables_vxd]
2778
2779 def requires(self):
2780 """
2781 Generate list of tasks that needs to be done for luigi to finish running
2782 this steering file.
2783 """
2784 cdc_training_targets = [
2785 "truth", # treats clones as background, only best matched CDC tracks are true
2786 # "truth_track_is_matched" # treats clones as signal
2787 ]
2788
2789 fast_bdt_options = []
2790 # possible to run over a chosen hyperparameter space if wanted
2791 # in principle this can be extended to specific options for the three different MVAs
2792 # for i in range(250, 400, 50):
2793 # for j in range(6, 10, 2):
2794 # for k in range(2, 6):
2795 # for l in range(0, 5):
2796 # fast_bdt_options.append([100 + i, j, 3+k, 0.025+l*0.025])
2797 # fast_bdt_options.append([200, 8, 3, 0.1]) # default FastBDT option
2798 fast_bdt_options.append([350, 6, 5, 0.1])
2799
2800 experiment_numbers = b2luigi.get_setting("experiment_numbers")
2801
2802 # iterate over all possible combinations of parameters from the above defined parameter lists
2803 for experiment_number, cdc_training_target, fast_bdt_option in itertools.product(
2804 experiment_numbers, cdc_training_targets, fast_bdt_options
2805 ):
2806 # if test_selected_task is activated, only run the following tasks:
2807 if b2luigi.get_setting("test_selected_task", default=False):
2808 # for process_type in ['BHABHA', 'MUMU', 'TAUPAIR', 'YY', 'EEEE', 'EEMUMU', 'UUBAR', \
2809 # 'DDBAR', 'CCBAR', 'SSBAR', 'BBBAR', 'V0BBBAR', 'V0STUDY']:
2810 for cut in ['000', '070', '090', '095']:
2811 yield RecoTrackQEDataCollectionTask(
2812 num_processes=self.num_processes,
2813 n_events=self.n_events_testing,
2814 experiment_number=experiment_number,
2815 random_seed=self.process_type + '_test',
2816 recotrack_option='useCDC_useVXD_deleteCDCQI'+cut,
2817 cdc_training_target=cdc_training_target,
2818 fast_bdt_option=fast_bdt_option,
2819 )
2820 yield CDCQEDataCollectionTask(
2821 num_processes=self.num_processes,
2822 n_events=self.n_events_testing,
2823 experiment_number=experiment_number,
2824 random_seed=self.process_type + '_test',
2825 )
2826 yield CDCQETeacherTask(
2827 n_events_training=self.n_events_training,
2828 process_type=self.process_type,
2829 experiment_number=experiment_number,
2830 exclude_variables=self.exclude_variables_cdc,
2831 training_target=cdc_training_target,
2832 fast_bdt_option=fast_bdt_option,
2833 )
2834 else:
2835 # if data shall be processed, it can neither be trained nor evaluated
2836 if 'DATA' in self.process_type:
2837 yield VXDQEDataCollectionTask(
2838 num_processes=self.num_processes,
2839 n_events=self.n_events_testing,
2840 experiment_number=experiment_number,
2841 random_seed=self.process_type + '_test',
2842 )
2843 yield CDCQEDataCollectionTask(
2844 num_processes=self.num_processes,
2845 n_events=self.n_events_testing,
2846 experiment_number=experiment_number,
2847 random_seed=self.process_type + '_test',
2848 )
2849 yield RecoTrackQEDataCollectionTask(
2850 num_processes=self.num_processes,
2851 n_events=self.n_events_testing,
2852 experiment_number=experiment_number,
2853 random_seed=self.process_type + '_test',
2854 recotrack_option='deleteCDCQI080',
2855 cdc_training_target=cdc_training_target,
2856 fast_bdt_option=fast_bdt_option,
2857 )
2858 else:
2859 yield QEWeightsLocalDBCreatorTask(
2860 n_events_training=self.n_events_training,
2861 process_type=self.process_type,
2862 experiment_number=experiment_number,
2863 cdc_training_target=cdc_training_target,
2864 fast_bdt_option=fast_bdt_option,
2865 )
2866
2867 if b2luigi.get_setting("run_validation_tasks", default=True):
2868 yield RecoTrackQEValidationPlotsTask(
2869 n_events_training=self.n_events_training,
2870 n_events_testing=self.n_events_testing,
2871 process_type=self.process_type,
2872 experiment_number=experiment_number,
2873 cdc_training_target=cdc_training_target,
2874 exclude_variables=self.exclude_variables_rec,
2875 fast_bdt_option=fast_bdt_option,
2876 )
2877 yield CDCQEValidationPlotsTask(
2878 n_events_training=self.n_events_training,
2879 n_events_testing=self.n_events_testing,
2880 process_type=self.process_type,
2881 experiment_number=experiment_number,
2882 exclude_variables=self.exclude_variables_cdc,
2883 training_target=cdc_training_target,
2884 fast_bdt_option=fast_bdt_option,
2885 )
2886 yield VXDQEValidationPlotsTask(
2887 n_events_training=self.n_events_training,
2888 n_events_testing=self.n_events_testing,
2889 process_type=self.process_type,
2890 exclude_variables=self.exclude_variables_vxd,
2891 experiment_number=experiment_number,
2892 fast_bdt_option=fast_bdt_option,
2893 )
2894
2895 if b2luigi.get_setting("run_mva_evaluate", default=True):
2896 # Evaluate trained weightfiles via basf2_mva_evaluate.py on separate testdatasets
2897 # requires a latex installation to work
2898 yield RecoTrackQEEvaluationTask(
2899 n_events_training=self.n_events_training,
2900 n_events_testing=self.n_events_testing,
2901 process_type=self.process_type,
2902 experiment_number=experiment_number,
2903 cdc_training_target=cdc_training_target,
2904 exclude_variables=self.exclude_variables_rec,
2905 fast_bdt_option=fast_bdt_option,
2906 )
2907 yield CDCTrackQEEvaluationTask(
2908 n_events_training=self.n_events_training,
2909 n_events_testing=self.n_events_testing,
2910 process_type=self.process_type,
2911 experiment_number=experiment_number,
2912 exclude_variables=self.exclude_variables_cdc,
2913 fast_bdt_option=fast_bdt_option,
2914 training_target=cdc_training_target,
2915 )
2916 yield VXDTrackQEEvaluationTask(
2917 n_events_training=self.n_events_training,
2918 n_events_testing=self.n_events_testing,
2919 process_type=self.process_type,
2920 experiment_number=experiment_number,
2921 exclude_variables=self.exclude_variables_vxd,
2922 fast_bdt_option=fast_bdt_option,
2923 )
2924
2925
2926if __name__ == "__main__":
2927 # if n_events_test_on_data is specified to be different from -1 in the settings,
2928 # then stop after N events (mainly useful to test data reconstruction):
2929 nEventsTestOnData = b2luigi.get_setting("n_events_test_on_data", default=-1)
2930 if nEventsTestOnData > 0 and 'DATA' in b2luigi.get_setting("process_type", default="BBBAR"):
2931 from ROOT import Belle2
2932 environment = Belle2.Environment.Instance()
2933 environment.setNumberEventsOverride(nEventsTestOnData)
2934 # if global tags are specified in the settings, use them:
2935 # e.g. for data use ["data_reprocessing_prompt", "online"]. Make sure to be up to date here
2936 globaltags = b2luigi.get_setting("globaltags", default=[])
2937 if len(globaltags) > 0:
2938 basf2.conditions.reset()
2939 for gt in globaltags:
2940 basf2.conditions.prepend_globaltag(gt)
2941 workers = b2luigi.get_setting("workers", default=1)
2942 b2luigi.process(MasterTask(), workers=workers)
2943
2944# @endcond
get_background_files(folder=None, output_file_info=True)
Definition background.py:60
static Environment & Instance()
Static method to get a reference to the Environment instance.
add_simulation(path, components=None, bkgfiles=None, bkgOverlay=True, forceSetPXDDataReduction=False, usePXDDataReduction=True, cleanupPXDDataReduction=True, generate_2nd_cdc_hits=False, simulateT0jitter=True, isCosmics=False, FilterEvents=False, usePXDGatedMode=False, skipExperimentCheckForBG=False, save_slow_pions_in_mc=False, save_all_charged_particles_in_mc=False)