Belle II Software light-2406-ragdoll
GraFEIModule Class Reference
Inheritance diagram for GraFEIModule:
Collaboration diagram for GraFEIModule:

Public Member Functions

def __init__ (self, particle_list, cfg_path=None, param_file=None, sig_side_lcas=None, sig_side_masses=None, gpu=False, payload_config_name="graFEIConfigFile", payload_model_name="graFEIModelFile")
 
def initialize (self)
 
def event (self)
 

Public Attributes

 particle_list
 Input particle list.
 
 cfg_path
 Config yaml file path.
 
 param_file
 PyTorch parameter file path.
 
 sig_side_lcas
 Chosen sig-side LCAS.
 
 sig_side_masses
 Chosen sig-side mass hypotheses.
 
 gpu
 If running on GPU.
 
 payload_config_name
 Config file name in the payload.
 
 payload_model_name
 Model file name in the payload.
 
 storeTrueInfo
 Figure out if we re running on data or MC.
 
 device
 Figure out which device all this is running on - CPU or GPU.
 
 configs
 Config file.
 
 mc_particle
 Top MC particle.
 
 max_level
 Max LCAS level.
 
 normalize
 Normalize features.
 
 use_amp
 Mixed precision.
 
 node_features
 Node features.
 
 edge_features
 Edge features.
 
 glob_features
 Global features.
 
 discarded_features
 Discarded node features.
 
 model
 The model The correct edge_classes is taken from the config file.
 

Detailed Description

Applies graFEI model to a particle list in basf2.
GraFEI information is stored as extraInfos.

Args:
    particle_list (str): Name of particle list.
    cfg_path (str): Path to config file. If `None` the config file in the global tag is used.
    param_file (str): Path to parameter file containing the model. If `None` the parameter file in the global tag is used.
    sig_side_lcas (list): List containing LCAS matrix of signal-side.
    sig_side_masses (list): List containing mass hypotheses of signal-side.
    gpu (bool): Whether to run on a GPU.
    payload_config_name (str): Name of config file payload. The default should be kept, except in basf2 examples.
    payload_model_name (str): Name of model file payload. The default should be kept, except in basf2 examples.

Definition at line 31 of file GraFEIModule.py.

Constructor & Destructor Documentation

◆ __init__()

def __init__ (   self,
  particle_list,
  cfg_path = None,
  param_file = None,
  sig_side_lcas = None,
  sig_side_masses = None,
  gpu = False,
  payload_config_name = "graFEIConfigFile",
  payload_model_name = "graFEIModelFile" 
)
Initialization.

Definition at line 47 of file GraFEIModule.py.

57 ):
58 """
59 Initialization.
60 """
61 super().__init__()
62
63 self.particle_list = particle_list
64
65 self.cfg_path = cfg_path
66
67 self.param_file = param_file
68
69 self.sig_side_lcas = torch.tensor(sig_side_lcas) if sig_side_lcas else None
70
71 self.sig_side_masses = sig_side_masses
72
73 self.gpu = gpu
74
75 self.payload_config_name = payload_config_name
76
77 self.payload_model_name = payload_model_name
78

Member Function Documentation

◆ event()

def event (   self)
Called at the beginning of each event.

Definition at line 173 of file GraFEIModule.py.

