Belle II Software  release-08-01-10
TfData.py
1 #!/usr/bin/env python
2 
3 
10 
11 
12 import numpy as np
13 
14 
15 class TfDataBasf2():
16  """
17  handles data, necessary for the training
18  """
19 
20  def __init__(self, train_x, train_y, valid_x, valid_y, batch_size, seed=None, epoch_random_shuffle=True):
21  """
22  declaration of class variables
23  """
24 
25  self.train_xtrain_x = train_x.astype(np.float32)
26 
27  self.train_ytrain_y = train_y.astype(np.float32)
28 
29  self.valid_xvalid_x = valid_x.astype(np.float32)
30 
31  self.valid_yvalid_y = valid_y.astype(np.float32)
32 
33 
34  self.batch_sizebatch_size = batch_size
35 
36  self.seedseed = seed
37 
38  self.epoch_random_shuffleepoch_random_shuffle = epoch_random_shuffle
39 
40 
41  self.train_eventstrain_events = self.train_xtrain_x.shape[0]
42 
43  self.valid_eventsvalid_events = self.valid_xvalid_x.shape[0]
44 
45 
46  self.feature_numberfeature_number = self.train_xtrain_x.shape[1]
47 
48 
49  self.batchesbatches = self.train_xtrain_x.shape[0] // self.batch_sizebatch_size
50 
51 
52  self.train_idxtrain_idx = np.zeros(self.train_xtrain_x.shape[0])
53 
54 
55  self.batch_train_xbatch_train_x = np.zeros((self.feature_numberfeature_number, self.batch_sizebatch_size))
56 
57 
58  self.batch_train_ybatch_train_y = np.zeros(self.batch_sizebatch_size)
59 
60 
61  self.random_staterandom_state = np.random.RandomState(seed)
62 
63 
64  self.sanitize_labelssanitize_labels()
65 
66  def sanitize_labels(self):
67  """
68  checks for a binary classification problem
69  transforms the two class labels to {0,1}
70  """
71  # not binary
72  assert len(np.unique(self.train_ytrain_y)) == 2
73  # different classes
74  assert np.array_equal(np.unique(self.train_ytrain_y), np.unique(self.valid_yvalid_y))
75 
76  # reject corner cases when classes would have special values
77  if self.train_ytrain_y.min() > 0:
78  self.train_ytrain_y[self.train_ytrain_y == self.train_ytrain_y.min()] = 0
79  self.valid_yvalid_y[self.valid_yvalid_y == self.valid_yvalid_y.min()] = 0
80 
81  if self.train_ytrain_y.max() != 1:
82  self.train_ytrain_y[self.train_ytrain_y == self.train_ytrain_y.max()] = 1
83  self.valid_yvalid_y[self.valid_yvalid_y == self.valid_yvalid_y.max()] = 1
84 
85  # transform labels
86  if self.train_ytrain_y.min() != 0:
87  self.train_ytrain_y[self.train_ytrain_y == self.train_ytrain_y.min()] = 0
88  self.valid_yvalid_y[self.valid_yvalid_y == self.valid_yvalid_y.min()] = 0
89 
90  def batch_iterator(self):
91  """
92  iterator to provide training batches
93  """
94  self.train_idxtrain_idx = np.arange(len(self.train_idxtrain_idx))
95 
96  if self.epoch_random_shuffleepoch_random_shuffle:
97  self.random_staterandom_state.shuffle(self.train_idxtrain_idx)
98 
99  for i in range(self.batchesbatches):
100  self.batch_train_xbatch_train_x = self.train_xtrain_x[self.train_idxtrain_idx[i * self.batch_sizebatch_size: (i + 1) * self.batch_sizebatch_size]]
101  self.batch_train_ybatch_train_y = self.train_ytrain_y[self.train_idxtrain_idx[i * self.batch_sizebatch_size: (i + 1) * self.batch_sizebatch_size]]
102 
103  yield self.batch_train_xbatch_train_x, self.batch_train_ybatch_train_y
104 
105 
107  """
108  stub class just for initializing in basf2 begin_run
109  """
110 
111  def __init__(self, batch_size, feature_number, event_number, train_fraction):
112  """
113  declare for initialization required batch parameters
114  """
115 
116  self.batch_sizebatch_size = batch_size
117 
118 
119  self.feature_numberfeature_number = feature_number
120 
121 
122  self.batchesbatches = (event_number * train_fraction) // self.batch_sizebatch_size
123 
124 
125  self.train_eventstrain_events = int(train_fraction * event_number)
126 
127 
128  self.valid_eventsvalid_events = int((1 - train_fraction) * event_number)
train_events
number of training training events
Definition: TfData.py:125
def __init__(self, batch_size, feature_number, event_number, train_fraction)
Definition: TfData.py:111
feature_number
feature number
Definition: TfData.py:119
valid_events
number of validation events
Definition: TfData.py:128
batches
number of batches
Definition: TfData.py:122
def batch_iterator(self)
Definition: TfData.py:90
epoch_random_shuffle
bool, enables shuffling
Definition: TfData.py:38
random_state
set random generator
Definition: TfData.py:61
valid_y
validation targets
Definition: TfData.py:31
train_events
number of training events
Definition: TfData.py:41
def __init__(self, train_x, train_y, valid_x, valid_y, batch_size, seed=None, epoch_random_shuffle=True)
Definition: TfData.py:20
seed
random generator seed
Definition: TfData.py:36
def sanitize_labels(self)
Definition: TfData.py:66
train_y
training targets
Definition: TfData.py:27
batch_size
batch size
Definition: TfData.py:34
batch_train_y
np ndarray for training batch of targets
Definition: TfData.py:58
valid_x
validation features
Definition: TfData.py:29
train_idx
indices required for shuffling
Definition: TfData.py:52
feature_number
number of features
Definition: TfData.py:46
valid_events
number of validation events
Definition: TfData.py:43
train_x
training features
Definition: TfData.py:25
batches
number of batches
Definition: TfData.py:49
batch_train_x
np ndarray for training batch features
Definition: TfData.py:55