Belle II Software light-2406-ragdoll
02_SamplePIDAnalysis.py
1#!/usr/bin/env python3
2
3
10
11
21
22import basf2 as b2
23import pidDataUtils as pdu
24import numpy as np
25
26# Read in the dataset (specifically the held-out test set from our training split.)
27standard = pdu.read_npz('data/slim_dstar/test.npz') # read PID info from npz into a DataFrame
28weighted = standard.copy(deep=True) # make a copy of the df
29print(f'{len(standard)} events')
30
31# Load the weights
32weights = np.load('models/net_wgt.npy')
33
34# Prepare DataFrames for analysis
35standard = pdu.produce_analysis_df(standard) # Standard PID: _no_ weights
36weighted = pdu.produce_analysis_df(weighted, weights=weights) # Weighted PID: uses the calibration weights
37print(f'{len(standard)} events after cuts')
38
39print('\nValues of the weights')
40print(weights)
41
42
43def compute_accuracy(df, mask=None):
44 _df = df.loc[mask] if mask is not None else df
45 return (_df['pid'] == _df['labels']).values.sum() / len(_df)
46
47
48print('\n no wgt wgt')
49standard_acc = compute_accuracy(standard)
50weighted_acc = compute_accuracy(weighted)
51print(f'Accuracy: {standard_acc:.3f} {weighted_acc:.3f}')
52for label in [2, 3]:
53 lbl = "pion" if label == 2 else "kaon"
54 _standard_eff = compute_accuracy(standard, mask=standard["labels"] == label)
55 _weighted_eff = compute_accuracy(weighted, mask=weighted["labels"] == label)
56 print(f'{lbl} eff: {_standard_eff:.3f} {_weighted_eff:.3f}')
57
58# There is also an external package, 'pidplots', that interfaces with these
59# DataFrames and provides many methods for quickly making plots of the PID
60# performance.
61
62
63# create path to create the data object
64my_path = b2.create_path()
65
66matrixName = "PIDCalibrationWeight_Example"
67weightMatrix = weights.tolist() # convert np.ndarray to list(list)
68
69addmatrix = b2.register_module('PIDCalibrationWeightCreator')
70addmatrix.param('matrixName', matrixName)
71addmatrix.param('weightMatrix', weightMatrix)
72addmatrix.param('experimentHigh', -1)
73addmatrix.param('experimentLow', 0)
74addmatrix.param('runHigh', -1)
75addmatrix.param('runLow', 0)
76
77eventinfosetter = b2.register_module('EventInfoSetter')
78eventinfosetter.param('evtNumList', [10])
79eventinfosetter.param('runList', [0])
80eventinfosetter.param('expList', [0])
81
82my_path.add_module(addmatrix)
83my_path.add_module(eventinfosetter)
84
85b2.process(my_path)