Belle II Software  light-2403-persian
dataset_split.py
1 
8 
9 
10 import torch_geometric
11 from pathlib import Path
12 from .geometric_datasets import GraphDataSet
13 
14 
15 def create_dataloader_mode_tags(configs, tags):
16  """
17  Convenience function to create the dataset/dataloader for each mode tag (train/val) and return them.
18 
19  Args:
20  configs (dict): Training configuration.
21  tags (list): Mode tags train/val containing dataset paths.
22 
23  Returns:
24  dict: Mode tag dictionary containing tuples of (mode, dataset, dataloader).
25  """
26 
27  mode_tags = {}
28 
29  for tag, path, mode in tags:
30  dataset = GraphDataSet(
31  root=Path(path, mode),
32  run_name=configs["output"]["run_name"],
33  **configs["dataset"]["config"],
34  )
35 
36  print(
37  f"{type(dataset).__name__} created for {mode} with {dataset.__len__()} samples\n"
38  )
39 
40  dataloader = torch_geometric.loader.DataLoader(
41  dataset, batch_size=configs["train"]["batch_size"],
42  shuffle=True,
43  drop_last=True,
44  )
45 
46  mode_tags[tag] = (mode, dataset, dataloader)
47 
48  return mode_tags