![]() |
Belle II Software
release-06-02-00
|
Public Member Functions | |
| def | __init__ (self, model, data_set, sess, log_dir=None, save_name=None, monitoring_size=100000, input_placeholders=None) |
| def | train_model (self) |
Public Attributes | |
| model | |
| model | |
| data_set | |
| data set | |
| monitoring_size | |
| monitoring size | |
| sess | |
| tf.session | |
| log_dir | |
| log directory | |
| x | |
| input placeholder features | |
| y_ | |
| input placeholder targets | |
| monitoring_params | |
| monitoring params for early stopping criterion, loss function, etc | |
| termination_criterion | |
| termination criterion | |
| max_epochs | |
| global_training_parameters | |
| current_epoch | |
| current epoch | |
| minimizer | |
| optimizer | |
| train_log_dict | |
| train_log_dict | |
| saver | |
| saver | |
| save_name | |
| save name | |
| train_monitor | |
| train_monitor | |
| valid_monitor | |
| valid monitor | |
| train_writer | |
| train writer | |
| test_writer | |
| test writer | |
| merged_summary | |
| summary | |
| epoch_parameters | |
| epoch parameters | |
Private Member Functions | |
| def | _prepare_monitoring (self) |
| def | _prepare_tensorboard (self, log_dir) |
| def | _add_to_basf2_collections (self) |
| def | _save_best_state (self, monitoring_params, label_name='mean_cross_entropy') |
| def | _closing_ops (self) |
| def | _train_epoch (self, current_epoch) |
Private Attributes | |
| _time | |
| time | |
handling the training of the network model
Definition at line 498 of file tensorflow_dnn_model.py.
| def __init__ | ( | self, | |
| model, | |||
| data_set, | |||
| sess, | |||
log_dir = None, |
|||
save_name = None, |
|||
monitoring_size = 100000, |
|||
input_placeholders = None |
|||
| ) |
class to train a predefined model :param model: DefaultModel obj :param data_set: TFData obj :param sess: tensorflow.Session obj :param log_dir: str, directory name of tensorboard logging :param save_name: str, path and name for saving the weightfiles :param monitoring_size: int, number of events of training fraction used for monitoring :param input_placeholders: list of tf.placeholders, [features, targets]
Definition at line 503 of file tensorflow_dnn_model.py.
|
private |
add to basf2 collection
Definition at line 622 of file tensorflow_dnn_model.py.
|
private |
closing ops
Definition at line 651 of file tensorflow_dnn_model.py.
|
private |
checking dataset sizes for evaluation
Definition at line 593 of file tensorflow_dnn_model.py.
|
private |
prepare tensorboard
Definition at line 607 of file tensorflow_dnn_model.py.
|
private |
save model only if a global minimum is reached on validation set :return:
Definition at line 632 of file tensorflow_dnn_model.py.
|
private |
train epoch
Definition at line 666 of file tensorflow_dnn_model.py.
| def train_model | ( | self | ) |
train model
Definition at line 730 of file tensorflow_dnn_model.py.