Belle II Software development
|
Public Member Functions | |
def | __init__ (self, model, data_set, log_dir=None, save_name=None, monitoring_size=10000) |
def | train_model (self) |
Public Attributes | |
model | |
model | |
data_set | |
dataset | |
monitoring_size | |
monitoring size | |
log_dir | |
log dir | |
termination_criterion | |
termination criterion | |
current_epoch | |
initialise current epoch | |
best_epoch | |
initialise best epoch | |
save_name | |
set the path and name for saving the weightfiles | |
train_monitor | |
train_monitor | |
valid_monitor | |
test monitor | |
train_writer | |
tf.summary.writer for training | |
valid_writer | |
tf.summary.writer for validation | |
optimizer | |
set optimizer for this epoch | |
Protected Member Functions | |
def | _prepare_monitoring (self) |
def | _prepare_tensorboard (self, log_dir) |
def | _train_epoch (self, current_epoch) |
def | _save_best_state (self, cross_entropy) |
def | _closing_ops (self) |
Protected Attributes | |
_time | |
current time | |
handling the training of the network model
Definition at line 450 of file tensorflow_dnn_model.py.
def __init__ | ( | self, | |
model, | |||
data_set, | |||
log_dir = None , |
|||
save_name = None , |
|||
monitoring_size = 10000 |
|||
) |
class to train a predefined model :param model: DefaultModel obj :param data_set: TFData obj :param log_dir: str, directory name of tensorboard logging :param save_name: str, path and name for saving the weightfiles :param monitoring_size: int, number of events of training fraction used for monitoring
Definition at line 455 of file tensorflow_dnn_model.py.
|
protected |
closing operations
Definition at line 625 of file tensorflow_dnn_model.py.
|
protected |
checking dataset sizes for evaluation. These samples are used after each epoch to collect summary statistics and test early stopping criteria.
Definition at line 508 of file tensorflow_dnn_model.py.
|
protected |
prepare tensorboard
Definition at line 524 of file tensorflow_dnn_model.py.
|
protected |
save model as a checkpoint only if a global minimum is reached on validation sample :return:
Definition at line 607 of file tensorflow_dnn_model.py.
|
protected |
train epoch
Definition at line 538 of file tensorflow_dnn_model.py.
def train_model | ( | self | ) |
train model
Definition at line 634 of file tensorflow_dnn_model.py.
|
protected |
current time
Definition at line 471 of file tensorflow_dnn_model.py.
best_epoch |
initialise best epoch
Definition at line 493 of file tensorflow_dnn_model.py.
current_epoch |
initialise current epoch
Definition at line 490 of file tensorflow_dnn_model.py.
data_set |
dataset
Definition at line 477 of file tensorflow_dnn_model.py.
log_dir |
log dir
Definition at line 484 of file tensorflow_dnn_model.py.
model |
model
Definition at line 474 of file tensorflow_dnn_model.py.
monitoring_size |
monitoring size
Definition at line 481 of file tensorflow_dnn_model.py.
optimizer |
set optimizer for this epoch
Definition at line 543 of file tensorflow_dnn_model.py.
save_name |
set the path and name for saving the weightfiles
Definition at line 503 of file tensorflow_dnn_model.py.
termination_criterion |
termination criterion
Definition at line 487 of file tensorflow_dnn_model.py.
train_monitor |
train_monitor
Definition at line 514 of file tensorflow_dnn_model.py.
train_writer |
tf.summary.writer for training
Definition at line 532 of file tensorflow_dnn_model.py.
valid_monitor |
test monitor
Definition at line 516 of file tensorflow_dnn_model.py.
valid_writer |
tf.summary.writer for validation
Definition at line 535 of file tensorflow_dnn_model.py.