Belle II Software development
NN_filter_module.py
1
8import numpy as np
9import torch
10
11import basf2 as b2
12from ROOT import Belle2
13from ROOT.Belle2 import DBAccessorBase, DBStoreEntry
14
15from smartBKG import PREPROC_CONFIG, MODEL_CONFIG
16from smartBKG.utils.preprocess import load_particle_list, preprocessed
17from smartBKG.models.gatgap import GATGAPModel
18from smartBKG.utils.dataset import ArrayDataset
19
20DEVICE = torch.device("cpu")
21
22
23class NNFilterModule(b2.Module):
24 """
25 Goals:
26 1. Build a graph from an event composed of MCParticles
27 2. Apply the well-trained model for reweighting or sampling method to get a score
28 3. Execute reweighting or sampling process to get a weight
29
30 Arguments:
31 model_file(str): Path to the saved model
32 model_config(dict): Parameters to build the model
33 preproc_config(dict): Parameters for preprocessing
34 threshold(float): Threshold for event selection using reweighting method, value *None* indicating sampling mehtod
35 extra_info_var(str): Name of eventExtraInfo to save model prediction to
36 global_tag(str): Tag in ConditionDB where the well trained model was stored
37 payload(str): Payload for the well trained model in global tag
38
39 Returns:
40 Pass or rejected according to random sampling or selection with the given threshold
41
42 Note:
43 Score after the NN filter indicating the probability of the event to pass is saved
44 under ``EventExtraInfo.extra_info_var``.
45
46 Use ``eventExtraInfo(extra_info_var)`` in ``modularAnalysis.variablesToNtuple`` or
47 ``additionalBranches=["EventExtraInfo"]`` in ``mdst.add_mdst_output`` to have access to the scores.
48 """
49
51 self,
52 model_file=None,
53 model_config=MODEL_CONFIG,
54 preproc_config=PREPROC_CONFIG,
55 threshold=None,
56 extra_info_var="NN_prediction",
57 global_tag="SmartBKG_GATGAP",
58 payload="GATGAPgen.pth"
59 ):
60 """
61 Initialise the class.
62 :param model_file: Path to the saved model file.
63 :param model_config: Parameters for building the model.
64 :param preproc_config: Parameters for preprocessing.
65 :param threshold: Threshold for event selection using reweighting method, value *None* indicating sampling mehtod.
66 :param extra_info_var: Name of eventExtraInfo to save model prediction to.
67 :param global_tag: Tag in ConditionDB where the well-trained model was stored.
68 :param payload: Payload for the well-trained model in global tag.
69 """
70 super().__init__()
71
72 self.model_file = model_file
73
74 self.model_config = model_config
75
76 self.preproc_config = preproc_config
77
78 self.threshold = threshold
79
80 self.extra_info_var = extra_info_var
81
82 self.payload = payload
83
84 # set additional database conditions for trained neural network
85 b2.conditions.prepend_globaltag(global_tag)
86
87 def initialize(self):
88 """
89 Initialise module before any events are processed
90 """
91 # read trained model parameters from database
92 if not self.model_file:
93 accessor = DBAccessorBase(DBStoreEntry.c_RawFile, self.payload, True)
94 self.model_file = accessor.getFilename()
95 trained_parameters = torch.load(self.model_file, map_location=DEVICE)
96
97
98 self.model = GATGAPModel(**self.model_config)
99 self.model.load_state_dict(trained_parameters['model_state_dict'])
100
101
102 self.EventExtraInfo = Belle2.PyStoreObj('EventExtraInfo')
103 if not self.EventExtraInfo.isValid():
104 self.EventExtraInfo.registerInDataStore()
105
106 self.EventInfo = Belle2.PyStoreObj('EventMetaData')
107
108 self.out_features = self.preproc_config['features']
109 if 'PDG' in self.preproc_config['features']:
110 self.out_features.remove('PDG')
111
112 def event(self):
113 """
114 Collect information from database, build graphs, make predictions and select through sampling or threshold
115 """
116 # Need to create the eventExtraInfo entry for each event
117 self.EventExtraInfo.create()
118 df_dict = load_particle_list(mcplist=Belle2.PyStoreArray("MCParticles"), evtNum=self.EventInfo.getEvent(), label=True)
119 single_input = preprocessed(df_dict, particle_selection=self.preproc_config['cuts'])
120 graph = ArrayDataset(single_input, batch_size=1)[0][0]
121 # Output pass probability
122 pred = torch.sigmoid(self.model(graph)).detach().numpy().squeeze()
123
124 # Save the pass probability to EventExtraInfo
125 self.EventExtraInfo.addExtraInfo(self.extra_info_var, pred)
126
127 # Module returns bool of whether prediciton passes threshold for use in basf2 path flow control
128 if self.threshold:
129 self.return_value(int(pred >= self.threshold))
130 else:
131 self.return_value(int(pred >= np.random.rand()))
A (simplified) python wrapper for StoreArray.
Definition: PyStoreArray.h:72
a (simplified) python wrapper for StoreObjPtr.
Definition: PyStoreObj.h:67
threshold
Threshold for event selection using reweighting method, value None indicating sampling method.
model_file
Path to the saved model file.
model
model with trained parameters
EventInfo
Initialise event metadata from data store.
model_config
Parameters for building the model.
def __init__(self, model_file=None, model_config=MODEL_CONFIG, preproc_config=PREPROC_CONFIG, threshold=None, extra_info_var="NN_prediction", global_tag="SmartBKG_GATGAP", payload="GATGAPgen.pth")
extra_info_var
Name of eventExtraInfo to save model prediction to.
preproc_config
Parameters for preprocessing.
payload
Payload for the well-trained model in global tag.
EventExtraInfo
StoreArray to save weights to.