23 formatter = TolerateMissingKeyFormatter()
27 """Perform truth-classification analysis"""
41 """Compare an estimated quantity to the truths by generating standardized validation plots."""
49 self.
plotsplots = collections.OrderedDict()
75 """Compares the concrete estimate to the truth and efficiency, purity and background rejection
76 as figure of merit and plots the selection as a stacked plot over the truths.
80 estimates : array_like
81 Selection variable to compare to the truths
83 Binary true class values.
87 axis_label = compose_axis_label(quantity_name, self.
unitunit)
89 plot_name =
"{quantity_name}_{subplot_name}"
90 plot_name = formatter.format(plot_name, quantity_name=quantity_name)
95 estimate_is_binary = statistics.is_binary_series(estimates)
97 if estimate_is_binary:
98 binary_estimates = estimates != 0
102 elif self.
cutcut
is not None:
103 if isinstance(self.
cutcut, numbers.Number):
104 cut_value = self.
cutcut
106 cut_classifier =
CutClassifier(cut_direction=cut_direction, cut_value=cut_value)
109 cut_classifier = self.
cutcut
110 cut_classifier = cut_classifier.clone()
112 cut_classifier.fit(estimates, truths)
113 binary_estimates = cut_classifier.predict(estimates) != 0
114 cut_direction = cut_classifier.cut_direction
115 cut_value = cut_classifier.cut_value
117 if not isinstance(self.
cutcut, numbers.Number):
118 print(formatter.format(plot_name, subplot_name=
"cut_classifier"),
"summary")
119 cut_classifier.describe(estimates, truths)
129 signal_bkg_histogram_name = formatter.format(plot_name, subplot_name=
"signal_bkg_histogram")
131 signal_bkg_histogram.hist(
134 lower_bound=lower_bound,
135 upper_bound=upper_bound,
139 signal_bkg_histogram.xlabel = axis_label
141 if lower_bound
is None:
142 lower_bound = signal_bkg_histogram.lower_bound
144 if upper_bound
is None:
145 upper_bound = signal_bkg_histogram.upper_bound
147 self.
plotsplots[
'signal_bkg'] = signal_bkg_histogram
150 purity_profile_name = formatter.format(plot_name, subplot_name=
"purity_profile")
153 purity_profile.profile(
156 lower_bound=lower_bound,
157 upper_bound=upper_bound,
162 purity_profile.xlabel = axis_label
163 purity_profile.ylabel =
'purity'
164 self.
plotsplots[
"purity"] = purity_profile
167 if cut_direction
is None:
168 purity_grapherrors = ValidationPlot.convert_tprofile_to_tgrapherrors(purity_profile.plot)
169 correlation = purity_grapherrors.GetCorrelationFactor()
170 if correlation > 0.1:
171 print(
"Determined cut direction", -1)
173 elif correlation < -0.1:
174 print(
"Determined cut direction", 1)
178 if cut_value
is not None:
179 fom_name = formatter.format(plot_name, subplot_name=
"classification_figures_of_merits")
180 fom_description =
"Efficiency, purity and background rejection of the classifiction with {quantity_name}".format(
181 quantity_name=quantity_name
184 fom_check =
"Check that the classifcation quality stays stable."
186 fom_title =
"Summary of the classification quality with {quantity_name}".format(
187 quantity_name=quantity_name
193 description=fom_description,
198 efficiency = scores.efficiency(truths, binary_estimates)
199 purity = scores.purity(truths, binary_estimates)
200 background_rejection = scores.background_rejection(truths, binary_estimates)
202 classification_fom[
'cut_value'] = cut_value
203 classification_fom[
'cut_direction'] = cut_direction
204 classification_fom[
'efficiency'] = efficiency
205 classification_fom[
'purity'] = purity
206 classification_fom[
'background_rejection'] = background_rejection
208 self.
fomfom = classification_fom
210 for aux_name, aux_values
in auxiliaries.items():
211 if statistics.is_single_value_series(aux_values)
or aux_name == quantity_name:
214 aux_axis_label = compose_axis_label(aux_name)
218 signal_bkg_aux_hist2d_name = formatter.format(plot_name, subplot_name=aux_name +
'_signal_bkg_aux2d')
219 signal_bkg_aux_hist2d =
ValidationPlot(signal_bkg_aux_hist2d_name)
220 signal_bkg_aux_hist2d.hist2d(
224 lower_bound=(
None, lower_bound),
225 upper_bound=(
None, upper_bound),
230 aux_lower_bound = signal_bkg_aux_hist2d.lower_bound[0]
231 aux_upper_bound = signal_bkg_aux_hist2d.upper_bound[0]
233 signal_bkg_aux_hist2d.xlabel = aux_axis_label
234 signal_bkg_aux_hist2d.ylabel = axis_label
235 self.
plotsplots[signal_bkg_aux_hist2d_name] = signal_bkg_aux_hist2d
238 if cut_value
is not None:
242 aux_purity_profile_name = formatter.format(plot_name, subplot_name=aux_name +
"_aux_purity_profile")
244 aux_purity_profile.profile(
245 aux_values[binary_estimates],
246 truths[binary_estimates],
249 lower_bound=aux_lower_bound,
250 upper_bound=aux_upper_bound,
253 aux_purity_profile.xlabel = aux_axis_label
254 aux_purity_profile.ylabel =
'purity'
255 self.
plotsplots[aux_purity_profile_name] = aux_purity_profile
259 aux_efficiency_profile_name = formatter.format(plot_name, subplot_name=aux_name +
"_aux_efficiency_profile")
260 aux_efficiency_profile =
ValidationPlot(aux_efficiency_profile_name)
261 aux_efficiency_profile.profile(
263 binary_estimates[signals],
266 lower_bound=aux_lower_bound,
267 upper_bound=aux_upper_bound,
270 aux_efficiency_profile.xlabel = aux_axis_label
271 aux_efficiency_profile.ylabel =
'efficiency'
272 self.
plotsplots[aux_efficiency_profile_name] = aux_efficiency_profile
276 aux_bkg_rejection_profile_name = formatter.format(plot_name, subplot_name=aux_name +
"_aux_bkg_rejection_profile")
277 aux_bkg_rejection_profile =
ValidationPlot(aux_bkg_rejection_profile_name)
278 aux_bkg_rejection_profile.profile(
279 aux_values[~signals],
280 ~binary_estimates[~signals],
283 lower_bound=aux_lower_bound,
284 upper_bound=aux_upper_bound,
287 aux_bkg_rejection_profile.xlabel = aux_axis_label
288 aux_bkg_rejection_profile.ylabel =
'bkg rejection'
289 self.
plotsplots[aux_bkg_rejection_profile_name] = aux_bkg_rejection_profile
292 if cut_direction
is None:
293 purity_grapherrors = ValidationPlot.convert_tprofile_to_tgrapherrors(purity_profile.plot,
295 correlation = purity_grapherrors.GetCorrelationFactor()
296 if correlation > 0.1:
297 print(
"Determined absolute cut direction", -1)
300 elif correlation < -0.1:
301 print(
"Determined absolute cut direction", 1)
306 estimates = np.abs(estimates)
307 cut_x_label =
"cut " + compose_axis_label(
"abs(" + quantity_name +
")", self.
unitunit)
310 cut_x_label =
"cut " + axis_label
313 if not estimate_is_binary
and cut_direction
is not None:
316 if cut_direction > 0:
317 quantiles = [0.5, 0.90, 0.99]
319 quantiles = [0.01, 0.10, 0.5]
321 for aux_name, aux_values
in auxiliaries.items():
322 if statistics.is_single_value_series(aux_values)
or aux_name == quantity_name:
325 aux_axis_label = compose_axis_label(aux_name)
327 signal_quantile_aux_profile_name = formatter.format(plot_name, subplot_name=aux_name +
'_signal_quantiles_aux2d')
328 signal_quantile_aux_profile =
ValidationPlot(signal_quantile_aux_profile_name)
329 signal_quantile_aux_profile.hist2d(
334 lower_bound=(
None, lower_bound),
335 upper_bound=(
None, upper_bound),
339 signal_quantile_aux_profile.xlabel = aux_axis_label
340 signal_quantile_aux_profile.ylabel = cut_x_label
341 self.
plotsplots[signal_quantile_aux_profile_name] = signal_quantile_aux_profile
344 if not estimate_is_binary
and cut_direction
is not None:
345 n_data = len(estimates)
346 n_signals = scores.signal_amount(truths, estimates)
347 n_bkgs = n_data - n_signals
350 if cut_direction < 0:
351 sorting_indices = np.argsort(-estimates)
353 sorting_indices = np.argsort(estimates)
355 sorted_truths = truths[sorting_indices]
356 sorted_estimates = estimates[sorting_indices]
358 sorted_n_accepted_signals = np.cumsum(sorted_truths, dtype=float)
359 sorted_efficiencies = sorted_n_accepted_signals / n_signals
361 sorted_n_rejected_signals = n_signals - sorted_n_accepted_signals
362 sorted_n_rejects = np.arange(len(estimates) + 1, 1, -1)
363 sorted_n_rejected_bkgs = sorted_n_rejects - sorted_n_rejected_signals
364 sorted_bkg_rejections = sorted_n_rejected_bkgs / n_bkgs
368 efficiency_by_cut_profile_name = formatter.format(plot_name, subplot_name=
"efficiency_by_cut")
370 efficiency_by_cut_profile =
ValidationPlot(efficiency_by_cut_profile_name)
371 efficiency_by_cut_profile.profile(
374 lower_bound=lower_bound,
375 upper_bound=upper_bound,
380 efficiency_by_cut_profile.xlabel = cut_x_label
381 efficiency_by_cut_profile.ylabel =
"efficiency"
383 self.
plotsplots[
"efficiency_by_cut"] = efficiency_by_cut_profile
387 bkg_rejection_by_cut_profile_name = formatter.format(plot_name, subplot_name=
"bkg_rejection_by_cut")
388 bkg_rejection_by_cut_profile =
ValidationPlot(bkg_rejection_by_cut_profile_name)
389 bkg_rejection_by_cut_profile.profile(
391 sorted_bkg_rejections,
392 lower_bound=lower_bound,
393 upper_bound=upper_bound,
398 bkg_rejection_by_cut_profile.xlabel = cut_x_label
399 bkg_rejection_by_cut_profile.ylabel =
"background rejection"
401 self.
plotsplots[
"bkg_rejection_by_cut"] = bkg_rejection_by_cut_profile
405 purity_over_efficiency_profile_name = formatter.format(plot_name, subplot_name=
"purity_over_efficiency")
406 purity_over_efficiency_profile =
ValidationPlot(purity_over_efficiency_profile_name)
407 purity_over_efficiency_profile.profile(
410 cumulation_direction=1,
414 purity_over_efficiency_profile.xlabel =
'efficiency'
415 purity_over_efficiency_profile.ylabel =
'purity'
417 self.
plotsplots[
"purity_over_efficiency"] = purity_over_efficiency_profile
421 cut_over_efficiency_profile_name = formatter.format(plot_name, subplot_name=
"cut_over_efficiency")
422 cut_over_efficiency_profile =
ValidationPlot(cut_over_efficiency_profile_name)
423 cut_over_efficiency_profile.profile(
431 cut_over_efficiency_profile.set_minimum(lower_bound)
432 cut_over_efficiency_profile.set_maximum(upper_bound)
433 cut_over_efficiency_profile.xlabel =
'efficiency'
434 cut_over_efficiency_profile.ylabel = cut_x_label
436 self.
plotsplots[
"cut_over_efficiency"] = cut_over_efficiency_profile
440 cut_over_bkg_rejection_profile_name = formatter.format(plot_name, subplot_name=
"cut_over_bkg_rejection")
441 cut_over_bkg_rejection_profile =
ValidationPlot(cut_over_bkg_rejection_profile_name)
442 cut_over_bkg_rejection_profile.profile(
443 sorted_bkg_rejections,
450 cut_over_bkg_rejection_profile.set_minimum(lower_bound)
451 cut_over_bkg_rejection_profile.set_maximum(upper_bound)
452 cut_over_bkg_rejection_profile.xlabel =
'bkg_rejection'
453 cut_over_bkg_rejection_profile.ylabel = cut_x_label
455 self.
plotsplots[
"cut_over_bkg_rejection"] = cut_over_bkg_rejection_profile
459 efficiency_over_bkg_rejection_profile_name = formatter.format(plot_name, subplot_name=
"efficiency_over_bkg_rejection")
460 efficiency_over_bkg_rejection_profile =
ValidationPlot(efficiency_over_bkg_rejection_profile_name)
461 efficiency_over_bkg_rejection_profile.profile(
462 sorted_bkg_rejections,
468 efficiency_over_bkg_rejection_profile.xlabel =
"bkg rejection"
469 efficiency_over_bkg_rejection_profile.ylabel =
"efficiency"
471 self.
plotsplots[
"efficiency_over_bkg_rejection"] = efficiency_over_bkg_rejection_profile
478 """Get the name of the contact person"""
483 """Set the name of the contact person"""
486 for plot
in list(self.
plotsplots.values()):
487 plot.contact = contact
490 self.
fomfom.contact = contact
493 """Write the plots to the ROOT TDirectory"""
494 for plot
in list(self.
plotsplots.values()):
495 plot.write(tdirectory)
503 """Simple classifier cutting on a single variable"""
505 def __init__(self, cut_direction=1, cut_value=np.nan):
514 """Get the value of the cut direction"""
519 """Get the value of the cut threshold"""
523 """Return a clone of this object"""
524 return copy.copy(self)
527 """Get the value of the cut threshold"""
530 def fit(self, estimates, truths):
531 """Fit to determine the cut threshold"""
536 """Select estimates that satisfy the cut"""
538 raise ValueError(
"Cut value not set. Forgot to fit?")
541 binary_estimates = estimates >= self.
cut_value_cut_value_
543 binary_estimates = estimates <= self.
cut_value_cut_value_
545 return binary_estimates
548 """Describe the cut selection and its efficiency, purity and background rejection"""
550 print(
"Cut accepts >= ", self.
cut_value_cut_value_,
'with')
552 print(
"Cut accepts <= ", self.
cut_value_cut_value_,
'with')
554 binary_estimates = self.
predictpredict(estimates)
556 efficiency = scores.efficiency(truths, binary_estimates)
557 purity = scores.purity(truths, binary_estimates)
558 background_rejection = scores.background_rejection(truths, binary_estimates)
560 print(
"efficiency", efficiency)
561 print(
"purity", purity)
562 print(
"background_rejection", background_rejection)
565 def cut_at_background_rejection(background_rejection=0.5, cut_direction=1):
570 """Apply cut on the background rejection"""
572 def __init__(self, background_rejection=0.5, cut_direction=1):
574 super(CutAtBackgroundRejectionClassifier, self).
__init__(cut_direction=cut_direction, cut_value=np.nan)
579 """Find the cut value that satisfies the desired background-rejection level"""
580 n_data = len(estimates)
581 n_signals = scores.signal_amount(truths, estimates)
582 n_bkgs = n_data - n_signals
584 sorting_indices = np.argsort(estimates)
587 original_sorting_indices = sorting_indices
588 sorting_indices = sorting_indices[::-1]
590 sorted_truths = truths[sorting_indices]
591 sorted_estimates = estimates[sorting_indices]
593 sorted_n_accepted_signals = np.cumsum(sorted_truths, dtype=float)
596 sorted_n_rejected_signals = n_signals - sorted_n_accepted_signals
597 sorted_n_rejects = np.arange(len(estimates) + 1, 1, -1)
598 sorted_n_rejected_bkgs = sorted_n_rejects - sorted_n_rejected_signals
599 sorted_bkg_rejections = sorted_n_rejected_bkgs / n_bkgs
601 cut_index, = np.searchsorted(sorted_bkg_rejections[::-1], (self.
background_rejectionbackground_rejection,), side=
'right')
603 cut_value = sorted_estimates[-cut_index - 1]
def analyse(self, estimates, truths, auxiliaries={})
cut
cached value of the threshold in the truth-classification analysis
quantity_name
cached name of the quantity in the truth-classification analysis
upper_bound
cached upper bound for this truth-classification analysis
def write(self, tdirectory=None)
outlier_z_score
cached Z-score (for outlier detection) for this truth-classification analysis
plots
cached dictionary of plots in the truth-classification analysis
allow_discrete
cached discrete-value flag for this truth-classification analysis
unit
cached measurement unit for this truth-classification analysis
_contact
cached contact person of the truth-classification analysis
cut_direction
cached value of the cut direction (< or >) in the truth-classification analysis
def contact(self, contact)
def __init__(self, contact, quantity_name, cut_direction=None, cut=None, lower_bound=None, upper_bound=None, outlier_z_score=None, allow_discrete=None, unit=None)
lower_bound
cached lower bound for this truth-classification analysis
fom
cached value of the figure of merit in the truth-classification analysis
background_rejection
cachec copy of the background-rejection threshold
def determine_cut_value(self, estimates, truths)
def __init__(self, background_rejection=0.5, cut_direction=1)
def fit(self, estimates, truths)
def determine_cut_value(self, estimates, truths)
cut_direction_
cached copy of the cut direction (< or >)
cut_value_
cached copy of the cut threshold
def describe(self, estimates, truths)
def __init__(self, cut_direction=1, cut_value=np.nan)
def predict(self, estimates)