14import pyarrow.parquet
as pq
20 Generator that reads the input data into memory in chunks.
23 def __init__(self, parquet_path, variables, target, batch_size, chunk_size):
25 Prepare all variables and prefetch 2 chunks.
27 super().
__init__(workers=1, use_multiprocessing=
False, max_queue_size=10)
35 self.
pf = pq.ParquetFile(parquet_path)
42 self.
pf.metadata.row_group(i).num_rows
for i
in range(self.
n_chunks)
69 Returns number of batches in dataset
75 Returns the next batch used in training
93 Load next chunk from datafile and shuffle it
96 table = self.
pf.read_row_group(rg)
98 df = table.to_pandas().sample(frac=1).reset_index(drop=
True)
101 y = df[self.
target].to_numpy()
117 Start new thread to load new chunk
125 Sleep until second thread is finished with loading the next chunk
132 Extract next batch from chunk
137 return X[i0:i1], y[i0:i1]
140def fit(model, train_file, val_file, treename, variables, target_variable, config, checkpoint_filepath):
141 variables = list(map(ROOT.Belle2.MakeROOTCompatible.makeROOTCompatible, variables))
142 batch_size = config[
'batch_size']
143 chunk_size = config[
'chunk_size']
145 train_ds =
batch_generator(train_file, variables, target_variable, batch_size, chunk_size)
146 val_ds =
batch_generator(val_file, variables, target_variable, batch_size, chunk_size)
149 callbacks = [keras.callbacks.EarlyStopping(
152 patience=config[
'patience'],
156 restore_best_weights=
True)]
159 model_checkpoint_callback = keras.callbacks.ModelCheckpoint(
160 filepath=checkpoint_filepath,
164 callbacks.append(model_checkpoint_callback)
169 validation_data=val_ds,
170 steps_per_epoch=len(train_ds),
171 validation_steps=len(val_ds),
172 epochs=config[
'epochs'],
_get_batch(self, batch_idx)
tuple chunk_in_waiting
Next chunk.
chunk_lock
Chunklock to avoid race conditions.
target
Name of target variable.
variables
List of input variable names.
max_batches_next
Maximum number of batches in this chunk.
loader_thread
Thread that loads new chunk while main thread is training.
bool chunk_ready
Flag that indicated weather the new chunk is done loading into memory.
int current_batch
Index of current batch in current chunk.
chunk_in_use
Chunk currently in use.
__init__(self, parquet_path, variables, target, batch_size, chunk_size)
n_chunks
Number of chunks in the data file.
batch_size
Batch size of the model.
int current_chunk_idx
Index of chunk currently in use.
dataset_length
Number of rows in datafile.
max_batches
Number of batches in a chunk.