![]() |
Belle II Software development
|
Public Member Functions | |
__init__ (self, model, optimizer, loss_fn, device, configs, tags, scheduler=None, ignore_index=-1.0) | |
setup_handlers (self, cfg_filename="config.yaml") | |
log_results (self, trainer, mode_tags) | |
Public Attributes | |
model = model | |
Model. | |
optimizer = optimizer | |
Optimizer. | |
configs = configs | |
Configs. | |
tags = tags | |
Tags. | |
ignore_index = ignore_index | |
Index to ignore. | |
device = device | |
CPU or GPU. | |
timestamp = datetime.now().strftime("%Y.%m.%d_%H.%M") | |
Run timestamp to distinguish trainings. | |
run_dir = None | |
Output directory for checkpoints. | |
trainer = ignite.engine.Engine(_update_model) | |
Ignite trainer. | |
dict | evaluators = {} |
Setup train and validation evaluators. | |
Protected Member Functions | |
_score_fn (self, engine) | |
_perfect_score_fn (self, engine) | |
_clean_config_dict (self, configs) | |
Class to setup the ignite trainer and hold all the things associated. :param model: The actual PyTorch model. :type model: `Model <https://pytorch.org/tutorials/beginner/introyt/modelsyt_tutorial.html>`_ :param optimizer: Optimizer used in training. :type optimizer: `Optimizer <https://pytorch.org/docs/stable/optim.html#torch.optim.Optimizer>`_ :param loss_fn: Loss function. :type loss_fn: `Loss <https://pytorch.org/docs/stable/nn.html#loss-functions>`_ :param device: Device to use. :type device: `Device <https://pytorch.org/docs/stable/tensor_attributes.html#torch.device>`_ :param configs: Dictionary of run configs from loaded yaml config file. :type configs: dict :param tags: Various tags to sort train and validation evaluators by, e.g. "Training", "Validation". :type tags: list :param scheduler: Learning rate scheduler. :type scheduler: `Scheduler <https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate>`_ :param ignore_index: Label index to ignore when calculating metrics, e.g. padding. :type ignore_index: int
Definition at line 21 of file create_trainer.py.
__init__ | ( | self, | |
model, | |||
optimizer, | |||
loss_fn, | |||
device, | |||
configs, | |||
tags, | |||
scheduler = None, | |||
ignore_index = -1.0 ) |
Initialization.
Definition at line 43 of file create_trainer.py.
|
protected |
Clean configs to prepare them for writing to file.
Definition at line 224 of file create_trainer.py.
|
protected |
Metric to use for checkpoints
Definition at line 220 of file create_trainer.py.
|
protected |
Metric to use for early stoppging
Definition at line 216 of file create_trainer.py.
log_results | ( | self, | |
trainer, | |||
mode_tags ) |
Callback to run evaluation and report the results. :param trainer: Trainer passed by ignite to this method. :type trainer: `Engine <https://pytorch.org/ignite/generated/ignite.engine.engine.Engine.html#ignite.engine.engine.Engine>`_ :param mode_tags: Dictionary of mode tags containing (mode, dataset, dataloader) tuples. :type mode_tags: dict
Definition at line 296 of file create_trainer.py.
setup_handlers | ( | self, | |
cfg_filename = "config.yaml" ) |
Creates the various ignite handlers (callbacks). Args: cfg_filename (str): Name of config yaml file to use when saving configs.
Definition at line 237 of file create_trainer.py.
configs = configs |
Configs.
Definition at line 62 of file create_trainer.py.
device = device |
CPU or GPU.
Definition at line 68 of file create_trainer.py.
dict evaluators = {} |
Setup train and validation evaluators.
Definition at line 132 of file create_trainer.py.
ignore_index = ignore_index |
Index to ignore.
Definition at line 66 of file create_trainer.py.
model = model |
Model.
Definition at line 58 of file create_trainer.py.
optimizer = optimizer |
Optimizer.
Definition at line 60 of file create_trainer.py.
run_dir = None |
Output directory for checkpoints.
Definition at line 74 of file create_trainer.py.
tags = tags |
Tags.
Definition at line 64 of file create_trainer.py.
timestamp = datetime.now().strftime("%Y.%m.%d_%H.%M") |
Run timestamp to distinguish trainings.
Definition at line 71 of file create_trainer.py.
trainer = ignite.engine.Engine(_update_model) |
Ignite trainer.
Definition at line 125 of file create_trainer.py.