12from ROOT
import Belle2
13from ROOT.Belle2
import DBAccessorBase, DBStoreEntry
15from smartBKG
import PREPROC_CONFIG, MODEL_CONFIG
16from smartBKG.utils.preprocess
import load_particle_list, preprocessed
17from smartBKG.models.gatgap
import GATGAPModel
18from smartBKG.utils.dataset
import ArrayDataset
20DEVICE = torch.device(
"cpu")
26 1. Build a graph from an event composed of MCParticles
27 2. Apply the well-trained model
for reweighting
or sampling method to get a score
28 3. Execute reweighting
or sampling process to get a weight
34 threshold(float): Threshold
for event selection using reweighting method, value *
None* indicating sampling mehtod
35 extra_info_var(str): Name of eventExtraInfo to save model prediction to
36 global_tag(str): Tag
in ConditionDB where the well trained model was stored
37 payload(str): Payload
for the well trained model
in global tag
40 Pass
or rejected according to random sampling
or selection
with the given threshold
43 Score after the NN filter indicating the probability of the event to
pass is saved
44 under ``EventExtraInfo.extra_info_var``.
46 Use ``eventExtraInfo(extra_info_var)``
in ``modularAnalysis.variablesToNtuple``
or
47 ``additionalBranches=[
"EventExtraInfo"]``
in ``mdst.add_mdst_output`` to have access to the scores.
53 model_config=MODEL_CONFIG,
54 preproc_config=PREPROC_CONFIG,
56 extra_info_var="NN_prediction",
57 global_tag="SmartBKG_GATGAP",
58 payload="GATGAPgen.pth"
62 :param model_file: Path to the saved model file.
63 :param model_config: Parameters
for building the model.
64 :param preproc_config: Parameters
for preprocessing.
65 :param threshold: Threshold
for event selection using reweighting method, value *
None* indicating sampling mehtod.
66 :param extra_info_var: Name of eventExtraInfo to save model prediction to.
67 :param global_tag: Tag
in ConditionDB where the well-trained model was stored.
68 :param payload: Payload
for the well-trained model
in global tag.
85 b2.conditions.prepend_globaltag(global_tag)
89 Initialise module before any events are processed
93 accessor = DBAccessorBase(DBStoreEntry.c_RawFile, self.
payload,
True)
95 trained_parameters = torch.load(self.
model_file, map_location=DEVICE)
99 self.
model.load_state_dict(trained_parameters[
'model_state_dict'])
114 Collect information from database, build graphs, make predictions
and select through sampling
or threshold
119 single_input = preprocessed(df_dict, particle_selection=self.
preproc_config[
'cuts'])
120 graph = ArrayDataset(single_input, batch_size=1)[0][0]
122 pred = torch.sigmoid(self.
model(graph)).detach().numpy().squeeze()
129 self.return_value(int(pred >= self.
threshold))
131 self.return_value(int(pred >= np.random.rand()))
A (simplified) python wrapper for StoreArray.
a (simplified) python wrapper for StoreObjPtr.
threshold
Threshold for event selection using reweighting method, value None indicating sampling method.
model_file
Path to the saved model file.
model
model with trained parameters
EventInfo
Initialise event metadata from data store.
out_features
node features
model_config
Parameters for building the model.
def __init__(self, model_file=None, model_config=MODEL_CONFIG, preproc_config=PREPROC_CONFIG, threshold=None, extra_info_var="NN_prediction", global_tag="SmartBKG_GATGAP", payload="GATGAPgen.pth")
extra_info_var
Name of eventExtraInfo to save model prediction to.
preproc_config
Parameters for preprocessing.
payload
Payload for the well-trained model in global tag.
EventExtraInfo
StoreArray to save weights to.