Source code for grafei.model.geometric_datasets

##########################################################################
# basf2 (Belle II Analysis Software Framework)                           #
# Author: The Belle II Collaboration                                     #
#                                                                        #
# See git log for contributors and copyright holders.                    #
# This file is licensed under LGPL-3.0, see LICENSE.md.                  #
##########################################################################


import itertools
from pathlib import Path
import numpy as np
import torch
from .tree_utils import masses_to_classes
from .dataset_utils import populate_avail_samples, preload_root_data
from .edge_features import compute_edge_features
from .normalize_features import normalize_features
from torch_geometric.data import Data, InMemoryDataset
import uproot


def _preload(self):
    """
    Creates graph objects and stores them into a python list.
    """

    #: Going to use x_files as an array that always exists
    self.x_files = sorted(self.root.glob("**/*.root"))

    # Select the first N files (useful for testing)
    if self.n_files is not None:
        if self.n_files > len(self.x_files):
            print(
                f"WARNING: n_files specified ({self.n_files}) is greater than files in path given ({len(self.x_files)})"
            )

        self.x_files = self.x_files[: self.n_files]

    if len(self.x_files) == 0:
        raise RuntimeError(f"No files found in {self.root}")

    # Save the features
    with uproot.open(self.x_files[0])["Tree"] as t:
        #: Features
        self.features = [f for f in t.keys() if f.startswith("feat_")]
        #: B reco flag
        self.B_reco = int(t["isB"].array(library="np")[0])
        assert self.B_reco in [0, 1, 2], "B_reco should be 0, 1 or 2, something went wrong"

    #: Discarded features
    self.discarded = [
        f for f in self.features if not f[f.find("_") + 1:] in self.node_features
    ]
    self.features = [
        f"feat_{f}" for f in self.node_features if f"feat_{f}" in self.features
    ]

    print(f"Input node features: {self.features}")
    print(f"Discarded node features: {self.discarded}")

    #: Edge features
    self.edge_features = [f"edge_{f}" for f in self.edge_features]
    #: Global features
    self.global_features = [f"glob_{f}" for f in self.global_features] if self.global_features else []
    print(f"Input edge features: {self.edge_features}")
    print(f"Input global features: {self.global_features}")

    #: Preload data
    self.x, self.y = preload_root_data(
        self.x_files,
        self.features,
        self.discarded,
    )

    #: Need to populate a list of available training samples
    self.avail_samples = populate_avail_samples(
        self.x,
        self.y,
        self.B_reco,
    )

    # Select a subset of available samples if requested
    if self.samples and self.samples < len(self.avail_samples):
        print(f"Selecting random subset of {self.samples} samples")
        self.avail_samples = [
            self.avail_samples[i]
            for i in np.random.choice(
                len(self.avail_samples), self.samples, replace=False
            )
        ]
    elif self.samples and (self.samples >= len(self.avail_samples)):
        print(
            f"WARNING: No. samples specified ({self.samples}) exceeds number of samples loaded ({len(self.avail_samples)})"
        )

    return len(self.avail_samples)


def _process_graph(self, idx):
    """
    Actually builds the graph object.

    Args:
        idx (int): Index of training example to be processed.

    Returns:
        torch_geometric.data.Data: Graph object to be used in training.
    """

    file_id, evt, p_index = self.avail_samples[idx]

    x_item = self.x[file_id]
    y_item = self.y[file_id][p_index]

    evt_b_index = x_item["b_index"][evt]
    evt_leaves = x_item["leaves"][evt]
    evt_primary = x_item["primary"][evt]

    y_leaves = y_item["LCA_leaves"][evt]
    # Use this to correctly reshape LCA (might be able to just use shape of y_leaves?)
    n_LCA = y_item["n_LCA"][evt]

    # Get the rows of the X inputs to fetch
    # This is a boolean numpy array
    x_rows = (evt_b_index != -1) if not self.B_reco else evt_b_index == int(p_index)

    # Find the unmatched particles
    unmatched_rows = evt_b_index == -1

    if np.any(unmatched_rows) and self.B_reco:
        # Create a random boolean array the same size as the number of leaves
        rand_mask = np.random.choice(a=[False, True], size=unmatched_rows.size)
        # AND the mask with the unmatched leaves
        # This selects a random subset of the unmatched leaves
        unmatched_rows = np.logical_and(unmatched_rows, rand_mask)

    # Add the unmatched rows to the current decay's rows
    x_rows = np.logical_or(x_rows, unmatched_rows)

    # Here we actually load the data

    # Initialise event's X array
    x = np.empty((x_rows.sum(), len(self.features)))
    x_dis = np.empty((x_rows.sum(), len(self.discarded)))

    # And populate it
    for idx, feat in enumerate(self.features):
        x[:, idx] = x_item["features"][feat][evt][x_rows]
    for idx, feat in enumerate(self.discarded):
        x_dis[:, idx] = x_item["discarded"][feat][evt][x_rows]

    # Same for edge and global features
    x_edges = (
        compute_edge_features(
            self.edge_features,
            self.features + self.discarded,
            np.concatenate([x, x_dis], axis=1),
        )
        if self.edge_features is not []
        else []
    )
    x_global = (
        np.array(
            [
                [
                    x_item["global"][feat + f"_{p_index}"][evt]
                    for feat in self.global_features
                ]
            ]
        )
        if self.global_features != []
        else []
    )

    x_leaves = evt_leaves[x_rows]

    # Set nans to zero, this is a surrogate value, may change in future
    np.nan_to_num(x, copy=False)
    np.nan_to_num(x_edges, copy=False)
    np.nan_to_num(x_global, copy=False)

    # Normalize any features that should be
    if self.normalize is not None:
        normalize_features(
            self.normalize,
            self.features,
            x,
            self.edge_features,
            x_edges,
            self.global_features,
            x_global,
        )

    # Reorder LCA

    # Get LCA indices in order that the leaves appear in reconstructed particles
    # Secondaries aren't in the LCA leaves list so they get a 0
    locs = np.array(
        [
            np.where(y_leaves == i)[0].item() if (i in y_leaves) else 0
            for i in x_leaves
        ]
    )

    # Get the LCA in the correct subset order
    # If we're not allowing secondaries this is all we need
    # If we are this will contain duplicates (since secondary locs are set to 0)
    # We can't load the firs locs directly (i.e. y_item[locs, :]) because locs is (intentionally) unsorted
    y_edge = y_item["LCA"][evt].reshape((n_LCA, n_LCA)).astype(int)
    # Get the true mcPDG pf FSPs
    y_mass = masses_to_classes(x_item["mc_pdg"][evt][x_rows])

    # Get the specificed row/cols, this inserts dummy rows/cols for secondaries
    y_edge = y_edge[locs, :][:, locs]
    # if self.allow_secondaries:
    # Set everything that's not primary (unmatched and secondaries) rows.cols to 0
    # Note we only consider the subset of leaves that made it into x_rows
    y_edge = np.where(evt_primary[x_rows], y_edge, 0)  # Set the rows
    y_edge = np.where(evt_primary[x_rows][:, None], y_edge, 0)  # Set the columns

    # Set diagonal to -1 (actually not necessary but ok...)
    np.fill_diagonal(y_edge, -1)

    n_nodes = x.shape[0]

    # Target edge tensor: shape [E]
    edge_y = torch.tensor(
        y_edge[np.eye(n_nodes) == 0],
        dtype=torch.long
    )
    # Fill tensor with edge indices: shape [N*(N-1), 2] == [E, 2]
    edge_index = torch.tensor(
        list(itertools.permutations(range(n_nodes), 2)),
        dtype=torch.long,
    )

    # Target global tensor: shape [B, F_g]
    u_y = torch.tensor(
        [[1]], dtype=torch.float
    )

    # Target node tensor: shape [N]
    x_y = torch.tensor(y_mass, dtype=torch.long)

    g = Data(
        x=torch.tensor(x, dtype=torch.float),
        edge_index=edge_index.t().contiguous(),
        edge_attr=torch.tensor(x_edges, dtype=torch.float),
        u=torch.tensor(x_global, dtype=torch.float),
        x_y=x_y,
        edge_y=edge_y,
        u_y=u_y,
    )

    return g


[docs]class GraphDataSet(InMemoryDataset): """ Dataset handler for converting Belle II data to PyTorch geometric InMemoryDataset. The ROOT format expects the tree in every file to be named ``Tree``, and all node features to have the format ``feat_FEATNAME``. .. note:: This expects the files under root to have the structure ``root/**/<file_name>.root`` where the root path is different for train and val. The ``**/`` is to handle subdirectories, e.g. ``sub00``. Args: root (str): Path to ROOT files. n_files (int): Load only ``n_files`` files. samples (int): Load only ``samples`` events. features (list): List of node features names. edge_features (list): List of edge features names. global_features (list): List of global features names. normalize (bool): Whether to normalize input features. """ def __init__( self, root, n_files=None, samples=None, features=[], edge_features=[], global_features=[], normalize=None, **kwargs, ): """ Initialization. """ assert isinstance( features, list ), f'Argument "features" must be a list and not {type(features)}' assert len(features) > 0, "You need to use at least one node feature" #: Root path self.root = Path(root) #: Normalize self.normalize = normalize #: Number of files self.n_files = n_files #: Node features self.node_features = features #: Edge features self.edge_features = edge_features #: Global features self.global_features = global_features #: Samples self.samples = samples # Delete processed files, in case file_path = Path(self.root, "processed") files = list(file_path.glob("*.pt")) for f in files: f.unlink(missing_ok=True) # Needs to be called after having assigned all attributes super().__init__(root, None, None, None) #: Data and Slices self.data, self.slices = torch.load(self.processed_paths[0]) @property def processed_file_names(self): """ Name of processed file. """ return ["processed_data.pt"] def process(self): """ Processes the data to create graph objects and stores them in ``root/processed/processed_data.pt`` where the root path is different for train and val. Called internally by PyTorch. """ num_samples = _preload(self) data_list = [_process_graph(self, i) for i in range(num_samples)] data, slices = self.collate(data_list) torch.save((data, slices), self.processed_paths[0]) del self.x, self.y, self.avail_samples, data_list, data, slices