Belle II Software  light-2403-persian
create_trainer.py
1 
8 
9 
10 import ignite
11 import torch
12 from torch_geometric.data import Batch
13 import numpy as np
14 import collections.abc
15 from datetime import datetime
16 from pathlib import Path
17 import yaml
18 from .metrics import PerfectLCA, PerfectEvent, PerfectMasses
19 
20 
22  """
23  Class to setup the ignite trainer and hold all the things associated.
24 
25  :param model: The actual PyTorch model.
26  :type model: `Model <https://pytorch.org/tutorials/beginner/introyt/modelsyt_tutorial.html>`_
27  :param optimizer: Optimizer used in training.
28  :type optimizer: `Optimizer <https://pytorch.org/docs/stable/optim.html#torch.optim.Optimizer>`_
29  :param loss_fn: Loss function.
30  :type loss_fn: `Loss <https://pytorch.org/docs/stable/nn.html#loss-functions>`_
31  :param device: Device to use.
32  :type device: `Device <https://pytorch.org/docs/stable/tensor_attributes.html#torch.device>`_
33  :param configs: Dictionary of run configs from loaded yaml config file.
34  :type configs: dict
35  :param tags: Various tags to sort train and validation evaluators by, e.g. "Training", "Validation".
36  :type tags: list
37  :param scheduler: Learning rate scheduler.
38  :type scheduler: `Scheduler <https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate>`_
39  :param ignore_index: Label index to ignore when calculating metrics, e.g. padding.
40  :type ignore_index: int
41  """
42 
43  def __init__(
44  self,
45  model,
46  optimizer,
47  loss_fn,
48  device,
49  configs,
50  tags,
51  scheduler=None,
52  ignore_index=-1.0,
53  ):
54  """
55  Initialization.
56  """
57 
58  self.modelmodel = model
59 
60  self.optimizeroptimizer = optimizer
61 
62  self.configsconfigs = configs
63 
64  self.tagstags = tags
65 
66  self.ignore_indexignore_index = ignore_index
67 
68  self.devicedevice = device
69 
70 
71  self.timestamptimestamp = datetime.now().strftime("%Y.%m.%d_%H.%M")
72 
73 
74  self.run_dirrun_dir = None
75  if self.configsconfigs["output"] is not None:
76  if ("path" in self.configsconfigs["output"].keys()) and (
77  self.configsconfigs["output"]["path"] is not None
78  ):
79  self.run_dirrun_dir = Path(
80  self.configsconfigs["output"]["path"],
81  self.configsconfigs["output"]["run_name"],
82  )
83 
84  # Setup ignite trainer
85  use_amp = configs["train"]["mixed_precision"] and self.devicedevice == torch.device(
86  "cuda"
87  )
88 
89  if use_amp:
90  from torch.cuda.amp import autocast
91  from torch.cuda.amp import GradScaler
92 
93  scaler = GradScaler(enabled=True)
94 
95  def _update_model(engine, batch):
96  # This just sets the training mode
97  model.train()
98 
99  optimizer.zero_grad()
100 
101  batch = (
102  Batch.from_data_list(batch).to(device)
103  if isinstance(batch, list)
104  else batch.to(device)
105  )
106 
107  x_y, edge_y, u_y = batch.x_y, batch.edge_y, batch.u_y
108 
109  if use_amp:
110  with autocast(enabled=True):
111  x_pred, e_pred, u_pred = model(batch)
112  loss = loss_fn(x_pred, x_y, e_pred, edge_y, u_pred, u_y)
113  scaler.scale(loss).backward()
114  scaler.step(optimizer)
115  scaler.update()
116  else:
117  x_pred, e_pred, u_pred = model(batch)
118  loss = loss_fn(x_pred, x_y, e_pred, edge_y, u_pred, u_y)
119  loss.backward()
120  optimizer.step()
121 
122  return loss.item()
123 
124 
125  self.trainertrainer = ignite.engine.Engine(_update_model)
126 
127  if scheduler:
128  ig_scheduler = ignite.handlers.param_scheduler.LRScheduler(scheduler)
129  self.trainertrainer.add_event_handler(ignite.engine.Events.ITERATION_STARTED, ig_scheduler)
130 
131 
132  self.evaluatorsevaluators = {}
133 
134  for tag in self.tagstags:
135  # Setup metrics
136  metrics = {
137  # ignite.metrics.Loss takes (y_pred, y, **kwargs) arguments.
138  # MultiTrainLoss needs in total 6 arguments to be computed,
139  # so the additional ones are passed in a dictionary.
140  "loss": ignite.metrics.Loss(
141  loss_fn,
142  output_transform=lambda x: [
143  x[0], x[3],
144  {
145  "edge_input": x[1],
146  "edge_target": x[4],
147  "u_input": x[2],
148  "u_target": x[5],
149  },
150  ],
151  device=device,
152  ),
153  "perfectLCA": PerfectLCA(
154  ignore_index=ignore_index,
155  device=device,
156  output_transform=lambda x: [
157  x[1], x[4], x[6], x[5], x[7], x[8],
158  ],
159  ),
160  "perfectMasses": PerfectMasses(
161  ignore_index=ignore_index,
162  device=device,
163  output_transform=lambda x: [x[0], x[3], x[5], x[7], x[8]],
164  ),
165  "perfectEvent": PerfectEvent(
166  ignore_index=ignore_index,
167  device=device,
168  output_transform=lambda x: [
169  x[0], x[3], x[1], x[4], x[6], x[5], x[7], x[8],
170  ],
171  ),
172  }
173 
174  def _predict_on_batch(engine, batch):
175  model.eval() # It just enables evaluation mode
176 
177  batch = (
178  Batch.from_data_list(batch).to(device)
179  if isinstance(batch, list)
180  else batch.to(device)
181  )
182 
183  x_y, edge_y, u_y, edge_index, torch_batch = (
184  batch.x_y,
185  batch.edge_y,
186  batch.u_y,
187  batch.edge_index,
188  batch.batch,
189  )
190  num_graph = batch.batch[torch_batch.shape[0] - 1] + 1
191 
192  with torch.no_grad():
193  if use_amp:
194  with autocast(enabled=True):
195  x_pred, e_pred, u_pred = model(batch)
196  else:
197  x_pred, e_pred, u_pred = model(batch)
198 
199  return (
200  x_pred,
201  e_pred,
202  u_pred,
203  x_y,
204  edge_y,
205  u_y,
206  edge_index,
207  torch_batch,
208  num_graph,
209  )
210 
211  self.evaluatorsevaluators[tag] = ignite.engine.Engine(_predict_on_batch)
212 
213  for metric_name, metric in zip(metrics.keys(), metrics.values()):
214  metric.attach(self.evaluatorsevaluators[tag], metric_name)
215 
216  def _score_fn(self, engine):
217  """Metric to use for early stoppging"""
218  return engine.state.metrics["loss"]
219 
220  def _perfect_score_fn(self, engine):
221  """Metric to use for checkpoints"""
222  return engine.state.metrics["perfectEvent"]
223 
224  def _clean_config_dict(self, configs):
225  """
226  Clean configs to prepare them for writing to file.
227  """
228  for k, v in configs.items():
229  if isinstance(v, collections.abc.Mapping):
230  configs[k] = self._clean_config_dict_clean_config_dict(configs[k])
231  elif isinstance(v, np.ndarray):
232  configs[k] = v.tolist()
233  else:
234  configs[k] = v
235  return configs
236 
237  def setup_handlers(self, cfg_filename="config.yaml"):
238  """
239  Creates the various ignite handlers (callbacks).
240 
241  Args:
242  cfg_filename (str): Name of config yaml file to use when saving configs.
243  """
244  # Create the output directory
245  if self.run_dirrun_dir is not None:
246  self.run_dirrun_dir.mkdir(parents=True, exist_ok=True)
247  # And save the configs, putting here to only save when setting up checkpointing
248  with open(
249  self.run_dirrun_dir / f"{self.timestamp}_{cfg_filename}", "w"
250  ) as outfile:
251  cleaned_configs = self._clean_config_dict_clean_config_dict(self.configsconfigs)
252  yaml.dump(cleaned_configs, outfile, default_flow_style=False)
253 
254  # Setup early stopping
255  early_handler = ignite.handlers.EarlyStopping(
256  patience=self.configsconfigs["train"]["early_stop_patience"],
257  score_function=self._score_fn_score_fn,
258  trainer=self.trainertrainer,
259  min_delta=1e-3,
260  )
261  self.evaluatorsevaluators["Validation"].add_event_handler(
262  ignite.engine.Events.EPOCH_COMPLETED, early_handler
263  )
264 
265  # Configure saving the best performing model
266  if self.run_dirrun_dir is not None:
267  to_save = {
268  "model": self.modelmodel,
269  "optimizer": self.optimizeroptimizer,
270  "trainer": self.trainertrainer,
271  }
272  # Note that we judge early stopping above by the validation loss, but save the best model
273  # according to validation perfectEvent score. This lets training continue for perfectEvent plateaus
274  # so long as the model is still changing (and hopefully improving again after some time).
275  best_model_handler = ignite.handlers.Checkpoint(
276  to_save=to_save,
277  save_handler=ignite.handlers.DiskSaver(
278  self.run_dirrun_dir, create_dir=True, require_empty=False
279  ),
280  filename_prefix=self.timestamptimestamp,
281  score_function=self._perfect_score_fn_perfect_score_fn,
282  score_name="validation_perfectEvent",
283  n_saved=1,
284  global_step_transform=ignite.handlers.global_step_from_engine(
285  self.evaluatorsevaluators["Validation"]
286  ),
287  )
288  self.evaluatorsevaluators["Validation"].add_event_handler(
289  ignite.engine.Events.EPOCH_COMPLETED, best_model_handler
290  )
291 
292  return
293 
294  # Set up end of epoch validation procedure
295  # Tell it to print epoch results
296  def log_results(self, trainer, mode_tags):
297  """
298  Callback to run evaluation and report the results.
299 
300  :param trainer: Trainer passed by ignite to this method.
301  :type trainer: `Engine <https://pytorch.org/ignite/generated/ignite.engine.engine.Engine.html#ignite.engine.engine.Engine>`_
302  :param mode_tags: Dictionary of mode tags containing (mode, dataset, dataloader) tuples.
303  :type mode_tags: dict
304  """
305 
306  for tag, values in mode_tags.items():
307  evaluator = self.evaluatorsevaluators[tag]
308 
309  # Need to wrap this in autocast since it caculates metrics (i.e. loss) without autocast switched on
310  # This is mostly fine except it fails to correctly cast the class weights tensor passed to the loss
311  if self.configsconfigs["train"]["mixed_precision"] and self.devicedevice == torch.device("cuda"):
312  with torch.cuda.amp.autocast():
313  evaluator.run(values[2], epoch_length=None)
314  else:
315  evaluator.run(values[2], epoch_length=None)
316 
317  metrics = evaluator.state.metrics
318  message = [f"{tag} Results - Epoch: {trainer.state.epoch}"]
319  message.extend([f"Avg {m}: {metrics[m]:.4f}" for m in metrics])
320  print(message)
def __init__(self, model, optimizer, loss_fn, device, configs, tags, scheduler=None, ignore_index=-1.0)
timestamp
Run timestamp to distinguish trainings.
def log_results(self, trainer, mode_tags)
run_dir
Output directory for checkpoints.
def setup_handlers(self, cfg_filename="config.yaml")
evaluators
Setup train and validation evaluators.