Belle II Software light-2601-hyperion
fitter.py
1#!/usr/bin/env python3
2
3
10
11import keras
12import ROOT
13import threading
14import pyarrow.parquet as pq
15import time
16
17
18class batch_generator(keras.utils.PyDataset):
19 '''
20 Generator that reads the input data into memory in chunks.
21 '''
22
23 def __init__(self, parquet_path, variables, target, batch_size, chunk_size):
24 """
25 Prepare all variables and prefetch 2 chunks.
26 """
27 super().__init__(workers=1, use_multiprocessing=False, max_queue_size=10)
28
29 self.variables = variables
30
31 self.target = target
32
33 self.batch_size = batch_size
34
35 self.pf = pq.ParquetFile(parquet_path)
36
37 self.n_chunks = self.pf.num_row_groups
38
40
41 self.dataset_length = sum(
42 self.pf.metadata.row_group(i).num_rows for i in range(self.n_chunks)
43 )
44
45 # Multithreading
46
47 self.chunk_lock = threading.Lock()
48
49 self.chunk_ready = False
50
51 self.loader_thread = None
52
53 # Prefetch first chunk
55 self._wait_for_chunk() # ensure first chunk is ready
56
58
60
61 # Prepare next chunk
63
64
66
67 def __len__(self):
68 """
69 Returns number of batches in dataset
70 """
71 return self.dataset_length // self.batch_size
72
73 def __getitem__(self, idx):
74 """
75 Returns the next batch used in training
76 """
77 if self.current_batch >= self.max_batches:
78 self._wait_for_chunk()
79
80 with self.chunk_lock:
83 self.current_batch = 0
84
86
87 X, y = self._get_batch(self.current_batch)
88 self.current_batch += 1
89 return X, y
90
91 def _load_chunk(self):
92 """
93 Load next chunk from datafile and shuffle it
94 """
95 rg = self.current_chunk_idx
96 table = self.pf.read_row_group(rg)
97 # Shuffle data
98 df = table.to_pandas().sample(frac=1).reset_index(drop=True)
99
100 X = df[self.variables].to_numpy()
101 y = df[self.target].to_numpy()
102 max_batches = len(df) // self.batch_size
103
104 # Publish chunk
105 with self.chunk_lock:
106
107 self.chunk_in_waiting = (X, y)
108
109 self.max_batches_next = max_batches
110 self.chunk_ready = True
111
112 # Move to next row group
113 self.current_chunk_idx = (self.current_chunk_idx + 1) % self.n_chunks
114
116 '''
117 Start new thread to load new chunk
118 '''
119 self.chunk_ready = False
120 self.loader_thread = threading.Thread(target=self._load_chunk, daemon=True)
121 self.loader_thread.start()
122
124 '''
125 Sleep until second thread is finished with loading the next chunk
126 '''
127 while not self.chunk_ready:
128 time.sleep(5)
129
130 def _get_batch(self, batch_idx):
131 '''
132 Extract next batch from chunk
133 '''
134 X, y = self.chunk_in_use
135 i0 = batch_idx * self.batch_size
136 i1 = i0 + self.batch_size
137 return X[i0:i1], y[i0:i1]
138
139
140def fit(model, train_file, val_file, treename, variables, target_variable, config, checkpoint_filepath):
141 variables = list(map(ROOT.Belle2.MakeROOTCompatible.makeROOTCompatible, variables))
142 batch_size = config['batch_size']
143 chunk_size = config['chunk_size']
144
145 train_ds = batch_generator(train_file, variables, target_variable, batch_size, chunk_size)
146 val_ds = batch_generator(val_file, variables, target_variable, batch_size, chunk_size)
147
148 # configure early stopping callback
149 callbacks = [keras.callbacks.EarlyStopping(
150 monitor='val_loss',
151 min_delta=0,
152 patience=config['patience'],
153 verbose=1,
154 mode='auto',
155 baseline=None,
156 restore_best_weights=True)]
157
158 # configure checkpointing callback
159 model_checkpoint_callback = keras.callbacks.ModelCheckpoint(
160 filepath=checkpoint_filepath,
161 monitor='val_loss',
162 mode='min',
163 save_best_only=True)
164 callbacks.append(model_checkpoint_callback)
165
166 # perform fit() with callbacks
167 model.fit(
168 train_ds,
169 validation_data=val_ds,
170 steps_per_epoch=len(train_ds),
171 validation_steps=len(val_ds),
172 epochs=config['epochs'],
173 callbacks=callbacks,
174 verbose=2)
175
176 print()
_get_batch(self, batch_idx)
Definition fitter.py:130
pf
Parquet metadata.
Definition fitter.py:35
__getitem__(self, idx)
Definition fitter.py:73
tuple chunk_in_waiting
Next chunk.
Definition fitter.py:107
chunk_lock
Chunklock to avoid race conditions.
Definition fitter.py:47
target
Name of target variable.
Definition fitter.py:31
variables
List of input variable names.
Definition fitter.py:29
max_batches_next
Maximum number of batches in this chunk.
Definition fitter.py:109
loader_thread
Thread that loads new chunk while main thread is training.
Definition fitter.py:51
bool chunk_ready
Flag that indicated weather the new chunk is done loading into memory.
Definition fitter.py:49
int current_batch
Index of current batch in current chunk.
Definition fitter.py:65
chunk_in_use
Chunk currently in use.
Definition fitter.py:57
__init__(self, parquet_path, variables, target, batch_size, chunk_size)
Definition fitter.py:23
n_chunks
Number of chunks in the data file.
Definition fitter.py:37
batch_size
Batch size of the model.
Definition fitter.py:33
int current_chunk_idx
Index of chunk currently in use.
Definition fitter.py:39
dataset_length
Number of rows in datafile.
Definition fitter.py:41
max_batches
Number of batches in a chunk.
Definition fitter.py:59