Belle II Software development
ClassificationAnalysis Class Reference

Public Member Functions

def __init__ (self, contact, quantity_name, cut_direction=None, cut=None, lower_bound=None, upper_bound=None, outlier_z_score=None, allow_discrete=None, unit=None)
 
def analyse (self, estimates, truths, auxiliaries={})
 
def contact (self)
 
def contact (self, contact)
 
def write (self, tdirectory=None)
 

Public Attributes

 quantity_name
 cached name of the quantity in the truth-classification analysis
 
 plots
 cached dictionary of plots in the truth-classification analysis
 
 fom
 cached value of the figure of merit in the truth-classification analysis
 
 cut_direction
 cached value of the cut direction (< or >) in the truth-classification analysis
 
 cut
 cached value of the threshold in the truth-classification analysis
 
 lower_bound
 cached lower bound for this truth-classification analysis
 
 upper_bound
 cached upper bound for this truth-classification analysis
 
 outlier_z_score
 cached Z-score (for outlier detection) for this truth-classification analysis
 
 allow_discrete
 cached discrete-value flag for this truth-classification analysis
 
 unit
 cached measurement unit for this truth-classification analysis
 
 contact
 contact person
 

Protected Attributes

 _contact
 cached contact person of the truth-classification analysis
 

Detailed Description

Perform truth-classification analysis

Definition at line 25 of file classification.py.

Constructor & Destructor Documentation

◆ __init__()

def __init__ (   self,
  contact,
  quantity_name,
  cut_direction = None,
  cut = None,
  lower_bound = None,
  upper_bound = None,
  outlier_z_score = None,
  allow_discrete = None,
  unit = None 
)
Compare an estimated quantity to the truths by generating standardized validation plots.

Definition at line 28 of file classification.py.

39 ):
40 """Compare an estimated quantity to the truths by generating standardized validation plots."""
41
42
43 self._contact = contact
44
45 self.quantity_name = quantity_name
46
47
48 self.plots = collections.OrderedDict()
49
50 self.fom = None
51
52
53 self.cut_direction = cut_direction
54
55 self.cut = cut
56
57
58 self.lower_bound = lower_bound
59
60 self.upper_bound = upper_bound
61
62 self.outlier_z_score = outlier_z_score
63
64 self.allow_discrete = allow_discrete
65
66 self.unit = unit
67

Member Function Documentation

◆ analyse()

def analyse (   self,
  estimates,
  truths,
  auxiliaries = {} 
)
Compares the concrete estimate to the truth and efficiency, purity and background rejection
as figure of merit and plots the selection as a stacked plot over the truths.

Parameters
----------
estimates : array_like
    Selection variable to compare to the truths
truths : array_like
    Binary true class values.

Definition at line 68 of file classification.py.

