Belle II Software development
GraFEIIgniteTrainer Class Reference

Public Member Functions

def __init__ (self, model, optimizer, loss_fn, device, configs, tags, scheduler=None, ignore_index=-1.0)
 
def setup_handlers (self, cfg_filename="config.yaml")
 
def log_results (self, trainer, mode_tags)
 

Public Attributes

 model
 Model.
 
 optimizer
 Optimizer.
 
 configs
 Configs.
 
 tags
 Tags.
 
 ignore_index
 Index to ignore.
 
 device
 CPU or GPU.
 
 timestamp
 Run timestamp to distinguish trainings.
 
 run_dir
 Output directory for checkpoints.
 
 trainer
 Ignite trainer.
 
 evaluators
 Setup train and validation evaluators.
 

Protected Member Functions

def _score_fn (self, engine)
 
def _perfect_score_fn (self, engine)
 
def _clean_config_dict (self, configs)
 

Detailed Description

Class to setup the ignite trainer and hold all the things associated.

:param model: The actual PyTorch model.
:type model: `Model <https://pytorch.org/tutorials/beginner/introyt/modelsyt_tutorial.html>`_
:param optimizer: Optimizer used in training.
:type optimizer: `Optimizer <https://pytorch.org/docs/stable/optim.html#torch.optim.Optimizer>`_
:param loss_fn: Loss function.
:type loss_fn: `Loss <https://pytorch.org/docs/stable/nn.html#loss-functions>`_
:param device: Device to use.
:type device: `Device <https://pytorch.org/docs/stable/tensor_attributes.html#torch.device>`_
:param configs: Dictionary of run configs from loaded yaml config file.
:type configs: dict
:param tags: Various tags to sort train and validation evaluators by, e.g. "Training", "Validation".
:type tags: list
:param scheduler: Learning rate scheduler.
:type scheduler: `Scheduler <https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate>`_
:param ignore_index: Label index to ignore when calculating metrics, e.g. padding.
:type ignore_index: int

Definition at line 21 of file create_trainer.py.

Constructor & Destructor Documentation

◆ __init__()

def __init__ (   self,
  model,
  optimizer,
  loss_fn,
  device,
  configs,
  tags,
  scheduler = None,
  ignore_index = -1.0 
)
Initialization.

Definition at line 43 of file create_trainer.py.

