![]() |
Belle II Software release-09-00-03
|


Public Member Functions | |
| def | __init__ (self, array, batch_size=1024, shuffle=True, seed=None, weighted=False) |
| def | __len__ (self) |
| def | maybe_permuted (self, array) |
| def | __iter__ (self) |
| def | __getitem__ (self, slicer) |
Static Public Member Functions | |
| def | to_tensor (array) |
Public Attributes | |
| array | |
| Awkward array containing the dataset. | |
| batch_size | |
| Batch size for the iterable dataset. | |
| shuffle | |
| Whether to shuffle the data. | |
| seed | |
| Random seed for shuffling, consistent seed for all workers. | |
| weighted | |
| Whether the dataset includes weights. | |
Dataset initialized from a pre-processed awkward array. Use `torch.utils.data.DataLoader` with `collate_fn=lambda x: x[0]` and `batch_size=1` to iterate through it. Yields a tuple of a batched dgl graph and labels. Optionally also weights if `weighted=True`. This requires a column `weight` in the array.
Definition at line 81 of file dataset.py.
| def __init__ | ( | self, | |
| array, | |||
batch_size = 1024, |
|||
shuffle = True, |
|||
seed = None, |
|||
weighted = False |
|||
| ) |
Initialize the ArrayDataset for Pytorch training and inference. :param array: Awkward array containing the dataset. :param batch_size (int): Batch size for the iterable dataset. :param shuffle (bool): Whether to shuffle the data. :param seed: Random seed for shuffling. :param weighted (bool): Whether the dataset includes weights.
Definition at line 92 of file dataset.py.
| def __getitem__ | ( | self, | |
| slicer | |||
| ) |
Get a single instance or a new ArrayDataset for a slice.
Arguments:
slicer (int or slice): Index or slice.
Returns:
ArrayDataset: New ArrayDataset instance.
Definition at line 190 of file dataset.py.
| def __iter__ | ( | self | ) |
Iterate over batches with changing random seeds.
Yields:
tuple: Batched dgl graph, labels, and optionally weights.
Definition at line 160 of file dataset.py.
| def __len__ | ( | self | ) |
Get the number of batches.
Returns:
int: Number of batches.
Definition at line 120 of file dataset.py.
| def maybe_permuted | ( | self, | |
| array | |||
| ) |
Possibly permute the array based on the shuffle parameter.
Arguments:
array (awkward array): Input array.
Returns:
array: Permuted or original array.
Definition at line 129 of file dataset.py.
|
static |
Convert an awkward array to a torch tensor.
Arguments:
array (awkward array): Input awkward array.
Returns:
torch.Tensor: Converted tensor.
Definition at line 145 of file dataset.py.
| array |
Awkward array containing the dataset.
Definition at line 110 of file dataset.py.
| batch_size |
Batch size for the iterable dataset.
Definition at line 112 of file dataset.py.
| seed |
Random seed for shuffling, consistent seed for all workers.
Definition at line 116 of file dataset.py.
| shuffle |
Whether to shuffle the data.
Definition at line 114 of file dataset.py.
| weighted |
Whether the dataset includes weights.
Definition at line 118 of file dataset.py.