Belle II Software development
dataset.py
1
8import math
9import awkward as ak
10import numpy as np
11import dgl
12import torch
13
14from smartBKG import DEFAULT_NODE_FEATURES
15
16
17def get_batched_graph(array, node_features, index_column=("particles", "motherIndex")):
18 """
19 Generate a batched DGL graph from an awkward array.
20
21 Arguments:
22 array (awkward array): containing particle information.
23 node_features (dict): mapping field names to columns in the array.
24 index_column (list): Column names for indexing.
25
26 Returns:
27 dgl.DGLGraph: Batched DGL graph.
28
29 Note:
30 This function assumes the input array has a nested structure representing particle relations.
31 """
32 mother_index = array[index_column]
33 array_index = ak.local_index(mother_index, axis=1)
34
35 src, dst = (
36 ak.concatenate([mother_index, array_index], axis=1),
37 ak.concatenate([array_index, mother_index], axis=1),
38 )
39 # remove edges to mothers that have been removed
40 # (represented by index -1)
41 # also remove self-loops (src == dst)
42 mask = (src != -1) & (dst != -1) & (src != dst)
43 src = src[mask]
44 dst = dst[mask]
45
46 # now add a single self-loop for each array index
47 src, dst = (
48 ak.concatenate([src, array_index], axis=1),
49 ak.concatenate([dst, array_index], axis=1),
50 )
51
52 # calculate offsets explicitly such that it works also for ListArray
53 offsets = np.append(0, np.cumsum(ak.num(mother_index).to_numpy())[:-1])
54
55 # add offsets such that we get a single graph of disconnected node groups
56 src_flat, dst_flat = (
57 torch.tensor(ak.to_numpy(ak.flatten(src + offsets))),
58 torch.tensor(ak.to_numpy(ak.flatten(dst + offsets))),
59 )
60
61 batched = dgl.graph((src_flat, dst_flat))
62 batched.set_batch_num_nodes(torch.tensor(ak.num(mother_index).to_numpy()))
63 batched.set_batch_num_edges(torch.tensor(ak.num(src).to_numpy()))
64 for field, columns in node_features.items():
65 feats = array[columns]
66 if len(feats.fields) == 0:
67 flat_feats = ak.to_numpy(ak.flatten(feats), allow_missing=False)
68 else:
69 flat_feats = np.stack(
70 [
71 ak.to_numpy(ak.flatten(x), allow_missing=False)
72 for x in ak.unzip(feats)
73 ],
74 axis=1
75 )
76 batched.ndata[field] = torch.tensor(flat_feats, dtype=torch.float32)
77
78 return batched
79
80
81class ArrayDataset(torch.utils.data.IterableDataset):
82 """
83 Dataset initialized from a pre-processed awkward array.
84
85 Use `torch.utils.data.DataLoader` with `collate_fn=lambda x: x[0]`
86 and `batch_size=1` to iterate through it.
87
88 Yields a tuple of a batched dgl graph and labels. Optionally also weights if
89 `weighted=True`. This requires a column `weight` in the array.
90 """
91
93 self,
94 array,
95 batch_size=1024,
96 shuffle=True,
97 seed=None,
98 weighted=False,
99 ):
100 """
101 Initialize the ArrayDataset for Pytorch training and inference.
102
103 :param array: Awkward array containing the dataset.
104 :param batch_size (int): Batch size for the iterable dataset.
105 :param shuffle (bool): Whether to shuffle the data.
106 :param seed: Random seed for shuffling.
107 :param weighted (bool): Whether the dataset includes weights.
108 """
109
110 self.array = array
111
112 self.batch_size = batch_size
113
114 self.shuffle = shuffle
115
116 self.seed = seed if seed is not None else np.random.SeedSequence().entropy
117
118 self.weighted = weighted
119
120 def __len__(self):
121 """
122 Get the number of batches.
123
124 Returns:
125 int: Number of batches.
126 """
127 return int(math.ceil(len(self.array) / self.batch_size))
128
129 def maybe_permuted(self, array):
130 """
131 Possibly permute the array based on the shuffle parameter.
132
133 Arguments:
134 array (awkward array): Input array.
135
136 Returns:
137 array: Permuted or original array.
138 """
139 if not self.shuffle or len(self.array) == 1:
140 return array
141 perm = np.random.default_rng(self.seed).permutation(len(array))
142 return self.array[perm]
143
144 @staticmethod
145 def to_tensor(array):
146 """
147 Convert an awkward array to a torch tensor.
148
149 Arguments:
150 array (awkward array): Input awkward array.
151
152 Returns:
153 torch.Tensor: Converted tensor.
154 """
155 return torch.tensor(
156 ak.to_numpy(array, allow_missing=False),
157 dtype=torch.float32,
158 ).reshape(-1, 1)
159
160 def __iter__(self):
161 """
162 Iterate over batches with changing random seeds.
163
164 Yields:
165 tuple: Batched dgl graph, labels, and optionally weights.
166 """
167 worker_info = torch.utils.data.get_worker_info()
168 if worker_info is not None:
169 num_workers = worker_info.num_workers
170 worker_id = worker_info.id
171 else:
172 num_workers = 1
173 worker_id = 0
174 array = self.maybe_permuted(self.array)
175 starts = list(range(0, len(self.array), self.batch_size))
176 per_worker = np.array_split(starts, num_workers)
177 for start in per_worker[worker_id]:
178 ak_slice = array[start: start + self.batch_size]
179 output = [
180 get_batched_graph(ak_slice, DEFAULT_NODE_FEATURES),
181 self.to_tensor(ak_slice.label),
182 ]
183 if self.weighted:
184 output.append(self.to_tensor(ak_slice.weight))
185 yield tuple(output)
186 # increase the seed to get a new order of instances in the next iteration
187 # note: need to use persistent_workers=True in the DataLoader for this to take effect
188 self.seed += 1
189
190 def __getitem__(self, slicer):
191 """
192 Get a single instance or a new ArrayDataset for a slice.
193
194 Arguments:
195 slicer (int or slice): Index or slice.
196
197 Returns:
198 ArrayDataset: New ArrayDataset instance.
199 """
200 kwargs = dict(
201 batch_size=self.batch_size,
202 shuffle=self.shuffle,
203 seed=self.seed,
204 weighted=self.weighted,
205 )
206 array = self.maybe_permuted(self.array)
207 if not isinstance(slicer, int):
208 return ArrayDataset(array[slicer], **kwargs)
209 slicer = slice(slicer, slicer + 1)
210 kwargs["batch_size"] = 1
211 return next(iter(ArrayDataset(array[slicer], **kwargs)))
def to_tensor(array)
Definition: dataset.py:145
def __getitem__(self, slicer)
Definition: dataset.py:190
def maybe_permuted(self, array)
Definition: dataset.py:129
array
Awkward array containing the dataset.
Definition: dataset.py:110
def __iter__(self)
Definition: dataset.py:160
weighted
Whether the dataset includes weights.
Definition: dataset.py:118
def __len__(self)
Definition: dataset.py:120
def __init__(self, array, batch_size=1024, shuffle=True, seed=None, weighted=False)
Definition: dataset.py:99
seed
Random seed for shuffling, consistent seed for all workers.
Definition: dataset.py:116
batch_size
Batch size for the iterable dataset.
Definition: dataset.py:112
shuffle
Whether to shuffle the data.
Definition: dataset.py:114