53 ):
54 """
55 Initialization.
56 """
57
58 self.model = model
59
60 self.optimizer = optimizer
61
62 self.configs = configs
63
64 self.tags = tags
65
66 self.ignore_index = ignore_index
67
68 self.device = device
69
70
71 self.timestamp = datetime.now().strftime("%Y.%m.%d_%H.%M")
72
73
74 self.run_dir = None
75 if self.configs["output"] is not None:
76 if ("path" in self.configs["output"].keys()) and (
77 self.configs["output"]["path"] is not None
78 ):
79 self.run_dir = Path(
80 self.configs["output"]["path"],
81 self.configs["output"]["run_name"],
82 )
83
84 # Setup ignite trainer
85 use_amp = configs["train"]["mixed_precision"] and self.device == torch.device(
86 "cuda"
87 )
88
89 if use_amp:
90 from torch.cuda.amp import autocast
91 from torch.cuda.amp import GradScaler
92
93 scaler = GradScaler(enabled=True)
94
95 def _update_model(engine, batch):
96 # This just sets the training mode
97 model.train()
98
99 optimizer.zero_grad()
100
101 batch = (
102 Batch.from_data_list(batch).to(device)
103 if isinstance(batch, list)
104 else batch.to(device)
105 )
106
107 x_y, edge_y, u_y = batch.x_y, batch.edge_y, batch.u_y
108
109 if use_amp:
110 with autocast(enabled=True):
111 x_pred, e_pred, u_pred = model(batch)
112 loss = loss_fn(x_pred, x_y, e_pred, edge_y, u_pred, u_y)
113 scaler.scale(loss).backward()
114 scaler.step(optimizer)
115 scaler.update()
116 else:
117 x_pred, e_pred, u_pred = model(batch)
118 loss = loss_fn(x_pred, x_y, e_pred, edge_y, u_pred, u_y)
119 loss.backward()
120 optimizer.step()
121
122 return loss.item()
123
124
125 self.trainer = ignite.engine.Engine(_update_model)
126
127 if scheduler:
128 ig_scheduler = ignite.handlers.param_scheduler.LRScheduler(scheduler)
129 self.trainer.add_event_handler(ignite.engine.Events.ITERATION_STARTED, ig_scheduler)
130
131
132 self.evaluators = {}
133
134 for tag in self.tags:
135 # Setup metrics
136 metrics = {
137 # ignite.metrics.Loss takes (y_pred, y, **kwargs) arguments.
138 # MultiTrainLoss needs in total 6 arguments to be computed,
139 # so the additional ones are passed in a dictionary.
140 "loss": ignite.metrics.Loss(
141 loss_fn,
142 output_transform=lambda x: [
143 x[0], x[3],
144 {
145 "edge_input": x[1],
146 "edge_target": x[4],
147 "u_input": x[2],
148 "u_target": x[5],
149 },
150 ],
151 device=device,
152 ),
153 "perfectLCA": PerfectLCA(
154 ignore_index=ignore_index,
155 device=device,
156 output_transform=lambda x: [
157 x[1], x[4], x[6], x[5], x[7], x[8],
158 ],
159 ),
160 "perfectMasses": PerfectMasses(
161 ignore_index=ignore_index,
162 device=device,
163 output_transform=lambda x: [x[0], x[3], x[5], x[7], x[8]],
164 ),
165 "perfectEvent": PerfectEvent(
166 ignore_index=ignore_index,
167 device=device,
168 output_transform=lambda x: [
169 x[0], x[3], x[1], x[4], x[6], x[5], x[7], x[8],
170 ],
171 ),
172 }
173
174 def _predict_on_batch(engine, batch):
175 model.eval() # It just enables evaluation mode
176
177 batch = (
178 Batch.from_data_list(batch).to(device)
179 if isinstance(batch, list)
180 else batch.to(device)
181 )
182
183 x_y, edge_y, u_y, edge_index, torch_batch = (
184 batch.x_y,
185 batch.edge_y,
186 batch.u_y,
187 batch.edge_index,
188 batch.batch,
189 )
190 num_graph = batch.batch[torch_batch.shape[0] - 1] + 1
191
192 with torch.no_grad():
193 if use_amp:
194 with autocast(enabled=True):
195 x_pred, e_pred, u_pred = model(batch)
196 else:
197 x_pred, e_pred, u_pred = model(batch)
198
199 return (
200 x_pred,
201 e_pred,
202 u_pred,
203 x_y,
204 edge_y,
205 u_y,
206 edge_index,
207 torch_batch,
208 num_graph,
209 )
210
211 self.evaluators[tag] = ignite.engine.Engine(_predict_on_batch)
212
213 for metric_name, metric in zip(metrics.keys(), metrics.values()):
214 metric.attach(self.evaluators[tag], metric_name)
215

Member Function Documentation

◆ _clean_config_dict()

def _clean_config_dict (   self,
  configs 
)
protected
Clean configs to prepare them for writing to file.

Definition at line 224 of file create_trainer.py.

224 def _clean_config_dict(self, configs):
225 """
226 Clean configs to prepare them for writing to file.
227 """
228 for k, v in configs.items():
229 if isinstance(v, collections.abc.Mapping):
230 configs[k] = self._clean_config_dict(configs[k])
231 elif isinstance(v, np.ndarray):
232 configs[k] = v.tolist()
233 else:
234 configs[k] = v
235 return configs
236

◆ _perfect_score_fn()

def _perfect_score_fn (   self,
  engine 
)
protected
Metric to use for checkpoints

Definition at line 220 of file create_trainer.py.

220 def _perfect_score_fn(self, engine):
221 """Metric to use for checkpoints"""
222 return engine.state.metrics["perfectEvent"]
223

◆ _score_fn()

def _score_fn (   self,
  engine 
)
protected
Metric to use for early stoppging

Definition at line 216 of file create_trainer.py.

216 def _score_fn(self, engine):
217 """Metric to use for early stoppging"""
218 return -engine.state.metrics["loss"]
219

◆ log_results()

def log_results (   self,
  trainer,
  mode_tags 
)
Callback to run evaluation and report the results.

:param trainer: Trainer passed by ignite to this method.
:type trainer: `Engine <https://pytorch.org/ignite/generated/ignite.engine.engine.Engine.html#ignite.engine.engine.Engine>`_
:param mode_tags: Dictionary of mode tags containing (mode, dataset, dataloader) tuples.
:type mode_tags: dict

Definition at line 296 of file create_trainer.py.

