10 from .
import statistics
12 from .plot
import ValidationPlot, compose_axis_label
13 from .fom
import ValidationFiguresOfMerit
14 from .tolerate_missing_key_formatter
import TolerateMissingKeyFormatter
16 formatter = TolerateMissingKeyFormatter()
20 """Perform truth-classification analysis"""
34 """Compare an estimated quantity to the truths by generating standardized validation plots."""
42 self.
plots = collections.OrderedDict()
68 """Compares the concrete estimate to the truth and efficiency, purity and background rejection
69 as figure of merit and plots the selection as a stacked plot over the truths.
73 estimates : array_like
74 Selection variable to compare to the truths
76 Binary true class values.
80 axis_label = compose_axis_label(quantity_name, self.
unit)
82 plot_name =
"{quantity_name}_{subplot_name}"
83 plot_name = formatter.format(plot_name, quantity_name=quantity_name)
88 estimate_is_binary = statistics.is_binary_series(estimates)
90 if estimate_is_binary:
91 binary_estimates = estimates != 0
95 elif self.
cut is not None:
96 if isinstance(self.
cut, numbers.Number):
99 cut_classifier =
CutClassifier(cut_direction=cut_direction, cut_value=cut_value)
102 cut_classifier = self.
cut
103 cut_classifier = cut_classifier.clone()
105 cut_classifier.fit(estimates, truths)
106 binary_estimates = cut_classifier.predict(estimates) != 0
107 cut_direction = cut_classifier.cut_direction
108 cut_value = cut_classifier.cut_value
110 if not isinstance(self.
cut, numbers.Number):
111 print(formatter.format(plot_name, subplot_name=
"cut_classifier"),
"summary")
112 cut_classifier.describe(estimates, truths)
122 signal_bkg_histogram_name = formatter.format(plot_name, subplot_name=
"signal_bkg_histogram")
124 signal_bkg_histogram.hist(
127 lower_bound=lower_bound,
128 upper_bound=upper_bound,
132 signal_bkg_histogram.xlabel = axis_label
134 if lower_bound
is None:
135 lower_bound = signal_bkg_histogram.lower_bound
137 if upper_bound
is None:
138 upper_bound = signal_bkg_histogram.upper_bound
140 self.
plots[
'signal_bkg'] = signal_bkg_histogram
143 purity_profile_name = formatter.format(plot_name, subplot_name=
"purity_profile")
146 purity_profile.profile(
149 lower_bound=lower_bound,
150 upper_bound=upper_bound,
155 purity_profile.xlabel = axis_label
156 purity_profile.ylabel =
'purity'
157 self.
plots[
"purity"] = purity_profile
160 if cut_direction
is None:
161 purity_grapherrors = ValidationPlot.convert_tprofile_to_tgrapherrors(purity_profile.plot)
162 correlation = purity_grapherrors.GetCorrelationFactor()
163 if correlation > 0.1:
164 print(
"Determined cut direction", -1)
166 elif correlation < -0.1:
167 print(
"Determined cut direction", 1)
171 if cut_value
is not None:
172 fom_name = formatter.format(plot_name, subplot_name=
"classification_figures_of_merits")
173 fom_description =
"Efficiency, purity and background rejection of the classifiction with {quantity_name}".format(
174 quantity_name=quantity_name
177 fom_check =
"Check that the classifcation quality stays stable."
179 fom_title =
"Summary of the classification quality with {quantity_name}".format(
180 quantity_name=quantity_name
186 description=fom_description,
191 efficiency = scores.efficiency(truths, binary_estimates)
192 purity = scores.purity(truths, binary_estimates)
193 background_rejection = scores.background_rejection(truths, binary_estimates)
195 classification_fom[
'cut_value'] = cut_value
196 classification_fom[
'cut_direction'] = cut_direction
197 classification_fom[
'efficiency'] = efficiency
198 classification_fom[
'purity'] = purity
199 classification_fom[
'background_rejection'] = background_rejection
201 self.
fom = classification_fom
203 for aux_name, aux_values
in auxiliaries.items():
204 if statistics.is_single_value_series(aux_values)
or aux_name == quantity_name:
207 aux_axis_label = compose_axis_label(aux_name)
211 signal_bkg_aux_hist2d_name = formatter.format(plot_name, subplot_name=aux_name +
'_signal_bkg_aux2d')
212 signal_bkg_aux_hist2d =
ValidationPlot(signal_bkg_aux_hist2d_name)
213 signal_bkg_aux_hist2d.hist2d(
217 lower_bound=(
None, lower_bound),
218 upper_bound=(
None, upper_bound),
223 aux_lower_bound = signal_bkg_aux_hist2d.lower_bound[0]
224 aux_upper_bound = signal_bkg_aux_hist2d.upper_bound[0]
226 signal_bkg_aux_hist2d.xlabel = aux_axis_label
227 signal_bkg_aux_hist2d.ylabel = axis_label
228 self.
plots[signal_bkg_aux_hist2d_name] = signal_bkg_aux_hist2d
231 if cut_value
is not None:
235 aux_purity_profile_name = formatter.format(plot_name, subplot_name=aux_name +
"_aux_purity_profile")
237 aux_purity_profile.profile(
238 aux_values[binary_estimates],
239 truths[binary_estimates],
242 lower_bound=aux_lower_bound,
243 upper_bound=aux_upper_bound,
246 aux_purity_profile.xlabel = aux_axis_label
247 aux_purity_profile.ylabel =
'purity'
248 self.
plots[aux_purity_profile_name] = aux_purity_profile
252 aux_efficiency_profile_name = formatter.format(plot_name, subplot_name=aux_name +
"_aux_efficiency_profile")
253 aux_efficiency_profile =
ValidationPlot(aux_efficiency_profile_name)
254 aux_efficiency_profile.profile(
256 binary_estimates[signals],
259 lower_bound=aux_lower_bound,
260 upper_bound=aux_upper_bound,
263 aux_efficiency_profile.xlabel = aux_axis_label
264 aux_efficiency_profile.ylabel =
'efficiency'
265 self.
plots[aux_efficiency_profile_name] = aux_efficiency_profile
269 aux_bkg_rejection_profile_name = formatter.format(plot_name, subplot_name=aux_name +
"_aux_bkg_rejection_profile")
270 aux_bkg_rejection_profile =
ValidationPlot(aux_bkg_rejection_profile_name)
271 aux_bkg_rejection_profile.profile(
272 aux_values[~signals],
273 ~binary_estimates[~signals],
276 lower_bound=aux_lower_bound,
277 upper_bound=aux_upper_bound,
280 aux_bkg_rejection_profile.xlabel = aux_axis_label
281 aux_bkg_rejection_profile.ylabel =
'bkg rejection'
282 self.
plots[aux_bkg_rejection_profile_name] = aux_bkg_rejection_profile
285 if cut_direction
is None:
286 purity_grapherrors = ValidationPlot.convert_tprofile_to_tgrapherrors(purity_profile.plot,
288 correlation = purity_grapherrors.GetCorrelationFactor()
289 if correlation > 0.1:
290 print(
"Determined absolute cut direction", -1)
293 elif correlation < -0.1:
294 print(
"Determined absolute cut direction", 1)
299 estimates = np.abs(estimates)
300 cut_x_label =
"cut " + compose_axis_label(
"abs(" + quantity_name +
")", self.
unit)
303 cut_x_label =
"cut " + axis_label
306 if not estimate_is_binary
and cut_direction
is not None:
309 if cut_direction > 0:
310 quantiles = [0.5, 0.90, 0.99]
312 quantiles = [0.01, 0.10, 0.5]
314 for aux_name, aux_values
in auxiliaries.items():
315 if statistics.is_single_value_series(aux_values)
or aux_name == quantity_name:
318 aux_axis_label = compose_axis_label(aux_name)
320 signal_quantile_aux_profile_name = formatter.format(plot_name, subplot_name=aux_name +
'_signal_quantiles_aux2d')
321 signal_quantile_aux_profile =
ValidationPlot(signal_quantile_aux_profile_name)
322 signal_quantile_aux_profile.hist2d(
327 lower_bound=(
None, lower_bound),
328 upper_bound=(
None, upper_bound),
332 signal_quantile_aux_profile.xlabel = aux_axis_label
333 signal_quantile_aux_profile.ylabel = cut_x_label
334 self.
plots[signal_quantile_aux_profile_name] = signal_quantile_aux_profile
337 if not estimate_is_binary
and cut_direction
is not None:
338 n_data = len(estimates)
339 n_signals = scores.signal_amount(truths, estimates)
340 n_bkgs = n_data - n_signals
343 if cut_direction < 0:
344 sorting_indices = np.argsort(-estimates)
346 sorting_indices = np.argsort(estimates)
348 sorted_truths = truths[sorting_indices]
349 sorted_estimates = estimates[sorting_indices]
351 sorted_n_accepted_signals = np.cumsum(sorted_truths, dtype=float)
352 sorted_efficiencies = sorted_n_accepted_signals / n_signals
354 sorted_n_rejected_signals = n_signals - sorted_n_accepted_signals
355 sorted_n_rejects = np.arange(len(estimates) + 1, 1, -1)
356 sorted_n_rejected_bkgs = sorted_n_rejects - sorted_n_rejected_signals
357 sorted_bkg_rejections = sorted_n_rejected_bkgs / n_bkgs
361 efficiency_by_cut_profile_name = formatter.format(plot_name, subplot_name=
"efficiency_by_cut")
363 efficiency_by_cut_profile =
ValidationPlot(efficiency_by_cut_profile_name)
364 efficiency_by_cut_profile.profile(
367 lower_bound=lower_bound,
368 upper_bound=upper_bound,
373 efficiency_by_cut_profile.xlabel = cut_x_label
374 efficiency_by_cut_profile.ylabel =
"efficiency"
376 self.
plots[
"efficiency_by_cut"] = efficiency_by_cut_profile
380 bkg_rejection_by_cut_profile_name = formatter.format(plot_name, subplot_name=
"bkg_rejection_by_cut")
381 bkg_rejection_by_cut_profile =
ValidationPlot(bkg_rejection_by_cut_profile_name)
382 bkg_rejection_by_cut_profile.profile(
384 sorted_bkg_rejections,
385 lower_bound=lower_bound,
386 upper_bound=upper_bound,
391 bkg_rejection_by_cut_profile.xlabel = cut_x_label
392 bkg_rejection_by_cut_profile.ylabel =
"background rejection"
394 self.
plots[
"bkg_rejection_by_cut"] = bkg_rejection_by_cut_profile
398 purity_over_efficiency_profile_name = formatter.format(plot_name, subplot_name=
"purity_over_efficiency")
399 purity_over_efficiency_profile =
ValidationPlot(purity_over_efficiency_profile_name)
400 purity_over_efficiency_profile.profile(
403 cumulation_direction=1,
407 purity_over_efficiency_profile.xlabel =
'efficiency'
408 purity_over_efficiency_profile.ylabel =
'purity'
410 self.
plots[
"purity_over_efficiency"] = purity_over_efficiency_profile
414 cut_over_efficiency_profile_name = formatter.format(plot_name, subplot_name=
"cut_over_efficiency")
415 cut_over_efficiency_profile =
ValidationPlot(cut_over_efficiency_profile_name)
416 cut_over_efficiency_profile.profile(
424 cut_over_efficiency_profile.set_minimum(lower_bound)
425 cut_over_efficiency_profile.set_maximum(upper_bound)
426 cut_over_efficiency_profile.xlabel =
'efficiency'
427 cut_over_efficiency_profile.ylabel = cut_x_label
429 self.
plots[
"cut_over_efficiency"] = cut_over_efficiency_profile
433 cut_over_bkg_rejection_profile_name = formatter.format(plot_name, subplot_name=
"cut_over_bkg_rejection")
434 cut_over_bkg_rejection_profile =
ValidationPlot(cut_over_bkg_rejection_profile_name)
435 cut_over_bkg_rejection_profile.profile(
436 sorted_bkg_rejections,
443 cut_over_bkg_rejection_profile.set_minimum(lower_bound)
444 cut_over_bkg_rejection_profile.set_maximum(upper_bound)
445 cut_over_bkg_rejection_profile.xlabel =
'bkg_rejection'
446 cut_over_bkg_rejection_profile.ylabel = cut_x_label
448 self.
plots[
"cut_over_bkg_rejection"] = cut_over_bkg_rejection_profile
452 efficiency_over_bkg_rejection_profile_name = formatter.format(plot_name, subplot_name=
"efficiency_over_bkg_rejection")
453 efficiency_over_bkg_rejection_profile =
ValidationPlot(efficiency_over_bkg_rejection_profile_name)
454 efficiency_over_bkg_rejection_profile.profile(
455 sorted_bkg_rejections,
461 efficiency_over_bkg_rejection_profile.xlabel =
"bkg rejection"
462 efficiency_over_bkg_rejection_profile.ylabel =
"efficiency"
464 self.
plots[
"efficiency_over_bkg_rejection"] = efficiency_over_bkg_rejection_profile
471 """Get the name of the contact person"""
476 """Set the name of the contact person"""
479 for plot
in list(self.
plots.values()):
480 plot.contact = contact
483 self.
fom.contact = contact
486 """Write the plots to the ROOT TDirectory"""
487 for plot
in list(self.
plots.values()):
488 plot.write(tdirectory)
496 """Simple classifier cutting on a single variable"""
498 def __init__(self, cut_direction=1, cut_value=np.nan):
507 """Get the value of the cut direction"""
512 """Get the value of the cut threshold"""
516 """Return a clone of this object"""
517 return copy.copy(self)
520 """Get the value of the cut threshold"""
523 def fit(self, estimates, truths):
524 """Fit to determine the cut threshold"""
529 """Select estimates that satisfy the cut"""
531 raise ValueError(
"Cut value not set. Forgot to fit?")
534 binary_estimates = estimates >= self.
cut_value_
536 binary_estimates = estimates <= self.
cut_value_
538 return binary_estimates
541 """Describe the cut selection and its efficiency, purity and background rejection"""
543 print(
"Cut accepts >= ", self.
cut_value_,
'with')
545 print(
"Cut accepts <= ", self.
cut_value_,
'with')
547 binary_estimates = self.
predict(estimates)
549 efficiency = scores.efficiency(truths, binary_estimates)
550 purity = scores.purity(truths, binary_estimates)
551 background_rejection = scores.background_rejection(truths, binary_estimates)
553 print(
"efficiency", efficiency)
554 print(
"purity", purity)
555 print(
"background_rejection", background_rejection)
558 def cut_at_background_rejection(background_rejection=0.5, cut_direction=1):
563 """Apply cut on the background rejection"""
565 def __init__(self, background_rejection=0.5, cut_direction=1):
567 super(CutAtBackgroundRejectionClassifier, self).
__init__(cut_direction=cut_direction, cut_value=np.nan)
572 """Find the cut value that satisfies the desired background-rejection level"""
573 n_data = len(estimates)
574 n_signals = scores.signal_amount(truths, estimates)
575 n_bkgs = n_data - n_signals
577 sorting_indices = np.argsort(estimates)
580 orginal_sorting_indices = sorting_indices
581 sorting_indices = sorting_indices[::-1]
583 sorted_truths = truths[sorting_indices]
584 sorted_estimates = estimates[sorting_indices]
586 sorted_n_accepted_signals = np.cumsum(sorted_truths, dtype=float)
587 sorted_efficiencies = sorted_n_accepted_signals / n_signals
589 sorted_n_rejected_signals = n_signals - sorted_n_accepted_signals
590 sorted_n_rejects = np.arange(len(estimates) + 1, 1, -1)
591 sorted_n_rejected_bkgs = sorted_n_rejects - sorted_n_rejected_signals
592 sorted_bkg_rejections = sorted_n_rejected_bkgs / n_bkgs
594 cut_index, = np.searchsorted(sorted_bkg_rejections[::-1], (self.
background_rejection,), side=
'right')
596 cut_value = sorted_estimates[-cut_index - 1]