Belle II Software light-2406-ragdoll
TfData.py
1#!/usr/bin/env python
2
3
10
11
12import numpy as np
13
14
16 """
17 handles data, necessary for the training
18 """
19
20 def __init__(self, train_x, train_y, valid_x, valid_y, batch_size, seed=None, epoch_random_shuffle=True):
21 """
22 declaration of class variables
23 """
24
25 self.train_x = train_x.astype(np.float32)
26
27 self.train_y = train_y.astype(np.float32)
28
29 self.valid_x = valid_x.astype(np.float32)
30
31 self.valid_y = valid_y.astype(np.float32)
32
33
34 self.batch_size = batch_size
35
36 self.seed = seed
37
38 self.epoch_random_shuffle = epoch_random_shuffle
39
40
41 self.train_events = self.train_x.shape[0]
42
43 self.valid_events = self.valid_x.shape[0]
44
45
46 self.feature_number = self.train_x.shape[1]
47
48
49 self.batches = self.train_x.shape[0] // self.batch_size
50
51
52 self.train_idx = np.zeros(self.train_x.shape[0])
53
54
55 self.batch_train_x = np.zeros((self.feature_number, self.batch_size))
56
57
58 self.batch_train_y = np.zeros(self.batch_size)
59
60
61 self.random_state = np.random.RandomState(seed)
62
63
64 self.sanitize_labels()
65
66 def sanitize_labels(self):
67 """
68 checks for a binary classification problem
69 transforms the two class labels to {0,1}
70 """
71 # not binary
72 assert len(np.unique(self.train_y)) == 2
73 # different classes
74 assert np.array_equal(np.unique(self.train_y), np.unique(self.valid_y))
75
76 # reject corner cases when classes would have special values
77 if self.train_y.min() > 0:
78 self.train_y[self.train_y == self.train_y.min()] = 0
79 self.valid_y[self.valid_y == self.valid_y.min()] = 0
80
81 if self.train_y.max() != 1:
82 self.train_y[self.train_y == self.train_y.max()] = 1
83 self.valid_y[self.valid_y == self.valid_y.max()] = 1
84
85 # transform labels
86 if self.train_y.min() != 0:
87 self.train_y[self.train_y == self.train_y.min()] = 0
88 self.valid_y[self.valid_y == self.valid_y.min()] = 0
89
90 def batch_iterator(self):
91 """
92 iterator to provide training batches
93 """
94 self.train_idx = np.arange(len(self.train_idx))
95
97 self.random_state.shuffle(self.train_idx)
98
99 for i in range(self.batches):
100 self.batch_train_x = self.train_x[self.train_idx[i * self.batch_size: (i + 1) * self.batch_size]]
101 self.batch_train_y = self.train_y[self.train_idx[i * self.batch_size: (i + 1) * self.batch_size]]
102
103 yield self.batch_train_x, self.batch_train_y
104
105
107 """
108 stub class just for initializing in basf2 begin_run
109 """
110
111 def __init__(self, batch_size, feature_number, event_number, train_fraction):
112 """
113 declare for initialization required batch parameters
114 """
115
116 self.batch_size = batch_size
117
118
119 self.feature_number = feature_number
120
121
122 self.batches = (event_number * train_fraction) // self.batch_size
123
124
125 self.train_events = int(train_fraction * event_number)
126
127
128 self.valid_events = int((1 - train_fraction) * event_number)
train_events
number of training training events
Definition: TfData.py:125
def __init__(self, batch_size, feature_number, event_number, train_fraction)
Definition: TfData.py:111
feature_number
feature number
Definition: TfData.py:119
valid_events
number of validation events
Definition: TfData.py:128
batches
number of batches
Definition: TfData.py:122
def batch_iterator(self)
Definition: TfData.py:90
epoch_random_shuffle
bool, enables shuffling
Definition: TfData.py:38
random_state
set random generator
Definition: TfData.py:61
valid_y
validation targets
Definition: TfData.py:31
train_events
number of training events
Definition: TfData.py:41
def __init__(self, train_x, train_y, valid_x, valid_y, batch_size, seed=None, epoch_random_shuffle=True)
Definition: TfData.py:20
seed
random generator seed
Definition: TfData.py:36
def sanitize_labels(self)
Definition: TfData.py:66
train_y
training targets
Definition: TfData.py:27
batch_size
batch size
Definition: TfData.py:34
batch_train_y
np ndarray for training batch of targets
Definition: TfData.py:58
valid_x
validation features
Definition: TfData.py:29
train_idx
indices required for shuffling
Definition: TfData.py:52
feature_number
number of features
Definition: TfData.py:46
valid_events
number of validation events
Definition: TfData.py:43
train_x
training features
Definition: TfData.py:25
batches
number of batches
Definition: TfData.py:49
batch_train_x
np ndarray for training batch features
Definition: TfData.py:55
Definition: __init__.py:1