71 self.
timestamp = datetime.now().strftime(
"%Y.%m.%d_%H.%M")
75 if self.
configs[
"output"]
is not None:
76 if (
"path" in self.
configs[
"output"].keys())
and (
77 self.
configs[
"output"][
"path"]
is not None
81 self.
configs[
"output"][
"run_name"],
85 use_amp = configs[
"train"][
"mixed_precision"]
and self.
device == torch.device(
90 from torch.cuda.amp
import autocast
91 from torch.cuda.amp
import GradScaler
93 scaler = GradScaler(enabled=
True)
95 def _update_model(engine, batch):
102 Batch.from_data_list(batch).to(device)
103 if isinstance(batch, list)
104 else batch.to(device)
107 x_y, edge_y, u_y = batch.x_y, batch.edge_y, batch.u_y
110 with autocast(enabled=
True):
111 x_pred, e_pred, u_pred =
model(batch)
112 loss = loss_fn(x_pred, x_y, e_pred, edge_y, u_pred, u_y)
113 scaler.scale(loss).backward()
114 scaler.step(optimizer)
117 x_pred, e_pred, u_pred =
model(batch)
118 loss = loss_fn(x_pred, x_y, e_pred, edge_y, u_pred, u_y)
125 self.
trainer = ignite.engine.Engine(_update_model)
128 ig_scheduler = ignite.handlers.param_scheduler.LRScheduler(scheduler)
129 self.
trainer.add_event_handler(ignite.engine.Events.ITERATION_STARTED, ig_scheduler)
134 for tag
in self.
tags:
140 "loss": ignite.metrics.Loss(
142 output_transform=
lambda x: [
154 ignore_index=ignore_index,
156 output_transform=
lambda x: [
157 x[1], x[4], x[6], x[5], x[7], x[8],
161 ignore_index=ignore_index,
163 output_transform=
lambda x: [x[0], x[3], x[5], x[7], x[8]],
166 ignore_index=ignore_index,
168 output_transform=
lambda x: [
169 x[0], x[3], x[1], x[4], x[6], x[5], x[7], x[8],
174 def _predict_on_batch(engine, batch):
178 Batch.from_data_list(batch).to(device)
179 if isinstance(batch, list)
180 else batch.to(device)
183 x_y, edge_y, u_y, edge_index, torch_batch = (
190 num_graph = batch.batch[torch_batch.shape[0] - 1] + 1
192 with torch.no_grad():
194 with autocast(enabled=
True):
195 x_pred, e_pred, u_pred =
model(batch)
197 x_pred, e_pred, u_pred =
model(batch)
211 self.
evaluators[tag] = ignite.engine.Engine(_predict_on_batch)
213 for metric_name, metric
in zip(metrics.keys(), metrics.values()):
214 metric.attach(self.
evaluators[tag], metric_name)
239 Creates the various ignite handlers (callbacks).
242 cfg_filename (str): Name of config yaml file to use when saving configs.
246 self.
run_dir.mkdir(parents=
True, exist_ok=
True)
249 self.
run_dir / f
"{self.timestamp}_{cfg_filename}",
"w"
252 yaml.dump(cleaned_configs, outfile, default_flow_style=
False)
255 early_handler = ignite.handlers.EarlyStopping(
256 patience=self.
configs[
"train"][
"early_stop_patience"],
261 self.
evaluators[
"Validation"].add_event_handler(
262 ignite.engine.Events.EPOCH_COMPLETED, early_handler
275 best_model_handler = ignite.handlers.Checkpoint(
277 save_handler=ignite.handlers.DiskSaver(
278 self.
run_dir, create_dir=
True, require_empty=
False
282 score_name=
"validation_perfectEvent",
284 global_step_transform=ignite.handlers.global_step_from_engine(
288 self.
evaluators[
"Validation"].add_event_handler(
289 ignite.engine.Events.EPOCH_COMPLETED, best_model_handler
298 Callback to run evaluation and report the results.
300 :param trainer: Trainer passed by ignite to this method.
301 :type trainer: `Engine <https://pytorch.org/ignite/generated/ignite.engine.engine.Engine.html#ignite.engine.engine.Engine>`_
302 :param mode_tags: Dictionary of mode tags containing (mode, dataset, dataloader) tuples.
303 :type mode_tags: dict
306 for tag, values
in mode_tags.items():
311 if self.
configs[
"train"][
"mixed_precision"]
and self.
device == torch.device(
"cuda"):
312 with torch.cuda.amp.autocast():
313 evaluator.run(values[2], epoch_length=
None)
315 evaluator.run(values[2], epoch_length=
None)
317 metrics = evaluator.state.metrics
318 message = [f
"{tag} Results - Epoch: {trainer.state.epoch}"]
319 message.extend([f
"Avg {m}: {metrics[m]:.4f}" for m
in metrics])