Source code for grafei.model.dataset_split
##########################################################################
# basf2 (Belle II Analysis Software Framework) #
# Author: The Belle II Collaboration #
# #
# See git log for contributors and copyright holders. #
# This file is licensed under LGPL-3.0, see LICENSE.md. #
##########################################################################
import torch_geometric
from pathlib import Path
from .geometric_datasets import GraphDataSet
[docs]
def create_dataloader_mode_tags(configs, tags):
"""
Convenience function to create the dataset/dataloader for each mode tag (train/val) and return them.
Args:
configs (dict): Training configuration.
tags (list): Mode tags train/val containing dataset paths.
Returns:
dict: Mode tag dictionary containing tuples of (mode, dataset, dataloader).
"""
mode_tags = {}
for tag, path, mode in tags:
dataset = GraphDataSet(
root=Path(path, mode),
run_name=configs["output"]["run_name"],
**configs["dataset"]["config"],
)
print(
f"{type(dataset).__name__} created for {mode} with {dataset.__len__()} samples\n"
)
dataloader = torch_geometric.loader.DataLoader(
dataset, batch_size=configs["train"]["batch_size"],
shuffle=True,
drop_last=True,
)
mode_tags[tag] = (mode, dataset, dataloader)
return mode_tags