Belle II Software development
dataset_split.py
1
8
9
10import torch_geometric
11from pathlib import Path
12from .geometric_datasets import GraphDataSet
13
14
15def create_dataloader_mode_tags(configs, tags):
16 """
17 Convenience function to create the dataset/dataloader for each mode tag (train/val) and return them.
18
19 Args:
20 configs (dict): Training configuration.
21 tags (list): Mode tags train/val containing dataset paths.
22
23 Returns:
24 dict: Mode tag dictionary containing tuples of (mode, dataset, dataloader).
25 """
26
27 mode_tags = {}
28
29 for tag, path, mode in tags:
30 dataset = GraphDataSet(
31 root=Path(path, mode),
32 run_name=configs["output"]["run_name"],
33 **configs["dataset"]["config"],
34 )
35
36 print(
37 f"{type(dataset).__name__} created for {mode} with {dataset.__len__()} samples\n"
38 )
39
40 dataloader = torch_geometric.loader.DataLoader(
41 dataset, batch_size=configs["train"]["batch_size"],
42 shuffle=True,
43 drop_last=True,
44 )
45
46 mode_tags[tag] = (mode, dataset, dataloader)
47
48 return mode_tags