73 ):
74 """Compares the concrete estimate to the truth and efficiency, purity and background rejection
75 as figure of merit and plots the selection as a stacked plot over the truths.
76
77 Parameters
78 ----------
79 estimates : array_like
80 Selection variable to compare to the truths
81 truths : array_like
82 Binary true class values.
83 """
84
85 quantity_name = self.quantity_name
86 axis_label = compose_axis_label(quantity_name, self.unit)
87
88 plot_name = "{quantity_name}_{subplot_name}"
89 plot_name = formatter.format(plot_name, quantity_name=quantity_name)
90
91 signals = truths != 0
92
93 # Some different things become presentable depending on the estimates
94 estimate_is_binary = statistics.is_binary_series(estimates)
95
96 if estimate_is_binary:
97 binary_estimates = estimates != 0
98 cut_value = 0.5
99 cut_direction = -1 # reject low values
100
101 elif self.cut is not None:
102 if isinstance(self.cut, numbers.Number):
103 cut_value = self.cut
104 cut_direction = self.cut_direction
105 cut_classifier = CutClassifier(cut_direction=cut_direction, cut_value=cut_value)
106
107 else:
108 cut_classifier = self.cut
109 cut_classifier = cut_classifier.clone()
110
111 cut_classifier.fit(estimates, truths)
112 binary_estimates = cut_classifier.predict(estimates) != 0
113 cut_direction = cut_classifier.cut_direction
114 cut_value = cut_classifier.cut_value
115
116 if not isinstance(self.cut, numbers.Number):
117 print(formatter.format(plot_name, subplot_name="cut_classifier"), "summary")
118 cut_classifier.describe(estimates, truths)
119
120 else:
121 cut_value = None
122 cut_direction = self.cut_direction
123
124 lower_bound = self.lower_bound
125 upper_bound = self.upper_bound
126
127 # Stacked histogram
128 signal_bkg_histogram_name = formatter.format(plot_name, subplot_name="signal_bkg_histogram")
129 signal_bkg_histogram = ValidationPlot(signal_bkg_histogram_name)
130 signal_bkg_histogram.hist(
131 estimates,
132 stackby=truths,
133 lower_bound=lower_bound,
134 upper_bound=upper_bound,
135 outlier_z_score=self.outlier_z_score,
136 allow_discrete=self.allow_discrete,
137 )
138 signal_bkg_histogram.xlabel = axis_label
139
140 if lower_bound is None:
141 lower_bound = signal_bkg_histogram.lower_bound
142
143 if upper_bound is None:
144 upper_bound = signal_bkg_histogram.upper_bound
145
146 self.plots['signal_bkg'] = signal_bkg_histogram
147
148 # Purity profile
149 purity_profile_name = formatter.format(plot_name, subplot_name="purity_profile")
150
151 purity_profile = ValidationPlot(purity_profile_name)
152 purity_profile.profile(
153 estimates,
154 truths,
155 lower_bound=lower_bound,
156 upper_bound=upper_bound,
157 outlier_z_score=self.outlier_z_score,
158 allow_discrete=self.allow_discrete,
159 )
160
161 purity_profile.xlabel = axis_label
162 purity_profile.ylabel = 'purity'
163 self.plots["purity"] = purity_profile
164
165 # Try to guess the cur direction form the correlation
166 if cut_direction is None:
167 purity_grapherrors = ValidationPlot.convert_tprofile_to_tgrapherrors(purity_profile.plot)
168 correlation = purity_grapherrors.GetCorrelationFactor()
169 if correlation > 0.1:
170 print("Determined cut direction", -1)
171 cut_direction = -1 # reject low values
172 elif correlation < -0.1:
173 print("Determined cut direction", 1)
174 cut_direction = +1 # reject high values
175
176 # Figures of merit
177 if cut_value is not None:
178 fom_name = formatter.format(plot_name, subplot_name="classification_figures_of_merits")
179 fom_description = f"Efficiency, purity and background rejection of the classifiction with {quantity_name}"
180
181 fom_check = "Check that the classifcation quality stays stable."
182
183 fom_title = f"Summary of the classification quality with {quantity_name}"
184
185 classification_fom = ValidationFiguresOfMerit(
186 name=fom_name,
187 title=fom_title,
188 description=fom_description,
189 check=fom_check,
190 contact=self.contact,
191 )
192
193 efficiency = scores.efficiency(truths, binary_estimates)
194 purity = scores.purity(truths, binary_estimates)
195 background_rejection = scores.background_rejection(truths, binary_estimates)
196
197 classification_fom['cut_value'] = cut_value
198 classification_fom['cut_direction'] = cut_direction
199 classification_fom['efficiency'] = efficiency
200 classification_fom['purity'] = purity
201 classification_fom['background_rejection'] = background_rejection
202
203 self.fom = classification_fom
204 # Auxiliary hists
205 for aux_name, aux_values in auxiliaries.items():
206 if statistics.is_single_value_series(aux_values) or aux_name == quantity_name:
207 continue
208
209 aux_axis_label = compose_axis_label(aux_name)
210
211 # Signal + bkg distribution over estimate and auxiliary variable #
212 # ############################################################## #
213 signal_bkg_aux_hist2d_name = formatter.format(plot_name, subplot_name=aux_name + '_signal_bkg_aux2d')
214 signal_bkg_aux_hist2d = ValidationPlot(signal_bkg_aux_hist2d_name)
215 signal_bkg_aux_hist2d.hist2d(
216 aux_values,
217 estimates,
218 stackby=truths,
219 lower_bound=(None, lower_bound),
220 upper_bound=(None, upper_bound),
221 outlier_z_score=self.outlier_z_score,
222 allow_discrete=self.allow_discrete,
223 )
224
225 aux_lower_bound = signal_bkg_aux_hist2d.lower_bound[0]
226 aux_upper_bound = signal_bkg_aux_hist2d.upper_bound[0]
227
228 signal_bkg_aux_hist2d.xlabel = aux_axis_label
229 signal_bkg_aux_hist2d.ylabel = axis_label
230 self.plots[signal_bkg_aux_hist2d_name] = signal_bkg_aux_hist2d
231
232 # Figures of merit as function of the auxiliary variables
233 if cut_value is not None:
234
235 # Auxiliary purity profile #
236 # ######################## #
237 aux_purity_profile_name = formatter.format(plot_name, subplot_name=aux_name + "_aux_purity_profile")
238 aux_purity_profile = ValidationPlot(aux_purity_profile_name)
239 aux_purity_profile.profile(
240 aux_values[binary_estimates],
241 truths[binary_estimates],
242 outlier_z_score=self.outlier_z_score,
243 allow_discrete=self.allow_discrete,
244 lower_bound=aux_lower_bound,
245 upper_bound=aux_upper_bound,
246 )
247
248 aux_purity_profile.xlabel = aux_axis_label
249 aux_purity_profile.ylabel = 'purity'
250 self.plots[aux_purity_profile_name] = aux_purity_profile
251
252 # Auxiliary efficiency profile #
253 # ############################ #
254 aux_efficiency_profile_name = formatter.format(plot_name, subplot_name=aux_name + "_aux_efficiency_profile")
255 aux_efficiency_profile = ValidationPlot(aux_efficiency_profile_name)
256 aux_efficiency_profile.profile(
257 aux_values[signals],
258 binary_estimates[signals],
259 outlier_z_score=self.outlier_z_score,
260 allow_discrete=self.allow_discrete,
261 lower_bound=aux_lower_bound,
262 upper_bound=aux_upper_bound,
263 )
264
265 aux_efficiency_profile.xlabel = aux_axis_label
266 aux_efficiency_profile.ylabel = 'efficiency'
267 self.plots[aux_efficiency_profile_name] = aux_efficiency_profile
268
269 # Auxiliary bkg rejection profile #
270 # ############################### #
271 aux_bkg_rejection_profile_name = formatter.format(plot_name, subplot_name=aux_name + "_aux_bkg_rejection_profile")
272 aux_bkg_rejection_profile = ValidationPlot(aux_bkg_rejection_profile_name)
273 aux_bkg_rejection_profile.profile(
274 aux_values[~signals],
275 ~binary_estimates[~signals],
276 outlier_z_score=self.outlier_z_score,
277 allow_discrete=self.allow_discrete,
278 lower_bound=aux_lower_bound,
279 upper_bound=aux_upper_bound,
280 )
281
282 aux_bkg_rejection_profile.xlabel = aux_axis_label
283 aux_bkg_rejection_profile.ylabel = 'bkg rejection'
284 self.plots[aux_bkg_rejection_profile_name] = aux_bkg_rejection_profile
285
286 cut_abs = False
287 if cut_direction is None:
288 purity_grapherrors = ValidationPlot.convert_tprofile_to_tgrapherrors(purity_profile.plot,
289 abs_x=True)
290 correlation = purity_grapherrors.GetCorrelationFactor()
291 if correlation > 0.1:
292 print("Determined absolute cut direction", -1)
293 cut_direction = -1 # reject low values
294 cut_abs = True
295 elif correlation < -0.1:
296 print("Determined absolute cut direction", 1)
297 cut_direction = +1 # reject high values
298 cut_abs = True
299
300 if cut_abs:
301 estimates = np.abs(estimates)
302 cut_x_label = "cut " + compose_axis_label("abs(" + quantity_name + ")", self.unit)
303 lower_bound = 0
304 else:
305 cut_x_label = "cut " + axis_label
306
307 # Quantile plots
308 if not estimate_is_binary and cut_direction is not None:
309 # Signal estimate quantiles over auxiliary variable #
310 # ################################################# #
311 if cut_direction > 0:
312 quantiles = [0.5, 0.90, 0.99]
313 else:
314 quantiles = [0.01, 0.10, 0.5]
315
316 for aux_name, aux_values in auxiliaries.items():
317 if statistics.is_single_value_series(aux_values) or aux_name == quantity_name:
318 continue
319
320 aux_axis_label = compose_axis_label(aux_name)
321
322 signal_quantile_aux_profile_name = formatter.format(plot_name, subplot_name=aux_name + '_signal_quantiles_aux2d')
323 signal_quantile_aux_profile = ValidationPlot(signal_quantile_aux_profile_name)
324 signal_quantile_aux_profile.hist2d(
325 aux_values[signals],
326 estimates[signals],
327 quantiles=quantiles,
328 bins=('flat', None),
329 lower_bound=(None, lower_bound),
330 upper_bound=(None, upper_bound),
331 outlier_z_score=self.outlier_z_score,
332 allow_discrete=self.allow_discrete,
333 )
334 signal_quantile_aux_profile.xlabel = aux_axis_label
335 signal_quantile_aux_profile.ylabel = cut_x_label
336 self.plots[signal_quantile_aux_profile_name] = signal_quantile_aux_profile
337
338 # ROC plots
339 if not estimate_is_binary and cut_direction is not None:
340 n_data = len(estimates)
341 n_signals = scores.signal_amount(truths, estimates)
342 n_bkgs = n_data - n_signals
343
344 # work around for numpy sorting nan values as high but we want it as low depending on the cut direction
345 if cut_direction < 0: # reject low
346 sorting_indices = np.argsort(-estimates)
347 else:
348 sorting_indices = np.argsort(estimates)
349
350 sorted_truths = truths[sorting_indices]
351 sorted_estimates = estimates[sorting_indices]
352
353 sorted_n_accepted_signals = np.cumsum(sorted_truths, dtype=float)
354 sorted_efficiencies = sorted_n_accepted_signals / n_signals
355
356 sorted_n_rejected_signals = n_signals - sorted_n_accepted_signals
357 sorted_n_rejects = np.arange(len(estimates) + 1, 1, -1)
358 sorted_n_rejected_bkgs = sorted_n_rejects - sorted_n_rejected_signals
359 sorted_bkg_rejections = sorted_n_rejected_bkgs / n_bkgs
360
361 # Efficiency by cut value #
362 # ####################### #
363 efficiency_by_cut_profile_name = formatter.format(plot_name, subplot_name="efficiency_by_cut")
364
365 efficiency_by_cut_profile = ValidationPlot(efficiency_by_cut_profile_name)
366 efficiency_by_cut_profile.profile(
367 sorted_estimates,
368 sorted_efficiencies,
369 lower_bound=lower_bound,
370 upper_bound=upper_bound,
371 outlier_z_score=self.outlier_z_score,
372 allow_discrete=self.allow_discrete,
373 )
374
375 efficiency_by_cut_profile.xlabel = cut_x_label
376 efficiency_by_cut_profile.ylabel = "efficiency"
377
378 self.plots["efficiency_by_cut"] = efficiency_by_cut_profile
379
380 # Background rejection over cut value #
381 # ################################### #
382 bkg_rejection_by_cut_profile_name = formatter.format(plot_name, subplot_name="bkg_rejection_by_cut")
383 bkg_rejection_by_cut_profile = ValidationPlot(bkg_rejection_by_cut_profile_name)
384 bkg_rejection_by_cut_profile.profile(
385 sorted_estimates,
386 sorted_bkg_rejections,
387 lower_bound=lower_bound,
388 upper_bound=upper_bound,
389 outlier_z_score=self.outlier_z_score,
390 allow_discrete=self.allow_discrete,
391 )
392
393 bkg_rejection_by_cut_profile.xlabel = cut_x_label
394 bkg_rejection_by_cut_profile.ylabel = "background rejection"
395
396 self.plots["bkg_rejection_by_cut"] = bkg_rejection_by_cut_profile
397
398 # Purity over efficiency #
399 # ###################### #
400 purity_over_efficiency_profile_name = formatter.format(plot_name, subplot_name="purity_over_efficiency")
401 purity_over_efficiency_profile = ValidationPlot(purity_over_efficiency_profile_name)
402 purity_over_efficiency_profile.profile(
403 sorted_efficiencies,
404 sorted_truths,
405 cumulation_direction=1,
406 lower_bound=0,
407 upper_bound=1
408 )
409 purity_over_efficiency_profile.xlabel = 'efficiency'
410 purity_over_efficiency_profile.ylabel = 'purity'
411
412 self.plots["purity_over_efficiency"] = purity_over_efficiency_profile
413
414 # Cut over efficiency #
415 # ################### #
416 cut_over_efficiency_profile_name = formatter.format(plot_name, subplot_name="cut_over_efficiency")
417 cut_over_efficiency_profile = ValidationPlot(cut_over_efficiency_profile_name)
418 cut_over_efficiency_profile.profile(
419 sorted_efficiencies,
420 sorted_estimates,
421 lower_bound=0,
422 upper_bound=1,
423 outlier_z_score=self.outlier_z_score,
424 allow_discrete=self.allow_discrete,
425 )
426 cut_over_efficiency_profile.set_minimum(lower_bound)
427 cut_over_efficiency_profile.set_maximum(upper_bound)
428 cut_over_efficiency_profile.xlabel = 'efficiency'
429 cut_over_efficiency_profile.ylabel = cut_x_label
430
431 self.plots["cut_over_efficiency"] = cut_over_efficiency_profile
432
433 # Cut over bkg_rejection #
434 # ###################### #
435 cut_over_bkg_rejection_profile_name = formatter.format(plot_name, subplot_name="cut_over_bkg_rejection")
436 cut_over_bkg_rejection_profile = ValidationPlot(cut_over_bkg_rejection_profile_name)
437 cut_over_bkg_rejection_profile.profile(
438 sorted_bkg_rejections,
439 sorted_estimates,
440 lower_bound=0,
441 upper_bound=1,
442 outlier_z_score=self.outlier_z_score,
443 allow_discrete=self.allow_discrete,
444 )
445 cut_over_bkg_rejection_profile.set_minimum(lower_bound)
446 cut_over_bkg_rejection_profile.set_maximum(upper_bound)
447 cut_over_bkg_rejection_profile.xlabel = 'bkg_rejection'
448 cut_over_bkg_rejection_profile.ylabel = cut_x_label
449
450 self.plots["cut_over_bkg_rejection"] = cut_over_bkg_rejection_profile
451
452 # Efficiency over background rejection #
453 # #################################### #
454 efficiency_over_bkg_rejection_profile_name = formatter.format(plot_name, subplot_name="efficiency_over_bkg_rejection")
455 efficiency_over_bkg_rejection_profile = ValidationPlot(efficiency_over_bkg_rejection_profile_name)
456 efficiency_over_bkg_rejection_profile.profile(
457 sorted_bkg_rejections,
458 sorted_efficiencies,
459 lower_bound=0,
460 upper_bound=1
461 )
462
463 efficiency_over_bkg_rejection_profile.xlabel = "bkg rejection"
464 efficiency_over_bkg_rejection_profile.ylabel = "efficiency"
465
466 self.plots["efficiency_over_bkg_rejection"] = efficiency_over_bkg_rejection_profile
467
468
469 self.contact = self.contact
470

◆ contact() [1/2]

def contact (   self)
Get the name of the contact person

Definition at line 472 of file classification.py.

472 def contact(self):
473 """Get the name of the contact person"""
474 return self._contact
475

◆ contact() [2/2]

def contact (   self,
  contact 
)
Set the name of the contact person

Definition at line 477 of file classification.py.

477 def contact(self, contact):
478 """Set the name of the contact person"""
479 self._contact = contact
480
481 for plot in list(self.plots.values()):
482 plot.contact = contact
483
484 if self.fom:
485 self.fom.contact = contact
486

◆ write()

def write (   self,
  tdirectory = None 
)
Write the plots to the ROOT TDirectory

Definition at line 487 of file classification.py.

487 def write(self, tdirectory=None):
488 """Write the plots to the ROOT TDirectory"""
489 for plot in list(self.plots.values()):
490 plot.write(tdirectory)
491
492 if self.fom:
493 self.fom.write(tdirectory)
494
495

Member Data Documentation

◆ _contact

_contact
protected

cached contact person of the truth-classification analysis

Definition at line 43 of file classification.py.

◆ allow_discrete

allow_discrete

cached discrete-value flag for this truth-classification analysis

Definition at line 64 of file classification.py.

◆ contact

contact

contact person

Definition at line 469 of file classification.py.

◆ cut

cut

cached value of the threshold in the truth-classification analysis

Definition at line 55 of file classification.py.

◆ cut_direction

cut_direction

cached value of the cut direction (< or >) in the truth-classification analysis

Definition at line 53 of file classification.py.

◆ fom

fom

cached value of the figure of merit in the truth-classification analysis

Definition at line 50 of file classification.py.

◆ lower_bound

lower_bound

cached lower bound for this truth-classification analysis

Definition at line 58 of file classification.py.

◆ outlier_z_score

outlier_z_score

cached Z-score (for outlier detection) for this truth-classification analysis

Definition at line 62 of file classification.py.

◆ plots

plots

cached dictionary of plots in the truth-classification analysis

Definition at line 48 of file classification.py.

◆ quantity_name

quantity_name

cached name of the quantity in the truth-classification analysis

Definition at line 45 of file classification.py.

◆ unit

unit

cached measurement unit for this truth-classification analysis

Definition at line 66 of file classification.py.

◆ upper_bound

upper_bound

cached upper bound for this truth-classification analysis

Definition at line 60 of file classification.py.


The documentation for this class was generated from the following file: