14from smartBKG
import DEFAULT_NODE_FEATURES
17def get_batched_graph(array, node_features, index_column=(
"particles",
"motherIndex")):
19 Generate a batched DGL graph from an awkward array.
22 array (awkward array): containing particle information.
23 node_features (dict): mapping field names to columns
in the array.
24 index_column (list): Column names
for indexing.
27 dgl.DGLGraph: Batched DGL graph.
30 This function assumes the input array has a nested structure representing particle relations.
32 mother_index = array[index_column]
33 array_index = ak.local_index(mother_index, axis=1)
36 ak.concatenate([mother_index, array_index], axis=1),
37 ak.concatenate([array_index, mother_index], axis=1),
42 mask = (src != -1) & (dst != -1) & (src != dst)
48 ak.concatenate([src, array_index], axis=1),
49 ak.concatenate([dst, array_index], axis=1),
53 offsets = np.append(0, np.cumsum(ak.num(mother_index).to_numpy())[:-1])
56 src_flat, dst_flat = (
57 torch.tensor(ak.to_numpy(ak.flatten(src + offsets))),
58 torch.tensor(ak.to_numpy(ak.flatten(dst + offsets))),
61 batched = dgl.graph((src_flat, dst_flat))
62 batched.set_batch_num_nodes(torch.tensor(ak.num(mother_index).to_numpy()))
63 batched.set_batch_num_edges(torch.tensor(ak.num(src).to_numpy()))
64 for field, columns
in node_features.items():
65 feats = array[columns]
66 if len(feats.fields) == 0:
67 flat_feats = ak.to_numpy(ak.flatten(feats), allow_missing=
False)
69 flat_feats = np.stack(
71 ak.to_numpy(ak.flatten(x), allow_missing=
False)
72 for x
in ak.unzip(feats)
76 batched.ndata[field] = torch.tensor(flat_feats, dtype=torch.float32)
83 Dataset initialized from a pre-processed awkward array.
85 Use `torch.utils.data.DataLoader`
with `collate_fn=
lambda x: x[0]`
86 and `batch_size=1` to iterate through it.
88 Yields a tuple of a batched dgl graph
and labels. Optionally also weights
if
89 `weighted=
True`. This requires a column `weight`
in the array.
101 Initialize the ArrayDataset for Pytorch training
and inference.
103 :param array: Awkward array containing the dataset.
104 :param batch_size (int): Batch size
for the iterable dataset.
105 :param shuffle (bool): Whether to shuffle the data.
106 :param seed: Random seed
for shuffling.
107 :param weighted (bool): Whether the dataset includes weights.
116 self.seed = seed if seed
is not None else np.random.SeedSequence().entropy
122 Get the number of batches.
125 int: Number of batches.
131 Possibly permute the array based on the shuffle parameter.
134 array (awkward array): Input array.
137 array: Permuted or original array.
141 perm = np.random.default_rng(self.
seed).permutation(len(array))
142 return self.
array[perm]
147 Convert an awkward array to a torch tensor.
150 array (awkward array): Input awkward array.
153 torch.Tensor: Converted tensor.
156 ak.to_numpy(array, allow_missing=
False),
162 Iterate over batches with changing random seeds.
165 tuple: Batched dgl graph, labels,
and optionally weights.
167 worker_info = torch.utils.data.get_worker_info()
168 if worker_info
is not None:
169 num_workers = worker_info.num_workers
170 worker_id = worker_info.id
176 per_worker = np.array_split(starts, num_workers)
177 for start
in per_worker[worker_id]:
178 ak_slice = array[start: start + self.
batch_size]
180 get_batched_graph(ak_slice, DEFAULT_NODE_FEATURES),
184 output.append(self.
to_tensor(ak_slice.weight))
192 Get a single instance or a new ArrayDataset
for a slice.
195 slicer (int
or slice): Index
or slice.
198 ArrayDataset: New ArrayDataset instance.
207 if not isinstance(slicer, int):
209 slicer = slice(slicer, slicer + 1)
210 kwargs[
"batch_size"] = 1
211 return next(iter(
ArrayDataset(array[slicer], **kwargs)))
def __getitem__(self, slicer)
def maybe_permuted(self, array)
array
Awkward array containing the dataset.
weighted
Whether the dataset includes weights.
def __init__(self, array, batch_size=1024, shuffle=True, seed=None, weighted=False)
seed
Random seed for shuffling, consistent seed for all workers.
batch_size
Batch size for the iterable dataset.
shuffle
Whether to shuffle the data.