11 from pathlib
import Path
12 import collections.abc
15 def _update_config_dict(d, u):
17 Updates the config dictionary.
20 `recursive solution <https://stackoverflow.com/questions/3232943/update-value-of-a-nested-dictionary-of-varying-depth>`_
21 because it's a nested dict of varying depth.
24 d (dict): Dictionary to update.
25 u (dict): Dictionary of configs.
28 dict: Updated dictionary of configs.
30 for k, v
in u.items():
31 if isinstance(v, collections.abc.Mapping):
32 d[k] = _update_config_dict(d.get(k, {}), v)
38 def load_config(cfg_path=None, model=None, dataset=None, run_name=None, samples=None, **kwargs):
40 Load default configs followed by user configs and populate dataset tags.
43 cfg_path(str or Path): Path to user config yaml.
44 model(str): Name of model to use (overwrites loaded config).
45 dataset(int): Individual dataset to load (overwrites loaded config).
46 run_name(str): Name of training run (overwrites loaded config).
47 samples(int): Number of samples to train on (overwrites loaded config).
50 dict, list: Loaded training configuration dictionary
51 and list of tuples containing (tag name, dataset path, tag key).
55 cwd = Path(__file__).resolve().parent
58 with open(cwd /
'config.yaml')
as file:
59 configs = yaml.safe_load(file)
62 if cfg_path
is not None:
63 with open(cfg_path)
as file:
67 configs = _update_config_dict(configs, yaml.safe_load(file))
71 configs[
'train'][
'model'] = model
73 if (dataset
is not None)
or (
'datasets' not in configs[
'dataset']):
74 configs[
'dataset'][
'datasets'] = dataset
76 if run_name
is not None:
77 configs[
'output'][
'run_name'] = run_name
80 tags = _generate_dataset_tags(configs, samples)
85 def _generate_dataset_tags(configs, samples=None):
87 Generate the different dataset tags and assign their file paths.
88 This helps us keep track of the train/val datasets throughout training and evaluation.
91 config (dict): Training configuration dictionary.
92 samples (dict): Number of training samples.
95 list: List of tuples containing (tag name, dataset path, tag key).
98 source_confs = configs[
'dataset']
102 (
"Training", source_confs[
'path'],
'train'),
103 (
"Validation", source_confs[
'path'],
'val'),
107 source_confs[
'config'][
'samples'] = samples