![]() |
Belle II Software development
|
Public Member Functions | |
__init__ (self, array, batch_size=1024, shuffle=True, seed=None, weighted=False) | |
__len__ (self) | |
maybe_permuted (self, array) | |
__iter__ (self) | |
__getitem__ (self, slicer) | |
Static Public Member Functions | |
to_tensor (array) | |
Public Attributes | |
array = array | |
Awkward array containing the dataset. | |
batch_size = batch_size | |
Batch size for the iterable dataset. | |
shuffle = shuffle | |
Whether to shuffle the data. | |
seed = seed if seed is not None else np.random.SeedSequence().entropy | |
Random seed for shuffling, consistent seed for all workers. | |
weighted = weighted | |
Whether the dataset includes weights. | |
Dataset initialized from a pre-processed awkward array. Use `torch.utils.data.DataLoader` with `collate_fn=lambda x: x[0]` and `batch_size=1` to iterate through it. Yields a tuple of a batched dgl graph and labels. Optionally also weights if `weighted=True`. This requires a column `weight` in the array.
Definition at line 81 of file dataset.py.
__init__ | ( | self, | |
array, | |||
batch_size = 1024, | |||
shuffle = True, | |||
seed = None, | |||
weighted = False ) |
Initialize the ArrayDataset for Pytorch training and inference. :param array: Awkward array containing the dataset. :param batch_size (int): Batch size for the iterable dataset. :param shuffle (bool): Whether to shuffle the data. :param seed: Random seed for shuffling. :param weighted (bool): Whether the dataset includes weights.
Definition at line 92 of file dataset.py.
__getitem__ | ( | self, | |
slicer ) |
Get a single instance or a new ArrayDataset for a slice. Arguments: slicer (int or slice): Index or slice. Returns: ArrayDataset: New ArrayDataset instance.
Definition at line 190 of file dataset.py.
__iter__ | ( | self | ) |
Iterate over batches with changing random seeds. Yields: tuple: Batched dgl graph, labels, and optionally weights.
Definition at line 160 of file dataset.py.
__len__ | ( | self | ) |
Get the number of batches. Returns: int: Number of batches.
Definition at line 120 of file dataset.py.
maybe_permuted | ( | self, | |
array ) |
Possibly permute the array based on the shuffle parameter. Arguments: array (awkward array): Input array. Returns: array: Permuted or original array.
Definition at line 129 of file dataset.py.
|
static |
Convert an awkward array to a torch tensor. Arguments: array (awkward array): Input awkward array. Returns: torch.Tensor: Converted tensor.
Definition at line 145 of file dataset.py.
Awkward array containing the dataset.
Definition at line 110 of file dataset.py.
batch_size = batch_size |
Batch size for the iterable dataset.
Definition at line 112 of file dataset.py.
seed = seed if seed is not None else np.random.SeedSequence().entropy |
Random seed for shuffling, consistent seed for all workers.
Definition at line 116 of file dataset.py.
shuffle = shuffle |
Whether to shuffle the data.
Definition at line 114 of file dataset.py.
weighted = weighted |
Whether the dataset includes weights.
Definition at line 118 of file dataset.py.