296 def log_results(self, trainer, mode_tags):
297 """
298 Callback to run evaluation and report the results.
299
300 :param trainer: Trainer passed by ignite to this method.
301 :type trainer: `Engine <https://pytorch.org/ignite/generated/ignite.engine.engine.Engine.html#ignite.engine.engine.Engine>`_
302 :param mode_tags: Dictionary of mode tags containing (mode, dataset, dataloader) tuples.
303 :type mode_tags: dict
304 """
305
306 for tag, values in mode_tags.items():
307 evaluator = self.evaluators[tag]
308
309 # Need to wrap this in autocast since it caculates metrics (i.e. loss) without autocast switched on
310 # This is mostly fine except it fails to correctly cast the class weights tensor passed to the loss
311 if self.configs["train"]["mixed_precision"] and self.device == torch.device("cuda"):
312 with torch.cuda.amp.autocast():
313 evaluator.run(values[2], epoch_length=None)
314 else:
315 evaluator.run(values[2], epoch_length=None)
316
317 metrics = evaluator.state.metrics
318 message = [f"{tag} Results - Epoch: {trainer.state.epoch}"]
319 message.extend([f"Avg {m}: {metrics[m]:.4f}" for m in metrics])
320 print(message)

◆ setup_handlers()

def setup_handlers (   self,
  cfg_filename = "config.yaml" 
)
Creates the various ignite handlers (callbacks).

Args:
    cfg_filename (str): Name of config yaml file to use when saving configs.

Definition at line 237 of file create_trainer.py.

237 def setup_handlers(self, cfg_filename="config.yaml"):
238 """
239 Creates the various ignite handlers (callbacks).
240
241 Args:
242 cfg_filename (str): Name of config yaml file to use when saving configs.
243 """
244 # Create the output directory
245 if self.run_dir is not None:
246 self.run_dir.mkdir(parents=True, exist_ok=True)
247 # And save the configs, putting here to only save when setting up checkpointing
248 with open(
249 self.run_dir / f"{self.timestamp}_{cfg_filename}", "w"
250 ) as outfile:
251 cleaned_configs = self._clean_config_dict(self.configs)
252 yaml.dump(cleaned_configs, outfile, default_flow_style=False)
253
254 # Setup early stopping
255 early_handler = ignite.handlers.EarlyStopping(
256 patience=self.configs["train"]["early_stop_patience"],
257 score_function=self._score_fn,
258 trainer=self.trainer,
259 min_delta=1e-3,
260 )
261 self.evaluators["Validation"].add_event_handler(
262 ignite.engine.Events.EPOCH_COMPLETED, early_handler
263 )
264
265 # Configure saving the best performing model
266 if self.run_dir is not None:
267 to_save = {
268 "model": self.model,
269 "optimizer": self.optimizer,
270 "trainer": self.trainer,
271 }
272 # Note that we judge early stopping above by the validation loss, but save the best model
273 # according to validation perfectEvent score. This lets training continue for perfectEvent plateaus
274 # so long as the model is still changing (and hopefully improving again after some time).
275 best_model_handler = ignite.handlers.Checkpoint(
276 to_save=to_save,
277 save_handler=ignite.handlers.DiskSaver(
278 self.run_dir, create_dir=True, require_empty=False
279 ),
280 filename_prefix=self.timestamp,
281 score_function=self._perfect_score_fn,
282 score_name="validation_perfectEvent",
283 n_saved=1,
284 global_step_transform=ignite.handlers.global_step_from_engine(
285 self.evaluators["Validation"]
286 ),
287 )
288 self.evaluators["Validation"].add_event_handler(
289 ignite.engine.Events.EPOCH_COMPLETED, best_model_handler
290 )
291
292 return
293

Member Data Documentation

◆ configs

configs

Configs.

Definition at line 62 of file create_trainer.py.

◆ device

device

CPU or GPU.

Definition at line 68 of file create_trainer.py.

◆ evaluators

evaluators

Setup train and validation evaluators.

Definition at line 132 of file create_trainer.py.

◆ ignore_index

ignore_index

Index to ignore.

Definition at line 66 of file create_trainer.py.

◆ model

model

Model.

Definition at line 58 of file create_trainer.py.

◆ optimizer

optimizer

Optimizer.

Definition at line 60 of file create_trainer.py.

◆ run_dir

run_dir

Output directory for checkpoints.

Definition at line 74 of file create_trainer.py.

◆ tags

tags

Tags.

Definition at line 64 of file create_trainer.py.

◆ timestamp

timestamp

Run timestamp to distinguish trainings.

Definition at line 71 of file create_trainer.py.

◆ trainer

trainer

Ignite trainer.

Definition at line 125 of file create_trainer.py.


The documentation for this class was generated from the following file: