Belle II Software development
ArrayDataset Class Reference
Inheritance diagram for ArrayDataset:

Public Member Functions

def __init__ (self, array, batch_size=1024, shuffle=True, seed=None, weighted=False)
 
def __len__ (self)
 
def maybe_permuted (self, array)
 
def __iter__ (self)
 
def __getitem__ (self, slicer)
 

Static Public Member Functions

def to_tensor (array)
 

Public Attributes

 array
 Awkward array containing the dataset.
 
 batch_size
 Batch size for the iterable dataset.
 
 shuffle
 Whether to shuffle the data.
 
 seed
 Random seed for shuffling, consistent seed for all workers.
 
 weighted
 Whether the dataset includes weights.
 

Detailed Description

Dataset initialized from a pre-processed awkward array.

Use `torch.utils.data.DataLoader` with `collate_fn=lambda x: x[0]`
and `batch_size=1` to iterate through it.

Yields a tuple of a batched dgl graph and labels. Optionally also weights if
`weighted=True`. This requires a column `weight` in the array.

Definition at line 81 of file dataset.py.

Constructor & Destructor Documentation

◆ __init__()

def __init__ (   self,
  array,
  batch_size = 1024,
  shuffle = True,
  seed = None,
  weighted = False 
)
Initialize the ArrayDataset for Pytorch training and inference.

:param array: Awkward array containing the dataset.
:param batch_size (int): Batch size for the iterable dataset.
:param shuffle (bool): Whether to shuffle the data.
:param seed: Random seed for shuffling.
:param weighted (bool): Whether the dataset includes weights.

Definition at line 92 of file dataset.py.

99 ):
100 """
101 Initialize the ArrayDataset for Pytorch training and inference.
102
103 :param array: Awkward array containing the dataset.
104 :param batch_size (int): Batch size for the iterable dataset.
105 :param shuffle (bool): Whether to shuffle the data.
106 :param seed: Random seed for shuffling.
107 :param weighted (bool): Whether the dataset includes weights.
108 """
109
110 self.array = array
111
112 self.batch_size = batch_size
113
114 self.shuffle = shuffle
115
116 self.seed = seed if seed is not None else np.random.SeedSequence().entropy
117
118 self.weighted = weighted
119

Member Function Documentation

◆ __getitem__()

def __getitem__ (   self,
  slicer 
)
Get a single instance or a new ArrayDataset for a slice.

Arguments:
    slicer (int or slice): Index or slice.

Returns:
    ArrayDataset: New ArrayDataset instance.

Definition at line 190 of file dataset.py.

190 def __getitem__(self, slicer):
191 """
192 Get a single instance or a new ArrayDataset for a slice.
193
194 Arguments:
195 slicer (int or slice): Index or slice.
196
197 Returns:
198 ArrayDataset: New ArrayDataset instance.
199 """
200 kwargs = dict(
201 batch_size=self.batch_size,
202 shuffle=self.shuffle,
203 seed=self.seed,
204 weighted=self.weighted,
205 )
206 array = self.maybe_permuted(self.array)
207 if not isinstance(slicer, int):
208 return ArrayDataset(array[slicer], **kwargs)
209 slicer = slice(slicer, slicer + 1)
210 kwargs["batch_size"] = 1
211 return next(iter(ArrayDataset(array[slicer], **kwargs)))

◆ __iter__()

def __iter__ (   self)
Iterate over batches with changing random seeds.

Yields:
    tuple: Batched dgl graph, labels, and optionally weights.

Definition at line 160 of file dataset.py.

160 def __iter__(self):
161 """
162 Iterate over batches with changing random seeds.
163
164 Yields:
165 tuple: Batched dgl graph, labels, and optionally weights.
166 """
167 worker_info = torch.utils.data.get_worker_info()
168 if worker_info is not None:
169 num_workers = worker_info.num_workers
170 worker_id = worker_info.id
171 else:
172 num_workers = 1
173 worker_id = 0
174 array = self.maybe_permuted(self.array)
175 starts = list(range(0, len(self.array), self.batch_size))
176 per_worker = np.array_split(starts, num_workers)
177 for start in per_worker[worker_id]:
178 ak_slice = array[start: start + self.batch_size]
179 output = [
180 get_batched_graph(ak_slice, DEFAULT_NODE_FEATURES),
181 self.to_tensor(ak_slice.label),
182 ]
183 if self.weighted:
184 output.append(self.to_tensor(ak_slice.weight))
185 yield tuple(output)
186 # increase the seed to get a new order of instances in the next iteration
187 # note: need to use persistent_workers=True in the DataLoader for this to take effect
188 self.seed += 1
189

◆ __len__()

def __len__ (   self)
Get the number of batches.

Returns:
    int: Number of batches.

Definition at line 120 of file dataset.py.

120 def __len__(self):
121 """
122 Get the number of batches.
123
124 Returns:
125 int: Number of batches.
126 """
127 return int(math.ceil(len(self.array) / self.batch_size))
128

◆ maybe_permuted()

def maybe_permuted (   self,
  array 
)
Possibly permute the array based on the shuffle parameter.

Arguments:
    array (awkward array): Input array.

Returns:
    array: Permuted or original array.

Definition at line 129 of file dataset.py.

129 def maybe_permuted(self, array):
130 """
131 Possibly permute the array based on the shuffle parameter.
132
133 Arguments:
134 array (awkward array): Input array.
135
136 Returns:
137 array: Permuted or original array.
138 """
139 if not self.shuffle or len(self.array) == 1:
140 return array
141 perm = np.random.default_rng(self.seed).permutation(len(array))
142 return self.array[perm]
143

◆ to_tensor()

def to_tensor (   array)
static
Convert an awkward array to a torch tensor.

Arguments:
    array (awkward array): Input awkward array.

Returns:
    torch.Tensor: Converted tensor.

Definition at line 145 of file dataset.py.

145 def to_tensor(array):
146 """
147 Convert an awkward array to a torch tensor.
148
149 Arguments:
150 array (awkward array): Input awkward array.
151
152 Returns:
153 torch.Tensor: Converted tensor.
154 """
155 return torch.tensor(
156 ak.to_numpy(array, allow_missing=False),
157 dtype=torch.float32,
158 ).reshape(-1, 1)
159

Member Data Documentation

◆ array

array

Awkward array containing the dataset.

Definition at line 110 of file dataset.py.

◆ batch_size

batch_size

Batch size for the iterable dataset.

Definition at line 112 of file dataset.py.

◆ seed

seed

Random seed for shuffling, consistent seed for all workers.

Definition at line 116 of file dataset.py.

◆ shuffle

shuffle

Whether to shuffle the data.

Definition at line 114 of file dataset.py.

◆ weighted

weighted

Whether the dataset includes weights.

Definition at line 118 of file dataset.py.


The documentation for this class was generated from the following file: