11from pathlib 
import Path
 
   12from .geometric_datasets 
import GraphDataSet
 
   15def create_dataloader_mode_tags(configs, tags):
 
   17    Convenience function to create the dataset/dataloader for each mode tag (train/val) and return them. 
   20        configs (dict): Training configuration. 
   21        tags (list): Mode tags train/val containing dataset paths. 
   24        dict: Mode tag dictionary containing tuples of (mode, dataset, dataloader). 
   29    for tag, path, mode 
in tags:
 
   30        dataset = GraphDataSet(
 
   31            root=Path(path, mode),
 
   32            run_name=configs[
"output"][
"run_name"],
 
   33            **configs[
"dataset"][
"config"],
 
   37            f
"{type(dataset).__name__} created for {mode} with {dataset.__len__()} samples\n" 
   40        dataloader = torch_geometric.loader.DataLoader(
 
   41            dataset, batch_size=configs[
"train"][
"batch_size"],
 
   46        mode_tags[tag] = (mode, dataset, dataloader)