11def speedup(y_true, filter_prob, retention_rate, balance_weights=None, from_logits=True,
12 t_gen=0.08, t_NN=0.63, t_SR=97.04):
14 Calculate the speedup achieved by sampling method.
17 y_true (torch.Tensor): True labels (ground truth).
18 filter_prob (torch.Tensor): Probabilities predicted by the filter model.
19 retention_rate (float): The rate at which events are retained by the filter.
20 balance_weights (torch.Tensor, optional): Weights
for balancing the dataset. Default
is None.
21 from_logits (bool): If
True, interpret `filter_prob`
as logits
and apply sigmoid. Default
is True.
22 t_gen (float): Typical processing time
for event generation (default
is 0.08 ms/event).
23 t_NN (float): Typical processing time
for neural network processing (default
is 0.63 ms/event).
24 t_SR (float): Typical processing time
for detector simulation
and reconstruction (default
is 97.04 ms/event).
27 float: The speedup achieved by the filtering method.
30 filter_prob = torch.sigmoid(filter_prob)
31 if balance_weights
is None:
33 balance_weights = torch.ones_like(torch.tensor(filter_prob))
35 N_TP = (filter_prob * balance_weights)[y_true == 1].sum() * retention_rate
36 N_FP = (filter_prob * balance_weights)[y_true == 0].sum() * (1 - retention_rate)
37 N_TN = ((1 - filter_prob) * balance_weights)[y_true == 0].sum() * (1 - retention_rate)
38 N_FN = ((1 - filter_prob) * balance_weights)[y_true == 1].sum() * retention_rate
41 t_simulated_filter = (
42 (N_TP + N_FP) * (t_gen + t_NN + t_SR)
43 + (N_TN + N_FN) * (t_gen + t_NN)
46 sumw = (y_true * balance_weights).sum()
47 sumw2 = (1 / filter_prob * balance_weights)[y_true == 1].sum()
48 N_simulated_nofilter = (sumw ** 2) / sumw2
50 t_simulated_nofilter = N_simulated_nofilter * (t_gen + t_SR)
52 return t_simulated_filter / t_simulated_nofilter