Belle II Software  light-2403-persian
config.py
1 
8 
9 
10 import yaml
11 from pathlib import Path
12 import collections.abc
13 
14 
15 def _update_config_dict(d, u):
16  """
17  Updates the config dictionary.
18 
19  .. seealso:: Need a
20  `recursive solution <https://stackoverflow.com/questions/3232943/update-value-of-a-nested-dictionary-of-varying-depth>`_
21  because it's a nested dict of varying depth.
22 
23  Args:
24  d (dict): Dictionary to update.
25  u (dict): Dictionary of configs.
26 
27  Returns:
28  dict: Updated dictionary of configs.
29  """
30  for k, v in u.items():
31  if isinstance(v, collections.abc.Mapping):
32  d[k] = _update_config_dict(d.get(k, {}), v)
33  else:
34  d[k] = v
35  return d
36 
37 
38 def load_config(cfg_path=None, model=None, dataset=None, run_name=None, samples=None, **kwargs):
39  """
40  Load default configs followed by user configs and populate dataset tags.
41 
42  Args:
43  cfg_path(str or Path): Path to user config yaml.
44  model(str): Name of model to use (overwrites loaded config).
45  dataset(int): Individual dataset to load (overwrites loaded config).
46  run_name(str): Name of training run (overwrites loaded config).
47  samples(int): Number of samples to train on (overwrites loaded config).
48 
49  Returns:
50  dict, list: Loaded training configuration dictionary
51  and list of tuples containing (tag name, dataset path, tag key).
52  """
53 
54  # Need to get this file's working directory to import config from
55  cwd = Path(__file__).resolve().parent
56 
57  # First load default configs
58  with open(cwd / 'config.yaml') as file:
59  configs = yaml.safe_load(file)
60 
61  # Load user configs if defined, overwriting defaults
62  if cfg_path is not None:
63  with open(cfg_path) as file:
64  # Approach if configs was not a nested dict
65  # configs.update(yaml.safe_load(file))
66  # Use custom update function for nested dict
67  configs = _update_config_dict(configs, yaml.safe_load(file))
68 
69  # Overwrite model architecture if specified in command line
70  if model is not None:
71  configs['train']['model'] = model
72  # Set datasets to load, or default to all if unset
73  if (dataset is not None) or ('datasets' not in configs['dataset']):
74  configs['dataset']['datasets'] = dataset
75  # And run name
76  if run_name is not None:
77  configs['output']['run_name'] = run_name
78 
79  # Finally, generate the dataset tags
80  tags = _generate_dataset_tags(configs, samples)
81 
82  return configs, tags
83 
84 
85 def _generate_dataset_tags(configs, samples=None):
86  """
87  Generate the different dataset tags and assign their file paths.
88  This helps us keep track of the train/val datasets throughout training and evaluation.
89 
90  Args:
91  config (dict): Training configuration dictionary.
92  samples (dict): Number of training samples.
93 
94  Returns:
95  list: List of tuples containing (tag name, dataset path, tag key).
96  """
97  # Fetch whichever data source we're loading
98  source_confs = configs['dataset']
99 
100  # Set up appropriate dataset tags
101  tags = [
102  ("Training", source_confs['path'], 'train'),
103  ("Validation", source_confs['path'], 'val'),
104  ]
105 
106  # And overwrite any data source specific configs
107  source_confs['config']['samples'] = samples
108 
109  return tags