Belle II Software
development
TfData.py
1
#!/usr/bin/env python
2
3
10
11
12
import
numpy
as
np
13
14
15
class
TfDataBasf2
():
16
"""
17
handles data, necessary for the training
18
"""
19
20
def
__init__
(self, train_x, train_y, valid_x, valid_y, batch_size, seed=None, epoch_random_shuffle=True):
21
"""
22
declaration of class variables
23
"""
24
25
self.
train_x
= train_x.astype(np.float32)
26
27
self.
train_y
= train_y.astype(np.float32)
28
29
self.
valid_x
= valid_x.astype(np.float32)
30
31
self.
valid_y
= valid_y.astype(np.float32)
32
33
34
self.
batch_size
= batch_size
35
36
self.
seed
= seed
37
38
self.
epoch_random_shuffle
= epoch_random_shuffle
39
40
41
self.
train_events
= self.
train_x
.shape[0]
42
43
self.
valid_events
= self.
valid_x
.shape[0]
44
45
46
self.
feature_number
= self.
train_x
.shape[1]
47
48
49
self.
batches
= self.
train_x
.shape[0] // self.
batch_size
50
51
52
self.
train_idx
= np.zeros(self.
train_x
.shape[0])
53
54
55
self.
batch_train_x
= np.zeros((self.
feature_number
, self.
batch_size
))
56
57
58
self.
batch_train_y
= np.zeros(self.
batch_size
)
59
60
61
self.
random_state
= np.random.RandomState(seed)
62
63
64
self.
sanitize_labels
()
65
66
def
sanitize_labels
(self):
67
"""
68
checks for a binary classification problem
69
transforms the two class labels to {0,1}
70
"""
71
# not binary
72
assert
len(np.unique(self.
train_y
)) == 2
73
# different classes
74
assert
np.array_equal(np.unique(self.
train_y
), np.unique(self.
valid_y
))
75
76
# reject corner cases when classes would have special values
77
if
self.
train_y
.min() > 0:
78
self.
train_y
[self.
train_y
== self.
train_y
.min()] = 0
79
self.
valid_y
[self.
valid_y
== self.
valid_y
.min()] = 0
80
81
if
self.
train_y
.max() != 1:
82
self.
train_y
[self.
train_y
== self.
train_y
.max()] = 1
83
self.
valid_y
[self.
valid_y
== self.
valid_y
.max()] = 1
84
85
# transform labels
86
if
self.
train_y
.min() != 0:
87
self.
train_y
[self.
train_y
== self.
train_y
.min()] = 0
88
self.
valid_y
[self.
valid_y
== self.
valid_y
.min()] = 0
89
90
def
batch_iterator
(self):
91
"""
92
iterator to provide training batches
93
"""
94
self.
train_idx
= np.arange(len(self.
train_idx
))
95
96
if
self.
epoch_random_shuffle
:
97
self.
random_state
.shuffle(self.
train_idx
)
98
99
for
i
in
range(self.
batches
):
100
self.
batch_train_x
= self.
train_x
[self.
train_idx
[i * self.
batch_size
: (i + 1) * self.
batch_size
]]
101
self.
batch_train_y
= self.
train_y
[self.
train_idx
[i * self.
batch_size
: (i + 1) * self.
batch_size
]]
102
103
yield
self.
batch_train_x
, self.
batch_train_y
104
105
106
class
TfDataBasf2Stub
():
107
"""
108
stub class just for initializing in basf2 begin_run
109
"""
110
111
def
__init__
(self, batch_size, feature_number, event_number, train_fraction):
112
"""
113
declare for initialization required batch parameters
114
"""
115
116
self.
batch_size
= batch_size
117
118
119
self.
feature_number
= feature_number
120
121
122
self.
batches
= (event_number * train_fraction) // self.
batch_size
123
124
125
self.
train_events
= int(train_fraction * event_number)
126
127
128
self.
valid_events
= int((1 - train_fraction) * event_number)
dft.TfData.TfDataBasf2Stub
Definition
TfData.py:106
dft.TfData.TfDataBasf2Stub.train_events
train_events
number of training training events
Definition
TfData.py:125
dft.TfData.TfDataBasf2Stub.batches
tuple batches
number of batches
Definition
TfData.py:122
dft.TfData.TfDataBasf2Stub.__init__
__init__(self, batch_size, feature_number, event_number, train_fraction)
Definition
TfData.py:111
dft.TfData.TfDataBasf2Stub.batch_size
batch_size
batch size
Definition
TfData.py:116
dft.TfData.TfDataBasf2Stub.feature_number
feature_number
feature number
Definition
TfData.py:119
dft.TfData.TfDataBasf2Stub.valid_events
valid_events
number of validation events
Definition
TfData.py:128
dft.TfData.TfDataBasf2
Definition
TfData.py:15
dft.TfData.TfDataBasf2.epoch_random_shuffle
epoch_random_shuffle
bool, enables shuffling
Definition
TfData.py:38
dft.TfData.TfDataBasf2.random_state
random_state
set random generator
Definition
TfData.py:61
dft.TfData.TfDataBasf2.valid_y
int valid_y
validation targets
Definition
TfData.py:31
dft.TfData.TfDataBasf2.train_events
train_events
number of training events
Definition
TfData.py:41
dft.TfData.TfDataBasf2.batch_train_y
int batch_train_y
np ndarray for training batch of targets
Definition
TfData.py:58
dft.TfData.TfDataBasf2.__init__
__init__(self, train_x, train_y, valid_x, valid_y, batch_size, seed=None, epoch_random_shuffle=True)
Definition
TfData.py:20
dft.TfData.TfDataBasf2.seed
seed
random generator seed
Definition
TfData.py:36
dft.TfData.TfDataBasf2.train_y
int train_y
training targets
Definition
TfData.py:27
dft.TfData.TfDataBasf2.batch_size
batch_size
batch size
Definition
TfData.py:34
dft.TfData.TfDataBasf2.valid_x
valid_x
validation features
Definition
TfData.py:29
dft.TfData.TfDataBasf2.batch_iterator
batch_iterator(self)
Definition
TfData.py:90
dft.TfData.TfDataBasf2.train_idx
train_idx
indices required for shuffling
Definition
TfData.py:52
dft.TfData.TfDataBasf2.feature_number
feature_number
number of features
Definition
TfData.py:46
dft.TfData.TfDataBasf2.valid_events
valid_events
number of validation events
Definition
TfData.py:43
dft.TfData.TfDataBasf2.sanitize_labels
sanitize_labels(self)
Definition
TfData.py:66
dft.TfData.TfDataBasf2.train_x
train_x
training features
Definition
TfData.py:25
dft.TfData.TfDataBasf2.batches
batches
number of batches
Definition
TfData.py:49
dft.TfData.TfDataBasf2.batch_train_x
batch_train_x
np ndarray for training batch features
Definition
TfData.py:55
analysis
scripts
dft
TfData.py
Generated on Mon Sep 1 2025 02:45:54 for Belle II Software by
1.13.2