Belle II Software light-2406-ragdoll
Trainer Class Reference

Public Member Functions

def __init__ (self, model, data_set, log_dir=None, save_name=None, monitoring_size=10000)
 
def train_model (self)
 

Public Attributes

 model
 model
 
 data_set
 dataset
 
 monitoring_size
 monitoring size
 
 log_dir
 log dir
 
 termination_criterion
 termination criterion
 
 current_epoch
 initialise current epoch
 
 best_epoch
 initialise best epoch
 
 save_name
 set the path and name for saving the weightfiles
 
 train_monitor
 train_monitor
 
 valid_monitor
 test monitor
 
 train_writer
 tf.summary.writer for training
 
 valid_writer
 tf.summary.writer for validation
 
 optimizer
 set optimizer for this epoch
 

Protected Member Functions

def _prepare_monitoring (self)
 
def _prepare_tensorboard (self, log_dir)
 
def _train_epoch (self, current_epoch)
 
def _save_best_state (self, cross_entropy)
 
def _closing_ops (self)
 

Protected Attributes

 _time
 current time
 

Detailed Description

handling the training of the network model

Definition at line 450 of file tensorflow_dnn_model.py.

Constructor & Destructor Documentation

◆ __init__()

def __init__ (   self,
  model,
  data_set,
  log_dir = None,
  save_name = None,
  monitoring_size = 10000 
)
class to train a predefined model
:param model: DefaultModel obj
:param data_set: TFData obj
:param log_dir: str, directory name of tensorboard logging
:param save_name: str, path and name for saving the weightfiles
:param monitoring_size: int, number of events of training fraction used for monitoring

Definition at line 455 of file tensorflow_dnn_model.py.

460 monitoring_size=10000):
461 """
462 class to train a predefined model
463 :param model: DefaultModel obj
464 :param data_set: TFData obj
465 :param log_dir: str, directory name of tensorboard logging
466 :param save_name: str, path and name for saving the weightfiles
467 :param monitoring_size: int, number of events of training fraction used for monitoring
468 """
469
470
471 self._time = time.time()
472
473
474 self.model = model
475
476
477 self.data_set = data_set
478 self.model.initialize(data_set)
479
480
481 self.monitoring_size = monitoring_size
482
483
484 self.log_dir = log_dir
485
486
487 self.termination_criterion = self.model.termination_criterion
488
489
490 self.current_epoch = 0
491
492
493 self.best_epoch = -np.inf
494
495 if log_dir is not None:
496 self._prepare_tensorboard(log_dir)
497
498 if save_name is None:
499 time_str = time.strftime("%Y%m%d-%H%M%S")
500 save_name = os.path.join(os.getcwd(), '_'.join([time_str, 'model']))
501
502
503 self.save_name = save_name
504
505 self._prepare_monitoring()
506 return
507

Member Function Documentation

◆ _closing_ops()

def _closing_ops (   self)
protected
closing operations

Definition at line 625 of file tensorflow_dnn_model.py.

625 def _closing_ops(self):
626 """
627 closing operations
628 """
629 if self.log_dir is not None:
630 self.train_writer.close()
631 self.valid_writer.close()
632 return
633

◆ _prepare_monitoring()

def _prepare_monitoring (   self)
protected
checking dataset sizes for evaluation. These samples are used after each epoch to collect
summary statistics and test early stopping criteria.

Definition at line 508 of file tensorflow_dnn_model.py.

508 def _prepare_monitoring(self):
509 """
510 checking dataset sizes for evaluation. These samples are used after each epoch to collect
511 summary statistics and test early stopping criteria.
512 """
513
514 self.train_monitor = -1
515
516 self.valid_monitor = -1
517 if self.data_set.train_events > self.monitoring_size:
518 self.train_monitor = self.monitoring_size
519
520 if self.data_set.valid_events > self.monitoring_size:
521 self.valid_monitor = self.monitoring_size
522 return
523

◆ _prepare_tensorboard()

def _prepare_tensorboard (   self,
  log_dir 
)
protected
prepare tensorboard

Definition at line 524 of file tensorflow_dnn_model.py.

524 def _prepare_tensorboard(self, log_dir):
525 """
526 prepare tensorboard
527 """
528 log_dir_train = os.path.join(log_dir, 'train')
529 log_dir_valid = os.path.join(log_dir, 'valid')
530
531
532 self.train_writer = tf.summary.create_file_writer(log_dir_train)
533
534
535 self.valid_writer = tf.summary.create_file_writer(log_dir_valid)
536 return
537

◆ _save_best_state()

def _save_best_state (   self,
  cross_entropy 
)
protected
save model as a checkpoint only if a global minimum is reached on validation sample
:return:

Definition at line 607 of file tensorflow_dnn_model.py.

607 def _save_best_state(self, cross_entropy):
608 """
609 save model as a checkpoint only if a global minimum is reached on validation sample
610 :return:
611 """
612 # current state - do we need this?
613 checkpoint = tf.train.Checkpoint(self.model)
614 checkpoint.save(self.save_name.replace('model', 'model_current'))
615
616 # check for a not set best value
617 if self.model.best_value == np.inf:
618 return
619
620 if cross_entropy < self.model.best_value:
621 self.best_epoch = self.current_epoch
622 checkpoint.save(self.save_name)
623 return
624

