Belle II Software development
ArrayDataset Class Reference
Inheritance diagram for ArrayDataset:

Public Member Functions

 __init__ (self, array, batch_size=1024, shuffle=True, seed=None, weighted=False)
 
 __len__ (self)
 
 maybe_permuted (self, array)
 
 __iter__ (self)
 
 __getitem__ (self, slicer)
 

Static Public Member Functions

 to_tensor (array)
 

Public Attributes

 array = array
 Awkward array containing the dataset.
 
 batch_size = batch_size
 Batch size for the iterable dataset.
 
 shuffle = shuffle
 Whether to shuffle the data.
 
 seed = seed if seed is not None else np.random.SeedSequence().entropy
 Random seed for shuffling, consistent seed for all workers.
 
 weighted = weighted
 Whether the dataset includes weights.
 

Detailed Description

Dataset initialized from a pre-processed awkward array.

Use `torch.utils.data.DataLoader` with `collate_fn=lambda x: x[0]`
and `batch_size=1` to iterate through it.

Yields a tuple of a batched dgl graph and labels. Optionally also weights if
`weighted=True`. This requires a column `weight` in the array.

Definition at line 81 of file dataset.py.

Constructor & Destructor Documentation

◆ __init__()

__init__ ( self,
array,
batch_size = 1024,
shuffle = True,
seed = None,
weighted = False )
Initialize the ArrayDataset for Pytorch training and inference.

:param array: Awkward array containing the dataset.
:param batch_size (int): Batch size for the iterable dataset.
:param shuffle (bool): Whether to shuffle the data.
:param seed: Random seed for shuffling.
:param weighted (bool): Whether the dataset includes weights.

Definition at line 92 of file dataset.py.

99 ):
100 """
101 Initialize the ArrayDataset for Pytorch training and inference.
102
103 :param array: Awkward array containing the dataset.
104 :param batch_size (int): Batch size for the iterable dataset.
105 :param shuffle (bool): Whether to shuffle the data.
106 :param seed: Random seed for shuffling.
107 :param weighted (bool): Whether the dataset includes weights.
108 """
109
110 self.array = array
111
112 self.batch_size = batch_size
113
114 self.shuffle = shuffle
115
116 self.seed = seed if seed is not None else np.random.SeedSequence().entropy
117
118 self.weighted = weighted
119

Member Function Documentation

◆ __getitem__()

__getitem__ ( self,
slicer )
Get a single instance or a new ArrayDataset for a slice.

Arguments:
    slicer (int or slice): Index or slice.

Returns:
    ArrayDataset: New ArrayDataset instance.

Definition at line 190 of file dataset.py.

190 def __getitem__(self, slicer):
191 """
192 Get a single instance or a new ArrayDataset for a slice.
193
194 Arguments:
195 slicer (int or slice): Index or slice.
196
197 Returns:
198 ArrayDataset: New ArrayDataset instance.
199 """
200 kwargs = dict(
201 batch_size=self.batch_size,
202 shuffle=self.shuffle,
203 seed=self.seed,
204 weighted=self.weighted,
205 )
206 array = self.maybe_permuted(self.array)
207 if not isinstance(slicer, int):
208 return ArrayDataset(array[slicer], **kwargs)
209 slicer = slice(slicer, slicer + 1)
210 kwargs["batch_size"] = 1
211 return next(iter(ArrayDataset(array[slicer], **kwargs)))

◆ __iter__()

__iter__ ( self)
Iterate over batches with changing random seeds.

Yields:
    tuple: Batched dgl graph, labels, and optionally weights.

Definition at line 160 of file dataset.py.

160 def __iter__(self):
161 """
162 Iterate over batches with changing random seeds.
163
164 Yields:
165 tuple: Batched dgl graph, labels, and optionally weights.
166 """
167 worker_info = torch.utils.data.get_worker_info()
168 if worker_info is not None:
169 num_workers = worker_info.num_workers
170 worker_id = worker_info.id
171 else:
172 num_workers = 1
173 worker_id = 0
174 array = self.maybe_permuted(self.array)
175 starts = list(range(0, len(self.array), self.batch_size))
176 per_worker = np.array_split(starts, num_workers)
177 for start in per_worker[worker_id]:
178 ak_slice = array[start: start + self.batch_size]
179 output = [
180 get_batched_graph(ak_slice, DEFAULT_NODE_FEATURES),
181 self.to_tensor(ak_slice.label),
182 ]
183 if self.weighted:
184 output.append(self.to_tensor(ak_slice.weight))
185 yield tuple(output)
186 # increase the seed to get a new order of instances in the next iteration
187 # note: need to use persistent_workers=True in the DataLoader for this to take effect
188 self.seed += 1
189

◆ __len__()

__len__ ( self)
Get the number of batches.

Returns:
    int: Number of batches.

Definition at line 120 of file dataset.py.

120 def __len__(self):
121 """
122 Get the number of batches.
123
124 Returns:
125 int: Number of batches.
126 """
127 return int(math.ceil(len(self.array) / self.batch_size))
128

◆ maybe_permuted()

maybe_permuted ( self,
array )
Possibly permute the array based on the shuffle parameter.

Arguments:
    array (awkward array): Input array.

Returns:
    array: Permuted or original array.

Definition at line 129 of file dataset.py.

129 def maybe_permuted(self, array):
130 """
131 Possibly permute the array based on the shuffle parameter.
132
133 Arguments:
134 array (awkward array): Input array.
135
136 Returns:
137 array: Permuted or original array.
138 """
139 if not self.shuffle or len(self.array) == 1:
140 return array
141 perm = np.random.default_rng(self.seed).permutation(len(array))
142 return self.array[perm]
143

◆ to_tensor()

to_tensor ( array)
static
Convert an awkward array to a torch tensor.

Arguments:
    array (awkward array): Input awkward array.

Returns:
    torch.Tensor: Converted tensor.

Definition at line 145 of file dataset.py.

145 def to_tensor(array):
146 """
147 Convert an awkward array to a torch tensor.
148
149 Arguments:
150 array (awkward array): Input awkward array.
151
152 Returns:
153 torch.Tensor: Converted tensor.
154 """
155 return torch.tensor(
156 ak.to_numpy(array, allow_missing=False),
157 dtype=torch.float32,
158 ).reshape(-1, 1)
159

Member Data Documentation

◆ array

Awkward array containing the dataset.

Definition at line 110 of file dataset.py.

◆ batch_size

batch_size = batch_size

Batch size for the iterable dataset.

Definition at line 112 of file dataset.py.

◆ seed

seed = seed if seed is not None else np.random.SeedSequence().entropy

Random seed for shuffling, consistent seed for all workers.

Definition at line 116 of file dataset.py.

◆ shuffle

shuffle = shuffle

Whether to shuffle the data.

Definition at line 114 of file dataset.py.

◆ weighted

weighted = weighted

Whether the dataset includes weights.

Definition at line 118 of file dataset.py.


The documentation for this class was generated from the following file: