Belle II Software  release-08-01-10
best_candidate_selection.py
1 
8 import basf2
9 import math
10 import random
11 from collections import defaultdict
12 import modularAnalysis as ma
13 from ROOT import Belle2
14 
15 
16 class Generator(basf2.Module):
17  """Generate a list of 10 electrons which have stupid momenta just to sort
18  them later. And then add one electron where all momentum components are
19  nan"""
20 
21  def initialize(self):
22  """We need to register the mc particles"""
23 
24  self.mcpmcp = Belle2.PyStoreArray("MCParticles")
25  self.mcpmcp.registerInDataStore()
26 
27  def event(self):
28  """And then we generate particles, ensuring some overlap in the momenta."""
29  print("New event:")
30  for _ in range(5):
31  px, py, pz = random.randrange(1, 5), random.randrange(1, 5), random.randrange(1, 5)
32  for _ in range(2):
33  p = self.mcpmcp.appendNew()
34  p.setPDG(11)
35  p.setMassFromPDG()
36  p.setMomentum(px, py, pz)
37 
38  p = self.mcpmcp.appendNew()
39  p.setPDG(11)
40  p.setMassFromPDG()
41  p.setMomentum(math.nan, math.nan, math.nan)
42 
43 
44 class RankChecker(basf2.Module):
45  """Check if the ranks are actually what we want"""
46 
47  def initialize(self):
48  """Create particle list object"""
49 
50  self.plistplist = Belle2.PyStoreObj("e-")
51 
52  def event(self):
53  """And check all the ranks"""
54  # make a list of all the values and a dict of all the extra infos
55  px = []
56  py = []
57  einfo = defaultdict(list)
58  for particle in self.plistplist:
59  px.append(particle.getPx())
60  py.append(particle.getPy())
61  # get all names of existing extra infos but convert to a native list of python strings to avoid
62  # possible crashes if the std::vector returned by the particle goes out of scope
63  names = [str(n) for n in particle.getExtraInfoNames()]
64  for n in names:
65  einfo[n].append(particle.getExtraInfo(n))
66 
67  # check the default name is set correctly if we don't specify an output variable
68  print(list(einfo.keys()))
69  assert 'M_rank' in einfo.keys(), "Default name is not as expected"
70 
71  # Now determine the correct ranks if multiple values are allowed:
72  # create a dictionary which will be value -> rank for all unique values
73  # in theory we just need loop over the sorted(set(values)) but we have
74  # special treatment for nans which should go always to the end of the
75  # list so sort with a special key that replaces nan by inf or -inf
76  # depending on sort order
77  px_value_ranks = {v: i for i, v in enumerate(sorted(set(px), reverse=True,
78  key=lambda v: -math.inf if math.isnan(v) else v), 1)}
79  py_value_ranks = {v: i for i, v in enumerate(sorted(set(py),
80  key=lambda v: math.inf if math.isnan(v) else v), 1)}
81 
82  # Ok, test if the rank from extra info actually corresponds to what we
83  # want
84  for v, r in zip(px, einfo["px_high_multi"]):
85  print(f"Value: {v}, rank: {r}, should be: {px_value_ranks[v]}")
86  assert r == px_value_ranks[v], "Rank is not correct"
87 
88  for v, r in zip(py, einfo["py_low_multi"]):
89  print(f"Value: {v}, rank: {r}, should be: {py_value_ranks[v]}")
90  assert r == py_value_ranks[v], "Rank is not correct"
91 
92  # so we checked multiRank=True. But for multiRank=False this is more
93  # complicated because ranking a second time will destroy the order
94  # of the previous sorts. But we can at least check if all the ranks
95  # form a range from 1..n if we sort them
96  simple_range = list(range(len(px)))
97  px_single_ranks = list(sorted(int(r) - 1 for r in einfo["px_high_single"]))
98  assert simple_range == px_single_ranks, "sorted ranks don't form a range from 1..n"
99  # but the second two rankings are on the same variable in the same
100  # order so they need to keep the order stable. so for py_low_single the
101  # ranks need to be the range without sorting
102  py_single_ranks = list(int(r) - 1 for r in einfo["py_low_single"])
103  assert simple_range == py_single_ranks, "ranks don't form a range from 1..n"
104 
105 
106 # fixed random numbers
107 random.seed(5)
108 # so lets create 10 events
109 path = basf2.Path()
110 path.add_module("EventInfoSetter", evtNumList=10)
111 # and put some electrons in there
112 path.add_module(Generator())
113 # load these electrons
114 ma.fillParticleListFromMC("e-", "", path=path)
115 # and sort them ...
116 ma.rankByHighest("e-", "M", path=path)
117 ma.rankByHighest("e-", "px", allowMultiRank=False, outputVariable="px_high_single", path=path)
118 ma.rankByHighest("e-", "px", allowMultiRank=True, outputVariable="px_high_multi", path=path)
119 ma.rankByLowest("e-", "py", allowMultiRank=False, outputVariable="py_low_single", path=path)
120 ma.rankByLowest("e-", "py", allowMultiRank=True, outputVariable="py_low_multi", path=path)
121 # and also check sorting
122 path.add_module(RankChecker())
123 
124 # we set numBest = 2: this is used also for the assert
125 numBest_value = 2
126 
127 
128 class NumBestChecker(basf2.Module):
129  """Check if 'numBest' works correctly"""
130 
131  def __init__(self):
132  """Initializing the parameters."""
133  super().__init__()
134 
135  self.num_bestnum_best = None
136 
137  self.allow_multirankallow_multirank = False
138 
139  def param(self, kwargs):
140  """Checking for module parameters to distinguish between the different test cases."""
141  self.num_bestnum_best = kwargs.pop('numBest')
142  self.allow_multirankallow_multirank = kwargs.pop('allowMultiRank', False)
143  super().param(kwargs)
144 
145  def initialize(self):
146  """Create particle list 'e-:numBest(MultiRank)' object, depending on parameter choice."""
147  if self.allow_multirankallow_multirank:
148 
149  self.plistplist = Belle2.PyStoreObj('e-:numBestMultiRank')
150  else:
151 
152  self.plistplist = Belle2.PyStoreObj('e-:numBest')
153 
154  def event(self):
155  """Check if 'e-:numBest' and 'e-:numBestMultiRank' have the expected size"""
156 
157  size = self.plistplist.getListSize()
158  if self.allow_multirankallow_multirank:
159  px = [particle.getPx() for particle in self.plistplist]
160  px_value_ranks = {v: i for i, v in enumerate(sorted(set(px), reverse=True,
161  key=lambda v: -math.inf if math.isnan(v) else v),
162  1)}
163  remaining_particles = [v for v in px if px_value_ranks[v] <= self.num_bestnum_best]
164  assert size <= len(remaining_particles), "numBest test with multirank failed: " \
165  f"there should be {len(remaining_particles)} Particles in the list " \
166  f"instead of {size}!"
167 
168  else:
169  # The test fails if size > numBest_value as this is passed as a parameter into the module
170  assert size <= self.num_bestnum_best, f"numBest test failed: there are too many Particles ({size}) in the list!"
171 
172 
173 # create a new list
174 ma.fillParticleListFromMC("e-:numBest", "", path=path)
175 # sort the list, using numBest
176 ma.rankByHighest("e-:numBest", "p", numBest=numBest_value, path=path)
177 # and check that numBest worked as expected
178 path.add_module(NumBestChecker(), numBest=numBest_value)
179 
180 # create another new list, this time for multi rank test
181 ma.fillParticleListFromMC("e-:numBestMultiRank", "", path=path)
182 # sort the list, using numBest and allowMultiRank
183 ma.rankByHighest("e-:numBestMultiRank", "px", numBest=numBest_value, allowMultiRank=True, path=path)
184 # and check that numBest worked as expected
185 path.add_module(NumBestChecker(), numBest=numBest_value, allowMultiRank=True)
186 
187 basf2.process(path)
A (simplified) python wrapper for StoreArray.
Definition: PyStoreArray.h:72
a (simplified) python wrapper for StoreObjPtr.
Definition: PyStoreObj.h:67
num_best
Number of candidates to keep (must be given as parameter, otherwise assert will fail).