Belle II Software development
config.py
1
8
9
10import basf2
11import yaml
12import collections.abc
13
14
15def _update_config_dict(d, u):
16 """
17 Updates the config dictionary.
18
19 .. seealso:: Need a
20 `recursive solution <https://stackoverflow.com/questions/3232943/update-value-of-a-nested-dictionary-of-varying-depth>`_
21 because it's a nested dict of varying depth.
22
23 Args:
24 d (dict): Dictionary to update.
25 u (dict): Dictionary of configs.
26
27 Returns:
28 dict: Updated dictionary of configs.
29 """
30 for k, v in u.items():
31 if isinstance(v, collections.abc.Mapping):
32 d[k] = _update_config_dict(d.get(k, {}), v)
33 else:
34 d[k] = v
35 return d
36
37
38def load_config(cfg_path=None, model=None, dataset=None, run_name=None, samples=None, **kwargs):
39 """
40 Load default configs followed by user configs and populate dataset tags.
41
42 Args:
43 cfg_path(str or Path): Path to user config yaml.
44 model(str): Name of model to use (overwrites loaded config).
45 dataset(int): Individual dataset to load (overwrites loaded config).
46 run_name(str): Name of training run (overwrites loaded config).
47 samples(int): Number of samples to train on (overwrites loaded config).
48
49 Returns:
50 dict, list: Loaded training configuration dictionary
51 and list of tuples containing (tag name, dataset path, tag key).
52 """
53
54 # First load default configs
55 with open(basf2.find_file('data/analysis/grafei_config.yaml')) as file:
56 configs = yaml.safe_load(file)
57
58 # Load user configs if defined, overwriting defaults
59 if cfg_path is not None:
60 with open(cfg_path) as file:
61 # Approach if configs was not a nested dict
62 # configs.update(yaml.safe_load(file))
63 # Use custom update function for nested dict
64 configs = _update_config_dict(configs, yaml.safe_load(file))
65
66 # Overwrite model architecture if specified in command line
67 if model is not None:
68 configs['train']['model'] = model
69 # Set datasets to load, or default to all if unset
70 if (dataset is not None) or ('datasets' not in configs['dataset']):
71 configs['dataset']['datasets'] = dataset
72 # And run name
73 if run_name is not None:
74 configs['output']['run_name'] = run_name
75
76 # Finally, generate the dataset tags
77 tags = _generate_dataset_tags(configs, samples)
78
79 return configs, tags
80
81
82def _generate_dataset_tags(configs, samples=None):
83 """
84 Generate the different dataset tags and assign their file paths.
85 This helps us keep track of the train/val datasets throughout training and evaluation.
86
87 Args:
88 config (dict): Training configuration dictionary.
89 samples (dict): Number of training samples.
90
91 Returns:
92 list: List of tuples containing (tag name, dataset path, tag key).
93 """
94 # Fetch whichever data source we're loading
95 source_confs = configs['dataset']
96
97 # Set up appropriate dataset tags
98 tags = [
99 ("Training", source_confs['path'], 'train'),
100 ("Validation", source_confs['path'], 'val'),
101 ]
102
103 # And overwrite any data source specific configs
104 source_confs['config']['samples'] = samples
105
106 return tags