Belle II Software  release-06-01-15
retention.py
1 # !/usr/bin/env python3
2 # -*- coding: utf-8 -*-
3 
4 
11 
12 """\
13 Provides class for tracking retention rate of each cut in a skim.
14 """
15 
16 import os
17 import matplotlib.pyplot as plt
18 
19 from ROOT import Belle2
20 
21 import basf2 as b2
22 
23 
24 class RetentionCheck(b2.Module):
25  """Check the retention rate and the number of candidates for a given set of particle lists.
26 
27  The module stores its results in the static variable "summary".
28 
29  To monitor the effect of every module of an initial path, this module should be added after
30  each module of the path. A function was written (`skim.utils.retention.pathWithRetentionCheck`) to do it:
31 
32  >>> path = pathWithRetentionCheck(particle_lists, path)
33 
34  After the path processing, the result of the RetentionCheck can be printed with
35 
36  >>> RetentionCheck.print_results()
37 
38  or plotted with (check the corresponding documentation)
39 
40  >>> RetentionCheck.plot_retention(...)
41 
42  and the summary dictionary can be accessed through
43 
44  >>> RetentionCheck.summary
45 
46  Authors:
47 
48  Cyrille Praz, Slavomira Stefkova
49 
50  Parameters:
51 
52  module_name (str): name of the module after which the retention rate is measured
53  module_number (int): index of the module after which the retention rate is measured
54  particle_lists (list(str)): list of particle list names which will be tracked by the module
55  """
56 
57  summary = {} # static dictionary containing the results (retention rates, number of candidates, ...)
58  output_override = None # if the -o option is provided to basf2, this variable store the ouptut for the plotting
59 
60  def __init__(self, module_name='', module_number=0, particle_lists=None):
61 
62  if particle_lists is None:
63  particle_lists = []
64 
65  self.module_namemodule_name = str(module_name)
66  self.module_numbermodule_number = int(module_number)
67 
68  self.candidate_countcandidate_count = {pl: 0 for pl in particle_lists}
69  self.event_with_candidate_countevent_with_candidate_count = {pl: 0 for pl in particle_lists}
70 
71  self.particle_listsparticle_lists = particle_lists
72 
73  self._key_key = "{:04}. {}".format(int(self.module_numbermodule_number), str(self.module_namemodule_name))
74  type(self).summary[self._key_key] = {}
75 
76  if type(self).output_override is None:
77  type(self).output_override = Belle2.Environment.Instance().getOutputFileOverride()
78 
79  super().__init__()
80 
81  def event(self):
82  """"""
83 
84  for particle_list in self.particle_listsparticle_lists:
85 
86  pl = Belle2.PyStoreObj(Belle2.ParticleList.Class(), particle_list)
87 
88  if pl.isValid():
89 
90  self.candidate_countcandidate_count[particle_list] += pl.getListSize()
91 
92  if pl.getListSize() != 0:
93 
94  self.event_with_candidate_countevent_with_candidate_count[particle_list] += 1
95 
96  def terminate(self):
97  """"""
98 
99  N = Belle2.Environment.Instance().getNumberOfEvents()
100 
101  for particle_list in self.particle_listsparticle_lists:
102 
103  if N > 0:
104 
105  retention_rate = float(self.event_with_candidate_countevent_with_candidate_count[particle_list]) / N
106 
107  else:
108 
109  b2.B2WARNING("Belle2.Environment.Instance().getNumberOfEvents() gives 0 or less.")
110  retention_rate = 0
111 
112  type(self).summary[self._key_key][particle_list] = {"retention_rate": retention_rate,
113  "#candidates": self.candidate_countcandidate_count[particle_list],
114  "#evts_with_candidates": self.event_with_candidate_countevent_with_candidate_count[particle_list],
115  "total_#events": N}
116 
117  @classmethod
118  def print_results(cls):
119  """ Print the results, should be called after the path processing."""
120  summary_tables = {} # one summary table per particle list
121  table_headline = "{:<100}|{:>9}|{:>12}|{:>22}|{:>12}|\n"
122  table_line = "{:<100}|{:>9.3f}|{:>12}|{:>22}|{:>12}|\n"
123 
124  atLeastOneEntry = {} # check if there is at least one non-zero retention for a given particle list
125 
126  for module, module_results in cls.summarysummary.items():
127 
128  for particle_list, list_results in module_results.items():
129 
130  if particle_list not in summary_tables.keys():
131 
132  atLeastOneEntry[particle_list] = False
133 
134  summary_tables[particle_list] = table_headline.format(
135  "Module", "Retention", "# Candidates", "# Evts with candidates", "Total # evts")
136  summary_tables[particle_list] += "=" * 160 + "\n"
137 
138  else:
139 
140  if list_results["retention_rate"] > 0 or atLeastOneEntry[particle_list]:
141 
142  atLeastOneEntry[particle_list] = True
143  if len(module) > 100: # module name tool long
144  module = module[:96] + "..."
145  summary_tables[particle_list] += table_line.format(module, *list_results.values())
146 
147  for particle_list, summary_table in summary_tables.items():
148  b2.B2INFO("\n" + "=" * 160 + "\n" +
149  "Results of the modules RetentionCheck for the list " + particle_list + ".\n" +
150  "=" * 160 + "\n" +
151  "Note: the module RetentionCheck is defined in skim/scripts/skim/utils/retention.py\n" +
152  "=" * 160 + "\n" +
153  summary_table +
154  "=" * 160 + "\n" +
155  "End of the results of the modules RetentionCheck for the list " + particle_list + ".\n" +
156  "=" * 160 + "\n"
157  )
158 
159  @classmethod
160  def plot_retention(cls, particle_list, plot_title="", save_as=None, module_name_max_length=80):
161  """ Plot the result of the RetentionCheck for a given particle list.
162 
163  Example of use (to be put after process(path)):
164 
165  >>> RetentionCheck.plot_retention('B+:semileptonic','skim:feiSLBplus','retention_plots/plot.pdf')
166 
167  Parameters:
168 
169  particle_list (str): particle list name
170  title (str): plot title (overwritten by the -o argument in basf2)
171  save_as (str): output filename (overwritten by the -o argument in basf2)
172  module_name_max_length (int): if the module name length is higher than this value, do not display the full name
173  """
174  module_name = []
175  retention = []
176 
177  at_least_one_entry = False
178  for module, results in cls.summarysummary.items():
179 
180  if particle_list not in results.keys():
181  b2.B2WARNING(particle_list + " is not present in the results of the RetentionCheck for the module {}."
182  .format(module))
183  return
184 
185  if results[particle_list]['retention_rate'] > 0 or at_least_one_entry:
186  at_least_one_entry = True
187  if len(module) > module_name_max_length and module_name_max_length > 3: # module name tool long
188  module = module[:module_name_max_length - 3] + "..."
189  module_name.append(module)
190  retention.append(100 * (results[particle_list]['retention_rate']))
191 
192  if not at_least_one_entry:
193  b2.B2WARNING(particle_list + " seems to have a zero retention rate when created (if created).")
194  return
195 
196  plt.figure()
197  bars = plt.barh(module_name, retention, label=particle_list, color=(0.67, 0.15, 0.31, 0.6))
198 
199  for bar in bars:
200  yval = bar.get_width()
201  plt.text(0.5, bar.get_y() + bar.get_height() / 2.0 + 0.1, str(round(yval, 3)))
202 
203  plt.gca().invert_yaxis()
204  plt.xticks(rotation=45)
205  plt.xlim(0, 100)
206  plt.axvline(x=10.0, linewidth=1, linestyle="--", color='k', alpha=0.5)
207  plt.xlabel('Retention Rate [%]')
208  plt.legend(loc='lower right')
209 
210  if save_as or cls.output_overrideoutput_override:
211  if cls.output_overrideoutput_override:
212  plot_title = (cls.output_overrideoutput_override).split(".")[0]
213  save_as = plot_title + '.pdf'
214  if '/' in save_as:
215  os.makedirs(os.path.dirname(save_as), exist_ok=True)
216  plt.title(plot_title)
217  plt.savefig(save_as, bbox_inches="tight")
218  b2.B2RESULT("Retention rate results for list {} saved in {}."
219  .format(particle_list, os.getcwd() + "/" + save_as))
220 
221 
222 def pathWithRetentionCheck(particle_lists, path):
223  """ Return a new path with the module RetentionCheck inserted between each module of a given path.
224 
225  This allows for checking how the retention rate is modified by each module of the path.
226 
227  Example of use (to be put just before process(path)):
228 
229  >>> path = pathWithRetentionCheck(['B+:semileptonic'], path)
230 
231  Warning: pathWithRetentionCheck(['B+:semileptonic'], path) does not modify path,
232  it only returns a new one.
233 
234  After the path processing, the result of the RetentionCheck can be printed with
235 
236  >>> RetentionCheck.print_results()
237 
238  or plotted with (check the corresponding documentation)
239 
240  >>> RetentionCheck.plot_retention(...)
241 
242  and the summary dictionary can be accessed through
243 
244  >>> RetentionCheck.summary
245 
246  Parameters:
247 
248  particle_lists (list(str)): list of particle list names which will be tracked by RetentionCheck
249  path (basf2.Path): initial path (it is not modified, see warning above and example of use)
250  """
251  new_path = b2.Path()
252  for module_number, module in enumerate(path.modules()):
253  new_path.add_module(module)
254  if 'ParticleSelector' in module.name():
255  name = module.name() + '(' + module.available_params()[0].values + ')' # get the cut string
256  else:
257  name = module.name()
258  new_path.add_module(RetentionCheck(name, module_number, particle_lists))
259  return new_path
static Environment & Instance()
Static method to get a reference to the Environment instance.
Definition: Environment.cc:29
a (simplified) python wrapper for StoreObjPtr.
Definition: PyStoreObj.h:67
def plot_retention(cls, particle_list, plot_title="", save_as=None, module_name_max_length=80)
Definition: retention.py:160