15def _update_config_dict(d, u):
17 Updates the config dictionary.
20 `recursive solution <https://stackoverflow.com/questions/3232943/update-value-of-a-nested-dictionary-of-varying-depth>`_
21 because it's a nested dict of varying depth.
23 d (dict): Dictionary to update.
24 u (dict): Dictionary of configs.
27 dict: Updated dictionary of configs.
29 for k, v
in u.items():
30 if isinstance(v, collections.abc.Mapping):
31 d[k] = _update_config_dict(d.get(k, {}), v)
37def load_config(cfg_path=None, model=None, dataset=None, run_name=None, samples=None, **kwargs):
39 Load default configs followed by user configs and populate dataset tags.
42 cfg_path(str
or Path): Path to user config yaml.
43 model(str): Name of model to use (overwrites loaded config).
44 dataset(int): Individual dataset to load (overwrites loaded config).
45 run_name(str): Name of training run (overwrites loaded config).
46 samples(int): Number of samples to train on (overwrites loaded config).
49 dict, list: Loaded training configuration dictionary
50 and list of tuples containing (tag name, dataset path, tag key).
54 with open(basf2.find_file(
'data/analysis/grafei_config.yaml'))
as file:
55 configs = yaml.safe_load(file)
58 if cfg_path
is not None:
59 with open(cfg_path)
as file:
63 configs = _update_config_dict(configs, yaml.safe_load(file))
67 configs[
'train'][
'model'] = model
69 if (dataset
is not None)
or (
'datasets' not in configs[
'dataset']):
70 configs[
'dataset'][
'datasets'] = dataset
72 if run_name
is not None:
73 configs[
'output'][
'run_name'] = run_name
76 tags = _generate_dataset_tags(configs, samples)
81def _generate_dataset_tags(configs, samples=None):
83 Generate the different dataset tags and assign their file paths.
84 This helps us keep track of the train/val datasets throughout training
and evaluation.
87 config (dict): Training configuration dictionary.
88 samples (dict): Number of training samples.
91 list: List of tuples containing (tag name, dataset path, tag key).
94 source_confs = configs[
'dataset']
98 (
"Training", source_confs[
'path'],
'train'),
99 (
"Validation", source_confs[
'path'],
'val'),
103 source_confs[
'config'][
'samples'] = samples