Belle II Software development
quadTreePlotter.py
1
8
9from trackfindingcdc.cdcdisplay.svgdrawing import attributemaps
10import bisect
11from datetime import datetime
12import tempfile
13import numpy as np
14import matplotlib.pyplot as plt
15import basf2
16import ROOT
17from ROOT import Belle2
18
19from ROOT import gSystem
20gSystem.Load('libtracking')
21gSystem.Load('libtracking_trackFindingCDC')
22
23
24class QuadTreePlotter(basf2.Module):
25 """
26 This Module is able to draw the content coming from a QuadTreeImplementation with debugOutput = True.
27 """
28
29 def __init__(self, queue):
30 """
31 Do not forget to set the ranges! Otherwise you will end up with an empty plot..
32 """
33 basf2.Module.__init__(self)
34
35 self.file_name_of_quad_tree_content = "output.root"
36
38
39 self.range_x_min = 0
40
41 self.range_x_max = 0
42
43 self.range_y_min = 0
44
45 self.range_y_max = 0
46
47
48 self.queue = queue
49
50 self.file_names = []
51
53 """
54 Draw the quad tree content coming from the root file if enabled.
55 """
56
57 import seaborn as sb
58
59 if not self.draw_quad_tree_content:
60 return
61
62 input_file = ROOT.TFile(self.file_name_of_quad_tree_content)
63
64 hist = input_file.Get("histUsed")
65
66 xAxis = hist.GetXaxis()
67 yAxis = hist.GetYaxis()
68
69 x_edges = np.array([xAxis.GetBinLowEdge(iX) for iX in range(1, xAxis.GetNbins() + 2)])
70 y_edges = np.array([yAxis.GetBinLowEdge(iY) for iY in range(1, yAxis.GetNbins() + 2)])
71
72 arr_l = np.array([[hist.GetBinContent(iX, iY) for iY in range(1, yAxis.GetNbins() + 1)]
73 for iX in range(1, xAxis.GetNbins() + 1)])
74
75 hist = input_file.Get("histUnused")
76
77 xAxis = hist.GetXaxis()
78 yAxis = hist.GetYaxis()
79
80 x_edges = np.array([xAxis.GetBinLowEdge(iX) for iX in range(1, xAxis.GetNbins() + 2)])
81 y_edges = np.array([yAxis.GetBinLowEdge(iY) for iY in range(1, yAxis.GetNbins() + 2)])
82
83 l2 = np.array([[hist.GetBinContent(iX, iY) for iY in range(1, yAxis.GetNbins() + 1)]
84 for iX in range(1, xAxis.GetNbins() + 1)])
85
86 cmap = sb.cubehelix_palette(8, start=2, rot=0, dark=0, light=1, reverse=False, as_cmap=True)
87
88 plt.gca().pcolorfast(x_edges, y_edges, (arr_l + l2).T, cmap=cmap)
89
90 x_labels = [f"{x:0.{int(not float(x).is_integer())}f}" if i % 4 == 0 else "" for i, x in enumerate(x_edges)]
91 plt.xticks(x_edges, x_labels)
92 y_labels = [f"{y:0.{int(not float(y).is_integer())}f}" if i % 4 == 0 else "" for i, y in enumerate(y_edges)]
93 plt.yticks(y_edges, y_labels)
94
96 """
97 Save the plot to a svg and show it (maybe a png would be better?)
98 """
99 fileName = tempfile.gettempdir() + "/" + datetime.now().isoformat() + '.svg'
100 plt.savefig(fileName)
101 self.file_names.append(fileName)
102
103 def init_plotting(self):
104 """
105 Initialize the figure with the plot ranges
106 We need to implement axes labels later!
107 """
108 plt.clf()
109 plt.xlim(self.range_x_min, self.range_x_max)
110 plt.ylim(self.range_y_min, self.range_y_max)
111
112 def event(self):
113 """
114 Draw everything
115 """
116 self.init_plotting()
118 self.save_and_show_file()
119
120 def terminate(self):
121 """Termination signal at the end of the event processing"""
122 self.queue.put("quadTree", self.file_names)
123
124
126
127 """
128 Implementation of a quad tree plotter for SegmentQuadTrees
129 """
130
131
132 draw_segment_intersection = True and False
133
134 draw_segment = True and False
135
136 draw_segment_averaged = True and False
137
138 draw_segment_fitted = True and False
139
140 draw_mc_information = True and False
141
142 draw_mc_hits = True and False
143
144
145 theta_shifted = False
146
147 maximum_theta = np.pi
148
150 """
151 Calculate the point where the two given hits intersect
152
153 params
154 ------
155 first: hit
156 second: hit
157 """
158 positionFront = first.getRecoPos2D().conformalTransformed()
159 positionBack = second.getRecoPos2D().conformalTransformed()
160
161 theta_cut = np.arctan2((positionBack - positionFront).x(), (positionFront - positionBack).y())
162
163 if self.theta_shifted:
164 while theta_cut < - self.maximum_theta / 2:
165 theta_cut += self.maximum_theta
166 else:
167 while theta_cut < 0:
168 theta_cut += self.maximum_theta
169
170 r_cut = positionFront.x() * np.cos(theta_cut) + positionFront.y() * np.sin(theta_cut)
171
172 return theta_cut, r_cut
173
175 """
176 Transform a given normal coordinate position to a legendre position (conformal transformed)
177
178 params
179 ------
180 position: TrackFindingCDC.Vector2D
181 """
182 position = position.conformalTransformed()
183
184 theta = np.linspace(self.range_x_minrange_x_min, self.range_x_maxrange_x_max, 100)
185 r = position.x() * np.cos(theta) + position.y() * np.sin(theta)
186
187 return theta, r
188
190 """
191 Loop over all segments and execute a function
192
193 params
194 ------
195 f: function
196 """
197 items = Belle2.PyStoreObj("CDCSegment2DVector")
198 wrapped_vector = items.obj()
199 vector = wrapped_vector.get()
200
201 for quad_tree_item in vector:
202 if quad_tree_item.getStereoType() == 0:
203 f(quad_tree_item)
204
205 def convertToQuadTreePicture(self, phi, mag, charge):
206 """
207 Convert given track parameters into a point in the legendre space
208
209 params
210 ------
211 phi: phi of the track
212 mag: magnitude of pt
213 charge: charge of the track
214 """
215 theta = phi + np.pi / 2
216 r = 1 / mag * 1.5 * 0.00299792458 * charge
217 if self.theta_shifted:
218 if theta > self.maximum_theta / 2 or theta < -self.maximum_theta / 2:
219 theta = theta % self.maximum_theta - self.maximum_theta / 2
220 else:
221 r *= -1
222 else:
223 if theta > self.maximum_theta or theta < 0:
224 theta = theta % self.maximum_theta
225 else:
226 r *= -1
227 return theta, r
228
229 def event(self):
230 """
231 Draw everything according to the given options
232
233 Attributes
234 ----------
235 draw_segment_intersection
236 draw_segment
237 draw_segment_averaged
238 draw_segment_fitted
239 draw_mc_information
240 draw_mc_hits
241 """
242 if self.theta_shifted:
243
245
247 else:
250
251
253
255
256 self.init_plotting()
257 # self.plotQuadTreeContent()
258
260 map = attributemaps.SegmentMCTrackIdColorMap()
261
262 def f(segment):
263 theta, r = self.calculateIntersectionInQuadTreePicture(segment.front(), segment.back())
264 plt.plot(theta, r, color=list(map(0, segment)), marker="o")
265
266 self.forAllAxialSegments(f)
267
268 if self.draw_segment:
269 map = attributemaps.SegmentMCTrackIdColorMap()
270
271 def f(segment):
272 theta, r = self.calculatePositionInQuadTreePicture(segment.front().getRecoPos2D())
273 plt.plot(theta, r, color=list(map(0, segment)), marker="", ls="-")
274
275 self.forAllAxialSegments(f)
276
277 if self.draw_segment_averaged:
278 map = attributemaps.SegmentMCTrackIdColorMap()
279
280 def f(segment):
281 middle_index = int(np.round(segment.size() / 2.0))
282 middle_point = list(segment.items())[middle_index]
283 theta_front, r_front = self.calculateIntersectionInQuadTreePicture(segment.front(), middle_point)
284 theta_back, r_back = self.calculateIntersectionInQuadTreePicture(middle_point, segment.back())
285
286 plt.plot([theta_front, theta_back], [r_front, r_back], color=list(map(0, segment)), marker="o", ls="-")
287
288 self.forAllAxialSegments(f)
289
290 if self.draw_segment_fitted:
291 map = attributemaps.SegmentMCTrackIdColorMap()
293
294 def f(segment):
295 trajectory = fitter.fit(segment)
296 momentum = trajectory.getUnitMom2D(Belle2.TrackFindingCDC.Vector2D(0, 0)).scale(trajectory.getAbsMom2D())
297 theta, r = self.convertToQuadTreePicture(momentum.phi(), momentum.norm(), trajectory.getChargeSign())
298 plt.plot(theta, r, color=list(map(0, segment)), marker="o")
299
300 self.forAllAxialSegments(f)
301
302 if self.draw_hits:
303 cdcHits = Belle2.PyStoreArray("CDCHits")
304 storedWireHits = Belle2.PyStoreObj('CDCWireHitVector')
305 wireHits = storedWireHits.obj().get()
306
307 array = Belle2.PyStoreArray("MCTrackCands")
308 cdc_hits = [cdcHits[i] for track in array for i in track.getHitIDs()]
309
310 for cdcHit in cdcHits:
311 if cdcHit in cdc_hits:
312 continue
313 wireHit = wireHits.at(bisect.bisect_left(wireHits, cdcHit))
314 theta, r = self.calculatePositionInQuadTreePicture(wireHit.getRefPos2D())
315
316 plt.plot(theta, r, marker="", color="black", ls="-", alpha=0.8)
317
318 if self.draw_mc_hits:
319 storedWireHits = Belle2.PyStoreObj('CDCWireHitVector')
320 wireHits = storedWireHits.obj().get()
321
322 map = attributemaps.listColors
323 array = Belle2.PyStoreArray("MCTrackCands")
324 cdcHits = Belle2.PyStoreArray("CDCHits")
325
326 for track in array:
327 mcTrackID = track.getMcTrackId()
328
329 for cdcHitID in track.getHitIDs(Belle2.Const.CDC):
330 cdcHit = cdcHits[cdcHitID]
331 wireHit = wireHits.at(bisect.bisect_left(wireHits, cdcHit))
332
333 theta, r = self.calculatePositionInQuadTreePicture(wireHit.getRefPos2D())
334
335 plt.plot(theta, r, marker="", color=map[mcTrackID % len(map)], ls="-", alpha=0.2)
336
337 if self.draw_mc_information:
338 map = attributemaps.listColors
339 array = Belle2.PyStoreArray("MCTrackCands")
340
341 for track in array:
342 momentum = track.getMomSeed()
343
344 # HARDCODED!!! Temporary solution
345 theta, r = self.convertToQuadTreePicture(momentum.Phi(), momentum.Mag(), track.getChargeSeed())
346 mcTrackID = track.getMcTrackId()
347
348 plt.plot(theta, r, marker="o", color="black", ms=10)
349 plt.plot(theta, r, marker="o", color=map[mcTrackID % len(map)], ms=5)
350
351 self.save_and_show_file()
352
353
355
356 """
357 Implementation of a quad tree plotter for StereoHitAssignment
358 """
359
360
361 draw_mc_hits = False
362
363 draw_mc_tracks = False
364
365 draw_track_hits = False
366
367 draw_last_track = True
368
369 delete_bad_hits = False
370
372 """
373 Convert a genfit::TrackCand into a TrackFindingCDC.CDCTrajectory3D
374
375 params
376 ------
377 track: genfit::TrackCand
378 """
381
382 position = Vector3D(track.getPosSeed())
383 momentum = Vector3D(track.getMomSeed())
384 charge = track.getChargeSeed()
385
386 return Trajectory3D(position, momentum, charge)
387
388 def create_reco_hit3D(self, cdcHit, trajectory3D, rlInfo):
389 """
390 Use a cdc hit and a trajectory to reconstruct a CDCRecoHit3D
391
392 params
393 ------
394 cdcHit: CDCHit
395 trajectory3D: TrackFindingCDC.CDCTrajectory3D
396 rlInfo: RightLeftInfo ( = short)
397 """
398 storedWireHits = Belle2.PyStoreObj('CDCWireHitVector')
399 wireHits = storedWireHits.obj().get()
400
402 wireHit = wireHits.at(bisect.bisect_left(wireHits, cdcHit))
403 rightLeftWireHit = Belle2.TrackFindingCDC.CDCRLWireHit(wireHit, rlInfo)
404 if rightLeftWireHit.getStereoType() != 0:
405 recoHit3D = CDCRecoHit3D.reconstruct(rightLeftWireHit, trajectory3D.getTrajectory2D())
406 return recoHit3D
407 else:
408 return None
409
410 def get_plottable_line(self, recoHit3D):
411 """
412 Minim the task of the StereoQuadTree by showing the line of quadtree nodes
413 a hit belongs to
414 """
416 arr_l = np.array((np.array(recoHit3D.getRecoPos3D().z()) - z0) / recoHit3D.getArcLength2D())
417 return arr_l, z0
418
419 def plot_hit_line(self, recoHit3D, color):
420 """
421 Draw one recoHit3D
422 """
423 if recoHit3D:
424 if recoHit3D.getStereoType() == 0:
425 return
426
427 arr_l, z0 = self.get_plottable_line(recoHit3D)
428 plt.plot(arr_l, z0, marker="", ls="-", alpha=0.4, color=color)
429
430 def event(self):
431 """
432 Draw the hit content according to the attributes
433
434 Attributes
435 ----------
436 draw_mc_hits
437 draw_mc_tracks
438 draw_track_hits
439 draw_last_track
440 delete_bad_hits
441 """
442
444
445
446 self.range_x_minrange_x_min = -2 - np.sqrt(3)
447
448 self.range_x_maxrange_x_max = 2 + np.sqrt(3)
449
450
452
454
455 self.init_plotting()
457
458 map = attributemaps.listColors
459 cdcHits = Belle2.PyStoreArray("CDCHits")
460
461 items = Belle2.PyStoreObj("CDCTrackVector")
462 wrapped_vector = items.obj()
463 track_vector = wrapped_vector.get()
464
465 mcHitLookUp = Belle2.TrackFindingCDC.CDCMCHitLookUp().getInstance()
466 mcHitLookUp.fill()
467
468 storedWireHits = Belle2.PyStoreObj('CDCWireHitVector')
469 wireHits = storedWireHits.obj().get()
470
471 if self.draw_mc_hits:
472 mc_track_cands = Belle2.PyStoreArray("MCTrackCands")
473
474 for track in mc_track_cands:
475 mcTrackID = track.getMcTrackId()
476 trajectory = self.create_trajectory_from_track(track)
477
478 for cdcHitID in track.getHitIDs(Belle2.Const.CDC):
479 cdcHit = cdcHits[cdcHitID]
480
481 leftRecoHit3D = self.create_reco_hit3D(cdcHit, trajectory, -1)
482 rightRecoHit3D = self.create_reco_hit3D(cdcHit, trajectory, 1)
483
484 self.plot_hit_line(leftRecoHit3D, color=map[mcTrackID % len(map)])
485 self.plot_hit_line(rightRecoHit3D, color=map[mcTrackID % len(map)])
486
487 if self.draw_mc_tracks:
488 mc_track_cands = Belle2.PyStoreArray("MCTrackCands")
489
490 for track in mc_track_cands:
491 mcTrackID = track.getMcTrackId()
492 trajectory = self.create_trajectory_from_track(track)
493 z0 = trajectory.getTrajectorySZ().getZ0()
494
495 for cdcHitID in track.getHitIDs(Belle2.Const.CDC):
496 cdcHit = cdcHits[cdcHitID]
497 recoHit3D = self.create_reco_hit3D(cdcHit, trajectory, mcHitLookUp.getRLInfo(cdcHit))
498
499 if recoHit3D:
500 arr_l = (recoHit3D.getRecoPos3D().z() - z0) / recoHit3D.getArcLength2D()
501 plt.plot(arr_l, z0, marker="o", color=map[mcTrackID % len(map)], ls="", alpha=0.2)
502
503 if self.draw_track_hits:
504 for id, track in enumerate(track_vector):
505 for recoHit3D in list(track.items()):
506 self.plot_hit_line(recoHit3D, color=map[id % len(map)])
507
508 if self.draw_last_track and len(track_vector) != 0:
509
510 last_track = track_vector[-1]
511 trajectory = last_track.getStartTrajectory3D().getTrajectory2D()
512
513 for wireHit in wireHits:
514 for rlInfo in (-1, 1):
515 recoHit3D = Belle2.TrackFindingCDC.CDCRecoHit3D.reconstruct(wireHit, rlInfo, trajectory)
516
517 if (self.delete_bad_hits and
518 (wireHit.getRLInfo() != mcHitLookUp.getRLInfo(wireHit.getWireHit().getHit()) or
519 not recoHit3D.isInCellZBounds())):
520 continue
521
522 if recoHit3D in list(last_track.items()):
523 color = map[len(track_vector) % len(map)]
524 else:
525 if wireHit.getRLInfo() == 1:
526 color = "black"
527 else:
528 color = "gray"
529 self.plot_hit_line(recoHit3D, color)
530
531 plt.xlabel(r"$\tan \ \lambda$")
532 plt.ylabel(r"$z_0$")
533 self.save_and_show_file()
534
535
537
538 """
539 A wrapper around the svg drawer in the tracking package that
540 writes its output files as a list to the queue
541 """
542
543 def __init__(self, queue, label, *args, **kwargs):
544 """ The same as the base class, except:
545
546 Arguments
547 ---------
548
549 queue: The queue to write to
550 label: The key name in the queue
551 """
552
553 self.queuequeue = queue
554
555 self.label = label
556 StereoQuadTreePlotter.__init__(self, *args, **kwargs)
557
558
559 self.file_list = []
560
561 def terminate(self):
562 """ Overwrite the terminate to put the list to the queue"""
563 StereoQuadTreePlotter.terminate(self)
564 self.queuequeue.put(self.label, self.file_list)
565
567 """ Overwrite the function to listen for every new filename """
568
569 from datetime import datetime
570 from matplotlib import pyplot as plt
571
572 fileName = tempfile.gettempdir() + "/" + datetime.now().isoformat() + '.svg'
573 plt.savefig(fileName)
574 self.file_list.append(fileName)
A (simplified) python wrapper for StoreArray.
Definition: PyStoreArray.h:72
a (simplified) python wrapper for StoreObjPtr.
Definition: PyStoreObj.h:67
Interface class to the Monte Carlo information for individual hits.
Class representing an oriented hit wire including a hypotheses whether the causing track passes left ...
Definition: CDCRLWireHit.h:41
Class representing a three dimensional reconstructed hit.
Definition: CDCRecoHit3D.h:52
static CDCRecoHit3D reconstruct(const CDCRecoHit2D &recoHit2D, const CDCTrajectory2D &trajectory2D)
Reconstructs the three dimensional hit from the two dimensional and the two dimensional trajectory.
Definition: CDCRecoHit3D.cc:56
static const CDCRiemannFitter & getOriginCircleFitter()
Static getter for an origin circle fitter.
Particle full three dimensional trajectory.
A two dimensional vector which is equipped with functions for correct handling of orientation relate...
Definition: Vector2D.h:32
A three dimensional vector.
Definition: Vector3D.h:33
range_x_min
cached minimum x value
draw_quad_tree_content
cached flag to draw QuadTree
range_y_max
cached maximum y value
range_x_max
cached maximum x value
file_name_of_quad_tree_content
cached output filename
file_names
cached array of output filenames (one file per image)
range_y_min
cached minimum y value
queue
cached value of the queue input parameter
def __init__(self, queue, label, *args, **kwargs)
label
The label for writing to the queue.
bool draw_segment_averaged
by default, do not draw an averaged segment
bool draw_segment
by default, do not draw a segment
bool draw_segment_fitted
by default, do not draw a fitted segment
range_x_min
lower x bound for polar angle
bool draw_mc_hits
by default, do not draw the MC hits
bool draw_mc_information
by default, do not draw the MC information
def convertToQuadTreePicture(self, phi, mag, charge)
def calculatePositionInQuadTreePicture(self, position)
range_x_max
upper x bound for polar angle
np maximum_theta
an alias for the maximum value of the polar angle
bool draw_segment_intersection
by default, do not draw a segment intersection
bool theta_shifted
by default, polar angles and cuts are in the range (0,pi) rather than (-pi/2,+pi/2)
def calculateIntersectionInQuadTreePicture(self, first, second)
def plot_hit_line(self, recoHit3D, color)
bool draw_mc_hits
by default, do not draw the MC hits
bool delete_bad_hits
by default, do not delete the bad track hits
draw_quad_tree_content
by default, draw the QuadTree
bool draw_last_track
by default, draw the last track
def create_reco_hit3D(self, cdcHit, trajectory3D, rlInfo)
bool draw_track_hits
by default, do not draw the track hits
bool draw_mc_tracks
by default, do not draw the MC tracks