![]() |
Belle II Software release-09-00-03
|
Public Member Functions | |
| def | __init__ (self, model, optimizer, loss_fn, device, configs, tags, scheduler=None, ignore_index=-1.0) |
| def | setup_handlers (self, cfg_filename="config.yaml") |
| def | log_results (self, trainer, mode_tags) |
Public Attributes | |
| model | |
| Model. | |
| optimizer | |
| Optimizer. | |
| configs | |
| Configs. | |
| tags | |
| Tags. | |
| ignore_index | |
| Index to ignore. | |
| device | |
| CPU or GPU. | |
| timestamp | |
| Run timestamp to distinguish trainings. | |
| run_dir | |
| Output directory for checkpoints. | |
| trainer | |
| Ignite trainer. | |
| evaluators | |
| Setup train and validation evaluators. | |
Protected Member Functions | |
| def | _score_fn (self, engine) |
| def | _perfect_score_fn (self, engine) |
| def | _clean_config_dict (self, configs) |
Class to setup the ignite trainer and hold all the things associated. :param model: The actual PyTorch model. :type model: `Model <https://pytorch.org/tutorials/beginner/introyt/modelsyt_tutorial.html>`_ :param optimizer: Optimizer used in training. :type optimizer: `Optimizer <https://pytorch.org/docs/stable/optim.html#torch.optim.Optimizer>`_ :param loss_fn: Loss function. :type loss_fn: `Loss <https://pytorch.org/docs/stable/nn.html#loss-functions>`_ :param device: Device to use. :type device: `Device <https://pytorch.org/docs/stable/tensor_attributes.html#torch.device>`_ :param configs: Dictionary of run configs from loaded yaml config file. :type configs: dict :param tags: Various tags to sort train and validation evaluators by, e.g. "Training", "Validation". :type tags: list :param scheduler: Learning rate scheduler. :type scheduler: `Scheduler <https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate>`_ :param ignore_index: Label index to ignore when calculating metrics, e.g. padding. :type ignore_index: int
Definition at line 21 of file create_trainer.py.
| def __init__ | ( | self, | |
| model, | |||
| optimizer, | |||
| loss_fn, | |||
| device, | |||
| configs, | |||
| tags, | |||
scheduler = None, |
|||
ignore_index = -1.0 |
|||
| ) |
Initialization.
Definition at line 43 of file create_trainer.py.
|
protected |
Clean configs to prepare them for writing to file.
Definition at line 224 of file create_trainer.py.
|
protected |
Metric to use for checkpoints
Definition at line 220 of file create_trainer.py.
|
protected |
Metric to use for early stoppging
Definition at line 216 of file create_trainer.py.
| def log_results | ( | self, | |
| trainer, | |||
| mode_tags | |||
| ) |
Callback to run evaluation and report the results. :param trainer: Trainer passed by ignite to this method. :type trainer: `Engine <https://pytorch.org/ignite/generated/ignite.engine.engine.Engine.html#ignite.engine.engine.Engine>`_ :param mode_tags: Dictionary of mode tags containing (mode, dataset, dataloader) tuples. :type mode_tags: dict
Definition at line 296 of file create_trainer.py.
| def setup_handlers | ( | self, | |
cfg_filename = "config.yaml" |
|||
| ) |
Creates the various ignite handlers (callbacks).
Args:
cfg_filename (str): Name of config yaml file to use when saving configs.
Definition at line 237 of file create_trainer.py.
| configs |
Configs.
Definition at line 62 of file create_trainer.py.
| device |
CPU or GPU.
Definition at line 68 of file create_trainer.py.
| evaluators |
Setup train and validation evaluators.
Definition at line 132 of file create_trainer.py.
| ignore_index |
Index to ignore.
Definition at line 66 of file create_trainer.py.
| model |
Model.
Definition at line 58 of file create_trainer.py.
| optimizer |
Optimizer.
Definition at line 60 of file create_trainer.py.
| run_dir |
Output directory for checkpoints.
Definition at line 74 of file create_trainer.py.
| tags |
Tags.
Definition at line 64 of file create_trainer.py.
| timestamp |
Run timestamp to distinguish trainings.
Definition at line 71 of file create_trainer.py.
| trainer |
Ignite trainer.
Definition at line 125 of file create_trainer.py.