12 from collections 
import defaultdict
 
   14 from ROOT 
import Belle2
 
   15 from ROOT.Belle2 
import DBAccessorBase, DBStoreEntry
 
   17 from smartBKG 
import TOKENIZE_DICT, PREPROC_CONFIG, MODEL_CONFIG
 
   20 def check_status_bit(status_bit):
 
   22     Returns True if conditions are satisfied (not an unusable particle) 
   25         (status_bit & 1 << 4 == 0) &  
 
   26         (status_bit & 1 << 5 == 0) &  
 
   27         (status_bit & 1 << 6 == 0) &  
 
   28         (status_bit & 1 << 7 == 0)  
 
   35        1. Build a graph from an event composed of MCParticles 
   36        2. Apply the well-trained model for reweighting or sampling method to get a score 
   37        3. Execute reweighting or sampling process to get a weight 
   40        model_file(str): Path to saved model 
   41        model_config(dict): Parameters to build the model 
   42        preproc_config(dict): Parameters to provide information for preprocessing 
   43        threshold(float): Threshold for event selection using reweighting method, value *None* indicating sampling mehtod 
   44        extra_info_var(str): Name of eventExtraInfo to save model prediction to 
   45        global_tag(str): Tag in ConditionDB where the well trained model was stored 
   46        payload(str): Payload for the well trained model in global tag 
   49        Pass or rejected according to random sampling or selection with the given threshold 
   52         Score after the NN filter indicating the probability of the event to pass is saved 
   53         under ``EventExtraInfo.extra_info_var``. 
   55         Use ``eventExtraInfo(extra_info_var)`` in ``modularAnalysis.variablesToNtuple`` or 
   56         ``additionalBranches=["EventExtraInfo"]`` in ``mdst.add_mdst_output`` to have access to the scores. 
   62         model_config=MODEL_CONFIG,
 
   63         preproc_config=PREPROC_CONFIG,
 
   65         extra_info_var="NN_prediction",
 
   66         global_tag="SmartBKG_GATGAP",
 
   67         payload="GATGAPgen.pth"
 
   71         :param model_file:  TODO 
   72         :param model_config:  TODO 
   73         :param preproc_config:  TODO 
   74         :param threshold:  TODO 
   75         :param extra_info_var:  TODO 
   76         :param global_tag:  TODO 
   94         b2.conditions.prepend_globaltag(global_tag)
 
   98         Initialise module before any events are processed 
  101         from smartBKG.models.gatgap 
import GATGAPModel
 
  103         DEVICE = torch.device(
"cpu")
 
  107             accessor = DBAccessorBase(DBStoreEntry.c_RawFile, self.
payloadpayload, 
True)
 
  108             self.
model_filemodel_file = accessor.getFilename()
 
  109         trained_parameters = torch.load(self.
model_filemodel_file, map_location=DEVICE)
 
  113         self.
modelmodel.load_state_dict(trained_parameters[
'model_state_dict'])
 
  129         Collect information from database, build graphs, make predictions and select through sampling or threshold 
  143         for i, mcp 
in enumerate(mcplist):
 
  144             if mcp.isPrimaryParticle():
 
  146                 if not check_status_bit(mcp.getStatus()):
 
  149                 prodTime = mcp.getProductionTime()
 
  152                     root_prodTime = prodTime
 
  153                 prodTime -= root_prodTime
 
  155                 four_vec = mcp.get4Vector()
 
  156                 prod_vec = mcp.getProductionVertex()
 
  159                 self.
gen_varsgen_vars[
'prodTime'].append(prodTime)
 
  160                 self.
gen_varsgen_vars[
'energy'].append(mcp.getEnergy())
 
  161                 self.
gen_varsgen_vars[
'x'].append(prod_vec.x())
 
  162                 self.
gen_varsgen_vars[
'y'].append(prod_vec.y())
 
  163                 self.
gen_varsgen_vars[
'z'].append(prod_vec.z())
 
  164                 self.
gen_varsgen_vars[
'px'].append(four_vec.Px())
 
  165                 self.
gen_varsgen_vars[
'py'].append(four_vec.Py())
 
  166                 self.
gen_varsgen_vars[
'pz'].append(four_vec.Pz())
 
  167                 self.
gen_varsgen_vars[
'PDG'].append(
 
  168                     TOKENIZE_DICT[int(mcp.getPDG())]
 
  172                 df = pd.DataFrame(self.
gen_varsgen_vars).tail(1)
 
  173                 df.query(
" and ".join(self.
preproc_configpreproc_config[
"cuts"]), inplace=
True)
 
  175                     for values 
in self.
gen_varsgen_vars.values():
 
  180                 array_indices.append(mcp.getArrayIndex())
 
  181                 mother = mcp.getMother()
 
  183                     mother_indices.append(mother.getArrayIndex())
 
  185                     mother_indices.append(0)
 
  188             array_indices=array_indices, mother_indices=mother_indices,
 
  190             symmetrize=
True, add_self_loops=
True 
  194         pred = torch.sigmoid(self.
modelmodel(graph)).detach().numpy().squeeze()
 
  201             self.return_value(int(pred >= self.
thresholdthreshold))
 
  203             self.return_value(int(pred >= np.random.rand()))
 
  207         Map the mother indices to an enumerated list. The one-hot encoded version 
  208         of that list then corresponds to the adjacency matrix. 
  211            >>> mapped_mother_indices( 
  212            ...    [0, 1, 3, 5, 6, 7, 8, 9, 10], 
  213            ...    [0, 0, 0, 1, 1, 1, 5, 5, 7] 
  215            [0, 0, 0, 1, 1, 1, 3, 3, 5] 
  218            array_indices: list or array of indices. Each index has to be unique. 
  219            mother_indices: list or array of mother indices. 
  222            List of mapped indices 
  224         idx_dict = {v: i 
for i, v 
in enumerate(array_indices)}
 
  225         return [idx_dict[m] 
for m 
in mother_indices]
 
  227     def build_graph(self, array_indices, mother_indices, PDGs, Features,
 
  228                     symmetrize=True, add_self_loops=True):
 
  230         Build graph from preprocessed particle information 
  234         os.environ[
"DGLBACKEND"] = 
"pytorch" 
  241         dst = np.arange(len(src))
 
  242         src_new, dst_new = src, dst
 
  245                 np.concatenate([src, dst]),
 
  246                 np.concatenate([dst, src])
 
  249         src_new, dst_new = map(
 
  250             np.array, zip(*[(s, d) 
for s, d 
in zip(src_new, dst_new) 
if not s == d])
 
  254                 np.concatenate([src_new, dst]),
 
  255                 np.concatenate([dst_new, dst])
 
  257         graph = dgl.graph((src_new, dst_new))
 
  258         graph.ndata[
"x_pdg"] = torch.tensor(PDGs, dtype=torch.int32)
 
  259         graph.ndata[
"x_feature"] = torch.tensor(np.transpose(Features), dtype=torch.float32)
 
A (simplified) python wrapper for StoreArray.
a (simplified) python wrapper for StoreObjPtr.
def build_graph(self, array_indices, mother_indices, PDGs, Features, symmetrize=True, add_self_loops=True)
model
model with trained parameters
out_features
node features
gen_vars
generated variables
def __init__(self, model_file=None, model_config=MODEL_CONFIG, preproc_config=PREPROC_CONFIG, threshold=None, extra_info_var="NN_prediction", global_tag="SmartBKG_GATGAP", payload="GATGAPgen.pth")
def mapped_mother_indices(self, array_indices, mother_indices)
EventExtraInfo
StoreArray to save weights to.