12combined_module_quality_estimator_teacher
13-----------------------------------------
15Information on the MVA Track Quality Indicator / Estimator can be found
17<https://xwiki.desy.de/xwiki/rest/p/0d3f4>`_.
22This python script is used for the combined training and validation of three
23classifiers, the actual final MVA track quality estimator and the two quality
24estimators for the intermediate standalone track finders that it depends on.
26 - Final MVA track quality estimator:
27 The final quality estimator for fully merged and fitted tracks (RecoTracks).
28 Its classifier uses features from the track fitting, merger, hit pattern, ...
29 But it also uses the outputs from respective intermediate quality
30 estimators for the VXD and the CDC track finding as inputs. It provides
31 the final quality indicator (QI) exported to the track objects.
33 - VXDTF2 track quality estimator:
34 MVA quality estimator for the VXD standalone track finding.
36 - CDC track quality estimator:
37 MVA quality estimator for the CDC standalone track finding.
39Each classifier requires for its training a different training data set and they
40need to be validated on a separate testing data set. Further, the final quality
41estimator can only be trained, when the trained weights for the intermediate
42quality estimators are available. If the final estimator shall be trained without
43one or both previous estimators, the requirements have to be commented out in the
44__init__.py file of tracking.
45For all estimators, a list of variables to be ignored is specified in the MasterTask.
46The current choice is mainly based on pure data MC agreement in these quantities or
47on outdated implementations. It was decided to leave them in the hardcoded "ugly" way
48in here to remind future generations that they exist in principle and they should and
49could be added to the estimator, once their modelling becomes better in future or an
50alternative implementation is programmed.
51To avoid mistakes, b2luigi is used to create a task chain for a combined training and
52validation of all classifiers.
54b2luigi: Understanding the steering file
55~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
57All trainings and validations are done in the correct order in this steering
58file. For the purpose of creating a dependency graph, the `b2luigi
59<https://b2luigi.readthedocs.io>`_ python package is used, which extends the
60`luigi <https://luigi.readthedocs.io>`_ package developed by spotify.
62Each task that has to be done is represented by a special class, which defines
63which defines parameters, output files and which other tasks with which
64parameters it depends on. For example a teacher task, which runs
65``basf2_mva_teacher.py`` to train the classifier, depends on a data collection
66task which runs a reconstruction and writes out track-wise variables into a root
67file for training. An evaluation/validation task for testing the classifier
68requires both the teacher task, as it needs the weightfile to be present, and
69also a data collection task, because it needs a dataset for testing classifier.
71The final task that defines which tasks need to be done for the steering file to
72finish is the ``MasterTask``. When you only want to run parts of the
73training/validation pipeline, you can comment out requirements in the Master
74task or replace them by lower-level tasks during debugging.
79This steering file relies on b2luigi_ for task scheduling and `uncertain_panda
80<https://github.com/nils-braun/uncertain_panda>`_ for uncertainty calculations.
81uncertain_panda is not in the externals and b2luigi is not upto v01-07-01. Both
82can be installed via pip::
84 python3 -m pip install [--user] b2luigi uncertain_panda
86Use the ``--user`` option if you have not rights to install python packages into
87your externals (e.g. because you are using cvmfs) and install them in
88``$HOME/.local`` instead.
93Instead of command line arguments, the b2luigi script is configured via a
94``settings.json`` file. Open it in your favorite text editor and modify it to
95fit to your requirements.
100You can test the b2luigi without running it via::
102 python3 combined_quality_estimator_teacher.py --dry-run
103 python3 combined_quality_estimator_teacher.py --show-output
105This will show the outputs and show potential errors in the definitions of the
106luigi task dependencies. To run the the steering file in normal (local) mode,
109 python3 combined_quality_estimator_teacher.py
111I usually use the interactive luigi web interface via the central scheduler
112which visualizes the task graph while it is running. Therefore, the scheduler
113daemon ``luigid`` has to run in the background, which is located in
114``~/.local/bin/luigid`` in case b2luigi had been installed with ``--user``. For
119Then, execute your steering (e.g. in another terminal) with::
121 python3 combined_quality_estimator_teacher.py --scheduler-port 8886
123To view the web interface, open your webbrowser enter into the url bar::
127If you don't run the steering file on the same machine on which you run your web
128browser, you have two options:
130 1. Run both the steering file and ``luigid`` remotely and use
131 ssh-port-forwarding to your local host. Therefore, run on your local
134 ssh -N -f -L 8886:localhost:8886 <remote_user>@<remote_host>
136 2. Run the ``luigid`` scheduler locally and use the ``--scheduler-host <your
137 local host>`` argument when calling the steering file
139Accessing the results / output files
140~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
142All output files are stored in a directory structure in the ``result_path``. The
143directory tree encodes the used b2luigi parameters. This ensures reproducibility
144and makes parameter searches easy. Sometimes, it is hard to find the relevant
145output files. You can view the whole directory structure by running ``tree
146<result_path>``. Ise the unix ``find`` command to find the files that interest
149 find <result_path> -name "*.pdf" # find all validation plot files
150 find <result_path> -name "*.root" # find all ROOT files
155from pathlib
import Path
159from datetime
import datetime
160from typing
import Iterable
162import matplotlib.pyplot
as plt
165from matplotlib.backends.backend_pdf
import PdfPages
169from packaging
import version
179install_helpstring_formatter = (
"\nCould not find {module} python module.Try installing it via\n"
180 " python3 -m pip install [--user] {module}\n")
183 from b2luigi.core.utils
import get_serialized_parameters, get_log_file_dir, create_output_dirs
184 from b2luigi.basf2_helper
import Basf2PathTask, Basf2Task
185 from b2luigi.core.task
import Task, ExternalTask
186 from b2luigi.basf2_helper.utils
import get_basf2_git_hash
187except ModuleNotFoundError:
188 print(install_helpstring_formatter.format(module=
"b2luigi"))
191 from uncertain_panda
import pandas
as upd
192except ModuleNotFoundError:
193 print(install_helpstring_formatter.format(module=
"uncertain_panda"))
201 version.parse(b2luigi.__version__) <= version.parse(
"0.3.2")
and
202 get_basf2_git_hash()
is None and
203 os.getenv(
"BELLE2_LOCAL_DIR")
is not None
205 print(f
"b2luigi version could not obtain git hash because of a bug not yet fixed in version {b2luigi.__version__}\n"
206 "Please install the latest version of b2luigi from github via\n\n"
207 " python3 -m pip install --upgrade [--user] git+https://github.com/nils-braun/b2luigi.git\n")
213def create_fbdt_option_string(fast_bdt_option):
215 returns a readable string created by the fast_bdt_option array
217 return "_nTrees" + str(fast_bdt_option[0]) +
"_nCuts" + str(fast_bdt_option[1]) +
"_nLevels" + \
218 str(fast_bdt_option[2]) +
"_shrin" + str(int(round(100*fast_bdt_option[3], 0)))
221def createV0momenta(x, mu, beta):
223 Copied from Biancas K_S0 particle gun code: Returns a realistic V0 momentum distribution
224 when running over x. Mu and Beta are properties of the function that define center and tails.
225 Used for the particle gun simulation code for K_S0 and Lambda_0
227 return (1/beta)*np.exp(-(x - mu)/beta) * np.exp(-np.exp(-(x - mu) / beta))
230def my_basf2_mva_teacher(
233 weightfile_identifier,
234 target_variable="truth",
235 exclude_variables=None,
236 fast_bdt_option=[200, 8, 3, 0.1]
239 My custom wrapper for basf2 mva teacher. Adapted from code in ``trackfindingcdc_teacher``.
241 :param records_files: List of files with collected ("recorded") variables to use as training data for the MVA.
242 :param tree_name: Name of the TTree in the ROOT file from the ``data_collection_task``
243 that contains the training data for the MVA teacher.
244 :param weightfile_identifier: Name of the weightfile that is created.
245 Should either end in ".xml" for local weightfiles or in ".root", when
246 the weightfile needs later to be uploaded as a payload to the conditions
248 :param target_variable: Feature/variable to use as truth label in the quality estimator MVA classifier.
249 :param exclude_variables: List of collected variables to not use in the training of the QE MVA classifier.
250 In addition to variables containing the "truth" substring, which are excluded by default.
251 :param fast_bdt_option: specified fast BDT options, default: [200, 8, 3, 0.1] [nTrees, nCuts, nLevels, shrinkage]
253 if exclude_variables
is None:
254 exclude_variables = []
256 weightfile_extension = Path(weightfile_identifier).suffix
257 if weightfile_extension
not in {
".xml",
".root"}:
258 raise ValueError(f
"Weightfile Identifier should end in .xml or .root, but ends in {weightfile_extension}")
261 with root_utils.root_open(records_files[0])
as records_tfile:
262 input_tree = records_tfile.Get(tree_name)
263 feature_names = [leave.GetName()
for leave
in input_tree.GetListOfLeaves()]
266 truth_free_variable_names = [
268 for name
in feature_names
270 (
"truth" not in name)
and
271 (name != target_variable)
and
272 (name
not in exclude_variables)
275 if "weight" in truth_free_variable_names:
276 truth_free_variable_names.remove(
"weight")
277 weight_variable =
"weight"
278 elif "__weight__" in truth_free_variable_names:
279 truth_free_variable_names.remove(
"__weight__")
280 weight_variable =
"__weight__"
285 general_options = basf2_mva.GeneralOptions()
286 general_options.m_datafiles = basf2_mva.vector(*records_files)
287 general_options.m_treename = tree_name
288 general_options.m_weight_variable = weight_variable
289 general_options.m_identifier = weightfile_identifier
290 general_options.m_variables = basf2_mva.vector(*truth_free_variable_names)
291 general_options.m_target_variable = target_variable
292 fastbdt_options = basf2_mva.FastBDTOptions()
294 fastbdt_options.m_nTrees = fast_bdt_option[0]
295 fastbdt_options.m_nCuts = fast_bdt_option[1]
296 fastbdt_options.m_nLevels = fast_bdt_option[2]
297 fastbdt_options.m_shrinkage = fast_bdt_option[3]
299 basf2_mva.teacher(general_options, fastbdt_options)
302def _my_uncertain_mean(series: upd.Series):
304 Temporary Workaround bug in ``uncertain_panda`` where a ``ValueError`` is
305 thrown for ``Series.unc.mean`` if the series is empty. Can be replaced by
306 .unc.mean when the issue is fixed.
307 https://github.com/nils-braun/uncertain_panda/issues/2
310 return series.unc.mean()
318def get_uncertain_means_for_qi_cuts(df: upd.DataFrame, column: str, qi_cuts: Iterable[float]):
320 Return a pandas series with an mean of the dataframe column and
321 uncertainty for each quality indicator cut.
323 :param df: Pandas dataframe with at least ``quality_indicator``
324 and another numeric ``column``.
325 :param column: Column of which we want to aggregate the means
326 and uncertainties for different QI cuts
327 :param qi_cuts: Iterable of quality indicator minimal thresholds.
328 :returns: Series of of means and uncertainties with ``qi_cuts`` as index
331 uncertain_means = (_my_uncertain_mean(df.query(f
"quality_indicator > {qi_cut}")[column])
332 for qi_cut
in qi_cuts)
333 uncertain_means_series = upd.Series(data=uncertain_means, index=qi_cuts)
334 return uncertain_means_series
337def plot_with_errobands(uncertain_series,
338 error_band_alpha=0.3,
340 fill_between_kwargs={},
343 Plot an uncertain series with error bands for y-errors
347 uncertain_series = uncertain_series.dropna()
348 ax.plot(uncertain_series.index.values, uncertain_series.nominal_value, **plot_kwargs)
349 ax.fill_between(x=uncertain_series.index,
350 y1=uncertain_series.nominal_value - uncertain_series.std_dev,
351 y2=uncertain_series.nominal_value + uncertain_series.std_dev,
352 alpha=error_band_alpha,
353 **fill_between_kwargs)
356def format_dictionary(adict, width=80, bullet="•"):
358 Helper function to format dictionary to string as a wrapped key-value bullet
359 list. Useful to print metadata from dictionaries.
361 :param adict: Dictionary to format
362 :param width: Characters after which to wrap a key-value line
363 :param bullet: Character to begin a key-value line with, e.g. ``-`` for a
369 return "\n".join(textwrap.fill(f
"{bullet} {key}: {value}", width=width)
370 for (key, value)
in adict.items())
375class GenerateSimTask(Basf2PathTask):
377 Generate simulated Monte Carlo with background overlay.
379 Make sure to use different ``random_seed`` parameters for the training data
380 format the classifier trainings and for the test data for the respective
381 evaluation/validation tasks.
385 n_events = b2luigi.IntParameter()
387 experiment_number = b2luigi.IntParameter()
390 random_seed = b2luigi.Parameter()
392 bkgfiles_dir = b2luigi.Parameter(
401 def output_file_name(self, n_events=None, random_seed=None):
403 Create output file name depending on number of events and production
404 mode that is specified in the random_seed string.
407 n_events = self.n_events
408 if random_seed
is None:
409 random_seed = self.random_seed
410 return "generated_mc_N" + str(n_events) +
"_" + random_seed +
".root"
414 Generate list of output files that the task should produce.
415 The task is considered finished if and only if the outputs all exist.
417 yield self.add_to_output(self.output_file_name())
419 def create_path(self):
421 Create basf2 path to process with event generation and simulation.
423 basf2.set_random_seed(self.random_seed)
424 path = basf2.create_path()
425 if self.experiment_number
in [0, 1002, 1003]:
430 f
"Simulating events with experiment number {self.experiment_number} is not implemented yet.")
432 "EventInfoSetter", evtNumList=[self.n_events], runList=[runNo], expList=[self.experiment_number]
434 if "BBBAR" in self.random_seed:
435 path.add_module(
"EvtGenInput")
436 elif "V0BBBAR" in self.random_seed:
437 path.add_module(
"EvtGenInput")
438 path.add_module(
"InclusiveParticleChecker", particles=[310, 3122], includeConjugates=
True)
440 import generators
as ge
446 if "V0STUDY" in self.random_seed:
447 if "V0STUDYKS" in self.random_seed:
452 if "V0STUDYL0" in self.random_seed:
461 pdgs = [310, 3122, -3122]
463 myx = [i*0.01
for i
in range(321)]
466 y = createV0momenta(x, mu, beta)
468 polParams = myx + myy
472 particlegun = basf2.register_module(
'ParticleGun')
473 particlegun.param(
'pdgCodes', pdg_list)
474 particlegun.param(
'nTracks', 8)
475 particlegun.param(
'momentumGeneration',
'polyline')
476 particlegun.param(
'momentumParams', polParams)
477 particlegun.param(
'thetaGeneration',
'uniformCos')
478 particlegun.param(
'thetaParams', [17, 150])
479 particlegun.param(
'phiGeneration',
'uniform')
480 particlegun.param(
'phiParams', [0, 360])
481 particlegun.param(
'vertexGeneration',
'fixed')
482 particlegun.param(
'xVertexParams', [0])
483 particlegun.param(
'yVertexParams', [0])
484 particlegun.param(
'zVertexParams', [0])
485 path.add_module(particlegun)
486 if "BHABHA" in self.random_seed:
487 ge.add_babayaganlo_generator(path=path, finalstate=
'ee', minenergy=0.15, minangle=10.0)
488 elif "MUMU" in self.random_seed:
489 ge.add_kkmc_generator(path=path, finalstate=
'mu+mu-')
490 elif "YY" in self.random_seed:
491 babayaganlo = basf2.register_module(
'BabayagaNLOInput')
492 babayaganlo.param(
'FinalState',
'gg')
493 babayaganlo.param(
'MaxAcollinearity', 180.0)
494 babayaganlo.param(
'ScatteringAngleRange', [0., 180.])
495 babayaganlo.param(
'FMax', 75000)
496 babayaganlo.param(
'MinEnergy', 0.01)
497 babayaganlo.param(
'Order',
'exp')
498 babayaganlo.param(
'DebugEnergySpread', 0.01)
499 babayaganlo.param(
'Epsilon', 0.00005)
500 path.add_module(babayaganlo)
501 generatorpreselection = basf2.register_module(
'GeneratorPreselection')
502 generatorpreselection.param(
'nChargedMin', 0)
503 generatorpreselection.param(
'nChargedMax', 999)
504 generatorpreselection.param(
'MinChargedPt', 0.15)
505 generatorpreselection.param(
'MinChargedTheta', 17.)
506 generatorpreselection.param(
'MaxChargedTheta', 150.)
507 generatorpreselection.param(
'nPhotonMin', 1)
508 generatorpreselection.param(
'MinPhotonEnergy', 1.5)
509 generatorpreselection.param(
'MinPhotonTheta', 15.0)
510 generatorpreselection.param(
'MaxPhotonTheta', 165.0)
511 generatorpreselection.param(
'applyInCMS',
True)
512 path.add_module(generatorpreselection)
513 empty = basf2.create_path()
514 generatorpreselection.if_value(
'!=11', empty)
515 elif "EEEE" in self.random_seed:
516 ge.add_aafh_generator(path=path, finalstate=
'e+e-e+e-', preselection=
False)
517 elif "EEMUMU" in self.random_seed:
518 ge.add_aafh_generator(path=path, finalstate=
'e+e-mu+mu-', preselection=
False)
519 elif "TAUPAIR" in self.random_seed:
520 ge.add_kkmc_generator(path, finalstate=
'tau+tau-')
521 elif "DDBAR" in self.random_seed:
522 ge.add_continuum_generator(path, finalstate=
'ddbar')
523 elif "UUBAR" in self.random_seed:
524 ge.add_continuum_generator(path, finalstate=
'uubar')
525 elif "SSBAR" in self.random_seed:
526 ge.add_continuum_generator(path, finalstate=
'ssbar')
527 elif "CCBAR" in self.random_seed:
528 ge.add_continuum_generator(path, finalstate=
'ccbar')
535 if self.experiment_number == 1002:
537 components = [
'PXD',
'SVD',
'CDC',
'ECL',
'TOP',
'ARICH',
'TRG']
545 outputFileName=self.get_output_file_name(self.output_file_name()),
552class SplitNMergeSimTask(Basf2Task):
554 Generate simulated Monte Carlo with background overlay.
556 Make sure to use different ``random_seed`` parameters for the training data
557 format the classifier trainings and for the test data for the respective
558 evaluation/validation tasks.
562 n_events = b2luigi.IntParameter()
564 experiment_number = b2luigi.IntParameter()
567 random_seed = b2luigi.Parameter()
569 bkgfiles_dir = b2luigi.Parameter(
578 def output_file_name(self, n_events=None, random_seed=None):
580 Create output file name depending on number of events and production
581 mode that is specified in the random_seed string.
584 n_events = self.n_events
585 if random_seed
is None:
586 random_seed = self.random_seed
587 return "generated_mc_N" + str(n_events) +
"_" + random_seed +
".root"
591 Generate list of output files that the task should produce.
592 The task is considered finished if and only if the outputs all exist.
594 yield self.add_to_output(self.output_file_name())
598 Generate list of luigi Tasks that this Task depends on.
600 n_events_per_task = MasterTask.n_events_per_task
601 quotient, remainder = divmod(self.n_events, n_events_per_task)
602 for i
in range(quotient):
603 yield GenerateSimTask(
604 bkgfiles_dir=self.bkgfiles_dir,
605 num_processes=MasterTask.num_processes,
606 random_seed=self.random_seed +
'_' + str(i).zfill(3),
607 n_events=n_events_per_task,
608 experiment_number=self.experiment_number,
611 yield GenerateSimTask(
612 bkgfiles_dir=self.bkgfiles_dir,
613 num_processes=MasterTask.num_processes,
614 random_seed=self.random_seed +
'_' + str(quotient).zfill(3),
616 experiment_number=self.experiment_number,
619 @b2luigi.on_temporary_files
622 When all GenerateSimTasks finished, merge the output.
624 create_output_dirs(self)
627 for _, file_name
in self.get_input_file_names().items():
628 file_list.append(*file_name)
629 print(
"Merge the following files:")
631 cmd = [
"b2file-merge",
"-f"]
632 args = cmd + [self.get_output_file_name(self.output_file_name())] + file_list
633 subprocess.check_call(args)
634 print(
"Finished merging. Now remove the input files to save space.")
636 for tempfile
in file_list:
637 args = cmd2 + [tempfile]
638 subprocess.check_call(args)
641class CheckExistingFile(ExternalTask):
643 Task to check if the given file really exists.
646 filename = b2luigi.Parameter()
650 Specify the output to be the file that was just checked.
652 from luigi
import LocalTarget
653 return LocalTarget(self.filename)
656class VXDQEDataCollectionTask(Basf2PathTask):
658 Collect variables/features from VXDTF2 tracking and write them to a ROOT
661 These variables are to be used as labelled training data for the MVA
662 classifier which is the VXD track quality estimator
665 n_events = b2luigi.IntParameter()
667 experiment_number = b2luigi.IntParameter()
670 random_seed = b2luigi.Parameter()
675 def get_records_file_name(self, n_events=None, random_seed=None):
677 Create output file name depending on number of events and production
678 mode that is specified in the random_seed string.
681 n_events = self.n_events
682 if random_seed
is None:
683 random_seed = self.random_seed
684 if 'vxd' not in random_seed:
685 random_seed +=
'_vxd'
686 if 'DATA' in random_seed:
687 return 'qe_records_DATA_vxd.root'
689 if 'USESIMBB' in random_seed:
690 random_seed =
'BBBAR_' + random_seed.split(
"_", 1)[1]
691 elif 'USESIMEE' in random_seed:
692 random_seed =
'BHABHA_' + random_seed.split(
"_", 1)[1]
693 return 'qe_records_N' + str(n_events) +
'_' + random_seed +
'.root'
695 def get_input_files(self, n_events=None, random_seed=None):
697 Get input file names depending on the use case: If they already exist, search in
698 the corresponding folders, for data check the specified list and if they are created
699 in the same run, check for the task that produced them.
702 n_events = self.n_events
703 if random_seed
is None:
704 random_seed = self.random_seed
705 if "USESIM" in random_seed:
706 if 'USESIMBB' in random_seed:
707 random_seed =
'BBBAR_' + random_seed.split(
"_", 1)[1]
708 elif 'USESIMEE' in random_seed:
709 random_seed =
'BHABHA_' + random_seed.split(
"_", 1)[1]
710 return [
'datafiles/' + GenerateSimTask.output_file_name(GenerateSimTask,
711 n_events=n_events, random_seed=random_seed)]
712 elif "DATA" in random_seed:
713 return MasterTask.datafiles
715 return self.get_input_file_names(GenerateSimTask.output_file_name(
716 GenerateSimTask, n_events=n_events, random_seed=random_seed))
720 Generate list of luigi Tasks that this Task depends on.
722 if "USESIM" in self.random_seed
or "DATA" in self.random_seed:
723 for filename
in self.get_input_files():
724 yield CheckExistingFile(
728 yield SplitNMergeSimTask(
729 bkgfiles_dir=MasterTask.bkgfiles_by_exp[self.experiment_number],
730 random_seed=self.random_seed,
731 n_events=self.n_events,
732 experiment_number=self.experiment_number,
737 Generate list of output files that the task should produce.
738 The task is considered finished if and only if the outputs all exist.
740 yield self.add_to_output(self.get_records_file_name())
742 def create_path(self):
744 Create basf2 path with VXDTF2 tracking and VXD QE data collection.
746 path = basf2.create_path()
747 inputFileNames = self.get_input_files()
750 inputFileNames=inputFileNames,
752 path.add_module(
"Gearbox")
753 tracking.add_geometry_modules(path)
754 if 'DATA' in self.random_seed:
755 from rawdata
import add_unpackers
756 add_unpackers(path, components=[
'SVD',
'PXD'])
757 tracking.add_hit_preparation_modules(path)
758 tracking.add_vxd_track_finding_vxdtf2(
759 path, components=[
"SVD"], add_mva_quality_indicator=
False
761 if 'DATA' in self.random_seed:
763 "VXDQETrainingDataCollector",
764 TrainingDataOutputName=self.get_output_file_name(self.get_records_file_name()),
765 SpacePointTrackCandsStoreArrayName=
"SPTrackCands",
766 EstimationMethod=
"tripletFit",
768 ClusterInformation=
"Average",
769 MCStrictQualityEstimator=
False,
775 "TrackFinderMCTruthRecoTracks",
776 RecoTracksStoreArrayName=
"MCRecoTracks",
783 "VXDQETrainingDataCollector",
784 TrainingDataOutputName=self.get_output_file_name(self.get_records_file_name()),
785 SpacePointTrackCandsStoreArrayName=
"SPTrackCands",
786 EstimationMethod=
"tripletFit",
788 ClusterInformation=
"Average",
789 MCStrictQualityEstimator=
True,
795class CDCQEDataCollectionTask(Basf2PathTask):
797 Collect variables/features from CDC tracking and write them to a ROOT file.
799 These variables are to be used as labelled training data for the MVA
800 classifier which is the CDC track quality estimator
803 n_events = b2luigi.IntParameter()
805 experiment_number = b2luigi.IntParameter()
808 random_seed = b2luigi.Parameter()
813 def get_records_file_name(self, n_events=None, random_seed=None):
815 Create output file name depending on number of events and production
816 mode that is specified in the random_seed string.
819 n_events = self.n_events
820 if random_seed
is None:
821 random_seed = self.random_seed
822 if 'cdc' not in random_seed:
823 random_seed +=
'_cdc'
824 if 'DATA' in random_seed:
825 return 'qe_records_DATA_cdc.root'
827 if 'USESIMBB' in random_seed:
828 random_seed =
'BBBAR_' + random_seed.split(
"_", 1)[1]
829 elif 'USESIMEE' in random_seed:
830 random_seed =
'BHABHA_' + random_seed.split(
"_", 1)[1]
831 return 'qe_records_N' + str(n_events) +
'_' + random_seed +
'.root'
833 def get_input_files(self, n_events=None, random_seed=None):
835 Get input file names depending on the use case: If they already exist, search in
836 the corresponding folders, for data check the specified list and if they are created
837 in the same run, check for the task that produced them.
840 n_events = self.n_events
841 if random_seed
is None:
842 random_seed = self.random_seed
843 if "USESIM" in random_seed:
844 if 'USESIMBB' in random_seed:
845 random_seed =
'BBBAR_' + random_seed.split(
"_", 1)[1]
846 elif 'USESIMEE' in random_seed:
847 random_seed =
'BHABHA_' + random_seed.split(
"_", 1)[1]
848 return [
'datafiles/' + GenerateSimTask.output_file_name(GenerateSimTask,
849 n_events=n_events, random_seed=random_seed)]
850 elif "DATA" in random_seed:
851 return MasterTask.datafiles
853 return self.get_input_file_names(GenerateSimTask.output_file_name(
854 GenerateSimTask, n_events=n_events, random_seed=random_seed))
858 Generate list of luigi Tasks that this Task depends on.
860 if "USESIM" in self.random_seed
or "DATA" in self.random_seed:
861 for filename
in self.get_input_files():
862 yield CheckExistingFile(
866 yield SplitNMergeSimTask(
867 bkgfiles_dir=MasterTask.bkgfiles_by_exp[self.experiment_number],
868 random_seed=self.random_seed,
869 n_events=self.n_events,
870 experiment_number=self.experiment_number,
875 Generate list of output files that the task should produce.
876 The task is considered finished if and only if the outputs all exist.
878 yield self.add_to_output(self.get_records_file_name())
880 def create_path(self):
882 Create basf2 path with CDC standalone tracking and CDC QE with recording filter for MVA feature collection.
884 path = basf2.create_path()
885 inputFileNames = self.get_input_files()
888 inputFileNames=inputFileNames,
890 path.add_module(
"Gearbox")
891 tracking.add_geometry_modules(path)
892 if 'DATA' in self.random_seed:
893 filter_choice =
"recording_data"
894 from rawdata
import add_unpackers
895 add_unpackers(path, components=[
'CDC'])
897 filter_choice =
"recording"
900 tracking.add_cdc_track_finding(path, with_cdc_cellular_automaton=
False, add_mva_quality_indicator=
True)
902 basf2.set_module_parameters(
904 name=
"TFCDC_TrackQualityEstimator",
905 filter=filter_choice,
907 "rootFileName": self.get_output_file_name(self.get_records_file_name())
913class RecoTrackQEDataCollectionTask(Basf2PathTask):
915 Collect variables/features from the reco track reconstruction including the
916 fit and write them to a ROOT file.
918 These variables are to be used as labelled training data for the MVA
919 classifier which is the MVA track quality estimator. The collected
920 variables include the classifier outputs from the VXD and CDC quality
921 estimators, namely the CDC and VXD quality indicators, combined with fit,
922 merger, timing, energy loss information etc. This task requires the
923 subdetector quality estimators to be trained.
927 n_events = b2luigi.IntParameter()
929 experiment_number = b2luigi.IntParameter()
932 random_seed = b2luigi.Parameter()
934 cdc_training_target = b2luigi.Parameter()
938 recotrack_option = b2luigi.Parameter(
940 default=
'deleteCDCQI080'
944 fast_bdt_option = b2luigi.ListParameter(
946 hashed=
True, default=[200, 8, 3, 0.1]
953 def get_records_file_name(self, n_events=None, random_seed=None, recotrack_option=None):
955 Create output file name depending on number of events and production
956 mode that is specified in the random_seed string.
959 n_events = self.n_events
960 if random_seed
is None:
961 random_seed = self.random_seed
962 if recotrack_option
is None:
963 recotrack_option = self.recotrack_option
964 if 'rec' not in random_seed:
965 random_seed +=
'_rec'
966 if 'DATA' in random_seed:
967 return 'qe_records_DATA_rec.root'
969 if 'USESIMBB' in random_seed:
970 random_seed =
'BBBAR_' + random_seed.split(
"_", 1)[1]
971 elif 'USESIMEE' in random_seed:
972 random_seed =
'BHABHA_' + random_seed.split(
"_", 1)[1]
973 return 'qe_records_N' + str(n_events) +
'_' + random_seed +
'_' + recotrack_option +
'.root'
975 def get_input_files(self, n_events=None, random_seed=None):
977 Get input file names depending on the use case: If they already exist, search in
978 the corresponding folders, for data check the specified list and if they are created
979 in the same run, check for the task that produced them.
982 n_events = self.n_events
983 if random_seed
is None:
984 random_seed = self.random_seed
985 if "USESIM" in random_seed:
986 if 'USESIMBB' in random_seed:
987 random_seed =
'BBBAR_' + random_seed.split(
"_", 1)[1]
988 elif 'USESIMEE' in random_seed:
989 random_seed =
'BHABHA_' + random_seed.split(
"_", 1)[1]
990 return [
'datafiles/' + GenerateSimTask.output_file_name(GenerateSimTask,
991 n_events=n_events, random_seed=random_seed)]
992 elif "DATA" in random_seed:
993 return MasterTask.datafiles
995 return self.get_input_file_names(GenerateSimTask.output_file_name(
996 GenerateSimTask, n_events=n_events, random_seed=random_seed))
1000 Generate list of luigi Tasks that this Task depends on.
1002 if "USESIM" in self.random_seed
or "DATA" in self.random_seed:
1003 for filename
in self.get_input_files():
1004 yield CheckExistingFile(
1008 yield SplitNMergeSimTask(
1009 bkgfiles_dir=MasterTask.bkgfiles_by_exp[self.experiment_number],
1010 random_seed=self.random_seed,
1011 n_events=self.n_events,
1012 experiment_number=self.experiment_number,
1014 if "DATA" not in self.random_seed:
1015 if 'useCDC' not in self.recotrack_option
and 'noCDC' not in self.recotrack_option:
1016 yield CDCQETeacherTask(
1017 n_events_training=MasterTask.n_events_training,
1018 experiment_number=self.experiment_number,
1019 training_target=self.cdc_training_target,
1020 process_type=self.random_seed.split(
"_", 1)[0],
1021 exclude_variables=MasterTask.exclude_variables_cdc,
1022 fast_bdt_option=self.fast_bdt_option,
1024 if 'useVXD' not in self.recotrack_option
and 'noVXD' not in self.recotrack_option:
1025 yield VXDQETeacherTask(
1026 n_events_training=MasterTask.n_events_training,
1027 experiment_number=self.experiment_number,
1028 process_type=self.random_seed.split(
"_", 1)[0],
1029 exclude_variables=MasterTask.exclude_variables_vxd,
1030 fast_bdt_option=self.fast_bdt_option,
1035 Generate list of output files that the task should produce.
1036 The task is considered finished if and only if the outputs all exist.
1038 yield self.add_to_output(self.get_records_file_name())
1040 def create_path(self):
1042 Create basf2 reconstruction path that should mirror the default path
1043 from ``add_tracking_reconstruction()``, but with modules for the VXD QE
1044 and CDC QE application and for collection of variables for the reco
1045 track quality estimator.
1047 path = basf2.create_path()
1048 inputFileNames = self.get_input_files()
1051 inputFileNames=inputFileNames,
1053 path.add_module(
"Gearbox")
1058 if 'noCDC' in self.recotrack_option:
1060 if 'noVXD' in self.recotrack_option:
1062 if 'DATA' in self.random_seed:
1063 from rawdata
import add_unpackers
1065 tracking.add_tracking_reconstruction(path, add_cdcTrack_QI=mvaCDC, add_vxdTrack_QI=mvaVXD, add_recoTrack_QI=
True)
1070 if (
'DATA' in self.random_seed
or 'useCDC' in self.recotrack_option)
and 'noCDC' not in self.recotrack_option:
1071 cdc_identifier =
'datafiles/' + \
1072 CDCQETeacherTask.get_weightfile_xml_identifier(CDCQETeacherTask, fast_bdt_option=self.fast_bdt_option)
1073 if os.path.exists(cdc_identifier):
1074 replace_cdc_qi =
True
1075 elif 'useCDC' in self.recotrack_option:
1076 raise ValueError(f
"CDC QI Identifier not found: {cdc_identifier}")
1078 replace_cdc_qi =
False
1079 elif 'noCDC' in self.recotrack_option:
1080 replace_cdc_qi =
False
1082 cdc_identifier = self.get_input_file_names(
1083 CDCQETeacherTask.get_weightfile_xml_identifier(
1084 CDCQETeacherTask, fast_bdt_option=self.fast_bdt_option))[0]
1085 replace_cdc_qi =
True
1086 if (
'DATA' in self.random_seed
or 'useVXD' in self.recotrack_option)
and 'noVXD' not in self.recotrack_option:
1087 vxd_identifier =
'datafiles/' + \
1088 VXDQETeacherTask.get_weightfile_xml_identifier(VXDQETeacherTask, fast_bdt_option=self.fast_bdt_option)
1089 if os.path.exists(vxd_identifier):
1090 replace_vxd_qi =
True
1091 elif 'useVXD' in self.recotrack_option:
1092 raise ValueError(f
"VXD QI Identifier not found: {vxd_identifier}")
1094 replace_vxd_qi =
False
1095 elif 'noVXD' in self.recotrack_option:
1096 replace_vxd_qi =
False
1098 vxd_identifier = self.get_input_file_names(
1099 VXDQETeacherTask.get_weightfile_xml_identifier(
1100 VXDQETeacherTask, fast_bdt_option=self.fast_bdt_option))[0]
1101 replace_vxd_qi =
True
1103 cdc_qe_mva_filter_parameters =
None
1106 if 'deleteCDCQI' in self.recotrack_option:
1107 cut_index = self.recotrack_option.find(
'deleteCDCQI') + len(
'deleteCDCQI')
1108 cut = int(self.recotrack_option[cut_index:cut_index+3])/100.
1110 cdc_qe_mva_filter_parameters = {
1111 "identifier": cdc_identifier,
"cut": cut}
1113 cdc_qe_mva_filter_parameters = {
1115 elif replace_cdc_qi:
1116 cdc_qe_mva_filter_parameters = {
1117 "identifier": cdc_identifier}
1118 if cdc_qe_mva_filter_parameters
is not None:
1120 basf2.set_module_parameters(
1122 name=
"TFCDC_TrackQualityEstimator",
1123 filterParameters=cdc_qe_mva_filter_parameters,
1128 basf2.set_module_parameters(
1130 name=
"VXDQualityEstimatorMVA",
1131 WeightFileIdentifier=vxd_identifier)
1134 track_qe_module_name =
"TrackQualityEstimatorMVA"
1135 module_found =
False
1136 new_path = basf2.create_path()
1137 for module
in path.modules():
1138 if module.name() != track_qe_module_name:
1139 if not module.name ==
'TrackCreator':
1140 new_path.add_module(module)
1144 new_path.add_module(
1150 recoTrackColName=
'RecoTracks',
1151 trackColName=
'MDSTTracks')
1152 new_path.add_module(
1153 "TrackQETrainingDataCollector",
1154 TrainingDataOutputName=self.get_output_file_name(self.get_records_file_name()),
1155 collectEventFeatures=
True,
1156 SVDPlusCDCStandaloneRecoTracksStoreArrayName=
"SVDPlusCDCStandaloneRecoTracks",
1159 if not module_found:
1160 raise KeyError(f
"No module {track_qe_module_name} found in path")
1165class TrackQETeacherBaseTask(Basf2Task):
1167 A teacher task runs the basf2 mva teacher on the training data provided by a
1168 data collection task.
1170 Since teacher tasks are needed for all quality estimators covered by this
1171 steering file and the only thing that changes is the required data
1172 collection task and some training parameters, I decided to use inheritance
1173 and have the basic functionality in this base class/interface and have the
1174 specific teacher tasks inherit from it.
1177 n_events_training = b2luigi.IntParameter()
1179 experiment_number = b2luigi.IntParameter()
1183 process_type = b2luigi.Parameter(
1189 training_target = b2luigi.Parameter(
1196 exclude_variables = b2luigi.ListParameter(
1198 hashed=
True, default=[]
1202 fast_bdt_option = b2luigi.ListParameter(
1204 hashed=
True, default=[200, 8, 3, 0.1]
1209 def weightfile_identifier_basename(self):
1211 Property defining the basename for the .xml and .root weightfiles that are created.
1212 Has to be implemented by the inheriting teacher task class.
1214 raise NotImplementedError(
1215 "Teacher Task must define a static weightfile_identifier"
1218 def get_weightfile_xml_identifier(self, fast_bdt_option=None, recotrack_option=None):
1220 Name of the xml weightfile that is created by the teacher task.
1221 It is subsequently used as a local weightfile in the following validation tasks.
1223 if fast_bdt_option
is None:
1224 fast_bdt_option = self.fast_bdt_option
1225 if recotrack_option
is None and hasattr(self,
'recotrack_option'):
1226 recotrack_option = self.recotrack_option
1228 recotrack_option =
''
1229 weightfile_details = create_fbdt_option_string(fast_bdt_option)
1230 weightfile_name = self.weightfile_identifier_basename + weightfile_details
1231 if recotrack_option !=
'':
1232 weightfile_name = weightfile_name +
'_' + recotrack_option
1233 return weightfile_name +
".weights.xml"
1236 def tree_name(self):
1238 Property defining the name of the tree in the ROOT file from the
1239 ``data_collection_task`` that contains the recorded training data. Must
1240 implemented by the inheriting specific teacher task class.
1242 raise NotImplementedError(
"Teacher Task must define a static tree_name")
1245 def random_seed(self):
1247 Property defining random seed to be used by the ``GenerateSimTask``.
1248 Should differ from the random seed in the test data samples. Must
1249 implemented by the inheriting specific teacher task class.
1251 raise NotImplementedError(
"Teacher Task must define a static random seed")
1254 def data_collection_task(self) -> Basf2PathTask:
1256 Property defining the specific ``DataCollectionTask`` to require. Must
1257 implemented by the inheriting specific teacher task class.
1259 raise NotImplementedError(
1260 "Teacher Task must define a data collection task to require "
1265 Generate list of luigi Tasks that this Task depends on.
1267 if 'USEREC' in self.process_type:
1268 if 'USERECBB' in self.process_type:
1270 elif 'USERECEE' in self.process_type:
1272 yield CheckExistingFile(
1273 filename=
'datafiles/qe_records_N' + str(self.n_events_training) +
'_' + process +
'_' + self.random_seed +
'.root',
1276 yield self.data_collection_task(
1277 num_processes=MasterTask.num_processes,
1278 n_events=self.n_events_training,
1279 experiment_number=self.experiment_number,
1280 random_seed=self.process_type +
'_' + self.random_seed,
1285 Generate list of output files that the task should produce.
1286 The task is considered finished if and only if the outputs all exist.
1288 yield self.add_to_output(self.get_weightfile_xml_identifier())
1292 Use basf2_mva teacher to create MVA weightfile from collected training
1295 This is the main process that is dispatched by the ``run`` method that
1296 is inherited from ``Basf2Task``.
1298 if 'USEREC' in self.process_type:
1299 if 'USERECBB' in self.process_type:
1301 elif 'USERECEE' in self.process_type:
1303 records_files = [
'datafiles/qe_records_N' + str(self.n_events_training) +
1304 '_' + process +
'_' + self.random_seed +
'.root']
1306 if hasattr(self,
'recotrack_option'):
1307 records_files = self.get_input_file_names(
1308 self.data_collection_task.get_records_file_name(
1309 self.data_collection_task,
1310 n_events=self.n_events_training,
1311 random_seed=self.process_type +
'_' + self.random_seed,
1312 recotrack_option=self.recotrack_option))
1314 records_files = self.get_input_file_names(
1315 self.data_collection_task.get_records_file_name(
1316 self.data_collection_task,
1317 n_events=self.n_events_training,
1318 random_seed=self.process_type +
'_' + self.random_seed))
1320 my_basf2_mva_teacher(
1321 records_files=records_files,
1322 tree_name=self.tree_name,
1323 weightfile_identifier=self.get_output_file_name(self.get_weightfile_xml_identifier()),
1324 target_variable=self.training_target,
1325 exclude_variables=self.exclude_variables,
1326 fast_bdt_option=self.fast_bdt_option,
1330class VXDQETeacherTask(TrackQETeacherBaseTask):
1332 Task to run basf2 mva teacher on collected data for VXDTF2 track quality estimator
1335 weightfile_identifier_basename =
"vxdtf2_mva_qe"
1340 random_seed =
"train_vxd"
1343 data_collection_task = VXDQEDataCollectionTask
1346class CDCQETeacherTask(TrackQETeacherBaseTask):
1348 Task to run basf2 mva teacher on collected data for CDC track quality estimator
1351 weightfile_identifier_basename =
"cdc_mva_qe"
1354 tree_name =
"records"
1356 random_seed =
"train_cdc"
1359 data_collection_task = CDCQEDataCollectionTask
1362class RecoTrackQETeacherTask(TrackQETeacherBaseTask):
1364 Task to run basf2 mva teacher on collected data for the final, combined
1365 track quality estimator
1370 recotrack_option = b2luigi.Parameter(
1372 default=
'deleteCDCQI080'
1377 weightfile_identifier_basename =
"recotrack_mva_qe"
1382 random_seed =
"train_rec"
1385 data_collection_task = RecoTrackQEDataCollectionTask
1387 cdc_training_target = b2luigi.Parameter()
1391 Generate list of luigi Tasks that this Task depends on.
1393 if 'USEREC' in self.process_type:
1394 if 'USERECBB' in self.process_type:
1396 elif 'USERECEE' in self.process_type:
1398 yield CheckExistingFile(
1399 filename=
'datafiles/qe_records_N' + str(self.n_events_training) +
'_' + process +
'_' + self.random_seed +
'.root',
1402 yield self.data_collection_task(
1403 cdc_training_target=self.cdc_training_target,
1404 num_processes=MasterTask.num_processes,
1405 n_events=self.n_events_training,
1406 experiment_number=self.experiment_number,
1407 random_seed=self.process_type +
'_' + self.random_seed,
1408 recotrack_option=self.recotrack_option,
1409 fast_bdt_option=self.fast_bdt_option,
1413class HarvestingValidationBaseTask(Basf2PathTask):
1415 Run track reconstruction with MVA quality estimator and write out
1416 (="harvest") a root file with variables useful for the validation.
1420 n_events_testing = b2luigi.IntParameter()
1422 n_events_training = b2luigi.IntParameter()
1424 experiment_number = b2luigi.IntParameter()
1428 process_type = b2luigi.Parameter(
1435 exclude_variables = b2luigi.ListParameter(
1441 fast_bdt_option = b2luigi.ListParameter(
1443 hashed=
True, default=[200, 8, 3, 0.1]
1447 validation_output_file_name =
"harvesting_validation.root"
1449 reco_output_file_name =
"reconstruction.root"
1454 def teacher_task(self) -> TrackQETeacherBaseTask:
1456 Teacher task to require to provide a quality estimator weightfile for ``add_tracking_with_quality_estimation``
1458 raise NotImplementedError()
1460 def add_tracking_with_quality_estimation(self, path: basf2.Path) ->
None:
1462 Add modules for track reconstruction to basf2 path that are to be
1463 validated. Besides track finding it should include MC matching, fitted
1464 track creation and a quality estimator module.
1466 raise NotImplementedError()
1470 Generate list of luigi Tasks that this Task depends on.
1472 yield self.teacher_task(
1473 n_events_training=self.n_events_training,
1474 experiment_number=self.experiment_number,
1475 process_type=self.process_type,
1476 exclude_variables=self.exclude_variables,
1477 fast_bdt_option=self.fast_bdt_option,
1479 if 'USE' in self.process_type:
1480 if 'BB' in self.process_type:
1482 elif 'EE' in self.process_type:
1484 yield CheckExistingFile(
1485 filename=
'datafiles/generated_mc_N' + str(self.n_events_testing) +
'_' + process +
'_test.root'
1488 yield SplitNMergeSimTask(
1489 bkgfiles_dir=MasterTask.bkgfiles_by_exp[self.experiment_number],
1490 random_seed=self.process_type +
'_test',
1491 n_events=self.n_events_testing,
1492 experiment_number=self.experiment_number,
1497 Generate list of output files that the task should produce.
1498 The task is considered finished if and only if the outputs all exist.
1500 yield self.add_to_output(self.validation_output_file_name)
1501 yield self.add_to_output(self.reco_output_file_name)
1503 def create_path(self):
1505 Create a basf2 path that uses ``add_tracking_with_quality_estimation()``
1506 and adds the ``CombinedTrackingValidationModule`` to write out variables
1510 path = basf2.create_path()
1511 if 'USE' in self.process_type:
1512 if 'BB' in self.process_type:
1514 elif 'EE' in self.process_type:
1516 inputFileNames = [
'datafiles/generated_mc_N' + str(self.n_events_testing) +
'_' + process +
'_test.root']
1518 inputFileNames = self.get_input_file_names(GenerateSimTask.output_file_name(
1519 GenerateSimTask, n_events=self.n_events_testing, random_seed=self.process_type +
'_test'))
1522 inputFileNames=inputFileNames,
1524 path.add_module(
"Gearbox")
1525 tracking.add_geometry_modules(path)
1526 tracking.add_hit_preparation_modules(path)
1528 self.add_tracking_with_quality_estimation(path)
1531 CombinedTrackingValidationModule(
1535 output_file_name=self.get_output_file_name(
1536 self.validation_output_file_name
1542 outputFileName=self.get_output_file_name(self.reco_output_file_name),
1547class VXDQEHarvestingValidationTask(HarvestingValidationBaseTask):
1549 Run VXDTF2 track reconstruction and write out (="harvest") a root file with
1550 variables useful for validation of the VXD Quality Estimator.
1554 validation_output_file_name =
"vxd_qe_harvesting_validation.root"
1556 reco_output_file_name =
"vxd_qe_reconstruction.root"
1558 teacher_task = VXDQETeacherTask
1560 def add_tracking_with_quality_estimation(self, path):
1562 Add modules for VXDTF2 tracking with VXD quality estimator to basf2 path.
1564 tracking.add_vxd_track_finding_vxdtf2(
1567 reco_tracks=
"RecoTracks",
1568 add_mva_quality_indicator=
True,
1572 basf2.set_module_parameters(
1574 name=
"VXDQualityEstimatorMVA",
1575 WeightFileIdentifier=self.get_input_file_names(
1576 self.teacher_task.get_weightfile_xml_identifier(self.teacher_task, fast_bdt_option=self.fast_bdt_option)
1579 tracking.add_mc_matcher(path, components=[
"SVD"])
1580 tracking.add_track_fit_and_track_creator(path, components=[
"SVD"])
1583class CDCQEHarvestingValidationTask(HarvestingValidationBaseTask):
1585 Run CDC reconstruction and write out (="harvest") a root file with variables
1586 useful for validation of the CDC Quality Estimator.
1589 training_target = b2luigi.Parameter()
1591 validation_output_file_name =
"cdc_qe_harvesting_validation.root"
1593 reco_output_file_name =
"cdc_qe_reconstruction.root"
1595 teacher_task = CDCQETeacherTask
1600 Generate list of luigi Tasks that this Task depends on.
1602 yield self.teacher_task(
1603 n_events_training=self.n_events_training,
1604 experiment_number=self.experiment_number,
1605 process_type=self.process_type,
1606 training_target=self.training_target,
1607 exclude_variables=self.exclude_variables,
1608 fast_bdt_option=self.fast_bdt_option,
1610 if 'USE' in self.process_type:
1611 if 'BB' in self.process_type:
1613 elif 'EE' in self.process_type:
1615 yield CheckExistingFile(
1616 filename=
'datafiles/generated_mc_N' + str(self.n_events_testing) +
'_' + process +
'_test.root'
1619 yield SplitNMergeSimTask(
1620 bkgfiles_dir=MasterTask.bkgfiles_by_exp[self.experiment_number],
1621 random_seed=self.process_type +
'_test',
1622 n_events=self.n_events_testing,
1623 experiment_number=self.experiment_number,
1626 def add_tracking_with_quality_estimation(self, path):
1628 Add modules for CDC standalone tracking with CDC quality estimator to basf2 path.
1630 tracking.add_cdc_track_finding(
1632 output_reco_tracks=
"RecoTracks",
1633 add_mva_quality_indicator=
True,
1636 cdc_qe_mva_filter_parameters = {
1637 "identifier": self.get_input_file_names(
1638 CDCQETeacherTask.get_weightfile_xml_identifier(
1640 fast_bdt_option=self.fast_bdt_option))[0]}
1641 basf2.set_module_parameters(
1643 name=
"TFCDC_TrackQualityEstimator",
1644 filterParameters=cdc_qe_mva_filter_parameters,
1646 tracking.add_mc_matcher(path, components=[
"CDC"])
1647 tracking.add_track_fit_and_track_creator(path, components=[
"CDC"])
1650class RecoTrackQEHarvestingValidationTask(HarvestingValidationBaseTask):
1652 Run track reconstruction and write out (="harvest") a root file with variables
1653 useful for validation of the MVA track Quality Estimator.
1656 cdc_training_target = b2luigi.Parameter()
1658 validation_output_file_name =
"reco_qe_harvesting_validation.root"
1660 reco_output_file_name =
"reco_qe_reconstruction.root"
1662 teacher_task = RecoTrackQETeacherTask
1666 Generate list of luigi Tasks that this Task depends on.
1668 yield CDCQETeacherTask(
1669 n_events_training=self.n_events_training,
1670 experiment_number=self.experiment_number,
1671 process_type=self.process_type,
1672 training_target=self.cdc_training_target,
1673 exclude_variables=MasterTask.exclude_variables_cdc,
1674 fast_bdt_option=self.fast_bdt_option,
1676 yield VXDQETeacherTask(
1677 n_events_training=self.n_events_training,
1678 experiment_number=self.experiment_number,
1679 process_type=self.process_type,
1680 exclude_variables=MasterTask.exclude_variables_vxd,
1681 fast_bdt_option=self.fast_bdt_option,
1684 yield self.teacher_task(
1685 n_events_training=self.n_events_training,
1686 experiment_number=self.experiment_number,
1687 process_type=self.process_type,
1688 exclude_variables=self.exclude_variables,
1689 cdc_training_target=self.cdc_training_target,
1690 fast_bdt_option=self.fast_bdt_option,
1692 if 'USE' in self.process_type:
1693 if 'BB' in self.process_type:
1695 elif 'EE' in self.process_type:
1697 yield CheckExistingFile(
1698 filename=
'datafiles/generated_mc_N' + str(self.n_events_testing) +
'_' + process +
'_test.root'
1701 yield SplitNMergeSimTask(
1702 bkgfiles_dir=MasterTask.bkgfiles_by_exp[self.experiment_number],
1703 random_seed=self.process_type +
'_test',
1704 n_events=self.n_events_testing,
1705 experiment_number=self.experiment_number,
1708 def add_tracking_with_quality_estimation(self, path):
1710 Add modules for reco tracking with all track quality estimators to basf2 path.
1714 tracking.add_tracking_reconstruction(
1716 add_cdcTrack_QI=
True,
1717 add_vxdTrack_QI=
True,
1718 add_recoTrack_QI=
True,
1719 skipGeometryAdding=
True,
1720 skipHitPreparerAdding=
False,
1725 cdc_qe_mva_filter_parameters = {
1726 "identifier": self.get_input_file_names(
1727 CDCQETeacherTask.get_weightfile_xml_identifier(
1729 fast_bdt_option=self.fast_bdt_option))[0]}
1730 basf2.set_module_parameters(
1732 name=
"TFCDC_TrackQualityEstimator",
1733 filterParameters=cdc_qe_mva_filter_parameters,
1735 basf2.set_module_parameters(
1737 name=
"VXDQualityEstimatorMVA",
1738 WeightFileIdentifier=self.get_input_file_names(
1739 VXDQETeacherTask.get_weightfile_xml_identifier(VXDQETeacherTask, fast_bdt_option=self.fast_bdt_option)
1742 basf2.set_module_parameters(
1744 name=
"TrackQualityEstimatorMVA",
1745 WeightFileIdentifier=self.get_input_file_names(
1746 RecoTrackQETeacherTask.get_weightfile_xml_identifier(RecoTrackQETeacherTask, fast_bdt_option=self.fast_bdt_option)
1751class TrackQEEvaluationBaseTask(Task):
1753 Base class for evaluating a quality estimator ``basf2_mva_evaluate.py`` on a
1754 separate test data set.
1756 Evaluation tasks for VXD, CDC and combined QE can inherit from it.
1764 git_hash = b2luigi.Parameter(
1766 default=get_basf2_git_hash()
1770 n_events_testing = b2luigi.IntParameter()
1772 n_events_training = b2luigi.IntParameter()
1774 experiment_number = b2luigi.IntParameter()
1778 process_type = b2luigi.Parameter(
1784 training_target = b2luigi.Parameter(
1791 exclude_variables = b2luigi.ListParameter(
1797 fast_bdt_option = b2luigi.ListParameter(
1799 hashed=
True, default=[200, 8, 3, 0.1]
1804 def teacher_task(self) -> TrackQETeacherBaseTask:
1806 Property defining specific teacher task to require.
1808 raise NotImplementedError(
1809 "Evaluation Tasks must define a teacher task to require "
1813 def data_collection_task(self) -> Basf2PathTask:
1815 Property defining the specific ``DataCollectionTask`` to require. Must
1816 implemented by the inheriting specific teacher task class.
1818 raise NotImplementedError(
1819 "Evaluation Tasks must define a data collection task to require "
1823 def task_acronym(self):
1825 Acronym to distinguish between cdc, vxd and rec(o) MVA
1827 raise NotImplementedError(
1828 "Evaluation Tasks must define a task acronym."
1833 Generate list of luigi Tasks that this Task depends on.
1835 yield self.teacher_task(
1836 n_events_training=self.n_events_training,
1837 experiment_number=self.experiment_number,
1838 process_type=self.process_type,
1839 training_target=self.training_target,
1840 exclude_variables=self.exclude_variables,
1841 fast_bdt_option=self.fast_bdt_option,
1843 if 'USEREC' in self.process_type:
1844 if 'USERECBB' in self.process_type:
1846 elif 'USERECEE' in self.process_type:
1848 yield CheckExistingFile(
1849 filename=
'datafiles/qe_records_N' + str(self.n_events_testing) +
'_' + process +
'_test_' +
1850 self.task_acronym +
'.root'
1853 yield self.data_collection_task(
1854 num_processes=MasterTask.num_processes,
1855 n_events=self.n_events_testing,
1856 experiment_number=self.experiment_number,
1857 random_seed=self.process_type +
'_test',
1862 Generate list of output files that the task should produce.
1863 The task is considered finished if and only if the outputs all exist.
1865 weightfile_details = create_fbdt_option_string(self.fast_bdt_option)
1866 evaluation_pdf_output = self.teacher_task.weightfile_identifier_basename + weightfile_details +
".pdf"
1867 yield self.add_to_output(evaluation_pdf_output)
1869 @b2luigi.on_temporary_files
1872 Run ``basf2_mva_evaluate.py`` subprocess to evaluate QE MVA.
1874 The MVA weight file created from training on the training data set is
1875 evaluated on separate test data.
1877 weightfile_details = create_fbdt_option_string(self.fast_bdt_option)
1878 evaluation_pdf_output_basename = self.teacher_task.weightfile_identifier_basename + weightfile_details +
".pdf"
1880 evaluation_pdf_output_path = self.get_output_file_name(evaluation_pdf_output_basename)
1882 if 'USEREC' in self.process_type:
1883 if 'USERECBB' in self.process_type:
1885 elif 'USERECEE' in self.process_type:
1887 datafiles =
'datafiles/qe_records_N' + str(self.n_events_testing) +
'_' + \
1888 process +
'_test_' + self.task_acronym +
'.root'
1890 datafiles = self.get_input_file_names(
1891 self.data_collection_task.get_records_file_name(
1892 self.data_collection_task,
1893 n_events=self.n_events_testing,
1894 random_seed=self.process +
'_test_' +
1895 self.task_acronym))[0]
1897 "basf2_mva_evaluate.py",
1899 self.get_input_file_names(
1900 self.teacher_task.get_weightfile_xml_identifier(
1902 fast_bdt_option=self.fast_bdt_option))[0],
1906 self.teacher_task.tree_name,
1908 evaluation_pdf_output_path,
1912 log_file_dir = get_log_file_dir(self)
1916 os.makedirs(log_file_dir, exist_ok=
True)
1919 except FileExistsError:
1920 print(
'Directory ' + log_file_dir +
'already exists.')
1921 stderr_log_file_path = log_file_dir +
"stderr"
1922 stdout_log_file_path = log_file_dir +
"stdout"
1923 with open(stdout_log_file_path,
"w")
as stdout_file:
1924 stdout_file.write(f
'stdout output of the command:\n{" ".join(cmd)}\n\n')
1925 if os.path.exists(stderr_log_file_path):
1927 os.remove(stderr_log_file_path)
1930 with open(stdout_log_file_path,
"a")
as stdout_file:
1931 with open(stderr_log_file_path,
"a")
as stderr_file:
1933 subprocess.run(cmd, check=
True, stdin=stdout_file, stderr=stderr_file)
1934 except subprocess.CalledProcessError
as err:
1935 stderr_file.write(f
"Evaluation failed with error:\n{err}")
1939class VXDTrackQEEvaluationTask(TrackQEEvaluationBaseTask):
1941 Run ``basf2_mva_evaluate.py`` for the VXD quality estimator on separate test data
1945 teacher_task = VXDQETeacherTask
1948 data_collection_task = VXDQEDataCollectionTask
1951 task_acronym =
'vxd'
1954class CDCTrackQEEvaluationTask(TrackQEEvaluationBaseTask):
1956 Run ``basf2_mva_evaluate.py`` for the CDC quality estimator on separate test data
1960 teacher_task = CDCQETeacherTask
1963 data_collection_task = CDCQEDataCollectionTask
1966 task_acronym =
'cdc'
1969class RecoTrackQEEvaluationTask(TrackQEEvaluationBaseTask):
1971 Run ``basf2_mva_evaluate.py`` for the final, combined quality estimator on
1976 teacher_task = RecoTrackQETeacherTask
1979 data_collection_task = RecoTrackQEDataCollectionTask
1982 task_acronym =
'rec'
1984 cdc_training_target = b2luigi.Parameter()
1988 Generate list of luigi Tasks that this Task depends on.
1990 yield self.teacher_task(
1991 n_events_training=self.n_events_training,
1992 experiment_number=self.experiment_number,
1993 process_type=self.process_type,
1994 training_target=self.training_target,
1995 exclude_variables=self.exclude_variables,
1996 cdc_training_target=self.cdc_training_target,
1997 fast_bdt_option=self.fast_bdt_option,
1999 if 'USEREC' in self.process_type:
2000 if 'USERECBB' in self.process_type:
2002 elif 'USERECEE' in self.process_type:
2004 yield CheckExistingFile(
2005 filename=
'datafiles/qe_records_N' + str(self.n_events_testing) +
'_' + process +
'_test_' +
2006 self.task_acronym +
'.root'
2009 yield self.data_collection_task(
2010 num_processes=MasterTask.num_processes,
2011 n_events=self.n_events_testing,
2012 experiment_number=self.experiment_number,
2013 random_seed=self.process_type +
"_test",
2014 cdc_training_target=self.cdc_training_target,
2018class PlotsFromHarvestingValidationBaseTask(Basf2Task):
2020 Create a PDF file with validation plots for a quality estimator produced
2021 from the ROOT ntuples produced by a harvesting validation task
2024 n_events_testing = b2luigi.IntParameter()
2026 n_events_training = b2luigi.IntParameter()
2028 experiment_number = b2luigi.IntParameter()
2032 process_type = b2luigi.Parameter(
2039 exclude_variables = b2luigi.ListParameter(
2045 fast_bdt_option = b2luigi.ListParameter(
2047 hashed=
True, default=[200, 8, 3, 0.1]
2051 primaries_only = b2luigi.BoolParameter(
2058 def harvesting_validation_task_instance(self) -> HarvestingValidationBaseTask:
2060 Specifies related harvesting validation task which produces the ROOT
2061 files with the data that is plotted by this task.
2063 raise NotImplementedError(
"Must define a QI harvesting validation task for which to do the plots")
2066 def output_pdf_file_basename(self):
2068 Name of the output PDF file containing the validation plots
2070 validation_harvest_basename = self.harvesting_validation_task_instance.validation_output_file_name
2071 return validation_harvest_basename.replace(
".root",
"_plots.pdf")
2075 Generate list of luigi Tasks that this Task depends on.
2077 yield self.harvesting_validation_task_instance
2081 Generate list of output files that the task should produce.
2082 The task is considered finished if and only if the outputs all exist.
2084 yield self.add_to_output(self.output_pdf_file_basename)
2086 @b2luigi.on_temporary_files
2089 Use basf2_mva teacher to create MVA weightfile from collected training
2092 Main process that is dispatched by the ``run`` method that is inherited
2096 validation_harvest_basename = self.harvesting_validation_task_instance.validation_output_file_name
2097 validation_harvest_path = self.get_input_file_names(validation_harvest_basename)[0]
2101 'is_fake',
'is_clone',
'is_matched',
'quality_indicator',
2102 'experiment_number',
'run_number',
'event_number',
'pr_store_array_number',
2103 'pt_estimate',
'z0_estimate',
'd0_estimate',
'tan_lambda_estimate',
2104 'phi0_estimate',
'pt_truth',
'z0_truth',
'd0_truth',
'tan_lambda_truth',
2108 pr_df = uproot.open(validation_harvest_path)[
'pr_tree/pr_tree'].arrays(pr_columns, library=
'pd')
2110 'experiment_number',
2113 'pr_store_array_number',
2118 mc_df = uproot.open(validation_harvest_path)[
'mc_tree/mc_tree'].arrays(mc_columns, library=
'pd')
2119 if self.primaries_only:
2120 mc_df = mc_df[mc_df.is_primary.eq(
True)]
2123 qi_cuts = np.linspace(0., 1, 20, endpoint=
False)
2129 output_pdf_file_path = self.get_output_file_name(self.output_pdf_file_basename)
2130 with PdfPages(output_pdf_file_path, keep_empty=
False)
as pdf:
2135 titlepage_fig, titlepage_ax = plt.subplots()
2136 titlepage_ax.axis(
"off")
2137 title = f
"Quality Estimator validation plots from {self.__class__.__name__}"
2138 titlepage_ax.set_title(title)
2139 teacher_task = self.harvesting_validation_task_instance.teacher_task
2140 weightfile_identifier = teacher_task.get_weightfile_xml_identifier(teacher_task, fast_bdt_option=self.fast_bdt_option)
2142 "Date": datetime.today().strftime(
"%Y-%m-%d %H:%M"),
2143 "Created by steering file": os.path.realpath(__file__),
2144 "Created from data in": validation_harvest_path,
2145 "Background directory": MasterTask.bkgfiles_by_exp[self.experiment_number],
2146 "weight file": weightfile_identifier,
2148 if hasattr(self,
'exclude_variables'):
2149 meta_data[
"Excluded variables"] =
", ".join(self.exclude_variables)
2150 meta_data_string = (format_dictionary(meta_data) +
2151 "\n\n(For all MVA training parameters look into the produced weight file)")
2152 luigi_params = get_serialized_parameters(self)
2153 luigi_param_string = (f
"\n\nb2luigi parameters for {self.__class__.__name__}\n" +
2154 format_dictionary(luigi_params))
2155 title_page_text = meta_data_string + luigi_param_string
2156 titlepage_ax.text(0, 1, title_page_text, ha=
"left", va=
"top", wrap=
True, fontsize=8)
2157 pdf.savefig(titlepage_fig)
2158 plt.close(titlepage_fig)
2160 fake_rates = get_uncertain_means_for_qi_cuts(pr_df,
"is_fake", qi_cuts)
2161 fake_fig, fake_ax = plt.subplots()
2162 fake_ax.set_title(
"Fake rate")
2163 plot_with_errobands(fake_rates, ax=fake_ax)
2164 fake_ax.set_ylabel(
"fake rate")
2165 fake_ax.set_xlabel(
"quality indicator requirement")
2166 pdf.savefig(fake_fig, bbox_inches=
"tight")
2170 clone_rates = get_uncertain_means_for_qi_cuts(pr_df,
"is_clone", qi_cuts)
2171 clone_fig, clone_ax = plt.subplots()
2172 clone_ax.set_title(
"Clone rate")
2173 plot_with_errobands(clone_rates, ax=clone_ax)
2174 clone_ax.set_ylabel(
"clone rate")
2175 clone_ax.set_xlabel(
"quality indicator requirement")
2176 pdf.savefig(clone_fig, bbox_inches=
"tight")
2177 plt.close(clone_fig)
2184 pr_track_identifiers = [
'experiment_number',
'run_number',
'event_number',
'pr_store_array_number']
2186 left=mc_df, right=pr_df[pr_track_identifiers + [
'quality_indicator']],
2188 on=pr_track_identifiers
2191 missing_fractions = (
2192 _my_uncertain_mean(mc_df[
2193 mc_df.quality_indicator.isnull() | (mc_df.quality_indicator > qi_cut)][
'is_missing'])
2194 for qi_cut
in qi_cuts
2197 findeff_fig, findeff_ax = plt.subplots()
2198 findeff_ax.set_title(
"Finding efficiency")
2199 finding_efficiencies = 1.0 - upd.Series(data=missing_fractions, index=qi_cuts)
2200 plot_with_errobands(finding_efficiencies, ax=findeff_ax)
2201 findeff_ax.set_ylabel(
"finding efficiency")
2202 findeff_ax.set_xlabel(
"quality indicator requirement")
2203 pdf.savefig(findeff_fig, bbox_inches=
"tight")
2204 plt.close(findeff_fig)
2209 fake_roc_fig, fake_roc_ax = plt.subplots()
2210 fake_roc_ax.set_title(
"Fake rate vs. finding efficiency ROC curve")
2211 fake_roc_ax.errorbar(x=finding_efficiencies.nominal_value, y=fake_rates.nominal_value,
2212 xerr=finding_efficiencies.std_dev, yerr=fake_rates.std_dev, elinewidth=0.8)
2213 fake_roc_ax.set_xlabel(
'finding efficiency')
2214 fake_roc_ax.set_ylabel(
'fake rate')
2215 pdf.savefig(fake_roc_fig, bbox_inches=
"tight")
2216 plt.close(fake_roc_fig)
2219 clone_roc_fig, clone_roc_ax = plt.subplots()
2220 clone_roc_ax.set_title(
"Clone rate vs. finding efficiency ROC curve")
2221 clone_roc_ax.errorbar(x=finding_efficiencies.nominal_value, y=clone_rates.nominal_value,
2222 xerr=finding_efficiencies.std_dev, yerr=clone_rates.std_dev, elinewidth=0.8)
2223 clone_roc_ax.set_xlabel(
'finding efficiency')
2224 clone_roc_ax.set_ylabel(
'clone rate')
2225 pdf.savefig(clone_roc_fig, bbox_inches=
"tight")
2226 plt.close(clone_roc_fig)
2231 kinematic_qi_cuts = [0, 0.5, 0.9]
2235 params = [
'd0',
'z0',
'pt',
'tan_lambda',
'phi0']
2240 "tan_lambda":
r"$\tan{\lambda}$",
2247 "tan_lambda":
"rad",
2250 n_kinematic_bins = 75
2252 "pt": np.linspace(0, np.percentile(pr_df[
'pt_truth'].dropna(), 95), n_kinematic_bins),
2253 "z0": np.linspace(-0.1, 0.1, n_kinematic_bins),
2254 "d0": np.linspace(0, 0.01, n_kinematic_bins),
2255 "tan_lambda": np.linspace(-2, 3, n_kinematic_bins),
2256 "phi0": np.linspace(0, 2 * np.pi, n_kinematic_bins)
2260 kinematic_qi_cuts = [0, 0.5, 0.8]
2261 blue, yellow, green = plt.get_cmap(
"tab10").colors[0:3]
2262 for param
in params:
2263 fig, axarr = plt.subplots(ncols=len(kinematic_qi_cuts), sharey=
True, sharex=
True, figsize=(14, 6))
2264 fig.suptitle(f
"{label_by_param[param]} distributions")
2265 for i, qi
in enumerate(kinematic_qi_cuts):
2267 ax.set_title(f
"QI > {qi}")
2268 incut = pr_df[(pr_df[
'quality_indicator'] > qi)]
2269 incut_matched = incut[incut.is_matched.eq(
True)]
2270 incut_clones = incut[incut.is_clone.eq(
True)]
2271 incut_fake = incut[incut.is_fake.eq(
True)]
2274 if any(series.empty
for series
in (incut, incut_matched, incut_clones, incut_fake)):
2275 ax.text(0.5, 0.5,
"Not enough data in bin", ha=
"center", va=
"center", transform=ax.transAxes)
2278 bins = bins_by_param[param]
2279 stacked_histogram_series_tuple = (
2280 incut_matched[f
'{param}_estimate'],
2281 incut_clones[f
'{param}_estimate'],
2282 incut_fake[f
'{param}_estimate'],
2284 histvals, _, _ = ax.hist(stacked_histogram_series_tuple,
2286 bins=bins, range=(bins.min(), bins.max()),
2287 color=(blue, green, yellow),
2288 label=(
"matched",
"clones",
"fakes"))
2289 ax.set_xlabel(f
'{label_by_param[param]} estimate / ({unit_by_param[param]})')
2290 ax.set_ylabel(
'# tracks')
2291 axarr[0].legend(loc=
"upper center", bbox_to_anchor=(0, -0.15))
2292 pdf.savefig(fig, bbox_inches=
"tight")
2296class VXDQEValidationPlotsTask(PlotsFromHarvestingValidationBaseTask):
2298 Create a PDF file with validation plots for the VXDTF2 track quality
2299 estimator produced from the ROOT ntuples produced by a VXDTF2 track QE
2300 harvesting validation task
2304 def harvesting_validation_task_instance(self):
2306 Harvesting validation task to require, which produces the ROOT files
2307 with variables to produce the VXD QE validation plots.
2309 return VXDQEHarvestingValidationTask(
2310 n_events_testing=self.n_events_testing,
2311 n_events_training=self.n_events_training,
2312 process_type=self.process_type,
2313 experiment_number=self.experiment_number,
2314 exclude_variables=self.exclude_variables,
2315 num_processes=MasterTask.num_processes,
2316 fast_bdt_option=self.fast_bdt_option,
2320class CDCQEValidationPlotsTask(PlotsFromHarvestingValidationBaseTask):
2322 Create a PDF file with validation plots for the CDC track quality estimator
2323 produced from the ROOT ntuples produced by a CDC track QE harvesting
2327 training_target = b2luigi.Parameter()
2330 def harvesting_validation_task_instance(self):
2332 Harvesting validation task to require, which produces the ROOT files
2333 with variables to produce the CDC QE validation plots.
2335 return CDCQEHarvestingValidationTask(
2336 n_events_testing=self.n_events_testing,
2337 n_events_training=self.n_events_training,
2338 process_type=self.process_type,
2339 experiment_number=self.experiment_number,
2340 training_target=self.training_target,
2341 exclude_variables=self.exclude_variables,
2342 num_processes=MasterTask.num_processes,
2343 fast_bdt_option=self.fast_bdt_option,
2347class RecoTrackQEValidationPlotsTask(PlotsFromHarvestingValidationBaseTask):
2349 Create a PDF file with validation plots for the reco MVA track quality
2350 estimator produced from the ROOT ntuples produced by a reco track QE
2351 harvesting validation task
2354 cdc_training_target = b2luigi.Parameter()
2357 def harvesting_validation_task_instance(self):
2359 Harvesting validation task to require, which produces the ROOT files
2360 with variables to produce the final MVA track QE validation plots.
2362 return RecoTrackQEHarvestingValidationTask(
2363 n_events_testing=self.n_events_testing,
2364 n_events_training=self.n_events_training,
2365 process_type=self.process_type,
2366 experiment_number=self.experiment_number,
2367 cdc_training_target=self.cdc_training_target,
2368 exclude_variables=self.exclude_variables,
2369 num_processes=MasterTask.num_processes,
2370 fast_bdt_option=self.fast_bdt_option,
2374class QEWeightsLocalDBCreatorTask(Basf2Task):
2376 Collect weightfile identifiers from different teacher tasks and merge them
2377 into a local database for testing.
2380 n_events_training = b2luigi.IntParameter()
2382 experiment_number = b2luigi.IntParameter()
2386 process_type = b2luigi.Parameter(
2392 cdc_training_target = b2luigi.Parameter()
2394 fast_bdt_option = b2luigi.ListParameter(
2396 hashed=
True, default=[200, 8, 3, 0.1]
2402 Required teacher tasks
2404 yield VXDQETeacherTask(
2405 n_events_training=self.n_events_training,
2406 process_type=self.process_type,
2407 experiment_number=self.experiment_number,
2408 exclude_variables=MasterTask.exclude_variables_vxd,
2409 fast_bdt_option=self.fast_bdt_option,
2411 yield CDCQETeacherTask(
2412 n_events_training=self.n_events_training,
2413 process_type=self.process_type,
2414 experiment_number=self.experiment_number,
2415 training_target=self.cdc_training_target,
2416 exclude_variables=MasterTask.exclude_variables_cdc,
2417 fast_bdt_option=self.fast_bdt_option,
2419 yield RecoTrackQETeacherTask(
2420 n_events_training=self.n_events_training,
2421 process_type=self.process_type,
2422 experiment_number=self.experiment_number,
2423 cdc_training_target=self.cdc_training_target,
2424 exclude_variables=MasterTask.exclude_variables_rec,
2425 fast_bdt_option=self.fast_bdt_option,
2432 yield self.add_to_output(
"localdb.tar")
2436 Create local database
2438 current_path = Path.cwd()
2439 localdb_archive_path = Path(self.get_output_file_name(
"localdb.tar")).absolute()
2440 output_dir = localdb_archive_path.parent
2445 for task
in (VXDQETeacherTask, CDCQETeacherTask, RecoTrackQETeacherTask):
2447 weightfile_xml_identifier_path = os.path.abspath(self.get_input_file_names(
2448 task.get_weightfile_xml_identifier(task, fast_bdt_option=self.fast_bdt_option))[0])
2451 os.chdir(output_dir)
2454 weightfile_xml_identifier_path,
2455 task.weightfile_identifier_basename,
2456 self.experiment_number, 0,
2457 self.experiment_number, -1,
2460 os.chdir(current_path)
2463 shutil.make_archive(
2464 base_name=localdb_archive_path.as_posix().split(
'.')[0],
2466 root_dir=output_dir,
2473 Remove local database and tar archives in output directory
2475 localdb_archive_path = Path(self.get_output_file_name(
"localdb.tar"))
2476 localdb_path = localdb_archive_path.parent /
"localdb"
2478 if localdb_path.exists():
2479 print(f
"Deleting localdb\n{localdb_path}\nwith contents\n ",
2480 "\n ".join(f.name
for f
in localdb_path.iterdir()))
2481 shutil.rmtree(localdb_path, ignore_errors=
False)
2483 if localdb_archive_path.is_file():
2484 print(f
"Deleting {localdb_archive_path}")
2485 os.remove(localdb_archive_path)
2487 def on_failure(self, exception):
2489 Cleanup: Remove local database to prevent existing outputs when task did not finish successfully
2493 super().on_failure(exception)
2496class MasterTask(b2luigi.WrapperTask):
2498 Wrapper task that needs to finish for b2luigi to finish running this steering file.
2500 It is done if the outputs of all required subtasks exist. It is thus at the
2501 top of the luigi task graph. Edit the ``requires`` method to steer which
2502 tasks and with which parameters you want to run.
2507 process_type = b2luigi.get_setting(
2509 "process_type", default=
'BBBAR'
2513 n_events_training = b2luigi.get_setting(
2515 "n_events_training", default=20000
2519 n_events_testing = b2luigi.get_setting(
2521 "n_events_testing", default=5000
2525 n_events_per_task = b2luigi.get_setting(
2527 "n_events_per_task", default=100
2531 num_processes = b2luigi.get_setting(
2533 "basf2_processes_per_worker", default=0
2537 datafiles = b2luigi.get_setting(
"datafiles")
2539 bkgfiles_by_exp = b2luigi.get_setting(
"bkgfiles_by_exp")
2541 bkgfiles_by_exp = {int(key): val
for (key, val)
in bkgfiles_by_exp.items()}
2543 exclude_variables_cdc = [
2544 "has_matching_segment",
2549 "cont_layer_variance",
2554 "cont_layer_max_vs_last",
2555 "cont_layer_first_vs_min",
2557 "cont_layer_occupancy",
2559 "super_layer_variance",
2560 "super_layer_max_vs_last",
2561 "super_layer_first_vs_min",
2562 "super_layer_occupancy",
2563 "drift_length_mean",
2564 "drift_length_variance",
2568 "norm_drift_length_mean",
2569 "norm_drift_length_variance",
2570 "norm_drift_length_max",
2571 "norm_drift_length_min",
2572 "norm_drift_length_sum",
2587 exclude_variables_vxd = [
2588 'energyLoss_max',
'energyLoss_min',
'energyLoss_mean',
'energyLoss_std',
'energyLoss_sum',
2589 'size_max',
'size_min',
'size_mean',
'size_std',
'size_sum',
2590 'seedCharge_max',
'seedCharge_min',
'seedCharge_mean',
'seedCharge_std',
'seedCharge_sum',
2591 'tripletFit_P_Mag',
'tripletFit_P_Eta',
'tripletFit_P_Phi',
'tripletFit_P_X',
'tripletFit_P_Y',
'tripletFit_P_Z']
2593 exclude_variables_rec = [
2605 'N_diff_PXD_SVD_RecoTracks',
2606 'N_diff_SVD_CDC_RecoTracks',
2608 'Fit_NFailedPoints',
2610 'N_TrackPoints_without_KalmanFitterInfo',
2611 'N_Hits_without_TrackPoint',
2612 'SVD_CDC_CDCwall_Chi2',
2613 'SVD_CDC_CDCwall_Pos_diff_Z',
2614 'SVD_CDC_CDCwall_Pos_diff_Pt',
2615 'SVD_CDC_CDCwall_Pos_diff_Theta',
2616 'SVD_CDC_CDCwall_Pos_diff_Phi',
2617 'SVD_CDC_CDCwall_Pos_diff_Mag',
2618 'SVD_CDC_CDCwall_Pos_diff_Eta',
2619 'SVD_CDC_CDCwall_Mom_diff_Z',
2620 'SVD_CDC_CDCwall_Mom_diff_Pt',
2621 'SVD_CDC_CDCwall_Mom_diff_Theta',
2622 'SVD_CDC_CDCwall_Mom_diff_Phi',
2623 'SVD_CDC_CDCwall_Mom_diff_Mag',
2624 'SVD_CDC_CDCwall_Mom_diff_Eta',
2625 'SVD_CDC_POCA_Pos_diff_Z',
2626 'SVD_CDC_POCA_Pos_diff_Pt',
2627 'SVD_CDC_POCA_Pos_diff_Theta',
2628 'SVD_CDC_POCA_Pos_diff_Phi',
2629 'SVD_CDC_POCA_Pos_diff_Mag',
2630 'SVD_CDC_POCA_Pos_diff_Eta',
2631 'SVD_CDC_POCA_Mom_diff_Z',
2632 'SVD_CDC_POCA_Mom_diff_Pt',
2633 'SVD_CDC_POCA_Mom_diff_Theta',
2634 'SVD_CDC_POCA_Mom_diff_Phi',
2635 'SVD_CDC_POCA_Mom_diff_Mag',
2636 'SVD_CDC_POCA_Mom_diff_Eta',
2643 'SVD_FitSuccessful',
2644 'CDC_FitSuccessful',
2647 'is_Vzero_Daughter',
2659 'weight_firstCDCHit',
2660 'weight_lastSVDHit',
2663 'smoothedChi2_mean',
2665 'smoothedChi2_median',
2666 'smoothedChi2_n_zeros',
2667 'smoothedChi2_firstCDCHit',
2668 'smoothedChi2_lastSVDHit']
2672 Generate list of tasks that needs to be done for luigi to finish running
2675 cdc_training_targets = [
2680 fast_bdt_options = []
2689 fast_bdt_options.append([350, 6, 5, 0.1])
2691 experiment_numbers = b2luigi.get_setting(
"experiment_numbers")
2694 for experiment_number, cdc_training_target, fast_bdt_option
in itertools.product(
2695 experiment_numbers, cdc_training_targets, fast_bdt_options
2698 if b2luigi.get_setting(
"test_selected_task", default=
False):
2701 for cut
in [
'000',
'070',
'090',
'095']:
2702 yield RecoTrackQEDataCollectionTask(
2703 num_processes=self.num_processes,
2704 n_events=self.n_events_testing,
2705 experiment_number=experiment_number,
2706 random_seed=self.process_type +
'_test',
2707 recotrack_option=
'useCDC_noVXD_deleteCDCQI'+cut,
2708 cdc_training_target=cdc_training_target,
2709 fast_bdt_option=fast_bdt_option,
2711 yield CDCQEDataCollectionTask(
2712 num_processes=self.num_processes,
2713 n_events=self.n_events_testing,
2714 experiment_number=experiment_number,
2715 random_seed=self.process_type +
'_test',
2717 yield CDCQETeacherTask(
2718 n_events_training=self.n_events_training,
2719 process_type=self.process_type,
2720 experiment_number=experiment_number,
2721 exclude_variables=self.exclude_variables_cdc,
2722 training_target=cdc_training_target,
2723 fast_bdt_option=fast_bdt_option,
2727 if 'DATA' in self.process_type:
2728 yield VXDQEDataCollectionTask(
2729 num_processes=self.num_processes,
2730 n_events=self.n_events_testing,
2731 experiment_number=experiment_number,
2732 random_seed=self.process_type +
'_test',
2734 yield CDCQEDataCollectionTask(
2735 num_processes=self.num_processes,
2736 n_events=self.n_events_testing,
2737 experiment_number=experiment_number,
2738 random_seed=self.process_type +
'_test',
2740 yield RecoTrackQEDataCollectionTask(
2741 num_processes=self.num_processes,
2742 n_events=self.n_events_testing,
2743 experiment_number=experiment_number,
2744 random_seed=self.process_type +
'_test',
2745 recotrack_option=
'deleteCDCQI080',
2746 cdc_training_target=cdc_training_target,
2747 fast_bdt_option=fast_bdt_option,
2750 yield QEWeightsLocalDBCreatorTask(
2751 n_events_training=self.n_events_training,
2752 process_type=self.process_type,
2753 experiment_number=experiment_number,
2754 cdc_training_target=cdc_training_target,
2755 fast_bdt_option=fast_bdt_option,
2758 if b2luigi.get_setting(
"run_validation_tasks", default=
True):
2759 yield RecoTrackQEValidationPlotsTask(
2760 n_events_training=self.n_events_training,
2761 n_events_testing=self.n_events_testing,
2762 process_type=self.process_type,
2763 experiment_number=experiment_number,
2764 cdc_training_target=cdc_training_target,
2765 exclude_variables=self.exclude_variables_rec,
2766 fast_bdt_option=fast_bdt_option,
2768 yield CDCQEValidationPlotsTask(
2769 n_events_training=self.n_events_training,
2770 n_events_testing=self.n_events_testing,
2771 process_type=self.process_type,
2772 experiment_number=experiment_number,
2773 exclude_variables=self.exclude_variables_cdc,
2774 training_target=cdc_training_target,
2775 fast_bdt_option=fast_bdt_option,
2777 yield VXDQEValidationPlotsTask(
2778 n_events_training=self.n_events_training,
2779 n_events_testing=self.n_events_testing,
2780 process_type=self.process_type,
2781 exclude_variables=self.exclude_variables_vxd,
2782 experiment_number=experiment_number,
2783 fast_bdt_option=fast_bdt_option,
2786 if b2luigi.get_setting(
"run_mva_evaluate", default=
True):
2789 yield RecoTrackQEEvaluationTask(
2790 n_events_training=self.n_events_training,
2791 n_events_testing=self.n_events_testing,
2792 process_type=self.process_type,
2793 experiment_number=experiment_number,
2794 cdc_training_target=cdc_training_target,
2795 exclude_variables=self.exclude_variables_rec,
2796 fast_bdt_option=fast_bdt_option,
2798 yield CDCTrackQEEvaluationTask(
2799 n_events_training=self.n_events_training,
2800 n_events_testing=self.n_events_testing,
2801 process_type=self.process_type,
2802 experiment_number=experiment_number,
2803 exclude_variables=self.exclude_variables_cdc,
2804 fast_bdt_option=fast_bdt_option,
2805 training_target=cdc_training_target,
2807 yield VXDTrackQEEvaluationTask(
2808 n_events_training=self.n_events_training,
2809 n_events_testing=self.n_events_testing,
2810 process_type=self.process_type,
2811 experiment_number=experiment_number,
2812 exclude_variables=self.exclude_variables_vxd,
2813 fast_bdt_option=fast_bdt_option,
2817if __name__ ==
"__main__":
2820 nEventsTestOnData = b2luigi.get_setting(
"n_events_test_on_data", default=-1)
2821 if nEventsTestOnData > 0
and 'DATA' in b2luigi.get_setting(
"process_type", default=
"BBBAR"):
2822 from ROOT
import Belle2
2824 environment.setNumberEventsOverride(nEventsTestOnData)
2827 globaltags = b2luigi.get_setting(
"globaltags", default=[])
2828 if len(globaltags) > 0:
2829 basf2.conditions.reset()
2830 for gt
in globaltags:
2831 basf2.conditions.prepend_globaltag(gt)
2832 workers = b2luigi.get_setting(
"workers", default=1)
2833 b2luigi.process(MasterTask(), workers=workers)
get_background_files(folder=None, output_file_info=True)
static Environment & Instance()
Static method to get a reference to the Environment instance.
add_simulation(path, components=None, bkgfiles=None, bkgOverlay=True, forceSetPXDDataReduction=False, usePXDDataReduction=True, cleanupPXDDataReduction=True, generate_2nd_cdc_hits=False, simulateT0jitter=True, isCosmics=False, FilterEvents=False, usePXDGatedMode=False, skipExperimentCheckForBG=False, save_slow_pions_in_mc=False, save_all_charged_particles_in_mc=False)