Belle II Software light-2406-ragdoll
TfDataBasf2 Class Reference

Public Member Functions

def __init__ (self, train_x, train_y, valid_x, valid_y, batch_size, seed=None, epoch_random_shuffle=True)
 
def sanitize_labels (self)
 
def batch_iterator (self)
 

Public Attributes

 train_x
 training features
 
 train_y
 training targets
 
 valid_x
 validation features
 
 valid_y
 validation targets
 
 batch_size
 batch size
 
 seed
 random generator seed
 
 epoch_random_shuffle
 bool, enables shuffling
 
 train_events
 number of training events
 
 valid_events
 number of validation events
 
 feature_number
 number of features
 
 batches
 number of batches
 
 train_idx
 indices required for shuffling
 
 batch_train_x
 np ndarray for training batch features
 
 batch_train_y
 np ndarray for training batch of targets
 
 random_state
 set random generator
 

Detailed Description

handles data, necessary for the training

Definition at line 15 of file TfData.py.

Constructor & Destructor Documentation

◆ __init__()

def __init__ (   self,
  train_x,
  train_y,
  valid_x,
  valid_y,
  batch_size,
  seed = None,
  epoch_random_shuffle = True 
)
declaration of class variables

Definition at line 20 of file TfData.py.

20 def __init__(self, train_x, train_y, valid_x, valid_y, batch_size, seed=None, epoch_random_shuffle=True):
21 """
22 declaration of class variables
23 """
24
25 self.train_x = train_x.astype(np.float32)
26
27 self.train_y = train_y.astype(np.float32)
28
29 self.valid_x = valid_x.astype(np.float32)
30
31 self.valid_y = valid_y.astype(np.float32)
32
33
34 self.batch_size = batch_size
35
36 self.seed = seed
37
38 self.epoch_random_shuffle = epoch_random_shuffle
39
40
41 self.train_events = self.train_x.shape[0]
42
43 self.valid_events = self.valid_x.shape[0]
44
45
46 self.feature_number = self.train_x.shape[1]
47
48
49 self.batches = self.train_x.shape[0] // self.batch_size
50
51
52 self.train_idx = np.zeros(self.train_x.shape[0])
53
54
55 self.batch_train_x = np.zeros((self.feature_number, self.batch_size))
56
57
58 self.batch_train_y = np.zeros(self.batch_size)
59
60
61 self.random_state = np.random.RandomState(seed)
62
63
64 self.sanitize_labels()
65

Member Function Documentation

◆ batch_iterator()

def batch_iterator (   self)
iterator to provide training batches

Definition at line 90 of file TfData.py.

90 def batch_iterator(self):
91 """
92 iterator to provide training batches
93 """
94 self.train_idx = np.arange(len(self.train_idx))
95
96 if self.epoch_random_shuffle:
97 self.random_state.shuffle(self.train_idx)
98
99 for i in range(self.batches):
100 self.batch_train_x = self.train_x[self.train_idx[i * self.batch_size: (i + 1) * self.batch_size]]
101 self.batch_train_y = self.train_y[self.train_idx[i * self.batch_size: (i + 1) * self.batch_size]]
102
103 yield self.batch_train_x, self.batch_train_y
104
105

◆ sanitize_labels()

def sanitize_labels (   self)
checks for a binary classification problem
transforms the two class labels to {0,1}

Definition at line 66 of file TfData.py.

66 def sanitize_labels(self):
67 """
68 checks for a binary classification problem
69 transforms the two class labels to {0,1}
70 """
71 # not binary
72 assert len(np.unique(self.train_y)) == 2
73 # different classes
74 assert np.array_equal(np.unique(self.train_y), np.unique(self.valid_y))
75
76 # reject corner cases when classes would have special values
77 if self.train_y.min() > 0:
78 self.train_y[self.train_y == self.train_y.min()] = 0
79 self.valid_y[self.valid_y == self.valid_y.min()] = 0
80
81 if self.train_y.max() != 1:
82 self.train_y[self.train_y == self.train_y.max()] = 1
83 self.valid_y[self.valid_y == self.valid_y.max()] = 1
84
85 # transform labels
86 if self.train_y.min() != 0:
87 self.train_y[self.train_y == self.train_y.min()] = 0
88 self.valid_y[self.valid_y == self.valid_y.min()] = 0
89

Member Data Documentation

◆ batch_size

batch_size

batch size

Definition at line 34 of file TfData.py.

◆ batch_train_x

batch_train_x

np ndarray for training batch features

Definition at line 55 of file TfData.py.

◆ batch_train_y

batch_train_y

np ndarray for training batch of targets

Definition at line 58 of file TfData.py.

◆ batches

batches

number of batches

Definition at line 49 of file TfData.py.

◆ epoch_random_shuffle

epoch_random_shuffle

bool, enables shuffling

Definition at line 38 of file TfData.py.

◆ feature_number

feature_number

number of features

Definition at line 46 of file TfData.py.

◆ random_state

random_state

set random generator

Definition at line 61 of file TfData.py.

◆ seed

seed

random generator seed

Definition at line 36 of file TfData.py.

◆ train_events

train_events

number of training events

Definition at line 41 of file TfData.py.

◆ train_idx

train_idx

indices required for shuffling

Definition at line 52 of file TfData.py.

◆ train_x

train_x

training features

Definition at line 25 of file TfData.py.

◆ train_y

train_y

training targets

Definition at line 27 of file TfData.py.

◆ valid_events

valid_events

number of validation events

Definition at line 43 of file TfData.py.

◆ valid_x

valid_x

validation features

Definition at line 29 of file TfData.py.

◆ valid_y

valid_y

validation targets

Definition at line 31 of file TfData.py.


The documentation for this class was generated from the following file: