Belle II Software development
lca_to_adjacency.py
1
8
9
10import torch as t
11import numpy as np
12from collections import Counter
13from itertools import permutations
14from .tree_utils import is_valid_tree
15
16
17class InvalidLCAMatrix(Exception):
18 """
19 Specialized Exception sub-class raised for malformed LCA matrices or LCA matrices not encoding trees.
20 """
21 pass
22
23
24class Node:
25 """
26 Class to hold levels of nodes in the tree.
27
28 Args:
29 level (int): Level in the tree.
30 children (list[Node]): Children of the nodes.
31 lca_index (int): Index in the LCAS matrix.
32 lcas_level (int): Level in the LCAS matrix.
33 """
34
35 def __init__(self, level, children, lca_index=None, lcas_level=0):
36 """
37 Initialization
38 """
39
40 self.level = level
41
42 self.children = children
43
44 self.lca_index = lca_index
45
46 self.lcas_level = lcas_level
47
48
49 self.parent = None
50
51 self.bfs_index = -1
52
53
54def _get_ancestor(node):
55 """
56 Trail search for the highest ancestor of a node.
57 """
58 ancestor = node
59
60 while ancestor.parent is not None:
61 ancestor = ancestor.parent
62
63 return ancestor
64
65
66def _nodes_in_ancestors_children(parent, node1, node2):
67 """
68 Checks if any node in parent's line of descent is also an ancestor of both node1 and node2.
69 """
70 for child in parent.children:
71 if (node1 in child.children) and (node2 in child.children):
72 return True
73 else:
74 _nodes_in_ancestors_children(child, node1, node2)
75
76 return False
77
78
79def _pull_down(node):
80 """
81 Works up the node's history, pulling down a level any nodes
82 whose children are all more than one level below.
83
84 Performs the operation in place.
85 """
86 # First check the children
87 if len(node.children) > 0:
88 highest_child = max([c.level for c in node.children])
89 node.level = highest_child + 1
90
91 # Then move on to the parent
92 if node.parent is not None:
93 _pull_down(node.parent)
94
95 return
96
97
98def _breadth_first_enumeration(root, queue, adjacency_matrix):
99 """
100 Enumerates the tree breadth-first into a queue.
101 """
102 # Insert current root node into the queue
103 level = root.level
104 queue.setdefault(level, []).append(root)
105
106 # Enumerate the children
107 for child in root.children:
108 _breadth_first_enumeration(child, queue, adjacency_matrix)
109
110 return queue
111
112
113def _breadth_first_adjacency(root, adjacency_matrix):
114 """
115 Enumerates the tree breadth-first into a queue.
116 """
117 queue = _breadth_first_enumeration(root, {}, adjacency_matrix)
118
119 # On recursion end in the root node, traverse the tree once to assign bfs ids to each node
120 index = 0
121 for i in range(root.level, 0, -1):
122 for node in queue[i]:
123 node.bfs_index = index
124 index += 1
125
126 # Then traverse the tree again to fill in the adjacencies
127 for i in range(root.level, 0, -1):
128 for node in queue[i]:
129 for child in node.children:
130 adjacency_matrix[node.bfs_index, child.bfs_index] = 1
131 adjacency_matrix[child.bfs_index, node.bfs_index] = 1
132
133
134def _reconstruct(lca_matrix):
135 """
136 Does the actual heavy lifting of the adjacency matrix reconstruction. Traverses the LCA matrix level-by-level,
137 starting at one. For each level new nodes have to be inserted into the adjacency matrix, if a LCA matrix with this
138 level number exists. The newly created node(s) will then be connected to the lower leaves, respectively,
139 sub-graphs. This function may produce reconstructions that are valid graphs, but not trees.
140 """
141 n = lca_matrix.shape[0]
142 total_nodes = n
143 # depths = int(lca_matrix.max())
144 levels = sorted(lca_matrix.unique().tolist())
145 # Want to skip over leaves
146 levels.remove(0)
147
148 # Create nodes for all leaves
149 leaves = [Node(1, [], lca_index=i) for i in range(n)]
150
151 # Iterate level-by-level through the matrix, starting from immediate connections
152 # we can correct missing intermediate levels here too
153 # Just use current_level to check the actual LCA entry, once we know which level it is
154 # (ignoring missed levels) then use the index (corrected level)
155 # for current_level in range(1, depths + 1):
156 for idx, current_level in enumerate(levels, 1):
157 # Iterate through each leaf in the LCA matrix
158 for column in range(n):
159 # Iterate through all corresponding nodes
160 # The LCA matrix is symmetric, hence, check only the from the diagonal down
161 for row in range(column + 1, n):
162 # Skip over entries not in current level
163 if lca_matrix[row, column] <= 0:
164 raise InvalidLCAMatrix
165 elif lca_matrix[row, column] != current_level:
166 continue
167
168 # Get the nodes
169 a_node = leaves[column]
170 another_node = leaves[row]
171
172 # Determine the ancestors of both nodes
173 an_ancestor = _get_ancestor(a_node)
174 a_level = an_ancestor.level
175
176 another_ancestor = _get_ancestor(another_node)
177 another_level = another_ancestor.level
178
179 # The nodes both already have an ancestor at that level, confirm it's the same one
180 # and check that the common ancestor doesn't have a child which is in turn an ancestor of both left and right nodes
181 if a_level == another_level == (idx + 1):
182 if (
183 an_ancestor is not another_ancestor
184 or _nodes_in_ancestors_children(
185 an_ancestor, a_node, another_node
186 )
187 ):
188 raise InvalidLCAMatrix
189 # Should also check neither have an ancestor above the current level
190 # If so then something went really wrong
191 elif a_level > idx + 1 or another_level > idx + 1:
192 raise InvalidLCAMatrix
193
194 # The nodes don't have an ancestor at the level we're inspecting.
195 # We need to make one and connect them to it
196 elif a_level < idx + 1 and another_level < idx + 1:
197 parent = Node(idx + 1, [an_ancestor, another_ancestor], lcas_level=current_level)
198 an_ancestor.parent = parent
199 another_ancestor.parent = parent
200 total_nodes += 1
201
202 # the left node already has a higher order parent, lets attach to it
203 # I think should confirm that a_level == idx + 1 too
204 elif another_level < idx + 1 and a_level == idx + 1:
205 # This should be the another_ancestor.parent getting assigned
206 # another_node.parent = an_ancestor
207 # an_ancestor.children.append(another_node)
208 another_ancestor.parent = an_ancestor
209 an_ancestor.children.append(another_ancestor)
210
211 # Same for right
212 elif a_level < idx + 1 and another_level == idx + 1:
213 an_ancestor.parent = another_ancestor
214 another_ancestor.children.append(an_ancestor)
215
216 # If all this fails I think that's also bad
217 else:
218 raise InvalidLCAMatrix
219
220 # The LCAs aren't guaranteed to actually be "lowest" ancestors, we need to make sure
221 # by pulling down any nodes that can be (i.e. have all children more than one level down)
222 for leaf in leaves:
223 _pull_down(leaf)
224
225 # We have created the tree structure, let's initialize the adjacency matrix and find the root to traverse from
226 root = _get_ancestor(leaves[0])
227
228 return root, total_nodes
229
230
231def lca_to_adjacency(lca_matrix):
232 """
233 Converts a tree's LCA matrix representation, i.e. a square matrix (M, M) where each row/column corresponds to
234 a leaf of the tree and each matrix entry is the level of the lowest-common-ancestor (LCA) of the two leaves, into
235 the corresponding two-dimension adjacency matrix (N,N), with M < N. The levels are enumerated top-down from the
236 root.
237
238 .. seealso::
239 The pseudocode for LCA to tree conversion is described in
240 `Kahn et al <https://iopscience.iop.org/article/10.1088/2632-2153/ac8de0>`_.
241
242 :param lca_matrix: 2-dimensional LCA matrix (M, M).
243 :type lca_matrix: `Tensor <https://pytorch.org/docs/stable/tensors.html#torch.Tensor>`_
244
245 :return: 2-dimensional matrix (N, N) encoding the graph's node adjacencies. Linked nodes have values unequal to zero.
246 :rtype: `Tensor <https://pytorch.org/docs/stable/tensors.html#torch.Tensor>`_
247
248 Raises:
249 InvalidLCAMatrix: If passed LCA matrix is malformed (e.g. not 2d or not square) or does not encode a tree.
250 """
251
252 # Ensure input is torch tensor or can be converted to it
253 if not isinstance(lca_matrix, t.Tensor):
254 try:
255 lca_matrix = t.Tensor(lca_matrix)
256 except TypeError as err:
257 print(f"Input type must be compatible with torch Tensor: {err}")
258 raise
259
260 # Ensure two dimensions
261 if len(lca_matrix.shape) != 2:
262 raise InvalidLCAMatrix
263
264 # Ensure that it is square
265 n, m = lca_matrix.shape
266 if n != m:
267 raise InvalidLCAMatrix
268
269 # Check symmetry
270 if not (lca_matrix == lca_matrix.T).all():
271 raise InvalidLCAMatrix
272
273 try:
274 root, total_nodes = _reconstruct(lca_matrix)
275 except IndexError:
276 raise InvalidLCAMatrix
277
278 # Allocate the adjacency matrix
279 adjacency_matrix = t.zeros((total_nodes, total_nodes), dtype=t.int64)
280 try:
281 _breadth_first_adjacency(root, adjacency_matrix)
282 except IndexError:
283 raise InvalidLCAMatrix
284
285 # Check whether what we reconstructed is actually a tree - might be a regular graph for example
286 if not is_valid_tree(adjacency_matrix):
287 raise InvalidLCAMatrix
288
289 return adjacency_matrix
290
291
292def _get_fsps_of_node(node):
293 """
294 Given a node, finds all the final state particles connected to it and get their indices in the LCA.
295
296 Args:
297 node (Node): Node to be inspected.
298
299 Returns:
300 indices (list): List of final state particles' indices in the LCA matrix connected to node.
301 """
302 indices = []
303
304 if node.lca_index is not None: # If you simply use 'if node.lca_index:' you will always miss the first fsp
305 indices.append(node.lca_index)
306 else:
307 for child in node.children:
308 indices.extend(_get_fsps_of_node(child))
309
310 return list(set(indices))
311
312
313def select_good_decay(predicted_lcas, predicted_masses, sig_side_lcas=None, sig_side_masses=None):
314 """
315 Checks if given LCAS matrix is found in reconstructed LCAS matrix and mass hypotheses are correct.
316
317 .. warning:: You have to make sure to call this function only for valid tree structures encoded in ``predicted_lcas``,
318 otherwise it will throw an exception.
319
320 Mass hypotheses are indicated by letters. The following convention is used:
321
322 .. math::
323 'e' \\to e \\\\
324 'i' \\to \\pi \\\\
325 'k' \\to K \\\\
326 'p' \\to p \\\\
327 'm' \\to \\mu \\\\
328 'g' \\to \\gamma \\\\
329 'o' \\to \\text{others}
330
331 .. warning:: The order of mass hypotheses should match that of the final state particles in the LCAS.
332
333 :param predicted_lcas: LCAS matrix.
334 :type predicted_lcas: `Tensor <https://pytorch.org/docs/stable/tensors.html#torch.Tensor>`_
335 :param predicted_masses: List of predicted mass classes.
336 :type predicted_masses: list[str]
337 :param sig_side_lcas: LCAS matrix of your signal-side.
338 :type sig_side_lcas: `Tensor <https://pytorch.org/docs/stable/tensors.html#torch.Tensor>`_
339 :param sig_side_masses: List of mass hypotheses for your FSPs.
340 :type sig_side_masses: list[str]
341
342 Returns:
343 bool, int, list: True if LCAS and masses match, LCAS level of root node,
344 LCA indices of FSPs belonging to the signal side ([-1] if LCAS does not match decay string).
345 """
346
347 # Reconstruct decay chain
348 root, _ = _reconstruct(predicted_lcas)
349
350 # If root is not Ups nor B then decay is not good
351 if root.lcas_level not in [5, 6]:
352 return (False, root.lcas_level, [-1])
353
354 # If root is B don't go any further (function is supposed to check wheter signal-side on Ups decay is good)
355 if root.lcas_level == 5:
356 return (True, 5, [i for i in range(predicted_lcas.shape[0])])
357
358 # If chosen LCAS or masses are None then decay is not good
359 if sig_side_lcas is None or sig_side_masses is None:
360 return (False, root.lcas_level, [-1])
361
362 # Check if the LCA matrix/masses you chose are valid
363 try:
364 if sig_side_lcas.item() == 0:
365 more_fsps = False
366 else:
367 raise InvalidLCAMatrix("If you have only one sig-side FSP, the LCA matrix should be [[0]]")
368 except ValueError:
369 try:
370 lca_to_adjacency(sig_side_lcas)
371 more_fsps = True
372 except InvalidLCAMatrix:
373 raise InvalidLCAMatrix("You chose an invalid LCA matrix")
374
375 # Check if the number of FSPs in the LCA is the same as the number of mass hypotheses
376 if sig_side_lcas.shape[0] != len(sig_side_masses):
377 raise InvalidLCAMatrix("The dimension of the LCA matrix you chose does not match with the number of mass hypotheses")
378
379 # Check if mass hypotheses are allowed
380 for e in set(sig_side_masses):
381 if e not in ['i', 'o', 'g', 'k', 'm', 'e', 'p']:
382 # Ok this is not properly an InvalidLCAMatrix case but I'm too lazy to define dedicated exception
383 raise InvalidLCAMatrix("Allowed mass hypotheses are 'i', 'o', 'g', 'k', 'm', 'e', 'p'")
384
385 # Convert mass hypotheses to classes and then to integers
386 for s, n in zip(["i", "k", "p", "e", "m", "g", "o"], ["3", "4", "5", "1", "2", "6", "0"]):
387 sig_side_masses = list(map(lambda x: x.replace(s, n), sig_side_masses))
388 sig_side_masses = t.from_numpy(np.array(sig_side_masses, dtype=int))
389
390 # Let's start the proper decay check
391 # Case 1: only one FSP in the signal-side
392 if not more_fsps:
393 # There should be two nodes: one '5' and one '0'
394 if Counter([child.lcas_level for child in root.children]) != Counter({5: 1, 0: 1}):
395 return (False, root.lcas_level, [-1])
396
397 # Get FSP index in LCA
398 fsp_idx = root.children[0].lca_index if root.children[0].lcas_level == 0 else root.children[1].lca_index
399
400 # Check mass hypothesis
401 if predicted_masses[fsp_idx] != sig_side_masses[0]:
402 return (False, root.lcas_level, [-1])
403
404 # I think the exceptions are over, decay is good
405 return (True, root.lcas_level, [fsp_idx])
406
407 # Case 2: more FSPs in the signal-side
408 else:
409 # There should be two nodes labelled as '5'
410 if Counter([child.lcas_level for child in root.children]) != Counter({5: 2}):
411 return (False, root.lcas_level, [-1])
412
413 # If there are two '5', at least one of them should decay into the nodes given by the chosen LCAS/masses
414 # Step 1: get LCA indices of both Bs
415 B1_indices = _get_fsps_of_node(root.children[0])
416 B2_indices = _get_fsps_of_node(root.children[1])
417
418 # Step 2: Loop over the two Bs and select LCA sub-matrix and sub-masses
419 for indices in [B1_indices, B2_indices]:
420 # Step 3: check whether number of FSPs in the chosen sig-side corresponds to that of one of the B's
421 if sig_side_lcas.shape[0] != len(indices):
422 continue
423
424 sub_lca = predicted_lcas[indices][:, indices]
425 sub_masses = predicted_masses[indices]
426
427 # Step 4: your chosen sig-side LCAS/masses could have different ordering,
428 # we have to check all possible permutations
429 for permutation in permutations(list(range(len(sub_lca)))):
430 permutation = list(permutation)
431 permuted_sig_side_lca = sig_side_lcas[permutation][:, permutation]
432 permuted_sig_side_masses = sig_side_masses[permutation]
433 # Step 5: if one of the permutations works decay is good
434 if (permuted_sig_side_lca == sub_lca).all() and (permuted_sig_side_masses == sub_masses).all():
435 return (True, root.lcas_level, indices)
436
437 # If we get here decay is not good
438 return (False, root.lcas_level, [-1])
def __init__(self, level, children, lca_index=None, lcas_level=0)