Belle II Software light-2601-hyperion
batch_generator Class Reference
Inheritance diagram for batch_generator:
Collaboration diagram for batch_generator:

Public Member Functions

 __init__ (self, parquet_path, variables, target, batch_size, chunk_size)
 
 __len__ (self)
 
 __getitem__ (self, idx)
 

Public Attributes

 variables = variables
 List of input variable names.
 
 target = target
 Name of target variable.
 
 batch_size = batch_size
 Batch size of the model.
 
 pf = pq.ParquetFile(parquet_path)
 Parquet metadata.
 
 n_chunks = self.pf.num_row_groups
 Number of chunks in the data file.
 
int current_chunk_idx = 0
 Index of chunk currently in use.
 
 dataset_length
 Number of rows in datafile.
 
 chunk_lock = threading.Lock()
 Chunklock to avoid race conditions.
 
bool chunk_ready = False
 Flag that indicated weather the new chunk is done loading into memory.
 
 loader_thread = None
 Thread that loads new chunk while main thread is training.
 
 chunk_in_use = self.chunk_in_waiting
 Chunk currently in use.
 
 max_batches = self.max_batches_next
 Number of batches in a chunk.
 
int current_batch = 0
 Index of current batch in current chunk.
 
tuple chunk_in_waiting = (X, y)
 Next chunk.
 
 max_batches_next = max_batches
 Maximum number of batches in this chunk.
 

Protected Member Functions

 _load_chunk (self)
 
 _start_async_load (self)
 
 _wait_for_chunk (self)
 
 _get_batch (self, batch_idx)
 

Detailed Description

Generator that reads the input data into memory in chunks.

Definition at line 18 of file fitter.py.

Constructor & Destructor Documentation

◆ __init__()

__init__ ( self,
parquet_path,
variables,
target,
batch_size,
chunk_size )
Prepare all variables and prefetch 2 chunks.

Definition at line 23 of file fitter.py.

23 def __init__(self, parquet_path, variables, target, batch_size, chunk_size):
24 """
25 Prepare all variables and prefetch 2 chunks.
26 """
27 super().__init__(workers=1, use_multiprocessing=False, max_queue_size=10)
28
29 self.variables = variables
30
31 self.target = target
32
33 self.batch_size = batch_size
34
35 self.pf = pq.ParquetFile(parquet_path)
36
37 self.n_chunks = self.pf.num_row_groups
38
39 self.current_chunk_idx = 0
40
41 self.dataset_length = sum(
42 self.pf.metadata.row_group(i).num_rows for i in range(self.n_chunks)
43 )
44
45 # Multithreading
46
47 self.chunk_lock = threading.Lock()
48
49 self.chunk_ready = False
50
51 self.loader_thread = None
52
53 # Prefetch first chunk
54 self._start_async_load()
55 self._wait_for_chunk() # ensure first chunk is ready
56
57 self.chunk_in_use = self.chunk_in_waiting
58
59 self.max_batches = self.max_batches_next
60
61 # Prepare next chunk
62 self._start_async_load()
63
64
65 self.current_batch = 0
66

Member Function Documentation

◆ __getitem__()

__getitem__ ( self,
idx )
Returns the next batch used in training

Definition at line 73 of file fitter.py.

73 def __getitem__(self, idx):
74 """
75 Returns the next batch used in training
76 """
77 if self.current_batch >= self.max_batches:
78 self._wait_for_chunk()
79
80 with self.chunk_lock:
81 self.chunk_in_use = self.chunk_in_waiting
82 self.max_batches = self.max_batches_next
83 self.current_batch = 0
84
85 self._start_async_load()
86
87 X, y = self._get_batch(self.current_batch)
88 self.current_batch += 1
89 return X, y
90

◆ __len__()

__len__ ( self)
Returns number of batches in dataset

Definition at line 67 of file fitter.py.

67 def __len__(self):
68 """
69 Returns number of batches in dataset
70 """
71 return self.dataset_length // self.batch_size
72

◆ _get_batch()

_get_batch ( self,
batch_idx )
protected
Extract next batch from chunk

Definition at line 130 of file fitter.py.

130 def _get_batch(self, batch_idx):
131 '''
132 Extract next batch from chunk
133 '''
134 X, y = self.chunk_in_use
135 i0 = batch_idx * self.batch_size
136 i1 = i0 + self.batch_size
137 return X[i0:i1], y[i0:i1]
138
139

◆ _load_chunk()

_load_chunk ( self)
protected
Load next chunk from datafile and shuffle it

Definition at line 91 of file fitter.py.

91 def _load_chunk(self):
92 """
93 Load next chunk from datafile and shuffle it
94 """
95 rg = self.current_chunk_idx
96 table = self.pf.read_row_group(rg)
97 # Shuffle data
98 df = table.to_pandas().sample(frac=1).reset_index(drop=True)
99
100 X = df[self.variables].to_numpy()
101 y = df[self.target].to_numpy()
102 max_batches = len(df) // self.batch_size
103
104 # Publish chunk
105 with self.chunk_lock:
106
107 self.chunk_in_waiting = (X, y)
108
109 self.max_batches_next = max_batches
110 self.chunk_ready = True
111
112 # Move to next row group
113 self.current_chunk_idx = (self.current_chunk_idx + 1) % self.n_chunks
114

◆ _start_async_load()

_start_async_load ( self)
protected
Start new thread to load new chunk

Definition at line 115 of file fitter.py.

115 def _start_async_load(self):
116 '''
117 Start new thread to load new chunk
118 '''
119 self.chunk_ready = False
120 self.loader_thread = threading.Thread(target=self._load_chunk, daemon=True)
121 self.loader_thread.start()
122

◆ _wait_for_chunk()

_wait_for_chunk ( self)
protected
Sleep until second thread is finished with loading the next chunk

Definition at line 123 of file fitter.py.

123 def _wait_for_chunk(self):
124 '''
125 Sleep until second thread is finished with loading the next chunk
126 '''
127 while not self.chunk_ready:
128 time.sleep(5)
129

Member Data Documentation

◆ batch_size

batch_size = batch_size

Batch size of the model.

Definition at line 33 of file fitter.py.

◆ chunk_in_use

chunk_in_use = self.chunk_in_waiting

Chunk currently in use.

Definition at line 57 of file fitter.py.

◆ chunk_in_waiting

tuple chunk_in_waiting = (X, y)

Next chunk.

Definition at line 107 of file fitter.py.

◆ chunk_lock

chunk_lock = threading.Lock()

Chunklock to avoid race conditions.

Definition at line 47 of file fitter.py.

◆ chunk_ready

bool chunk_ready = False

Flag that indicated weather the new chunk is done loading into memory.

Definition at line 49 of file fitter.py.

◆ current_batch

int current_batch = 0

Index of current batch in current chunk.

Definition at line 65 of file fitter.py.

◆ current_chunk_idx

int current_chunk_idx = 0

Index of chunk currently in use.

Definition at line 39 of file fitter.py.

◆ dataset_length

dataset_length
Initial value:
= sum(
self.pf.metadata.row_group(i).num_rows for i in range(self.n_chunks)
)

Number of rows in datafile.

Definition at line 41 of file fitter.py.

◆ loader_thread

loader_thread = None

Thread that loads new chunk while main thread is training.

Definition at line 51 of file fitter.py.

◆ max_batches

max_batches = self.max_batches_next

Number of batches in a chunk.

Definition at line 59 of file fitter.py.

◆ max_batches_next

max_batches_next = max_batches

Maximum number of batches in this chunk.

Definition at line 109 of file fitter.py.

◆ n_chunks

n_chunks = self.pf.num_row_groups

Number of chunks in the data file.

Definition at line 37 of file fitter.py.

◆ pf

pf = pq.ParquetFile(parquet_path)

Parquet metadata.

Definition at line 35 of file fitter.py.

◆ target

target = target

Name of target variable.

Definition at line 31 of file fitter.py.

◆ variables

variables = variables

List of input variable names.

Definition at line 29 of file fitter.py.


The documentation for this class was generated from the following file: