60 self.optimizer = optimizer
62 self.configs = configs
66 self.ignore_index = ignore_index
71 self.timestamp = datetime.now().strftime(
"%Y.%m.%d_%H.%M")
75 if self.configs[
"output"]
is not None:
76 if (
"path" in self.configs[
"output"].keys())
and (
77 self.configs[
"output"][
"path"]
is not None
80 self.configs[
"output"][
"path"],
81 self.configs[
"output"][
"run_name"],
85 use_amp = configs[
"train"][
"mixed_precision"]
and self.device == torch.device(
90 from torch.cuda.amp
import autocast
91 from torch.cuda.amp
import GradScaler
93 scaler = GradScaler(enabled=
True)
95 def _update_model(engine, batch):
102 Batch.from_data_list(batch).to(device)
103 if isinstance(batch, list)
104 else batch.to(device)
107 x_y, edge_y, u_y = batch.x_y, batch.edge_y, batch.u_y
110 with autocast(enabled=
True):
111 x_pred, e_pred, u_pred = model(batch)
112 loss = loss_fn(x_pred, x_y, e_pred, edge_y, u_pred, u_y)
113 scaler.scale(loss).backward()
114 scaler.step(optimizer)
117 x_pred, e_pred, u_pred = model(batch)
118 loss = loss_fn(x_pred, x_y, e_pred, edge_y, u_pred, u_y)
125 self.trainer = ignite.engine.Engine(_update_model)
128 ig_scheduler = ignite.handlers.param_scheduler.LRScheduler(scheduler)
129 self.trainer.add_event_handler(ignite.engine.Events.ITERATION_STARTED, ig_scheduler)
134 for tag
in self.tags:
140 "loss": ignite.metrics.Loss(
142 output_transform=
lambda x: [
153 "perfectLCA": PerfectLCA(
154 ignore_index=ignore_index,
156 output_transform=
lambda x: [
157 x[1], x[4], x[6], x[5], x[7], x[8],
160 "perfectMasses": PerfectMasses(
161 ignore_index=ignore_index,
163 output_transform=
lambda x: [x[0], x[3], x[5], x[7], x[8]],
165 "perfectEvent": PerfectEvent(
166 ignore_index=ignore_index,
168 output_transform=
lambda x: [
169 x[0], x[3], x[1], x[4], x[6], x[5], x[7], x[8],
174 def _predict_on_batch(engine, batch):
178 Batch.from_data_list(batch).to(device)
179 if isinstance(batch, list)
180 else batch.to(device)
183 x_y, edge_y, u_y, edge_index, torch_batch = (
190 num_graph = batch.batch[torch_batch.shape[0] - 1] + 1
192 with torch.no_grad():
194 with autocast(enabled=
True):
195 x_pred, e_pred, u_pred = model(batch)
197 x_pred, e_pred, u_pred = model(batch)
211 self.evaluators[tag] = ignite.engine.Engine(_predict_on_batch)
213 for metric_name, metric
in zip(metrics.keys(), metrics.values()):
214 metric.attach(self.evaluators[tag], metric_name)