Belle II Software development
create_trainer.py
1
8
9
10import ignite
11import torch
12from torch_geometric.data import Batch
13import numpy as np
14import collections.abc
15from datetime import datetime
16from pathlib import Path
17import yaml
18from .metrics import PerfectLCA, PerfectEvent, PerfectMasses
19
20
22 """
23 Class to setup the ignite trainer and hold all the things associated.
24
25 :param model: The actual PyTorch model.
26 :type model: `Model <https://pytorch.org/tutorials/beginner/introyt/modelsyt_tutorial.html>`_
27 :param optimizer: Optimizer used in training.
28 :type optimizer: `Optimizer <https://pytorch.org/docs/stable/optim.html#torch.optim.Optimizer>`_
29 :param loss_fn: Loss function.
30 :type loss_fn: `Loss <https://pytorch.org/docs/stable/nn.html#loss-functions>`_
31 :param device: Device to use.
32 :type device: `Device <https://pytorch.org/docs/stable/tensor_attributes.html#torch.device>`_
33 :param configs: Dictionary of run configs from loaded yaml config file.
34 :type configs: dict
35 :param tags: Various tags to sort train and validation evaluators by, e.g. "Training", "Validation".
36 :type tags: list
37 :param scheduler: Learning rate scheduler.
38 :type scheduler: `Scheduler <https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate>`_
39 :param ignore_index: Label index to ignore when calculating metrics, e.g. padding.
40 :type ignore_index: int
41 """
42
44 self,
45 model,
46 optimizer,
47 loss_fn,
48 device,
49 configs,
50 tags,
51 scheduler=None,
52 ignore_index=-1.0,
53 ):
54 """
55 Initialization.
56 """
57
58 self.model = model
59
60 self.optimizer = optimizer
61
62 self.configs = configs
63
64 self.tags = tags
65
66 self.ignore_index = ignore_index
67
68 self.device = device
69
70
71 self.timestamp = datetime.now().strftime("%Y.%m.%d_%H.%M")
72
73
74 self.run_dir = None
75 if self.configs["output"] is not None:
76 if ("path" in self.configs["output"].keys()) and (
77 self.configs["output"]["path"] is not None
78 ):
79 self.run_dir = Path(
80 self.configs["output"]["path"],
81 self.configs["output"]["run_name"],
82 )
83
84 # Setup ignite trainer
85 use_amp = configs["train"]["mixed_precision"] and self.device == torch.device(
86 "cuda"
87 )
88
89 if use_amp:
90 from torch.cuda.amp import autocast
91 from torch.cuda.amp import GradScaler
92
93 scaler = GradScaler(enabled=True)
94
95 def _update_model(engine, batch):
96 # This just sets the training mode
97 model.train()
98
99 optimizer.zero_grad()
100
101 batch = (
102 Batch.from_data_list(batch).to(device)
103 if isinstance(batch, list)
104 else batch.to(device)
105 )
106
107 x_y, edge_y, u_y = batch.x_y, batch.edge_y, batch.u_y
108
109 if use_amp:
110 with autocast(enabled=True):
111 x_pred, e_pred, u_pred = model(batch)
112 loss = loss_fn(x_pred, x_y, e_pred, edge_y, u_pred, u_y)
113 scaler.scale(loss).backward()
114 scaler.step(optimizer)
115 scaler.update()
116 else:
117 x_pred, e_pred, u_pred = model(batch)
118 loss = loss_fn(x_pred, x_y, e_pred, edge_y, u_pred, u_y)
119 loss.backward()
120 optimizer.step()
121
122 return loss.item()
123
124
125 self.trainer = ignite.engine.Engine(_update_model)
126
127 if scheduler:
128 ig_scheduler = ignite.handlers.param_scheduler.LRScheduler(scheduler)
129 self.trainer.add_event_handler(ignite.engine.Events.ITERATION_STARTED, ig_scheduler)
130
131
132 self.evaluators = {}
133
134 for tag in self.tags:
135 # Setup metrics
136 metrics = {
137 # ignite.metrics.Loss takes (y_pred, y, **kwargs) arguments.
138 # MultiTrainLoss needs in total 6 arguments to be computed,
139 # so the additional ones are passed in a dictionary.
140 "loss": ignite.metrics.Loss(
141 loss_fn,
142 output_transform=lambda x: [
143 x[0], x[3],
144 {
145 "edge_input": x[1],
146 "edge_target": x[4],
147 "u_input": x[2],
148 "u_target": x[5],
149 },
150 ],
151 device=device,
152 ),
153 "perfectLCA": PerfectLCA(
154 ignore_index=ignore_index,
155 device=device,
156 output_transform=lambda x: [
157 x[1], x[4], x[6], x[5], x[7], x[8],
158 ],
159 ),
160 "perfectMasses": PerfectMasses(
161 ignore_index=ignore_index,
162 device=device,
163 output_transform=lambda x: [x[0], x[3], x[5], x[7], x[8]],
164 ),
165 "perfectEvent": PerfectEvent(
166 ignore_index=ignore_index,
167 device=device,
168 output_transform=lambda x: [
169 x[0], x[3], x[1], x[4], x[6], x[5], x[7], x[8],
170 ],
171 ),
172 }
173
174 def _predict_on_batch(engine, batch):
175 model.eval() # It just enables evaluation mode
176
177 batch = (
178 Batch.from_data_list(batch).to(device)
179 if isinstance(batch, list)
180 else batch.to(device)
181 )
182
183 x_y, edge_y, u_y, edge_index, torch_batch = (
184 batch.x_y,
185 batch.edge_y,
186 batch.u_y,
187 batch.edge_index,
188 batch.batch,
189 )
190 num_graph = batch.batch[torch_batch.shape[0] - 1] + 1
191
192 with torch.no_grad():
193 if use_amp:
194 with autocast(enabled=True):
195 x_pred, e_pred, u_pred = model(batch)
196 else:
197 x_pred, e_pred, u_pred = model(batch)
198
199 return (
200 x_pred,
201 e_pred,
202 u_pred,
203 x_y,
204 edge_y,
205 u_y,
206 edge_index,
207 torch_batch,
208 num_graph,
209 )
210
211 self.evaluators[tag] = ignite.engine.Engine(_predict_on_batch)
212
213 for metric_name, metric in zip(metrics.keys(), metrics.values()):
214 metric.attach(self.evaluators[tag], metric_name)
215
216 def _score_fn(self, engine):
217 """Metric to use for early stoppging"""
218 return -engine.state.metrics["loss"]
219
220 def _perfect_score_fn(self, engine):
221 """Metric to use for checkpoints"""
222 return engine.state.metrics["perfectEvent"]
223
224 def _clean_config_dict(self, configs):
225 """
226 Clean configs to prepare them for writing to file.
227 """
228 for k, v in configs.items():
229 if isinstance(v, collections.abc.Mapping):
230 configs[k] = self._clean_config_dict(configs[k])
231 elif isinstance(v, np.ndarray):
232 configs[k] = v.tolist()
233 else:
234 configs[k] = v
235 return configs
236
237 def setup_handlers(self, cfg_filename="config.yaml"):
238 """
239 Creates the various ignite handlers (callbacks).
240
241 Args:
242 cfg_filename (str): Name of config yaml file to use when saving configs.
243 """
244 # Create the output directory
245 if self.run_dir is not None:
246 self.run_dir.mkdir(parents=True, exist_ok=True)
247 # And save the configs, putting here to only save when setting up checkpointing
248 with open(
249 self.run_dir / f"{self.timestamp}_{cfg_filename}", "w"
250 ) as outfile:
251 cleaned_configs = self._clean_config_dict(self.configs)
252 yaml.dump(cleaned_configs, outfile, default_flow_style=False)
253
254 # Setup early stopping
255 early_handler = ignite.handlers.EarlyStopping(
256 patience=self.configs["train"]["early_stop_patience"],
257 score_function=self._score_fn,
258 trainer=self.trainer,
259 min_delta=1e-3,
260 )
261 self.evaluators["Validation"].add_event_handler(
262 ignite.engine.Events.EPOCH_COMPLETED, early_handler
263 )
264
265 # Configure saving the best performing model
266 if self.run_dir is not None:
267 to_save = {
268 "model": self.model,
269 "optimizer": self.optimizer,
270 "trainer": self.trainer,
271 }
272 # Note that we judge early stopping above by the validation loss, but save the best model
273 # according to validation perfectEvent score. This lets training continue for perfectEvent plateaus
274 # so long as the model is still changing (and hopefully improving again after some time).
275 best_model_handler = ignite.handlers.Checkpoint(
276 to_save=to_save,
277 save_handler=ignite.handlers.DiskSaver(
278 self.run_dir, create_dir=True, require_empty=False
279 ),
280 filename_prefix=self.timestamp,
281 score_function=self._perfect_score_fn,
282 score_name="validation_perfectEvent",
283 n_saved=1,
284 global_step_transform=ignite.handlers.global_step_from_engine(
285 self.evaluators["Validation"]
286 ),
287 )
288 self.evaluators["Validation"].add_event_handler(
289 ignite.engine.Events.EPOCH_COMPLETED, best_model_handler
290 )
291
292 return
293
294 # Set up end of epoch validation procedure
295 # Tell it to print epoch results
296 def log_results(self, trainer, mode_tags):
297 """
298 Callback to run evaluation and report the results.
299
300 :param trainer: Trainer passed by ignite to this method.
301 :type trainer: `Engine <https://pytorch.org/ignite/generated/ignite.engine.engine.Engine.html#ignite.engine.engine.Engine>`_
302 :param mode_tags: Dictionary of mode tags containing (mode, dataset, dataloader) tuples.
303 :type mode_tags: dict
304 """
305
306 for tag, values in mode_tags.items():
307 evaluator = self.evaluators[tag]
308
309 # Need to wrap this in autocast since it calculates metrics (i.e. loss) without autocast switched on
310 # This is mostly fine except it fails to correctly cast the class weights tensor passed to the loss
311 if self.configs["train"]["mixed_precision"] and self.device == torch.device("cuda"):
312 with torch.cuda.amp.autocast():
313 evaluator.run(values[2], epoch_length=None)
314 else:
315 evaluator.run(values[2], epoch_length=None)
316
317 metrics = evaluator.state.metrics
318 message = [f"{tag} Results - Epoch: {trainer.state.epoch}"]
319 message.extend([f"Avg {m}: {metrics[m]:.4f}" for m in metrics])
320 print(message)
def __init__(self, model, optimizer, loss_fn, device, configs, tags, scheduler=None, ignore_index=-1.0)
timestamp
Run timestamp to distinguish trainings.
def log_results(self, trainer, mode_tags)
run_dir
Output directory for checkpoints.
def setup_handlers(self, cfg_filename="config.yaml")
evaluators
Setup train and validation evaluators.