173 def event(self):
174 """
175 Called at the beginning of each event.
176 """
177 b2.B2DEBUG(10, "---- Processing new event ----")
178
179 # Get the B candidate list
180 candidate_list = get_object_list(Belle2.PyStoreObj(self.particle_list).obj())
181
182 # Get the particle candidate(s)
183 for candidate in candidate_list:
184 # Get FSPs
185 p_list = get_object_list(candidate.getFinalStateDaughters())
186
187 # Number of FSPs
188 n_nodes = len(p_list)
189
190 # Particle nature
191 masses = np.array([abs(p.getPDGCode()) for p in p_list])
192
193 # Number of charged and photons
194 graFEI_nFSP = n_nodes
195 graFEI_nPhotons_preFit = (masses == 22).sum()
196 graFEI_nCharged_preFit = graFEI_nFSP - graFEI_nPhotons_preFit
197 graFEI_nElectrons_preFit = (masses == 11).sum()
198 graFEI_nMuons_preFit = (masses == 13).sum()
199 graFEI_nPions_preFit = (masses == 211).sum()
200 graFEI_nKaons_preFit = (masses == 321).sum()
201 graFEI_nProtons_preFit = (masses == 2212).sum()
202 graFEI_nLeptons_preFit = graFEI_nElectrons_preFit + graFEI_nMuons_preFit
203 graFEI_nOthers_preFit = graFEI_nCharged_preFit - \
204 (graFEI_nLeptons_preFit + graFEI_nPions_preFit + graFEI_nKaons_preFit + graFEI_nProtons_preFit)
205
206 candidate.addExtraInfo("graFEI_nFSP", graFEI_nFSP)
207 candidate.addExtraInfo("graFEI_nCharged_preFit", graFEI_nCharged_preFit)
208 candidate.addExtraInfo("graFEI_nPhotons_preFit", graFEI_nPhotons_preFit)
209 candidate.addExtraInfo("graFEI_nElectrons_preFit", graFEI_nElectrons_preFit)
210 candidate.addExtraInfo("graFEI_nMuons_preFit", graFEI_nMuons_preFit)
211 candidate.addExtraInfo("graFEI_nPions_preFit", graFEI_nPions_preFit)
212 candidate.addExtraInfo("graFEI_nKaons_preFit", graFEI_nKaons_preFit)
213 candidate.addExtraInfo("graFEI_nProtons_preFit", graFEI_nProtons_preFit)
214 candidate.addExtraInfo("graFEI_nLeptons_preFit", graFEI_nLeptons_preFit)
215 candidate.addExtraInfo("graFEI_nOthers_preFit", graFEI_nOthers_preFit)
216
217 # Trivial decay tree
218 if n_nodes < 2:
219 b2.B2WARNING(
220 f"Skipping candidate with {n_nodes} reconstructed FSPs"
221 )
222
223 continue
224
225 # Initialize node features array
226 x_nodes = np.empty((n_nodes, len(self.node_features)))
227 x_dis = np.empty((n_nodes, len(self.discarded_features)))
228
229 # Fill node features array
230 for p, particle in enumerate(p_list):
231 for f, feat in enumerate(self.node_features):
232 feat = feat[feat.find("feat_") + 5:]
233 x_nodes[p, f] = vm.evaluate(feat, particle)
234 for f, feat in enumerate(self.discarded_features):
235 feat = feat[feat.find("feat_") + 5:]
236 x_dis[p, f] = vm.evaluate(feat, particle)
237 b2.B2DEBUG(11, "Node features:\n", x_nodes)
238
239 # Fill edge features array
240 x_edges = (compute_edge_features(self.edge_features, self.node_features + self.discarded_features,
241 np.concatenate([x_nodes, x_dis], axis=1)) if self.edge_features != [] else [])
242 edge_index = torch.tensor(list(itertools.permutations(range(n_nodes), 2)), dtype=torch.long)
243 b2.B2DEBUG(11, "Edge features:\n", x_edges)
244
245 # Fill global features # TODO: get them from basf2
246 x_global = (
247 np.array([[n_nodes]], dtype=float)
248 if self.glob_features != []
249 else []
250 )
251 b2.B2DEBUG(11, "Global features:\n", x_global)
252
253 # Fill tensor to assign each node to a graph (trivial since we have only one graph per decay)
254 torch_batch = torch.zeros(size=[n_nodes], dtype=torch.long)
255
256 # Set nans to zero, this is a surrogate value, may change in future
257 np.nan_to_num(x_nodes, copy=False)
258 np.nan_to_num(x_edges, copy=False)
259 np.nan_to_num(x_global, copy=False)
260
261 # Normalize any features that should be
262 if self.normalize is not None:
264 self.normalize,
265 self.node_features,
266 x_nodes,
267 self.edge_features,
268 x_edges,
269 self.glob_features,
270 x_global,
271 )
272
273 # Convert everything to torch tensors and/or send to some device in case
274 x = torch.tensor(x_nodes, dtype=torch.float).to(self.device)
275 edge_index = edge_index.t().contiguous().to(self.device)
276 edge_attr = torch.tensor(x_edges, dtype=torch.float).to(self.device)
277 u = torch.tensor(x_global, dtype=torch.float).to(self.device)
278 torch_batch = torch_batch.to(self.device)
279
280 # Create Batch object to be passed to model
281 batch = Batch(
282 x=x, edge_index=edge_index, edge_attr=edge_attr, u=u, batch=torch_batch
283 )
284
285 # Evaluate model
286 with torch.no_grad():
287 x_pred, e_pred, u_pred = self.model(batch)
288 # if self.use_amp:
289 # with autocast(enabled=True):
290 # x_pred, e_pred, u_pred = self.model(batch)
291 # else:
292 # x_pred, e_pred, u_pred = self.model(batch)
293
294 # Select edges from predictions
295 edge_probs = torch.softmax(e_pred, dim=1)
296 edge_probability, predicted_LCA = edge_probs.max(dim=1)
297
298 # Select masses from predictions
299 mass_probs = torch.softmax(x_pred, dim=1)
300 mass_probability, predicted_masses = mass_probs.max(dim=1)
301 b2.B2DEBUG(10, "Predicted mass classes:\n", predicted_masses)
302 b2.B2DEBUG(11, "Mass class probabilities:\n", mass_probability)
303
304 # Count number of predicted particles for each mass hypothesis
305 graFEI_nPhotons_postFit = (predicted_masses == 6).sum()
306 graFEI_nCharged_postFit = graFEI_nFSP - graFEI_nPhotons_postFit
307 graFEI_nElectrons_postFit = (predicted_masses == 1).sum()
308 graFEI_nMuons_postFit = (predicted_masses == 2).sum()
309 graFEI_nPions_postFit = (predicted_masses == 3).sum()
310 graFEI_nKaons_postFit = (predicted_masses == 4).sum()
311 graFEI_nProtons_postFit = (predicted_masses == 5).sum()
312 graFEI_nLeptons_postFit = graFEI_nElectrons_postFit + graFEI_nMuons_postFit
313 graFEI_nOthers_postFit = (predicted_masses == 0).sum()
314
315 # Assign new mass hypotheses as extraInfo
316 for i, p in enumerate(p_list):
317 p.addExtraInfo("graFEI_massHypothesis", predicted_masses[i])
318
319 # Get square matrices
320 edge_probability_square = torch.sparse_coo_tensor(
321 edge_index, edge_probability
322 ).to_dense()
323 predicted_LCA_square = torch.sparse_coo_tensor(
324 edge_index, predicted_LCA, dtype=int
325 ).to_dense()
326 b2.B2DEBUG(10, "Predicted LCA:\n", predicted_LCA_square)
327 b2.B2DEBUG(11, "Edge class probabilities:\n", edge_probability_square)
328
329 # Remove symmetric elements from probability
330 edge_probability_unique = edge_probability_square[
331 edge_probability_square.tril(diagonal=-1) > 0
332 ]
333
334 # Get particles predicted as matched by the model
335 predicted_matched = np.array(
336 [False if torch.all(i == 0) else True for i in predicted_LCA_square]
337 )
338 b2.B2DEBUG(10, "Predicted matched particles:\n", predicted_matched)
339 # Same but ignoring photons
340 predicted_matched_noPhotons = predicted_matched[masses != 22]
341
342 # Get number of predicted as unmatched
343 graFEI_nPredictedUnmatched = (~predicted_matched).sum()
344 graFEI_nPredictedUnmatched_noPhotons = (
345 (~predicted_matched_noPhotons).sum()
346 if predicted_matched_noPhotons.size != 0
347 else 0
348 )
349
350 # Get LCA of predicted matched only
351 predicted_LCA_square_matched = predicted_LCA_square[predicted_matched]
352 predicted_LCA_square_matched = predicted_LCA_square_matched[:, predicted_matched]
353
354 # Get predicted masses of predicted matched only
355 predicted_masses_matched = predicted_masses[predicted_matched]
356
357 # Check if LCA describes a tree graph
358 graFEI_validTree = 0
359 if not torch.all(predicted_LCA_square == 0):
360 try:
361 adjacency = lca_to_adjacency(predicted_LCA_square_matched)
362 graFEI_validTree = 1
363 except InvalidLCAMatrix:
364 pass
365
366 # Check if event is good, depending on the chosen sig-side LCA matrix/masses
367 graFEI_goodEvent = 0
368 if graFEI_validTree:
369 # Check if the event is good
370 good_decay, root_level, sig_side_fsps = select_good_decay(predicted_LCA_square_matched,
371 predicted_masses_matched,
372 self.sig_side_lcas,
373 self.sig_side_masses)
374 graFEI_goodEvent = int((self.max_level == root_level) and good_decay)
375
376 if graFEI_goodEvent:
377 # Find sig- and tag-side FSPs (1 = sig-side, 0 = tag-side)
378 p_list_matched = list(np.array(p_list)[predicted_matched])
379 for i, particle in enumerate(p_list_matched):
380 if i in sig_side_fsps:
381 particle.addExtraInfo("graFEI_sigSide", 1)
382 else:
383 particle.addExtraInfo("graFEI_sigSide", 0)
384
385 b2.B2DEBUG(11, "This LCA describes a valid tree")
386 b2.B2DEBUG(
387 11,
388 "Predicted LCA on matched particles:\n",
389 predicted_LCA_square_matched,
390 )
391 b2.B2DEBUG(11, "Adjacency matrix:\n", adjacency)
392
393 # Particles not assigned to B decays get -1
394 for particle in p_list:
395 if not particle.hasExtraInfo("graFEI_sigSide"):
396 particle.addExtraInfo("graFEI_sigSide", -1)
397
398 # Define B probabilities
399 graFEI_probEdgeProd = edge_probability_unique.prod().item()
400 graFEI_probEdgeMean = edge_probability_unique.mean().item()
401 graFEI_probEdgeGeom = torch.pow(edge_probability_unique.prod(), 1/n_nodes).item()
402
403 # Add extra info for each B candidate
404 candidate.addExtraInfo("graFEI_probEdgeProd", graFEI_probEdgeProd)
405 candidate.addExtraInfo("graFEI_probEdgeMean", graFEI_probEdgeMean)
406 candidate.addExtraInfo("graFEI_probEdgeGeom", graFEI_probEdgeGeom)
407 candidate.addExtraInfo("graFEI_validTree", graFEI_validTree)
408 candidate.addExtraInfo("graFEI_goodEvent", graFEI_goodEvent)
409 candidate.addExtraInfo("graFEI_nPhotons_postFit", graFEI_nPhotons_postFit)
410 candidate.addExtraInfo("graFEI_nCharged_postFit", graFEI_nCharged_postFit)
411 candidate.addExtraInfo("graFEI_nElectrons_postFit", graFEI_nElectrons_postFit)
412 candidate.addExtraInfo("graFEI_nMuons_postFit", graFEI_nMuons_postFit)
413 candidate.addExtraInfo("graFEI_nPions_postFit", graFEI_nPions_postFit)
414 candidate.addExtraInfo("graFEI_nKaons_postFit", graFEI_nKaons_postFit)
415 candidate.addExtraInfo("graFEI_nProtons_postFit", graFEI_nProtons_postFit)
416 candidate.addExtraInfo("graFEI_nLeptons_postFit", graFEI_nLeptons_postFit)
417 candidate.addExtraInfo("graFEI_nOthers_postFit", graFEI_nOthers_postFit)
418 candidate.addExtraInfo("graFEI_nPredictedUnmatched", graFEI_nPredictedUnmatched)
419 candidate.addExtraInfo("graFEI_nPredictedUnmatched_noPhotons", graFEI_nPredictedUnmatched_noPhotons)
420
421 # Add MC truth information
422 if self.storeTrueInfo:
423 # Get the true IDs of the ancestors (if it's a B)
424 parentID = np.array([vm.evaluate("ancestorBIndex", p) for p in p_list], dtype=int)
425 b2.B2DEBUG(10, "Ancestor true ID:\n", parentID)
426
427 # Get particle indices
428 p_indices = np.array(
429 [
430 p.getMCParticle().getArrayIndex() if parentID[i] >= 0 else -1
431 for (i, p) in enumerate(p_list)
432 ]
433 )
434 # Get particle masses
435 p_masses = masses_to_classes(
436 np.array(
437 [
438 p.getMCParticle().getPDG() if parentID[i] >= 0 else -1
439 for (i, p) in enumerate(p_list)
440 ]
441 )
442 )
443 b2.B2DEBUG(10, "True mass classes:\n", p_masses)
444 # And primary information
445 evt_primary = np.array(
446 [
447 p.getMCParticle().isPrimaryParticle()
448 if parentID[i] >= 0
449 else False
450 for (i, p) in enumerate(p_list)
451 ]
452 )
453 b2.B2DEBUG(10, "Is primary particle:\n", evt_primary)
454
455 # Get unique B indices associated to each predicted matched particle which is also a primary
456 # The idea is that if a primary particle coming from the other B is categorized as unmatched,
457 # then it's ok and the decay could still have a perfectLCA
458 B_indices = parentID[np.logical_and(evt_primary, predicted_matched)]
459 b2.B2DEBUG(
460 10, "Ancestor ID of predicted matched particles:\n", B_indices
461 )
462 B_indices = list(set(B_indices))
463
464 # Initialize truth-matching variables
465 graFEI_truth_perfectLCA = 0 # 1 if LCA perfectly reconstructed
466 graFEI_truth_isSemileptonic = -1 # 0 if hadronic, 1 is semileptonic, -1 if not matched
467 graFEI_truth_nFSP = -1 # Number of true FSPs
468 graFEI_truth_perfectMasses = int((predicted_masses.numpy() == p_masses).all()
469 ) # Check if all the masses are predicted correctly
470 graFEI_truth_nPhotons = (p_masses == 6).sum()
471 graFEI_truth_nElectrons = (p_masses == 1).sum()
472 graFEI_truth_nMuons = (p_masses == 2).sum()
473 graFEI_truth_nPions = (p_masses == 3).sum()
474 graFEI_truth_nKaons = (p_masses == 4).sum()
475 graFEI_truth_nProtons = (p_masses == 5).sum()
476 graFEI_truth_nOthers = (p_masses == 0).sum()
477
478 # Get the generated B's
479 gen_list = Belle2.PyStoreObj(self.mc_particle)
480
481 # Iterate over generated Ups
482 if self.mc_particle == "Upsilon(4S):MC" and gen_list.getListSize() > 1:
483 b2.B2WARNING(
484 f"Found {gen_list.getListSize()} true Upsilon(4S) in the generated MC (??)")
485
486 if gen_list.getListSize() > 0:
487 # Here we look if the candidate has a perfectly reconstructed LCA
488 for genP in gen_list.obj():
489 mcp = genP.getMCParticle()
490 # If storing true info on B decays and we have matched paricles coming
491 # from different Bs the decay will not have a perfectLCA
492 if self.mc_particle != "Upsilon(4S):MC" and len(B_indices) != 1:
493 break
494
495 # Get array index of MC particle
496 array_index = mcp.getArrayIndex()
497
498 # If we are reconstructing Bs, skip the other in the event
499 if self.mc_particle != "Upsilon(4S):MC" and array_index != B_indices[0]:
500 continue
501
502 # Write leaf history
503 (
504 leaf_hist,
505 levels,
506 _,
507 _,
508 semilep_flag,
509 ) = write_hist(
510 particle=mcp,
511 leaf_hist={},
512 levels={},
513 hist=[],
514 pdg={},
515 leaf_pdg={},
516 semilep_flag=False,
517 )
518
519 # Skip B decays with trivial LCA (should be always false except for B -> nunu ?)
520 if len(leaf_hist) < 2:
521 continue
522
523 # Initialize LCA...
524 true_LCA_square = np.zeros(
525 [len(leaf_hist), len(leaf_hist)], dtype=int
526 )
527
528 # Number of true FSPs
529 graFEI_truth_nFSP = len(leaf_hist)
530
531 # ... and fill it!
532 for x, y in itertools.combinations(enumerate(leaf_hist), 2):
533 intersection = [
534 i for i in leaf_hist[x[1]] if i in leaf_hist[y[1]]
535 ]
536 true_LCA_square[x[0], y[0]] = levels[intersection[-1]]
537 true_LCA_square[y[0], x[0]] = levels[intersection[-1]]
538
539 x_leaves = p_indices
540 y_leaves = list(leaf_hist.keys())
541
542 # Get LCA indices in order that the leaves appear in reconstructed particles
543 # Secondaries aren't in the LCA leaves list so they get a 0
544 locs = np.array(
545 [
546 np.where(y_leaves == i)[0].item()
547 if (i in y_leaves)
548 else 0
549 for i in x_leaves
550 ],
551 dtype=int,
552 )
553
554 # Insert dummy rows for secondaries
555 true_LCA_square = true_LCA_square[locs, :][:, locs]
556
557 # Set everything that's not primary (unmatched and secondaries) rows.cols to 0
558 # Note we only consider the subset of leaves that made it into x_rows
559 x_rows = np.array(
560 [
561 vm.evaluate("ancestorBIndex", p) == array_index
562 for p in p_list
563 ]
564 ) if self.mc_particle != "Upsilon(4S):MC" else evt_primary
565
566 primaries_from_right_cand = np.logical_and(evt_primary, x_rows)
567
568 # Set the rows
569 true_LCA_square = np.where(
570 primaries_from_right_cand, true_LCA_square, 0
571 )
572 # Set the columns
573 true_LCA_square = np.where(
574 primaries_from_right_cand[:, None], true_LCA_square, 0
575 )
576
577 # Convert LCA to tensor
578 true_LCA_square = torch.tensor(true_LCA_square, dtype=int)
579 b2.B2DEBUG(10, "True LCA:\n", true_LCA_square)
580
581 # Check if perfect LCA
582 if (true_LCA_square == predicted_LCA_square).all():
583 graFEI_truth_perfectLCA = 1
584 b2.B2DEBUG(10, "LCA perfectly reconstructed!")
585
586 # Assign semileptonic flag
587 graFEI_truth_isSemileptonic = int(semilep_flag)
588
589 # Perfect event = perfectLCA and perfectMasses
590 graFEI_truth_perfectEvent = int(graFEI_truth_perfectLCA and graFEI_truth_perfectMasses)
591
592 # Write extra info
593 candidate.addExtraInfo("graFEI_truth_perfectLCA", graFEI_truth_perfectLCA)
594 candidate.addExtraInfo("graFEI_truth_perfectMasses", graFEI_truth_perfectMasses)
595 candidate.addExtraInfo("graFEI_truth_perfectEvent", graFEI_truth_perfectEvent)
596 candidate.addExtraInfo("graFEI_truth_isSemileptonic", graFEI_truth_isSemileptonic)
597 candidate.addExtraInfo("graFEI_truth_nFSP", graFEI_truth_nFSP)
598 candidate.addExtraInfo("graFEI_truth_nPhotons", graFEI_truth_nPhotons)
599 candidate.addExtraInfo("graFEI_truth_nElectrons", graFEI_truth_nElectrons)
600 candidate.addExtraInfo("graFEI_truth_nMuons", graFEI_truth_nMuons)
601 candidate.addExtraInfo("graFEI_truth_nPions", graFEI_truth_nPions)
602 candidate.addExtraInfo("graFEI_truth_nKaons", graFEI_truth_nKaons)
603 candidate.addExtraInfo("graFEI_truth_nProtons", graFEI_truth_nProtons)
604 candidate.addExtraInfo("graFEI_truth_nOthers", graFEI_truth_nOthers)
a (simplified) python wrapper for StoreObjPtr.
Definition: PyStoreObj.h:67

◆ initialize()

def initialize (   self)
Called at the beginning.

Definition at line 79 of file GraFEIModule.py.

79 def initialize(self):
80 """
81 Called at the beginning.
82 """
83 # Get weights and configs from the DB if they are not provided from the user
84 if not self.cfg_path:
85 config = Belle2.DBAccessorBase(
86 Belle2.DBStoreEntry.c_RawFile, self.payload_config_name, True
87 )
88 self.cfg_path = config.getFilename()
89 if not self.param_file:
91 Belle2.DBStoreEntry.c_RawFile, self.payload_model_name, True
92 )
93 self.param_file = model.getFilename()
94
95
96 self.storeTrueInfo = Belle2.Environment.Instance().isMC()
97
98
99 self.device = torch.device(
100 "cuda" if (self.gpu and torch.cuda.is_available()) else "cpu"
101 )
102
103 # Load configs
104 cfg_file = open(self.cfg_path, "r")
105
106 self.configs = yaml.safe_load(cfg_file)
107
108
109 self.mc_particle = None
110
111 self.max_level = None
112 # B or Ups reco? 0 = Ups, 1 = B0, 2 = B+
113 if self.configs["model"]["B_reco"] == 0:
114 self.mc_particle = "Upsilon(4S):MC"
115 self.max_level = 6
116 elif self.configs["model"]["B_reco"] == 1:
117 self.mc_particle = "B0:MC"
118 self.max_level = 5
119 elif self.configs["model"]["B_reco"] == 2:
120 self.mc_particle = "B+:MC"
121 self.max_level = 5
122 else:
123 b2.B2FATAL("The B_reco setting in the config file is incorrect.")
124
125
126 self.normalize = self.configs["dataset"]["config"]["normalize"]
127
128
129 self.use_amp = self.configs["train"][
130 "mixed_precision"
131 ] and self.device == torch.device("cuda")
132
133
134 self.node_features = self.configs["dataset"]["config"]["features"]
135
136 self.edge_features = self.configs["dataset"]["config"]["edge_features"]
137
138 self.glob_features = self.configs["dataset"]["config"]["global_features"]
139
140 # Naming convention
141 self.node_features = [f"feat_{name}" for name in self.node_features] if self.node_features else []
142 self.edge_features = [f"edge_{name}" for name in self.edge_features] if self.edge_features else []
143 self.glob_features = [f"glob_{name}" for name in self.glob_features] if self.glob_features else []
144
145 self.discarded_features = ["feat_x", "feat_y", "feat_z", "feat_px", "feat_py", "feat_p"]
146
147 # Extract the number of features
148 n_infeatures = len(self.node_features)
149 e_infeatures = len(self.edge_features)
150 g_infeatures = len(self.glob_features)
151
152
154 self.model = GraFEIModel(
155 nfeat_in_dim=n_infeatures,
156 efeat_in_dim=e_infeatures,
157 gfeat_in_dim=g_infeatures,
158 **self.configs["model"],
159 )
160
161 # Load paramaters' values
162 self.model.load_state_dict(
163 torch.load(self.param_file, map_location=self.device)["model"]
164 )
165
166 # Activate evaluation mode
167 self.model.eval()
168 # Push model to GPU in case
169 self.model.to(self.device)
170
171 b2.B2DEBUG(10, "Model structure:\n", {self.model})
172
Base class for DBObjPtr and DBArray for easier common treatment.
static Environment & Instance()
Static method to get a reference to the Environment instance.
Definition: Environment.cc:28

Member Data Documentation

◆ cfg_path

cfg_path

Config yaml file path.

Definition at line 65 of file GraFEIModule.py.

◆ configs

configs

Config file.

Definition at line 106 of file GraFEIModule.py.

◆ device

device

Figure out which device all this is running on - CPU or GPU.

Definition at line 99 of file GraFEIModule.py.

◆ discarded_features

discarded_features

Discarded node features.

Definition at line 145 of file GraFEIModule.py.

◆ edge_features

edge_features

Edge features.

Definition at line 136 of file GraFEIModule.py.

◆ glob_features

glob_features

Global features.

Definition at line 138 of file GraFEIModule.py.

◆ gpu

gpu

If running on GPU.

Definition at line 73 of file GraFEIModule.py.

◆ max_level

max_level

Max LCAS level.

Definition at line 111 of file GraFEIModule.py.

◆ mc_particle

mc_particle

Top MC particle.

Definition at line 109 of file GraFEIModule.py.

◆ model

model

The model The correct edge_classes is taken from the config file.

Definition at line 154 of file GraFEIModule.py.

◆ node_features

node_features

Node features.

Definition at line 134 of file GraFEIModule.py.

◆ normalize

normalize

Normalize features.

Definition at line 126 of file GraFEIModule.py.

◆ param_file

param_file

PyTorch parameter file path.

Definition at line 67 of file GraFEIModule.py.

◆ particle_list

particle_list

Input particle list.

Definition at line 63 of file GraFEIModule.py.

◆ payload_config_name

payload_config_name

Config file name in the payload.

Definition at line 75 of file GraFEIModule.py.

◆ payload_model_name

payload_model_name

Model file name in the payload.

Definition at line 77 of file GraFEIModule.py.

◆ sig_side_lcas

sig_side_lcas

Chosen sig-side LCAS.

Definition at line 69 of file GraFEIModule.py.

◆ sig_side_masses

sig_side_masses

Chosen sig-side mass hypotheses.

Definition at line 71 of file GraFEIModule.py.

◆ storeTrueInfo

storeTrueInfo

Figure out if we re running on data or MC.

Definition at line 96 of file GraFEIModule.py.

◆ use_amp

use_amp

Mixed precision.

Definition at line 129 of file GraFEIModule.py.


The documentation for this class was generated from the following file: