14from smartBKG 
import DEFAULT_NODE_FEATURES
 
   17def get_batched_graph(array, node_features, index_column=(
"particles", 
"motherIndex")):
 
   19    Generate a batched DGL graph from an awkward array.
 
   22        array (awkward array): containing particle information.
 
   23        node_features (dict): mapping field names to columns 
in the array.
 
   24        index_column (list): Column names 
for indexing.
 
   27        dgl.DGLGraph: Batched DGL graph.
 
   30        This function assumes the input array has a nested structure representing particle relations.
 
   32    mother_index = array[index_column] 
   33    array_index = ak.local_index(mother_index, axis=1) 
   36        ak.concatenate([mother_index, array_index], axis=1), 
   37        ak.concatenate([array_index, mother_index], axis=1), 
   42    mask = (src != -1) & (dst != -1) & (src != dst)
 
   48        ak.concatenate([src, array_index], axis=1),
 
   49        ak.concatenate([dst, array_index], axis=1),
 
   53    offsets = np.append(0, np.cumsum(ak.num(mother_index).to_numpy())[:-1])
 
   56    src_flat, dst_flat = (
 
   57        torch.tensor(ak.to_numpy(ak.flatten(src + offsets))),
 
   58        torch.tensor(ak.to_numpy(ak.flatten(dst + offsets))),
 
   61    batched = dgl.graph((src_flat, dst_flat))
 
   62    batched.set_batch_num_nodes(torch.tensor(ak.num(mother_index).to_numpy()))
 
   63    batched.set_batch_num_edges(torch.tensor(ak.num(src).to_numpy()))
 
   64    for field, columns 
in node_features.items():
 
   65        feats = array[columns]
 
   66        if len(feats.fields) == 0:
 
   67            flat_feats = ak.to_numpy(ak.flatten(feats), allow_missing=
False)
 
   69            flat_feats = np.stack(
 
   71                    ak.to_numpy(ak.flatten(x), allow_missing=
False)
 
   72                    for x 
in ak.unzip(feats)
 
   76        batched.ndata[field] = torch.tensor(flat_feats, dtype=torch.float32)
 
   83    Dataset initialized from a pre-processed awkward array.
 
   85    Use `torch.utils.data.DataLoader` 
with `collate_fn=
lambda x: x[0]`
 
   86    and `batch_size=1` to iterate through it.
 
   88    Yields a tuple of a batched dgl graph 
and labels. Optionally also weights 
if 
   89    `weighted=
True`. This requires a column `weight` 
in the array.
 
  101        Initialize the ArrayDataset for Pytorch training 
and inference.
 
  103        :param array: Awkward array containing the dataset.
 
  104        :param batch_size (int): Batch size 
for the iterable dataset.
 
  105        :param shuffle (bool): Whether to shuffle the data.
 
  106        :param seed: Random seed 
for shuffling.
 
  107        :param weighted (bool): Whether the dataset includes weights.
 
  116        self.seed = seed if seed 
is not None else np.random.SeedSequence().entropy
 
  122        Get the number of batches. 
  125            int: Number of batches. 
  131        Possibly permute the array based on the shuffle parameter. 
  134            array (awkward array): Input array. 
  137            array: Permuted or original array.
 
  141        perm = np.random.default_rng(self.
seed).permutation(len(array))
 
  142        return self.
array[perm]
 
  147        Convert an awkward array to a torch tensor. 
  150            array (awkward array): Input awkward array. 
  153            torch.Tensor: Converted tensor. 
  156            ak.to_numpy(array, allow_missing=
False),
 
  162        Iterate over batches with changing random seeds.
 
  165            tuple: Batched dgl graph, labels, 
and optionally weights.
 
  167        worker_info = torch.utils.data.get_worker_info() 
  168        if worker_info 
is not None:
 
  169            num_workers = worker_info.num_workers
 
  170            worker_id = worker_info.id
 
  176        per_worker = np.array_split(starts, num_workers)
 
  177        for start 
in per_worker[worker_id]:
 
  178            ak_slice = array[start: start + self.
batch_size]
 
  180                get_batched_graph(ak_slice, DEFAULT_NODE_FEATURES),
 
  184                output.append(self.
to_tensor(ak_slice.weight))
 
  192        Get a single instance or a new ArrayDataset 
for a slice.
 
  195            slicer (int 
or slice): Index 
or slice.
 
  198            ArrayDataset: New ArrayDataset instance.
 
  207        if not isinstance(slicer, int):
 
  209        slicer = slice(slicer, slicer + 1)
 
  210        kwargs[
"batch_size"] = 1
 
  211        return next(iter(
ArrayDataset(array[slicer], **kwargs)))
 
def __getitem__(self, slicer)
def maybe_permuted(self, array)
array
Awkward array containing the dataset.
weighted
Whether the dataset includes weights.
def __init__(self, array, batch_size=1024, shuffle=True, seed=None, weighted=False)
seed
Random seed for shuffling, consistent seed for all workers.
batch_size
Batch size for the iterable dataset.
shuffle
Whether to shuffle the data.