11from pathlib
import Path
12from .geometric_datasets
import GraphDataSet
15def create_dataloader_mode_tags(configs, tags):
17 Convenience function to create the dataset/dataloader for each mode tag (train/val)
and return them.
20 configs (dict): Training configuration.
21 tags (list): Mode tags train/val containing dataset paths.
24 dict: Mode tag dictionary containing tuples of (mode, dataset, dataloader).
29 for tag, path, mode
in tags:
30 dataset = GraphDataSet(
31 root=Path(path, mode),
32 run_name=configs[
"output"][
"run_name"],
33 **configs[
"dataset"][
"config"],
37 f
"{type(dataset).__name__} created for {mode} with {dataset.__len__()} samples\n"
40 dataloader = torch_geometric.loader.DataLoader(
41 dataset, batch_size=configs[
"train"][
"batch_size"],
46 mode_tags[tag] = (mode, dataset, dataloader)