Belle II Software development
|
Public Member Functions | |
def | __init__ (self, array, batch_size=1024, shuffle=True, seed=None, weighted=False) |
def | __len__ (self) |
def | maybe_permuted (self, array) |
def | __iter__ (self) |
def | __getitem__ (self, slicer) |
Static Public Member Functions | |
def | to_tensor (array) |
Public Attributes | |
array | |
Awkward array containing the dataset. | |
batch_size | |
Batch size for the iterable dataset. | |
shuffle | |
Whether to shuffle the data. | |
seed | |
Random seed for shuffling, consistent seed for all workers. | |
weighted | |
Whether the dataset includes weights. | |
Dataset initialized from a pre-processed awkward array. Use `torch.utils.data.DataLoader` with `collate_fn=lambda x: x[0]` and `batch_size=1` to iterate through it. Yields a tuple of a batched dgl graph and labels. Optionally also weights if `weighted=True`. This requires a column `weight` in the array.
Definition at line 81 of file dataset.py.
def __init__ | ( | self, | |
array, | |||
batch_size = 1024 , |
|||
shuffle = True , |
|||
seed = None , |
|||
weighted = False |
|||
) |
Initialize the ArrayDataset for Pytorch training and inference. :param array: Awkward array containing the dataset. :param batch_size (int): Batch size for the iterable dataset. :param shuffle (bool): Whether to shuffle the data. :param seed: Random seed for shuffling. :param weighted (bool): Whether the dataset includes weights.
Definition at line 92 of file dataset.py.
def __getitem__ | ( | self, | |
slicer | |||
) |
Get a single instance or a new ArrayDataset for a slice. Arguments: slicer (int or slice): Index or slice. Returns: ArrayDataset: New ArrayDataset instance.
Definition at line 190 of file dataset.py.
def __iter__ | ( | self | ) |
Iterate over batches with changing random seeds. Yields: tuple: Batched dgl graph, labels, and optionally weights.
Definition at line 160 of file dataset.py.
def __len__ | ( | self | ) |
Get the number of batches. Returns: int: Number of batches.
Definition at line 120 of file dataset.py.
def maybe_permuted | ( | self, | |
array | |||
) |
Possibly permute the array based on the shuffle parameter. Arguments: array (awkward array): Input array. Returns: array: Permuted or original array.
Definition at line 129 of file dataset.py.
|
static |
Convert an awkward array to a torch tensor. Arguments: array (awkward array): Input awkward array. Returns: torch.Tensor: Converted tensor.
Definition at line 145 of file dataset.py.
array |
Awkward array containing the dataset.
Definition at line 110 of file dataset.py.
batch_size |
Batch size for the iterable dataset.
Definition at line 112 of file dataset.py.
seed |
Random seed for shuffling, consistent seed for all workers.
Definition at line 116 of file dataset.py.
shuffle |
Whether to shuffle the data.
Definition at line 114 of file dataset.py.
weighted |
Whether the dataset includes weights.
Definition at line 118 of file dataset.py.