Source code for smartBKG.NN_filter_module

##########################################################################
# basf2 (Belle II Analysis Software Framework)                           #
# Author: The Belle II Collaboration                                     #
#                                                                        #
# See git log for contributors and copyright holders.                    #
# This file is licensed under LGPL-3.0, see LICENSE.md.                  #
##########################################################################
import os
import numpy as np
import pandas as pd

from collections import defaultdict
import basf2 as b2
from ROOT import Belle2
from ROOT.Belle2 import DBAccessorBase, DBStoreEntry

from smartBKG import TOKENIZE_DICT, PREPROC_CONFIG, MODEL_CONFIG


def check_status_bit(status_bit):
    """
    Returns True if conditions are satisfied (not an unusable particle)
    """
    return (
        (status_bit & 1 << 4 == 0) &  # IsVirtual
        (status_bit & 1 << 5 == 0) &  # Initial
        (status_bit & 1 << 6 == 0) &  # ISRPhoton
        (status_bit & 1 << 7 == 0)  # FSRPhoton
    )


[docs]class NNFilterModule(b2.Module): """ Goals: 1. Build a graph from an event composed of MCParticles 2. Apply the well-trained model for reweighting or sampling method to get a score 3. Execute reweighting or sampling process to get a weight Arguments: model_file(str): Path to saved model model_config(dict): Parameters to build the model preproc_config(dict): Parameters to provide information for preprocessing threshold(float): Threshold for event selection using reweighting method, value *None* indicating sampling mehtod extra_info_var(str): Name of eventExtraInfo to save model prediction to global_tag(str): Tag in ConditionDB where the well trained model was stored payload(str): Payload for the well trained model in global tag Returns: Pass or rejected according to random sampling or selection with the given threshold Note: Score after the NN filter indicating the probability of the event to pass is saved under ``EventExtraInfo.extra_info_var``. Use ``eventExtraInfo(extra_info_var)`` in ``modularAnalysis.variablesToNtuple`` or ``additionalBranches=["EventExtraInfo"]`` in ``mdst.add_mdst_output`` to have access to the scores. """ def __init__( self, model_file=None, model_config=MODEL_CONFIG, preproc_config=PREPROC_CONFIG, threshold=None, extra_info_var="NN_prediction", global_tag="SmartBKG_GATGAP", payload="GATGAPgen.pth" ): """ Initialise the class. :param model_file: TODO :param model_config: TODO :param preproc_config: TODO :param threshold: TODO :param extra_info_var: TODO :param global_tag: TODO :param payload: TODO """ super().__init__() #: TODO self.model_file = model_file #: TODO self.model_config = model_config #: TODO self.preproc_config = preproc_config #: TODO self.threshold = threshold #: TODO self.extra_info_var = extra_info_var #: TODO self.payload = payload # set additional database conditions for trained neural network b2.conditions.prepend_globaltag(global_tag) def initialize(self): """ Initialise module before any events are processed """ import torch from smartBKG.models.gatgap import GATGAPModel DEVICE = torch.device("cpu") # read trained model parameters from if not self.model_file: accessor = DBAccessorBase(DBStoreEntry.c_RawFile, self.payload, True) self.model_file = accessor.getFilename() trained_parameters = torch.load(self.model_file, map_location=DEVICE) #: model with trained parameters self.model = GATGAPModel(**self.model_config) self.model.load_state_dict(trained_parameters['model_state_dict']) #: StoreArray to save weights to self.EventExtraInfo = Belle2.PyStoreObj('EventExtraInfo') if not self.EventExtraInfo.isValid(): self.EventExtraInfo.registerInDataStore() #: generated variables self.gen_vars = defaultdict(list) #: node features self.out_features = self.preproc_config['features'] if 'PDG' in self.preproc_config['features']: self.out_features.remove('PDG') def event(self): """ Collect information from database, build graphs, make predictions and select through sampling or threshold """ import torch # Initialize for every event self.gen_vars.clear() # Need to create the eventExtraInfo entry for each event self.EventExtraInfo.create() mcplist = Belle2.PyStoreArray("MCParticles") array_indices = [] mother_indices = [] for i, mcp in enumerate(mcplist): if mcp.isPrimaryParticle(): # Check mc particle is useable if not check_status_bit(mcp.getStatus()): continue prodTime = mcp.getProductionTime() # record the production time of root particle for the correction of jitter if i == 0: root_prodTime = prodTime prodTime -= root_prodTime four_vec = mcp.get4Vector() prod_vec = mcp.getProductionVertex() # build generated variables as node features self.gen_vars['prodTime'].append(prodTime) self.gen_vars['energy'].append(mcp.getEnergy()) self.gen_vars['x'].append(prod_vec.x()) self.gen_vars['y'].append(prod_vec.y()) self.gen_vars['z'].append(prod_vec.z()) self.gen_vars['px'].append(four_vec.Px()) self.gen_vars['py'].append(four_vec.Py()) self.gen_vars['pz'].append(four_vec.Pz()) self.gen_vars['PDG'].append( TOKENIZE_DICT[int(mcp.getPDG())] ) # Particle level cutting df = pd.DataFrame(self.gen_vars).tail(1) df.query(" and ".join(self.preproc_config["cuts"]), inplace=True) if df.empty: for values in self.gen_vars.values(): values.pop() continue # Collect indices for graph array_indices.append(mcp.getArrayIndex()) mother = mcp.getMother() if mother: mother_indices.append(mother.getArrayIndex()) else: mother_indices.append(0) graph = self.build_graph( array_indices=array_indices, mother_indices=mother_indices, PDGs=self.gen_vars['PDG'], Features=[self.gen_vars[key] for key in self.out_features], symmetrize=True, add_self_loops=True ) # Output pass probability pred = torch.sigmoid(self.model(graph)).detach().numpy().squeeze() # Save the pass probability to EventExtraInfo self.EventExtraInfo.addExtraInfo(self.extra_info_var, pred) # Module returns bool of whether prediciton passes threshold for use in basf2 path flow control if self.threshold: self.return_value(int(pred >= self.threshold)) else: self.return_value(int(pred >= np.random.rand())) def mapped_mother_indices(self, array_indices, mother_indices): """ Map the mother indices to an enumerated list. The one-hot encoded version of that list then corresponds to the adjacency matrix. Example: >>> mapped_mother_indices( ... [0, 1, 3, 5, 6, 7, 8, 9, 10], ... [0, 0, 0, 1, 1, 1, 5, 5, 7] ... ) [0, 0, 0, 1, 1, 1, 3, 3, 5] Args: array_indices: list or array of indices. Each index has to be unique. mother_indices: list or array of mother indices. Returns: List of mapped indices """ idx_dict = {v: i for i, v in enumerate(array_indices)} return [idx_dict[m] for m in mother_indices] def build_graph(self, array_indices, mother_indices, PDGs, Features, symmetrize=True, add_self_loops=True): """ Build graph from preprocessed particle information """ import torch import dgl os.environ["DGLBACKEND"] = "pytorch" # Build adjacency mapping adjacency = self.mapped_mother_indices(array_indices, mother_indices) # Build graph src = adjacency dst = np.arange(len(src)) src_new, dst_new = src, dst if symmetrize: src_new, dst_new = ( np.concatenate([src, dst]), np.concatenate([dst, src]) ) # remove self-loops (the Y(4S)) to avoid duplicated self loops src_new, dst_new = map( np.array, zip(*[(s, d) for s, d in zip(src_new, dst_new) if not s == d]) ) if add_self_loops: src_new, dst_new = ( np.concatenate([src_new, dst]), np.concatenate([dst_new, dst]) ) graph = dgl.graph((src_new, dst_new)) graph.ndata["x_pdg"] = torch.tensor(PDGs, dtype=torch.int32) graph.ndata["x_feature"] = torch.tensor(np.transpose(Features), dtype=torch.float32) return graph