![]() |
Belle II Software development
|
Public Member Functions | |
__init__ (self, model, data_set, log_dir=None, save_name=None, monitoring_size=10000) | |
train_model (self) | |
Public Attributes | |
model = model | |
model | |
data_set = data_set | |
dataset | |
monitoring_size = monitoring_size | |
monitoring size | |
log_dir = log_dir | |
log dir | |
termination_criterion = self.model.termination_criterion | |
termination criterion | |
int | current_epoch = 0 |
initialise current epoch | |
int | best_epoch = -np.inf |
initialise best epoch | |
save_name = save_name | |
set the path and name for saving the weightfiles | |
int | train_monitor = -1 |
train_monitor | |
int | valid_monitor = -1 |
test monitor | |
train_writer = tf.summary.create_file_writer(log_dir_train) | |
tf.summary.writer for training | |
valid_writer = tf.summary.create_file_writer(log_dir_valid) | |
tf.summary.writer for validation | |
optimizer = self.model.get_optimizer(current_epoch) | |
set optimizer for this epoch | |
Protected Member Functions | |
_prepare_monitoring (self) | |
_prepare_tensorboard (self, log_dir) | |
_train_epoch (self, current_epoch) | |
_save_best_state (self, cross_entropy) | |
_closing_ops (self) | |
Protected Attributes | |
_time = time.time() | |
current time | |
handling the training of the network model
Definition at line 450 of file tensorflow_dnn_model.py.
__init__ | ( | self, | |
model, | |||
data_set, | |||
log_dir = None, | |||
save_name = None, | |||
monitoring_size = 10000 ) |
class to train a predefined model :param model: DefaultModel obj :param data_set: TFData obj :param log_dir: str, directory name of tensorboard logging :param save_name: str, path and name for saving the weightfiles :param monitoring_size: int, number of events of training fraction used for monitoring
Definition at line 455 of file tensorflow_dnn_model.py.
|
protected |
closing operations
Definition at line 625 of file tensorflow_dnn_model.py.
|
protected |
checking dataset sizes for evaluation. These samples are used after each epoch to collect summary statistics and test early stopping criteria.
Definition at line 508 of file tensorflow_dnn_model.py.
|
protected |
prepare tensorboard
Definition at line 524 of file tensorflow_dnn_model.py.
|
protected |
save model as a checkpoint only if a global minimum is reached on validation sample :return:
Definition at line 607 of file tensorflow_dnn_model.py.
|
protected |
train epoch
Definition at line 538 of file tensorflow_dnn_model.py.
train_model | ( | self | ) |
train model
Definition at line 634 of file tensorflow_dnn_model.py.
|
protected |
current time
Definition at line 471 of file tensorflow_dnn_model.py.
best_epoch = -np.inf |
initialise best epoch
Definition at line 493 of file tensorflow_dnn_model.py.
int current_epoch = 0 |
initialise current epoch
Definition at line 490 of file tensorflow_dnn_model.py.
data_set = data_set |
dataset
Definition at line 477 of file tensorflow_dnn_model.py.
log_dir = log_dir |
log dir
Definition at line 484 of file tensorflow_dnn_model.py.
model = model |
model
Definition at line 474 of file tensorflow_dnn_model.py.
monitoring_size = monitoring_size |
monitoring size
Definition at line 481 of file tensorflow_dnn_model.py.
optimizer = self.model.get_optimizer(current_epoch) |
set optimizer for this epoch
Definition at line 543 of file tensorflow_dnn_model.py.
save_name = save_name |
set the path and name for saving the weightfiles
Definition at line 503 of file tensorflow_dnn_model.py.
termination_criterion = self.model.termination_criterion |
termination criterion
Definition at line 487 of file tensorflow_dnn_model.py.
int train_monitor = -1 |
train_monitor
Definition at line 514 of file tensorflow_dnn_model.py.
train_writer = tf.summary.create_file_writer(log_dir_train) |
tf.summary.writer for training
Definition at line 532 of file tensorflow_dnn_model.py.
int valid_monitor = -1 |
test monitor
Definition at line 516 of file tensorflow_dnn_model.py.
valid_writer = tf.summary.create_file_writer(log_dir_valid) |
tf.summary.writer for validation
Definition at line 535 of file tensorflow_dnn_model.py.