Belle II Software  light-2403-persian
GraFEIModule.py
1 
8 
9 
10 import itertools
11 import numpy as np
12 import yaml
13 import warnings
14 import basf2 as b2
15 from ROOT import Belle2
16 from variables import variables as vm
17 import torch
18 from torch_geometric.data import Batch
19 from grafei.modules.LCASaverModule import get_object_list, write_hist
20 from grafei.model.geometric_network import GraFEIModel
21 from grafei.model.normalize_features import normalize_features
22 from grafei.model.edge_features import compute_edge_features
23 from grafei.model.lca_to_adjacency import lca_to_adjacency, InvalidLCAMatrix, select_good_decay
24 from grafei.model.tree_utils import masses_to_classes
25 
26 warnings.filterwarnings(
27  action="ignore", category=RuntimeWarning, message="Mean of empty slice.*"
28 )
29 
30 
31 class GraFEIModule(b2.Module):
32  """
33  Applies graFEI model to a particle list in basf2.
34  GraFEI information is stored as extraInfos.
35 
36  Args:
37  particle_list (str): Name of particle list.
38  cfg_path (str): Path to config file. If `None` the config file in the global tag is used.
39  param_file (str): Path to parameter file containing the model. If `None` the parameter file in the global tag is used.
40  sig_side_lcas (list): List containing LCAS matrix of signal-side.
41  sig_side_masses (list): List containing mass hypotheses of signal-side.
42  gpu (bool): Whether to run on a GPU.
43  payload_config_name (str): Name of config file payload. The default should be kept, except in basf2 examples.
44  payload_model_name (str): Name of model file payload. The default should be kept, except in basf2 examples.
45  """
46 
47  def __init__(
48  self,
49  particle_list,
50  cfg_path=None,
51  param_file=None,
52  sig_side_lcas=None,
53  sig_side_masses=None,
54  gpu=False,
55  payload_config_name="graFEIConfigFile",
56  payload_model_name="graFEIModelFile",
57  ):
58  """
59  Initialization.
60  """
61  super().__init__()
62 
63  self.particle_listparticle_list = particle_list
64 
65  self.cfg_pathcfg_path = cfg_path
66 
67  self.param_fileparam_file = param_file
68 
69  self.sig_side_lcassig_side_lcas = torch.tensor(sig_side_lcas) if sig_side_lcas else None
70 
71  self.sig_side_massessig_side_masses = sig_side_masses
72 
73  self.gpugpu = gpu
74 
75  self.payload_config_namepayload_config_name = payload_config_name
76 
77  self.payload_model_namepayload_model_name = payload_model_name
78 
79  def initialize(self):
80  """
81  Called at the beginning.
82  """
83  # Get weights and configs from the DB if they are not provided from the user
84  if not self.cfg_pathcfg_path:
85  config = Belle2.DBAccessorBase(
86  Belle2.DBStoreEntry.c_RawFile, self.payload_config_namepayload_config_name, True
87  )
88  self.cfg_pathcfg_path = config.getFilename()
89  if not self.param_fileparam_file:
90  model = Belle2.DBAccessorBase(
91  Belle2.DBStoreEntry.c_RawFile, self.payload_model_namepayload_model_name, True
92  )
93  self.param_fileparam_file = model.getFilename()
94 
95 
96  self.storeTrueInfostoreTrueInfo = Belle2.Environment.Instance().isMC()
97 
98 
99  self.devicedevice = torch.device(
100  "cuda" if (self.gpugpu and torch.cuda.is_available()) else "cpu"
101  )
102 
103  # Load configs
104  cfg_file = open(self.cfg_pathcfg_path, "r")
105 
106  self.configsconfigs = yaml.safe_load(cfg_file)
107 
108 
109  self.mc_particlemc_particle = None
110 
111  self.max_levelmax_level = None
112  # B or Ups reco? 0 = Ups, 1 = B0, 2 = B+
113  if self.configsconfigs["model"]["B_reco"] == 0:
114  self.mc_particlemc_particle = "Upsilon(4S):MC"
115  self.max_levelmax_level = 6
116  elif self.configsconfigs["model"]["B_reco"] == 1:
117  self.mc_particlemc_particle = "B0:MC"
118  self.max_levelmax_level = 5
119  elif self.configsconfigs["model"]["B_reco"] == 2:
120  self.mc_particlemc_particle = "B+:MC"
121  self.max_levelmax_level = 5
122  else:
123  b2.B2FATAL("The B_reco setting in the config file is incorrect.")
124 
125 
126  self.normalizenormalize = self.configsconfigs["dataset"]["config"]["normalize"]
127 
128 
129  self.use_ampuse_amp = self.configsconfigs["train"][
130  "mixed_precision"
131  ] and self.devicedevice == torch.device("cuda")
132 
133 
134  self.node_featuresnode_features = self.configsconfigs["dataset"]["config"]["features"]
135 
136  self.edge_featuresedge_features = self.configsconfigs["dataset"]["config"]["edge_features"]
137 
138  self.glob_featuresglob_features = self.configsconfigs["dataset"]["config"]["global_features"]
139 
140  # Naming convention
141  self.node_featuresnode_features = [f"feat_{name}" for name in self.node_featuresnode_features] if self.node_featuresnode_features else []
142  self.edge_featuresedge_features = [f"edge_{name}" for name in self.edge_featuresedge_features] if self.edge_featuresedge_features else []
143  self.glob_featuresglob_features = [f"glob_{name}" for name in self.glob_featuresglob_features] if self.glob_featuresglob_features else []
144 
145  self.discarded_featuresdiscarded_features = ["feat_x", "feat_y", "feat_z", "feat_px", "feat_py", "feat_p"]
146 
147  # Extract the number of features
148  n_infeatures = len(self.node_featuresnode_features)
149  e_infeatures = len(self.edge_featuresedge_features)
150  g_infeatures = len(self.glob_featuresglob_features)
151 
152 
154  self.modelmodel = GraFEIModel(
155  nfeat_in_dim=n_infeatures,
156  efeat_in_dim=e_infeatures,
157  gfeat_in_dim=g_infeatures,
158  **self.configsconfigs["model"],
159  )
160 
161  # Load paramaters' values
162  self.modelmodel.load_state_dict(
163  torch.load(self.param_fileparam_file, map_location=self.devicedevice)["model"]
164  )
165 
166  # Activate evaluation mode
167  self.modelmodel.eval()
168  # Push model to GPU in case
169  self.modelmodel.to(self.devicedevice)
170 
171  b2.B2DEBUG(10, "Model structure:\n", {self.modelmodel})
172 
173  def event(self):
174  """
175  Called at the beginning of each event.
176  """
177  b2.B2DEBUG(10, "---- Processing new event ----")
178 
179  # Get the B candidate list
180  candidate_list = get_object_list(Belle2.PyStoreObj(self.particle_listparticle_list).obj())
181 
182  # Get the particle candidate(s)
183  for candidate in candidate_list:
184  # Get FSPs
185  p_list = get_object_list(candidate.getFinalStateDaughters())
186 
187  # Number of FSPs
188  n_nodes = len(p_list)
189 
190  # Particle nature
191  masses = np.array([abs(p.getPDGCode()) for p in p_list])
192 
193  # Number of charged and photons
194  graFEI_nFSP = n_nodes
195  graFEI_nPhotons_preFit = (masses == 22).sum()
196  graFEI_nCharged_preFit = graFEI_nFSP - graFEI_nPhotons_preFit
197  graFEI_nElectrons_preFit = (masses == 11).sum()
198  graFEI_nMuons_preFit = (masses == 13).sum()
199  graFEI_nPions_preFit = (masses == 211).sum()
200  graFEI_nKaons_preFit = (masses == 321).sum()
201  graFEI_nProtons_preFit = (masses == 2212).sum()
202  graFEI_nLeptons_preFit = graFEI_nElectrons_preFit + graFEI_nMuons_preFit
203  graFEI_nOthers_preFit = graFEI_nCharged_preFit - \
204  (graFEI_nLeptons_preFit + graFEI_nPions_preFit + graFEI_nKaons_preFit + graFEI_nProtons_preFit)
205 
206  candidate.addExtraInfo("graFEI_nFSP", graFEI_nFSP)
207  candidate.addExtraInfo("graFEI_nCharged_preFit", graFEI_nCharged_preFit)
208  candidate.addExtraInfo("graFEI_nPhotons_preFit", graFEI_nPhotons_preFit)
209  candidate.addExtraInfo("graFEI_nElectrons_preFit", graFEI_nElectrons_preFit)
210  candidate.addExtraInfo("graFEI_nMuons_preFit", graFEI_nMuons_preFit)
211  candidate.addExtraInfo("graFEI_nPions_preFit", graFEI_nPions_preFit)
212  candidate.addExtraInfo("graFEI_nKaons_preFit", graFEI_nKaons_preFit)
213  candidate.addExtraInfo("graFEI_nProtons_preFit", graFEI_nProtons_preFit)
214  candidate.addExtraInfo("graFEI_nLeptons_preFit", graFEI_nLeptons_preFit)
215  candidate.addExtraInfo("graFEI_nOthers_preFit", graFEI_nOthers_preFit)
216 
217  # Trivial decay tree
218  if n_nodes < 2:
219  b2.B2WARNING(
220  f"Skipping candidate with {n_nodes} reconstructed FSPs"
221  )
222 
223  continue
224 
225  # Initialize node features array
226  x_nodes = np.empty((n_nodes, len(self.node_featuresnode_features)))
227  x_dis = np.empty((n_nodes, len(self.discarded_featuresdiscarded_features)))
228 
229  # Fill node features array
230  for p, particle in enumerate(p_list):
231  for f, feat in enumerate(self.node_featuresnode_features):
232  feat = feat[feat.find("feat_") + 5:]
233  x_nodes[p, f] = vm.evaluate(feat, particle)
234  for f, feat in enumerate(self.discarded_featuresdiscarded_features):
235  feat = feat[feat.find("feat_") + 5:]
236  x_dis[p, f] = vm.evaluate(feat, particle)
237  b2.B2DEBUG(11, "Node features:\n", x_nodes)
238 
239  # Fill edge features array
240  x_edges = (compute_edge_features(self.edge_featuresedge_features, self.node_featuresnode_features + self.discarded_featuresdiscarded_features,
241  np.concatenate([x_nodes, x_dis], axis=1)) if self.edge_featuresedge_features != [] else [])
242  edge_index = torch.tensor(list(itertools.permutations(range(n_nodes), 2)), dtype=torch.long)
243  b2.B2DEBUG(11, "Edge features:\n", x_edges)
244 
245  # Fill global features # TODO: get them from basf2
246  x_global = (
247  np.array([[n_nodes]], dtype=float)
248  if self.glob_featuresglob_features != []
249  else []
250  )
251  b2.B2DEBUG(11, "Global features:\n", x_global)
252 
253  # Fill tensor to assign each node to a graph (trivial since we have only one graph per decay)
254  torch_batch = torch.zeros(size=[n_nodes], dtype=torch.long)
255 
256  # Set nans to zero, this is a surrogate value, may change in future
257  np.nan_to_num(x_nodes, copy=False)
258  np.nan_to_num(x_edges, copy=False)
259  np.nan_to_num(x_global, copy=False)
260 
261  # Normalize any features that should be
262  if self.normalizenormalize is not None:
264  self.normalizenormalize,
265  self.node_featuresnode_features,
266  x_nodes,
267  self.edge_featuresedge_features,
268  x_edges,
269  self.glob_featuresglob_features,
270  x_global,
271  )
272 
273  # Convert everything to torch tensors and/or send to some device in case
274  x = torch.tensor(x_nodes, dtype=torch.float).to(self.devicedevice)
275  edge_index = edge_index.t().contiguous().to(self.devicedevice)
276  edge_attr = torch.tensor(x_edges, dtype=torch.float).to(self.devicedevice)
277  u = torch.tensor(x_global, dtype=torch.float).to(self.devicedevice)
278  torch_batch = torch_batch.to(self.devicedevice)
279 
280  # Create Batch object to be passed to model
281  batch = Batch(
282  x=x, edge_index=edge_index, edge_attr=edge_attr, u=u, batch=torch_batch
283  )
284 
285  # Evaluate model
286  with torch.no_grad():
287  x_pred, e_pred, u_pred = self.modelmodel(batch)
288  # if self.use_amp:
289  # with autocast(enabled=True):
290  # x_pred, e_pred, u_pred = self.model(batch)
291  # else:
292  # x_pred, e_pred, u_pred = self.model(batch)
293 
294  # Select edges from predictions
295  edge_probs = torch.softmax(e_pred, dim=1)
296  edge_probability, predicted_LCA = edge_probs.max(dim=1)
297 
298  # Select masses from predictions
299  mass_probs = torch.softmax(x_pred, dim=1)
300  mass_probability, predicted_masses = mass_probs.max(dim=1)
301  b2.B2DEBUG(10, "Predicted mass classes:\n", predicted_masses)
302  b2.B2DEBUG(11, "Mass class probabilities:\n", mass_probability)
303 
304  # Count number of predicted particles for each mass hypothesis
305  graFEI_nPhotons_postFit = (predicted_masses == 6).sum()
306  graFEI_nCharged_postFit = graFEI_nFSP - graFEI_nPhotons_postFit
307  graFEI_nElectrons_postFit = (predicted_masses == 1).sum()
308  graFEI_nMuons_postFit = (predicted_masses == 2).sum()
309  graFEI_nPions_postFit = (predicted_masses == 3).sum()
310  graFEI_nKaons_postFit = (predicted_masses == 4).sum()
311  graFEI_nProtons_postFit = (predicted_masses == 5).sum()
312  graFEI_nLeptons_postFit = graFEI_nElectrons_postFit + graFEI_nMuons_postFit
313  graFEI_nOthers_postFit = (predicted_masses == 0).sum()
314 
315  # Get square matrices
316  edge_probability_square = torch.sparse_coo_tensor(
317  edge_index, edge_probability
318  ).to_dense()
319  predicted_LCA_square = torch.sparse_coo_tensor(
320  edge_index, predicted_LCA, dtype=int
321  ).to_dense()
322  b2.B2DEBUG(10, "Predicted LCA:\n", predicted_LCA_square)
323  b2.B2DEBUG(11, "Edge class probabilities:\n", edge_probability_square)
324 
325  # Remove symmetric elements from probability
326  edge_probability_unique = edge_probability_square[
327  edge_probability_square.tril(diagonal=-1) > 0
328  ]
329 
330  # Get particles predicted as matched by the model
331  predicted_matched = np.array(
332  [False if torch.all(i == 0) else True for i in predicted_LCA_square]
333  )
334  b2.B2DEBUG(10, "Predicted matched particles:\n", predicted_matched)
335  # Same but ignoring photons
336  predicted_matched_noPhotons = predicted_matched[masses != 22]
337 
338  # Get number of predicted as unmatched
339  graFEI_nPredictedUnmatched = (~predicted_matched).sum()
340  graFEI_nPredictedUnmatched_noPhotons = (
341  (~predicted_matched_noPhotons).sum()
342  if predicted_matched_noPhotons.size != 0
343  else 0
344  )
345 
346  # Get LCA of predicted matched only
347  predicted_LCA_square_matched = predicted_LCA_square[predicted_matched]
348  predicted_LCA_square_matched = predicted_LCA_square_matched[:, predicted_matched]
349 
350  # Get predicted masses of predicted matched only
351  predicted_masses_matched = predicted_masses[predicted_matched]
352 
353  # Check if LCA describes a tree graph
354  graFEI_validTree = 0
355  if not torch.all(predicted_LCA_square == 0):
356  try:
357  adjacency = lca_to_adjacency(predicted_LCA_square_matched)
358  graFEI_validTree = 1
359  except InvalidLCAMatrix:
360  pass
361 
362  # Check if event is good, depending on the chosen sig-side LCA matrix/masses
363  graFEI_goodEvent = 0
364  if graFEI_validTree:
365  # Check if the event is good
366  good_decay, root_level, sig_side_fsps = select_good_decay(predicted_LCA_square_matched,
367  predicted_masses_matched,
368  self.sig_side_lcassig_side_lcas,
369  self.sig_side_massessig_side_masses)
370  graFEI_goodEvent = int((self.max_levelmax_level == root_level) and good_decay)
371 
372  if graFEI_goodEvent:
373  # Find sig- and tag-side FSPs (1 = sig-side, 0 = tag-side)
374  p_list_matched = list(np.array(p_list)[predicted_matched])
375  for i, particle in enumerate(p_list_matched):
376  if i in sig_side_fsps:
377  particle.addExtraInfo("graFEI_sigSide", 1)
378  else:
379  particle.addExtraInfo("graFEI_sigSide", 0)
380 
381  b2.B2DEBUG(11, "This LCA describes a valid tree")
382  b2.B2DEBUG(
383  11,
384  "Predicted LCA on matched particles:\n",
385  predicted_LCA_square_matched,
386  )
387  b2.B2DEBUG(11, "Adjacency matrix:\n", adjacency)
388 
389  # Particles not assigned to B decays get -1
390  for particle in p_list:
391  if not particle.hasExtraInfo("graFEI_sigSide"):
392  particle.addExtraInfo("graFEI_sigSide", -1)
393 
394  # Define B probabilities
395  graFEI_probEdgeProd = edge_probability_unique.prod().item()
396  graFEI_probEdgeMean = edge_probability_unique.mean().item()
397  graFEI_probEdgeGeom = torch.pow(edge_probability_unique.prod(), 1/n_nodes).item()
398 
399  # Add extra info for each B candidate
400  candidate.addExtraInfo("graFEI_probEdgeProd", graFEI_probEdgeProd)
401  candidate.addExtraInfo("graFEI_probEdgeMean", graFEI_probEdgeMean)
402  candidate.addExtraInfo("graFEI_probEdgeGeom", graFEI_probEdgeGeom)
403  candidate.addExtraInfo("graFEI_validTree", graFEI_validTree)
404  candidate.addExtraInfo("graFEI_goodEvent", graFEI_goodEvent)
405  candidate.addExtraInfo("graFEI_nPhotons_postFit", graFEI_nPhotons_postFit)
406  candidate.addExtraInfo("graFEI_nCharged_postFit", graFEI_nCharged_postFit)
407  candidate.addExtraInfo("graFEI_nElectrons_postFit", graFEI_nElectrons_postFit)
408  candidate.addExtraInfo("graFEI_nMuons_postFit", graFEI_nMuons_postFit)
409  candidate.addExtraInfo("graFEI_nPions_postFit", graFEI_nPions_postFit)
410  candidate.addExtraInfo("graFEI_nKaons_postFit", graFEI_nKaons_postFit)
411  candidate.addExtraInfo("graFEI_nProtons_postFit", graFEI_nProtons_postFit)
412  candidate.addExtraInfo("graFEI_nLeptons_postFit", graFEI_nLeptons_postFit)
413  candidate.addExtraInfo("graFEI_nOthers_postFit", graFEI_nOthers_postFit)
414  candidate.addExtraInfo("graFEI_nPredictedUnmatched", graFEI_nPredictedUnmatched)
415  candidate.addExtraInfo("graFEI_nPredictedUnmatched_noPhotons", graFEI_nPredictedUnmatched_noPhotons)
416 
417  # Add MC truth information
418  if self.storeTrueInfostoreTrueInfo:
419  # Get the true IDs of the ancestors (if it's a B)
420  parentID = np.array([vm.evaluate("ancestorBIndex", p) for p in p_list], dtype=int)
421  b2.B2DEBUG(10, "Ancestor true ID:\n", parentID)
422 
423  # Get particle indices
424  p_indices = np.array(
425  [
426  p.getMCParticle().getArrayIndex() if parentID[i] >= 0 else -1
427  for (i, p) in enumerate(p_list)
428  ]
429  )
430  # Get particle masses
431  p_masses = masses_to_classes(
432  np.array(
433  [
434  p.getMCParticle().getPDG() if parentID[i] >= 0 else -1
435  for (i, p) in enumerate(p_list)
436  ]
437  )
438  )
439  b2.B2DEBUG(10, "True mass classes:\n", p_masses)
440  # And primary information
441  evt_primary = np.array(
442  [
443  p.getMCParticle().isPrimaryParticle()
444  if parentID[i] >= 0
445  else False
446  for (i, p) in enumerate(p_list)
447  ]
448  )
449  b2.B2DEBUG(10, "Is primary particle:\n", evt_primary)
450 
451  # Get unique B indices associated to each predicted matched particle which is also a primary
452  # The idea is that if a primary particle coming from the other B is categorized as unmatched,
453  # then it's ok and the decay could still have a perfectLCA
454  B_indices = parentID[np.logical_and(evt_primary, predicted_matched)]
455  b2.B2DEBUG(
456  10, "Ancestor ID of predicted matched particles:\n", B_indices
457  )
458  B_indices = list(set(B_indices))
459 
460  # Initialize truth-matching variables
461  graFEI_truth_perfectLCA = 0 # 1 if LCA perfectly reconstructed
462  graFEI_truth_isSemileptonic = -1 # 0 if hadronic, 1 is semileptonic, -1 if not matched
463  graFEI_truth_nFSP = -1 # Number of true FSPs
464  graFEI_truth_perfectMasses = int((predicted_masses.numpy() == p_masses).all()
465  ) # Check if all the masses are predicted correctly
466  graFEI_truth_nPhotons = (p_masses == 6).sum()
467  graFEI_truth_nElectrons = (p_masses == 1).sum()
468  graFEI_truth_nMuons = (p_masses == 2).sum()
469  graFEI_truth_nPions = (p_masses == 3).sum()
470  graFEI_truth_nKaons = (p_masses == 4).sum()
471  graFEI_truth_nProtons = (p_masses == 5).sum()
472  graFEI_truth_nOthers = (p_masses == 0).sum()
473 
474  # Get the generated B's
475  gen_list = Belle2.PyStoreObj(self.mc_particlemc_particle)
476 
477  # Iterate over generated Ups
478  if self.mc_particlemc_particle == "Upsilon(4S):MC" and gen_list.getListSize() > 1:
479  b2.B2WARNING(
480  f"Found {gen_list.getListSize()} true Upsilon(4S) in the generated MC (??)")
481 
482  if gen_list.getListSize() > 0:
483  # Here we look if the candidate has a perfectly reconstructed LCA
484  for genP in gen_list.obj():
485  mcp = genP.getMCParticle()
486  # If storing true info on B decays and we have matched paricles coming
487  # from different Bs the decay will not have a perfectLCA
488  if self.mc_particlemc_particle != "Upsilon(4S):MC" and len(B_indices) != 1:
489  break
490 
491  # Get array index of MC particle
492  array_index = mcp.getArrayIndex()
493 
494  # If we are reconstructing Bs, skip the other in the event
495  if self.mc_particlemc_particle != "Upsilon(4S):MC" and array_index != B_indices[0]:
496  continue
497 
498  # Write leaf history
499  (
500  leaf_hist,
501  levels,
502  _,
503  _,
504  semilep_flag,
505  ) = write_hist(
506  particle=mcp,
507  leaf_hist={},
508  levels={},
509  hist=[],
510  pdg={},
511  leaf_pdg={},
512  semilep_flag=False,
513  )
514 
515  # Skip B decays with trivial LCA (should be always false except for B -> nunu ?)
516  if len(leaf_hist) < 2:
517  continue
518 
519  # Initialize LCA...
520  true_LCA_square = np.zeros(
521  [len(leaf_hist), len(leaf_hist)], dtype=int
522  )
523 
524  # Number of true FSPs
525  graFEI_truth_nFSP = len(leaf_hist)
526 
527  # ... and fill it!
528  for x, y in itertools.combinations(enumerate(leaf_hist), 2):
529  intersection = [
530  i for i in leaf_hist[x[1]] if i in leaf_hist[y[1]]
531  ]
532  true_LCA_square[x[0], y[0]] = levels[intersection[-1]]
533  true_LCA_square[y[0], x[0]] = levels[intersection[-1]]
534 
535  x_leaves = p_indices
536  y_leaves = list(leaf_hist.keys())
537 
538  # Get LCA indices in order that the leaves appear in reconstructed particles
539  # Secondaries aren't in the LCA leaves list so they get a 0
540  locs = np.array(
541  [
542  np.where(y_leaves == i)[0].item()
543  if (i in y_leaves)
544  else 0
545  for i in x_leaves
546  ],
547  dtype=int,
548  )
549 
550  # Insert dummy rows for secondaries
551  true_LCA_square = true_LCA_square[locs, :][:, locs]
552 
553  # Set everything that's not primary (unmatched and secondaries) rows.cols to 0
554  # Note we only consider the subset of leaves that made it into x_rows
555  x_rows = np.array(
556  [
557  vm.evaluate("ancestorBIndex", p) == array_index
558  for p in p_list
559  ]
560  ) if self.mc_particlemc_particle != "Upsilon(4S):MC" else evt_primary
561 
562  primaries_from_right_cand = np.logical_and(evt_primary, x_rows)
563 
564  # Set the rows
565  true_LCA_square = np.where(
566  primaries_from_right_cand, true_LCA_square, 0
567  )
568  # Set the columns
569  true_LCA_square = np.where(
570  primaries_from_right_cand[:, None], true_LCA_square, 0
571  )
572 
573  # Convert LCA to tensor
574  true_LCA_square = torch.tensor(true_LCA_square, dtype=int)
575  b2.B2DEBUG(10, "True LCA:\n", true_LCA_square)
576 
577  # Check if perfect LCA
578  if (true_LCA_square == predicted_LCA_square).all():
579  graFEI_truth_perfectLCA = 1
580  b2.B2DEBUG(10, "LCA perfectly reconstructed!")
581 
582  # Assign semileptonic flag
583  graFEI_truth_isSemileptonic = int(semilep_flag)
584 
585  # Perfect event = perfectLCA and perfectMasses
586  graFEI_truth_perfectEvent = int(graFEI_truth_perfectLCA and graFEI_truth_perfectMasses)
587 
588  # Write extra info
589  candidate.addExtraInfo("graFEI_truth_perfectLCA", graFEI_truth_perfectLCA)
590  candidate.addExtraInfo("graFEI_truth_perfectMasses", graFEI_truth_perfectMasses)
591  candidate.addExtraInfo("graFEI_truth_perfectEvent", graFEI_truth_perfectEvent)
592  candidate.addExtraInfo("graFEI_truth_isSemileptonic", graFEI_truth_isSemileptonic)
593  candidate.addExtraInfo("graFEI_truth_nFSP", graFEI_truth_nFSP)
594  candidate.addExtraInfo("graFEI_truth_nPhotons", graFEI_truth_nPhotons)
595  candidate.addExtraInfo("graFEI_truth_nElectrons", graFEI_truth_nElectrons)
596  candidate.addExtraInfo("graFEI_truth_nMuons", graFEI_truth_nMuons)
597  candidate.addExtraInfo("graFEI_truth_nPions", graFEI_truth_nPions)
598  candidate.addExtraInfo("graFEI_truth_nKaons", graFEI_truth_nKaons)
599  candidate.addExtraInfo("graFEI_truth_nProtons", graFEI_truth_nProtons)
600  candidate.addExtraInfo("graFEI_truth_nOthers", graFEI_truth_nOthers)
Base class for DBObjPtr and DBArray for easier common treatment.
static Environment & Instance()
Static method to get a reference to the Environment instance.
Definition: Environment.cc:28
a (simplified) python wrapper for StoreObjPtr.
Definition: PyStoreObj.h:67
gpu
If running on GPU.
Definition: GraFEIModule.py:73
storeTrueInfo
Figure out if we re running on data or MC.
Definition: GraFEIModule.py:96
payload_model_name
Model file name in the payload.
Definition: GraFEIModule.py:77
param_file
PyTorch parameter file path.
Definition: GraFEIModule.py:67
model
The model The correct edge_classes is taken from the config file.
particle_list
Input particle list.
Definition: GraFEIModule.py:63
use_amp
Mixed precision.
discarded_features
Discarded node features.
max_level
Max LCAS level.
def __init__(self, particle_list, cfg_path=None, param_file=None, sig_side_lcas=None, sig_side_masses=None, gpu=False, payload_config_name="graFEIConfigFile", payload_model_name="graFEIModelFile")
Definition: GraFEIModule.py:57
edge_features
Edge features.
normalize
Normalize features.
node_features
Node features.
sig_side_masses
Chosen sig-side mass hypotheses.
Definition: GraFEIModule.py:71
device
Figure out which device all this is running on - CPU or GPU.
Definition: GraFEIModule.py:99
mc_particle
Top MC particle.
cfg_path
Config yaml file path.
Definition: GraFEIModule.py:65
glob_features
Global features.
payload_config_name
Config file name in the payload.
Definition: GraFEIModule.py:75
sig_side_lcas
Chosen sig-side LCAS.
Definition: GraFEIModule.py:69