Belle II Software development
quadTreePlotter.py
1
8
9from trackfindingcdc.cdcdisplay.svgdrawing import attributemaps
10import bisect
11from datetime import datetime
12import tempfile
13import numpy as np
14import matplotlib.pyplot as plt
15import basf2
16import ROOT
17from ROOT import Belle2
18
19from ROOT import gSystem
20# @cond internal_test
21gSystem.Load('libtracking')
22gSystem.Load('libtracking_trackFindingCDC')
23
24
25class QuadTreePlotter(basf2.Module):
26 """
27 This Module is able to draw the content coming from a QuadTreeImplementation with debugOutput = True.
28 """
29
30 def __init__(self, queue):
31 """
32 Do not forget to set the ranges! Otherwise you will end up with an empty plot..
33 """
34 basf2.Module.__init__(self)
35
36 self.file_name_of_quad_tree_content = "output.root"
37
38 self.draw_quad_tree_content = True
39
40 self.range_x_min = 0
41
42 self.range_x_max = 0
43
44 self.range_y_min = 0
45
46 self.range_y_max = 0
47
48
49 self.queue = queue
50
51 self.file_names = []
52
53 def plot_quad_tree_content(self):
54 """
55 Draw the quad tree content coming from the root file if enabled.
56 """
57
58 import seaborn as sb
59
60 if not self.draw_quad_tree_content:
61 return
62
63 input_file = ROOT.TFile(self.file_name_of_quad_tree_content)
64
65 hist = input_file.Get("histUsed")
66
67 xAxis = hist.GetXaxis()
68 yAxis = hist.GetYaxis()
69
70 x_edges = np.array([xAxis.GetBinLowEdge(iX) for iX in range(1, xAxis.GetNbins() + 2)])
71 y_edges = np.array([yAxis.GetBinLowEdge(iY) for iY in range(1, yAxis.GetNbins() + 2)])
72
73 arr_l = np.array([[hist.GetBinContent(iX, iY) for iY in range(1, yAxis.GetNbins() + 1)]
74 for iX in range(1, xAxis.GetNbins() + 1)])
75
76 hist = input_file.Get("histUnused")
77
78 xAxis = hist.GetXaxis()
79 yAxis = hist.GetYaxis()
80
81 x_edges = np.array([xAxis.GetBinLowEdge(iX) for iX in range(1, xAxis.GetNbins() + 2)])
82 y_edges = np.array([yAxis.GetBinLowEdge(iY) for iY in range(1, yAxis.GetNbins() + 2)])
83
84 l2 = np.array([[hist.GetBinContent(iX, iY) for iY in range(1, yAxis.GetNbins() + 1)]
85 for iX in range(1, xAxis.GetNbins() + 1)])
86
87 cmap = sb.cubehelix_palette(8, start=2, rot=0, dark=0, light=1, reverse=False, as_cmap=True)
88
89 plt.gca().pcolorfast(x_edges, y_edges, (arr_l + l2).T, cmap=cmap)
90
91 x_labels = [f"{x:0.{int(not float(x).is_integer())}f}" if i % 4 == 0 else "" for i, x in enumerate(x_edges)]
92 plt.xticks(x_edges, x_labels)
93 y_labels = [f"{y:0.{int(not float(y).is_integer())}f}" if i % 4 == 0 else "" for i, y in enumerate(y_edges)]
94 plt.yticks(y_edges, y_labels)
95
96 def save_and_show_file(self):
97 """
98 Save the plot to a svg and show it (maybe a png would be better?)
99 """
100 fileName = tempfile.gettempdir() + "/" + datetime.now().isoformat() + '.svg'
101 plt.savefig(fileName)
102 self.file_names.append(fileName)
103
104 def init_plotting(self):
105 """
106 Initialize the figure with the plot ranges
107 We need to implement axes labels later!
108 """
109 plt.clf()
110 plt.xlim(self.range_x_min, self.range_x_max)
111 plt.ylim(self.range_y_min, self.range_y_max)
112
113 def event(self):
114 """
115 Draw everything
116 """
117 self.init_plotting()
118 self.plot_quad_tree_content()
119 self.save_and_show_file()
120
121 def terminate(self):
122 """Termination signal at the end of the event processing"""
123 self.queue.put("quadTree", self.file_names)
124
125
126class SegmentQuadTreePlotter(QuadTreePlotter):
127
128 """
129 Implementation of a quad tree plotter for SegmentQuadTrees
130 """
131
132
133 draw_segment_intersection = True and False
134
135 draw_segment = True and False
136
137 draw_segment_averaged = True and False
138
139 draw_segment_fitted = True and False
140
141 draw_mc_information = True and False
142
143 draw_mc_hits = True and False
144
145
146 theta_shifted = False
147
148 maximum_theta = np.pi
149
150 def calculateIntersectionInQuadTreePicture(self, first, second):
151 """
152 Calculate the point where the two given hits intersect
153
154 params
155 ------
156 first: hit
157 second: hit
158 """
159 positionFront = first.getRecoPos2D().conformalTransformed()
160 positionBack = second.getRecoPos2D().conformalTransformed()
161
162 theta_cut = np.arctan2((positionBack - positionFront).x(), (positionFront - positionBack).y())
163
164 if self.theta_shifted:
165 while theta_cut < - self.maximum_theta / 2:
166 theta_cut += self.maximum_theta
167 else:
168 while theta_cut < 0:
169 theta_cut += self.maximum_theta
170
171 r_cut = positionFront.x() * np.cos(theta_cut) + positionFront.y() * np.sin(theta_cut)
172
173 return theta_cut, r_cut
174
175 def calculatePositionInQuadTreePicture(self, position):
176 """
177 Transform a given normal coordinate position to a legendre position (conformal transformed)
178
179 params
180 ------
181 position: TrackFindingCDC.Vector2D
182 """
183 position = position.conformalTransformed()
184
185 theta = np.linspace(self.range_x_min, self.range_x_max, 100)
186 r = position.x() * np.cos(theta) + position.y() * np.sin(theta)
187
188 return theta, r
189
190 def forAllAxialSegments(self, f):
191 """
192 Loop over all segments and execute a function
193
194 params
195 ------
196 f: function
197 """
198 items = Belle2.PyStoreObj("CDCSegment2DVector")
199 wrapped_vector = items.obj()
200 vector = wrapped_vector.get()
201
202 for quad_tree_item in vector:
203 if quad_tree_item.getStereoType() == 0:
204 f(quad_tree_item)
205
206 def convertToQuadTreePicture(self, phi, mag, charge):
207 """
208 Convert given track parameters into a point in the legendre space
209
210 params
211 ------
212 phi: phi of the track
213 mag: magnitude of pt
214 charge: charge of the track
215 """
216 theta = phi + np.pi / 2
217 r = 1 / mag * 1.5 * 0.00299792458 * charge
218 if self.theta_shifted:
219 if theta > self.maximum_theta / 2 or theta < -self.maximum_theta / 2:
220 theta = theta % self.maximum_theta - self.maximum_theta / 2
221 else:
222 r *= -1
223 else:
224 if theta > self.maximum_theta or theta < 0:
225 theta = theta % self.maximum_theta
226 else:
227 r *= -1
228 return theta, r
229
230 def event(self):
231 """
232 Draw everything according to the given options
233
234 Attributes
235 ----------
236 draw_segment_intersection
237 draw_segment
238 draw_segment_averaged
239 draw_segment_fitted
240 draw_mc_information
241 draw_mc_hits
242 """
243 if self.theta_shifted:
244
245 self.range_x_min = -self.maximum_theta / 2
246
247 self.range_x_max = self.maximum_theta / 2
248 else:
249 self.range_x_min = 0
250 self.range_x_max = self.maximum_theta
251
252
253 self.range_y_min = -0.08
254
255 self.range_y_max = 0.08
256
257 self.init_plotting()
258 # self.plotQuadTreeContent()
259
260 if self.draw_segment_intersection:
261 map = attributemaps.SegmentMCTrackIdColorMap()
262
263 def f(segment):
264 theta, r = self.calculateIntersectionInQuadTreePicture(segment.front(), segment.back())
265 plt.plot(theta, r, color=list(map(0, segment)), marker="o")
266
267 self.forAllAxialSegments(f)
268
269 if self.draw_segment:
270 map = attributemaps.SegmentMCTrackIdColorMap()
271
272 def f(segment):
273 theta, r = self.calculatePositionInQuadTreePicture(segment.front().getRecoPos2D())
274 plt.plot(theta, r, color=list(map(0, segment)), marker="", ls="-")
275
276 self.forAllAxialSegments(f)
277
278 if self.draw_segment_averaged:
279 map = attributemaps.SegmentMCTrackIdColorMap()
280
281 def f(segment):
282 middle_index = int(np.round(segment.size() / 2.0))
283 middle_point = list(segment.items())[middle_index]
284 theta_front, r_front = self.calculateIntersectionInQuadTreePicture(segment.front(), middle_point)
285 theta_back, r_back = self.calculateIntersectionInQuadTreePicture(middle_point, segment.back())
286
287 plt.plot([theta_front, theta_back], [r_front, r_back], color=list(map(0, segment)), marker="o", ls="-")
288
289 self.forAllAxialSegments(f)
290
291 if self.draw_segment_fitted:
292 map = attributemaps.SegmentMCTrackIdColorMap()
294
295 def f(segment):
296 trajectory = fitter.fit(segment)
297 momentum = trajectory.getUnitMom2D(Belle2.TrackFindingCDC.Vector2D(0, 0)).scale(trajectory.getAbsMom2D())
298 theta, r = self.convertToQuadTreePicture(momentum.phi(), momentum.norm(), trajectory.getChargeSign())
299 plt.plot(theta, r, color=list(map(0, segment)), marker="o")
300
301 self.forAllAxialSegments(f)
302
303 if self.draw_hits:
304 cdcHits = Belle2.PyStoreArray("CDCHits")
305 storedWireHits = Belle2.PyStoreObj('CDCWireHitVector')
306 wireHits = storedWireHits.obj().get()
307
308 array = Belle2.PyStoreArray("MCTrackCands")
309 cdc_hits = [cdcHits[i] for track in array for i in track.getHitIDs()]
310
311 for cdcHit in cdcHits:
312 if cdcHit in cdc_hits:
313 continue
314 wireHit = wireHits.at(bisect.bisect_left(wireHits, cdcHit))
315 theta, r = self.calculatePositionInQuadTreePicture(wireHit.getRefPos2D())
316
317 plt.plot(theta, r, marker="", color="black", ls="-", alpha=0.8)
318
319 if self.draw_mc_hits:
320 storedWireHits = Belle2.PyStoreObj('CDCWireHitVector')
321 wireHits = storedWireHits.obj().get()
322
323 map = attributemaps.listColors
324 array = Belle2.PyStoreArray("MCTrackCands")
325 cdcHits = Belle2.PyStoreArray("CDCHits")
326
327 for track in array:
328 mcTrackID = track.getMcTrackId()
329
330 for cdcHitID in track.getHitIDs(Belle2.Const.CDC):
331 cdcHit = cdcHits[cdcHitID]
332 wireHit = wireHits.at(bisect.bisect_left(wireHits, cdcHit))
333
334 theta, r = self.calculatePositionInQuadTreePicture(wireHit.getRefPos2D())
335
336 plt.plot(theta, r, marker="", color=map[mcTrackID % len(map)], ls="-", alpha=0.2)
337
338 if self.draw_mc_information:
339 map = attributemaps.listColors
340 array = Belle2.PyStoreArray("MCTrackCands")
341
342 for track in array:
343 momentum = track.getMomSeed()
344
345 # HARDCODED!!! Temporary solution
346 theta, r = self.convertToQuadTreePicture(momentum.Phi(), momentum.Mag(), track.getChargeSeed())
347 mcTrackID = track.getMcTrackId()
348
349 plt.plot(theta, r, marker="o", color="black", ms=10)
350 plt.plot(theta, r, marker="o", color=map[mcTrackID % len(map)], ms=5)
351
352 self.save_and_show_file()
353
354
355class StereoQuadTreePlotter(QuadTreePlotter):
356
357 """
358 Implementation of a quad tree plotter for StereoHitAssignment
359 """
360
361
362 draw_mc_hits = False
363
364 draw_mc_tracks = False
365
366 draw_track_hits = False
367
368 draw_last_track = True
369
370 delete_bad_hits = False
371
372 def create_trajectory_from_track(self, track):
373 """
374 Convert a genfit::TrackCand into a TrackFindingCDC.CDCTrajectory3D
375
376 params
377 ------
378 track: genfit::TrackCand
379 """
382
383 position = Vector3D(track.getPosSeed())
384 momentum = Vector3D(track.getMomSeed())
385 charge = track.getChargeSeed()
386
387 return Trajectory3D(position, momentum, charge)
388
389 def create_reco_hit3D(self, cdcHit, trajectory3D, rlInfo):
390 """
391 Use a cdc hit and a trajectory to reconstruct a CDCRecoHit3D
392
393 params
394 ------
395 cdcHit: CDCHit
396 trajectory3D: TrackFindingCDC.CDCTrajectory3D
397 rlInfo: RightLeftInfo ( = short)
398 """
399 storedWireHits = Belle2.PyStoreObj('CDCWireHitVector')
400 wireHits = storedWireHits.obj().get()
401
403 wireHit = wireHits.at(bisect.bisect_left(wireHits, cdcHit))
404 rightLeftWireHit = Belle2.TrackFindingCDC.CDCRLWireHit(wireHit, rlInfo)
405 if rightLeftWireHit.getStereoType() != 0:
406 recoHit3D = CDCRecoHit3D.reconstruct(rightLeftWireHit, trajectory3D.getTrajectory2D())
407 return recoHit3D
408 else:
409 return None
410
411 def get_plottable_line(self, recoHit3D):
412 """
413 Minim the task of the StereoQuadTree by showing the line of quadtree nodes
414 a hit belongs to
415 """
416 z0 = [self.range_y_min, self.range_y_max]
417 arr_l = np.array((np.array(recoHit3D.getRecoPos3D().z()) - z0) / recoHit3D.getArcLength2D())
418 return arr_l, z0
419
420 def plot_hit_line(self, recoHit3D, color):
421 """
422 Draw one recoHit3D
423 """
424 if recoHit3D:
425 if recoHit3D.getStereoType() == 0:
426 return
427
428 arr_l, z0 = self.get_plottable_line(recoHit3D)
429 plt.plot(arr_l, z0, marker="", ls="-", alpha=0.4, color=color)
430
431 def event(self):
432 """
433 Draw the hit content according to the attributes
434
435 Attributes
436 ----------
437 draw_mc_hits
438 draw_mc_tracks
439 draw_track_hits
440 draw_last_track
441 delete_bad_hits
442 """
443
444 self.draw_quad_tree_content = True
445
446
447 self.range_x_min = -2 - np.sqrt(3)
448
449 self.range_x_max = 2 + np.sqrt(3)
450
451
452 self.range_y_min = -100
453
454 self.range_y_max = -self.range_y_min
455
456 self.init_plotting()
457 self.plot_quad_tree_content()
458
459 map = attributemaps.listColors
460 cdcHits = Belle2.PyStoreArray("CDCHits")
461
462 items = Belle2.PyStoreObj("CDCTrackVector")
463 wrapped_vector = items.obj()
464 track_vector = wrapped_vector.get()
465
466 mcHitLookUp = Belle2.TrackFindingCDC.CDCMCHitLookUp().getInstance()
467 mcHitLookUp.fill()
468
469 storedWireHits = Belle2.PyStoreObj('CDCWireHitVector')
470 wireHits = storedWireHits.obj().get()
471
472 if self.draw_mc_hits:
473 mc_track_cands = Belle2.PyStoreArray("MCTrackCands")
474
475 for track in mc_track_cands:
476 mcTrackID = track.getMcTrackId()
477 trajectory = self.create_trajectory_from_track(track)
478
479 for cdcHitID in track.getHitIDs(Belle2.Const.CDC):
480 cdcHit = cdcHits[cdcHitID]
481
482 leftRecoHit3D = self.create_reco_hit3D(cdcHit, trajectory, -1)
483 rightRecoHit3D = self.create_reco_hit3D(cdcHit, trajectory, 1)
484
485 self.plot_hit_line(leftRecoHit3D, color=map[mcTrackID % len(map)])
486 self.plot_hit_line(rightRecoHit3D, color=map[mcTrackID % len(map)])
487
488 if self.draw_mc_tracks:
489 mc_track_cands = Belle2.PyStoreArray("MCTrackCands")
490
491 for track in mc_track_cands:
492 mcTrackID = track.getMcTrackId()
493 trajectory = self.create_trajectory_from_track(track)
494 z0 = trajectory.getTrajectorySZ().getZ0()
495
496 for cdcHitID in track.getHitIDs(Belle2.Const.CDC):
497 cdcHit = cdcHits[cdcHitID]
498 recoHit3D = self.create_reco_hit3D(cdcHit, trajectory, mcHitLookUp.getRLInfo(cdcHit))
499
500 if recoHit3D:
501 arr_l = (recoHit3D.getRecoPos3D().z() - z0) / recoHit3D.getArcLength2D()
502 plt.plot(arr_l, z0, marker="o", color=map[mcTrackID % len(map)], ls="", alpha=0.2)
503
504 if self.draw_track_hits:
505 for id, track in enumerate(track_vector):
506 for recoHit3D in list(track.items()):
507 self.plot_hit_line(recoHit3D, color=map[id % len(map)])
508
509 if self.draw_last_track and len(track_vector) != 0:
510
511 last_track = track_vector[-1]
512 trajectory = last_track.getStartTrajectory3D().getTrajectory2D()
513
514 for wireHit in wireHits:
515 for rlInfo in (-1, 1):
516 recoHit3D = Belle2.TrackFindingCDC.CDCRecoHit3D.reconstruct(wireHit, rlInfo, trajectory)
517
518 if (self.delete_bad_hits and
519 (wireHit.getRLInfo() != mcHitLookUp.getRLInfo(wireHit.getWireHit().getHit()) or
520 not recoHit3D.isInCellZBounds())):
521 continue
522
523 if recoHit3D in list(last_track.items()):
524 color = map[len(track_vector) % len(map)]
525 else:
526 if wireHit.getRLInfo() == 1:
527 color = "black"
528 else:
529 color = "gray"
530 self.plot_hit_line(recoHit3D, color)
531
532 plt.xlabel(r"$\tan \ \lambda$")
533 plt.ylabel(r"$z_0$")
534 self.save_and_show_file()
535
536
537class QueueStereoQuadTreePlotter(StereoQuadTreePlotter):
538
539 """
540 A wrapper around the svg drawer in the tracking package that
541 writes its output files as a list to the queue
542 """
543
544 def __init__(self, queue, label, *args, **kwargs):
545 """ The same as the base class, except:
546
547 Arguments
548 ---------
549
550 queue: The queue to write to
551 label: The key name in the queue
552 """
553
554 self.queue = queue
555
556 self.label = label
557 StereoQuadTreePlotter.__init__(self, *args, **kwargs)
558
559
560 self.file_list = []
561
562 def terminate(self):
563 """ Overwrite the terminate to put the list to the queue"""
564 StereoQuadTreePlotter.terminate(self)
565 self.queue.put(self.label, self.file_list)
566
567 def save_and_show_file(self):
568 """ Overwrite the function to listen for every new filename """
569
570 from datetime import datetime
571 from matplotlib import pyplot as plt
572
573 fileName = tempfile.gettempdir() + "/" + datetime.now().isoformat() + '.svg'
574 plt.savefig(fileName)
575 self.file_list.append(fileName)
576
577# @endcond
A (simplified) python wrapper for StoreArray.
a (simplified) python wrapper for StoreObjPtr.
Definition PyStoreObj.h:67
Interface class to the Monte Carlo information for individual hits.
Class representing an oriented hit wire including a hypotheses whether the causing track passes left ...
Class representing a three dimensional reconstructed hit.
static CDCRecoHit3D reconstruct(const CDCRecoHit2D &recoHit2D, const CDCTrajectory2D &trajectory2D)
Reconstructs the three dimensional hit from the two dimensional and the two dimensional trajectory.
static const CDCRiemannFitter & getOriginCircleFitter()
Static getter for an origin circle fitter.
Particle full three dimensional trajectory.
A two dimensional vector which is equipped with functions for correct handling of orientation relate...
Definition Vector2D.h:32
A three dimensional vector.
Definition Vector3D.h:33
STL class.