Belle II Software  release-08-01-10
NN_filter_module.py
1 
8 import os
9 import numpy as np
10 import pandas as pd
11 
12 from collections import defaultdict
13 import basf2 as b2
14 from ROOT import Belle2
15 from ROOT.Belle2 import DBAccessorBase, DBStoreEntry
16 
17 from smartBKG import TOKENIZE_DICT, PREPROC_CONFIG, MODEL_CONFIG
18 
19 
20 def check_status_bit(status_bit):
21  """
22  Returns True if conditions are satisfied (not an unusable particle)
23  """
24  return (
25  (status_bit & 1 << 4 == 0) & # IsVirtual
26  (status_bit & 1 << 5 == 0) & # Initial
27  (status_bit & 1 << 6 == 0) & # ISRPhoton
28  (status_bit & 1 << 7 == 0) # FSRPhoton
29  )
30 
31 
32 class NNFilterModule(b2.Module):
33  """
34  Goals:
35  1. Build a graph from an event composed of MCParticles
36  2. Apply the well-trained model for reweighting or sampling method to get a score
37  3. Execute reweighting or sampling process to get a weight
38 
39  Arguments:
40  model_file(str): Path to saved model
41  model_config(dict): Parameters to build the model
42  preproc_config(dict): Parameters to provide information for preprocessing
43  threshold(float): Threshold for event selection using reweighting method, value *None* indicating sampling mehtod
44  extra_info_var(str): Name of eventExtraInfo to save model prediction to
45  global_tag(str): Tag in ConditionDB where the well trained model was stored
46  payload(str): Payload for the well trained model in global tag
47 
48  Returns:
49  Pass or rejected according to random sampling or selection with the given threshold
50 
51  Note:
52  Score after the NN filter indicating the probability of the event to pass is saved
53  under ``EventExtraInfo.extra_info_var``.
54 
55  Use ``eventExtraInfo(extra_info_var)`` in ``modularAnalysis.variablesToNtuple`` or
56  ``additionalBranches=["EventExtraInfo"]`` in ``mdst.add_mdst_output`` to have access to the scores.
57  """
58 
59  def __init__(
60  self,
61  model_file=None,
62  model_config=MODEL_CONFIG,
63  preproc_config=PREPROC_CONFIG,
64  threshold=None,
65  extra_info_var="NN_prediction",
66  global_tag="SmartBKG_GATGAP",
67  payload="GATGAPgen.pth"
68  ):
69  """
70  Initialise the class.
71  :param model_file: TODO
72  :param model_config: TODO
73  :param preproc_config: TODO
74  :param threshold: TODO
75  :param extra_info_var: TODO
76  :param global_tag: TODO
77  :param payload: TODO
78  """
79  super().__init__()
80 
81  self.model_filemodel_file = model_file
82 
83  self.model_configmodel_config = model_config
84 
85  self.preproc_configpreproc_config = preproc_config
86 
87  self.thresholdthreshold = threshold
88 
89  self.extra_info_varextra_info_var = extra_info_var
90 
91  self.payloadpayload = payload
92 
93  # set additional database conditions for trained neural network
94  b2.conditions.prepend_globaltag(global_tag)
95 
96  def initialize(self):
97  """
98  Initialise module before any events are processed
99  """
100  import torch
101  from smartBKG.models.gatgap import GATGAPModel
102 
103  DEVICE = torch.device("cpu")
104 
105  # read trained model parameters from
106  if not self.model_filemodel_file:
107  accessor = DBAccessorBase(DBStoreEntry.c_RawFile, self.payloadpayload, True)
108  self.model_filemodel_file = accessor.getFilename()
109  trained_parameters = torch.load(self.model_filemodel_file, map_location=DEVICE)
110 
111 
112  self.modelmodel = GATGAPModel(**self.model_configmodel_config)
113  self.modelmodel.load_state_dict(trained_parameters['model_state_dict'])
114 
115 
116  self.EventExtraInfoEventExtraInfo = Belle2.PyStoreObj('EventExtraInfo')
117  if not self.EventExtraInfoEventExtraInfo.isValid():
118  self.EventExtraInfoEventExtraInfo.registerInDataStore()
119 
120 
121  self.gen_varsgen_vars = defaultdict(list)
122 
123  self.out_featuresout_features = self.preproc_configpreproc_config['features']
124  if 'PDG' in self.preproc_configpreproc_config['features']:
125  self.out_featuresout_features.remove('PDG')
126 
127  def event(self):
128  """
129  Collect information from database, build graphs, make predictions and select through sampling or threshold
130  """
131  import torch
132  # Initialize for every event
133  self.gen_varsgen_vars.clear()
134 
135  # Need to create the eventExtraInfo entry for each event
136  self.EventExtraInfoEventExtraInfo.create()
137 
138  mcplist = Belle2.PyStoreArray("MCParticles")
139 
140  array_indices = []
141  mother_indices = []
142 
143  for i, mcp in enumerate(mcplist):
144  if mcp.isPrimaryParticle():
145  # Check mc particle is useable
146  if not check_status_bit(mcp.getStatus()):
147  continue
148 
149  prodTime = mcp.getProductionTime()
150  # record the production time of root particle for the correction of jitter
151  if i == 0:
152  root_prodTime = prodTime
153  prodTime -= root_prodTime
154 
155  four_vec = mcp.get4Vector()
156  prod_vec = mcp.getProductionVertex()
157 
158  # build generated variables as node features
159  self.gen_varsgen_vars['prodTime'].append(prodTime)
160  self.gen_varsgen_vars['energy'].append(mcp.getEnergy())
161  self.gen_varsgen_vars['x'].append(prod_vec.x())
162  self.gen_varsgen_vars['y'].append(prod_vec.y())
163  self.gen_varsgen_vars['z'].append(prod_vec.z())
164  self.gen_varsgen_vars['px'].append(four_vec.Px())
165  self.gen_varsgen_vars['py'].append(four_vec.Py())
166  self.gen_varsgen_vars['pz'].append(four_vec.Pz())
167  self.gen_varsgen_vars['PDG'].append(
168  TOKENIZE_DICT[int(mcp.getPDG())]
169  )
170 
171  # Particle level cutting
172  df = pd.DataFrame(self.gen_varsgen_vars).tail(1)
173  df.query(" and ".join(self.preproc_configpreproc_config["cuts"]), inplace=True)
174  if df.empty:
175  for values in self.gen_varsgen_vars.values():
176  values.pop()
177  continue
178 
179  # Collect indices for graph
180  array_indices.append(mcp.getArrayIndex())
181  mother = mcp.getMother()
182  if mother:
183  mother_indices.append(mother.getArrayIndex())
184  else:
185  mother_indices.append(0)
186 
187  graph = self.build_graphbuild_graph(
188  array_indices=array_indices, mother_indices=mother_indices,
189  PDGs=self.gen_varsgen_vars['PDG'], Features=[self.gen_varsgen_vars[key] for key in self.out_featuresout_features],
190  symmetrize=True, add_self_loops=True
191  )
192 
193  # Output pass probability
194  pred = torch.sigmoid(self.modelmodel(graph)).detach().numpy().squeeze()
195 
196  # Save the pass probability to EventExtraInfo
197  self.EventExtraInfoEventExtraInfo.addExtraInfo(self.extra_info_varextra_info_var, pred)
198 
199  # Module returns bool of whether prediciton passes threshold for use in basf2 path flow control
200  if self.thresholdthreshold:
201  self.return_value(int(pred >= self.thresholdthreshold))
202  else:
203  self.return_value(int(pred >= np.random.rand()))
204 
205  def mapped_mother_indices(self, array_indices, mother_indices):
206  """
207  Map the mother indices to an enumerated list. The one-hot encoded version
208  of that list then corresponds to the adjacency matrix.
209 
210  Example:
211  >>> mapped_mother_indices(
212  ... [0, 1, 3, 5, 6, 7, 8, 9, 10],
213  ... [0, 0, 0, 1, 1, 1, 5, 5, 7]
214  ... )
215  [0, 0, 0, 1, 1, 1, 3, 3, 5]
216 
217  Args:
218  array_indices: list or array of indices. Each index has to be unique.
219  mother_indices: list or array of mother indices.
220 
221  Returns:
222  List of mapped indices
223  """
224  idx_dict = {v: i for i, v in enumerate(array_indices)}
225  return [idx_dict[m] for m in mother_indices]
226 
227  def build_graph(self, array_indices, mother_indices, PDGs, Features,
228  symmetrize=True, add_self_loops=True):
229  """
230  Build graph from preprocessed particle information
231  """
232  import torch
233  import dgl
234  os.environ["DGLBACKEND"] = "pytorch"
235 
236  # Build adjacency mapping
237  adjacency = self.mapped_mother_indicesmapped_mother_indices(array_indices, mother_indices)
238 
239  # Build graph
240  src = adjacency
241  dst = np.arange(len(src))
242  src_new, dst_new = src, dst
243  if symmetrize:
244  src_new, dst_new = (
245  np.concatenate([src, dst]),
246  np.concatenate([dst, src])
247  )
248  # remove self-loops (the Y(4S)) to avoid duplicated self loops
249  src_new, dst_new = map(
250  np.array, zip(*[(s, d) for s, d in zip(src_new, dst_new) if not s == d])
251  )
252  if add_self_loops:
253  src_new, dst_new = (
254  np.concatenate([src_new, dst]),
255  np.concatenate([dst_new, dst])
256  )
257  graph = dgl.graph((src_new, dst_new))
258  graph.ndata["x_pdg"] = torch.tensor(PDGs, dtype=torch.int32)
259  graph.ndata["x_feature"] = torch.tensor(np.transpose(Features), dtype=torch.float32)
260 
261  return graph
A (simplified) python wrapper for StoreArray.
Definition: PyStoreArray.h:72
a (simplified) python wrapper for StoreObjPtr.
Definition: PyStoreObj.h:67
def build_graph(self, array_indices, mother_indices, PDGs, Features, symmetrize=True, add_self_loops=True)
model
model with trained parameters
def __init__(self, model_file=None, model_config=MODEL_CONFIG, preproc_config=PREPROC_CONFIG, threshold=None, extra_info_var="NN_prediction", global_tag="SmartBKG_GATGAP", payload="GATGAPgen.pth")
def mapped_mother_indices(self, array_indices, mother_indices)
EventExtraInfo
StoreArray to save weights to.