Belle II Software development
record.py
1#!/usr/bin/env python3
2
3
10
11
12from ROOT import Belle2 # noqa: make Belle2 namespace available # noqa
13from ROOT.Belle2 import TrackFindingCDC as TFCDC
14
15import sys
16import random
17import numpy as np
18
19from tracking.validation.utilities import is_primary
20
21import tracking.harvest.harvesting as harvesting
22import tracking.harvest.refiners as refiners
23from tracking.harvest.run import HarvestingRun
24
25
26import logging
27
28
29def get_logger():
30 return logging.getLogger(__name__)
31
32
33CONTACT = "oliver.frost@desy.de"
34
35
37 """Harvester to generate, postprocess and inspect MC events for track-segment evaluation"""
38
39 n_events = 10000
40
41 generator_module = "generic"
42
43 @property
45 """Get the output ROOT filename"""
46 return 'legendre_binning.root'
47
48 def harvesting_module(self, path=None):
49 """Harvest and post-process the generated events"""
51 if path:
52 path.add_module(harvesting_module)
53 return harvesting_module
54
55 def create_argument_parser(self, **kwds):
56 """Convert command-line arguments to basf2 argument list"""
57 argument_parser = super().create_argument_parser(**kwds)
58 return argument_parser
59
60 def create_path(self):
61 """
62 Sets up a path that plays back pregenerated events or generates events
63 based on the properties in the base class.
64 """
65 path = super().create_path()
66
67 path.add_module("TFCDC_WireHitPreparer",
68 logLevel=8,
69 flightTimeEstimation="outwards",
70 UseNLoops=1)
71
72 path.add_module('TFCDC_AxialTrackCreatorMCTruth',
73 logLevel=8,
74 useOnlyBeforeTOP=True,
75 fit=True,
76 reconstructedDriftLength=True,
77 reconstructedPositions=True)
78
79 return path
80
81
82class LegendreBinningValidationModule(harvesting.HarvestingModule):
83
84 """Module to collect information about the generated segments and
85 compose validation plots on terminate."""
86
87 def __init__(self, output_file_name):
88 """Constructor"""
89 super().__init__(foreach='CDCTrackVector',
90 output_file_name=output_file_name)
91
92
93 self.mc_track_lookup = None
94
95 origin_track_fitter = TFCDC.CDCRiemannFitter()
96 origin_track_fitter.setOriginConstrained()
97
98 self.track_fitter = origin_track_fitter
99
100 curv_bounds = []
101 with open('fine_curv_bounds.txt') as curv_bounds_file:
102 for curv_bound_line in curv_bounds_file:
103 curv_bounds.append(float(curv_bound_line))
104
105 bin_bounds = list(zip(curv_bounds[0::2], curv_bounds[1::2]))
106 bin_bounds = sorted(bin_bounds)
107
108
109 self.lower_curv_bounds = np.array([bin[0] for bin in bin_bounds])
110
111 self.upper_curv_bounds = np.array([bin[1] for bin in bin_bounds])
112
113 assert (len(self.lower_curv_bounds) == len(self.upper_curv_bounds))
114
115 def initialize(self):
116 """Receive signal at the start of event processing"""
117 super().initialize()
118
119 self.mc_track_lookup = TFCDC.CDCMCTrackLookUp.getInstance()
120
121 self.mc_hit_lookup = TFCDC.CDCMCHitLookUp.getInstance()
122
123 def prepare(self):
124 """Initialize the MC-hit lookup method"""
125 TFCDC.CDCMCHitLookUp.getInstance().fill()
126
127 def pick(self, track):
128 """Select tracks with at least 4 segments and associated primary MC particle"""
129 mc_track_lookup = self.mc_track_lookup
130 mc_particle = mc_track_lookup.getMCParticle(track)
131
132 # Check that mc_particle is not a nullptr
133 return mc_particle and is_primary(mc_particle) and track.size() > 3
134
135 def peel(self, track):
136 """Aggregate the track and MC information for track-segment analysis"""
137 track_fitter = self.track_fitter
138
139 rl_drift_circle = 1
140 unit_variance = 0
141 observations2D = TFCDC.CDCObservations2D(rl_drift_circle, unit_variance)
142
143 for recoHit3D in track:
144 observations2D.append(recoHit3D)
145
146 trajectory2D = track_fitter.fit(observations2D)
147 trajectory2D.setLocalOrigin(TFCDC.Vector2D(0, 0))
148
149 n_hits = track.size()
150 pt = trajectory2D.getAbsMom2D()
151 curv = trajectory2D.getCurvature()
152 curl_curv = abs(self.lower_curv_bounds[0])
153 bin_curv = curv if abs(curv) < curl_curv else abs(curv)
154 curv_var = trajectory2D.getLocalVariance(0)
155 impact = trajectory2D.getGlobalImpact()
156 phi0 = trajectory2D.getLocalCircle().phi0()
157
158 circle = trajectory2D.getLocalCircle()
159 n12 = circle.n12()
160
161 cross_curvs = []
162 for recoHit3D in track:
163 wire_ref_pos = recoHit3D.getRefPos2D()
164 drift_length = recoHit3D.getSignedRecoDriftLength()
165 r = wire_ref_pos.norm()
166 cross_curv = -2 * (n12.dot(wire_ref_pos) - drift_length) / (r * r - drift_length * drift_length)
167 cross_curvs.append(cross_curv)
168
169 cross_curvs = np.array(cross_curvs)
170 cross_curv = np.median(cross_curvs)
171 cross_curv_var = np.median(np.abs(cross_curvs - cross_curv))
172
173 basic_curv_precision = TFCDC.PrecisionUtil.getBasicCurvPrecision(cross_curv)
174 origin_curv_precision = TFCDC.PrecisionUtil.getOriginCurvPrecision(cross_curv)
175 non_origin_curv_precision = TFCDC.PrecisionUtil.getNonOriginCurvPrecision(cross_curv)
176
177 bin_id = np.digitize([abs(cross_curv)], self.lower_curv_bounds)
178 if bin_id == 0:
179 max_curv_precision = 0.00007
180 else:
181 max_curv_precision = self.upper_curv_bounds[bin_id - 1] - self.lower_curv_bounds[bin_id - 1]
182
183 random_bin_id = random.randrange(len(self.upper_curv_bounds))
184 random_lower_curv_bound = self.lower_curv_bounds[random_bin_id]
185 random_upper_curv_bound = self.upper_curv_bounds[random_bin_id]
186 curv_dense = random.uniform(random_lower_curv_bound, random_upper_curv_bound)
187 curv_width = random_upper_curv_bound - random_lower_curv_bound
188
189 return dict(
190 n_hits=n_hits,
191 curvature_estimate=curv,
192 curvature_variance=curv_var,
193 abs_curvature_estimate=abs(curv),
194 inv_curv=1.0 / abs(curv),
195 cross_curv=cross_curv,
196 cross_curv_var=cross_curv_var,
197 basic_curv_precision=basic_curv_precision,
198 origin_curv_precision=origin_curv_precision,
199 non_origin_curv_precision=non_origin_curv_precision,
200 max_curv_precision=max_curv_precision,
201 pt=pt,
202 curv_bin=bin_curv,
203 curv_dense=curv_dense,
204 curv_width=curv_width,
205 impact_estimate=impact,
206 phi0=phi0,
207 )
208
209 # Refiners to be executed at the end of the harvesting / termination of the module
210
211 save_tree = refiners.save_tree()
212
213 save_histograms = refiners.save_histograms(outlier_z_score=5.0, allow_discrete=True)
214
215
216 save_profiles = refiners.save_profiles(x=['curvature_estimate', 'phi0'],
217 y='curvature_variance',
218 outlier_z_score=5.0)
219
220
221 save_cross_curv_profile = refiners.save_profiles(x=['cross_curv'],
222 y=['cross_curv_var',
223 'curvature_estimate',
224 'basic_curv_precision',
225 'origin_curv_precision',
226 'non_origin_curv_precision',
227 'max_curv_precision',
228 ],
229 outlier_z_score=5.0)
230
231
232 save_scatter = refiners.save_scatters(x=['curvature_estimate'], y='n_hits')
233
234
235def main():
237 run.configure_and_execute_from_commandline()
238
239
240if __name__ == "__main__":
241 logging.basicConfig(stream=sys.stdout, level=logging.INFO, format='%(levelname)s:%(message)s')
242 main()
mc_track_lookup
by default, there is no method to find matching MC tracks
Definition: record.py:93
mc_hit_lookup
Method to find matching MC hits.
Definition: record.py:121
track_fitter
Use the CDCReimannFitter with a constrained origin for track fitting.
Definition: record.py:98
lower_curv_bounds
cached copy of lower bounds
Definition: record.py:109
upper_curv_bounds
cached copy of upper bounds
Definition: record.py:111
def __init__(self, output_file_name)
Definition: record.py:87
def create_argument_parser(self, **kwds)
Definition: record.py:55
def harvesting_module(self, path=None)
Definition: record.py:48
None output_file_name
Disable the writing of an output ROOT file.
Definition: run.py:20
None output_file_name
There is no default for the name of the output TFile.
Definition: mixins.py:60
Definition: main.py:1