Belle II Software development
config.py
1
8
9
10import basf2
11import yaml
12import collections.abc
13
14
15def _update_config_dict(d, u):
16 """
17 Updates the config dictionary.
18
19 .. seealso:: Need a
20 `recursive solution <https://stackoverflow.com/questions/3232943/update-value-of-a-nested-dictionary-of-varying-depth>`_
21 because it's a nested dict of varying depth.
22 Args:
23 d (dict): Dictionary to update.
24 u (dict): Dictionary of configs.
25
26 Returns:
27 dict: Updated dictionary of configs.
28 """
29 for k, v in u.items():
30 if isinstance(v, collections.abc.Mapping):
31 d[k] = _update_config_dict(d.get(k, {}), v)
32 else:
33 d[k] = v
34 return d
35
36
37def load_config(cfg_path=None, model=None, dataset=None, run_name=None, samples=None, **kwargs):
38 """
39 Load default configs followed by user configs and populate dataset tags.
40
41 Args:
42 cfg_path(str or Path): Path to user config yaml.
43 model(str): Name of model to use (overwrites loaded config).
44 dataset(int): Individual dataset to load (overwrites loaded config).
45 run_name(str): Name of training run (overwrites loaded config).
46 samples(int): Number of samples to train on (overwrites loaded config).
47
48 Returns:
49 dict, list: Loaded training configuration dictionary
50 and list of tuples containing (tag name, dataset path, tag key).
51 """
52
53 # First load default configs
54 with open(basf2.find_file('data/analysis/grafei_config.yaml')) as file:
55 configs = yaml.safe_load(file)
56
57 # Load user configs if defined, overwriting defaults
58 if cfg_path is not None:
59 with open(cfg_path) as file:
60 # Approach if configs was not a nested dict
61 # configs.update(yaml.safe_load(file))
62 # Use custom update function for nested dict
63 configs = _update_config_dict(configs, yaml.safe_load(file))
64
65 # Overwrite model architecture if specified in command line
66 if model is not None:
67 configs['train']['model'] = model
68 # Set datasets to load, or default to all if unset
69 if (dataset is not None) or ('datasets' not in configs['dataset']):
70 configs['dataset']['datasets'] = dataset
71 # And run name
72 if run_name is not None:
73 configs['output']['run_name'] = run_name
74
75 # Finally, generate the dataset tags
76 tags = _generate_dataset_tags(configs, samples)
77
78 return configs, tags
79
80
81def _generate_dataset_tags(configs, samples=None):
82 """
83 Generate the different dataset tags and assign their file paths.
84 This helps us keep track of the train/val datasets throughout training and evaluation.
85
86 Args:
87 config (dict): Training configuration dictionary.
88 samples (dict): Number of training samples.
89
90 Returns:
91 list: List of tuples containing (tag name, dataset path, tag key).
92 """
93 # Fetch whichever data source we're loading
94 source_confs = configs['dataset']
95
96 # Set up appropriate dataset tags
97 tags = [
98 ("Training", source_confs['path'], 'train'),
99 ("Validation", source_confs['path'], 'val'),
100 ]
101
102 # And overwrite any data source specific configs
103 source_confs['config']['samples'] = samples
104
105 return tags
106