15from ROOT
import Belle2
16from variables
import variables
as vm
18from torch_geometric.data
import Batch
19from grafei.modules.LCASaverModule
import get_object_list, write_hist
20from grafei.model.geometric_network
import GraFEIModel
21from grafei.model.normalize_features
import normalize_features
22from grafei.model.edge_features
import compute_edge_features
23from grafei.model.lca_to_adjacency
import lca_to_adjacency, InvalidLCAMatrix, select_good_decay
24from grafei.model.tree_utils
import masses_to_classes
26warnings.filterwarnings(
27 action=
"ignore", category=RuntimeWarning, message=
"Mean of empty slice.*"
33 Applies graFEI model to a particle list in basf2.
34 GraFEI information
is stored
as extraInfos.
37 particle_list (str): Name of particle list.
38 cfg_path (str): Path to config file. If `
None` the config file
in the
global tag
is used.
39 param_file (str): Path to parameter file containing the model. If `
None` the parameter file
in the
global tag
is used.
40 sig_side_lcas (list): List containing LCAS matrix of signal-side.
41 sig_side_masses (list): List containing mass hypotheses of signal-side.
42 gpu (bool): Whether to run on a GPU.
43 payload_config_name (str): Name of config file payload. The default should be kept,
except in basf2 examples.
44 payload_model_name (str): Name of model file payload. The default should be kept,
except in basf2 examples.
55 payload_config_name="graFEIConfigFile",
56 payload_model_name="graFEIModelFile",
69 self.sig_side_lcas = torch.tensor(sig_side_lcas) if sig_side_lcas
else None
81 Called at the beginning.
100 "cuda" if (self.
gpu and torch.cuda.is_available())
else "cpu"
113 if self.
configs[
"model"][
"B_reco"] == 0:
116 elif self.
configs[
"model"][
"B_reco"] == 1:
119 elif self.
configs[
"model"][
"B_reco"] == 2:
123 b2.B2FATAL(
"The B_reco setting in the config file is incorrect.")
131 ]
and self.
device == torch.device(
"cuda")
155 nfeat_in_dim=n_infeatures,
156 efeat_in_dim=e_infeatures,
157 gfeat_in_dim=g_infeatures,
162 self.
model.load_state_dict(
171 b2.B2DEBUG(10,
"Model structure:\n", {self.
model})
175 Called at the beginning of each event.
177 b2.B2DEBUG(10, "---- Processing new event ----")
183 for candidate
in candidate_list:
185 p_list = get_object_list(candidate.getFinalStateDaughters())
188 n_nodes = len(p_list)
191 masses = np.array([abs(p.getPDGCode())
for p
in p_list])
194 graFEI_nFSP = n_nodes
195 graFEI_nPhotons_preFit = (masses == 22).sum()
196 graFEI_nCharged_preFit = graFEI_nFSP - graFEI_nPhotons_preFit
197 graFEI_nElectrons_preFit = (masses == 11).sum()
198 graFEI_nMuons_preFit = (masses == 13).sum()
199 graFEI_nPions_preFit = (masses == 211).sum()
200 graFEI_nKaons_preFit = (masses == 321).sum()
201 graFEI_nProtons_preFit = (masses == 2212).sum()
202 graFEI_nLeptons_preFit = graFEI_nElectrons_preFit + graFEI_nMuons_preFit
203 graFEI_nOthers_preFit = graFEI_nCharged_preFit - \
204 (graFEI_nLeptons_preFit + graFEI_nPions_preFit + graFEI_nKaons_preFit + graFEI_nProtons_preFit)
206 candidate.addExtraInfo(
"graFEI_nFSP", graFEI_nFSP)
207 candidate.addExtraInfo(
"graFEI_nCharged_preFit", graFEI_nCharged_preFit)
208 candidate.addExtraInfo(
"graFEI_nPhotons_preFit", graFEI_nPhotons_preFit)
209 candidate.addExtraInfo(
"graFEI_nElectrons_preFit", graFEI_nElectrons_preFit)
210 candidate.addExtraInfo(
"graFEI_nMuons_preFit", graFEI_nMuons_preFit)
211 candidate.addExtraInfo(
"graFEI_nPions_preFit", graFEI_nPions_preFit)
212 candidate.addExtraInfo(
"graFEI_nKaons_preFit", graFEI_nKaons_preFit)
213 candidate.addExtraInfo(
"graFEI_nProtons_preFit", graFEI_nProtons_preFit)
214 candidate.addExtraInfo(
"graFEI_nLeptons_preFit", graFEI_nLeptons_preFit)
215 candidate.addExtraInfo(
"graFEI_nOthers_preFit", graFEI_nOthers_preFit)
220 f
"Skipping candidate with {n_nodes} reconstructed FSPs"
230 for p, particle
in enumerate(p_list):
232 feat = feat[feat.find(
"feat_") + 5:]
233 x_nodes[p, f] = vm.evaluate(feat, particle)
235 feat = feat[feat.find(
"feat_") + 5:]
236 x_dis[p, f] = vm.evaluate(feat, particle)
237 b2.B2DEBUG(11,
"Node features:\n", x_nodes)
241 np.concatenate([x_nodes, x_dis], axis=1))
if self.
edge_features != []
else [])
242 edge_index = torch.tensor(list(itertools.permutations(range(n_nodes), 2)), dtype=torch.long)
243 b2.B2DEBUG(11,
"Edge features:\n", x_edges)
247 np.array([[n_nodes]], dtype=float)
251 b2.B2DEBUG(11,
"Global features:\n", x_global)
254 torch_batch = torch.zeros(size=[n_nodes], dtype=torch.long)
257 np.nan_to_num(x_nodes, copy=
False)
258 np.nan_to_num(x_edges, copy=
False)
259 np.nan_to_num(x_global, copy=
False)
274 x = torch.tensor(x_nodes, dtype=torch.float).to(self.
device)
275 edge_index = edge_index.t().contiguous().to(self.
device)
276 edge_attr = torch.tensor(x_edges, dtype=torch.float).to(self.
device)
277 u = torch.tensor(x_global, dtype=torch.float).to(self.
device)
278 torch_batch = torch_batch.to(self.
device)
282 x=x, edge_index=edge_index, edge_attr=edge_attr, u=u, batch=torch_batch
286 with torch.no_grad():
287 x_pred, e_pred, u_pred = self.
model(batch)
295 edge_probs = torch.softmax(e_pred, dim=1)
296 edge_probability, predicted_LCA = edge_probs.max(dim=1)
299 mass_probs = torch.softmax(x_pred, dim=1)
300 mass_probability, predicted_masses = mass_probs.max(dim=1)
301 b2.B2DEBUG(10,
"Predicted mass classes:\n", predicted_masses)
302 b2.B2DEBUG(11,
"Mass class probabilities:\n", mass_probability)
305 graFEI_nPhotons_postFit = (predicted_masses == 6).sum()
306 graFEI_nCharged_postFit = graFEI_nFSP - graFEI_nPhotons_postFit
307 graFEI_nElectrons_postFit = (predicted_masses == 1).sum()
308 graFEI_nMuons_postFit = (predicted_masses == 2).sum()
309 graFEI_nPions_postFit = (predicted_masses == 3).sum()
310 graFEI_nKaons_postFit = (predicted_masses == 4).sum()
311 graFEI_nProtons_postFit = (predicted_masses == 5).sum()
312 graFEI_nLeptons_postFit = graFEI_nElectrons_postFit + graFEI_nMuons_postFit
313 graFEI_nOthers_postFit = (predicted_masses == 0).sum()
316 for i, p
in enumerate(p_list):
317 p.addExtraInfo(
"graFEI_massHypothesis", predicted_masses[i])
320 edge_probability_square = torch.sparse_coo_tensor(
321 edge_index, edge_probability
323 predicted_LCA_square = torch.sparse_coo_tensor(
324 edge_index, predicted_LCA, dtype=int
326 b2.B2DEBUG(10,
"Predicted LCA:\n", predicted_LCA_square)
327 b2.B2DEBUG(11,
"Edge class probabilities:\n", edge_probability_square)
330 edge_probability_unique = edge_probability_square[
331 edge_probability_square.tril(diagonal=-1) > 0
335 predicted_matched = np.array(
336 [
False if torch.all(i == 0)
else True for i
in predicted_LCA_square]
338 b2.B2DEBUG(10,
"Predicted matched particles:\n", predicted_matched)
340 predicted_matched_noPhotons = predicted_matched[masses != 22]
343 graFEI_nPredictedUnmatched = (~predicted_matched).sum()
344 graFEI_nPredictedUnmatched_noPhotons = (
345 (~predicted_matched_noPhotons).sum()
346 if predicted_matched_noPhotons.size != 0
351 predicted_LCA_square_matched = predicted_LCA_square[predicted_matched]
352 predicted_LCA_square_matched = predicted_LCA_square_matched[:, predicted_matched]
355 predicted_masses_matched = predicted_masses[predicted_matched]
359 if not torch.all(predicted_LCA_square == 0):
363 except InvalidLCAMatrix:
370 good_decay, root_level, sig_side_fsps = select_good_decay(predicted_LCA_square_matched,
371 predicted_masses_matched,
374 graFEI_goodEvent = int((self.
max_level == root_level)
and good_decay)
378 p_list_matched = list(np.array(p_list)[predicted_matched])
379 for i, particle
in enumerate(p_list_matched):
380 if i
in sig_side_fsps:
381 particle.addExtraInfo(
"graFEI_sigSide", 1)
383 particle.addExtraInfo(
"graFEI_sigSide", 0)
385 b2.B2DEBUG(11,
"This LCA describes a valid tree")
388 "Predicted LCA on matched particles:\n",
389 predicted_LCA_square_matched,
391 b2.B2DEBUG(11,
"Adjacency matrix:\n", adjacency)
394 for particle
in p_list:
395 if not particle.hasExtraInfo(
"graFEI_sigSide"):
396 particle.addExtraInfo(
"graFEI_sigSide", -1)
399 graFEI_probEdgeProd = edge_probability_unique.prod().item()
400 graFEI_probEdgeMean = edge_probability_unique.mean().item()
401 graFEI_probEdgeGeom = torch.pow(edge_probability_unique.prod(), 1/n_nodes).item()
404 candidate.addExtraInfo(
"graFEI_probEdgeProd", graFEI_probEdgeProd)
405 candidate.addExtraInfo(
"graFEI_probEdgeMean", graFEI_probEdgeMean)
406 candidate.addExtraInfo(
"graFEI_probEdgeGeom", graFEI_probEdgeGeom)
407 candidate.addExtraInfo(
"graFEI_validTree", graFEI_validTree)
408 candidate.addExtraInfo(
"graFEI_goodEvent", graFEI_goodEvent)
409 candidate.addExtraInfo(
"graFEI_nPhotons_postFit", graFEI_nPhotons_postFit)
410 candidate.addExtraInfo(
"graFEI_nCharged_postFit", graFEI_nCharged_postFit)
411 candidate.addExtraInfo(
"graFEI_nElectrons_postFit", graFEI_nElectrons_postFit)
412 candidate.addExtraInfo(
"graFEI_nMuons_postFit", graFEI_nMuons_postFit)
413 candidate.addExtraInfo(
"graFEI_nPions_postFit", graFEI_nPions_postFit)
414 candidate.addExtraInfo(
"graFEI_nKaons_postFit", graFEI_nKaons_postFit)
415 candidate.addExtraInfo(
"graFEI_nProtons_postFit", graFEI_nProtons_postFit)
416 candidate.addExtraInfo(
"graFEI_nLeptons_postFit", graFEI_nLeptons_postFit)
417 candidate.addExtraInfo(
"graFEI_nOthers_postFit", graFEI_nOthers_postFit)
418 candidate.addExtraInfo(
"graFEI_nPredictedUnmatched", graFEI_nPredictedUnmatched)
419 candidate.addExtraInfo(
"graFEI_nPredictedUnmatched_noPhotons", graFEI_nPredictedUnmatched_noPhotons)
424 parentID = np.array([vm.evaluate(
"ancestorBIndex", p)
for p
in p_list], dtype=int)
425 b2.B2DEBUG(10,
"Ancestor true ID:\n", parentID)
428 p_indices = np.array(
430 p.getMCParticle().getArrayIndex()
if parentID[i] >= 0
else -1
431 for (i, p)
in enumerate(p_list)
435 p_masses = masses_to_classes(
438 p.getMCParticle().getPDG()
if parentID[i] >= 0
else -1
439 for (i, p)
in enumerate(p_list)
443 b2.B2DEBUG(10,
"True mass classes:\n", p_masses)
445 evt_primary = np.array(
447 p.getMCParticle().isPrimaryParticle()
450 for (i, p)
in enumerate(p_list)
453 b2.B2DEBUG(10,
"Is primary particle:\n", evt_primary)
458 B_indices = parentID[np.logical_and(evt_primary, predicted_matched)]
460 10,
"Ancestor ID of predicted matched particles:\n", B_indices
462 B_indices = list(set(B_indices))
465 graFEI_truth_perfectLCA = 0
466 graFEI_truth_isSemileptonic = -1
467 graFEI_truth_nFSP = -1
468 graFEI_truth_perfectMasses = int((predicted_masses.numpy() == p_masses).all()
470 graFEI_truth_nPhotons = (p_masses == 6).sum()
471 graFEI_truth_nElectrons = (p_masses == 1).sum()
472 graFEI_truth_nMuons = (p_masses == 2).sum()
473 graFEI_truth_nPions = (p_masses == 3).sum()
474 graFEI_truth_nKaons = (p_masses == 4).sum()
475 graFEI_truth_nProtons = (p_masses == 5).sum()
476 graFEI_truth_nOthers = (p_masses == 0).sum()
482 if self.
mc_particle ==
"Upsilon(4S):MC" and gen_list.getListSize() > 1:
484 f
"Found {gen_list.getListSize()} true Upsilon(4S) in the generated MC (??)")
486 if gen_list.getListSize() > 0:
488 for genP
in gen_list.obj():
489 mcp = genP.getMCParticle()
492 if self.
mc_particle !=
"Upsilon(4S):MC" and len(B_indices) != 1:
496 array_index = mcp.getArrayIndex()
499 if self.
mc_particle !=
"Upsilon(4S):MC" and array_index != B_indices[0]:
520 if len(leaf_hist) < 2:
524 true_LCA_square = np.zeros(
525 [len(leaf_hist), len(leaf_hist)], dtype=int
529 graFEI_truth_nFSP = len(leaf_hist)
532 for x, y
in itertools.combinations(enumerate(leaf_hist), 2):
534 i
for i
in leaf_hist[x[1]]
if i
in leaf_hist[y[1]]
536 true_LCA_square[x[0], y[0]] = levels[intersection[-1]]
537 true_LCA_square[y[0], x[0]] = levels[intersection[-1]]
540 y_leaves = list(leaf_hist.keys())
546 np.where(y_leaves == i)[0].item()
555 true_LCA_square = true_LCA_square[locs, :][:, locs]
561 vm.evaluate(
"ancestorBIndex", p) == array_index
564 )
if self.
mc_particle !=
"Upsilon(4S):MC" else evt_primary
566 primaries_from_right_cand = np.logical_and(evt_primary, x_rows)
569 true_LCA_square = np.where(
570 primaries_from_right_cand, true_LCA_square, 0
573 true_LCA_square = np.where(
574 primaries_from_right_cand[:,
None], true_LCA_square, 0
578 true_LCA_square = torch.tensor(true_LCA_square, dtype=int)
579 b2.B2DEBUG(10,
"True LCA:\n", true_LCA_square)
582 if (true_LCA_square == predicted_LCA_square).all():
583 graFEI_truth_perfectLCA = 1
584 b2.B2DEBUG(10,
"LCA perfectly reconstructed!")
587 graFEI_truth_isSemileptonic = int(semilep_flag)
590 graFEI_truth_perfectEvent = int(graFEI_truth_perfectLCA
and graFEI_truth_perfectMasses)
593 candidate.addExtraInfo(
"graFEI_truth_perfectLCA", graFEI_truth_perfectLCA)
594 candidate.addExtraInfo(
"graFEI_truth_perfectMasses", graFEI_truth_perfectMasses)
595 candidate.addExtraInfo(
"graFEI_truth_perfectEvent", graFEI_truth_perfectEvent)
596 candidate.addExtraInfo(
"graFEI_truth_isSemileptonic", graFEI_truth_isSemileptonic)
597 candidate.addExtraInfo(
"graFEI_truth_nFSP", graFEI_truth_nFSP)
598 candidate.addExtraInfo(
"graFEI_truth_nPhotons", graFEI_truth_nPhotons)
599 candidate.addExtraInfo(
"graFEI_truth_nElectrons", graFEI_truth_nElectrons)
600 candidate.addExtraInfo(
"graFEI_truth_nMuons", graFEI_truth_nMuons)
601 candidate.addExtraInfo(
"graFEI_truth_nPions", graFEI_truth_nPions)
602 candidate.addExtraInfo(
"graFEI_truth_nKaons", graFEI_truth_nKaons)
603 candidate.addExtraInfo(
"graFEI_truth_nProtons", graFEI_truth_nProtons)
604 candidate.addExtraInfo(
"graFEI_truth_nOthers", graFEI_truth_nOthers)
Base class for DBObjPtr and DBArray for easier common treatment.
static Environment & Instance()
Static method to get a reference to the Environment instance.
a (simplified) python wrapper for StoreObjPtr.
storeTrueInfo
Figure out if we re running on data or MC.
payload_model_name
Model file name in the payload.
param_file
PyTorch parameter file path.
model
The model The correct edge_classes is taken from the config file.
particle_list
Input particle list.
discarded_features
Discarded node features.
def __init__(self, particle_list, cfg_path=None, param_file=None, sig_side_lcas=None, sig_side_masses=None, gpu=False, payload_config_name="graFEIConfigFile", payload_model_name="graFEIModelFile")
edge_features
Edge features.
normalize
Normalize features.
node_features
Node features.
sig_side_masses
Chosen sig-side mass hypotheses.
device
Figure out which device all this is running on - CPU or GPU.
mc_particle
Top MC particle.
cfg_path
Config yaml file path.
glob_features
Global features.
payload_config_name
Config file name in the payload.
sig_side_lcas
Chosen sig-side LCAS.