Belle II Software development
classification.py
1#!/usr/bin/env python3
2
3
10
11import numpy as np
12import collections
13import numbers
14import copy
15
16from tracking.validation import scores, statistics
17
18from tracking.validation.plot import ValidationPlot, compose_axis_label
19from tracking.validation.fom import ValidationFiguresOfMerit
20from tracking.validation.tolerate_missing_key_formatter import TolerateMissingKeyFormatter
21
22formatter = TolerateMissingKeyFormatter()
23
24
26 """Perform truth-classification analysis"""
27
29 self,
30 contact,
31 quantity_name,
32 cut_direction=None,
33 cut=None,
34 lower_bound=None,
35 upper_bound=None,
36 outlier_z_score=None,
37 allow_discrete=None,
38 unit=None
39 ):
40 """Compare an estimated quantity to the truths by generating standardized validation plots."""
41
42
43 self._contact = contact
44
45 self.quantity_name = quantity_name
46
47
48 self.plots = collections.OrderedDict()
49
50 self.fom = None
51
52
53 self.cut_direction = cut_direction
54
55 self.cut = cut
56
57
58 self.lower_bound = lower_bound
59
60 self.upper_bound = upper_bound
61
62 self.outlier_z_score = outlier_z_score
63
64 self.allow_discrete = allow_discrete
65
66 self.unit = unit
67
69 self,
70 estimates,
71 truths,
72 auxiliaries={}
73 ):
74 """Compares the concrete estimate to the truth and efficiency, purity and background rejection
75 as figure of merit and plots the selection as a stacked plot over the truths.
76
77 Parameters
78 ----------
79 estimates : array_like
80 Selection variable to compare to the truths
81 truths : array_like
82 Binary true class values.
83 """
84
85 quantity_name = self.quantity_name
86 axis_label = compose_axis_label(quantity_name, self.unit)
87
88 plot_name = "{quantity_name}_{subplot_name}"
89 plot_name = formatter.format(plot_name, quantity_name=quantity_name)
90
91 signals = truths != 0
92
93 # Some different things become presentable depending on the estimates
94 estimate_is_binary = statistics.is_binary_series(estimates)
95
96 if estimate_is_binary:
97 binary_estimates = estimates != 0
98 cut_value = 0.5
99 cut_direction = -1 # reject low values
100
101 elif self.cut is not None:
102 if isinstance(self.cut, numbers.Number):
103 cut_value = self.cut
104 cut_direction = self.cut_direction
105 cut_classifier = CutClassifier(cut_direction=cut_direction, cut_value=cut_value)
106
107 else:
108 cut_classifier = self.cut
109 cut_classifier = cut_classifier.clone()
110
111 cut_classifier.fit(estimates, truths)
112 binary_estimates = cut_classifier.predict(estimates) != 0
113 cut_direction = cut_classifier.cut_direction
114 cut_value = cut_classifier.cut_value
115
116 if not isinstance(self.cut, numbers.Number):
117 print(formatter.format(plot_name, subplot_name="cut_classifier"), "summary")
118 cut_classifier.describe(estimates, truths)
119
120 else:
121 cut_value = None
122 cut_direction = self.cut_direction
123
124 lower_bound = self.lower_bound
125 upper_bound = self.upper_bound
126
127 # Stacked histogram
128 signal_bkg_histogram_name = formatter.format(plot_name, subplot_name="signal_bkg_histogram")
129 signal_bkg_histogram = ValidationPlot(signal_bkg_histogram_name)
130 signal_bkg_histogram.hist(
131 estimates,
132 stackby=truths,
133 lower_bound=lower_bound,
134 upper_bound=upper_bound,
135 outlier_z_score=self.outlier_z_score,
136 allow_discrete=self.allow_discrete,
137 )
138 signal_bkg_histogram.xlabel = axis_label
139
140 if lower_bound is None:
141 lower_bound = signal_bkg_histogram.lower_bound
142
143 if upper_bound is None:
144 upper_bound = signal_bkg_histogram.upper_bound
145
146 self.plots['signal_bkg'] = signal_bkg_histogram
147
148 # Purity profile
149 purity_profile_name = formatter.format(plot_name, subplot_name="purity_profile")
150
151 purity_profile = ValidationPlot(purity_profile_name)
152 purity_profile.profile(
153 estimates,
154 truths,
155 lower_bound=lower_bound,
156 upper_bound=upper_bound,
157 outlier_z_score=self.outlier_z_score,
158 allow_discrete=self.allow_discrete,
159 )
160
161 purity_profile.xlabel = axis_label
162 purity_profile.ylabel = 'purity'
163 self.plots["purity"] = purity_profile
164
165 # Try to guess the cur direction form the correlation
166 if cut_direction is None:
167 purity_grapherrors = ValidationPlot.convert_tprofile_to_tgrapherrors(purity_profile.plot)
168 correlation = purity_grapherrors.GetCorrelationFactor()
169 if correlation > 0.1:
170 print("Determined cut direction", -1)
171 cut_direction = -1 # reject low values
172 elif correlation < -0.1:
173 print("Determined cut direction", 1)
174 cut_direction = +1 # reject high values
175
176 # Figures of merit
177 if cut_value is not None:
178 fom_name = formatter.format(plot_name, subplot_name="classification_figures_of_merits")
179 fom_description = f"Efficiency, purity and background rejection of the classifiction with {quantity_name}"
180
181 fom_check = "Check that the classifcation quality stays stable."
182
183 fom_title = f"Summary of the classification quality with {quantity_name}"
184
185 classification_fom = ValidationFiguresOfMerit(
186 name=fom_name,
187 title=fom_title,
188 description=fom_description,
189 check=fom_check,
190 contact=self.contactcontactcontact,
191 )
192
193 efficiency = scores.efficiency(truths, binary_estimates)
194 purity = scores.purity(truths, binary_estimates)
195 background_rejection = scores.background_rejection(truths, binary_estimates)
196
197 classification_fom['cut_value'] = cut_value
198 classification_fom['cut_direction'] = cut_direction
199 classification_fom['efficiency'] = efficiency
200 classification_fom['purity'] = purity
201 classification_fom['background_rejection'] = background_rejection
202
203 self.fom = classification_fom
204 # Auxiliary hists
205 for aux_name, aux_values in auxiliaries.items():
206 if statistics.is_single_value_series(aux_values) or aux_name == quantity_name:
207 continue
208
209 aux_axis_label = compose_axis_label(aux_name)
210
211 # Signal + bkg distribution over estimate and auxiliary variable #
212 # ############################################################## #
213 signal_bkg_aux_hist2d_name = formatter.format(plot_name, subplot_name=aux_name + '_signal_bkg_aux2d')
214 signal_bkg_aux_hist2d = ValidationPlot(signal_bkg_aux_hist2d_name)
215 signal_bkg_aux_hist2d.hist2d(
216 aux_values,
217 estimates,
218 stackby=truths,
219 lower_bound=(None, lower_bound),
220 upper_bound=(None, upper_bound),
221 outlier_z_score=self.outlier_z_score,
222 allow_discrete=self.allow_discrete,
223 )
224
225 aux_lower_bound = signal_bkg_aux_hist2d.lower_bound[0]
226 aux_upper_bound = signal_bkg_aux_hist2d.upper_bound[0]
227
228 signal_bkg_aux_hist2d.xlabel = aux_axis_label
229 signal_bkg_aux_hist2d.ylabel = axis_label
230 self.plots[signal_bkg_aux_hist2d_name] = signal_bkg_aux_hist2d
231
232 # Figures of merit as function of the auxiliary variables
233 if cut_value is not None:
234
235 # Auxiliary purity profile #
236 # ######################## #
237 aux_purity_profile_name = formatter.format(plot_name, subplot_name=aux_name + "_aux_purity_profile")
238 aux_purity_profile = ValidationPlot(aux_purity_profile_name)
239 aux_purity_profile.profile(
240 aux_values[binary_estimates],
241 truths[binary_estimates],
242 outlier_z_score=self.outlier_z_score,
243 allow_discrete=self.allow_discrete,
244 lower_bound=aux_lower_bound,
245 upper_bound=aux_upper_bound,
246 )
247
248 aux_purity_profile.xlabel = aux_axis_label
249 aux_purity_profile.ylabel = 'purity'
250 self.plots[aux_purity_profile_name] = aux_purity_profile
251
252 # Auxiliary efficiency profile #
253 # ############################ #
254 aux_efficiency_profile_name = formatter.format(plot_name, subplot_name=aux_name + "_aux_efficiency_profile")
255 aux_efficiency_profile = ValidationPlot(aux_efficiency_profile_name)
256 aux_efficiency_profile.profile(
257 aux_values[signals],
258 binary_estimates[signals],
259 outlier_z_score=self.outlier_z_score,
260 allow_discrete=self.allow_discrete,
261 lower_bound=aux_lower_bound,
262 upper_bound=aux_upper_bound,
263 )
264
265 aux_efficiency_profile.xlabel = aux_axis_label
266 aux_efficiency_profile.ylabel = 'efficiency'
267 self.plots[aux_efficiency_profile_name] = aux_efficiency_profile
268
269 # Auxiliary bkg rejection profile #
270 # ############################### #
271 aux_bkg_rejection_profile_name = formatter.format(plot_name, subplot_name=aux_name + "_aux_bkg_rejection_profile")
272 aux_bkg_rejection_profile = ValidationPlot(aux_bkg_rejection_profile_name)
273 aux_bkg_rejection_profile.profile(
274 aux_values[~signals],
275 ~binary_estimates[~signals],
276 outlier_z_score=self.outlier_z_score,
277 allow_discrete=self.allow_discrete,
278 lower_bound=aux_lower_bound,
279 upper_bound=aux_upper_bound,
280 )
281
282 aux_bkg_rejection_profile.xlabel = aux_axis_label
283 aux_bkg_rejection_profile.ylabel = 'bkg rejection'
284 self.plots[aux_bkg_rejection_profile_name] = aux_bkg_rejection_profile
285
286 cut_abs = False
287 if cut_direction is None:
288 purity_grapherrors = ValidationPlot.convert_tprofile_to_tgrapherrors(purity_profile.plot,
289 abs_x=True)
290 correlation = purity_grapherrors.GetCorrelationFactor()
291 if correlation > 0.1:
292 print("Determined absolute cut direction", -1)
293 cut_direction = -1 # reject low values
294 cut_abs = True
295 elif correlation < -0.1:
296 print("Determined absolute cut direction", 1)
297 cut_direction = +1 # reject high values
298 cut_abs = True
299
300 if cut_abs:
301 estimates = np.abs(estimates)
302 cut_x_label = "cut " + compose_axis_label("abs(" + quantity_name + ")", self.unit)
303 lower_bound = 0
304 else:
305 cut_x_label = "cut " + axis_label
306
307 # Quantile plots
308 if not estimate_is_binary and cut_direction is not None:
309 # Signal estimate quantiles over auxiliary variable #
310 # ################################################# #
311 if cut_direction > 0:
312 quantiles = [0.5, 0.90, 0.99]
313 else:
314 quantiles = [0.01, 0.10, 0.5]
315
316 for aux_name, aux_values in auxiliaries.items():
317 if statistics.is_single_value_series(aux_values) or aux_name == quantity_name:
318 continue
319
320 aux_axis_label = compose_axis_label(aux_name)
321
322 signal_quantile_aux_profile_name = formatter.format(plot_name, subplot_name=aux_name + '_signal_quantiles_aux2d')
323 signal_quantile_aux_profile = ValidationPlot(signal_quantile_aux_profile_name)
324 signal_quantile_aux_profile.hist2d(
325 aux_values[signals],
326 estimates[signals],
327 quantiles=quantiles,
328 bins=('flat', None),
329 lower_bound=(None, lower_bound),
330 upper_bound=(None, upper_bound),
331 outlier_z_score=self.outlier_z_score,
332 allow_discrete=self.allow_discrete,
333 )
334 signal_quantile_aux_profile.xlabel = aux_axis_label
335 signal_quantile_aux_profile.ylabel = cut_x_label
336 self.plots[signal_quantile_aux_profile_name] = signal_quantile_aux_profile
337
338 # ROC plots
339 if not estimate_is_binary and cut_direction is not None:
340 n_data = len(estimates)
341 n_signals = scores.signal_amount(truths, estimates)
342 n_bkgs = n_data - n_signals
343
344 # work around for numpy sorting nan values as high but we want it as low depending on the cut direction
345 if cut_direction < 0: # reject low
346 sorting_indices = np.argsort(-estimates)
347 else:
348 sorting_indices = np.argsort(estimates)
349
350 sorted_truths = truths[sorting_indices]
351 sorted_estimates = estimates[sorting_indices]
352
353 sorted_n_accepted_signals = np.cumsum(sorted_truths, dtype=float)
354 sorted_efficiencies = sorted_n_accepted_signals / n_signals
355
356 sorted_n_rejected_signals = n_signals - sorted_n_accepted_signals
357 sorted_n_rejects = np.arange(len(estimates) + 1, 1, -1)
358 sorted_n_rejected_bkgs = sorted_n_rejects - sorted_n_rejected_signals
359 sorted_bkg_rejections = sorted_n_rejected_bkgs / n_bkgs
360
361 # Efficiency by cut value #
362 # ####################### #
363 efficiency_by_cut_profile_name = formatter.format(plot_name, subplot_name="efficiency_by_cut")
364
365 efficiency_by_cut_profile = ValidationPlot(efficiency_by_cut_profile_name)
366 efficiency_by_cut_profile.profile(
367 sorted_estimates,
368 sorted_efficiencies,
369 lower_bound=lower_bound,
370 upper_bound=upper_bound,
371 outlier_z_score=self.outlier_z_score,
372 allow_discrete=self.allow_discrete,
373 )
374
375 efficiency_by_cut_profile.xlabel = cut_x_label
376 efficiency_by_cut_profile.ylabel = "efficiency"
377
378 self.plots["efficiency_by_cut"] = efficiency_by_cut_profile
379
380 # Background rejection over cut value #
381 # ################################### #
382 bkg_rejection_by_cut_profile_name = formatter.format(plot_name, subplot_name="bkg_rejection_by_cut")
383 bkg_rejection_by_cut_profile = ValidationPlot(bkg_rejection_by_cut_profile_name)
384 bkg_rejection_by_cut_profile.profile(
385 sorted_estimates,
386 sorted_bkg_rejections,
387 lower_bound=lower_bound,
388 upper_bound=upper_bound,
389 outlier_z_score=self.outlier_z_score,
390 allow_discrete=self.allow_discrete,
391 )
392
393 bkg_rejection_by_cut_profile.xlabel = cut_x_label
394 bkg_rejection_by_cut_profile.ylabel = "background rejection"
395
396 self.plots["bkg_rejection_by_cut"] = bkg_rejection_by_cut_profile
397
398 # Purity over efficiency #
399 # ###################### #
400 purity_over_efficiency_profile_name = formatter.format(plot_name, subplot_name="purity_over_efficiency")
401 purity_over_efficiency_profile = ValidationPlot(purity_over_efficiency_profile_name)
402 purity_over_efficiency_profile.profile(
403 sorted_efficiencies,
404 sorted_truths,
405 cumulation_direction=1,
406 lower_bound=0,
407 upper_bound=1
408 )
409 purity_over_efficiency_profile.xlabel = 'efficiency'
410 purity_over_efficiency_profile.ylabel = 'purity'
411
412 self.plots["purity_over_efficiency"] = purity_over_efficiency_profile
413
414 # Cut over efficiency #
415 # ################### #
416 cut_over_efficiency_profile_name = formatter.format(plot_name, subplot_name="cut_over_efficiency")
417 cut_over_efficiency_profile = ValidationPlot(cut_over_efficiency_profile_name)
418 cut_over_efficiency_profile.profile(
419 sorted_efficiencies,
420 sorted_estimates,
421 lower_bound=0,
422 upper_bound=1,
423 outlier_z_score=self.outlier_z_score,
424 allow_discrete=self.allow_discrete,
425 )
426 cut_over_efficiency_profile.set_minimum(lower_bound)
427 cut_over_efficiency_profile.set_maximum(upper_bound)
428 cut_over_efficiency_profile.xlabel = 'efficiency'
429 cut_over_efficiency_profile.ylabel = cut_x_label
430
431 self.plots["cut_over_efficiency"] = cut_over_efficiency_profile
432
433 # Cut over bkg_rejection #
434 # ###################### #
435 cut_over_bkg_rejection_profile_name = formatter.format(plot_name, subplot_name="cut_over_bkg_rejection")
436 cut_over_bkg_rejection_profile = ValidationPlot(cut_over_bkg_rejection_profile_name)
437 cut_over_bkg_rejection_profile.profile(
438 sorted_bkg_rejections,
439 sorted_estimates,
440 lower_bound=0,
441 upper_bound=1,
442 outlier_z_score=self.outlier_z_score,
443 allow_discrete=self.allow_discrete,
444 )
445 cut_over_bkg_rejection_profile.set_minimum(lower_bound)
446 cut_over_bkg_rejection_profile.set_maximum(upper_bound)
447 cut_over_bkg_rejection_profile.xlabel = 'bkg_rejection'
448 cut_over_bkg_rejection_profile.ylabel = cut_x_label
449
450 self.plots["cut_over_bkg_rejection"] = cut_over_bkg_rejection_profile
451
452 # Efficiency over background rejection #
453 # #################################### #
454 efficiency_over_bkg_rejection_profile_name = formatter.format(plot_name, subplot_name="efficiency_over_bkg_rejection")
455 efficiency_over_bkg_rejection_profile = ValidationPlot(efficiency_over_bkg_rejection_profile_name)
456 efficiency_over_bkg_rejection_profile.profile(
457 sorted_bkg_rejections,
458 sorted_efficiencies,
459 lower_bound=0,
460 upper_bound=1
461 )
462
463 efficiency_over_bkg_rejection_profile.xlabel = "bkg rejection"
464 efficiency_over_bkg_rejection_profile.ylabel = "efficiency"
465
466 self.plots["efficiency_over_bkg_rejection"] = efficiency_over_bkg_rejection_profile
467
468
470
471 @property
472 def contact(self):
473 """Get the name of the contact person"""
474 return self._contact
475
476 @contact.setter
477 def contact(self, contact):
478 """Set the name of the contact person"""
479 self._contact = contact
480
481 for plot in list(self.plots.values()):
482 plot.contact = contact
483
484 if self.fom:
485 self.fom.contact = contact
486
487 def write(self, tdirectory=None):
488 """Write the plots to the ROOT TDirectory"""
489 for plot in list(self.plots.values()):
490 plot.write(tdirectory)
491
492 if self.fom:
493 self.fom.write(tdirectory)
494
495
497
498 """Simple classifier cutting on a single variable"""
499
500 def __init__(self, cut_direction=1, cut_value=np.nan):
501 """Constructor"""
502
503 self.cut_direction_ = cut_direction
504
505 self.cut_value_ = cut_value
506
507 @property
508 def cut_direction(self):
509 """Get the value of the cut direction"""
510 return self.cut_direction_
511
512 @property
513 def cut_value(self):
514 """Get the value of the cut threshold"""
515 return self.cut_value_
516
517 def clone(self):
518 """Return a clone of this object"""
519 return copy.copy(self)
520
521 def determine_cut_value(self, estimates, truths):
522 """Get the value of the cut threshold"""
523 return self.cut_value_ # do not change cut value from constructed one
524
525 def fit(self, estimates, truths):
526 """Fit to determine the cut threshold"""
527 self.cut_value_ = self.determine_cut_value(estimates, truths)
528 return self
529
530 def predict(self, estimates):
531 """Select estimates that satisfy the cut"""
532 if self.cut_value_ is None:
533 raise ValueError("Cut value not set. Forgot to fit?")
534
535 if self.cut_direction_ < 0:
536 binary_estimates = estimates >= self.cut_value_
537 else:
538 binary_estimates = estimates <= self.cut_value_
539
540 return binary_estimates
541
542 def describe(self, estimates, truths):
543 """Describe the cut selection and its efficiency, purity and background rejection"""
544 if self.cut_direction_ < 0:
545 print("Cut accepts >= ", self.cut_value_, 'with')
546 else:
547 print("Cut accepts <= ", self.cut_value_, 'with')
548
549 binary_estimates = self.predict(estimates)
550
551 efficiency = scores.efficiency(truths, binary_estimates)
552 purity = scores.purity(truths, binary_estimates)
553 background_rejection = scores.background_rejection(truths, binary_estimates)
554
555 print("efficiency", efficiency)
556 print("purity", purity)
557 print("background_rejection", background_rejection)
558
559
560def cut_at_background_rejection(background_rejection=0.5, cut_direction=1):
561 return CutAtBackgroundRejectionClassifier(background_rejection, cut_direction)
562
563
565 """Apply cut on the background rejection"""
566
567 def __init__(self, background_rejection=0.5, cut_direction=1):
568 """Constructor"""
569 super().__init__(cut_direction=cut_direction, cut_value=np.nan)
570
571 self.background_rejection = background_rejection
572
573 def determine_cut_value(self, estimates, truths):
574 """Find the cut value that satisfies the desired background-rejection level"""
575 n_data = len(estimates)
576 n_signals = scores.signal_amount(truths, estimates)
577 n_bkgs = n_data - n_signals
578
579 sorting_indices = np.argsort(estimates)
580 if self.cut_direction_ < 0: # reject low
581 # Keep a reference to keep the content alive
582 original_sorting_indices = sorting_indices # noqa
583 sorting_indices = sorting_indices[::-1]
584
585 sorted_truths = truths[sorting_indices]
586 sorted_estimates = estimates[sorting_indices]
587
588 sorted_n_accepted_signals = np.cumsum(sorted_truths, dtype=float)
589 # sorted_efficiencies = sorted_n_accepted_signals / n_signals
590
591 sorted_n_rejected_signals = n_signals - sorted_n_accepted_signals
592 sorted_n_rejects = np.arange(len(estimates) + 1, 1, -1)
593 sorted_n_rejected_bkgs = sorted_n_rejects - sorted_n_rejected_signals
594 sorted_bkg_rejections = sorted_n_rejected_bkgs / n_bkgs
595
596 cut_index, = np.searchsorted(sorted_bkg_rejections[::-1], (self.background_rejection,), side='right')
597
598 cut_value = sorted_estimates[-cut_index - 1]
599 return cut_value
cut
cached value of the threshold in the truth-classification analysis
quantity_name
cached name of the quantity in the truth-classification analysis
upper_bound
cached upper bound for this truth-classification analysis
outlier_z_score
cached Z-score (for outlier detection) for this truth-classification analysis
plots
cached dictionary of plots in the truth-classification analysis
allow_discrete
cached discrete-value flag for this truth-classification analysis
unit
cached measurement unit for this truth-classification analysis
_contact
cached contact person of the truth-classification analysis
cut_direction
cached value of the cut direction (< or >) in the truth-classification analysis
def __init__(self, contact, quantity_name, cut_direction=None, cut=None, lower_bound=None, upper_bound=None, outlier_z_score=None, allow_discrete=None, unit=None)
lower_bound
cached lower bound for this truth-classification analysis
fom
cached value of the figure of merit in the truth-classification analysis
background_rejection
cachec copy of the background-rejection threshold
def __init__(self, background_rejection=0.5, cut_direction=1)
def determine_cut_value(self, estimates, truths)
cut_direction_
cached copy of the cut direction (< or >)
cut_value_
cached copy of the cut threshold
def __init__(self, cut_direction=1, cut_value=np.nan)