12 from collections
import defaultdict
14 from ROOT
import Belle2
15 from ROOT.Belle2
import DBAccessorBase, DBStoreEntry
17 from smartBKG
import TOKENIZE_DICT, PREPROC_CONFIG, MODEL_CONFIG
20 def check_status_bit(status_bit):
22 Returns True if conditions are satisfied (not an unusable particle)
25 (status_bit & 1 << 4 == 0) &
26 (status_bit & 1 << 5 == 0) &
27 (status_bit & 1 << 6 == 0) &
28 (status_bit & 1 << 7 == 0)
35 1. Build a graph from an event composed of MCParticles
36 2. Apply the well-trained model for reweighting or sampling method to get a score
37 3. Execute reweighting or sampling process to get a weight
40 model_file(str): Path to saved model
41 model_config(dict): Parameters to build the model
42 preproc_config(dict): Parameters to provide information for preprocessing
43 threshold(float): Threshold for event selection using reweighting method, value *None* indicating sampling mehtod
44 extra_info_var(str): Name of eventExtraInfo to save model prediction to
45 global_tag(str): Tag in ConditionDB where the well trained model was stored
46 payload(str): Payload for the well trained model in global tag
49 Pass or rejected according to random sampling or selection with the given threshold
52 Score after the NN filter indicating the probability of the event to pass is saved
53 under ``EventExtraInfo.extra_info_var``.
55 Use ``eventExtraInfo(extra_info_var)`` in ``modularAnalysis.variablesToNtuple`` or
56 ``additionalBranches=["EventExtraInfo"]`` in ``mdst.add_mdst_output`` to have access to the scores.
62 model_config=MODEL_CONFIG,
63 preproc_config=PREPROC_CONFIG,
65 extra_info_var="NN_prediction",
66 global_tag="SmartBKG_GATGAP",
67 payload="GATGAPgen.pth"
71 :param model_file: TODO
72 :param model_config: TODO
73 :param preproc_config: TODO
74 :param threshold: TODO
75 :param extra_info_var: TODO
76 :param global_tag: TODO
94 b2.conditions.prepend_globaltag(global_tag)
98 Initialise module before any events are processed
101 from smartBKG.models.gatgap
import GATGAPModel
103 DEVICE = torch.device(
"cpu")
107 accessor = DBAccessorBase(DBStoreEntry.c_RawFile, self.
payloadpayload,
True)
108 self.
model_filemodel_file = accessor.getFilename()
109 trained_parameters = torch.load(self.
model_filemodel_file, map_location=DEVICE)
113 self.
modelmodel.load_state_dict(trained_parameters[
'model_state_dict'])
129 Collect information from database, build graphs, make predictions and select through sampling or threshold
143 for i, mcp
in enumerate(mcplist):
144 if mcp.isPrimaryParticle():
146 if not check_status_bit(mcp.getStatus()):
149 prodTime = mcp.getProductionTime()
152 root_prodTime = prodTime
153 prodTime -= root_prodTime
155 four_vec = mcp.get4Vector()
156 prod_vec = mcp.getProductionVertex()
159 self.
gen_varsgen_vars[
'prodTime'].append(prodTime)
160 self.
gen_varsgen_vars[
'energy'].append(mcp.getEnergy())
161 self.
gen_varsgen_vars[
'x'].append(prod_vec.x())
162 self.
gen_varsgen_vars[
'y'].append(prod_vec.y())
163 self.
gen_varsgen_vars[
'z'].append(prod_vec.z())
164 self.
gen_varsgen_vars[
'px'].append(four_vec.Px())
165 self.
gen_varsgen_vars[
'py'].append(four_vec.Py())
166 self.
gen_varsgen_vars[
'pz'].append(four_vec.Pz())
167 self.
gen_varsgen_vars[
'PDG'].append(
168 TOKENIZE_DICT[int(mcp.getPDG())]
172 df = pd.DataFrame(self.
gen_varsgen_vars).tail(1)
173 df.query(
" and ".join(self.
preproc_configpreproc_config[
"cuts"]), inplace=
True)
175 for values
in self.
gen_varsgen_vars.values():
180 array_indices.append(mcp.getArrayIndex())
181 mother = mcp.getMother()
183 mother_indices.append(mother.getArrayIndex())
185 mother_indices.append(0)
188 array_indices=array_indices, mother_indices=mother_indices,
190 symmetrize=
True, add_self_loops=
True
194 pred = torch.sigmoid(self.
modelmodel(graph)).detach().numpy().squeeze()
201 self.return_value(int(pred >= self.
thresholdthreshold))
203 self.return_value(int(pred >= np.random.rand()))
207 Map the mother indices to an enumerated list. The one-hot encoded version
208 of that list then corresponds to the adjacency matrix.
211 >>> mapped_mother_indices(
212 ... [0, 1, 3, 5, 6, 7, 8, 9, 10],
213 ... [0, 0, 0, 1, 1, 1, 5, 5, 7]
215 [0, 0, 0, 1, 1, 1, 3, 3, 5]
218 array_indices: list or array of indices. Each index has to be unique.
219 mother_indices: list or array of mother indices.
222 List of mapped indices
224 idx_dict = {v: i
for i, v
in enumerate(array_indices)}
225 return [idx_dict[m]
for m
in mother_indices]
227 def build_graph(self, array_indices, mother_indices, PDGs, Features,
228 symmetrize=True, add_self_loops=True):
230 Build graph from preprocessed particle information
234 os.environ[
"DGLBACKEND"] =
"pytorch"
241 dst = np.arange(len(src))
242 src_new, dst_new = src, dst
245 np.concatenate([src, dst]),
246 np.concatenate([dst, src])
249 src_new, dst_new = map(
250 np.array, zip(*[(s, d)
for s, d
in zip(src_new, dst_new)
if not s == d])
254 np.concatenate([src_new, dst]),
255 np.concatenate([dst_new, dst])
257 graph = dgl.graph((src_new, dst_new))
258 graph.ndata[
"x_pdg"] = torch.tensor(PDGs, dtype=torch.int32)
259 graph.ndata[
"x_feature"] = torch.tensor(np.transpose(Features), dtype=torch.float32)
A (simplified) python wrapper for StoreArray.
a (simplified) python wrapper for StoreObjPtr.
def build_graph(self, array_indices, mother_indices, PDGs, Features, symmetrize=True, add_self_loops=True)
model
model with trained parameters
out_features
node features
gen_vars
generated variables
def __init__(self, model_file=None, model_config=MODEL_CONFIG, preproc_config=PREPROC_CONFIG, threshold=None, extra_info_var="NN_prediction", global_tag="SmartBKG_GATGAP", payload="GATGAPgen.pth")
def mapped_mother_indices(self, array_indices, mother_indices)
EventExtraInfo
StoreArray to save weights to.