Belle II Software  release-05-01-25
record.py
1 #!/usr/bin/env python3
2 # -*- coding: utf-8 -*-
3 
4 import basf2
5 
6 import ROOT
7 from ROOT import Belle2 # make Belle2 namespace available
8 from ROOT import std
9 from ROOT.Belle2 import TrackFindingCDC as TFCDC
10 
11 import os
12 import sys
13 import random
14 import numpy as np
15 
16 from tracking.validation.utilities import is_primary
17 
18 import tracking.harvest.harvesting as harvesting
19 import tracking.harvest.refiners as refiners
20 from tracking.harvest.run import HarvestingRun
21 
22 import argparse
23 
24 import logging
25 
26 
27 def get_logger():
28  return logging.getLogger(__name__)
29 
30 CONTACT = "oliver.frost@desy.de"
31 
32 
34  """Harvester to generate, postprocess and inspect MC events for track-segment evaluation"""
35 
36  n_events = 10000
37 
38  generator_module = "generic"
39 
40  @property
41  def output_file_name(self):
42  """Get the output ROOT filename"""
43  return 'legendre_binning.root'
44 
45  def harvesting_module(self, path=None):
46  """Harvest and post-process the generated events"""
47  harvesting_module = LegendreBinningValidationModule(self.output_file_name)
48  if path:
49  path.add_module(harvesting_module)
50  return harvesting_module
51 
52  def create_argument_parser(self, **kwds):
53  """Convert command-line arguments to basf2 argument list"""
54  argument_parser = super().create_argument_parser(**kwds)
55  return argument_parser
56 
57  def create_path(self):
58  """
59  Sets up a path that plays back pregenerated events or generates events
60  based on the properties in the base class.
61  """
62  path = super().create_path()
63 
64  path.add_module("TFCDC_WireHitPreparer",
65  logLevel=8,
66  flightTimeEstimation="outwards",
67  UseNLoops=1)
68 
69  path.add_module('TFCDC_AxialTrackCreatorMCTruth',
70  logLevel=8,
71  useOnlyBeforeTOP=True,
72  fit=True,
73  reconstructedDriftLength=True,
74  reconstructedPositions=True)
75 
76  return path
77 
78 
79 class LegendreBinningValidationModule(harvesting.HarvestingModule):
80 
81  """Module to collect information about the generated segments and
82  compose validation plots on terminate."""
83 
84  def __init__(self, output_file_name):
85  """Constructor"""
86  super().__init__(foreach='CDCTrackVector',
87  output_file_name=output_file_name)
88 
89 
90  self.mc_track_lookup = None
91 
92  origin_track_fitter = TFCDC.CDCRiemannFitter()
93  origin_track_fitter.setOriginConstrained()
94 
95  self.track_fitter = origin_track_fitter
96 
97  curv_bounds = []
98  with open('fine_curv_bounds.txt') as curv_bounds_file:
99  for curv_bound_line in curv_bounds_file:
100  curv_bounds.append(float(curv_bound_line))
101 
102  bin_bounds = list(zip(curv_bounds[0::2], curv_bounds[1::2]))
103  bin_bounds = sorted(bin_bounds)
104 
105 
106  self.lower_curv_bounds = np.array([bin[0] for bin in bin_bounds])
107 
108  self.upper_curv_bounds = np.array([bin[1] for bin in bin_bounds])
109 
110  assert(len(self.lower_curv_bounds) == len(self.upper_curv_bounds))
111 
112  def initialize(self):
113  """Receive signal at the start of event processing"""
114  super().initialize()
115 
116  self.mc_track_lookup = TFCDC.CDCMCTrackLookUp.getInstance()
117 
118  self.mc_hit_lookup = TFCDC.CDCMCHitLookUp.getInstance()
119 
120  def prepare(self):
121  """Initialize the MC-hit lookup method"""
122  TFCDC.CDCMCHitLookUp.getInstance().fill()
123 
124  def pick(self, track):
125  """Select tracks with at least 4 segments and associated primary MC particle"""
126  mc_track_lookup = self.mc_track_lookup
127  mc_particle = mc_track_lookup.getMCParticle(track)
128 
129  # Check that mc_particle is not a nullptr
130  return mc_particle and is_primary(mc_particle) and track.size() > 3
131 
132  def peel(self, track):
133  """Aggregate the track and MC information for track-segment analysis"""
134  mc_track_lookup = self.mc_track_lookup
135  mc_hit_lookup = self.mc_hit_lookup
136 
137  track_fitter = self.track_fitter
138 
139  rl_drift_circle = 1
140  unit_variance = 0
141  observations2D = TFCDC.CDCObservations2D(rl_drift_circle, unit_variance)
142 
143  for recoHit3D in track:
144  observations2D.append(recoHit3D)
145 
146  trajectory2D = track_fitter.fit(observations2D)
147  trajectory2D.setLocalOrigin(TFCDC.Vector2D(0, 0))
148 
149  n_hits = track.size()
150  pt = trajectory2D.getAbsMom2D()
151  curv = trajectory2D.getCurvature()
152  curl_curv = abs(self.lower_curv_bounds[0])
153  bin_curv = curv if abs(curv) < curl_curv else abs(curv)
154  curv_var = trajectory2D.getLocalVariance(0)
155  impact = trajectory2D.getGlobalImpact()
156  phi0 = trajectory2D.getLocalCircle().phi0()
157 
158  circle = trajectory2D.getLocalCircle()
159  n12 = circle.n12()
160 
161  cross_curvs = []
162  for recoHit3D in track:
163  wire_ref_pos = recoHit3D.getRefPos2D()
164  l = recoHit3D.getSignedRecoDriftLength()
165  r = wire_ref_pos.norm()
166  cross_curv = -2 * (n12.dot(wire_ref_pos) - l) / (r * r - l * l)
167  cross_curvs.append(cross_curv)
168 
169  cross_curvs = np.array(cross_curvs)
170  cross_curv = np.median(cross_curvs)
171  cross_curv_var = np.median(np.abs(cross_curvs - cross_curv))
172 
173  basic_curv_precision = TFCDC.PrecisionUtil.getBasicCurvPrecision(cross_curv)
174  origin_curv_precision = TFCDC.PrecisionUtil.getOriginCurvPrecision(cross_curv)
175  non_origin_curv_precision = TFCDC.PrecisionUtil.getNonOriginCurvPrecision(cross_curv)
176 
177  bin_id = np.digitize([abs(cross_curv)], self.lower_curv_bounds)
178  if bin_id == 0:
179  max_curv_precision = 0.00007
180  else:
181  max_curv_precision = self.upper_curv_bounds[bin_id - 1] - self.lower_curv_bounds[bin_id - 1]
182 
183  random_bin_id = random.randrange(len(self.upper_curv_bounds))
184  random_lower_curv_bound = self.lower_curv_bounds[random_bin_id]
185  random_upper_curv_bound = self.upper_curv_bounds[random_bin_id]
186  curv_dense = random.uniform(random_lower_curv_bound, random_upper_curv_bound)
187  curv_width = random_upper_curv_bound - random_lower_curv_bound
188 
189  return dict(
190  n_hits=n_hits,
191  curvature_estimate=curv,
192  curvature_variance=curv_var,
193  abs_curvature_estimate=abs(curv),
194  inv_curv=1.0 / abs(curv),
195  cross_curv=cross_curv,
196  cross_curv_var=cross_curv_var,
197  basic_curv_precision=basic_curv_precision,
198  origin_curv_precision=origin_curv_precision,
199  non_origin_curv_precision=non_origin_curv_precision,
200  max_curv_precision=max_curv_precision,
201  pt=pt,
202  curv_bin=bin_curv,
203  curv_dense=curv_dense,
204  curv_width=curv_width,
205  impact_estimate=impact,
206  phi0=phi0,
207  )
208 
209  # Refiners to be executed at the end of the harvesting / termination of the module
210 
211  save_tree = refiners.save_tree()
212 
213  save_histograms = refiners.save_histograms(outlier_z_score=5.0, allow_discrete=True)
214 
215 
216  save_profiles = refiners.save_profiles(x=['curvature_estimate', 'phi0'],
217  y='curvature_variance',
218  outlier_z_score=5.0)
219 
220 
221  save_cross_curv_profile = refiners.save_profiles(x=['cross_curv'],
222  y=['cross_curv_var',
223  'curvature_estimate',
224  'basic_curv_precision',
225  'origin_curv_precision',
226  'non_origin_curv_precision',
227  'max_curv_precision',
228  ],
229  outlier_z_score=5.0)
230 
231 
232  save_scatter = refiners.save_scatters(x=['curvature_estimate'], y='n_hits')
233 
234 
235 def main():
237  run.configure_and_execute_from_commandline()
238 
239 if __name__ == "__main__":
240  logging.basicConfig(stream=sys.stdout, level=logging.INFO, format='%(levelname)s:%(message)s')
241  main()
record.LegendreBinningValidationRun.create_argument_parser
def create_argument_parser(self, **kwds)
Definition: record.py:52
record.LegendreBinningValidationModule
Definition: record.py:79
tracking.harvest.refiners
Definition: refiners.py:1
record.LegendreBinningValidationModule.peel
def peel(self, track)
Definition: record.py:132
tracking.harvest.harvesting
Definition: harvesting.py:1
record.LegendreBinningValidationModule.pick
def pick(self, track)
Definition: record.py:124
record.LegendreBinningValidationModule.track_fitter
track_fitter
Use the CDCReimannFitter with a constrained origin for track fitting.
Definition: record.py:95
record.LegendreBinningValidationModule.mc_hit_lookup
mc_hit_lookup
Method to find matching MC hits.
Definition: record.py:118
tracking.harvest.run
Definition: run.py:1
record.LegendreBinningValidationModule.mc_track_lookup
mc_track_lookup
by default, there is no method to find matching MC tracks
Definition: record.py:90
main
int main(int argc, char **argv)
Run all tests.
Definition: test_main.cc:77
tracking.harvest.run.HarvestingRunMixin.output_file_name
output_file_name
Disable the writing of an output ROOT file.
Definition: run.py:12
record.LegendreBinningValidationRun.harvesting_module
def harvesting_module(self, path=None)
Definition: record.py:45
record.LegendreBinningValidationModule.prepare
def prepare(self)
Definition: record.py:120
record.LegendreBinningValidationModule.initialize
def initialize(self)
Definition: record.py:112
tracking.validation.utilities
Definition: utilities.py:1
record.LegendreBinningValidationModule.upper_curv_bounds
upper_curv_bounds
cached copy of upper bounds
Definition: record.py:108
record.LegendreBinningValidationRun
Definition: record.py:33
tracking.harvest.run.HarvestingRun
Definition: run.py:69
record.LegendreBinningValidationModule.__init__
def __init__(self, output_file_name)
Definition: record.py:84
record.LegendreBinningValidationRun.create_path
def create_path(self)
Definition: record.py:57
record.LegendreBinningValidationModule.lower_curv_bounds
lower_curv_bounds
cached copy of lower bounds
Definition: record.py:106