Belle II Software development
geometric_datasets.py
1
8
9
10import itertools
11from pathlib import Path
12import numpy as np
13import torch
14from .tree_utils import masses_to_classes
15from .dataset_utils import populate_avail_samples, preload_root_data
16from .edge_features import compute_edge_features
17from .normalize_features import normalize_features
18from torch_geometric.data import Data, InMemoryDataset
19import uproot
20
21
22def _preload(self):
23 """
24 Creates graph objects and stores them into a python list.
25 """
26
27
28 self.x_files = sorted(self.root.glob("**/*.root"))
29
30 # Select the first N files (useful for testing)
31 if self.n_files is not None:
32 if self.n_files > len(self.x_files):
33 print(
34 f"WARNING: n_files specified ({self.n_files}) is greater than files in path given ({len(self.x_files)})"
35 )
36
37 self.x_files = self.x_files[: self.n_files]
38
39 if len(self.x_files) == 0:
40 raise RuntimeError(f"No files found in {self.root}")
41
42 # Save the features
43 with uproot.open(self.x_files[0])["Tree"] as t:
44
45 self.features = [f for f in t.keys() if f.startswith("feat_")]
46
47 self.B_reco = int(t["isB"].array(library="np")[0])
48 assert self.B_reco in [0, 1, 2], "B_reco should be 0, 1 or 2, something went wrong"
49
50
51 self.discarded = [
52 f for f in self.features if not f[f.find("_") + 1:] in self.node_features
53 ]
54 self.features = [
55 f"feat_{f}" for f in self.node_features if f"feat_{f}" in self.features
56 ]
57
58 print(f"Input node features: {self.features}")
59 print(f"Discarded node features: {self.discarded}")
60
61
62 self.edge_features = [f"edge_{f}" for f in self.edge_features]
63
64 self.global_features = [f"glob_{f}" for f in self.global_features] if self.global_features else []
65 print(f"Input edge features: {self.edge_features}")
66 print(f"Input global features: {self.global_features}")
67
68
69 self.x, self.y = preload_root_data(
70 self.x_files,
71 self.features,
72 self.discarded,
73 )
74
75
76 self.avail_samples = populate_avail_samples(
77 self.x,
78 self.y,
79 self.B_reco,
80 )
81
82 # Select a subset of available samples if requested
83 if self.samples and self.samples < len(self.avail_samples):
84 print(f"Selecting random subset of {self.samples} samples")
85 self.avail_samples = [
86 self.avail_samples[i]
87 for i in np.random.choice(
88 len(self.avail_samples), self.samples, replace=False
89 )
90 ]
91 elif self.samples and (self.samples >= len(self.avail_samples)):
92 print(
93 f"WARNING: No. samples specified ({self.samples}) exceeds number of samples loaded ({len(self.avail_samples)})"
94 )
95
96 return len(self.avail_samples)
97
98
99def _process_graph(self, idx):
100 """
101 Actually builds the graph object.
102
103 Args:
104 idx (int): Index of training example to be processed.
105
106 Returns:
107 torch_geometric.data.Data: Graph object to be used in training.
108 """
109
110 file_id, evt, p_index = self.avail_samples[idx]
111
112 x_item = self.x[file_id]
113 y_item = self.y[file_id][p_index]
114
115 evt_b_index = x_item["b_index"][evt]
116 evt_leaves = x_item["leaves"][evt]
117 evt_primary = x_item["primary"][evt]
118
119 y_leaves = y_item["LCA_leaves"][evt]
120 # Use this to correctly reshape LCA (might be able to just use shape of y_leaves?)
121 n_LCA = y_item["n_LCA"][evt]
122
123 # Get the rows of the X inputs to fetch
124 # This is a boolean numpy array
125 x_rows = (evt_b_index != -1) if not self.B_reco else evt_b_index == int(p_index)
126
127 # Find the unmatched particles
128 unmatched_rows = evt_b_index == -1
129
130 if np.any(unmatched_rows) and self.B_reco:
131 # Create a random boolean array the same size as the number of leaves
132 rand_mask = np.random.choice(a=[False, True], size=unmatched_rows.size)
133 # AND the mask with the unmatched leaves
134 # This selects a random subset of the unmatched leaves
135 unmatched_rows = np.logical_and(unmatched_rows, rand_mask)
136
137 # Add the unmatched rows to the current decay's rows
138 x_rows = np.logical_or(x_rows, unmatched_rows)
139
140 # Here we actually load the data
141
142 # Initialise event's X array
143 x = np.empty((x_rows.sum(), len(self.features)))
144 x_dis = np.empty((x_rows.sum(), len(self.discarded)))
145
146 # And populate it
147 for idx, feat in enumerate(self.features):
148 x[:, idx] = x_item["features"][feat][evt][x_rows]
149 for idx, feat in enumerate(self.discarded):
150 x_dis[:, idx] = x_item["discarded"][feat][evt][x_rows]
151
152 # Same for edge and global features
153 x_edges = (
154 compute_edge_features(
155 self.edge_features,
156 self.features + self.discarded,
157 np.concatenate([x, x_dis], axis=1),
158 )
159 if self.edge_features is not []
160 else []
161 )
162 x_global = (
163 np.array(
164 [
165 [
166 x_item["global"][feat + f"_{p_index}"][evt]
167 for feat in self.global_features
168 ]
169 ]
170 )
171 if self.global_features != []
172 else []
173 )
174
175 x_leaves = evt_leaves[x_rows]
176
177 # Set nans to zero, this is a surrogate value, may change in future
178 np.nan_to_num(x, copy=False)
179 np.nan_to_num(x_edges, copy=False)
180 np.nan_to_num(x_global, copy=False)
181
182 # Normalize any features that should be
183 if self.normalize is not None:
185 self.normalize,
186 self.features,
187 x,
188 self.edge_features,
189 x_edges,
190 self.global_features,
191 x_global,
192 )
193
194 # Reorder LCA
195
196 # Get LCA indices in order that the leaves appear in reconstructed particles
197 # Secondaries aren't in the LCA leaves list so they get a 0
198 locs = np.array(
199 [
200 np.where(y_leaves == i)[0].item() if (i in y_leaves) else 0
201 for i in x_leaves
202 ]
203 )
204
205 # Get the LCA in the correct subset order
206 # If we're not allowing secondaries this is all we need
207 # If we are this will contain duplicates (since secondary locs are set to 0)
208 # We can't load the firs locs directly (i.e. y_item[locs, :]) because locs is (intentionally) unsorted
209 y_edge = y_item["LCA"][evt].reshape((n_LCA, n_LCA)).astype(int)
210 # Get the true mcPDG pf FSPs
211 y_mass = masses_to_classes(x_item["mc_pdg"][evt][x_rows])
212
213 # Get the specified row/cols, this inserts dummy rows/cols for secondaries
214 y_edge = y_edge[locs, :][:, locs]
215 # if self.allow_secondaries:
216 # Set everything that's not primary (unmatched and secondaries) rows.cols to 0
217 # Note we only consider the subset of leaves that made it into x_rows
218 y_edge = np.where(evt_primary[x_rows], y_edge, 0) # Set the rows
219 y_edge = np.where(evt_primary[x_rows][:, None], y_edge, 0) # Set the columns
220
221 # Set diagonal to -1 (actually not necessary but ok...)
222 np.fill_diagonal(y_edge, -1)
223
224 n_nodes = x.shape[0]
225
226 # Target edge tensor: shape [E]
227 edge_y = torch.tensor(
228 y_edge[np.eye(n_nodes) == 0],
229 dtype=torch.long
230 )
231 # Fill tensor with edge indices: shape [N*(N-1), 2] == [E, 2]
232 edge_index = torch.tensor(
233 list(itertools.permutations(range(n_nodes), 2)),
234 dtype=torch.long,
235 )
236
237 # Target global tensor: shape [B, F_g]
238 u_y = torch.tensor(
239 [[1]], dtype=torch.float
240 )
241
242 # Target node tensor: shape [N]
243 x_y = torch.tensor(y_mass, dtype=torch.long)
244
245 g = Data(
246 x=torch.tensor(x, dtype=torch.float),
247 edge_index=edge_index.t().contiguous(),
248 edge_attr=torch.tensor(x_edges, dtype=torch.float),
249 u=torch.tensor(x_global, dtype=torch.float),
250 x_y=x_y,
251 edge_y=edge_y,
252 u_y=u_y,
253 )
254
255 return g
256
257
258class GraphDataSet(InMemoryDataset):
259 """
260 Dataset handler for converting Belle II data to PyTorch geometric InMemoryDataset.
261
262 The ROOT format expects the tree in every file to be named ``Tree``,
263 and all node features to have the format ``feat_FEATNAME``.
264
265 .. note:: This expects the files under root to have the structure ``root/**/<file_name>.root``
266 where the root path is different for train and val.
267 The ``**/`` is to handle subdirectories, e.g. ``sub00``.
268
269 Args:
270 root (str): Path to ROOT files.
271 n_files (int): Load only ``n_files`` files.
272 samples (int): Load only ``samples`` events.
273 features (list): List of node features names.
274 edge_features (list): List of edge features names.
275 global_features (list): List of global features names.
276 normalize (bool): Whether to normalize input features.
277 """
278
280 self,
281 root,
282 n_files=None,
283 samples=None,
284 features=[],
285 edge_features=[],
286 global_features=[],
287 normalize=None,
288 **kwargs,
289 ):
290 """
291 Initialization.
292 """
293 assert isinstance(
294 features, list
295 ), f'Argument "features" must be a list and not {type(features)}'
296 assert len(features) > 0, "You need to use at least one node feature"
297
298
299 self.root = Path(root)
300
301
302 self.normalize = normalize
303
304
305 self.n_files = n_files
306
307 self.node_features = features
308
309 self.edge_features = edge_features
310
311 self.global_features = global_features
312
313 self.samples = samples
314
315 # Delete processed files, in case
316 file_path = Path(self.root, "processed")
317 files = list(file_path.glob("*.pt"))
318 for f in files:
319 f.unlink(missing_ok=True)
320
321 # Needs to be called after having assigned all attributes
322 super().__init__(root, None, None, None)
323
324
325 self.data, self.slices = torch.load(self.processed_paths[0])
326
327 @property
329 """
330 Name of processed file.
331 """
332 return ["processed_data.pt"]
333
334 def process(self):
335 """
336 Processes the data to create graph objects and stores them in ``root/processed/processed_data.pt``
337 where the root path is different for train and val.
338
339 Called internally by PyTorch.
340 """
341 num_samples = _preload(self)
342 data_list = [_process_graph(self, i) for i in range(num_samples)]
343 data, slices = self.collate(data_list)
344 torch.save((data, slices), self.processed_paths[0])
345
346 del self.x, self.y, self.avail_samples, data_list, data, slices
def __init__(self, root, n_files=None, samples=None, features=[], edge_features=[], global_features=[], normalize=None, **kwargs)