Belle II Software light-2601-hyperion
merge_samples_to_parquet.py
1#!/usr/bin/env python3
2
3
10
11import os
12import glob
13import uproot
14import numpy as np
15import pyarrow.parquet as pq
16import pyarrow as pa
17import pandas as pd
18import argparse
19import tflat.utils as utils
20
21
22def merge_root_to_parquet(root_dir, parquet_dir, mask_value, uniqueIdentifier, tree_name="tflat_variables"):
23 '''
24 Merges tflat sampled root files into one parquet file
25 '''
26 files = sorted(glob.glob(os.path.join(root_dir, f"{uniqueIdentifier}_training_data*.root")))
27 writer = None
28 print("Merging root files into parquet")
29 for i in range(len(files)):
30 print(f"\r{i+1}/{len(files)}", end="", flush=True)
31 f = files[i]
32 with uproot.open(f)[tree_name] as tree:
33 df = tree.arrays(library="pd")
34
35 # Rescale target variable from [-1,1] to [0,1]
36 m = df["qrCombined"].min()
37 if m != 0:
38 df["qrCombined"] = df["qrCombined"].where(df["qrCombined"] != m, 0)
39
40 # verify two-class output
41 assert len(df["qrCombined"].unique()) == 2
42
43 # Mask NaN values
44 df = df.fillna(mask_value)
45 # Remove columns containing meta variables
46 for column in df.columns:
47 if (column[0:2] == '__') & (column[-2:] == '__'):
48 df = df.drop(column, axis=1)
49
50 table = pa.Table.from_pandas(df)
51 if writer is None:
52 writer = pq.ParquetWriter(
53 os.path.join(parquet_dir, f'{uniqueIdentifier}_samples_merged.parquet'),
54 table.schema, compression='snappy')
55 writer.write_table(table)
56
57 writer.close()
58
59
60def create_dataset(pf, parquet_path, index, chunk_size, n_rowgroups, rowgroup_edges):
61 '''
62 Picks rows from parquet file according to the given index array.
63 Created paruet file is segmented into rowgroups with maximum size given by chunk_size.
64 '''
65 writer = None
66 n_chunks = (len(index)-1)//chunk_size + 1
67 for chunk in range(n_chunks):
68 print(f"\r{chunk+1}/{n_chunks}", end="", flush=True)
69 chunk_df = pd.DataFrame()
70 # Handle special case of last chunk having fewer rows
71 if (chunk+1)*chunk_size > len(index):
72 index_chunks = index[chunk*chunk_size:]
73 else:
74 index_chunks = index[chunk*chunk_size:(chunk+1)*chunk_size]
75 lower_edge = 0
76 for rowgroup in range(n_rowgroups):
77 upper_edge = rowgroup_edges[rowgroup]
78 # Select rows from rowgroup that belong into current chunk
79 rows_to_fetch = index_chunks[(index_chunks >= lower_edge) & (index_chunks < upper_edge)]
80 if len(rows_to_fetch) > 0:
81 table = pf.read_row_group(rowgroup)
82 df = table.to_pandas()
83 df = df.iloc[(rows_to_fetch-lower_edge)]
84 chunk_df = pd.concat([chunk_df, df])
85 lower_edge = upper_edge
86 table = pa.Table.from_pandas(chunk_df)
87 if writer is None:
88 writer = pq.ParquetWriter(parquet_path, table.schema, compression='NONE')
89 writer.write_table(table)
90 writer.close()
91
92
93def shuffle_and_chunk_parquet(parquet_dir, val_split, chunk_size, uniqueIdentifier):
94 '''
95 Splits single parquet file into a training and validation parquet file.
96 The data contained in the resulting files is shuffled and segmented into chunks.
97 '''
98 pf = pq.ParquetFile(os.path.join(parquet_dir, f'{uniqueIdentifier}_samples_merged.parquet'))
99 rowgroup_edges = []
100 n_rows = 0
101 n_rowgroups = pf.num_row_groups
102 for i in range(n_rowgroups):
103 n = pf.metadata.row_group(i).num_rows
104 n_rows += n
105 rowgroup_edges.append(n_rows)
106 index = np.arange(n_rows)
107 np.random.shuffle(index)
108 n_training_samples = int(n_rows*val_split)
109 index_training = index[:n_training_samples]
110 index_validation = index[n_training_samples:]
111 print('\nCreating training dataset')
112 create_dataset(
113 pf,
114 os.path.join(
115 parquet_dir +
116 f'{uniqueIdentifier}_training_samples.parquet'),
117 index_training,
118 chunk_size,
119 n_rowgroups,
120 rowgroup_edges)
121 print('\nCreating validation dataset')
122 create_dataset(
123 pf,
124 os.path.join(
125 parquet_dir +
126 f'{uniqueIdentifier}_validation_samples.parquet'),
127 index_validation,
128 chunk_size,
129 n_rowgroups,
130 rowgroup_edges)
131
132
133if __name__ == "__main__":
134 parser = argparse.ArgumentParser(description='Train TFlat')
135 parser.add_argument(
136 '--root_dir',
137 dest='root_dir',
138 type=str,
139 help='Path to directory where sampled root files are stored'
140 )
141 parser.add_argument(
142 '--parquet_dir',
143 dest='parquet_dir',
144 type=str,
145 help='Path to directory where parquet files are saved to'
146 )
147 parser.add_argument( # input parser
148 '--uniqueIdentifier',
149 metavar='uniqueIdentifier',
150 dest='uniqueIdentifier',
151 type=str,
152 default="TFlaT_MC16rd_light_2601_hyperion",
153 help='Name of both the config .yaml to be used and the produced weightfile'
154 )
155 args = parser.parse_args()
156 root_dir = args.root_dir
157 parquet_dir = args.parquet_dir
158 uniqueIdentifier = args.uniqueIdentifier
159 os.makedirs(parquet_dir, exist_ok=True)
160
161 config = utils.load_config(uniqueIdentifier)
162 val_split = config['train_valid_fraction']
163 chunk_size = config['chunk_size']
164 mask_value = config['parameters']['mask_value']
165
166 merge_root_to_parquet(
167 root_dir=root_dir,
168 parquet_dir=parquet_dir,
169 mask_value=mask_value,
170 uniqueIdentifier=uniqueIdentifier
171 )
172 shuffle_and_chunk_parquet(
173 parquet_dir=parquet_dir,
174 val_split=val_split,
175 chunk_size=chunk_size,
176 uniqueIdentifier=uniqueIdentifier
177 )
178 print('\nDone!')