Belle II Software  release-08-01-10
02_SamplePIDAnalysis.py
1 #!/usr/bin/env python3
2 
3 
10 
11 
21 
22 import basf2 as b2
23 import pidDataUtils as pdu
24 import numpy as np
25 
26 # Read in the dataset (specifically the held-out test set from our training split.)
27 standard = pdu.read_npz('data/slim_dstar/test.npz') # read PID info from npz into a DataFrame
28 weighted = standard.copy(deep=True) # make a copy of the df
29 print(f'{len(standard)} events')
30 
31 # Load the weights
32 weights = np.load('models/net_wgt.npy')
33 
34 # Prepare DataFrames for analysis
35 standard = pdu.produce_analysis_df(standard) # Standard PID: _no_ weights
36 weighted = pdu.produce_analysis_df(weighted, weights=weights) # Weighted PID: uses the calibration weights
37 print(f'{len(standard)} events after cuts')
38 
39 print('\nValues of the weights')
40 print(weights)
41 
42 
43 def compute_accuracy(df, mask=None):
44  _df = df.loc[mask] if mask is not None else df
45  return (_df['pid'] == _df['labels']).values.sum() / len(_df)
46 
47 
48 print('\n no wgt wgt')
49 standard_acc = compute_accuracy(standard)
50 weighted_acc = compute_accuracy(weighted)
51 print(f'Accuracy: {standard_acc:.3f} {weighted_acc:.3f}')
52 for label in [2, 3]:
53  lbl = "pion" if label == 2 else "kaon"
54  _standard_eff = compute_accuracy(standard, mask=standard["labels"] == label)
55  _weighted_eff = compute_accuracy(weighted, mask=weighted["labels"] == label)
56  print(f'{lbl} eff: {_standard_eff:.3f} {_weighted_eff:.3f}')
57 
58 # There is also an external package, 'pidplots', that interfaces with these
59 # DataFrames and provides many methods for quickly making plots of the PID
60 # performance.
61 
62 
63 # create path to create the data object
64 my_path = b2.create_path()
65 
66 matrixName = "PIDCalibrationWeight_Example"
67 weightMatrix = weights.tolist() # convert np.ndarray to list(list)
68 
69 addmatrix = b2.register_module('PIDCalibrationWeightCreator')
70 addmatrix.param('matrixName', matrixName)
71 addmatrix.param('weightMatrix', weightMatrix)
72 addmatrix.param('experimentHigh', -1)
73 addmatrix.param('experimentLow', 0)
74 addmatrix.param('runHigh', -1)
75 addmatrix.param('runLow', 0)
76 
77 eventinfosetter = b2.register_module('EventInfoSetter')
78 eventinfosetter.param('evtNumList', [10])
79 eventinfosetter.param('runList', [0])
80 eventinfosetter.param('expList', [0])
81 
82 my_path.add_module(addmatrix)
83 my_path.add_module(eventinfosetter)
84 
85 b2.process(my_path)