# basf2 (Belle II Analysis Software Framework)                           #
# Author: The Belle II Collaboration                                     #
#                                                                        #
# See git log for contributors and copyright holders.                    #
# This file is licensed under LGPL-3.0, see                  #

import ignite
import torch
from import Batch
import numpy as np
from datetime import datetime
from pathlib import Path
import yaml
from .metrics import PerfectLCA, PerfectEvent, PerfectMasses

[docs] class GraFEIIgniteTrainer: """ Class to setup the ignite trainer and hold all the things associated. :param model: The actual PyTorch model. :type model: `Model <>`_ :param optimizer: Optimizer used in training. :type optimizer: `Optimizer <>`_ :param loss_fn: Loss function. :type loss_fn: `Loss <>`_ :param device: Device to use. :type device: `Device <>`_ :param configs: Dictionary of run configs from loaded yaml config file. :type configs: dict :param tags: Various tags to sort train and validation evaluators by, e.g. "Training", "Validation". :type tags: list :param scheduler: Learning rate scheduler. :type scheduler: `Scheduler <>`_ :param ignore_index: Label index to ignore when calculating metrics, e.g. padding. :type ignore_index: int """ def __init__( self, model, optimizer, loss_fn, device, configs, tags, scheduler=None, ignore_index=-1.0, ): """ Initialization. """ #: Model self.model = model #: Optimizer self.optimizer = optimizer #: Configs self.configs = configs #: Tags self.tags = tags #: Index to ignore self.ignore_index = ignore_index #: CPU or GPU self.device = device #: Run timestamp to distinguish trainings self.timestamp ="%Y.%m.%d_%H.%M") #: Output directory for checkpoints self.run_dir = None if self.configs["output"] is not None: if ("path" in self.configs["output"].keys()) and ( self.configs["output"]["path"] is not None ): self.run_dir = Path( self.configs["output"]["path"], self.configs["output"]["run_name"], ) # Setup ignite trainer use_amp = configs["train"]["mixed_precision"] and self.device == torch.device( "cuda" ) if use_amp: from torch.cuda.amp import autocast from torch.cuda.amp import GradScaler scaler = GradScaler(enabled=True) def _update_model(engine, batch): # This just sets the training mode model.train() optimizer.zero_grad() batch = ( Batch.from_data_list(batch).to(device) if isinstance(batch, list) else ) x_y, edge_y, u_y = batch.x_y, batch.edge_y, batch.u_y if use_amp: with autocast(enabled=True): x_pred, e_pred, u_pred = model(batch) loss = loss_fn(x_pred, x_y, e_pred, edge_y, u_pred, u_y) scaler.scale(loss).backward() scaler.step(optimizer) scaler.update() else: x_pred, e_pred, u_pred = model(batch) loss = loss_fn(x_pred, x_y, e_pred, edge_y, u_pred, u_y) loss.backward() optimizer.step() return loss.item() #: Ignite trainer self.trainer = ignite.engine.Engine(_update_model) if scheduler: ig_scheduler = ignite.handlers.param_scheduler.LRScheduler(scheduler) self.trainer.add_event_handler(ignite.engine.Events.ITERATION_STARTED, ig_scheduler) #: Setup train and validation evaluators self.evaluators = {} for tag in self.tags: # Setup metrics metrics = { # ignite.metrics.Loss takes (y_pred, y, **kwargs) arguments. # MultiTrainLoss needs in total 6 arguments to be computed, # so the additional ones are passed in a dictionary. "loss": ignite.metrics.Loss( loss_fn, output_transform=lambda x: [ x[0], x[3], { "edge_input": x[1], "edge_target": x[4], "u_input": x[2], "u_target": x[5], }, ], device=device, ), "perfectLCA": PerfectLCA( ignore_index=ignore_index, device=device, output_transform=lambda x: [ x[1], x[4], x[6], x[5], x[7], x[8], ], ), "perfectMasses": PerfectMasses( ignore_index=ignore_index, device=device, output_transform=lambda x: [x[0], x[3], x[5], x[7], x[8]], ), "perfectEvent": PerfectEvent( ignore_index=ignore_index, device=device, output_transform=lambda x: [ x[0], x[3], x[1], x[4], x[6], x[5], x[7], x[8], ], ), } def _predict_on_batch(engine, batch): model.eval() # It just enables evaluation mode batch = ( Batch.from_data_list(batch).to(device) if isinstance(batch, list) else ) x_y, edge_y, u_y, edge_index, torch_batch = ( batch.x_y, batch.edge_y, batch.u_y, batch.edge_index, batch.batch, ) num_graph = batch.batch[torch_batch.shape[0] - 1] + 1 with torch.no_grad(): if use_amp: with autocast(enabled=True): x_pred, e_pred, u_pred = model(batch) else: x_pred, e_pred, u_pred = model(batch) return ( x_pred, e_pred, u_pred, x_y, edge_y, u_y, edge_index, torch_batch, num_graph, ) self.evaluators[tag] = ignite.engine.Engine(_predict_on_batch) for metric_name, metric in zip(metrics.keys(), metrics.values()): metric.attach(self.evaluators[tag], metric_name) def _score_fn(self, engine): """Metric to use for early stoppging""" return -engine.state.metrics["loss"] def _perfect_score_fn(self, engine): """Metric to use for checkpoints""" return engine.state.metrics["perfectEvent"] def _clean_config_dict(self, configs): """ Clean configs to prepare them for writing to file. """ for k, v in configs.items(): if isinstance(v, configs[k] = self._clean_config_dict(configs[k]) elif isinstance(v, np.ndarray): configs[k] = v.tolist() else: configs[k] = v return configs def setup_handlers(self, cfg_filename="config.yaml"): """ Creates the various ignite handlers (callbacks). Args: cfg_filename (str): Name of config yaml file to use when saving configs. """ # Create the output directory if self.run_dir is not None: self.run_dir.mkdir(parents=True, exist_ok=True) # And save the configs, putting here to only save when setting up checkpointing with open( self.run_dir / f"{self.timestamp}_{cfg_filename}", "w" ) as outfile: cleaned_configs = self._clean_config_dict(self.configs) yaml.dump(cleaned_configs, outfile, default_flow_style=False) # Setup early stopping early_handler = ignite.handlers.EarlyStopping( patience=self.configs["train"]["early_stop_patience"], score_function=self._score_fn, trainer=self.trainer, min_delta=1e-3, ) self.evaluators["Validation"].add_event_handler( ignite.engine.Events.EPOCH_COMPLETED, early_handler ) # Configure saving the best performing model if self.run_dir is not None: to_save = { "model": self.model, "optimizer": self.optimizer, "trainer": self.trainer, } # Note that we judge early stopping above by the validation loss, but save the best model # according to validation perfectEvent score. This lets training continue for perfectEvent plateaus # so long as the model is still changing (and hopefully improving again after some time). best_model_handler = ignite.handlers.Checkpoint( to_save=to_save, save_handler=ignite.handlers.DiskSaver( self.run_dir, create_dir=True, require_empty=False ), filename_prefix=self.timestamp, score_function=self._perfect_score_fn, score_name="validation_perfectEvent", n_saved=1, global_step_transform=ignite.handlers.global_step_from_engine( self.evaluators["Validation"] ), ) self.evaluators["Validation"].add_event_handler( ignite.engine.Events.EPOCH_COMPLETED, best_model_handler ) return # Set up end of epoch validation procedure # Tell it to print epoch results def log_results(self, trainer, mode_tags): """ Callback to run evaluation and report the results. :param trainer: Trainer passed by ignite to this method. :type trainer: `Engine <>`_ :param mode_tags: Dictionary of mode tags containing (mode, dataset, dataloader) tuples. :type mode_tags: dict """ for tag, values in mode_tags.items(): evaluator = self.evaluators[tag] # Need to wrap this in autocast since it caculates metrics (i.e. loss) without autocast switched on # This is mostly fine except it fails to correctly cast the class weights tensor passed to the loss if self.configs["train"]["mixed_precision"] and self.device == torch.device("cuda"): with torch.cuda.amp.autocast():[2], epoch_length=None) else:[2], epoch_length=None) metrics = evaluator.state.metrics message = [f"{tag} Results - Epoch: {trainer.state.epoch}"] message.extend([f"Avg {m}: {metrics[m]:.4f}" for m in metrics]) print(message)