11from pathlib
import Path
14from .tree_utils
import masses_to_classes
15from .dataset_utils
import populate_avail_samples, preload_root_data
16from .edge_features
import compute_edge_features
17from .normalize_features
import normalize_features
18from torch_geometric.data
import Data, InMemoryDataset
24 Creates graph objects and stores them into a python list.
28 self.x_files = sorted(self.root.glob("**/*.root"))
31 if self.n_files
is not None:
32 if self.n_files > len(self.x_files):
34 f
"WARNING: n_files specified ({self.n_files}) is greater than files in path given ({len(self.x_files)})"
37 self.x_files = self.x_files[: self.n_files]
39 if len(self.x_files) == 0:
40 raise RuntimeError(f
"No files found in {self.root}")
43 with uproot.open(self.x_files[0])[
"Tree"]
as t:
45 self.features = [f
for f
in t.keys()
if f.startswith(
"feat_")]
47 self.B_reco = int(t[
"isB"].array(library=
"np")[0])
48 assert self.B_reco
in [0, 1, 2],
"B_reco should be 0, 1 or 2, something went wrong"
52 f
for f
in self.features
if not f[f.find(
"_") + 1:]
in self.node_features
55 f
"feat_{f}" for f
in self.node_features
if f
"feat_{f}" in self.features
58 print(f
"Input node features: {self.features}")
59 print(f
"Discarded node features: {self.discarded}")
62 self.edge_features = [f
"edge_{f}" for f
in self.edge_features]
64 self.global_features = [f
"glob_{f}" for f
in self.global_features]
if self.global_features
else []
65 print(f
"Input edge features: {self.edge_features}")
66 print(f
"Input global features: {self.global_features}")
69 self.x, self.y = preload_root_data(
76 self.avail_samples = populate_avail_samples(
83 if self.samples
and self.samples < len(self.avail_samples):
84 print(f
"Selecting random subset of {self.samples} samples")
85 self.avail_samples = [
87 for i
in np.random.choice(
88 len(self.avail_samples), self.samples, replace=
False
91 elif self.samples
and (self.samples >= len(self.avail_samples)):
93 f
"WARNING: No. samples specified ({self.samples}) exceeds number of samples loaded ({len(self.avail_samples)})"
96 return len(self.avail_samples)
99def _process_graph(self, idx):
101 Actually builds the graph object.
104 idx (int): Index of training example to be processed.
107 torch_geometric.data.Data: Graph object to be used in training.
110 file_id, evt, p_index = self.avail_samples[idx]
112 x_item = self.x[file_id]
113 y_item = self.y[file_id][p_index]
115 evt_b_index = x_item["b_index"][evt]
116 evt_leaves = x_item[
"leaves"][evt]
117 evt_primary = x_item[
"primary"][evt]
119 y_leaves = y_item[
"LCA_leaves"][evt]
121 n_LCA = y_item[
"n_LCA"][evt]
125 x_rows = (evt_b_index != -1)
if not self.B_reco
else evt_b_index == int(p_index)
128 unmatched_rows = evt_b_index == -1
130 if np.any(unmatched_rows)
and self.B_reco:
132 rand_mask = np.random.choice(a=[
False,
True], size=unmatched_rows.size)
135 unmatched_rows = np.logical_and(unmatched_rows, rand_mask)
138 x_rows = np.logical_or(x_rows, unmatched_rows)
143 x = np.empty((x_rows.sum(), len(self.features)))
144 x_dis = np.empty((x_rows.sum(), len(self.discarded)))
147 for idx, feat
in enumerate(self.features):
148 x[:, idx] = x_item[
"features"][feat][evt][x_rows]
149 for idx, feat
in enumerate(self.discarded):
150 x_dis[:, idx] = x_item[
"discarded"][feat][evt][x_rows]
154 compute_edge_features(
156 self.features + self.discarded,
157 np.concatenate([x, x_dis], axis=1),
159 if self.edge_features
is not []
166 x_item[
"global"][feat + f
"_{p_index}"][evt]
167 for feat
in self.global_features
171 if self.global_features != []
175 x_leaves = evt_leaves[x_rows]
178 np.nan_to_num(x, copy=
False)
179 np.nan_to_num(x_edges, copy=
False)
180 np.nan_to_num(x_global, copy=
False)
183 if self.normalize
is not None:
190 self.global_features,
200 np.where(y_leaves == i)[0].item()
if (i
in y_leaves)
else 0
209 y_edge = y_item[
"LCA"][evt].reshape((n_LCA, n_LCA)).astype(int)
211 y_mass = masses_to_classes(x_item[
"mc_pdg"][evt][x_rows])
214 y_edge = y_edge[locs, :][:, locs]
218 y_edge = np.where(evt_primary[x_rows], y_edge, 0)
219 y_edge = np.where(evt_primary[x_rows][:,
None], y_edge, 0)
222 np.fill_diagonal(y_edge, -1)
227 edge_y = torch.tensor(
228 y_edge[np.eye(n_nodes) == 0],
232 edge_index = torch.tensor(
233 list(itertools.permutations(range(n_nodes), 2)),
239 [[1]], dtype=torch.float
243 x_y = torch.tensor(y_mass, dtype=torch.long)
246 x=torch.tensor(x, dtype=torch.float),
247 edge_index=edge_index.t().contiguous(),
248 edge_attr=torch.tensor(x_edges, dtype=torch.float),
249 u=torch.tensor(x_global, dtype=torch.float),
260 Dataset handler for converting Belle II data to PyTorch geometric InMemoryDataset.
262 The ROOT format expects the tree
in every file to be named ``Tree``,
263 and all node features to have the format ``feat_FEATNAME``.
265 .. note:: This expects the files under root to have the structure ``root/**/<file_name>.root``
266 where the root path
is different
for train
and val.
267 The ``**/``
is to handle subdirectories, e.g. ``sub00``.
270 root (str): Path to ROOT files.
271 n_files (int): Load only ``n_files`` files.
272 samples (int): Load only ``samples`` events.
273 features (list): List of node features names.
274 edge_features (list): List of edge features names.
275 global_features (list): List of
global features names.
276 normalize (bool): Whether to normalize input features.
295 ), f
'Argument "features" must be a list and not {type(features)}'
296 assert len(features) > 0,
"You need to use at least one node feature"
316 file_path = Path(self.
root,
"processed")
317 files = list(file_path.glob(
"*.pt"))
319 f.unlink(missing_ok=
True)
322 super().
__init__(root,
None,
None,
None)
325 self.data, self.
slices = torch.load(self.processed_paths[0])
330 Name of processed file.
332 return [
"processed_data.pt"]
336 Processes the data to create graph objects and stores them
in ``root/processed/processed_data.pt``
337 where the root path
is different
for train
and val.
339 Called internally by PyTorch.
341 num_samples = _preload(self)
342 data_list = [_process_graph(self, i) for i
in range(num_samples)]
343 data, slices = self.collate(data_list)
344 torch.save((data, slices), self.processed_paths[0])
346 del self.x, self.y, self.avail_samples, data_list, data, slices
def processed_file_names(self)
global_features
Global features.
edge_features
Edge features.
node_features
Node features.
def __init__(self, root, n_files=None, samples=None, features=[], edge_features=[], global_features=[], normalize=None, **kwargs)