Belle II Software light-2406-ragdoll
best_candidate_selection.py
1
8import basf2
9import math
10import random
11from collections import defaultdict
12import modularAnalysis as ma
13from ROOT import Belle2
14
15
16class Generator(basf2.Module):
17 """Generate a list of 10 electrons which have stupid momenta just to sort
18 them later. And then add one electron where all momentum components are
19 nan"""
20
21 def initialize(self):
22 """We need to register the mc particles"""
23
24 self.mcp = Belle2.PyStoreArray("MCParticles")
25 self.mcp.registerInDataStore()
26
27 def event(self):
28 """And then we generate particles, ensuring some overlap in the momenta."""
29 print("New event:")
30 for _ in range(5):
31 px, py, pz = random.randrange(1, 5), random.randrange(1, 5), random.randrange(1, 5)
32 for _ in range(2):
33 p = self.mcp.appendNew()
34 p.setPDG(11)
35 p.setMassFromPDG()
36 p.setMomentum(px, py, pz)
37
38 p = self.mcp.appendNew()
39 p.setPDG(11)
40 p.setMassFromPDG()
41 p.setMomentum(math.nan, math.nan, math.nan)
42
43
44class RankChecker(basf2.Module):
45 """Check if the ranks are actually what we want"""
46
47 def initialize(self):
48 """Create particle list object"""
49
51
52 def event(self):
53 """And check all the ranks"""
54 # make a list of all the values and a dict of all the extra infos
55 px = []
56 py = []
57 einfo = defaultdict(list)
58 for particle in self.plist:
59 px.append(particle.getPx())
60 py.append(particle.getPy())
61 # get all names of existing extra infos but convert to a native list of python strings to avoid
62 # possible crashes if the std::vector returned by the particle goes out of scope
63 names = [str(n) for n in particle.getExtraInfoNames()]
64 for n in names:
65 einfo[n].append(particle.getExtraInfo(n))
66
67 # check the default name is set correctly if we don't specify an output variable
68 print(list(einfo.keys()))
69 assert 'M_rank' in einfo.keys(), "Default name is not as expected"
70
71 # Now determine the correct ranks if multiple values are allowed:
72 # create a dictionary which will be value -> rank for all unique values
73 # in theory we just need loop over the sorted(set(values)) but we have
74 # special treatment for nans which should go always to the end of the
75 # list so sort with a special key that replaces nan by inf or -inf
76 # depending on sort order
77 px_value_ranks = {v: i for i, v in enumerate(sorted(set(px), reverse=True,
78 key=lambda v: -math.inf if math.isnan(v) else v), 1)}
79 py_value_ranks = {v: i for i, v in enumerate(sorted(set(py),
80 key=lambda v: math.inf if math.isnan(v) else v), 1)}
81
82 # Ok, test if the rank from extra info actually corresponds to what we
83 # want
84 for v, r in zip(px, einfo["px_high_multi"]):
85 print(f"Value: {v}, rank: {r}, should be: {px_value_ranks[v]}")
86 assert r == px_value_ranks[v], "Rank is not correct"
87
88 for v, r in zip(py, einfo["py_low_multi"]):
89 print(f"Value: {v}, rank: {r}, should be: {py_value_ranks[v]}")
90 assert r == py_value_ranks[v], "Rank is not correct"
91
92 # so we checked multiRank=True. But for multiRank=False this is more
93 # complicated because ranking a second time will destroy the order
94 # of the previous sorts. But we can at least check if all the ranks
95 # form a range from 1..n if we sort them
96 simple_range = list(range(len(px)))
97 px_single_ranks = list(sorted(int(r) - 1 for r in einfo["px_high_single"]))
98 assert simple_range == px_single_ranks, "sorted ranks don't form a range from 1..n"
99 # but the second two rankings are on the same variable in the same
100 # order so they need to keep the order stable. so for py_low_single the
101 # ranks need to be the range without sorting
102 py_single_ranks = list(int(r) - 1 for r in einfo["py_low_single"])
103 assert simple_range == py_single_ranks, "ranks don't form a range from 1..n"
104
105
106class NumBestChecker(basf2.Module):
107 """Check if 'numBest' works correctly"""
108
109 def __init__(self):
110 """Initializing the parameters."""
111 super().__init__()
112
113 self.num_best = None
114
115 self.allow_multirank = False
116
117 def param(self, kwargs):
118 """Checking for module parameters to distinguish between the different test cases."""
119 self.num_best = kwargs.pop('numBest')
120 self.allow_multirank = kwargs.pop('allowMultiRank', False)
121 super().param(kwargs)
122
123 def initialize(self):
124 """Create particle list 'e-:numBest(MultiRank)' object, depending on parameter choice."""
125 if self.allow_multirank:
126
127 self.plist = Belle2.PyStoreObj('e-:numBestMultiRank')
128 else:
129
130 self.plist = Belle2.PyStoreObj('e-:numBest')
131
132 def event(self):
133 """Check if 'e-:numBest' and 'e-:numBestMultiRank' have the expected size"""
134
135 size = self.plist.getListSize()
136 if self.allow_multirank:
137 px = [particle.getPx() for particle in self.plist]
138 px_value_ranks = {v: i for i, v in enumerate(sorted(set(px), reverse=True,
139 key=lambda v: -math.inf if math.isnan(v) else v),
140 1)}
141 remaining_particles = [v for v in px if px_value_ranks[v] <= self.num_best]
142 assert size <= len(remaining_particles), "numBest test with multirank failed: " \
143 f"there should be {len(remaining_particles)} Particles in the list " \
144 f"instead of {size}!"
145
146 else:
147 # The test fails if size > numBest_value as this is passed as a parameter into the module
148 assert size <= self.num_best, f"numBest test failed: there are too many Particles ({size}) in the list!"
149
150
151# fixed random numbers
152random.seed(5)
153# so lets create 10 events
154path = basf2.Path()
155path.add_module("EventInfoSetter", evtNumList=10)
156# and put some electrons in there
157path.add_module(Generator())
158# load these electrons
159ma.fillParticleListFromMC("e-", "", path=path)
160# and sort them ...
161ma.rankByHighest("e-", "M", path=path)
162ma.rankByHighest("e-", "px", allowMultiRank=False, outputVariable="px_high_single", path=path)
163ma.rankByHighest("e-", "px", allowMultiRank=True, outputVariable="px_high_multi", path=path)
164ma.rankByLowest("e-", "py", allowMultiRank=False, outputVariable="py_low_single", path=path)
165ma.rankByLowest("e-", "py", allowMultiRank=True, outputVariable="py_low_multi", path=path)
166# and also check sorting
167path.add_module(RankChecker())
168
169# we set numBest = 2: this is used also for the assert
170numBest_value = 2
171
172# create a new list
173ma.fillParticleListFromMC("e-:numBest", "", path=path)
174# sort the list, using numBest
175ma.rankByHighest("e-:numBest", "p", numBest=numBest_value, path=path)
176# and check that numBest worked as expected
177path.add_module(NumBestChecker(), numBest=numBest_value)
178
179# create another new list, this time for multi rank test
180ma.fillParticleListFromMC("e-:numBestMultiRank", "", path=path)
181# sort the list, using numBest and allowMultiRank
182ma.rankByHighest("e-:numBestMultiRank", "px", numBest=numBest_value, allowMultiRank=True, path=path)
183# and check that numBest worked as expected
184path.add_module(NumBestChecker(), numBest=numBest_value, allowMultiRank=True)
185
186basf2.process(path)
A (simplified) python wrapper for StoreArray.
Definition: PyStoreArray.h:72
a (simplified) python wrapper for StoreObjPtr.
Definition: PyStoreObj.h:67
num_best
Number of candidates to keep (must be given as parameter, otherwise assert will fail).