12from torch_geometric.data
import Batch
15from datetime
import datetime
16from pathlib
import Path
18from .metrics
import PerfectLCA, PerfectEvent, PerfectMasses
23 Class to setup the ignite trainer and hold all the things associated.
25 :param model: The actual PyTorch model.
26 :type model: `Model <https://pytorch.org/tutorials/beginner/introyt/modelsyt_tutorial.html>`_
27 :param optimizer: Optimizer used
in training.
28 :type optimizer: `Optimizer <https://pytorch.org/docs/stable/optim.html
29 :param loss_fn: Loss function.
30 :type loss_fn: `Loss <https://pytorch.org/docs/stable/nn.html
31 :param device: Device to use.
32 :type device: `Device <https://pytorch.org/docs/stable/tensor_attributes.html
33 :param configs: Dictionary of run configs
from loaded yaml config file.
35 :param tags: Various tags to sort train
and validation evaluators by, e.g.
"Training",
"Validation".
37 :param scheduler: Learning rate scheduler.
38 :type scheduler: `Scheduler <https://pytorch.org/docs/stable/optim.html
39 :param ignore_index: Label index to ignore when calculating metrics, e.g. padding.
40 :type ignore_index: int
71 self.timestamp = datetime.now().strftime("%Y.%m.%d_%H.%M")
75 if self.
configs[
"output"]
is not None:
76 if (
"path" in self.
configs[
"output"].keys())
and (
77 self.
configs[
"output"][
"path"]
is not None
81 self.
configs[
"output"][
"run_name"],
85 use_amp = configs[
"train"][
"mixed_precision"]
and self.
device == torch.device(
90 from torch.cuda.amp
import autocast
91 from torch.cuda.amp
import GradScaler
93 scaler = GradScaler(enabled=
True)
95 def _update_model(engine, batch):
102 Batch.from_data_list(batch).to(device)
103 if isinstance(batch, list)
104 else batch.to(device)
107 x_y, edge_y, u_y = batch.x_y, batch.edge_y, batch.u_y
110 with autocast(enabled=
True):
111 x_pred, e_pred, u_pred =
model(batch)
112 loss = loss_fn(x_pred, x_y, e_pred, edge_y, u_pred, u_y)
113 scaler.scale(loss).backward()
114 scaler.step(optimizer)
117 x_pred, e_pred, u_pred =
model(batch)
118 loss = loss_fn(x_pred, x_y, e_pred, edge_y, u_pred, u_y)
125 self.
trainer = ignite.engine.Engine(_update_model)
128 ig_scheduler = ignite.handlers.param_scheduler.LRScheduler(scheduler)
129 self.
trainer.add_event_handler(ignite.engine.Events.ITERATION_STARTED, ig_scheduler)
134 for tag
in self.
tags:
140 "loss": ignite.metrics.Loss(
142 output_transform=
lambda x: [
154 ignore_index=ignore_index,
156 output_transform=
lambda x: [
157 x[1], x[4], x[6], x[5], x[7], x[8],
161 ignore_index=ignore_index,
163 output_transform=
lambda x: [x[0], x[3], x[5], x[7], x[8]],
166 ignore_index=ignore_index,
168 output_transform=
lambda x: [
169 x[0], x[3], x[1], x[4], x[6], x[5], x[7], x[8],
174 def _predict_on_batch(engine, batch):
178 Batch.from_data_list(batch).to(device)
179 if isinstance(batch, list)
180 else batch.to(device)
183 x_y, edge_y, u_y, edge_index, torch_batch = (
190 num_graph = batch.batch[torch_batch.shape[0] - 1] + 1
192 with torch.no_grad():
194 with autocast(enabled=
True):
195 x_pred, e_pred, u_pred =
model(batch)
197 x_pred, e_pred, u_pred =
model(batch)
211 self.
evaluators[tag] = ignite.engine.Engine(_predict_on_batch)
213 for metric_name, metric
in zip(metrics.keys(), metrics.values()):
214 metric.attach(self.
evaluators[tag], metric_name)
217 """Metric to use for early stoppging"""
218 return -engine.state.metrics[
"loss"]
221 """Metric to use for checkpoints"""
222 return engine.state.metrics[
"perfectEvent"]
226 Clean configs to prepare them for writing to file.
228 for k, v
in configs.items():
229 if isinstance(v, collections.abc.Mapping):
231 elif isinstance(v, np.ndarray):
232 configs[k] = v.tolist()
239 Creates the various ignite handlers (callbacks).
242 cfg_filename (str): Name of config yaml file to use when saving configs.
246 self.
run_dir.mkdir(parents=
True, exist_ok=
True)
249 self.
run_dir / f
"{self.timestamp}_{cfg_filename}",
"w"
252 yaml.dump(cleaned_configs, outfile, default_flow_style=
False)
255 early_handler = ignite.handlers.EarlyStopping(
256 patience=self.
configs[
"train"][
"early_stop_patience"],
261 self.
evaluators[
"Validation"].add_event_handler(
262 ignite.engine.Events.EPOCH_COMPLETED, early_handler
275 best_model_handler = ignite.handlers.Checkpoint(
277 save_handler=ignite.handlers.DiskSaver(
278 self.
run_dir, create_dir=
True, require_empty=
False
282 score_name=
"validation_perfectEvent",
284 global_step_transform=ignite.handlers.global_step_from_engine(
288 self.
evaluators[
"Validation"].add_event_handler(
289 ignite.engine.Events.EPOCH_COMPLETED, best_model_handler
298 Callback to run evaluation and report the results.
300 :param trainer: Trainer passed by ignite to this method.
301 :type trainer: `Engine <https://pytorch.org/ignite/generated/ignite.engine.engine.Engine.html
302 :param mode_tags: Dictionary of mode tags containing (mode, dataset, dataloader) tuples.
303 :type mode_tags: dict
306 for tag, values
in mode_tags.items():
311 if self.
configs[
"train"][
"mixed_precision"]
and self.
device == torch.device(
"cuda"):
312 with torch.cuda.amp.autocast():
313 evaluator.run(values[2], epoch_length=
None)
315 evaluator.run(values[2], epoch_length=
None)
317 metrics = evaluator.state.metrics
318 message = [f
"{tag} Results - Epoch: {trainer.state.epoch}"]
319 message.extend([f
"Avg {m}: {metrics[m]:.4f}" for m
in metrics])
ignore_index
Index to ignore.
def _perfect_score_fn(self, engine)
def __init__(self, model, optimizer, loss_fn, device, configs, tags, scheduler=None, ignore_index=-1.0)
timestamp
Run timestamp to distinguish trainings.
def log_results(self, trainer, mode_tags)
run_dir
Output directory for checkpoints.
def setup_handlers(self, cfg_filename="config.yaml")
def _clean_config_dict(self, configs)
def _score_fn(self, engine)
evaluators
Setup train and validation evaluators.