![]() |
Belle II Software light-2505-deimos
|
Public Member Functions | |
| __init__ (self, model, data_set, log_dir=None, save_name=None, monitoring_size=10000) | |
| train_model (self) | |
Public Attributes | |
| model = model | |
| model | |
| data_set = data_set | |
| dataset | |
| monitoring_size = monitoring_size | |
| monitoring size | |
| log_dir = log_dir | |
| log dir | |
| termination_criterion = self.model.termination_criterion | |
| termination criterion | |
| int | current_epoch = 0 |
| initialise current epoch | |
| int | best_epoch = -np.inf |
| initialise best epoch | |
| save_name = save_name | |
| set the path and name for saving the weightfiles | |
| int | train_monitor = -1 |
| train_monitor | |
| int | valid_monitor = -1 |
| test monitor | |
| train_writer = tf.summary.create_file_writer(log_dir_train) | |
| tf.summary.writer for training | |
| valid_writer = tf.summary.create_file_writer(log_dir_valid) | |
| tf.summary.writer for validation | |
| optimizer = self.model.get_optimizer(current_epoch) | |
| set optimizer for this epoch | |
Protected Member Functions | |
| _prepare_monitoring (self) | |
| _prepare_tensorboard (self, log_dir) | |
| _train_epoch (self, current_epoch) | |
| _save_best_state (self, cross_entropy) | |
| _closing_ops (self) | |
Protected Attributes | |
| _time = time.time() | |
| current time | |
handling the training of the network model
Definition at line 450 of file tensorflow_dnn_model.py.
| __init__ | ( | self, | |
| model, | |||
| data_set, | |||
| log_dir = None, | |||
| save_name = None, | |||
| monitoring_size = 10000 ) |
class to train a predefined model :param model: DefaultModel obj :param data_set: TFData obj :param log_dir: str, directory name of tensorboard logging :param save_name: str, path and name for saving the weightfiles :param monitoring_size: int, number of events of training fraction used for monitoring
Definition at line 455 of file tensorflow_dnn_model.py.
|
protected |
closing operations
Definition at line 625 of file tensorflow_dnn_model.py.
|
protected |
checking dataset sizes for evaluation. These samples are used after each epoch to collect summary statistics and test early stopping criteria.
Definition at line 508 of file tensorflow_dnn_model.py.
|
protected |
prepare tensorboard
Definition at line 524 of file tensorflow_dnn_model.py.
|
protected |
save model as a checkpoint only if a global minimum is reached on validation sample :return:
Definition at line 607 of file tensorflow_dnn_model.py.
|
protected |
train epoch
Definition at line 538 of file tensorflow_dnn_model.py.
| train_model | ( | self | ) |
train model
Definition at line 634 of file tensorflow_dnn_model.py.
|
protected |
current time
Definition at line 471 of file tensorflow_dnn_model.py.
| best_epoch = -np.inf |
initialise best epoch
Definition at line 493 of file tensorflow_dnn_model.py.
| int current_epoch = 0 |
initialise current epoch
Definition at line 490 of file tensorflow_dnn_model.py.
| data_set = data_set |
dataset
Definition at line 477 of file tensorflow_dnn_model.py.
| log_dir = log_dir |
log dir
Definition at line 484 of file tensorflow_dnn_model.py.
| model = model |
model
Definition at line 474 of file tensorflow_dnn_model.py.
| monitoring_size = monitoring_size |
monitoring size
Definition at line 481 of file tensorflow_dnn_model.py.
| optimizer = self.model.get_optimizer(current_epoch) |
set optimizer for this epoch
Definition at line 543 of file tensorflow_dnn_model.py.
| save_name = save_name |
set the path and name for saving the weightfiles
Definition at line 503 of file tensorflow_dnn_model.py.
| termination_criterion = self.model.termination_criterion |
termination criterion
Definition at line 487 of file tensorflow_dnn_model.py.
| int train_monitor = -1 |
train_monitor
Definition at line 514 of file tensorflow_dnn_model.py.
| train_writer = tf.summary.create_file_writer(log_dir_train) |
tf.summary.writer for training
Definition at line 532 of file tensorflow_dnn_model.py.
| int valid_monitor = -1 |
test monitor
Definition at line 516 of file tensorflow_dnn_model.py.
| valid_writer = tf.summary.create_file_writer(log_dir_valid) |
tf.summary.writer for validation
Definition at line 535 of file tensorflow_dnn_model.py.