Belle II Software development
merge_samples_to_parquet.py
1#!/usr/bin/env python3
2
3
10
11import os
12import glob
13import uproot
14import numpy as np
15import pyarrow.parquet as pq
16import pyarrow as pa
17import pandas as pd
18import argparse
19import tflat.utils as utils
20
21
22def merge_root_to_parquet(root_dir, parquet_dir, mask_value, uniqueIdentifier, tree_name="tflat_variables"):
23 '''
24 Merges tflat sampled root files into one parquet file
25 '''
26 files = sorted(glob.glob(os.path.join(root_dir, f"{uniqueIdentifier}_training_data*.root")))
27 writer = None
28 n_rows = 0
29 print("Merging root files into parquet")
30 for i in range(len(files)):
31 print(f"\r{i+1}/{len(files)}", end="", flush=True)
32 f = files[i]
33 with uproot.open(f)[tree_name] as tree:
34 df = tree.arrays(library="pd")
35 n_rows += len(df)
36
37 # Rescale target variable from [-1,1] to [0,1]
38 df["qrCombined"] = df["qrCombined"].where(df["qrCombined"] != -1, 0)
39
40 # verify two-class output
41 assert set(df["qrCombined"].unique()).issubset([0, 1])
42
43 # Mask NaN values
44 df = df.fillna(mask_value)
45 # Remove columns containing meta variables
46 for column in df.columns:
47 if (column[0:2] == '__') & (column[-2:] == '__'):
48 df = df.drop(column, axis=1)
49
50 table = pa.Table.from_pandas(df)
51 if writer is None:
52 writer = pq.ParquetWriter(
53 os.path.join(parquet_dir, f'{uniqueIdentifier}_samples_merged.parquet'),
54 table.schema, compression='snappy')
55 writer.write_table(table)
56
57 writer.close()
58 print(f'\nNumber of events: {n_rows}')
59
60
61def create_dataset(pf, parquet_path, index, chunk_size, n_rowgroups, rowgroup_edges):
62 '''
63 Picks rows from parquet file according to the given index array.
64 Created paruet file is segmented into rowgroups with maximum size given by chunk_size.
65 '''
66 writer = None
67 n_chunks = (len(index)-1)//chunk_size + 1
68 for chunk in range(n_chunks):
69 print(f"\r{chunk+1}/{n_chunks}", end="", flush=True)
70 chunk_df = pd.DataFrame()
71 # Handle special case of last chunk having fewer rows
72 if (chunk+1)*chunk_size > len(index):
73 index_chunks = index[chunk*chunk_size:]
74 else:
75 index_chunks = index[chunk*chunk_size:(chunk+1)*chunk_size]
76 lower_edge = 0
77 for rowgroup in range(n_rowgroups):
78 upper_edge = rowgroup_edges[rowgroup]
79 # Select rows from rowgroup that belong into current chunk
80 rows_to_fetch = index_chunks[(index_chunks >= lower_edge) & (index_chunks < upper_edge)]
81 if len(rows_to_fetch) > 0:
82 table = pf.read_row_group(rowgroup)
83 df = table.to_pandas()
84 df = df.iloc[(rows_to_fetch-lower_edge)]
85 chunk_df = pd.concat([chunk_df, df])
86 lower_edge = upper_edge
87 table = pa.Table.from_pandas(chunk_df)
88 if writer is None:
89 writer = pq.ParquetWriter(parquet_path, table.schema, compression='NONE')
90 writer.write_table(table)
91 writer.close()
92
93
94def shuffle_and_chunk_parquet(parquet_dir, val_split, chunk_size, uniqueIdentifier):
95 '''
96 Splits single parquet file into a training and validation parquet file.
97 The data contained in the resulting files is shuffled and segmented into chunks.
98 '''
99 pf = pq.ParquetFile(os.path.join(parquet_dir, f'{uniqueIdentifier}_samples_merged.parquet'))
100 rowgroup_edges = []
101 n_rows = 0
102 n_rowgroups = pf.num_row_groups
103 for i in range(n_rowgroups):
104 n = pf.metadata.row_group(i).num_rows
105 n_rows += n
106 rowgroup_edges.append(n_rows)
107 index = np.arange(n_rows)
108 np.random.shuffle(index)
109 n_training_samples = int(n_rows*val_split)
110 index_training = index[:n_training_samples]
111 index_validation = index[n_training_samples:]
112 print('Creating training dataset')
113 create_dataset(
114 pf,
115 os.path.join(
116 parquet_dir,
117 f'{uniqueIdentifier}_training_samples.parquet'),
118 index_training,
119 chunk_size,
120 n_rowgroups,
121 rowgroup_edges)
122 print('\nCreating validation dataset')
123 create_dataset(
124 pf,
125 os.path.join(
126 parquet_dir,
127 f'{uniqueIdentifier}_validation_samples.parquet'),
128 index_validation,
129 chunk_size,
130 n_rowgroups,
131 rowgroup_edges)
132
133
134if __name__ == "__main__":
135 parser = argparse.ArgumentParser(description='Train TFlat')
136 parser.add_argument(
137 '--root_dir',
138 dest='root_dir',
139 type=str,
140 help='Path to directory where sampled root files are stored'
141 )
142 parser.add_argument(
143 '--parquet_dir',
144 dest='parquet_dir',
145 type=str,
146 help='Path to directory where parquet files are saved to'
147 )
148 parser.add_argument( # input parser
149 '--uniqueIdentifier',
150 metavar='uniqueIdentifier',
151 dest='uniqueIdentifier',
152 type=str,
153 default="TFlaT_MC16rd_light_2601_hyperion",
154 help='Name of both the config .yaml to be used and the produced weightfile'
155 )
156 args = parser.parse_args()
157 root_dir = args.root_dir
158 parquet_dir = args.parquet_dir
159 uniqueIdentifier = args.uniqueIdentifier
160 os.makedirs(parquet_dir, exist_ok=True)
161
162 config = utils.load_config(uniqueIdentifier)
163 val_split = config['train_valid_fraction']
164 chunk_size = config['chunk_size']
165 mask_value = config['parameters']['mask_value']
166
167 merge_root_to_parquet(
168 root_dir=root_dir,
169 parquet_dir=parquet_dir,
170 mask_value=mask_value,
171 uniqueIdentifier=uniqueIdentifier
172 )
173 shuffle_and_chunk_parquet(
174 parquet_dir=parquet_dir,
175 val_split=val_split,
176 chunk_size=chunk_size,
177 uniqueIdentifier=uniqueIdentifier
178 )
179 print('\nDone!')