22formatter = TolerateMissingKeyFormatter()
26 """Perform truth-classification analysis"""
40 """Compare an estimated quantity to the truths by generating standardized validation plots."""
48 self.
plots = collections.OrderedDict()
74 """Compares the concrete estimate to the truth and efficiency, purity and background rejection
75 as figure of merit
and plots the selection
as a stacked plot over the truths.
79 estimates : array_like
80 Selection variable to compare to the truths
82 Binary true
class values.
86 axis_label = compose_axis_label(quantity_name, self.unit)
88 plot_name = "{quantity_name}_{subplot_name}"
89 plot_name = formatter.format(plot_name, quantity_name=quantity_name)
94 estimate_is_binary = statistics.is_binary_series(estimates)
96 if estimate_is_binary:
97 binary_estimates = estimates != 0
101 elif self.
cut is not None:
102 if isinstance(self.
cut, numbers.Number):
105 cut_classifier =
CutClassifier(cut_direction=cut_direction, cut_value=cut_value)
108 cut_classifier = self.
cut
109 cut_classifier = cut_classifier.clone()
111 cut_classifier.fit(estimates, truths)
112 binary_estimates = cut_classifier.predict(estimates) != 0
113 cut_direction = cut_classifier.cut_direction
114 cut_value = cut_classifier.cut_value
116 if not isinstance(self.
cut, numbers.Number):
117 print(formatter.format(plot_name, subplot_name=
"cut_classifier"),
"summary")
118 cut_classifier.describe(estimates, truths)
128 signal_bkg_histogram_name = formatter.format(plot_name, subplot_name=
"signal_bkg_histogram")
130 signal_bkg_histogram.hist(
133 lower_bound=lower_bound,
134 upper_bound=upper_bound,
138 signal_bkg_histogram.xlabel = axis_label
140 if lower_bound
is None:
141 lower_bound = signal_bkg_histogram.lower_bound
143 if upper_bound
is None:
144 upper_bound = signal_bkg_histogram.upper_bound
146 self.
plots[
'signal_bkg'] = signal_bkg_histogram
149 purity_profile_name = formatter.format(plot_name, subplot_name=
"purity_profile")
152 purity_profile.profile(
155 lower_bound=lower_bound,
156 upper_bound=upper_bound,
161 purity_profile.xlabel = axis_label
162 purity_profile.ylabel =
'purity'
163 self.
plots[
"purity"] = purity_profile
166 if cut_direction
is None:
167 purity_grapherrors = ValidationPlot.convert_tprofile_to_tgrapherrors(purity_profile.plot)
168 correlation = purity_grapherrors.GetCorrelationFactor()
169 if correlation > 0.1:
170 print(
"Determined cut direction", -1)
172 elif correlation < -0.1:
173 print(
"Determined cut direction", 1)
177 if cut_value
is not None:
178 fom_name = formatter.format(plot_name, subplot_name=
"classification_figures_of_merits")
179 fom_description = f
"Efficiency, purity and background rejection of the classification with {quantity_name}"
181 fom_check =
"Check that the classification quality stays stable."
183 fom_title = f
"Summary of the classification quality with {quantity_name}"
188 description=fom_description,
193 efficiency = scores.efficiency(truths, binary_estimates)
194 purity = scores.purity(truths, binary_estimates)
195 background_rejection = scores.background_rejection(truths, binary_estimates)
197 classification_fom[
'cut_value'] = cut_value
198 classification_fom[
'cut_direction'] = cut_direction
199 classification_fom[
'efficiency'] = efficiency
200 classification_fom[
'purity'] = purity
201 classification_fom[
'background_rejection'] = background_rejection
203 self.
fom = classification_fom
205 for aux_name, aux_values
in auxiliaries.items():
206 if statistics.is_single_value_series(aux_values)
or aux_name == quantity_name:
209 aux_axis_label = compose_axis_label(aux_name)
213 signal_bkg_aux_hist2d_name = formatter.format(plot_name, subplot_name=aux_name +
'_signal_bkg_aux2d')
214 signal_bkg_aux_hist2d =
ValidationPlot(signal_bkg_aux_hist2d_name)
215 signal_bkg_aux_hist2d.hist2d(
219 lower_bound=(
None, lower_bound),
220 upper_bound=(
None, upper_bound),
225 aux_lower_bound = signal_bkg_aux_hist2d.lower_bound[0]
226 aux_upper_bound = signal_bkg_aux_hist2d.upper_bound[0]
228 signal_bkg_aux_hist2d.xlabel = aux_axis_label
229 signal_bkg_aux_hist2d.ylabel = axis_label
230 self.
plots[signal_bkg_aux_hist2d_name] = signal_bkg_aux_hist2d
233 if cut_value
is not None:
237 aux_purity_profile_name = formatter.format(plot_name, subplot_name=aux_name +
"_aux_purity_profile")
239 aux_purity_profile.profile(
240 aux_values[binary_estimates],
241 truths[binary_estimates],
244 lower_bound=aux_lower_bound,
245 upper_bound=aux_upper_bound,
248 aux_purity_profile.xlabel = aux_axis_label
249 aux_purity_profile.ylabel =
'purity'
250 self.
plots[aux_purity_profile_name] = aux_purity_profile
254 aux_efficiency_profile_name = formatter.format(plot_name, subplot_name=aux_name +
"_aux_efficiency_profile")
255 aux_efficiency_profile =
ValidationPlot(aux_efficiency_profile_name)
256 aux_efficiency_profile.profile(
258 binary_estimates[signals],
261 lower_bound=aux_lower_bound,
262 upper_bound=aux_upper_bound,
265 aux_efficiency_profile.xlabel = aux_axis_label
266 aux_efficiency_profile.ylabel =
'efficiency'
267 self.
plots[aux_efficiency_profile_name] = aux_efficiency_profile
271 aux_bkg_rejection_profile_name = formatter.format(plot_name, subplot_name=aux_name +
"_aux_bkg_rejection_profile")
272 aux_bkg_rejection_profile =
ValidationPlot(aux_bkg_rejection_profile_name)
273 aux_bkg_rejection_profile.profile(
274 aux_values[~signals],
275 ~binary_estimates[~signals],
278 lower_bound=aux_lower_bound,
279 upper_bound=aux_upper_bound,
282 aux_bkg_rejection_profile.xlabel = aux_axis_label
283 aux_bkg_rejection_profile.ylabel =
'bkg rejection'
284 self.
plots[aux_bkg_rejection_profile_name] = aux_bkg_rejection_profile
287 if cut_direction
is None:
288 purity_grapherrors = ValidationPlot.convert_tprofile_to_tgrapherrors(purity_profile.plot,
290 correlation = purity_grapherrors.GetCorrelationFactor()
291 if correlation > 0.1:
292 print(
"Determined absolute cut direction", -1)
295 elif correlation < -0.1:
296 print(
"Determined absolute cut direction", 1)
301 estimates = np.abs(estimates)
302 cut_x_label =
"cut " + compose_axis_label(
"abs(" + quantity_name +
")", self.
unit)
305 cut_x_label =
"cut " + axis_label
308 if not estimate_is_binary
and cut_direction
is not None:
311 if cut_direction > 0:
312 quantiles = [0.5, 0.90, 0.99]
314 quantiles = [0.01, 0.10, 0.5]
316 for aux_name, aux_values
in auxiliaries.items():
317 if statistics.is_single_value_series(aux_values)
or aux_name == quantity_name:
320 aux_axis_label = compose_axis_label(aux_name)
322 signal_quantile_aux_profile_name = formatter.format(plot_name, subplot_name=aux_name +
'_signal_quantiles_aux2d')
323 signal_quantile_aux_profile =
ValidationPlot(signal_quantile_aux_profile_name)
324 signal_quantile_aux_profile.hist2d(
329 lower_bound=(
None, lower_bound),
330 upper_bound=(
None, upper_bound),
334 signal_quantile_aux_profile.xlabel = aux_axis_label
335 signal_quantile_aux_profile.ylabel = cut_x_label
336 self.
plots[signal_quantile_aux_profile_name] = signal_quantile_aux_profile
339 if not estimate_is_binary
and cut_direction
is not None:
340 n_data = len(estimates)
341 n_signals = scores.signal_amount(truths, estimates)
342 n_bkgs = n_data - n_signals
345 if cut_direction < 0:
346 sorting_indices = np.argsort(-estimates)
348 sorting_indices = np.argsort(estimates)
350 sorted_truths = truths[sorting_indices]
351 sorted_estimates = estimates[sorting_indices]
353 sorted_n_accepted_signals = np.cumsum(sorted_truths, dtype=float)
354 sorted_efficiencies = sorted_n_accepted_signals / n_signals
356 sorted_n_rejected_signals = n_signals - sorted_n_accepted_signals
357 sorted_n_rejects = np.arange(len(estimates) + 1, 1, -1)
358 sorted_n_rejected_bkgs = sorted_n_rejects - sorted_n_rejected_signals
359 sorted_bkg_rejections = sorted_n_rejected_bkgs / n_bkgs
363 efficiency_by_cut_profile_name = formatter.format(plot_name, subplot_name=
"efficiency_by_cut")
365 efficiency_by_cut_profile =
ValidationPlot(efficiency_by_cut_profile_name)
366 efficiency_by_cut_profile.profile(
369 lower_bound=lower_bound,
370 upper_bound=upper_bound,
375 efficiency_by_cut_profile.xlabel = cut_x_label
376 efficiency_by_cut_profile.ylabel =
"efficiency"
378 self.
plots[
"efficiency_by_cut"] = efficiency_by_cut_profile
382 bkg_rejection_by_cut_profile_name = formatter.format(plot_name, subplot_name=
"bkg_rejection_by_cut")
383 bkg_rejection_by_cut_profile =
ValidationPlot(bkg_rejection_by_cut_profile_name)
384 bkg_rejection_by_cut_profile.profile(
386 sorted_bkg_rejections,
387 lower_bound=lower_bound,
388 upper_bound=upper_bound,
393 bkg_rejection_by_cut_profile.xlabel = cut_x_label
394 bkg_rejection_by_cut_profile.ylabel =
"background rejection"
396 self.
plots[
"bkg_rejection_by_cut"] = bkg_rejection_by_cut_profile
400 purity_over_efficiency_profile_name = formatter.format(plot_name, subplot_name=
"purity_over_efficiency")
401 purity_over_efficiency_profile =
ValidationPlot(purity_over_efficiency_profile_name)
402 purity_over_efficiency_profile.profile(
405 cumulation_direction=1,
409 purity_over_efficiency_profile.xlabel =
'efficiency'
410 purity_over_efficiency_profile.ylabel =
'purity'
412 self.
plots[
"purity_over_efficiency"] = purity_over_efficiency_profile
416 cut_over_efficiency_profile_name = formatter.format(plot_name, subplot_name=
"cut_over_efficiency")
417 cut_over_efficiency_profile =
ValidationPlot(cut_over_efficiency_profile_name)
418 cut_over_efficiency_profile.profile(
426 cut_over_efficiency_profile.set_minimum(lower_bound)
427 cut_over_efficiency_profile.set_maximum(upper_bound)
428 cut_over_efficiency_profile.xlabel =
'efficiency'
429 cut_over_efficiency_profile.ylabel = cut_x_label
431 self.
plots[
"cut_over_efficiency"] = cut_over_efficiency_profile
435 cut_over_bkg_rejection_profile_name = formatter.format(plot_name, subplot_name=
"cut_over_bkg_rejection")
436 cut_over_bkg_rejection_profile =
ValidationPlot(cut_over_bkg_rejection_profile_name)
437 cut_over_bkg_rejection_profile.profile(
438 sorted_bkg_rejections,
445 cut_over_bkg_rejection_profile.set_minimum(lower_bound)
446 cut_over_bkg_rejection_profile.set_maximum(upper_bound)
447 cut_over_bkg_rejection_profile.xlabel =
'bkg_rejection'
448 cut_over_bkg_rejection_profile.ylabel = cut_x_label
450 self.
plots[
"cut_over_bkg_rejection"] = cut_over_bkg_rejection_profile
454 efficiency_over_bkg_rejection_profile_name = formatter.format(plot_name, subplot_name=
"efficiency_over_bkg_rejection")
455 efficiency_over_bkg_rejection_profile =
ValidationPlot(efficiency_over_bkg_rejection_profile_name)
456 efficiency_over_bkg_rejection_profile.profile(
457 sorted_bkg_rejections,
463 efficiency_over_bkg_rejection_profile.xlabel =
"bkg rejection"
464 efficiency_over_bkg_rejection_profile.ylabel =
"efficiency"
466 self.
plots[
"efficiency_over_bkg_rejection"] = efficiency_over_bkg_rejection_profile
473 """Get the name of the contact person"""
478 """Set the name of the contact person"""
481 for plot
in list(self.
plots.values()):
482 plot.contact = contact
485 self.
fom.contact = contact
488 """Write the plots to the ROOT TDirectory"""
489 for plot
in list(self.
plots.values()):
490 plot.write(tdirectory)
498 """Simple classifier cutting on a single variable"""
500 def __init__(self, cut_direction=1, cut_value=np.nan):
509 """Get the value of the cut direction"""
514 """Get the value of the cut threshold"""
518 """Return a clone of this object"""
519 return copy.copy(self)
522 """Get the value of the cut threshold"""
525 def fit(self, estimates, truths):
526 """Fit to determine the cut threshold"""
531 """Select estimates that satisfy the cut"""
533 raise ValueError(
"Cut value not set. Forgot to fit?")
536 binary_estimates = estimates >= self.
cut_value_
538 binary_estimates = estimates <= self.
cut_value_
540 return binary_estimates
543 """Describe the cut selection and its efficiency, purity and background rejection"""
545 print(
"Cut accepts >= ", self.
cut_value_,
'with')
547 print(
"Cut accepts <= ", self.
cut_value_,
'with')
549 binary_estimates = self.
predict(estimates)
551 efficiency = scores.efficiency(truths, binary_estimates)
552 purity = scores.purity(truths, binary_estimates)
553 background_rejection = scores.background_rejection(truths, binary_estimates)
555 print(
"efficiency", efficiency)
556 print(
"purity", purity)
557 print(
"background_rejection", background_rejection)
560def cut_at_background_rejection(background_rejection=0.5, cut_direction=1):
565 """Apply cut on the background rejection"""
567 def __init__(self, background_rejection=0.5, cut_direction=1):
569 super().__init__(cut_direction=cut_direction, cut_value=np.nan)
574 """Find the cut value that satisfies the desired background-rejection level"""
575 n_data = len(estimates)
576 n_signals = scores.signal_amount(truths, estimates)
577 n_bkgs = n_data - n_signals
579 sorting_indices = np.argsort(estimates)
582 original_sorting_indices = sorting_indices
583 sorting_indices = sorting_indices[::-1]
585 sorted_truths = truths[sorting_indices]
586 sorted_estimates = estimates[sorting_indices]
588 sorted_n_accepted_signals = np.cumsum(sorted_truths, dtype=float)
591 sorted_n_rejected_signals = n_signals - sorted_n_accepted_signals
592 sorted_n_rejects = np.arange(len(estimates) + 1, 1, -1)
593 sorted_n_rejected_bkgs = sorted_n_rejects - sorted_n_rejected_signals
594 sorted_bkg_rejections = sorted_n_rejected_bkgs / n_bkgs
596 cut_index, = np.searchsorted(sorted_bkg_rejections[::-1], (self.
background_rejection,), side=
'right')
598 cut_value = sorted_estimates[-cut_index - 1]
cut
cached value of the threshold in the truth-classification analysis
quantity_name
cached name of the quantity in the truth-classification analysis
upper_bound
cached upper bound for this truth-classification analysis
def write(self, tdirectory=None)
outlier_z_score
cached Z-score (for outlier detection) for this truth-classification analysis
plots
cached dictionary of plots in the truth-classification analysis
allow_discrete
cached discrete-value flag for this truth-classification analysis
unit
cached measurement unit for this truth-classification analysis
_contact
cached contact person of the truth-classification analysis
cut_direction
cached value of the cut direction (< or >) in the truth-classification analysis
def contact(self, contact)
def __init__(self, contact, quantity_name, cut_direction=None, cut=None, lower_bound=None, upper_bound=None, outlier_z_score=None, allow_discrete=None, unit=None)
lower_bound
cached lower bound for this truth-classification analysis
fom
cached value of the figure of merit in the truth-classification analysis
background_rejection
cachec copy of the background-rejection threshold
def determine_cut_value(self, estimates, truths)
def __init__(self, background_rejection=0.5, cut_direction=1)
def fit(self, estimates, truths)
def determine_cut_value(self, estimates, truths)
cut_direction_
cached copy of the cut direction (< or >)
cut_value_
cached copy of the cut threshold
def describe(self, estimates, truths)
def __init__(self, cut_direction=1, cut_value=np.nan)
def predict(self, estimates)