Belle II Software  release-05-02-19
TfData.py
1 #!/usr/bin/env python
2 # -*- coding: utf-8 -*-
3 
4 
13 
14 from __future__ import division, print_function, generators
15 
16 import numpy as np
17 
18 
19 class TfDataBasf2():
20  """
21  handles data, necessary for the training
22  """
23 
24  def __init__(self, train_x, train_y, valid_x, valid_y, batch_size, seed=None, epoch_random_shuffle=True):
25  """
26  declaration of class variables
27  """
28 
29  self.train_x = train_x
30 
31  self.train_y = train_y
32 
33  self.valid_x = valid_x
34 
35  self.valid_y = valid_y
36 
37 
38  self.batch_size = batch_size
39 
40  self.seed = seed
41 
42  self.epoch_random_shuffle = epoch_random_shuffle
43 
44 
45  self.train_events = self.train_x.shape[0]
46 
47  self.valid_events = self.valid_x.shape[0]
48 
49 
50  self.feature_number = self.train_x.shape[1]
51 
52 
53  self.batches = self.train_x.shape[0] // self.batch_size
54 
55 
56  self.train_idx = np.zeros((self.train_x.shape[0]))
57 
58 
59  self.batch_train_x = np.zeros((self.feature_number, self.batch_size))
60 
61 
62  self.batch_train_y = np.zeros(self.batch_size)
63 
64 
65  self.random_state = np.random.RandomState(seed)
66 
67 
68  self.sanitize_labels()
69 
70  def sanitize_labels(self):
71  """
72  checks for a binary classification problem
73  transforms the two class labels to {0,1}
74  """
75  # not binary
76  assert len(np.unique(self.train_y)) == 2
77  # different classes
78  assert np.array_equal(np.unique(self.train_y), np.unique(self.valid_y))
79 
80  # reject corner cases when classes would have special values
81  if self.train_y.min() > 0:
82  self.train_y[self.train_y == self.train_y.min()] = 0
83  self.valid_y[self.valid_y == self.valid_y.min()] = 0
84 
85  if self.train_y.max() != 1:
86  self.train_y[self.train_y == self.train_y.max()] = 1
87  self.valid_y[self.valid_y == self.valid_y.max()] = 1
88 
89  # transform labels
90  if self.train_y.min() != 0:
91  self.train_y[self.train_y == self.train_y.min()] = 0
92  self.valid_y[self.valid_y == self.valid_y.min()] = 0
93 
94  def batch_iterator(self):
95  """
96  iterator to provide training batches
97  """
98  self.train_idx = np.arange(len(self.train_idx))
99 
100  if self.epoch_random_shuffle:
101  self.random_state.shuffle(self.train_idx)
102 
103  for i in range(self.batches):
104  self.batch_train_x = self.train_x[self.train_idx[i * self.batch_size: (i + 1) * self.batch_size]]
105  self.batch_train_y = self.train_y[self.train_idx[i * self.batch_size: (i + 1) * self.batch_size]]
106 
107  yield self.batch_train_x, self.batch_train_y
108 
109 
111  """
112  stub class just for initializing in basf2 begin_run
113  """
114 
115  def __init__(self, batch_size, feature_number, event_number, train_fraction):
116  """
117  declare for initialization required batch parameters
118  """
119 
120  self.batch_size = batch_size
121 
122 
123  self.feature_number = feature_number
124 
125 
126  self.batches = (event_number * train_fraction) // self.batch_size
127 
128 
129  self.train_events = int(train_fraction * event_number)
130 
131 
132  self.valid_events = int((1 - train_fraction) * event_number)
dft.TfData.TfDataBasf2.batches
batches
number of batches
Definition: TfData.py:53
dft.TfData.TfDataBasf2.valid_y
valid_y
validation targets
Definition: TfData.py:35
dft.TfData.TfDataBasf2Stub
Definition: TfData.py:110
dft.TfData.TfDataBasf2.train_events
train_events
number of training events
Definition: TfData.py:45
dft.TfData.TfDataBasf2.valid_x
valid_x
validation features
Definition: TfData.py:33
dft.TfData.TfDataBasf2Stub.valid_events
valid_events
number of validation events
Definition: TfData.py:132
dft.TfData.TfDataBasf2Stub.train_events
train_events
number of training training events
Definition: TfData.py:129
dft.TfData.TfDataBasf2.sanitize_labels
def sanitize_labels(self)
Definition: TfData.py:70
dft.TfData.TfDataBasf2.__init__
def __init__(self, train_x, train_y, valid_x, valid_y, batch_size, seed=None, epoch_random_shuffle=True)
Definition: TfData.py:24
dft.TfData.TfDataBasf2.batch_size
batch_size
batch size
Definition: TfData.py:38
dft.TfData.TfDataBasf2
Definition: TfData.py:19
dft.TfData.TfDataBasf2.batch_train_x
batch_train_x
np ndarray for training batch features
Definition: TfData.py:59
dft.TfData.TfDataBasf2.epoch_random_shuffle
epoch_random_shuffle
bool, enables shuffling
Definition: TfData.py:42
dft.TfData.TfDataBasf2.train_idx
train_idx
idices required for shuffling
Definition: TfData.py:56
dft.TfData.TfDataBasf2Stub.__init__
def __init__(self, batch_size, feature_number, event_number, train_fraction)
Definition: TfData.py:115
dft.TfData.TfDataBasf2.valid_events
valid_events
number of validation events
Definition: TfData.py:47
dft.TfData.TfDataBasf2.batch_train_y
batch_train_y
np ndarray for training batch of targets
Definition: TfData.py:62
dft.TfData.TfDataBasf2.feature_number
feature_number
number of features
Definition: TfData.py:50
dft.TfData.TfDataBasf2.seed
seed
random generator seed
Definition: TfData.py:40
dft.TfData.TfDataBasf2Stub.feature_number
feature_number
feature number
Definition: TfData.py:123
dft.TfData.TfDataBasf2Stub.batch_size
batch_size
batch size
Definition: TfData.py:120
dft.TfData.TfDataBasf2.train_y
train_y
training targets
Definition: TfData.py:31
dft.TfData.TfDataBasf2.batch_iterator
def batch_iterator(self)
Definition: TfData.py:94
dft.TfData.TfDataBasf2Stub.batches
batches
number of batches
Definition: TfData.py:126
dft.TfData.TfDataBasf2.random_state
random_state
set random generator
Definition: TfData.py:65
dft.TfData.TfDataBasf2.train_x
train_x
training features
Definition: TfData.py:29