20 def __init__(self, train_x, train_y, valid_x, valid_y, batch_size, seed=None, epoch_random_shuffle=True):
21 """
23 """
24
25 self.train_x = train_x.astype(np.float32)
26
27 self.train_y = train_y.astype(np.float32)
28
29 self.valid_x = valid_x.astype(np.float32)
30
31 self.valid_y = valid_y.astype(np.float32)
32
33
34 self.batch_size = batch_size
35
36 self.seed = seed
37
38 self.epoch_random_shuffle = epoch_random_shuffle
39
40
41 self.train_events = self.train_x.shape[0]
42
43 self.valid_events = self.valid_x.shape[0]
44
45
46 self.feature_number = self.train_x.shape[1]
47
48
49 self.batches = self.train_x.shape[0] // self.batch_size
50
51
52 self.train_idx = np.zeros(self.train_x.shape[0])
53
54
55 self.batch_train_x = np.zeros((self.feature_number, self.batch_size))
56
57
58 self.batch_train_y = np.zeros(self.batch_size)
59
60
61 self.random_state = np.random.RandomState(seed)
62
63
64 self.sanitize_labels()
65