12combined_module_quality_estimator_teacher
13-----------------------------------------
15Information on the MVA Track Quality Indicator / Estimator can be found
17<https://xwiki.desy.de/xwiki/rest/p/0d3f4>`_.
22This python script is used for the combined training and validation of three
23classifiers, the actual final MVA track quality estimator and the two quality
24estimators for the intermediate standalone track finders that it depends on.
26 - Final MVA track quality estimator:
27 The final quality estimator for fully merged and fitted tracks (RecoTracks).
28 Its classifier uses features from the track fitting, merger, hit pattern, ...
29 But it also uses the outputs from respective intermediate quality
30 estimators for the VXD and the CDC track finding as inputs. It provides
31 the final quality indicator (QI) exported to the track objects.
33 - VXDTF2 track quality estimator:
34 MVA quality estimator for the VXD standalone track finding.
36 - CDC track quality estimator:
37 MVA quality estimator for the CDC standalone track finding.
39Each classifier requires for its training a different training data set and they
40need to be validated on a separate testing data set. Further, the final quality
41estimator can only be trained, when the trained weights for the intermediate
42quality estimators are available. If the final estimator shall be trained without
43one or both previous estimators, the requirements have to be commented out in the
44__init__.py file of tracking.
45For all estimators, a list of variables to be ignored is specified in the MasterTask.
46The current choice is mainly based on pure data MC agreement in these quantities or
47on outdated implementations. It was decided to leave them in the hardcoded "ugly" way
48in here to remind future generations that they exist in principle and they should and
49could be added to the estimator, once their modelling becomes better in future or an
50alternative implementation is programmed.
51To avoid mistakes, b2luigi is used to create a task chain for a combined training and
52validation of all classifiers.
54b2luigi: Understanding the steering file
55~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
57All trainings and validations are done in the correct order in this steering
58file. For the purpose of creating a dependency graph, the `b2luigi
59<https://b2luigi.readthedocs.io>`_ python package is used, which extends the
60`luigi <https://luigi.readthedocs.io>`_ package developed by spotify.
62Each task that has to be done is represented by a special class, which defines
63which defines parameters, output files and which other tasks with which
64parameters it depends on. For example a teacher task, which runs
65``basf2_mva_teacher.py`` to train the classifier, depends on a data collection
66task which runs a reconstruction and writes out track-wise variables into a root
67file for training. An evaluation/validation task for testing the classifier
68requires both the teacher task, as it needs the weightfile to be present, and
69also a data collection task, because it needs a dataset for testing classifier.
71The final task that defines which tasks need to be done for the steering file to
72finish is the ``MasterTask``. When you only want to run parts of the
73training/validation pipeline, you can comment out requirements in the Master
74task or replace them by lower-level tasks during debugging.
79This steering file relies on b2luigi_ for task scheduling and `uncertain_panda
80<https://github.com/nils-braun/uncertain_panda>`_ for uncertainty calculations.
81uncertain_panda is not in the externals and b2luigi is not upto v01-07-01. Both
82can be installed via pip::
84 python3 -m pip install [--user] b2luigi uncertain_panda
86Use the ``--user`` option if you have not rights to install python packages into
87your externals (e.g. because you are using cvmfs) and install them in
88``$HOME/.local`` instead.
93Instead of command line arguments, the b2luigi script is configured via a
94``settings.json`` file. Open it in your favorite text editor and modify it to
95fit to your requirements.
100You can test the b2luigi without running it via::
102 python3 combined_quality_estimator_teacher.py --dry-run
103 python3 combined_quality_estimator_teacher.py --show-output
105This will show the outputs and show potential errors in the definitions of the
106luigi task dependencies. To run the the steering file in normal (local) mode,
109 python3 combined_quality_estimator_teacher.py
111I usually use the interactive luigi web interface via the central scheduler
112which visualizes the task graph while it is running. Therefore, the scheduler
113daemon ``luigid`` has to run in the background, which is located in
114``~/.local/bin/luigid`` in case b2luigi had been installed with ``--user``. For
119Then, execute your steering (e.g. in another terminal) with::
121 python3 combined_quality_estimator_teacher.py --scheduler-port 8886
123To view the web interface, open your webbrowser enter into the url bar::
127If you don't run the steering file on the same machine on which you run your web
128browser, you have two options:
130 1. Run both the steering file and ``luigid`` remotely and use
131 ssh-port-forwarding to your local host. Therefore, run on your local
134 ssh -N -f -L 8886:localhost:8886 <remote_user>@<remote_host>
136 2. Run the ``luigid`` scheduler locally and use the ``--scheduler-host <your
137 local host>`` argument when calling the steering file
139Accessing the results / output files
140~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
142All output files are stored in a directory structure in the ``result_path``. The
143directory tree encodes the used b2luigi parameters. This ensures reproducibility
144and makes parameter searches easy. Sometimes, it is hard to find the relevant
145output files. You can view the whole directory structure by running ``tree
146<result_path>``. Ise the unix ``find`` command to find the files that interest
149 find <result_path> -name "*.pdf" # find all validation plot files
150 find <result_path> -name "*.root" # find all ROOT files
155from pathlib
import Path
159from datetime
import datetime
160from typing
import Iterable
162import matplotlib.pyplot
as plt
165from matplotlib.backends.backend_pdf
import PdfPages
169from packaging
import version
175from tracking_mva_filter_payloads.write_tracking_mva_filter_payloads_to_db
import write_tracking_mva_filter_payloads_to_db
176from tracking_mva_filter_payloads.write_tracking_mva_filter_payloads_to_db
import write_mva_weightfile_content_to_db
181install_helpstring_formatter = (
"\nCould not find {module} python module.Try installing it via\n"
182 " python3 -m pip install [--user] {module}\n")
185 from b2luigi.core.utils
import get_serialized_parameters, get_log_file_dir, create_output_dirs
186 from b2luigi.basf2_helper
import Basf2PathTask, Basf2Task
187 from b2luigi.core.task
import Task, ExternalTask
188 from b2luigi.basf2_helper.utils
import get_basf2_git_hash
189except ModuleNotFoundError:
190 print(install_helpstring_formatter.format(module=
"b2luigi"))
193 from uncertain_panda
import pandas
as upd
194except ModuleNotFoundError:
195 print(install_helpstring_formatter.format(module=
"uncertain_panda"))
203 version.parse(b2luigi.__version__) <= version.parse(
"0.3.2")
and
204 get_basf2_git_hash()
is None and
205 os.getenv(
"BELLE2_LOCAL_DIR")
is not None
207 print(f
"b2luigi version could not obtain git hash because of a bug not yet fixed in version {b2luigi.__version__}\n"
208 "Please install the latest version of b2luigi from github via\n\n"
209 " python3 -m pip install --upgrade [--user] git+https://github.com/nils-braun/b2luigi.git\n")
215def create_fbdt_option_string(fast_bdt_option):
217 returns a readable string created by the fast_bdt_option array
219 return "_nTrees" + str(fast_bdt_option[0]) +
"_nCuts" + str(fast_bdt_option[1]) +
"_nLevels" + \
220 str(fast_bdt_option[2]) +
"_shrin" + str(int(round(100*fast_bdt_option[3], 0)))
223def createV0momenta(x, mu, beta):
225 Copied from Biancas K_S0 particle gun code: Returns a realistic V0 momentum distribution
226 when running over x. Mu and Beta are properties of the function that define center and tails.
227 Used for the particle gun simulation code for K_S0 and Lambda_0
229 return (1/beta)*np.exp(-(x - mu)/beta) * np.exp(-np.exp(-(x - mu) / beta))
232def my_basf2_mva_teacher(
235 weightfile_identifier,
236 target_variable="truth",
237 exclude_variables=None,
238 fast_bdt_option=[200, 8, 3, 0.1]
241 My custom wrapper for basf2 mva teacher. Adapted from code in ``trackfindingcdc_teacher``.
243 :param records_files: List of files with collected ("recorded") variables to use as training data for the MVA.
244 :param tree_name: Name of the TTree in the ROOT file from the ``data_collection_task``
245 that contains the training data for the MVA teacher.
246 :param weightfile_identifier: Name of the weightfile that is created.
247 Should either end in ".xml" for local weightfiles or in ".root", when
248 the weightfile needs later to be uploaded as a payload to the conditions
250 :param target_variable: Feature/variable to use as truth label in the quality estimator MVA classifier.
251 :param exclude_variables: List of collected variables to not use in the training of the QE MVA classifier.
252 In addition to variables containing the "truth" substring, which are excluded by default.
253 :param fast_bdt_option: specified fast BDT options, default: [200, 8, 3, 0.1] [nTrees, nCuts, nLevels, shrinkage]
255 if exclude_variables
is None:
256 exclude_variables = []
258 weightfile_extension = Path(weightfile_identifier).suffix
259 if weightfile_extension
not in {
".xml",
".root"}:
260 raise ValueError(f
"Weightfile Identifier should end in .xml or .root, but ends in {weightfile_extension}")
263 with root_utils.root_open(records_files[0])
as records_tfile:
264 input_tree = records_tfile.Get(tree_name)
265 feature_names = [leave.GetName()
for leave
in input_tree.GetListOfLeaves()]
268 truth_free_variable_names = [
270 for name
in feature_names
272 (
"truth" not in name)
and
273 (name != target_variable)
and
274 (name
not in exclude_variables)
277 if "weight" in truth_free_variable_names:
278 truth_free_variable_names.remove(
"weight")
279 weight_variable =
"weight"
280 elif "__weight__" in truth_free_variable_names:
281 truth_free_variable_names.remove(
"__weight__")
282 weight_variable =
"__weight__"
287 general_options = basf2_mva.GeneralOptions()
288 general_options.m_datafiles = basf2_mva.vector(*records_files)
289 general_options.m_treename = tree_name
290 general_options.m_weight_variable = weight_variable
291 general_options.m_identifier = weightfile_identifier
292 general_options.m_variables = basf2_mva.vector(*truth_free_variable_names)
293 general_options.m_target_variable = target_variable
294 fastbdt_options = basf2_mva.FastBDTOptions()
296 fastbdt_options.m_nTrees = fast_bdt_option[0]
297 fastbdt_options.m_nCuts = fast_bdt_option[1]
298 fastbdt_options.m_nLevels = fast_bdt_option[2]
299 fastbdt_options.m_shrinkage = fast_bdt_option[3]
301 basf2_mva.teacher(general_options, fastbdt_options)
304def _my_uncertain_mean(series: upd.Series):
306 Temporary Workaround bug in ``uncertain_panda`` where a ``ValueError`` is
307 thrown for ``Series.unc.mean`` if the series is empty. Can be replaced by
308 .unc.mean when the issue is fixed.
309 https://github.com/nils-braun/uncertain_panda/issues/2
312 return series.unc.mean()
320def get_uncertain_means_for_qi_cuts(df: upd.DataFrame, column: str, qi_cuts: Iterable[float]):
322 Return a pandas series with an mean of the dataframe column and
323 uncertainty for each quality indicator cut.
325 :param df: Pandas dataframe with at least ``quality_indicator``
326 and another numeric ``column``.
327 :param column: Column of which we want to aggregate the means
328 and uncertainties for different QI cuts
329 :param qi_cuts: Iterable of quality indicator minimal thresholds.
330 :returns: Series of of means and uncertainties with ``qi_cuts`` as index
333 uncertain_means = (_my_uncertain_mean(df.query(f
"quality_indicator > {qi_cut}")[column])
334 for qi_cut
in qi_cuts)
335 uncertain_means_series = upd.Series(data=uncertain_means, index=qi_cuts)
336 return uncertain_means_series
339def plot_with_errobands(uncertain_series,
340 error_band_alpha=0.3,
342 fill_between_kwargs={},
345 Plot an uncertain series with error bands for y-errors
349 uncertain_series = uncertain_series.dropna()
350 ax.plot(uncertain_series.index.values, uncertain_series.nominal_value, **plot_kwargs)
351 ax.fill_between(x=uncertain_series.index,
352 y1=uncertain_series.nominal_value - uncertain_series.std_dev,
353 y2=uncertain_series.nominal_value + uncertain_series.std_dev,
354 alpha=error_band_alpha,
355 **fill_between_kwargs)
358def format_dictionary(adict, width=80, bullet="•"):
360 Helper function to format dictionary to string as a wrapped key-value bullet
361 list. Useful to print metadata from dictionaries.
363 :param adict: Dictionary to format
364 :param width: Characters after which to wrap a key-value line
365 :param bullet: Character to begin a key-value line with, e.g. ``-`` for a
371 return "\n".join(textwrap.fill(f
"{bullet} {key}: {value}", width=width)
372 for (key, value)
in adict.items())
377class GenerateSimTask(Basf2PathTask):
379 Generate simulated Monte Carlo with background overlay.
381 Make sure to use different ``random_seed`` parameters for the training data
382 format the classifier trainings and for the test data for the respective
383 evaluation/validation tasks.
387 n_events = b2luigi.IntParameter()
389 experiment_number = b2luigi.IntParameter()
392 random_seed = b2luigi.Parameter()
394 bkgfiles_dir = b2luigi.Parameter(
403 def output_file_name(self, n_events=None, random_seed=None):
405 Create output file name depending on number of events and production
406 mode that is specified in the random_seed string.
409 n_events = self.n_events
410 if random_seed
is None:
411 random_seed = self.random_seed
412 return "generated_mc_N" + str(n_events) +
"_" + random_seed +
".root"
416 Generate list of output files that the task should produce.
417 The task is considered finished if and only if the outputs all exist.
419 yield self.add_to_output(self.output_file_name())
421 def create_path(self):
423 Create basf2 path to process with event generation and simulation.
425 basf2.set_random_seed(self.random_seed)
426 path = basf2.create_path()
427 if self.experiment_number
in [0, 1002, 1003]:
432 f
"Simulating events with experiment number {self.experiment_number} is not implemented yet.")
434 "EventInfoSetter", evtNumList=[self.n_events], runList=[runNo], expList=[self.experiment_number]
436 if "BBBAR" in self.random_seed:
437 path.add_module(
"EvtGenInput")
438 elif "V0BBBAR" in self.random_seed:
439 path.add_module(
"EvtGenInput")
440 path.add_module(
"InclusiveParticleChecker", particles=[310, 3122], includeConjugates=
True)
442 import generators
as ge
448 if "V0STUDY" in self.random_seed:
449 if "V0STUDYKS" in self.random_seed:
454 if "V0STUDYL0" in self.random_seed:
463 pdgs = [310, 3122, -3122]
465 myx = [i*0.01
for i
in range(321)]
468 y = createV0momenta(x, mu, beta)
470 polParams = myx + myy
474 particlegun = basf2.register_module(
'ParticleGun')
475 particlegun.param(
'pdgCodes', pdg_list)
476 particlegun.param(
'nTracks', 8)
477 particlegun.param(
'momentumGeneration',
'polyline')
478 particlegun.param(
'momentumParams', polParams)
479 particlegun.param(
'thetaGeneration',
'uniformCos')
480 particlegun.param(
'thetaParams', [17, 150])
481 particlegun.param(
'phiGeneration',
'uniform')
482 particlegun.param(
'phiParams', [0, 360])
483 particlegun.param(
'vertexGeneration',
'fixed')
484 particlegun.param(
'xVertexParams', [0])
485 particlegun.param(
'yVertexParams', [0])
486 particlegun.param(
'zVertexParams', [0])
487 path.add_module(particlegun)
488 if "BHABHA" in self.random_seed:
489 ge.add_babayaganlo_generator(path=path, finalstate=
'ee', minenergy=0.15, minangle=10.0)
490 elif "MUMU" in self.random_seed:
491 ge.add_kkmc_generator(path=path, finalstate=
'mu+mu-')
492 elif "YY" in self.random_seed:
493 babayaganlo = basf2.register_module(
'BabayagaNLOInput')
494 babayaganlo.param(
'FinalState',
'gg')
495 babayaganlo.param(
'MaxAcollinearity', 180.0)
496 babayaganlo.param(
'ScatteringAngleRange', [0., 180.])
497 babayaganlo.param(
'FMax', 75000)
498 babayaganlo.param(
'MinEnergy', 0.01)
499 babayaganlo.param(
'Order',
'exp')
500 babayaganlo.param(
'DebugEnergySpread', 0.01)
501 babayaganlo.param(
'Epsilon', 0.00005)
502 path.add_module(babayaganlo)
503 generatorpreselection = basf2.register_module(
'GeneratorPreselection')
504 generatorpreselection.param(
'nChargedMin', 0)
505 generatorpreselection.param(
'nChargedMax', 999)
506 generatorpreselection.param(
'MinChargedPt', 0.15)
507 generatorpreselection.param(
'MinChargedTheta', 17.)
508 generatorpreselection.param(
'MaxChargedTheta', 150.)
509 generatorpreselection.param(
'nPhotonMin', 1)
510 generatorpreselection.param(
'MinPhotonEnergy', 1.5)
511 generatorpreselection.param(
'MinPhotonTheta', 15.0)
512 generatorpreselection.param(
'MaxPhotonTheta', 165.0)
513 generatorpreselection.param(
'applyInCMS',
True)
514 path.add_module(generatorpreselection)
515 empty = basf2.create_path()
516 generatorpreselection.if_value(
'!=11', empty)
517 elif "EEEE" in self.random_seed:
518 ge.add_aafh_generator(path=path, finalstate=
'e+e-e+e-', preselection=
False)
519 elif "EEMUMU" in self.random_seed:
520 ge.add_aafh_generator(path=path, finalstate=
'e+e-mu+mu-', preselection=
False)
521 elif "TAUPAIR" in self.random_seed:
522 ge.add_kkmc_generator(path, finalstate=
'tau+tau-')
523 elif "DDBAR" in self.random_seed:
524 ge.add_continuum_generator(path, finalstate=
'ddbar')
525 elif "UUBAR" in self.random_seed:
526 ge.add_continuum_generator(path, finalstate=
'uubar')
527 elif "SSBAR" in self.random_seed:
528 ge.add_continuum_generator(path, finalstate=
'ssbar')
529 elif "CCBAR" in self.random_seed:
530 ge.add_continuum_generator(path, finalstate=
'ccbar')
537 if self.experiment_number == 1002:
539 components = [
'PXD',
'SVD',
'CDC',
'ECL',
'TOP',
'ARICH',
'TRG']
547 outputFileName=self.get_output_file_name(self.output_file_name()),
554class SplitNMergeSimTask(Basf2Task):
556 Generate simulated Monte Carlo with background overlay.
558 Make sure to use different ``random_seed`` parameters for the training data
559 format the classifier trainings and for the test data for the respective
560 evaluation/validation tasks.
564 n_events = b2luigi.IntParameter()
566 experiment_number = b2luigi.IntParameter()
569 random_seed = b2luigi.Parameter()
571 bkgfiles_dir = b2luigi.Parameter(
580 def output_file_name(self, n_events=None, random_seed=None):
582 Create output file name depending on number of events and production
583 mode that is specified in the random_seed string.
586 n_events = self.n_events
587 if random_seed
is None:
588 random_seed = self.random_seed
589 return "generated_mc_N" + str(n_events) +
"_" + random_seed +
".root"
593 Generate list of output files that the task should produce.
594 The task is considered finished if and only if the outputs all exist.
596 yield self.add_to_output(self.output_file_name())
600 Generate list of luigi Tasks that this Task depends on.
602 n_events_per_task = MasterTask.n_events_per_task
603 quotient, remainder = divmod(self.n_events, n_events_per_task)
604 for i
in range(quotient):
605 yield GenerateSimTask(
606 bkgfiles_dir=self.bkgfiles_dir,
607 num_processes=MasterTask.num_processes,
608 random_seed=self.random_seed +
'_' + str(i).zfill(3),
609 n_events=n_events_per_task,
610 experiment_number=self.experiment_number,
613 yield GenerateSimTask(
614 bkgfiles_dir=self.bkgfiles_dir,
615 num_processes=MasterTask.num_processes,
616 random_seed=self.random_seed +
'_' + str(quotient).zfill(3),
618 experiment_number=self.experiment_number,
621 @b2luigi.on_temporary_files
624 When all GenerateSimTasks finished, merge the output.
626 create_output_dirs(self)
628 file_list = list(self.get_all_input_file_names())
629 print(
"Merge the following files:")
631 cmd = [
"b2file-merge",
"-f"]
632 args = cmd + [self.get_output_file_name(self.output_file_name())] + file_list
633 subprocess.check_call(args)
634 print(
"Finished merging. Now remove the input files to save space.")
636 for tempfile
in file_list:
637 args = cmd2 + [tempfile]
638 subprocess.check_call(args)
641class CheckExistingFile(ExternalTask):
643 Task to check if the given file really exists.
646 filename = b2luigi.Parameter()
650 Specify the output to be the file that was just checked.
652 from luigi
import LocalTarget
653 return LocalTarget(self.filename)
656class VXDQEDataCollectionTask(Basf2PathTask):
658 Collect variables/features from VXDTF2 tracking and write them to a ROOT
661 These variables are to be used as labelled training data for the MVA
662 classifier which is the VXD track quality estimator
665 n_events = b2luigi.IntParameter()
667 experiment_number = b2luigi.IntParameter()
670 random_seed = b2luigi.Parameter()
675 def get_records_file_name(self, n_events=None, random_seed=None):
677 Create output file name depending on number of events and production
678 mode that is specified in the random_seed string.
681 n_events = self.n_events
682 if random_seed
is None:
683 random_seed = self.random_seed
684 if 'vxd' not in random_seed:
685 random_seed +=
'_vxd'
686 if 'DATA' in random_seed:
687 return 'qe_records_DATA_vxd.root'
689 if 'USESIMBB' in random_seed:
690 random_seed =
'BBBAR_' + random_seed.split(
"_", 1)[1]
691 elif 'USESIMEE' in random_seed:
692 random_seed =
'BHABHA_' + random_seed.split(
"_", 1)[1]
693 return 'qe_records_N' + str(n_events) +
'_' + random_seed +
'.root'
695 def get_input_files(self, n_events=None, random_seed=None):
697 Get input file names depending on the use case: If they already exist, search in
698 the corresponding folders, for data check the specified list and if they are created
699 in the same run, check for the task that produced them.
702 n_events = self.n_events
703 if random_seed
is None:
704 random_seed = self.random_seed
705 if "USESIM" in random_seed:
706 if 'USESIMBB' in random_seed:
707 random_seed =
'BBBAR_' + random_seed.split(
"_", 1)[1]
708 elif 'USESIMEE' in random_seed:
709 random_seed =
'BHABHA_' + random_seed.split(
"_", 1)[1]
710 return [
'datafiles/' + GenerateSimTask.output_file_name(GenerateSimTask,
711 n_events=n_events, random_seed=random_seed)]
712 elif "DATA" in random_seed:
713 return MasterTask.datafiles
715 return self.get_input_file_names(GenerateSimTask.output_file_name(
716 GenerateSimTask, n_events=n_events, random_seed=random_seed))
720 Generate list of luigi Tasks that this Task depends on.
722 if "USESIM" in self.random_seed
or "DATA" in self.random_seed:
723 for filename
in self.get_input_files():
724 yield CheckExistingFile(
728 yield SplitNMergeSimTask(
729 bkgfiles_dir=MasterTask.bkgfiles_by_exp[self.experiment_number],
730 random_seed=self.random_seed,
731 n_events=self.n_events,
732 experiment_number=self.experiment_number,
737 Generate list of output files that the task should produce.
738 The task is considered finished if and only if the outputs all exist.
740 yield self.add_to_output(self.get_records_file_name())
742 def create_path(self):
744 Create basf2 path with VXDTF2 tracking and VXD QE data collection.
746 path = basf2.create_path()
747 inputFileNames = self.get_input_files()
750 inputFileNames=inputFileNames,
752 path.add_module(
"Gearbox")
753 tracking.add_geometry_modules(path)
754 if 'DATA' in self.random_seed:
755 from rawdata
import add_unpackers
756 add_unpackers(path, components=[
'SVD',
'PXD'])
757 tracking.add_hit_preparation_modules(path)
758 tracking.add_vxd_track_finding_vxdtf2(
759 path, components=[
"SVD"], add_mva_quality_indicator=
False
761 if 'DATA' in self.random_seed:
763 "VXDQETrainingDataCollector",
764 TrainingDataOutputName=self.get_output_file_name(self.get_records_file_name()),
765 SpacePointTrackCandsStoreArrayName=
"SPTrackCands",
766 EstimationMethod=
"tripletFit",
768 ClusterInformation=
"Average",
769 MCStrictQualityEstimator=
False,
775 "TrackFinderMCTruthRecoTracks",
776 RecoTracksStoreArrayName=
"MCRecoTracks",
783 "VXDQETrainingDataCollector",
784 TrainingDataOutputName=self.get_output_file_name(self.get_records_file_name()),
785 SpacePointTrackCandsStoreArrayName=
"SPTrackCands",
786 EstimationMethod=
"tripletFit",
788 ClusterInformation=
"Average",
789 MCStrictQualityEstimator=
True,
795class CDCQEDataCollectionTask(Basf2PathTask):
797 Collect variables/features from CDC tracking and write them to a ROOT file.
799 These variables are to be used as labelled training data for the MVA
800 classifier which is the CDC track quality estimator
803 n_events = b2luigi.IntParameter()
805 experiment_number = b2luigi.IntParameter()
808 random_seed = b2luigi.Parameter()
813 def get_records_file_name(self, n_events=None, random_seed=None):
815 Create output file name depending on number of events and production
816 mode that is specified in the random_seed string.
819 n_events = self.n_events
820 if random_seed
is None:
821 random_seed = self.random_seed
822 if 'cdc' not in random_seed:
823 random_seed +=
'_cdc'
824 if 'DATA' in random_seed:
825 return 'qe_records_DATA_cdc.root'
827 if 'USESIMBB' in random_seed:
828 random_seed =
'BBBAR_' + random_seed.split(
"_", 1)[1]
829 elif 'USESIMEE' in random_seed:
830 random_seed =
'BHABHA_' + random_seed.split(
"_", 1)[1]
831 return 'qe_records_N' + str(n_events) +
'_' + random_seed +
'.root'
833 def get_input_files(self, n_events=None, random_seed=None):
835 Get input file names depending on the use case: If they already exist, search in
836 the corresponding folders, for data check the specified list and if they are created
837 in the same run, check for the task that produced them.
840 n_events = self.n_events
841 if random_seed
is None:
842 random_seed = self.random_seed
843 if "USESIM" in random_seed:
844 if 'USESIMBB' in random_seed:
845 random_seed =
'BBBAR_' + random_seed.split(
"_", 1)[1]
846 elif 'USESIMEE' in random_seed:
847 random_seed =
'BHABHA_' + random_seed.split(
"_", 1)[1]
848 return [
'datafiles/' + GenerateSimTask.output_file_name(GenerateSimTask,
849 n_events=n_events, random_seed=random_seed)]
850 elif "DATA" in random_seed:
851 return MasterTask.datafiles
853 return self.get_input_file_names(GenerateSimTask.output_file_name(
854 GenerateSimTask, n_events=n_events, random_seed=random_seed))
858 Generate list of luigi Tasks that this Task depends on.
860 if "USESIM" in self.random_seed
or "DATA" in self.random_seed:
861 for filename
in self.get_input_files():
862 yield CheckExistingFile(
866 yield SplitNMergeSimTask(
867 bkgfiles_dir=MasterTask.bkgfiles_by_exp[self.experiment_number],
868 random_seed=self.random_seed,
869 n_events=self.n_events,
870 experiment_number=self.experiment_number,
875 Generate list of output files that the task should produce.
876 The task is considered finished if and only if the outputs all exist.
878 yield self.add_to_output(self.get_records_file_name())
880 def create_path(self):
882 Create basf2 path with CDC standalone tracking and CDC QE with recording filter for MVA feature collection.
884 path = basf2.create_path()
885 inputFileNames = self.get_input_files()
888 inputFileNames=inputFileNames,
890 path.add_module(
"Gearbox")
891 tracking.add_geometry_modules(path)
892 if 'DATA' in self.random_seed:
893 filter_choice =
"recording_data"
894 from rawdata
import add_unpackers
895 add_unpackers(path, components=[
'CDC'])
897 filter_choice =
"recording"
900 tracking.add_cdc_track_finding(path, add_mva_quality_indicator=
True)
902 basf2.set_module_parameters(
904 name=
"TFCDC_TrackQualityEstimator",
905 filter=filter_choice,
907 "rootFileName": self.get_output_file_name(self.get_records_file_name())
909 deactivateIfDeadBoard=
False
914class RecoTrackQEDataCollectionTask(Basf2PathTask):
916 Collect variables/features from the reco track reconstruction including the
917 fit and write them to a ROOT file.
919 These variables are to be used as labelled training data for the MVA
920 classifier which is the MVA track quality estimator. The collected
921 variables include the classifier outputs from the VXD and CDC quality
922 estimators, namely the CDC and VXD quality indicators, combined with fit,
923 merger, timing, energy loss information etc. This task requires the
924 subdetector quality estimators to be trained.
928 n_events = b2luigi.IntParameter()
930 experiment_number = b2luigi.IntParameter()
933 random_seed = b2luigi.Parameter()
935 cdc_training_target = b2luigi.Parameter()
939 recotrack_option = b2luigi.Parameter(
941 default=
'deleteCDCQI080'
945 fast_bdt_option = b2luigi.ListParameter(
947 hashed=
True, default=[200, 8, 3, 0.1]
954 def get_records_file_name(self, n_events=None, random_seed=None, recotrack_option=None):
956 Create output file name depending on number of events and production
957 mode that is specified in the random_seed string.
960 n_events = self.n_events
961 if random_seed
is None:
962 random_seed = self.random_seed
963 if recotrack_option
is None:
964 if isinstance(self.recotrack_option, str):
965 recotrack_option = self.recotrack_option
967 recotrack_option = self.recotrack_option._default
968 if not isinstance(recotrack_option, str):
969 recotrack_option = recotrack_option._default
970 if 'rec' not in random_seed:
971 random_seed +=
'_rec'
972 if 'DATA' in random_seed:
973 return 'qe_records_DATA_rec.root'
975 if 'USESIMBB' in random_seed:
976 random_seed =
'BBBAR_' + random_seed.split(
"_", 1)[1]
977 elif 'USESIMEE' in random_seed:
978 random_seed =
'BHABHA_' + random_seed.split(
"_", 1)[1]
979 return 'qe_records_N' + str(n_events) +
'_' + random_seed +
'_' + recotrack_option +
'.root'
981 def get_input_files(self, n_events=None, random_seed=None):
983 Get input file names depending on the use case: If they already exist, search in
984 the corresponding folders, for data check the specified list and if they are created
985 in the same run, check for the task that produced them.
988 n_events = self.n_events
989 if random_seed
is None:
990 random_seed = self.random_seed
991 if "USESIM" in random_seed:
992 if 'USESIMBB' in random_seed:
993 random_seed =
'BBBAR_' + random_seed.split(
"_", 1)[1]
994 elif 'USESIMEE' in random_seed:
995 random_seed =
'BHABHA_' + random_seed.split(
"_", 1)[1]
996 return [
'datafiles/' + GenerateSimTask.output_file_name(GenerateSimTask,
997 n_events=n_events, random_seed=random_seed)]
998 elif "DATA" in random_seed:
999 return MasterTask.datafiles
1001 return self.get_input_file_names(GenerateSimTask.output_file_name(
1002 GenerateSimTask, n_events=n_events, random_seed=random_seed))
1006 Generate list of luigi Tasks that this Task depends on.
1008 if "USESIM" in self.random_seed
or "DATA" in self.random_seed:
1009 for filename
in self.get_input_files():
1010 yield CheckExistingFile(
1014 yield SplitNMergeSimTask(
1015 bkgfiles_dir=MasterTask.bkgfiles_by_exp[self.experiment_number],
1016 random_seed=self.random_seed,
1017 n_events=self.n_events,
1018 experiment_number=self.experiment_number,
1020 if "DATA" not in self.random_seed:
1021 if 'useCDC' not in self.recotrack_option
and 'noCDC' not in self.recotrack_option:
1022 yield CDCQETeacherTask(
1023 n_events_training=MasterTask.n_events_training,
1024 experiment_number=self.experiment_number,
1025 training_target=self.cdc_training_target,
1026 process_type=self.random_seed.split(
"_", 1)[0],
1027 exclude_variables=MasterTask.exclude_variables_cdc,
1028 fast_bdt_option=self.fast_bdt_option,
1030 if 'useVXD' not in self.recotrack_option
and 'noVXD' not in self.recotrack_option:
1031 yield VXDQETeacherTask(
1032 n_events_training=MasterTask.n_events_training,
1033 experiment_number=self.experiment_number,
1034 process_type=self.random_seed.split(
"_", 1)[0],
1035 exclude_variables=MasterTask.exclude_variables_vxd,
1036 fast_bdt_option=self.fast_bdt_option,
1041 Generate list of output files that the task should produce.
1042 The task is considered finished if and only if the outputs all exist.
1044 yield self.add_to_output(self.get_records_file_name())
1046 def create_path(self):
1048 Create basf2 reconstruction path that should mirror the default path
1049 from ``add_tracking_reconstruction()``, but with modules for the VXD QE
1050 and CDC QE application and for collection of variables for the reco
1051 track quality estimator.
1053 path = basf2.create_path()
1054 inputFileNames = self.get_input_files()
1057 inputFileNames=inputFileNames,
1059 path.add_module(
"Gearbox")
1064 if 'noCDC' in self.recotrack_option:
1066 if 'noVXD' in self.recotrack_option:
1068 if 'DATA' in self.random_seed:
1069 from rawdata
import add_unpackers
1071 tracking.add_tracking_reconstruction(path, add_cdcTrack_QI=mvaCDC, add_vxdTrack_QI=mvaVXD, add_recoTrack_QI=
True)
1077 if (
'DATA' in self.random_seed
or 'useCDC' in self.recotrack_option)
and 'noCDC' not in self.recotrack_option:
1078 cdc_identifier =
'datafiles/' + \
1079 CDCQETeacherTask.get_weightfile_xml_identifier(CDCQETeacherTask, fast_bdt_option=self.fast_bdt_option)
1080 if os.path.exists(cdc_identifier):
1081 replace_cdc_qi =
True
1082 elif 'useCDC' in self.recotrack_option:
1083 raise ValueError(f
"CDC QI Identifier not found: {cdc_identifier}")
1085 replace_cdc_qi =
False
1086 elif 'noCDC' in self.recotrack_option:
1087 replace_cdc_qi =
False
1089 cdc_identifier = self.get_input_file_names(
1090 CDCQETeacherTask.get_weightfile_xml_identifier(
1091 CDCQETeacherTask, fast_bdt_option=self.fast_bdt_option))[0]
1092 replace_cdc_qi =
True
1093 if (
'DATA' in self.random_seed
or 'useVXD' in self.recotrack_option)
and 'noVXD' not in self.recotrack_option:
1094 vxd_identifier =
'datafiles/' + \
1095 VXDQETeacherTask.get_weightfile_xml_identifier(VXDQETeacherTask, fast_bdt_option=self.fast_bdt_option)
1096 if os.path.exists(vxd_identifier):
1097 replace_vxd_qi =
True
1098 elif 'useVXD' in self.recotrack_option:
1099 raise ValueError(f
"VXD QI Identifier not found: {vxd_identifier}")
1101 replace_vxd_qi =
False
1102 elif 'noVXD' in self.recotrack_option:
1103 replace_vxd_qi =
False
1105 vxd_identifier = self.get_input_file_names(
1106 VXDQETeacherTask.get_weightfile_xml_identifier(
1107 VXDQETeacherTask, fast_bdt_option=self.fast_bdt_option))[0]
1108 replace_vxd_qi =
True
1110 cdc_qe_mva_filter_parameters =
None
1114 if 'deleteCDCQI' in self.recotrack_option:
1115 cut_index = self.recotrack_option.find(
'deleteCDCQI') + len(
'deleteCDCQI')
1116 cut = int(self.recotrack_option[cut_index:cut_index+3])/100.
1118 cdc_qe_mva_filter_parameters = {
1119 "identifier": cdc_identifier,
"cut": cut}
1121 cdc_qe_mva_filter_parameters = {
1123 elif replace_cdc_qi:
1124 cdc_qe_mva_filter_parameters = {
1125 "identifier": cdc_identifier}
1127 basf2.conditions.prepend_testing_payloads(
"localdb/database.txt")
1129 if cdc_qe_mva_filter_parameters
is not None and cdc_identifier
is not None:
1130 name =
'TrackingMVAFilterParameters'
1133 iovList=(0, 0, 0, -1),
1134 weightfile_identifier=cdc_identifier,
1136 cdc_qe_mva_filter_parameters = {
'DBPayloadName': name}
1137 if replace_vxd_qi
and vxd_identifier
is not None:
1138 vxd_name =
'VXDQualityEstimatorMVAWeightFileIdentifier'
1139 with open(vxd_identifier)
as f:
1140 weight_file_content = f.read()
1141 vxd_name = write_mva_weightfile_content_to_db(vxd_name, weight_file_content, (0, 0, 0, -1))
1142 if cdc_qe_mva_filter_parameters
is not None:
1144 basf2.set_module_parameters(
1146 name=
"TFCDC_TrackQualityEstimator",
1147 filterParameters=cdc_qe_mva_filter_parameters,
1149 resetTakenFlag=
True,
1150 deactivateIfDeadBoard=
False,
1153 basf2.set_module_parameters(
1155 name=
"VXDQualityEstimatorMVA",
1156 WeightFileIdentifier=vxd_identifier)
1159 track_qe_module_name =
"TrackQualityEstimatorMVA"
1160 mc_track_matcher_module_name =
"MCRecoTracksMatcher"
1161 mc_matcher_module_found =
False
1162 qe_module_found =
False
1163 new_path = basf2.create_path()
1164 for module
in path.modules():
1165 if module.name() == track_qe_module_name:
1168 new_path.add_module(
1174 recoTrackColName=
'RecoTracks',
1175 trackColName=
'MDSTTracks')
1176 qe_module_found =
True
1177 elif module.name() == mc_track_matcher_module_name:
1178 new_path.add_module(module)
1180 new_path.add_module(
1181 "TrackQETrainingDataCollector",
1182 TrainingDataOutputName=self.get_output_file_name(self.get_records_file_name()),
1183 collectEventFeatures=
True
1185 mc_matcher_module_found =
True
1187 new_path.add_module(module)
1188 if not qe_module_found:
1189 raise KeyError(f
"No module {track_qe_module_name} found in path")
1190 if not mc_matcher_module_found:
1191 raise KeyError(f
"No module {mc_matcher_module_found} found in path")
1196class TrackQETeacherBaseTask(Basf2Task):
1198 A teacher task runs the basf2 mva teacher on the training data provided by a
1199 data collection task.
1201 Since teacher tasks are needed for all quality estimators covered by this
1202 steering file and the only thing that changes is the required data
1203 collection task and some training parameters, I decided to use inheritance
1204 and have the basic functionality in this base class/interface and have the
1205 specific teacher tasks inherit from it.
1208 n_events_training = b2luigi.IntParameter()
1210 experiment_number = b2luigi.IntParameter()
1214 process_type = b2luigi.Parameter(
1220 training_target = b2luigi.Parameter(
1227 exclude_variables = b2luigi.ListParameter(
1229 hashed=
True, default=[]
1233 fast_bdt_option = b2luigi.ListParameter(
1235 hashed=
True, default=[200, 8, 3, 0.1]
1240 def weightfile_identifier_basename(self):
1242 Property defining the basename for the .xml and .root weightfiles that are created.
1243 Has to be implemented by the inheriting teacher task class.
1245 raise NotImplementedError(
1246 "Teacher Task must define a static weightfile_identifier"
1249 def get_weightfile_xml_identifier(self, fast_bdt_option=None, recotrack_option=None):
1251 Name of the xml weightfile that is created by the teacher task.
1252 It is subsequently used as a local weightfile in the following validation tasks.
1254 if fast_bdt_option
is None:
1255 fast_bdt_option = self.fast_bdt_option
1256 if recotrack_option
is None and hasattr(self,
'recotrack_option'):
1257 if isinstance(self.recotrack_option, str):
1258 recotrack_option = self.recotrack_option
1260 recotrack_option = self.recotrack_option._default
1262 recotrack_option =
''
1263 weightfile_details = create_fbdt_option_string(fast_bdt_option)
1264 weightfile_name = self.weightfile_identifier_basename + weightfile_details
1265 if recotrack_option !=
'':
1266 weightfile_name = weightfile_name +
'_' + recotrack_option
1267 return weightfile_name +
"_weights.xml"
1270 def tree_name(self):
1272 Property defining the name of the tree in the ROOT file from the
1273 ``data_collection_task`` that contains the recorded training data. Must
1274 implemented by the inheriting specific teacher task class.
1276 raise NotImplementedError(
"Teacher Task must define a static tree_name")
1279 def random_seed(self):
1281 Property defining random seed to be used by the ``GenerateSimTask``.
1282 Should differ from the random seed in the test data samples. Must
1283 implemented by the inheriting specific teacher task class.
1285 raise NotImplementedError(
"Teacher Task must define a static random seed")
1288 def data_collection_task(self) -> Basf2PathTask:
1290 Property defining the specific ``DataCollectionTask`` to require. Must
1291 implemented by the inheriting specific teacher task class.
1293 raise NotImplementedError(
1294 "Teacher Task must define a data collection task to require "
1299 Generate list of luigi Tasks that this Task depends on.
1301 if 'USEREC' in self.process_type:
1302 if 'USERECBB' in self.process_type:
1304 elif 'USERECEE' in self.process_type:
1306 yield CheckExistingFile(
1307 filename=
'datafiles/qe_records_N' + str(self.n_events_training) +
'_' + process +
'_' + self.random_seed +
'.root',
1310 yield self.data_collection_task(
1311 num_processes=MasterTask.num_processes,
1312 n_events=self.n_events_training,
1313 experiment_number=self.experiment_number,
1314 random_seed=self.process_type +
'_' + self.random_seed,
1319 Generate list of output files that the task should produce.
1320 The task is considered finished if and only if the outputs all exist.
1322 yield self.add_to_output(self.get_weightfile_xml_identifier())
1326 Use basf2_mva teacher to create MVA weightfile from collected training
1329 This is the main process that is dispatched by the ``run`` method that
1330 is inherited from ``Basf2Task``.
1332 if 'USEREC' in self.process_type:
1333 if 'USERECBB' in self.process_type:
1335 elif 'USERECEE' in self.process_type:
1337 records_files = [
'datafiles/qe_records_N' + str(self.n_events_training) +
1338 '_' + process +
'_' + self.random_seed +
'.root']
1340 if hasattr(self,
'recotrack_option'):
1341 records_files = self.get_input_file_names(
1342 self.data_collection_task.get_records_file_name(
1343 self.data_collection_task,
1344 n_events=self.n_events_training,
1345 random_seed=self.process_type +
'_' + self.random_seed,
1346 recotrack_option=self.recotrack_option))
1348 records_files = self.get_input_file_names(
1349 self.data_collection_task.get_records_file_name(
1350 self.data_collection_task,
1351 n_events=self.n_events_training,
1352 random_seed=self.process_type +
'_' + self.random_seed))
1354 my_basf2_mva_teacher(
1355 records_files=records_files,
1356 tree_name=self.tree_name,
1357 weightfile_identifier=self.get_output_file_name(self.get_weightfile_xml_identifier()),
1358 target_variable=self.training_target,
1359 exclude_variables=self.exclude_variables,
1360 fast_bdt_option=self.fast_bdt_option,
1364class VXDQETeacherTask(TrackQETeacherBaseTask):
1366 Task to run basf2 mva teacher on collected data for VXDTF2 track quality estimator
1369 weightfile_identifier_basename =
"vxdtf2_mva_qe"
1374 random_seed =
"train_vxd"
1377 data_collection_task = VXDQEDataCollectionTask
1380class CDCQETeacherTask(TrackQETeacherBaseTask):
1382 Task to run basf2 mva teacher on collected data for CDC track quality estimator
1385 weightfile_identifier_basename =
"cdc_mva_qe"
1388 tree_name =
"records"
1390 random_seed =
"train_cdc"
1393 data_collection_task = CDCQEDataCollectionTask
1396class RecoTrackQETeacherTask(TrackQETeacherBaseTask):
1398 Task to run basf2 mva teacher on collected data for the final, combined
1399 track quality estimator
1404 recotrack_option = b2luigi.Parameter(
1406 default=
'deleteCDCQI080'
1411 weightfile_identifier_basename =
"recotrack_mva_qe"
1416 random_seed =
"train_rec"
1419 data_collection_task = RecoTrackQEDataCollectionTask
1421 cdc_training_target = b2luigi.Parameter()
1425 Generate list of luigi Tasks that this Task depends on.
1427 if 'USEREC' in self.process_type:
1428 if 'USERECBB' in self.process_type:
1430 elif 'USERECEE' in self.process_type:
1432 yield CheckExistingFile(
1433 filename=
'datafiles/qe_records_N' + str(self.n_events_training) +
'_' + process +
'_' + self.random_seed +
'.root',
1436 yield self.data_collection_task(
1437 cdc_training_target=self.cdc_training_target,
1438 num_processes=MasterTask.num_processes,
1439 n_events=self.n_events_training,
1440 experiment_number=self.experiment_number,
1441 random_seed=self.process_type +
'_' + self.random_seed,
1442 recotrack_option=self.recotrack_option,
1443 fast_bdt_option=self.fast_bdt_option,
1447class HarvestingValidationBaseTask(Basf2PathTask):
1449 Run track reconstruction with MVA quality estimator and write out
1450 (="harvest") a root file with variables useful for the validation.
1454 n_events_testing = b2luigi.IntParameter()
1456 n_events_training = b2luigi.IntParameter()
1458 experiment_number = b2luigi.IntParameter()
1462 process_type = b2luigi.Parameter(
1469 exclude_variables = b2luigi.ListParameter(
1475 fast_bdt_option = b2luigi.ListParameter(
1477 hashed=
True, default=[200, 8, 3, 0.1]
1481 validation_output_file_name =
"harvesting_validation.root"
1483 reco_output_file_name =
"reconstruction.root"
1488 def teacher_task(self) -> TrackQETeacherBaseTask:
1490 Teacher task to require to provide a quality estimator weightfile for ``add_tracking_with_quality_estimation``
1492 raise NotImplementedError()
1494 def add_tracking_with_quality_estimation(self, path: basf2.Path) ->
None:
1496 Add modules for track reconstruction to basf2 path that are to be
1497 validated. Besides track finding it should include MC matching, fitted
1498 track creation and a quality estimator module.
1500 raise NotImplementedError()
1504 Generate list of luigi Tasks that this Task depends on.
1506 yield self.teacher_task(
1507 n_events_training=self.n_events_training,
1508 experiment_number=self.experiment_number,
1509 process_type=self.process_type,
1510 exclude_variables=self.exclude_variables,
1511 fast_bdt_option=self.fast_bdt_option,
1513 if 'USE' in self.process_type:
1514 if 'BB' in self.process_type:
1516 elif 'EE' in self.process_type:
1518 yield CheckExistingFile(
1519 filename=
'datafiles/generated_mc_N' + str(self.n_events_testing) +
'_' + process +
'_test.root'
1522 yield SplitNMergeSimTask(
1523 bkgfiles_dir=MasterTask.bkgfiles_by_exp[self.experiment_number],
1524 random_seed=self.process_type +
'_test',
1525 n_events=self.n_events_testing,
1526 experiment_number=self.experiment_number,
1531 Generate list of output files that the task should produce.
1532 The task is considered finished if and only if the outputs all exist.
1534 yield self.add_to_output(self.validation_output_file_name)
1535 yield self.add_to_output(self.reco_output_file_name)
1537 def create_path(self):
1539 Create a basf2 path that uses ``add_tracking_with_quality_estimation()``
1540 and adds the ``CombinedTrackingValidationModule`` to write out variables
1544 path = basf2.create_path()
1545 if 'USE' in self.process_type:
1546 if 'BB' in self.process_type:
1548 elif 'EE' in self.process_type:
1550 inputFileNames = [
'datafiles/generated_mc_N' + str(self.n_events_testing) +
'_' + process +
'_test.root']
1552 inputFileNames = self.get_input_file_names(GenerateSimTask.output_file_name(
1553 GenerateSimTask, n_events=self.n_events_testing, random_seed=self.process_type +
'_test'))
1556 inputFileNames=inputFileNames,
1558 path.add_module(
"Gearbox")
1559 tracking.add_geometry_modules(path)
1560 tracking.add_hit_preparation_modules(path)
1562 self.add_tracking_with_quality_estimation(path)
1565 CombinedTrackingValidationModule(
1569 output_file_name=self.get_output_file_name(
1570 self.validation_output_file_name
1576 outputFileName=self.get_output_file_name(self.reco_output_file_name),
1581class VXDQEHarvestingValidationTask(HarvestingValidationBaseTask):
1583 Run VXDTF2 track reconstruction and write out (="harvest") a root file with
1584 variables useful for validation of the VXD Quality Estimator.
1588 validation_output_file_name =
"vxd_qe_harvesting_validation.root"
1590 reco_output_file_name =
"vxd_qe_reconstruction.root"
1592 teacher_task = VXDQETeacherTask
1594 def add_tracking_with_quality_estimation(self, path):
1596 Add modules for VXDTF2 tracking with VXD quality estimator to basf2 path.
1598 tracking.add_vxd_track_finding_vxdtf2(
1601 reco_tracks=
"RecoTracks",
1602 add_mva_quality_indicator=
True,
1606 vxd_identifier = self.get_input_file_names(
1607 self.teacher_task.get_weightfile_xml_identifier(self.teacher_task, fast_bdt_option=self.fast_bdt_option)
1609 with open(vxd_identifier)
as f:
1610 weight_file_content = f.read()
1611 vxd_name = write_mva_weightfile_content_to_db(
1612 dbobj_name=
'VXDQualityEstimatorMVAWeightFileIdentifier',
1613 content=weight_file_content,
1614 iovList=(0, 0, 0, -1)
1616 basf2.set_module_parameters(
1618 name=
"VXDQualityEstimatorMVA",
1619 WeightFileIdentifier=vxd_name,
1621 tracking.add_mc_matcher(path, components=[
"SVD"])
1622 tracking.add_track_fit_and_track_creator(path, components=[
"SVD"])
1625class CDCQEHarvestingValidationTask(HarvestingValidationBaseTask):
1627 Run CDC reconstruction and write out (="harvest") a root file with variables
1628 useful for validation of the CDC Quality Estimator.
1631 training_target = b2luigi.Parameter()
1633 validation_output_file_name =
"cdc_qe_harvesting_validation.root"
1635 reco_output_file_name =
"cdc_qe_reconstruction.root"
1637 teacher_task = CDCQETeacherTask
1642 Generate list of luigi Tasks that this Task depends on.
1644 yield self.teacher_task(
1645 n_events_training=self.n_events_training,
1646 experiment_number=self.experiment_number,
1647 process_type=self.process_type,
1648 training_target=self.training_target,
1649 exclude_variables=self.exclude_variables,
1650 fast_bdt_option=self.fast_bdt_option,
1652 if 'USE' in self.process_type:
1653 if 'BB' in self.process_type:
1655 elif 'EE' in self.process_type:
1657 yield CheckExistingFile(
1658 filename=
'datafiles/generated_mc_N' + str(self.n_events_testing) +
'_' + process +
'_test.root'
1661 yield SplitNMergeSimTask(
1662 bkgfiles_dir=MasterTask.bkgfiles_by_exp[self.experiment_number],
1663 random_seed=self.process_type +
'_test',
1664 n_events=self.n_events_testing,
1665 experiment_number=self.experiment_number,
1668 def add_tracking_with_quality_estimation(self, path):
1670 Add modules for CDC standalone tracking with CDC quality estimator to basf2 path.
1672 tracking.add_cdc_track_finding(
1674 output_reco_tracks=
"RecoTracks",
1675 add_mva_quality_indicator=
True,
1678 basf2.conditions.prepend_testing_payloads(
"localdb/database.txt")
1679 cdc_qe_mva_filter_parameters = {
1680 "identifier": self.get_input_file_names(
1681 CDCQETeacherTask.get_weightfile_xml_identifier(
1683 fast_bdt_option=self.fast_bdt_option))[0]}
1686 name =
'TrackingMVAFilterParameters'
1689 iovList=(0, 0, 0, -1),
1690 weightfile_identifier=self.get_input_file_names(
1691 CDCQETeacherTask.get_weightfile_xml_identifier(
1693 fast_bdt_option=self.fast_bdt_option))[0],
1695 cdc_qe_mva_filter_parameters = {
'DBPayloadName': name}
1696 basf2.set_module_parameters(
1698 name=
"TFCDC_TrackQualityEstimator",
1699 filterParameters=cdc_qe_mva_filter_parameters,
1700 deactivateIfDeadBoard=
False,
1702 tracking.add_track_fit_and_track_creator(path, components=[
"CDC"])
1703 tracking.add_mc_matcher(path, components=[
"CDC"])
1706class RecoTrackQEHarvestingValidationTask(HarvestingValidationBaseTask):
1708 Run track reconstruction and write out (="harvest") a root file with variables
1709 useful for validation of the MVA track Quality Estimator.
1712 cdc_training_target = b2luigi.Parameter()
1714 validation_output_file_name =
"reco_qe_harvesting_validation.root"
1716 reco_output_file_name =
"reco_qe_reconstruction.root"
1718 teacher_task = RecoTrackQETeacherTask
1722 Generate list of luigi Tasks that this Task depends on.
1724 yield CDCQETeacherTask(
1725 n_events_training=self.n_events_training,
1726 experiment_number=self.experiment_number,
1727 process_type=self.process_type,
1728 training_target=self.cdc_training_target,
1729 exclude_variables=MasterTask.exclude_variables_cdc,
1730 fast_bdt_option=self.fast_bdt_option,
1732 yield VXDQETeacherTask(
1733 n_events_training=self.n_events_training,
1734 experiment_number=self.experiment_number,
1735 process_type=self.process_type,
1736 exclude_variables=MasterTask.exclude_variables_vxd,
1737 fast_bdt_option=self.fast_bdt_option,
1740 yield self.teacher_task(
1741 n_events_training=self.n_events_training,
1742 experiment_number=self.experiment_number,
1743 process_type=self.process_type,
1744 exclude_variables=self.exclude_variables,
1745 cdc_training_target=self.cdc_training_target,
1746 fast_bdt_option=self.fast_bdt_option,
1748 if 'USE' in self.process_type:
1749 if 'BB' in self.process_type:
1751 elif 'EE' in self.process_type:
1753 yield CheckExistingFile(
1754 filename=
'datafiles/generated_mc_N' + str(self.n_events_testing) +
'_' + process +
'_test.root'
1757 yield SplitNMergeSimTask(
1758 bkgfiles_dir=MasterTask.bkgfiles_by_exp[self.experiment_number],
1759 random_seed=self.process_type +
'_test',
1760 n_events=self.n_events_testing,
1761 experiment_number=self.experiment_number,
1764 def add_tracking_with_quality_estimation(self, path):
1766 Add modules for reco tracking with all track quality estimators to basf2 path.
1770 tracking.add_tracking_reconstruction(
1772 add_cdcTrack_QI=
True,
1773 add_vxdTrack_QI=
True,
1774 add_recoTrack_QI=
True,
1775 skipGeometryAdding=
True,
1776 skipHitPreparerAdding=
True,
1781 basf2.conditions.prepend_testing_payloads(
"localdb/database.txt")
1782 cdc_qe_mva_filter_parameters = {
1783 "identifier": self.get_input_file_names(
1784 CDCQETeacherTask.get_weightfile_xml_identifier(
1786 fast_bdt_option=self.fast_bdt_option))[0]}
1789 name =
'TrackingMVAFilterParameters'
1792 iovList=(0, 0, 0, -1),
1793 weightfile_identifier=self.get_input_file_names(
1794 CDCQETeacherTask.get_weightfile_xml_identifier(
1796 fast_bdt_option=self.fast_bdt_option))[0],
1798 cdc_qe_mva_filter_parameters = {
'DBPayloadName': name}
1799 basf2.set_module_parameters(
1801 name=
"TFCDC_TrackQualityEstimator",
1802 filterParameters=cdc_qe_mva_filter_parameters,
1803 deactivateIfDeadBoard=
False,
1805 vxd_identifier = self.get_input_file_names(
1806 VXDQETeacherTask.get_weightfile_xml_identifier(VXDQETeacherTask, fast_bdt_option=self.fast_bdt_option)
1808 with open(vxd_identifier)
as f:
1809 weight_file_content = f.read()
1810 vxd_name = write_mva_weightfile_content_to_db(
1811 dbobj_name=
'VXDQualityEstimatorMVAWeightFileIdentifier',
1812 content=weight_file_content,
1813 iovList=(0, 0, 0, -1)
1815 basf2.set_module_parameters(
1817 name=
"VXDQualityEstimatorMVA",
1818 WeightFileIdentifier=vxd_name,
1820 recotrack_identifier = self.get_input_file_names(
1821 RecoTrackQETeacherTask.get_weightfile_xml_identifier(RecoTrackQETeacherTask, fast_bdt_option=self.fast_bdt_option)
1823 with open(recotrack_identifier)
as f:
1824 weight_file_content = f.read()
1825 recotrack_name = write_mva_weightfile_content_to_db(
1826 dbobj_name=
'RecoTrackQualityEstimatorMVAWeightFileIdentifier',
1827 content=weight_file_content,
1828 iovList=(0, 0, 0, -1)
1830 basf2.set_module_parameters(
1832 name=
"TrackQualityEstimatorMVA",
1833 WeightFileIdentifier=recotrack_name,
1837class TrackQEEvaluationBaseTask(Task):
1839 Base class for evaluating a quality estimator ``basf2_mva_evaluate.py`` on a
1840 separate test data set.
1842 Evaluation tasks for VXD, CDC and combined QE can inherit from it.
1850 git_hash = b2luigi.Parameter(
1852 default=get_basf2_git_hash()
1856 n_events_testing = b2luigi.IntParameter()
1858 n_events_training = b2luigi.IntParameter()
1860 experiment_number = b2luigi.IntParameter()
1864 process_type = b2luigi.Parameter(
1870 training_target = b2luigi.Parameter(
1877 exclude_variables = b2luigi.ListParameter(
1883 fast_bdt_option = b2luigi.ListParameter(
1885 hashed=
True, default=[200, 8, 3, 0.1]
1890 def teacher_task(self) -> TrackQETeacherBaseTask:
1892 Property defining specific teacher task to require.
1894 raise NotImplementedError(
1895 "Evaluation Tasks must define a teacher task to require "
1899 def data_collection_task(self) -> Basf2PathTask:
1901 Property defining the specific ``DataCollectionTask`` to require. Must
1902 implemented by the inheriting specific teacher task class.
1904 raise NotImplementedError(
1905 "Evaluation Tasks must define a data collection task to require "
1909 def task_acronym(self):
1911 Acronym to distinguish between cdc, vxd and rec(o) MVA
1913 raise NotImplementedError(
1914 "Evaluation Tasks must define a task acronym."
1919 Generate list of luigi Tasks that this Task depends on.
1921 yield self.teacher_task(
1922 n_events_training=self.n_events_training,
1923 experiment_number=self.experiment_number,
1924 process_type=self.process_type,
1925 training_target=self.training_target,
1926 exclude_variables=self.exclude_variables,
1927 fast_bdt_option=self.fast_bdt_option,
1929 if 'USEREC' in self.process_type:
1930 if 'USERECBB' in self.process_type:
1932 elif 'USERECEE' in self.process_type:
1934 yield CheckExistingFile(
1935 filename=
'datafiles/qe_records_N' + str(self.n_events_testing) +
'_' + process +
'_test_' +
1936 self.task_acronym +
'.root'
1939 yield self.data_collection_task(
1940 num_processes=MasterTask.num_processes,
1941 n_events=self.n_events_testing,
1942 experiment_number=self.experiment_number,
1943 random_seed=self.process_type +
'_test',
1948 Generate list of output files that the task should produce.
1949 The task is considered finished if and only if the outputs all exist.
1951 weightfile_details = create_fbdt_option_string(self.fast_bdt_option)
1952 evaluation_pdf_output = self.teacher_task.weightfile_identifier_basename + weightfile_details +
".zip"
1953 yield self.add_to_output(evaluation_pdf_output)
1955 @b2luigi.on_temporary_files
1958 Run ``basf2_mva_evaluate.py`` subprocess to evaluate QE MVA.
1960 The MVA weight file created from training on the training data set is
1961 evaluated on separate test data.
1963 weightfile_details = create_fbdt_option_string(self.fast_bdt_option)
1964 evaluation_pdf_output_basename = self.teacher_task.weightfile_identifier_basename + weightfile_details +
".zip"
1966 evaluation_pdf_output_path = self.get_output_file_name(evaluation_pdf_output_basename)
1968 if 'USEREC' in self.process_type:
1969 if 'USERECBB' in self.process_type:
1971 elif 'USERECEE' in self.process_type:
1973 datafiles =
'datafiles/qe_records_N' + str(self.n_events_testing) +
'_' + \
1974 process +
'_test_' + self.task_acronym +
'.root'
1976 datafiles = self.get_input_file_names(
1977 self.data_collection_task.get_records_file_name(
1978 self.data_collection_task,
1979 n_events=self.n_events_testing,
1980 random_seed=self.process_type +
'_test_' +
1981 self.task_acronym))[0]
1983 for req
in b2luigi.task.flatten(self.requires()):
1984 if isinstance(req, self.teacher_task):
1987 if hasattr(teacher_task,
'recotrack_option'):
1988 records_files = teacher_task.get_input_file_names(
1989 self.data_collection_task.get_records_file_name(
1990 self.data_collection_task,
1991 n_events=self.n_events_training,
1992 random_seed=self.process_type +
'_' + teacher_task.random_seed,
1993 recotrack_option=teacher_task.recotrack_option))
1995 records_files = teacher_task.get_input_file_names(
1996 self.data_collection_task.get_records_file_name(
1997 self.data_collection_task,
1998 n_events=self.n_events_training,
1999 random_seed=self.process_type +
'_' + teacher_task.random_seed))
2001 "basf2_mva_evaluate.py",
2003 self.get_input_file_names(
2004 self.teacher_task.get_weightfile_xml_identifier(
2006 fast_bdt_option=self.fast_bdt_option))[0],
2007 "--train_datafiles",
2012 self.teacher_task.tree_name,
2014 evaluation_pdf_output_path,
2018 log_file_dir = get_log_file_dir(self)
2022 os.makedirs(log_file_dir, exist_ok=
True)
2025 except FileExistsError:
2026 print(
'Directory ' + log_file_dir +
'already exists.')
2027 stderr_log_file_path = log_file_dir +
"stderr"
2028 stdout_log_file_path = log_file_dir +
"stdout"
2029 with open(stdout_log_file_path,
"w")
as stdout_file:
2030 stdout_file.write(f
'stdout output of the command:\n{" ".join(cmd)}\n\n')
2031 if os.path.exists(stderr_log_file_path):
2033 os.remove(stderr_log_file_path)
2036 with open(stdout_log_file_path,
"a")
as stdout_file:
2037 with open(stderr_log_file_path,
"a")
as stderr_file:
2039 subprocess.run(cmd, check=
True, stdout=stdout_file, stderr=stderr_file)
2040 except subprocess.CalledProcessError
as err:
2041 stderr_file.write(f
"Evaluation failed with error:\n{err}")
2045class VXDTrackQEEvaluationTask(TrackQEEvaluationBaseTask):
2047 Run ``basf2_mva_evaluate.py`` for the VXD quality estimator on separate test data
2051 teacher_task = VXDQETeacherTask
2054 data_collection_task = VXDQEDataCollectionTask
2057 task_acronym =
'vxd'
2060class CDCTrackQEEvaluationTask(TrackQEEvaluationBaseTask):
2062 Run ``basf2_mva_evaluate.py`` for the CDC quality estimator on separate test data
2066 teacher_task = CDCQETeacherTask
2069 data_collection_task = CDCQEDataCollectionTask
2072 task_acronym =
'cdc'
2075class RecoTrackQEEvaluationTask(TrackQEEvaluationBaseTask):
2077 Run ``basf2_mva_evaluate.py`` for the final, combined quality estimator on
2082 teacher_task = RecoTrackQETeacherTask
2085 data_collection_task = RecoTrackQEDataCollectionTask
2088 task_acronym =
'rec'
2090 cdc_training_target = b2luigi.Parameter()
2094 Generate list of luigi Tasks that this Task depends on.
2096 yield self.teacher_task(
2097 n_events_training=self.n_events_training,
2098 experiment_number=self.experiment_number,
2099 process_type=self.process_type,
2100 training_target=self.training_target,
2101 exclude_variables=self.exclude_variables,
2102 cdc_training_target=self.cdc_training_target,
2103 fast_bdt_option=self.fast_bdt_option,
2105 if 'USEREC' in self.process_type:
2106 if 'USERECBB' in self.process_type:
2108 elif 'USERECEE' in self.process_type:
2110 yield CheckExistingFile(
2111 filename=
'datafiles/qe_records_N' + str(self.n_events_testing) +
'_' + process +
'_test_' +
2112 self.task_acronym +
'.root'
2115 yield self.data_collection_task(
2116 num_processes=MasterTask.num_processes,
2117 n_events=self.n_events_testing,
2118 experiment_number=self.experiment_number,
2119 random_seed=self.process_type +
"_test",
2120 cdc_training_target=self.cdc_training_target,
2124class PlotsFromHarvestingValidationBaseTask(Basf2Task):
2126 Create a PDF file with validation plots for a quality estimator produced
2127 from the ROOT ntuples produced by a harvesting validation task
2130 n_events_testing = b2luigi.IntParameter()
2132 n_events_training = b2luigi.IntParameter()
2134 experiment_number = b2luigi.IntParameter()
2138 process_type = b2luigi.Parameter(
2145 exclude_variables = b2luigi.ListParameter(
2151 fast_bdt_option = b2luigi.ListParameter(
2153 hashed=
True, default=[200, 8, 3, 0.1]
2157 primaries_only = b2luigi.BoolParameter(
2164 def harvesting_validation_task_instance(self) -> HarvestingValidationBaseTask:
2166 Specifies related harvesting validation task which produces the ROOT
2167 files with the data that is plotted by this task.
2169 raise NotImplementedError(
"Must define a QI harvesting validation task for which to do the plots")
2172 def output_pdf_file_basename(self):
2174 Name of the output PDF file containing the validation plots
2176 validation_harvest_basename = self.harvesting_validation_task_instance.validation_output_file_name
2177 return validation_harvest_basename.replace(
".root",
"_plots.pdf")
2181 Generate list of luigi Tasks that this Task depends on.
2183 yield self.harvesting_validation_task_instance
2187 Generate list of output files that the task should produce.
2188 The task is considered finished if and only if the outputs all exist.
2190 yield self.add_to_output(self.output_pdf_file_basename)
2192 @b2luigi.on_temporary_files
2195 Use basf2_mva teacher to create MVA weightfile from collected training
2198 Main process that is dispatched by the ``run`` method that is inherited
2202 validation_harvest_basename = self.harvesting_validation_task_instance.validation_output_file_name
2203 validation_harvest_path = self.get_input_file_names(validation_harvest_basename)[0]
2207 'is_fake',
'is_clone',
'is_matched',
'quality_indicator',
2208 'experiment_number',
'run_number',
'event_number',
'pr_store_array_number',
2209 'pt_estimate',
'z0_estimate',
'd0_estimate',
'tan_lambda_estimate',
2210 'phi0_estimate',
'pt_truth',
'z0_truth',
'd0_truth',
'tan_lambda_truth',
2214 pr_df = uproot.open(validation_harvest_path)[
'pr_tree/pr_tree'].arrays(pr_columns, library=
'pd')
2216 'experiment_number',
2219 'pr_store_array_number',
2224 mc_df = uproot.open(validation_harvest_path)[
'mc_tree/mc_tree'].arrays(mc_columns, library=
'pd')
2225 if self.primaries_only:
2226 mc_df = mc_df[mc_df.is_primary.eq(
True)]
2229 qi_cuts = np.linspace(0., 1, 20, endpoint=
False)
2235 output_pdf_file_path = self.get_output_file_name(self.output_pdf_file_basename)
2236 with PdfPages(output_pdf_file_path, keep_empty=
False)
as pdf:
2241 titlepage_fig, titlepage_ax = plt.subplots()
2242 titlepage_ax.axis(
"off")
2243 title = f
"Quality Estimator validation plots from {self.__class__.__name__}"
2244 titlepage_ax.set_title(title)
2245 teacher_task = self.harvesting_validation_task_instance.teacher_task
2246 weightfile_identifier = teacher_task.get_weightfile_xml_identifier(teacher_task, fast_bdt_option=self.fast_bdt_option)
2248 "Date": datetime.today().strftime(
"%Y-%m-%d %H:%M"),
2249 "Created by steering file": os.path.realpath(__file__),
2250 "Created from data in": validation_harvest_path,
2251 "Background directory": MasterTask.bkgfiles_by_exp[self.experiment_number],
2252 "weight file": weightfile_identifier,
2254 if hasattr(self,
'exclude_variables'):
2255 meta_data[
"Excluded variables"] =
", ".join(self.exclude_variables)
2256 meta_data_string = (format_dictionary(meta_data) +
2257 "\n\n(For all MVA training parameters look into the produced weight file)")
2258 luigi_params = get_serialized_parameters(self)
2259 luigi_param_string = (f
"\n\nb2luigi parameters for {self.__class__.__name__}\n" +
2260 format_dictionary(luigi_params))
2261 title_page_text = meta_data_string + luigi_param_string
2262 titlepage_ax.text(0, 1, title_page_text, ha=
"left", va=
"top", wrap=
True, fontsize=8)
2263 pdf.savefig(titlepage_fig)
2264 plt.close(titlepage_fig)
2266 fake_rates = get_uncertain_means_for_qi_cuts(pr_df,
"is_fake", qi_cuts)
2267 fake_fig, fake_ax = plt.subplots()
2268 fake_ax.set_title(
"Fake rate")
2269 plot_with_errobands(fake_rates, ax=fake_ax)
2270 fake_ax.set_ylabel(
"fake rate")
2271 fake_ax.set_xlabel(
"quality indicator requirement")
2272 pdf.savefig(fake_fig, bbox_inches=
"tight")
2276 clone_rates = get_uncertain_means_for_qi_cuts(pr_df,
"is_clone", qi_cuts)
2277 clone_fig, clone_ax = plt.subplots()
2278 clone_ax.set_title(
"Clone rate")
2279 plot_with_errobands(clone_rates, ax=clone_ax)
2280 clone_ax.set_ylabel(
"clone rate")
2281 clone_ax.set_xlabel(
"quality indicator requirement")
2282 pdf.savefig(clone_fig, bbox_inches=
"tight")
2283 plt.close(clone_fig)
2290 pr_track_identifiers = [
'experiment_number',
'run_number',
'event_number',
'pr_store_array_number']
2292 left=mc_df, right=pr_df[pr_track_identifiers + [
'quality_indicator']],
2294 on=pr_track_identifiers
2297 missing_fractions = (
2298 _my_uncertain_mean(mc_df[
2299 mc_df.quality_indicator.isnull() | (mc_df.quality_indicator > qi_cut)][
'is_missing'])
2300 for qi_cut
in qi_cuts
2303 findeff_fig, findeff_ax = plt.subplots()
2304 findeff_ax.set_title(
"Finding efficiency")
2305 finding_efficiencies = 1.0 - upd.Series(data=missing_fractions, index=qi_cuts)
2306 plot_with_errobands(finding_efficiencies, ax=findeff_ax)
2307 findeff_ax.set_ylabel(
"finding efficiency")
2308 findeff_ax.set_xlabel(
"quality indicator requirement")
2309 pdf.savefig(findeff_fig, bbox_inches=
"tight")
2310 plt.close(findeff_fig)
2315 fake_roc_fig, fake_roc_ax = plt.subplots()
2316 fake_roc_ax.set_title(
"Fake rate vs. finding efficiency ROC curve")
2317 fake_roc_ax.errorbar(x=finding_efficiencies.nominal_value, y=fake_rates.nominal_value,
2318 xerr=finding_efficiencies.std_dev, yerr=fake_rates.std_dev, elinewidth=0.8)
2319 fake_roc_ax.set_xlabel(
'finding efficiency')
2320 fake_roc_ax.set_ylabel(
'fake rate')
2321 pdf.savefig(fake_roc_fig, bbox_inches=
"tight")
2322 plt.close(fake_roc_fig)
2325 clone_roc_fig, clone_roc_ax = plt.subplots()
2326 clone_roc_ax.set_title(
"Clone rate vs. finding efficiency ROC curve")
2327 clone_roc_ax.errorbar(x=finding_efficiencies.nominal_value, y=clone_rates.nominal_value,
2328 xerr=finding_efficiencies.std_dev, yerr=clone_rates.std_dev, elinewidth=0.8)
2329 clone_roc_ax.set_xlabel(
'finding efficiency')
2330 clone_roc_ax.set_ylabel(
'clone rate')
2331 pdf.savefig(clone_roc_fig, bbox_inches=
"tight")
2332 plt.close(clone_roc_fig)
2337 kinematic_qi_cuts = [0, 0.5, 0.9]
2341 params = [
'd0',
'z0',
'pt',
'tan_lambda',
'phi0']
2346 "tan_lambda":
r"$\tan{\lambda}$",
2353 "tan_lambda":
"rad",
2356 n_kinematic_bins = 75
2358 "pt": np.linspace(0, np.percentile(pr_df[
'pt_truth'].dropna(), 95), n_kinematic_bins),
2359 "z0": np.linspace(-0.1, 0.1, n_kinematic_bins),
2360 "d0": np.linspace(0, 0.01, n_kinematic_bins),
2361 "tan_lambda": np.linspace(-2, 3, n_kinematic_bins),
2362 "phi0": np.linspace(0, 2 * np.pi, n_kinematic_bins)
2366 kinematic_qi_cuts = [0, 0.5, 0.8]
2367 blue, yellow, green = plt.get_cmap(
"tab10").colors[0:3]
2368 for param
in params:
2369 fig, axarr = plt.subplots(ncols=len(kinematic_qi_cuts), sharey=
True, sharex=
True, figsize=(14, 6))
2370 fig.suptitle(f
"{label_by_param[param]} distributions")
2371 for i, qi
in enumerate(kinematic_qi_cuts):
2373 ax.set_title(f
"QI > {qi}")
2374 incut = pr_df[(pr_df[
'quality_indicator'] > qi)]
2375 incut_matched = incut[incut.is_matched.eq(
True)]
2376 incut_clones = incut[incut.is_clone.eq(
True)]
2377 incut_fake = incut[incut.is_fake.eq(
True)]
2380 if any(series.empty
for series
in (incut, incut_matched, incut_clones, incut_fake)):
2381 ax.text(0.5, 0.5,
"Not enough data in bin", ha=
"center", va=
"center", transform=ax.transAxes)
2384 bins = bins_by_param[param]
2385 stacked_histogram_series_tuple = (
2386 incut_matched[f
'{param}_estimate'],
2387 incut_clones[f
'{param}_estimate'],
2388 incut_fake[f
'{param}_estimate'],
2390 histvals, _, _ = ax.hist(stacked_histogram_series_tuple,
2392 bins=bins, range=(bins.min(), bins.max()),
2393 color=(blue, green, yellow),
2394 label=(
"matched",
"clones",
"fakes"))
2395 ax.set_xlabel(f
'{label_by_param[param]} estimate / ({unit_by_param[param]})')
2396 ax.set_ylabel(
'# tracks')
2397 axarr[0].legend(loc=
"upper center", bbox_to_anchor=(0, -0.15))
2398 pdf.savefig(fig, bbox_inches=
"tight")
2402class VXDQEValidationPlotsTask(PlotsFromHarvestingValidationBaseTask):
2404 Create a PDF file with validation plots for the VXDTF2 track quality
2405 estimator produced from the ROOT ntuples produced by a VXDTF2 track QE
2406 harvesting validation task
2410 def harvesting_validation_task_instance(self):
2412 Harvesting validation task to require, which produces the ROOT files
2413 with variables to produce the VXD QE validation plots.
2415 return VXDQEHarvestingValidationTask(
2416 n_events_testing=self.n_events_testing,
2417 n_events_training=self.n_events_training,
2418 process_type=self.process_type,
2419 experiment_number=self.experiment_number,
2420 exclude_variables=self.exclude_variables,
2421 num_processes=MasterTask.num_processes,
2422 fast_bdt_option=self.fast_bdt_option,
2426class CDCQEValidationPlotsTask(PlotsFromHarvestingValidationBaseTask):
2428 Create a PDF file with validation plots for the CDC track quality estimator
2429 produced from the ROOT ntuples produced by a CDC track QE harvesting
2433 training_target = b2luigi.Parameter()
2436 def harvesting_validation_task_instance(self):
2438 Harvesting validation task to require, which produces the ROOT files
2439 with variables to produce the CDC QE validation plots.
2441 return CDCQEHarvestingValidationTask(
2442 n_events_testing=self.n_events_testing,
2443 n_events_training=self.n_events_training,
2444 process_type=self.process_type,
2445 experiment_number=self.experiment_number,
2446 training_target=self.training_target,
2447 exclude_variables=self.exclude_variables,
2448 num_processes=MasterTask.num_processes,
2449 fast_bdt_option=self.fast_bdt_option,
2453class RecoTrackQEValidationPlotsTask(PlotsFromHarvestingValidationBaseTask):
2455 Create a PDF file with validation plots for the reco MVA track quality
2456 estimator produced from the ROOT ntuples produced by a reco track QE
2457 harvesting validation task
2460 cdc_training_target = b2luigi.Parameter()
2463 def harvesting_validation_task_instance(self):
2465 Harvesting validation task to require, which produces the ROOT files
2466 with variables to produce the final MVA track QE validation plots.
2468 return RecoTrackQEHarvestingValidationTask(
2469 n_events_testing=self.n_events_testing,
2470 n_events_training=self.n_events_training,
2471 process_type=self.process_type,
2472 experiment_number=self.experiment_number,
2473 cdc_training_target=self.cdc_training_target,
2474 exclude_variables=self.exclude_variables,
2475 num_processes=MasterTask.num_processes,
2476 fast_bdt_option=self.fast_bdt_option,
2480class QEWeightsLocalDBCreatorTask(Basf2Task):
2482 Collect weightfile identifiers from different teacher tasks and merge them
2483 into a local database for testing.
2486 n_events_training = b2luigi.IntParameter()
2488 experiment_number = b2luigi.IntParameter()
2492 process_type = b2luigi.Parameter(
2498 cdc_training_target = b2luigi.Parameter()
2500 fast_bdt_option = b2luigi.ListParameter(
2502 hashed=
True, default=[200, 8, 3, 0.1]
2508 Required teacher tasks
2510 yield VXDQETeacherTask(
2511 n_events_training=self.n_events_training,
2512 process_type=self.process_type,
2513 experiment_number=self.experiment_number,
2514 exclude_variables=MasterTask.exclude_variables_vxd,
2515 fast_bdt_option=self.fast_bdt_option,
2517 yield CDCQETeacherTask(
2518 n_events_training=self.n_events_training,
2519 process_type=self.process_type,
2520 experiment_number=self.experiment_number,
2521 training_target=self.cdc_training_target,
2522 exclude_variables=MasterTask.exclude_variables_cdc,
2523 fast_bdt_option=self.fast_bdt_option,
2525 yield RecoTrackQETeacherTask(
2526 n_events_training=self.n_events_training,
2527 process_type=self.process_type,
2528 experiment_number=self.experiment_number,
2529 cdc_training_target=self.cdc_training_target,
2530 exclude_variables=MasterTask.exclude_variables_rec,
2531 fast_bdt_option=self.fast_bdt_option,
2538 yield self.add_to_output(
"localdb.tar")
2542 Create local database
2544 current_path = Path.cwd()
2545 localdb_archive_path = Path(self.get_output_file_name(
"localdb.tar")).absolute()
2546 output_dir = localdb_archive_path.parent
2551 for task
in (VXDQETeacherTask, CDCQETeacherTask, RecoTrackQETeacherTask):
2553 weightfile_xml_identifier_path = os.path.abspath(self.get_input_file_names(
2554 task.get_weightfile_xml_identifier(task, fast_bdt_option=self.fast_bdt_option))[0])
2557 os.chdir(output_dir)
2560 weightfile_xml_identifier_path,
2561 task.weightfile_identifier_basename,
2562 self.experiment_number, 0,
2563 self.experiment_number, -1,
2566 os.chdir(current_path)
2569 shutil.make_archive(
2570 base_name=localdb_archive_path.as_posix().split(
'.')[0],
2572 root_dir=output_dir,
2579 Remove local database and tar archives in output directory
2581 localdb_archive_path = Path(self.get_output_file_name(
"localdb.tar"))
2582 localdb_path = localdb_archive_path.parent /
"localdb"
2584 if localdb_path.exists():
2585 print(f
"Deleting localdb\n{localdb_path}\nwith contents\n ",
2586 "\n ".join(f.name
for f
in localdb_path.iterdir()))
2587 shutil.rmtree(localdb_path, ignore_errors=
False)
2589 if localdb_archive_path.is_file():
2590 print(f
"Deleting {localdb_archive_path}")
2591 os.remove(localdb_archive_path)
2593 def on_failure(self, exception):
2595 Cleanup: Remove local database to prevent existing outputs when task did not finish successfully
2599 super().on_failure(exception)
2602class MasterTask(b2luigi.WrapperTask):
2604 Wrapper task that needs to finish for b2luigi to finish running this steering file.
2606 It is done if the outputs of all required subtasks exist. It is thus at the
2607 top of the luigi task graph. Edit the ``requires`` method to steer which
2608 tasks and with which parameters you want to run.
2613 process_type = b2luigi.get_setting(
2615 "process_type", default=
'BBBAR'
2619 n_events_training = b2luigi.get_setting(
2621 "n_events_training", default=20000
2625 n_events_testing = b2luigi.get_setting(
2627 "n_events_testing", default=5000
2631 n_events_per_task = b2luigi.get_setting(
2633 "n_events_per_task", default=100
2637 num_processes = b2luigi.get_setting(
2639 "basf2_processes_per_worker", default=0
2643 datafiles = b2luigi.get_setting(
"datafiles")
2645 bkgfiles_by_exp = b2luigi.get_setting(
"bkgfiles_by_exp")
2647 bkgfiles_by_exp = {int(key): val
for (key, val)
in bkgfiles_by_exp.items()}
2649 exclude_variables_cdc = [
2650 "has_matching_segment",
2655 "cont_layer_variance",
2660 "cont_layer_max_vs_last",
2661 "cont_layer_first_vs_min",
2663 "cont_layer_occupancy",
2665 "super_layer_variance",
2666 "super_layer_max_vs_last",
2667 "super_layer_first_vs_min",
2668 "super_layer_occupancy",
2669 "drift_length_mean",
2670 "drift_length_variance",
2674 "norm_drift_length_mean",
2675 "norm_drift_length_variance",
2676 "norm_drift_length_max",
2677 "norm_drift_length_min",
2678 "norm_drift_length_sum",
2693 exclude_variables_vxd = [
2694 'energyLoss_max',
'energyLoss_min',
'energyLoss_mean',
'energyLoss_std',
'energyLoss_sum',
2695 'size_max',
'size_min',
'size_mean',
'size_std',
'size_sum',
2696 'seedCharge_max',
'seedCharge_min',
'seedCharge_mean',
'seedCharge_std',
'seedCharge_sum',
2697 'tripletFit_P_Mag',
'tripletFit_P_Eta',
'tripletFit_P_Phi',
'tripletFit_P_X',
'tripletFit_P_Y',
'tripletFit_P_Z']
2699 exclude_variables_rec = [
2711 'N_diff_PXD_SVD_RecoTracks',
2712 'N_diff_SVD_CDC_RecoTracks',
2714 'Fit_NFailedPoints',
2716 'N_TrackPoints_without_KalmanFitterInfo',
2717 'N_Hits_without_TrackPoint',
2718 'SVD_CDC_CDCwall_Chi2',
2719 'SVD_CDC_CDCwall_Pos_diff_Z',
2720 'SVD_CDC_CDCwall_Pos_diff_Pt',
2721 'SVD_CDC_CDCwall_Pos_diff_Theta',
2722 'SVD_CDC_CDCwall_Pos_diff_Phi',
2723 'SVD_CDC_CDCwall_Pos_diff_Mag',
2724 'SVD_CDC_CDCwall_Pos_diff_Eta',
2725 'SVD_CDC_CDCwall_Mom_diff_Z',
2726 'SVD_CDC_CDCwall_Mom_diff_Pt',
2727 'SVD_CDC_CDCwall_Mom_diff_Theta',
2728 'SVD_CDC_CDCwall_Mom_diff_Phi',
2729 'SVD_CDC_CDCwall_Mom_diff_Mag',
2730 'SVD_CDC_CDCwall_Mom_diff_Eta',
2731 'SVD_CDC_POCA_Pos_diff_Z',
2732 'SVD_CDC_POCA_Pos_diff_Pt',
2733 'SVD_CDC_POCA_Pos_diff_Theta',
2734 'SVD_CDC_POCA_Pos_diff_Phi',
2735 'SVD_CDC_POCA_Pos_diff_Mag',
2736 'SVD_CDC_POCA_Pos_diff_Eta',
2737 'SVD_CDC_POCA_Mom_diff_Z',
2738 'SVD_CDC_POCA_Mom_diff_Pt',
2739 'SVD_CDC_POCA_Mom_diff_Theta',
2740 'SVD_CDC_POCA_Mom_diff_Phi',
2741 'SVD_CDC_POCA_Mom_diff_Mag',
2742 'SVD_CDC_POCA_Mom_diff_Eta',
2749 'SVD_FitSuccessful',
2750 'CDC_FitSuccessful',
2753 'is_Vzero_Daughter',
2765 'weight_firstCDCHit',
2766 'weight_lastSVDHit',
2769 'smoothedChi2_mean',
2771 'smoothedChi2_median',
2772 'smoothedChi2_n_zeros',
2773 'smoothedChi2_firstCDCHit',
2774 'smoothedChi2_lastSVDHit',
2776 [
"SVD_" + x
for x
in exclude_variables_vxd] + \
2777 [
"SVDbefore_" + x
for x
in exclude_variables_vxd]
2781 Generate list of tasks that needs to be done for luigi to finish running
2784 cdc_training_targets = [
2789 fast_bdt_options = []
2798 fast_bdt_options.append([350, 6, 5, 0.1])
2800 experiment_numbers = b2luigi.get_setting(
"experiment_numbers")
2803 for experiment_number, cdc_training_target, fast_bdt_option
in itertools.product(
2804 experiment_numbers, cdc_training_targets, fast_bdt_options
2807 if b2luigi.get_setting(
"test_selected_task", default=
False):
2810 for cut
in [
'000',
'070',
'090',
'095']:
2811 yield RecoTrackQEDataCollectionTask(
2812 num_processes=self.num_processes,
2813 n_events=self.n_events_testing,
2814 experiment_number=experiment_number,
2815 random_seed=self.process_type +
'_test',
2816 recotrack_option=
'useCDC_useVXD_deleteCDCQI'+cut,
2817 cdc_training_target=cdc_training_target,
2818 fast_bdt_option=fast_bdt_option,
2820 yield CDCQEDataCollectionTask(
2821 num_processes=self.num_processes,
2822 n_events=self.n_events_testing,
2823 experiment_number=experiment_number,
2824 random_seed=self.process_type +
'_test',
2826 yield CDCQETeacherTask(
2827 n_events_training=self.n_events_training,
2828 process_type=self.process_type,
2829 experiment_number=experiment_number,
2830 exclude_variables=self.exclude_variables_cdc,
2831 training_target=cdc_training_target,
2832 fast_bdt_option=fast_bdt_option,
2836 if 'DATA' in self.process_type:
2837 yield VXDQEDataCollectionTask(
2838 num_processes=self.num_processes,
2839 n_events=self.n_events_testing,
2840 experiment_number=experiment_number,
2841 random_seed=self.process_type +
'_test',
2843 yield CDCQEDataCollectionTask(
2844 num_processes=self.num_processes,
2845 n_events=self.n_events_testing,
2846 experiment_number=experiment_number,
2847 random_seed=self.process_type +
'_test',
2849 yield RecoTrackQEDataCollectionTask(
2850 num_processes=self.num_processes,
2851 n_events=self.n_events_testing,
2852 experiment_number=experiment_number,
2853 random_seed=self.process_type +
'_test',
2854 recotrack_option=
'deleteCDCQI080',
2855 cdc_training_target=cdc_training_target,
2856 fast_bdt_option=fast_bdt_option,
2859 yield QEWeightsLocalDBCreatorTask(
2860 n_events_training=self.n_events_training,
2861 process_type=self.process_type,
2862 experiment_number=experiment_number,
2863 cdc_training_target=cdc_training_target,
2864 fast_bdt_option=fast_bdt_option,
2867 if b2luigi.get_setting(
"run_validation_tasks", default=
True):
2868 yield RecoTrackQEValidationPlotsTask(
2869 n_events_training=self.n_events_training,
2870 n_events_testing=self.n_events_testing,
2871 process_type=self.process_type,
2872 experiment_number=experiment_number,
2873 cdc_training_target=cdc_training_target,
2874 exclude_variables=self.exclude_variables_rec,
2875 fast_bdt_option=fast_bdt_option,
2877 yield CDCQEValidationPlotsTask(
2878 n_events_training=self.n_events_training,
2879 n_events_testing=self.n_events_testing,
2880 process_type=self.process_type,
2881 experiment_number=experiment_number,
2882 exclude_variables=self.exclude_variables_cdc,
2883 training_target=cdc_training_target,
2884 fast_bdt_option=fast_bdt_option,
2886 yield VXDQEValidationPlotsTask(
2887 n_events_training=self.n_events_training,
2888 n_events_testing=self.n_events_testing,
2889 process_type=self.process_type,
2890 exclude_variables=self.exclude_variables_vxd,
2891 experiment_number=experiment_number,
2892 fast_bdt_option=fast_bdt_option,
2895 if b2luigi.get_setting(
"run_mva_evaluate", default=
True):
2898 yield RecoTrackQEEvaluationTask(
2899 n_events_training=self.n_events_training,
2900 n_events_testing=self.n_events_testing,
2901 process_type=self.process_type,
2902 experiment_number=experiment_number,
2903 cdc_training_target=cdc_training_target,
2904 exclude_variables=self.exclude_variables_rec,
2905 fast_bdt_option=fast_bdt_option,
2907 yield CDCTrackQEEvaluationTask(
2908 n_events_training=self.n_events_training,
2909 n_events_testing=self.n_events_testing,
2910 process_type=self.process_type,
2911 experiment_number=experiment_number,
2912 exclude_variables=self.exclude_variables_cdc,
2913 fast_bdt_option=fast_bdt_option,
2914 training_target=cdc_training_target,
2916 yield VXDTrackQEEvaluationTask(
2917 n_events_training=self.n_events_training,
2918 n_events_testing=self.n_events_testing,
2919 process_type=self.process_type,
2920 experiment_number=experiment_number,
2921 exclude_variables=self.exclude_variables_vxd,
2922 fast_bdt_option=fast_bdt_option,
2926if __name__ ==
"__main__":
2929 nEventsTestOnData = b2luigi.get_setting(
"n_events_test_on_data", default=-1)
2930 if nEventsTestOnData > 0
and 'DATA' in b2luigi.get_setting(
"process_type", default=
"BBBAR"):
2931 from ROOT
import Belle2
2933 environment.setNumberEventsOverride(nEventsTestOnData)
2936 globaltags = b2luigi.get_setting(
"globaltags", default=[])
2937 if len(globaltags) > 0:
2938 basf2.conditions.reset()
2939 for gt
in globaltags:
2940 basf2.conditions.prepend_globaltag(gt)
2941 workers = b2luigi.get_setting(
"workers", default=1)
2942 b2luigi.process(MasterTask(), workers=workers)
get_background_files(folder=None, output_file_info=True)
static Environment & Instance()
Static method to get a reference to the Environment instance.
add_simulation(path, components=None, bkgfiles=None, bkgOverlay=True, forceSetPXDDataReduction=False, usePXDDataReduction=True, cleanupPXDDataReduction=True, generate_2nd_cdc_hits=False, simulateT0jitter=True, isCosmics=False, FilterEvents=False, usePXDGatedMode=False, skipExperimentCheckForBG=False, save_slow_pions_in_mc=False, save_all_charged_particles_in_mc=False)