Belle II Software development
metrics.py
1
8import torch
9
10
11def speedup(y_true, filter_prob, retention_rate, balance_weights=None, from_logits=True,
12 t_gen=0.08, t_NN=0.63, t_SR=97.04):
13 """
14 Calculate the speedup achieved by sampling method.
15
16 Arguments:
17 y_true (torch.Tensor): True labels (ground truth).
18 filter_prob (torch.Tensor): Probabilities predicted by the filter model.
19 retention_rate (float): The rate at which events are retained by the filter.
20 balance_weights (torch.Tensor, optional): Weights for balancing the dataset. Default is None.
21 from_logits (bool): If True, interpret `filter_prob` as logits and apply sigmoid. Default is True.
22 t_gen (float): Typical processing time for event generation (default is 0.08 ms/event).
23 t_NN (float): Typical processing time for neural network processing (default is 0.63 ms/event).
24 t_SR (float): Typical processing time for detector simulation and reconstruction (default is 97.04 ms/event).
25
26 Returns:
27 float: The speedup achieved by the filtering method.
28 """
29 if from_logits:
30 filter_prob = torch.sigmoid(filter_prob)
31 if balance_weights is None:
32 # Set balance_weights to 1 by default if the dataset is balanced.
33 balance_weights = torch.ones_like(torch.tensor(filter_prob))
34
35 N_TP = (filter_prob * balance_weights)[y_true == 1].sum() * retention_rate # Number of true positives
36 N_FP = (filter_prob * balance_weights)[y_true == 0].sum() * (1 - retention_rate) # Number of false positives
37 N_TN = ((1 - filter_prob) * balance_weights)[y_true == 0].sum() * (1 - retention_rate) # Number of true negatives
38 N_FN = ((1 - filter_prob) * balance_weights)[y_true == 1].sum() * retention_rate # Number of false negatives
39
40 # Add up all types of events and multiply them by their processing times
41 t_simulated_filter = (
42 (N_TP + N_FP) * (t_gen + t_NN + t_SR)
43 + (N_TN + N_FN) * (t_gen + t_NN)
44 )
45
46 sumw = (y_true * balance_weights).sum()
47 sumw2 = (1 / filter_prob * balance_weights)[y_true == 1].sum()
48 N_simulated_nofilter = (sumw ** 2) / sumw2
49
50 t_simulated_nofilter = N_simulated_nofilter * (t_gen + t_SR)
51
52 return t_simulated_filter / t_simulated_nofilter