Belle II Software development
preprocess.py
1
8import re
9import random
10import string
11import awkward as ak
12import numpy as np
13import pandas as pd
14from collections import defaultdict
15
16from smartBKG import PREPROC_CONFIG, TOKENIZE_DICT, LIST_FIELDS
17
18
19def check_status_bit(status_bit):
20 """
21 Check whether the corresponding particle is usable according to its status_bit,
22 which means not virtual, not initial, not ISR or FSR photon.
23
24 Arguments:
25 status_bit (short int): 1-based index of particle showing its status.
26 More details in `mdst/dataobjects/include/MCParticle.h`
27
28 Returns:
29 bool: Whether conditions are satisfied (not an unusable particle).
30 """
31 return (
32 (status_bit & 1 << 4 == 0) & # IsVirtual
33 (status_bit & 1 << 5 == 0) & # Initial
34 (status_bit & 1 << 6 == 0) & # ISRPhoton
35 (status_bit & 1 << 7 == 0) # FSRPhoton
36 )
37
38
39def load_particle_list(mcplist, **meta_kwargs):
40 """
41 Collect variables from MC particle list.
42
43 Arguments:
44 mcplist (Belle2.PyStoreArray("MCParticles")): MC particle list in belle2 data store.
45 meta_kwargs: extra event level variables that will be copied through the particle list.
46
47 Returns:
48 pandas dataframe: particle list containing all the necessary information.
49 """
50 particle_dict = defaultdict(list)
51 root_prodTime = defaultdict(list)
52 # Create particle features
53 for mcp in mcplist:
54 prodTime = mcp.getProductionTime()
55 # Collect indices for graph building
56 arrayIndex = mcp.getArrayIndex()
57 mother = mcp.getMother()
58 if mother:
59 motherArrayIndex = mother.getArrayIndex()
60 # pass the production time of root particle for the correction of jitter
61 root_prodTime[arrayIndex] = root_prodTime[motherArrayIndex]
62 if mother.isVirtual():
63 motherArrayIndex = arrayIndex
64 else:
65 motherArrayIndex = arrayIndex
66 # record the production time of root particle for the correction of jitter
67 root_prodTime[arrayIndex] = prodTime
68
69 if mcp.isPrimaryParticle() and check_status_bit(mcp.getStatus()):
70 four_vec = mcp.get4Vector()
71 prod_vec = mcp.getProductionVertex()
72 # indices
73 particle_dict['arrayIndex'].append(arrayIndex)
74 # features
75 particle_dict['PDG'].append(mcp.getPDG())
76 particle_dict['mass'].append(mcp.getMass())
77 particle_dict['charge'].append(mcp.getCharge())
78 particle_dict['energy'].append(mcp.getEnergy())
79 particle_dict['prodTime'].append(prodTime-root_prodTime[arrayIndex])
80 particle_dict['x'].append(prod_vec.x())
81 particle_dict['y'].append(prod_vec.y())
82 particle_dict['z'].append(prod_vec.z())
83 particle_dict['px'].append(four_vec.Px())
84 particle_dict['py'].append(four_vec.Py())
85 particle_dict['pz'].append(four_vec.Pz())
86 particle_dict['motherIndex'].append(motherArrayIndex)
87 particle_dict.update(meta_kwargs)
88 return pd.DataFrame(particle_dict)
89
90
91def ak_from_df(
92 df,
93 decorr_df=None,
94 columns=None,
95 missing_values=False,
96 convert_to_32=True,
97):
98 """
99 Load pandas data frame stored in parquet into an awkward array
100
101 Particle-level quantities will be lists of variable length, the other
102 variables are assumed to be event-level. Grouping will be done based on the
103 `evtNum` column.
104
105 Arguments:
106 df (pandas dataframe): particle-level information.
107 decorr_df (pandas dataframe): event-level information.
108 columns (list): read only the listed columns (None for all) - passed to `ak.from_parquet`
109 missing_values (bool): if False, assume there are no missing values in the particle
110 lists and drop the masks. For the event-level quantities, replace missing
111 values with nan. Avoiding option types might speedup the subsequent processing.
112 convert_to_32 (bool): convert int64 to int32 and float64 to float32
113
114 Returns:
115 Awkward array with particle quantities as Lists and event-level quantities as flat arrays
116
117 Note:
118 Example:
119 >>> import io
120 >>> import pandas as pd
121 >>> df = pd.DataFrame({
122 ... "x": [1, 2, 3, 4],
123 ... "y": [5, 6, 7, 8],
124 ... "label": [True, True, False, False],
125 ... "evtNum": [0, 0, 1, 1],
126 ... })
127 >>> f = io.BytesIO()
128 >>> df.to_parquet(f)
129 >>> ak_array = ak_from_parquet_df(f)
130 >>> ak_array.particles.x
131 <Array [[1, 2], [3, 4]] type='2 * var * int32'>
132 >>> ak_array.label
133 <Array [True, False] type='2 * bool'>
134 """
135 # to group by events (unflatten), we need to find the particle count for each event
136 evt_nums = df.evtNum.values
137 df.set_index('evtNum')
138 if decorr_df is not None:
139 decorr_df = decorr_df.set_index('evtNum')
140 df = df.join(other=decorr_df, on='evtNum', how='inner')
141 df = df.reset_index().set_index(['label', 'evtNum', 'arrayIndex'])
142 unique, indices, counts = np.unique(evt_nums, return_index=True, return_counts=True)
143 # reverse the sorting that np.unique did
144 counts = counts[np.argsort(indices)]
145 ak_array = ak.unflatten(ak.Array(df.to_records()), counts)
146
147 out = {"particles": {}}
148 for field in ak_array.fields:
149 if field not in LIST_FIELDS:
150 # for event-level quantities we assume that the first entry is equal to all entries
151 out[field] = ak_array[field][:, 0]
152 if convert_to_32:
153 out[field] = values_as_32(out[field])
154 else:
155 # particle-level quantities stay lists
156 out["particles"][field] = ak_array[field]
157 if convert_to_32:
158 out["particles"][field] = values_as_32(out["particles"][field])
159 out["particles"] = ak.zip(out["particles"])
160 out = ak.zip(out, depth_limit=1)
161
162 if not missing_values:
163 out = remove_masks(out)
164
165 return out
166
167
168def values_as_32(array):
169 """
170 Convert int64 to int32 and float64 to float32 in the given array for the processing in Pytorch.
171
172 Arguments:
173 array (awkward array): any.
174
175 Returns:
176 awkward array: the converted array.
177 """
178 ak_type = ak.type(array.layout)
179 while not isinstance(ak_type, ak.types.PrimitiveType):
180 ak_type = ak_type.type
181 dtype = ak_type.dtype
182 if dtype == "int64":
183 return ak.values_astype(array, np.int32)
184 if dtype == "float64":
185 return ak.values_astype(array, np.float32)
186 return array
187
188
189def remove_masks(array):
190 """
191 Drop masks for particle-level quantities and replace missing values by nan for event-level quantities
192
193 Arguments:
194 array (awkward array): any.
195
196 Returns:
197 awkward array: the processed array.
198 """
199 out = array[:]
200 for field in out.fields:
201 if field == "particles":
202 continue
203 out[field] = ak.fill_none(out[field], np.nan)
204 for field in out.particles.fields:
205 masked = out.particles[field].layout
206 if not isinstance(masked, ak.layout.ListOffsetArray64):
207 raise TypeError(
208 "Wrong type - this method only works with ListOffsetArray for the particle fields"
209 )
210 out["particles", field] = ak.Array(
211 # explicitly construct a ListOffsetArray (without the mask)
212 ak.layout.ListOffsetArray64(masked.offsets, masked.content)
213 )
214 return out
215
216
217def mapped_mother_index_flat(array_indices_flat, mother_indices_flat, total, sizes, dict_size):
218 """
219 Map mother indices for particle arrays to handle removed mothers.
220
221 Arguments:
222 array_indices_flat (array): flat array indices of the retained particles from MC particle list.
223 mother_indices_flat (array): flat array indices of the mother particles of the retained particles.
224 total (int): total number of particles in all the events.
225 sizes (array): numbers of particles in each event.
226 dict_size (int): maximum number of different indices.
227
228 Returns:
229 array: flat mother indices after correction.
230 """
231 out = np.empty(total, dtype=np.int32)
232 i = 0
233 idx_dict = np.empty(dict_size, dtype=np.int32)
234 start = 0
235 for size in sizes:
236 stop = start + size
237 array_indices = array_indices_flat[start:stop]
238 mother_indices = mother_indices_flat[start:stop]
239 # fill idx_dict
240 for original_index in mother_indices:
241 # default -1 (will represent mothers that have been removed)
242 idx_dict[original_index] = -1
243 for mapped_index, original_index in enumerate(array_indices):
244 # indices of still existing mothers
245 idx_dict[original_index] = mapped_index
246 # remap
247 for mother_index in mother_indices:
248 out[i] = idx_dict[mother_index]
249 i += 1
250 start = stop
251 return out
252
253
254def mapped_mother_index(array_indices, mother_indices):
255 """
256 Map mother indices for particle arrays to handle removed mothers for awkward arrays.
257
258 Arguments:
259 array_indices (awkward array): array indices of the retained particles from MC particle list.
260 mother_indices (awkward array): array indices of the mother particles of the retained particles.
261
262 Returns:
263 awkward array: mother indices after correction.
264 """
265 max_dict_index = max(ak.max(array_indices), ak.max(mother_indices))
266 dict_size = max_dict_index + 1
267 flat = mapped_mother_index_flat(
268 ak.to_numpy(ak.flatten(ak.fill_none(array_indices, -1))),
269 ak.to_numpy(ak.flatten(ak.fill_none(mother_indices, -1))),
270 sizes=ak.num(array_indices),
271 total=ak.sum(ak.num(array_indices)),
272 dict_size=dict_size,
273 )
274 return ak.unflatten(flat, ak.num(array_indices))
275
276
277def map_np(array, mapping):
278 """
279 Map PDG IDs to tokens.
280
281 Arguments:
282 pdg (array): PDG IDs.
283
284 Returns:
285 array: array after PDG ID mapping.
286 """
287 unique, inv = np.unique(array, return_inverse=True)
288 np_mapping = np.array([mapping[x] for x in unique])
289 return np_mapping[inv]
290
291
292def mapped_pdg_id(pdg):
293 """
294 Map PDG IDs to tokens for awkward arrays.
295
296 Arguments:
297 pdg (awkward array): PDG IDs.
298
299 Returns:
300 awkward array: awkward array after PDG ID mapping.
301 """
302 return ak.unflatten(
303 map_np(ak.to_numpy(ak.flatten(pdg)), TOKENIZE_DICT), ak.num(pdg)
304 )
305
306
307def evaluate_query(array, query):
308 """
309 Evaluate a query on the awkward array, pd.DataFrame.evaluate - style
310 Can also pass a callable that takes an awkward array and returns an awkward array mask
311
312 Arguments:
313 array (awkward array or dataframe): any.
314 query (str): queries for particle selection.
315
316 Returns:
317 awkward array: awkward array after particle selection.
318 """
319 if callable(query):
320 return query(array)
321
322 # merge event-level and particle-level quantities
323 # such that queries can directly access both
324 # e.g. "x > 5" instead of "particles.x > 5"
325 array_dict = {
326 **dict(zip(array.fields, ak.unzip(array))),
327 **dict(zip(array.particles.fields, ak.unzip(array.particles)))
328 }
329
330 # replace quoted fieldnames - e.g `nParticlesInList(B0:feiHadronic)` by random strings
331 # such that ak.numexpr.evaluate can handle them
332 joined_query = " & ".join(f"({q})" for q in query)
333 quoted_fieldnames = set(re.findall("`[^`]*`", joined_query))
334 name_mapping = {k: "".join(random.choices(string.ascii_lowercase, k=20)) for k in quoted_fieldnames}
335 quoted_fields = {name_mapping[field]: array_dict[field.replace("`", "")] for field in quoted_fieldnames}
336 for field, rnd_name in name_mapping.items():
337 joined_query = joined_query.replace(field, rnd_name)
338
339 return ak.numexpr.evaluate(joined_query, local_dict={**array_dict, **quoted_fields})
340
341
342def preprocessed(df, decorr_df=None, particle_selection=PREPROC_CONFIG['cuts']):
343 """
344 Preprocess the input dataframe and return an awkward array that is ready for graph building.
345
346 Arguments:
347 df (pandas dataframe): containing particle-level information.
348 decorr_df (pandas dataframe): containing event-level information.
349 particle_selection (str): queries for particle selection.
350
351 Returns:
352 awkward array: awkward array after preprocessing.
353 """
354 array = ak_from_df(df, decorr_df)[:]
355 array["particles"] = array.particles[evaluate_query(array, particle_selection)]
356 array["particles", "PDG"] = mapped_pdg_id(array.particles.PDG)
357 array["particles", "motherIndex"] = mapped_mother_index(
358 array.particles.arrayIndex,
359 array.particles.motherIndex,
360 )
361 return array
A (simplified) python wrapper for StoreArray.
Definition: PyStoreArray.h:72