71        self.
timestamp = datetime.now().strftime(
"%Y.%m.%d_%H.%M")
 
   75        if self.
configs[
"output"] 
is not None:
 
   76            if (
"path" in self.
configs[
"output"].keys()) 
and (
 
   77                self.
configs[
"output"][
"path"] 
is not None 
   81                    self.
configs[
"output"][
"run_name"],
 
   85        use_amp = configs[
"train"][
"mixed_precision"] 
and self.
device == torch.device(
 
   90            from torch.cuda.amp 
import autocast
 
   91            from torch.cuda.amp 
import GradScaler
 
   93            scaler = GradScaler(enabled=
True)
 
   95        def _update_model(engine, batch):
 
  102                Batch.from_data_list(batch).to(device)
 
  103                if isinstance(batch, list)
 
  104                else batch.to(device)
 
  107            x_y, edge_y, u_y = batch.x_y, batch.edge_y, batch.u_y
 
  110                with autocast(enabled=
True):
 
  111                    x_pred, e_pred, u_pred = 
model(batch)
 
  112                    loss = loss_fn(x_pred, x_y, e_pred, edge_y, u_pred, u_y)
 
  113                scaler.scale(loss).backward()
 
  114                scaler.step(optimizer)
 
  117                x_pred, e_pred, u_pred = 
model(batch)
 
  118                loss = loss_fn(x_pred, x_y, e_pred, edge_y, u_pred, u_y)
 
  125        self.
trainer = ignite.engine.Engine(_update_model)
 
  128            ig_scheduler = ignite.handlers.param_scheduler.LRScheduler(scheduler)
 
  129            self.
trainer.add_event_handler(ignite.engine.Events.ITERATION_STARTED, ig_scheduler)
 
  134        for tag 
in self.
tags:
 
  140                "loss": ignite.metrics.Loss(
 
  142                    output_transform=
lambda x: [
 
  154                    ignore_index=ignore_index,
 
  156                    output_transform=
lambda x: [
 
  157                        x[1], x[4], x[6], x[5], x[7], x[8],
 
  161                    ignore_index=ignore_index,
 
  163                    output_transform=
lambda x: [x[0], x[3], x[5], x[7], x[8]],
 
  166                    ignore_index=ignore_index,
 
  168                    output_transform=
lambda x: [
 
  169                        x[0], x[3], x[1], x[4], x[6], x[5], x[7], x[8],
 
  174            def _predict_on_batch(engine, batch):
 
  178                    Batch.from_data_list(batch).to(device)
 
  179                    if isinstance(batch, list)
 
  180                    else batch.to(device)
 
  183                x_y, edge_y, u_y, edge_index, torch_batch = (
 
  190                num_graph = batch.batch[torch_batch.shape[0] - 1] + 1
 
  192                with torch.no_grad():
 
  194                        with autocast(enabled=
True):
 
  195                            x_pred, e_pred, u_pred = 
model(batch)
 
  197                        x_pred, e_pred, u_pred = 
model(batch)
 
  211            self.
evaluators[tag] = ignite.engine.Engine(_predict_on_batch)
 
  213            for metric_name, metric 
in zip(metrics.keys(), metrics.values()):
 
  214                metric.attach(self.
evaluators[tag], metric_name)
 
 
  239        Creates the various ignite handlers (callbacks). 
  242            cfg_filename (str): Name of config yaml file to use when saving configs. 
  246            self.
run_dir.mkdir(parents=
True, exist_ok=
True)
 
  249                self.
run_dir / f
"{self.timestamp}_{cfg_filename}", 
"w" 
  252                yaml.dump(cleaned_configs, outfile, default_flow_style=
False)
 
  255        early_handler = ignite.handlers.EarlyStopping(
 
  256            patience=self.
configs[
"train"][
"early_stop_patience"],
 
  261        self.
evaluators[
"Validation"].add_event_handler(
 
  262            ignite.engine.Events.EPOCH_COMPLETED, early_handler
 
  275            best_model_handler = ignite.handlers.Checkpoint(
 
  277                save_handler=ignite.handlers.DiskSaver(
 
  278                    self.
run_dir, create_dir=
True, require_empty=
False 
  282                score_name=
"validation_perfectEvent",
 
  284                global_step_transform=ignite.handlers.global_step_from_engine(
 
  288            self.
evaluators[
"Validation"].add_event_handler(
 
  289                ignite.engine.Events.EPOCH_COMPLETED, best_model_handler
 
 
  298        Callback to run evaluation and report the results. 
  300        :param trainer: Trainer passed by ignite to this method. 
  301        :type trainer: `Engine <https://pytorch.org/ignite/generated/ignite.engine.engine.Engine.html#ignite.engine.engine.Engine>`_ 
  302        :param mode_tags: Dictionary of mode tags containing (mode, dataset, dataloader) tuples. 
  303        :type mode_tags: dict 
  306        for tag, values 
in mode_tags.items():
 
  311            if self.
configs[
"train"][
"mixed_precision"] 
and self.
device == torch.device(
"cuda"):
 
  312                with torch.cuda.amp.autocast():
 
  313                    evaluator.run(values[2], epoch_length=
None)
 
  315                evaluator.run(values[2], epoch_length=
None)
 
  317            metrics = evaluator.state.metrics
 
  318            message = [f
"{tag} Results - Epoch: {trainer.state.epoch}"]
 
  319            message.extend([f
"Avg {m}: {metrics[m]:.4f}" for m 
in metrics])