◆ _train_epoch()

def _train_epoch (   self,
  current_epoch 
)
protected
train epoch

Definition at line 538 of file tensorflow_dnn_model.py.

538 def _train_epoch(self, current_epoch):
539 """
540 train epoch
541 """
542
543 self.optimizer = self.model.get_optimizer(current_epoch)
544
545 batch_iter = self.data_set.batch_iterator()
546
547 t_range = trange(self.data_set.batches)
548 t_range.set_description(f'Epoch {current_epoch:4d}')
549 for i in t_range:
550
551 batch = next(batch_iter)
552
553 batch_x = batch[0]
554 batch_y = batch[1]
555
556 with tf.GradientTape() as tape:
557 loss, _ = self.model.loss(self.model(batch_x), batch_y)
558 grads = tape.gradient(loss, self.model.trainable_variables)
559
560 self.optimizer.apply_gradients(zip(grads, self.model.trainable_variables))
561
562 # write the learning rate and momentum to the tensorbord log
563 if self.log_dir is not None:
564 with self.train_writer.as_default():
565 tf.summary.scalar('learning_rate', self.model._get_learning_rate(), step=self.model.global_step)
566 tf.summary.scalar('momentum', self.model._get_momentum(), step=self.model.global_step)
567 self.train_writer.flush()
568
569 self.model.global_step.assign_add(1)
570
571 train_x = self.data_set.train_x[:self.train_monitor]
572 train_y = self.data_set.train_y[:self.train_monitor]
573
574 valid_x = self.data_set.valid_x[:self.valid_monitor]
575 valid_y = self.data_set.valid_y[:self.valid_monitor]
576
577 # run the training and validation samples to collect statistics
578 train_loss, train_cross_entropy = self.model.loss(self.model(train_x), train_y)
579 valid_loss, valid_cross_entropy = self.model.loss(self.model(valid_x), valid_y)
580
581 # if we have a log_dir set write extra summary information
582 if self.log_dir is not None:
583 with self.train_writer.as_default():
584 tf.summary.scalar('loss', train_loss, step=current_epoch)
585 tf.summary.scalar('cross_entropy', train_cross_entropy, step=current_epoch)
586
587 # this is now at the end of each epoch
588 tf.summary.scalar('epoch_learning_rate', self.model._get_learning_rate(), step=current_epoch)
589 tf.summary.scalar('epoch_momentum', self.model._get_momentum(), step=current_epoch)
590 self.train_writer.flush()
591
592 # write all the model parameters to the summary file too
593 self.model.mlp.variables_to_writer(current_epoch, self.train_writer)
594
595 with self.valid_writer.as_default():
596 tf.summary.scalar('loss', valid_loss, step=current_epoch)
597 tf.summary.scalar('cross_entropy', valid_cross_entropy, step=current_epoch)
598 tf.summary.scalar('best_epoch', self.best_epoch, step=current_epoch)
599 self.valid_writer.flush()
600
601 # update time
602 self._time = time.time()
603 self.current_epoch += 1
604
605 return valid_cross_entropy
606

◆ train_model()

def train_model (   self)
train model

Definition at line 634 of file tensorflow_dnn_model.py.

634 def train_model(self):
635 """
636 train model
637 """
638 for epoch in range(self.model.max_epochs):
639 valid_cross_entropy = self._train_epoch(epoch)
640
641 self._save_best_state(valid_cross_entropy)
642
643 if self.termination_criterion(valid_cross_entropy, epoch):
644 break
645
646 self._closing_ops()
647 return

Member Data Documentation

◆ _time

_time
protected

current time

Definition at line 471 of file tensorflow_dnn_model.py.

◆ best_epoch

best_epoch

initialise best epoch

Definition at line 493 of file tensorflow_dnn_model.py.

◆ current_epoch

current_epoch

initialise current epoch

Definition at line 490 of file tensorflow_dnn_model.py.

◆ data_set

data_set

dataset

Definition at line 477 of file tensorflow_dnn_model.py.

◆ log_dir

log_dir

log dir

Definition at line 484 of file tensorflow_dnn_model.py.

◆ model

model

model

Definition at line 474 of file tensorflow_dnn_model.py.

◆ monitoring_size

monitoring_size

monitoring size

Definition at line 481 of file tensorflow_dnn_model.py.

◆ optimizer

optimizer

set optimizer for this epoch

Definition at line 543 of file tensorflow_dnn_model.py.

◆ save_name

save_name

set the path and name for saving the weightfiles

Definition at line 503 of file tensorflow_dnn_model.py.

◆ termination_criterion

termination_criterion

termination criterion

Definition at line 487 of file tensorflow_dnn_model.py.

◆ train_monitor

train_monitor

train_monitor

Definition at line 514 of file tensorflow_dnn_model.py.

◆ train_writer

train_writer

tf.summary.writer for training

Definition at line 532 of file tensorflow_dnn_model.py.

◆ valid_monitor

valid_monitor

test monitor

Definition at line 516 of file tensorflow_dnn_model.py.

◆ valid_writer

valid_writer

tf.summary.writer for validation

Definition at line 535 of file tensorflow_dnn_model.py.


The documentation for this class was generated from the following file: