![]() |
Belle II Software light-2509-fornax
|
Public Member Functions | |
| __init__ (self, model, optimizer, loss_fn, device, configs, tags, scheduler=None, ignore_index=-1.0) | |
| setup_handlers (self, cfg_filename="config.yaml") | |
| log_results (self, trainer, mode_tags) | |
Public Attributes | |
| model = model | |
| Model. | |
| optimizer = optimizer | |
| Optimizer. | |
| configs = configs | |
| Configs. | |
| tags = tags | |
| Tags. | |
| ignore_index = ignore_index | |
| Index to ignore. | |
| device = device | |
| CPU or GPU. | |
| timestamp = datetime.now().strftime("%Y.%m.%d_%H.%M") | |
| Run timestamp to distinguish trainings. | |
| run_dir = None | |
| Output directory for checkpoints. | |
| trainer = ignite.engine.Engine(_update_model) | |
| Ignite trainer. | |
| dict | evaluators = {} |
| Setup train and validation evaluators. | |
Protected Member Functions | |
| _score_fn (self, engine) | |
| _perfect_score_fn (self, engine) | |
| _clean_config_dict (self, configs) | |
Class to setup the ignite trainer and hold all the things associated. :param model: The actual PyTorch model. :type model: `Model <https://pytorch.org/tutorials/beginner/introyt/modelsyt_tutorial.html>`_ :param optimizer: Optimizer used in training. :type optimizer: `Optimizer <https://pytorch.org/docs/stable/optim.html#torch.optim.Optimizer>`_ :param loss_fn: Loss function. :type loss_fn: `Loss <https://pytorch.org/docs/stable/nn.html#loss-functions>`_ :param device: Device to use. :type device: `Device <https://pytorch.org/docs/stable/tensor_attributes.html#torch.device>`_ :param configs: Dictionary of run configs from loaded yaml config file. :type configs: dict :param tags: Various tags to sort train and validation evaluators by, e.g. "Training", "Validation". :type tags: list :param scheduler: Learning rate scheduler. :type scheduler: `Scheduler <https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate>`_ :param ignore_index: Label index to ignore when calculating metrics, e.g. padding. :type ignore_index: int
Definition at line 21 of file create_trainer.py.
| __init__ | ( | self, | |
| model, | |||
| optimizer, | |||
| loss_fn, | |||
| device, | |||
| configs, | |||
| tags, | |||
| scheduler = None, | |||
| ignore_index = -1.0 ) |
Initialization.
Definition at line 43 of file create_trainer.py.
|
protected |
Clean configs to prepare them for writing to file.
Definition at line 224 of file create_trainer.py.
|
protected |
Metric to use for checkpoints
Definition at line 220 of file create_trainer.py.
|
protected |
Metric to use for early stoppging
Definition at line 216 of file create_trainer.py.
| log_results | ( | self, | |
| trainer, | |||
| mode_tags ) |
Callback to run evaluation and report the results. :param trainer: Trainer passed by ignite to this method. :type trainer: `Engine <https://pytorch.org/ignite/generated/ignite.engine.engine.Engine.html#ignite.engine.engine.Engine>`_ :param mode_tags: Dictionary of mode tags containing (mode, dataset, dataloader) tuples. :type mode_tags: dict
Definition at line 296 of file create_trainer.py.
| setup_handlers | ( | self, | |
| cfg_filename = "config.yaml" ) |
Creates the various ignite handlers (callbacks).
Args:
cfg_filename (str): Name of config yaml file to use when saving configs.
Definition at line 237 of file create_trainer.py.
| configs = configs |
Configs.
Definition at line 62 of file create_trainer.py.
| device = device |
CPU or GPU.
Definition at line 68 of file create_trainer.py.
| dict evaluators = {} |
Setup train and validation evaluators.
Definition at line 132 of file create_trainer.py.
| ignore_index = ignore_index |
Index to ignore.
Definition at line 66 of file create_trainer.py.
| model = model |
Model.
Definition at line 58 of file create_trainer.py.
| optimizer = optimizer |
Optimizer.
Definition at line 60 of file create_trainer.py.
| run_dir = None |
Output directory for checkpoints.
Definition at line 74 of file create_trainer.py.
| tags = tags |
Tags.
Definition at line 64 of file create_trainer.py.
| timestamp = datetime.now().strftime("%Y.%m.%d_%H.%M") |
Run timestamp to distinguish trainings.
Definition at line 71 of file create_trainer.py.
| trainer = ignite.engine.Engine(_update_model) |
Ignite trainer.
Definition at line 125 of file create_trainer.py.