Belle II Software  light-2403-persian
geometric_datasets.py
1 
8 
9 
10 import itertools
11 from pathlib import Path
12 import numpy as np
13 import torch
14 from .tree_utils import masses_to_classes
15 from .dataset_utils import populate_avail_samples, preload_root_data
16 from .edge_features import compute_edge_features
17 from .normalize_features import normalize_features
18 from torch_geometric.data import Data, InMemoryDataset
19 import uproot
20 
21 
22 def _preload(self):
23  """
24  Creates graph objects and stores them into a python list.
25  """
26 
27 
28  self.x_files = sorted(self.root.glob("**/*.root"))
29 
30  # Select the first N files (useful for testing)
31  if self.n_files is not None:
32  if self.n_files > len(self.x_files):
33  print(
34  f"WARNING: n_files specified ({self.n_files}) is greater than files in path given ({len(self.x_files)})"
35  )
36 
37  self.x_files = self.x_files[: self.n_files]
38 
39  if len(self.x_files) == 0:
40  raise RuntimeError(f"No files found in {self.root}")
41 
42  # Save the features
43  with uproot.open(self.x_files[0])["Tree"] as t:
44 
45  self.features = [f for f in t.keys() if f.startswith("feat_")]
46 
47  self.B_reco = int(t["isB"].array(library="np")[0])
48  assert self.B_reco in [0, 1, 2], "B_reco should be 0, 1 or 2, something went wrong"
49 
50 
51  self.discarded = [
52  f for f in self.features if not f[f.find("_") + 1:] in self.node_features
53  ]
54  self.features = [
55  f"feat_{f}" for f in self.node_features if f"feat_{f}" in self.features
56  ]
57 
58  print(f"Input node features: {self.features}")
59  print(f"Discarded node features: {self.discarded}")
60 
61 
62  self.edge_features = [f"edge_{f}" for f in self.edge_features]
63 
64  self.global_features = [f"glob_{f}" for f in self.global_features] if self.global_features else []
65  print(f"Input edge features: {self.edge_features}")
66  print(f"Input global features: {self.global_features}")
67 
68 
69  self.x, self.y = preload_root_data(
70  self.x_files,
71  self.features,
72  self.discarded,
73  )
74 
75 
76  self.avail_samples = populate_avail_samples(
77  self.x,
78  self.y,
79  self.B_reco,
80  )
81 
82  # Select a subset of available samples if requested
83  if self.samples and self.samples < len(self.avail_samples):
84  print(f"Selecting random subset of {self.samples} samples")
85  self.avail_samples = [
86  self.avail_samples[i]
87  for i in np.random.choice(
88  len(self.avail_samples), self.samples, replace=False
89  )
90  ]
91  elif self.samples and (self.samples >= len(self.avail_samples)):
92  print(
93  f"WARNING: No. samples specified ({self.samples}) exceeds number of samples loaded ({len(self.avail_samples)})"
94  )
95 
96  return len(self.avail_samples)
97 
98 
99 def _process_graph(self, idx):
100  """
101  Actually builds the graph object.
102 
103  Args:
104  idx (int): Index of training example to be processed.
105 
106  Returns:
107  torch_geometric.data.Data: Graph object to be used in training.
108  """
109 
110  file_id, evt, p_index = self.avail_samples[idx]
111 
112  x_item = self.x[file_id]
113  y_item = self.y[file_id][p_index]
114 
115  evt_b_index = x_item["b_index"][evt]
116  evt_leaves = x_item["leaves"][evt]
117  evt_primary = x_item["primary"][evt]
118 
119  y_leaves = y_item["LCA_leaves"][evt]
120  # Use this to correctly reshape LCA (might be able to just use shape of y_leaves?)
121  n_LCA = y_item["n_LCA"][evt]
122 
123  # Get the rows of the X inputs to fetch
124  # This is a boolean numpy array
125  x_rows = (evt_b_index != -1) if not self.B_reco else evt_b_index == int(p_index)
126 
127  # Find the unmatched particles
128  unmatched_rows = evt_b_index == -1
129 
130  if np.any(unmatched_rows) and self.B_reco:
131  # Create a random boolean array the same size as the number of leaves
132  rand_mask = np.random.choice(a=[False, True], size=unmatched_rows.size)
133  # AND the mask with the unmatched leaves
134  # This selects a random subset of the unmatched leaves
135  unmatched_rows = np.logical_and(unmatched_rows, rand_mask)
136 
137  # Add the unmatched rows to the current decay's rows
138  x_rows = np.logical_or(x_rows, unmatched_rows)
139 
140  # Here we actually load the data
141 
142  # Initialise event's X array
143  x = np.empty((x_rows.sum(), len(self.features)))
144  x_dis = np.empty((x_rows.sum(), len(self.discarded)))
145 
146  # And populate it
147  for idx, feat in enumerate(self.features):
148  x[:, idx] = x_item["features"][feat][evt][x_rows]
149  for idx, feat in enumerate(self.discarded):
150  x_dis[:, idx] = x_item["discarded"][feat][evt][x_rows]
151 
152  # Same for edge and global features
153  x_edges = (
154  compute_edge_features(
155  self.edge_features,
156  self.features + self.discarded,
157  np.concatenate([x, x_dis], axis=1),
158  )
159  if self.edge_features is not []
160  else []
161  )
162  x_global = (
163  np.array(
164  [
165  [
166  x_item["global"][feat + f"_{p_index}"][evt]
167  for feat in self.global_features
168  ]
169  ]
170  )
171  if self.global_features != []
172  else []
173  )
174 
175  x_leaves = evt_leaves[x_rows]
176 
177  # Set nans to zero, this is a surrogate value, may change in future
178  np.nan_to_num(x, copy=False)
179  np.nan_to_num(x_edges, copy=False)
180  np.nan_to_num(x_global, copy=False)
181 
182  # Normalize any features that should be
183  if self.normalize is not None:
185  self.normalize,
186  self.features,
187  x,
188  self.edge_features,
189  x_edges,
190  self.global_features,
191  x_global,
192  )
193 
194  # Reorder LCA
195 
196  # Get LCA indices in order that the leaves appear in reconstructed particles
197  # Secondaries aren't in the LCA leaves list so they get a 0
198  locs = np.array(
199  [
200  np.where(y_leaves == i)[0].item() if (i in y_leaves) else 0
201  for i in x_leaves
202  ]
203  )
204 
205  # Get the LCA in the correct subset order
206  # If we're not allowing secondaries this is all we need
207  # If we are this will contain duplicates (since secondary locs are set to 0)
208  # We can't load the firs locs directly (i.e. y_item[locs, :]) because locs is (intentionally) unsorted
209  y_edge = y_item["LCA"][evt].reshape((n_LCA, n_LCA)).astype(int)
210  # Get the true mcPDG pf FSPs
211  y_mass = masses_to_classes(x_item["mc_pdg"][evt][x_rows])
212 
213  # Get the specificed row/cols, this inserts dummy rows/cols for secondaries
214  y_edge = y_edge[locs, :][:, locs]
215  # if self.allow_secondaries:
216  # Set everything that's not primary (unmatched and secondaries) rows.cols to 0
217  # Note we only consider the subset of leaves that made it into x_rows
218  y_edge = np.where(evt_primary[x_rows], y_edge, 0) # Set the rows
219  y_edge = np.where(evt_primary[x_rows][:, None], y_edge, 0) # Set the columns
220 
221  # Set diagonal to -1 (actually not necessary but ok...)
222  np.fill_diagonal(y_edge, -1)
223 
224  n_nodes = x.shape[0]
225 
226  # Target edge tensor: shape [E]
227  edge_y = torch.tensor(
228  y_edge[np.eye(n_nodes) == 0],
229  dtype=torch.long
230  )
231  # Fill tensor with edge indices: shape [N*(N-1), 2] == [E, 2]
232  edge_index = torch.tensor(
233  list(itertools.permutations(range(n_nodes), 2)),
234  dtype=torch.long,
235  )
236 
237  # Target global tensor: shape [B, F_g]
238  u_y = torch.tensor(
239  [[1]], dtype=torch.float
240  )
241 
242  # Target node tensor: shape [N]
243  x_y = torch.tensor(y_mass, dtype=torch.long)
244 
245  g = Data(
246  x=torch.tensor(x, dtype=torch.float),
247  edge_index=edge_index.t().contiguous(),
248  edge_attr=torch.tensor(x_edges, dtype=torch.float),
249  u=torch.tensor(x_global, dtype=torch.float),
250  x_y=x_y,
251  edge_y=edge_y,
252  u_y=u_y,
253  )
254 
255  return g
256 
257 
258 class GraphDataSet(InMemoryDataset):
259  """
260  Dataset handler for converting Belle II data to PyTorch geometric InMemoryDataset.
261 
262  The ROOT format expects the tree in every file to be named ``Tree``,
263  and all node features to have the format ``feat_FEATNAME``.
264 
265  .. note:: This expects the files under root to have the structure ``root/**/<file_name>.root``
266  where the root path is different for train and val.
267  The ``**/`` is to handle subdirectories, e.g. ``sub00``.
268 
269  Args:
270  root (str): Path to ROOT files.
271  n_files (int): Load only ``n_files`` files.
272  samples (int): Load only ``samples`` events.
273  features (list): List of node features names.
274  edge_features (list): List of edge features names.
275  global_features (list): List of global features names.
276  normalize (bool): Whether to normalize input features.
277  """
278 
279  def __init__(
280  self,
281  root,
282  n_files=None,
283  samples=None,
284  features=[],
285  edge_features=[],
286  global_features=[],
287  normalize=None,
288  **kwargs,
289  ):
290  """
291  Initialization.
292  """
293  assert isinstance(
294  features, list
295  ), f'Argument "features" must be a list and not {type(features)}'
296  assert len(features) > 0, "You need to use at least one node feature"
297 
298 
299  self.rootroot = Path(root)
300 
301 
302  self.normalizenormalize = normalize
303 
304 
305  self.n_filesn_files = n_files
306 
307  self.node_featuresnode_features = features
308 
309  self.edge_featuresedge_features = edge_features
310 
311  self.global_featuresglobal_features = global_features
312 
313  self.samplessamples = samples
314 
315  # Delete processed files, in case
316  file_path = Path(self.rootroot, "processed")
317  files = list(file_path.glob("*.pt"))
318  for f in files:
319  f.unlink(missing_ok=True)
320 
321  # Needs to be called after having assigned all attributes
322  super().__init__(root, None, None, None)
323 
324 
325  self.data, self.slicesslices = torch.load(self.processed_paths[0])
326 
327  @property
329  """
330  Name of processed file.
331  """
332  return ["processed_data.pt"]
333 
334  def process(self):
335  """
336  Processes the data to create graph objects and stores them in ``root/processed/processed_data.pt``
337  where the root path is different for train and val.
338 
339  Called internally by PyTorch.
340  """
341  num_samples = _preload(self)
342  data_list = [_process_graph(self, i) for i in range(num_samples)]
343  data, slices = self.collate(data_list)
344  torch.save((data, slices), self.processed_paths[0])
345 
346  del self.x, self.y, self.avail_samples, data_list, data, slices
def __init__(self, root, n_files=None, samples=None, features=[], edge_features=[], global_features=[], normalize=None, **kwargs)