Belle II Software  light-2403-persian
GraFEIIgniteTrainer Class Reference

Public Member Functions

def __init__ (self, model, optimizer, loss_fn, device, configs, tags, scheduler=None, ignore_index=-1.0)
 
def setup_handlers (self, cfg_filename="config.yaml")
 
def log_results (self, trainer, mode_tags)
 

Public Attributes

 model
 Model.
 
 optimizer
 Optimizer.
 
 configs
 Configs.
 
 tags
 Tags.
 
 ignore_index
 Index to ignore.
 
 device
 CPU or GPU.
 
 timestamp
 Run timestamp to distinguish trainings.
 
 run_dir
 Output directory for checkpoints.
 
 trainer
 Ignite trainer.
 
 evaluators
 Setup train and validation evaluators.
 

Private Member Functions

def _score_fn (self, engine)
 
def _perfect_score_fn (self, engine)
 
def _clean_config_dict (self, configs)
 

Detailed Description

Class to setup the ignite trainer and hold all the things associated.

:param model: The actual PyTorch model.
:type model: `Model <https://pytorch.org/tutorials/beginner/introyt/modelsyt_tutorial.html>`_
:param optimizer: Optimizer used in training.
:type optimizer: `Optimizer <https://pytorch.org/docs/stable/optim.html#torch.optim.Optimizer>`_
:param loss_fn: Loss function.
:type loss_fn: `Loss <https://pytorch.org/docs/stable/nn.html#loss-functions>`_
:param device: Device to use.
:type device: `Device <https://pytorch.org/docs/stable/tensor_attributes.html#torch.device>`_
:param configs: Dictionary of run configs from loaded yaml config file.
:type configs: dict
:param tags: Various tags to sort train and validation evaluators by, e.g. "Training", "Validation".
:type tags: list
:param scheduler: Learning rate scheduler.
:type scheduler: `Scheduler <https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate>`_
:param ignore_index: Label index to ignore when calculating metrics, e.g. padding.
:type ignore_index: int

Definition at line 21 of file create_trainer.py.

Constructor & Destructor Documentation

◆ __init__()

def __init__ (   self,
  model,
  optimizer,
  loss_fn,
  device,
  configs,
  tags,
  scheduler = None,
  ignore_index = -1.0 
)
Initialization.

Definition at line 43 of file create_trainer.py.

53  ):
54  """
55  Initialization.
56  """
57 
58  self.model = model
59 
60  self.optimizer = optimizer
61 
62  self.configs = configs
63 
64  self.tags = tags
65 
66  self.ignore_index = ignore_index
67 
68  self.device = device
69 
70 
71  self.timestamp = datetime.now().strftime("%Y.%m.%d_%H.%M")
72 
73 
74  self.run_dir = None
75  if self.configs["output"] is not None:
76  if ("path" in self.configs["output"].keys()) and (
77  self.configs["output"]["path"] is not None
78  ):
79  self.run_dir = Path(
80  self.configs["output"]["path"],
81  self.configs["output"]["run_name"],
82  )
83 
84  # Setup ignite trainer
85  use_amp = configs["train"]["mixed_precision"] and self.device == torch.device(
86  "cuda"
87  )
88 
89  if use_amp:
90  from torch.cuda.amp import autocast
91  from torch.cuda.amp import GradScaler
92 
93  scaler = GradScaler(enabled=True)
94 
95  def _update_model(engine, batch):
96  # This just sets the training mode
97  model.train()
98 
99  optimizer.zero_grad()
100 
101  batch = (
102  Batch.from_data_list(batch).to(device)
103  if isinstance(batch, list)
104  else batch.to(device)
105  )
106 
107  x_y, edge_y, u_y = batch.x_y, batch.edge_y, batch.u_y
108 
109  if use_amp:
110  with autocast(enabled=True):
111  x_pred, e_pred, u_pred = model(batch)
112  loss = loss_fn(x_pred, x_y, e_pred, edge_y, u_pred, u_y)
113  scaler.scale(loss).backward()
114  scaler.step(optimizer)
115  scaler.update()
116  else:
117  x_pred, e_pred, u_pred = model(batch)
118  loss = loss_fn(x_pred, x_y, e_pred, edge_y, u_pred, u_y)
119  loss.backward()
120  optimizer.step()
121 
122  return loss.item()
123 
124 
125  self.trainer = ignite.engine.Engine(_update_model)
126 
127  if scheduler:
128  ig_scheduler = ignite.handlers.param_scheduler.LRScheduler(scheduler)
129  self.trainer.add_event_handler(ignite.engine.Events.ITERATION_STARTED, ig_scheduler)
130 
131 
132  self.evaluators = {}
133 
134  for tag in self.tags:
135  # Setup metrics
136  metrics = {
137  # ignite.metrics.Loss takes (y_pred, y, **kwargs) arguments.
138  # MultiTrainLoss needs in total 6 arguments to be computed,
139  # so the additional ones are passed in a dictionary.
140  "loss": ignite.metrics.Loss(
141  loss_fn,
142  output_transform=lambda x: [
143  x[0], x[3],
144  {
145  "edge_input": x[1],
146  "edge_target": x[4],
147  "u_input": x[2],
148  "u_target": x[5],
149  },
150  ],
151  device=device,
152  ),
153  "perfectLCA": PerfectLCA(
154  ignore_index=ignore_index,
155  device=device,
156  output_transform=lambda x: [
157  x[1], x[4], x[6], x[5], x[7], x[8],
158  ],
159  ),
160  "perfectMasses": PerfectMasses(
161  ignore_index=ignore_index,
162  device=device,
163  output_transform=lambda x: [x[0], x[3], x[5], x[7], x[8]],
164  ),
165  "perfectEvent": PerfectEvent(
166  ignore_index=ignore_index,
167  device=device,
168  output_transform=lambda x: [
169  x[0], x[3], x[1], x[4], x[6], x[5], x[7], x[8],
170  ],
171  ),
172  }
173 
174  def _predict_on_batch(engine, batch):
175  model.eval() # It just enables evaluation mode
176 
177  batch = (
178  Batch.from_data_list(batch).to(device)
179  if isinstance(batch, list)
180  else batch.to(device)
181  )
182 
183  x_y, edge_y, u_y, edge_index, torch_batch = (
184  batch.x_y,
185  batch.edge_y,
186  batch.u_y,
187  batch.edge_index,
188  batch.batch,
189  )
190  num_graph = batch.batch[torch_batch.shape[0] - 1] + 1
191 
192  with torch.no_grad():
193  if use_amp:
194  with autocast(enabled=True):
195  x_pred, e_pred, u_pred = model(batch)
196  else:
197  x_pred, e_pred, u_pred = model(batch)
198 
199  return (
200  x_pred,
201  e_pred,
202  u_pred,
203  x_y,
204  edge_y,
205  u_y,
206  edge_index,
207  torch_batch,
208  num_graph,
209  )
210 
211  self.evaluators[tag] = ignite.engine.Engine(_predict_on_batch)
212 
213  for metric_name, metric in zip(metrics.keys(), metrics.values()):
214  metric.attach(self.evaluators[tag], metric_name)
215 

Member Function Documentation

◆ _clean_config_dict()

def _clean_config_dict (   self,
  configs 
)
private
Clean configs to prepare them for writing to file.

Definition at line 224 of file create_trainer.py.

◆ _perfect_score_fn()

def _perfect_score_fn (   self,
  engine 
)
private
Metric to use for checkpoints

Definition at line 220 of file create_trainer.py.

◆ _score_fn()

def _score_fn (   self,
  engine 
)
private
Metric to use for early stoppging

Definition at line 216 of file create_trainer.py.

◆ log_results()

def log_results (   self,
  trainer,
  mode_tags 
)
Callback to run evaluation and report the results.

:param trainer: Trainer passed by ignite to this method.
:type trainer: `Engine <https://pytorch.org/ignite/generated/ignite.engine.engine.Engine.html#ignite.engine.engine.Engine>`_
:param mode_tags: Dictionary of mode tags containing (mode, dataset, dataloader) tuples.
:type mode_tags: dict

Definition at line 296 of file create_trainer.py.

◆ setup_handlers()

def setup_handlers (   self,
  cfg_filename = "config.yaml" 
)
Creates the various ignite handlers (callbacks).

Args:
    cfg_filename (str): Name of config yaml file to use when saving configs.

Definition at line 237 of file create_trainer.py.


The documentation for this class was generated from the following file: