![]() |
Belle II Software development
|
Public Member Functions | |
| __init__ (self, array, batch_size=1024, shuffle=True, seed=None, weighted=False) | |
| __len__ (self) | |
| maybe_permuted (self, array) | |
| __iter__ (self) | |
| __getitem__ (self, slicer) | |
Static Public Member Functions | |
| to_tensor (array) | |
Public Attributes | |
| array = array | |
| Awkward array containing the dataset. | |
| batch_size = batch_size | |
| Batch size for the iterable dataset. | |
| shuffle = shuffle | |
| Whether to shuffle the data. | |
| seed = seed if seed is not None else np.random.SeedSequence().entropy | |
| Random seed for shuffling, consistent seed for all workers. | |
| weighted = weighted | |
| Whether the dataset includes weights. | |
Dataset initialized from a pre-processed awkward array. Use `torch.utils.data.DataLoader` with `collate_fn=lambda x: x[0]` and `batch_size=1` to iterate through it. Yields a tuple of a batched dgl graph and labels. Optionally also weights if `weighted=True`. This requires a column `weight` in the array.
Definition at line 81 of file dataset.py.
| __init__ | ( | self, | |
| array, | |||
| batch_size = 1024, | |||
| shuffle = True, | |||
| seed = None, | |||
| weighted = False ) |
Initialize the ArrayDataset for Pytorch training and inference. :param array: Awkward array containing the dataset. :param batch_size (int): Batch size for the iterable dataset. :param shuffle (bool): Whether to shuffle the data. :param seed: Random seed for shuffling. :param weighted (bool): Whether the dataset includes weights.
Definition at line 92 of file dataset.py.
| __getitem__ | ( | self, | |
| slicer ) |
Get a single instance or a new ArrayDataset for a slice.
Arguments:
slicer (int or slice): Index or slice.
Returns:
ArrayDataset: New ArrayDataset instance.
Definition at line 190 of file dataset.py.
| __iter__ | ( | self | ) |
Iterate over batches with changing random seeds.
Yields:
tuple: Batched dgl graph, labels, and optionally weights.
Definition at line 160 of file dataset.py.
| __len__ | ( | self | ) |
Get the number of batches.
Returns:
int: Number of batches.
Definition at line 120 of file dataset.py.
| maybe_permuted | ( | self, | |
| array ) |
Possibly permute the array based on the shuffle parameter.
Arguments:
array (awkward array): Input array.
Returns:
array: Permuted or original array.
Definition at line 129 of file dataset.py.
|
static |
Convert an awkward array to a torch tensor.
Arguments:
array (awkward array): Input awkward array.
Returns:
torch.Tensor: Converted tensor.
Definition at line 145 of file dataset.py.
Awkward array containing the dataset.
Definition at line 110 of file dataset.py.
| batch_size = batch_size |
Batch size for the iterable dataset.
Definition at line 112 of file dataset.py.
| seed = seed if seed is not None else np.random.SeedSequence().entropy |
Random seed for shuffling, consistent seed for all workers.
Definition at line 116 of file dataset.py.
| shuffle = shuffle |
Whether to shuffle the data.
Definition at line 114 of file dataset.py.
| weighted = weighted |
Whether the dataset includes weights.
Definition at line 118 of file dataset.py.