Belle II Software development
NNFilterModule Class Reference
Inheritance diagram for NNFilterModule:

Public Member Functions

def __init__ (self, model_file=None, model_config=MODEL_CONFIG, preproc_config=PREPROC_CONFIG, threshold=None, extra_info_var="NN_prediction", global_tag="SmartBKG_GATGAP", payload="GATGAPgen.pth")
 
def initialize (self)
 
def event (self)
 

Public Attributes

 model_file
 Path to the saved model file.
 
 model_config
 Parameters for building the model.
 
 preproc_config
 Parameters for preprocessing.
 
 threshold
 Threshold for event selection using reweighting method, value None indicating sampling method.
 
 extra_info_var
 Name of eventExtraInfo to save model prediction to.
 
 payload
 Payload for the well-trained model in global tag.
 
 model
 model with trained parameters
 
 EventExtraInfo
 StoreArray to save weights to.
 
 EventInfo
 Initialise event metadata from data store.
 
 out_features
 node features
 

Detailed Description

Goals:
   1. Build a graph from an event composed of MCParticles
   2. Apply the well-trained model for reweighting or sampling method to get a score
   3. Execute reweighting or sampling process to get a weight

Arguments:
   model_file(str): Path to the saved model
   model_config(dict): Parameters to build the model
   preproc_config(dict): Parameters for preprocessing
   threshold(float): Threshold for event selection using reweighting method, value *None* indicating sampling mehtod
   extra_info_var(str): Name of eventExtraInfo to save model prediction to
   global_tag(str): Tag in ConditionDB where the well trained model was stored
   payload(str): Payload for the well trained model in global tag

Returns:
   Pass or rejected according to random sampling or selection with the given threshold

Note:
    Score after the NN filter indicating the probability of the event to pass is saved
    under ``EventExtraInfo.extra_info_var``.

    Use ``eventExtraInfo(extra_info_var)`` in ``modularAnalysis.variablesToNtuple`` or
    ``additionalBranches=["EventExtraInfo"]`` in ``mdst.add_mdst_output`` to have access to the scores.

Definition at line 23 of file NN_filter_module.py.

Constructor & Destructor Documentation

◆ __init__()

def __init__ (   self,
  model_file = None,
  model_config = MODEL_CONFIG,
  preproc_config = PREPROC_CONFIG,
  threshold = None,
  extra_info_var = "NN_prediction",
  global_tag = "SmartBKG_GATGAP",
  payload = "GATGAPgen.pth" 
)
Initialise the class.
:param model_file: Path to the saved model file.
:param model_config: Parameters for building the model.
:param preproc_config: Parameters for preprocessing.
:param threshold: Threshold for event selection using reweighting method, value *None* indicating sampling mehtod.
:param extra_info_var: Name of eventExtraInfo to save model prediction to.
:param global_tag: Tag in ConditionDB where the well-trained model was stored.
:param payload: Payload for the well-trained model in global tag.

Definition at line 50 of file NN_filter_module.py.

59 ):
60 """
61 Initialise the class.
62 :param model_file: Path to the saved model file.
63 :param model_config: Parameters for building the model.
64 :param preproc_config: Parameters for preprocessing.
65 :param threshold: Threshold for event selection using reweighting method, value *None* indicating sampling mehtod.
66 :param extra_info_var: Name of eventExtraInfo to save model prediction to.
67 :param global_tag: Tag in ConditionDB where the well-trained model was stored.
68 :param payload: Payload for the well-trained model in global tag.
69 """
70 super().__init__()
71
72 self.model_file = model_file
73
74 self.model_config = model_config
75
76 self.preproc_config = preproc_config
77
78 self.threshold = threshold
79
80 self.extra_info_var = extra_info_var
81
82 self.payload = payload
83
84 # set additional database conditions for trained neural network
85 b2.conditions.prepend_globaltag(global_tag)
86

Member Function Documentation

◆ event()

def event (   self)
Collect information from database, build graphs, make predictions and select through sampling or threshold

Definition at line 112 of file NN_filter_module.py.

112 def event(self):
113 """
114 Collect information from database, build graphs, make predictions and select through sampling or threshold
115 """
116 # Need to create the eventExtraInfo entry for each event
117 self.EventExtraInfo.create()
118 df_dict = load_particle_list(mcplist=Belle2.PyStoreArray("MCParticles"), evtNum=self.EventInfo.getEvent(), label=True)
119 single_input = preprocessed(df_dict, particle_selection=self.preproc_config['cuts'])
120 graph = ArrayDataset(single_input, batch_size=1)[0][0]
121 # Output pass probability
122 pred = torch.sigmoid(self.model(graph)).detach().numpy().squeeze()
123
124 # Save the pass probability to EventExtraInfo
125 self.EventExtraInfo.addExtraInfo(self.extra_info_var, pred)
126
127 # Module returns bool of whether prediciton passes threshold for use in basf2 path flow control
128 if self.threshold:
129 self.return_value(int(pred >= self.threshold))
130 else:
131 self.return_value(int(pred >= np.random.rand()))
A (simplified) python wrapper for StoreArray.
Definition: PyStoreArray.h:72

◆ initialize()

def initialize (   self)
Initialise module before any events are processed

Definition at line 87 of file NN_filter_module.py.

87 def initialize(self):
88 """
89 Initialise module before any events are processed
90 """
91 # read trained model parameters from database
92 if not self.model_file:
93 accessor = DBAccessorBase(DBStoreEntry.c_RawFile, self.payload, True)
94 self.model_file = accessor.getFilename()
95 trained_parameters = torch.load(self.model_file, map_location=DEVICE)
96
97
98 self.model = GATGAPModel(**self.model_config)
99 self.model.load_state_dict(trained_parameters['model_state_dict'])
100
101
102 self.EventExtraInfo = Belle2.PyStoreObj('EventExtraInfo')
103 if not self.EventExtraInfo.isValid():
104 self.EventExtraInfo.registerInDataStore()
105
106 self.EventInfo = Belle2.PyStoreObj('EventMetaData')
107
108 self.out_features = self.preproc_config['features']
109 if 'PDG' in self.preproc_config['features']:
110 self.out_features.remove('PDG')
111
a (simplified) python wrapper for StoreObjPtr.
Definition: PyStoreObj.h:67

Member Data Documentation

◆ EventExtraInfo

EventExtraInfo

StoreArray to save weights to.

Definition at line 102 of file NN_filter_module.py.

◆ EventInfo

EventInfo

Initialise event metadata from data store.

Definition at line 106 of file NN_filter_module.py.

◆ extra_info_var

extra_info_var

Name of eventExtraInfo to save model prediction to.

Definition at line 80 of file NN_filter_module.py.

◆ model

model

model with trained parameters

Definition at line 98 of file NN_filter_module.py.

◆ model_config

model_config

Parameters for building the model.

Definition at line 74 of file NN_filter_module.py.

◆ model_file

model_file

Path to the saved model file.

Definition at line 72 of file NN_filter_module.py.

◆ out_features

out_features

node features

Definition at line 108 of file NN_filter_module.py.

◆ payload

payload

Payload for the well-trained model in global tag.

Definition at line 82 of file NN_filter_module.py.

◆ preproc_config

preproc_config

Parameters for preprocessing.

Definition at line 76 of file NN_filter_module.py.

◆ threshold

threshold

Threshold for event selection using reweighting method, value None indicating sampling method.

Definition at line 78 of file NN_filter_module.py.


The documentation for this class was generated from the following file: