Belle II Software development
|
Public Member Functions | |
def | __init__ (self, model, optimizer, loss_fn, device, configs, tags, scheduler=None, ignore_index=-1.0) |
def | setup_handlers (self, cfg_filename="config.yaml") |
def | log_results (self, trainer, mode_tags) |
Public Attributes | |
model | |
Model. | |
optimizer | |
Optimizer. | |
configs | |
Configs. | |
tags | |
Tags. | |
ignore_index | |
Index to ignore. | |
device | |
CPU or GPU. | |
timestamp | |
Run timestamp to distinguish trainings. | |
run_dir | |
Output directory for checkpoints. | |
trainer | |
Ignite trainer. | |
evaluators | |
Setup train and validation evaluators. | |
Protected Member Functions | |
def | _score_fn (self, engine) |
def | _perfect_score_fn (self, engine) |
def | _clean_config_dict (self, configs) |
Class to setup the ignite trainer and hold all the things associated. :param model: The actual PyTorch model. :type model: `Model <https://pytorch.org/tutorials/beginner/introyt/modelsyt_tutorial.html>`_ :param optimizer: Optimizer used in training. :type optimizer: `Optimizer <https://pytorch.org/docs/stable/optim.html#torch.optim.Optimizer>`_ :param loss_fn: Loss function. :type loss_fn: `Loss <https://pytorch.org/docs/stable/nn.html#loss-functions>`_ :param device: Device to use. :type device: `Device <https://pytorch.org/docs/stable/tensor_attributes.html#torch.device>`_ :param configs: Dictionary of run configs from loaded yaml config file. :type configs: dict :param tags: Various tags to sort train and validation evaluators by, e.g. "Training", "Validation". :type tags: list :param scheduler: Learning rate scheduler. :type scheduler: `Scheduler <https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate>`_ :param ignore_index: Label index to ignore when calculating metrics, e.g. padding. :type ignore_index: int
Definition at line 21 of file create_trainer.py.
def __init__ | ( | self, | |
model, | |||
optimizer, | |||
loss_fn, | |||
device, | |||
configs, | |||
tags, | |||
scheduler = None , |
|||
ignore_index = -1.0 |
|||
) |
Initialization.
Definition at line 43 of file create_trainer.py.
|
protected |
Clean configs to prepare them for writing to file.
Definition at line 224 of file create_trainer.py.
|
protected |
Metric to use for checkpoints
Definition at line 220 of file create_trainer.py.
|
protected |
Metric to use for early stoppging
Definition at line 216 of file create_trainer.py.
def log_results | ( | self, | |
trainer, | |||
mode_tags | |||
) |
Callback to run evaluation and report the results. :param trainer: Trainer passed by ignite to this method. :type trainer: `Engine <https://pytorch.org/ignite/generated/ignite.engine.engine.Engine.html#ignite.engine.engine.Engine>`_ :param mode_tags: Dictionary of mode tags containing (mode, dataset, dataloader) tuples. :type mode_tags: dict
Definition at line 296 of file create_trainer.py.
def setup_handlers | ( | self, | |
cfg_filename = "config.yaml" |
|||
) |
Creates the various ignite handlers (callbacks). Args: cfg_filename (str): Name of config yaml file to use when saving configs.
Definition at line 237 of file create_trainer.py.
configs |
Configs.
Definition at line 62 of file create_trainer.py.
device |
CPU or GPU.
Definition at line 68 of file create_trainer.py.
evaluators |
Setup train and validation evaluators.
Definition at line 132 of file create_trainer.py.
ignore_index |
Index to ignore.
Definition at line 66 of file create_trainer.py.
model |
Model.
Definition at line 58 of file create_trainer.py.
optimizer |
Optimizer.
Definition at line 60 of file create_trainer.py.
run_dir |
Output directory for checkpoints.
Definition at line 74 of file create_trainer.py.
tags |
Tags.
Definition at line 64 of file create_trainer.py.
timestamp |
Run timestamp to distinguish trainings.
Definition at line 71 of file create_trainer.py.
trainer |
Ignite trainer.
Definition at line 125 of file create_trainer.py.