Belle II Software  release-08-01-10
record.py
1 #!/usr/bin/env python3
2 # -*- coding: utf-8 -*-
3 
4 
11 
12 
13 from ROOT import Belle2 # make Belle2 namespace available # noqa
14 from ROOT.Belle2 import TrackFindingCDC as TFCDC
15 
16 import sys
17 import random
18 import numpy as np
19 
20 from tracking.validation.utilities import is_primary
21 
22 import tracking.harvest.harvesting as harvesting
23 import tracking.harvest.refiners as refiners
24 from tracking.harvest.run import HarvestingRun
25 
26 
27 import logging
28 
29 
30 def get_logger():
31  return logging.getLogger(__name__)
32 
33 
34 CONTACT = "oliver.frost@desy.de"
35 
36 
38  """Harvester to generate, postprocess and inspect MC events for track-segment evaluation"""
39 
40  n_events = 10000
41 
42  generator_module = "generic"
43 
44  @property
45  def output_file_name(self):
46  """Get the output ROOT filename"""
47  return 'legendre_binning.root'
48 
49  def harvesting_module(self, path=None):
50  """Harvest and post-process the generated events"""
52  if path:
53  path.add_module(harvesting_module)
54  return harvesting_module
55 
56  def create_argument_parser(self, **kwds):
57  """Convert command-line arguments to basf2 argument list"""
58  argument_parser = super().create_argument_parser(**kwds)
59  return argument_parser
60 
61  def create_path(self):
62  """
63  Sets up a path that plays back pregenerated events or generates events
64  based on the properties in the base class.
65  """
66  path = super().create_path()
67 
68  path.add_module("TFCDC_WireHitPreparer",
69  logLevel=8,
70  flightTimeEstimation="outwards",
71  UseNLoops=1)
72 
73  path.add_module('TFCDC_AxialTrackCreatorMCTruth',
74  logLevel=8,
75  useOnlyBeforeTOP=True,
76  fit=True,
77  reconstructedDriftLength=True,
78  reconstructedPositions=True)
79 
80  return path
81 
82 
83 class LegendreBinningValidationModule(harvesting.HarvestingModule):
84 
85  """Module to collect information about the generated segments and
86  compose validation plots on terminate."""
87 
88  def __init__(self, output_file_name):
89  """Constructor"""
90  super().__init__(foreach='CDCTrackVector',
91  output_file_name=output_file_name)
92 
93 
94  self.mc_track_lookupmc_track_lookup = None
95 
96  origin_track_fitter = TFCDC.CDCRiemannFitter()
97  origin_track_fitter.setOriginConstrained()
98 
99  self.track_fittertrack_fitter = origin_track_fitter
100 
101  curv_bounds = []
102  with open('fine_curv_bounds.txt') as curv_bounds_file:
103  for curv_bound_line in curv_bounds_file:
104  curv_bounds.append(float(curv_bound_line))
105 
106  bin_bounds = list(zip(curv_bounds[0::2], curv_bounds[1::2]))
107  bin_bounds = sorted(bin_bounds)
108 
109 
110  self.lower_curv_boundslower_curv_bounds = np.array([bin[0] for bin in bin_bounds])
111 
112  self.upper_curv_boundsupper_curv_bounds = np.array([bin[1] for bin in bin_bounds])
113 
114  assert(len(self.lower_curv_boundslower_curv_bounds) == len(self.upper_curv_boundsupper_curv_bounds))
115 
116  def initialize(self):
117  """Receive signal at the start of event processing"""
118  super().initialize()
119 
120  self.mc_track_lookupmc_track_lookup = TFCDC.CDCMCTrackLookUp.getInstance()
121 
122  self.mc_hit_lookupmc_hit_lookup = TFCDC.CDCMCHitLookUp.getInstance()
123 
124  def prepare(self):
125  """Initialize the MC-hit lookup method"""
126  TFCDC.CDCMCHitLookUp.getInstance().fill()
127 
128  def pick(self, track):
129  """Select tracks with at least 4 segments and associated primary MC particle"""
130  mc_track_lookup = self.mc_track_lookupmc_track_lookup
131  mc_particle = mc_track_lookup.getMCParticle(track)
132 
133  # Check that mc_particle is not a nullptr
134  return mc_particle and is_primary(mc_particle) and track.size() > 3
135 
136  def peel(self, track):
137  """Aggregate the track and MC information for track-segment analysis"""
138  track_fitter = self.track_fittertrack_fitter
139 
140  rl_drift_circle = 1
141  unit_variance = 0
142  observations2D = TFCDC.CDCObservations2D(rl_drift_circle, unit_variance)
143 
144  for recoHit3D in track:
145  observations2D.append(recoHit3D)
146 
147  trajectory2D = track_fitter.fit(observations2D)
148  trajectory2D.setLocalOrigin(TFCDC.Vector2D(0, 0))
149 
150  n_hits = track.size()
151  pt = trajectory2D.getAbsMom2D()
152  curv = trajectory2D.getCurvature()
153  curl_curv = abs(self.lower_curv_boundslower_curv_bounds[0])
154  bin_curv = curv if abs(curv) < curl_curv else abs(curv)
155  curv_var = trajectory2D.getLocalVariance(0)
156  impact = trajectory2D.getGlobalImpact()
157  phi0 = trajectory2D.getLocalCircle().phi0()
158 
159  circle = trajectory2D.getLocalCircle()
160  n12 = circle.n12()
161 
162  cross_curvs = []
163  for recoHit3D in track:
164  wire_ref_pos = recoHit3D.getRefPos2D()
165  drift_length = recoHit3D.getSignedRecoDriftLength()
166  r = wire_ref_pos.norm()
167  cross_curv = -2 * (n12.dot(wire_ref_pos) - drift_length) / (r * r - drift_length * drift_length)
168  cross_curvs.append(cross_curv)
169 
170  cross_curvs = np.array(cross_curvs)
171  cross_curv = np.median(cross_curvs)
172  cross_curv_var = np.median(np.abs(cross_curvs - cross_curv))
173 
174  basic_curv_precision = TFCDC.PrecisionUtil.getBasicCurvPrecision(cross_curv)
175  origin_curv_precision = TFCDC.PrecisionUtil.getOriginCurvPrecision(cross_curv)
176  non_origin_curv_precision = TFCDC.PrecisionUtil.getNonOriginCurvPrecision(cross_curv)
177 
178  bin_id = np.digitize([abs(cross_curv)], self.lower_curv_boundslower_curv_bounds)
179  if bin_id == 0:
180  max_curv_precision = 0.00007
181  else:
182  max_curv_precision = self.upper_curv_boundsupper_curv_bounds[bin_id - 1] - self.lower_curv_boundslower_curv_bounds[bin_id - 1]
183 
184  random_bin_id = random.randrange(len(self.upper_curv_boundsupper_curv_bounds))
185  random_lower_curv_bound = self.lower_curv_boundslower_curv_bounds[random_bin_id]
186  random_upper_curv_bound = self.upper_curv_boundsupper_curv_bounds[random_bin_id]
187  curv_dense = random.uniform(random_lower_curv_bound, random_upper_curv_bound)
188  curv_width = random_upper_curv_bound - random_lower_curv_bound
189 
190  return dict(
191  n_hits=n_hits,
192  curvature_estimate=curv,
193  curvature_variance=curv_var,
194  abs_curvature_estimate=abs(curv),
195  inv_curv=1.0 / abs(curv),
196  cross_curv=cross_curv,
197  cross_curv_var=cross_curv_var,
198  basic_curv_precision=basic_curv_precision,
199  origin_curv_precision=origin_curv_precision,
200  non_origin_curv_precision=non_origin_curv_precision,
201  max_curv_precision=max_curv_precision,
202  pt=pt,
203  curv_bin=bin_curv,
204  curv_dense=curv_dense,
205  curv_width=curv_width,
206  impact_estimate=impact,
207  phi0=phi0,
208  )
209 
210  # Refiners to be executed at the end of the harvesting / termination of the module
211 
212  save_tree = refiners.save_tree()
213 
214  save_histograms = refiners.save_histograms(outlier_z_score=5.0, allow_discrete=True)
215 
216 
217  save_profiles = refiners.save_profiles(x=['curvature_estimate', 'phi0'],
218  y='curvature_variance',
219  outlier_z_score=5.0)
220 
221 
222  save_cross_curv_profile = refiners.save_profiles(x=['cross_curv'],
223  y=['cross_curv_var',
224  'curvature_estimate',
225  'basic_curv_precision',
226  'origin_curv_precision',
227  'non_origin_curv_precision',
228  'max_curv_precision',
229  ],
230  outlier_z_score=5.0)
231 
232 
233  save_scatter = refiners.save_scatters(x=['curvature_estimate'], y='n_hits')
234 
235 
236 def main():
238  run.configure_and_execute_from_commandline()
239 
240 
241 if __name__ == "__main__":
242  logging.basicConfig(stream=sys.stdout, level=logging.INFO, format='%(levelname)s:%(message)s')
243  main()
mc_track_lookup
by default, there is no method to find matching MC tracks
Definition: record.py:94
mc_hit_lookup
Method to find matching MC hits.
Definition: record.py:122
track_fitter
Use the CDCReimannFitter with a constrained origin for track fitting.
Definition: record.py:99
lower_curv_bounds
cached copy of lower bounds
Definition: record.py:110
upper_curv_bounds
cached copy of upper bounds
Definition: record.py:112
def __init__(self, output_file_name)
Definition: record.py:88
def create_argument_parser(self, **kwds)
Definition: record.py:56
def harvesting_module(self, path=None)
Definition: record.py:49
output_file_name
Disable the writing of an output ROOT file.
Definition: run.py:20
output_file_name
There is no default for the name of the output TFile.
Definition: mixins.py:61
Definition: main.py:1
int main(int argc, char **argv)
Run all tests.
Definition: test_main.cc:91