Belle II Software  release-08-01-10
TfDataBasf2 Class Reference

Public Member Functions

def __init__ (self, train_x, train_y, valid_x, valid_y, batch_size, seed=None, epoch_random_shuffle=True)
 
def sanitize_labels (self)
 
def batch_iterator (self)
 

Public Attributes

 train_x
 training features
 
 train_y
 training targets
 
 valid_x
 validation features
 
 valid_y
 validation targets
 
 batch_size
 batch size
 
 seed
 random generator seed
 
 epoch_random_shuffle
 bool, enables shuffling
 
 train_events
 number of training events
 
 valid_events
 number of validation events
 
 feature_number
 number of features
 
 batches
 number of batches
 
 train_idx
 indices required for shuffling
 
 batch_train_x
 np ndarray for training batch features
 
 batch_train_y
 np ndarray for training batch of targets
 
 random_state
 set random generator
 

Detailed Description

handles data, necessary for the training

Definition at line 15 of file TfData.py.

Constructor & Destructor Documentation

◆ __init__()

def __init__ (   self,
  train_x,
  train_y,
  valid_x,
  valid_y,
  batch_size,
  seed = None,
  epoch_random_shuffle = True 
)
declaration of class variables

Definition at line 20 of file TfData.py.

20  def __init__(self, train_x, train_y, valid_x, valid_y, batch_size, seed=None, epoch_random_shuffle=True):
21  """
22  declaration of class variables
23  """
24 
25  self.train_x = train_x.astype(np.float32)
26 
27  self.train_y = train_y.astype(np.float32)
28 
29  self.valid_x = valid_x.astype(np.float32)
30 
31  self.valid_y = valid_y.astype(np.float32)
32 
33 
34  self.batch_size = batch_size
35 
36  self.seed = seed
37 
38  self.epoch_random_shuffle = epoch_random_shuffle
39 
40 
41  self.train_events = self.train_x.shape[0]
42 
43  self.valid_events = self.valid_x.shape[0]
44 
45 
46  self.feature_number = self.train_x.shape[1]
47 
48 
49  self.batches = self.train_x.shape[0] // self.batch_size
50 
51 
52  self.train_idx = np.zeros(self.train_x.shape[0])
53 
54 
55  self.batch_train_x = np.zeros((self.feature_number, self.batch_size))
56 
57 
58  self.batch_train_y = np.zeros(self.batch_size)
59 
60 
61  self.random_state = np.random.RandomState(seed)
62 
63 
64  self.sanitize_labels()
65 

Member Function Documentation

◆ batch_iterator()

def batch_iterator (   self)
iterator to provide training batches

Definition at line 90 of file TfData.py.

◆ sanitize_labels()

def sanitize_labels (   self)
checks for a binary classification problem
transforms the two class labels to {0,1}

Definition at line 66 of file TfData.py.


The documentation for this class was generated from the following file: