Belle II Software  light-2403-persian
lca_to_adjacency.py
1 
8 
9 
10 import torch as t
11 import numpy as np
12 from collections import Counter
13 from itertools import permutations
14 from .tree_utils import is_valid_tree
15 
16 
17 class InvalidLCAMatrix(Exception):
18  """
19  Specialized Exception sub-class raised for malformed LCA matrices or LCA matrices not encoding trees.
20  """
21  pass
22 
23 
24 class Node:
25  """
26  Class to hold levels of nodes in the tree.
27 
28  Args:
29  level (int): Level in the tree.
30  children (list[Node]): Children of the nodes.
31  lca_index (int): Index in the LCAS matrix.
32  lcas_level (int): Level in the LCAS matrix.
33  """
34 
35  def __init__(self, level, children, lca_index=None, lcas_level=0):
36  """
37  Initialization
38  """
39 
40  self.levellevel = level
41 
42  self.childrenchildren = children
43 
44  self.lca_indexlca_index = lca_index
45 
46  self.lcas_levellcas_level = lcas_level
47 
48 
49  self.parentparent = None
50 
51  self.bfs_indexbfs_index = -1
52 
53 
54 def _get_ancestor(node):
55  """
56  Trail search for the highest ancestor of a node.
57  """
58  ancestor = node
59 
60  while ancestor.parent is not None:
61  ancestor = ancestor.parent
62 
63  return ancestor
64 
65 
66 def _nodes_in_ancestors_children(parent, node1, node2):
67  """
68  Checks if any node in parent's line of descent is also an ancestor of both node1 and node2.
69  """
70  for child in parent.children:
71  if (node1 in child.children) and (node2 in child.children):
72  return True
73  else:
74  _nodes_in_ancestors_children(child, node1, node2)
75 
76  return False
77 
78 
79 def _pull_down(node):
80  """
81  Works up the node's history, pulling down a level any nodes
82  whose children are all more than one level below.
83 
84  Performs the operation in place.
85  """
86  # First check the children
87  if len(node.children) > 0:
88  highest_child = max([c.level for c in node.children])
89  node.level = highest_child + 1
90 
91  # Then move on to the parent
92  if node.parent is not None:
93  _pull_down(node.parent)
94 
95  return
96 
97 
98 def _breadth_first_enumeration(root, queue, adjacency_matrix):
99  """
100  Enumerates the tree breadth-first into a queue.
101  """
102  # Insert current root node into the queue
103  level = root.level
104  queue.setdefault(level, []).append(root)
105 
106  # Enumerate the children
107  for child in root.children:
108  _breadth_first_enumeration(child, queue, adjacency_matrix)
109 
110  return queue
111 
112 
113 def _breadth_first_adjacency(root, adjacency_matrix):
114  """
115  Enumerates the tree breadth-first into a queue.
116  """
117  queue = _breadth_first_enumeration(root, {}, adjacency_matrix)
118 
119  # On recursion end in the root node, traverse the tree once to assign bfs ids to each node
120  index = 0
121  for i in range(root.level, 0, -1):
122  for node in queue[i]:
123  node.bfs_index = index
124  index += 1
125 
126  # Then traverse the tree again to fill in the adjacencies
127  for i in range(root.level, 0, -1):
128  for node in queue[i]:
129  for child in node.children:
130  adjacency_matrix[node.bfs_index, child.bfs_index] = 1
131  adjacency_matrix[child.bfs_index, node.bfs_index] = 1
132 
133 
134 def _reconstruct(lca_matrix):
135  """
136  Does the actual heavy lifting of the adjacency matrix reconstruction. Traverses the LCA matrix level-by-level,
137  starting at one. For each level new nodes have to be inserted into the adjacency matrix, if a LCA matrix with this
138  level number exists. The newly created node(s) will then be connected to the lower leaves, respectively,
139  sub-graphs. This function may produce reconstructions that are valid graphs, but not trees.
140  """
141  n = lca_matrix.shape[0]
142  total_nodes = n
143  # depths = int(lca_matrix.max())
144  levels = sorted(lca_matrix.unique().tolist())
145  # Want to skip over leaves
146  levels.remove(0)
147 
148  # Create nodes for all leaves
149  leaves = [Node(1, [], lca_index=i) for i in range(n)]
150 
151  # Iterate level-by-level through the matrix, starting from immediate connections
152  # we can correct missing intermediate levels here too
153  # Just use current_level to check the actual LCA entry, once we know which level it is
154  # (ignoring missed levels) then use the index (corrected level)
155  # for current_level in range(1, depths + 1):
156  for idx, current_level in enumerate(levels, 1):
157  # Iterate through each leaf in the LCA matrix
158  for column in range(n):
159  # Iterate through all corresponding nodes
160  # The LCA matrix is symmetric, hence, check only the from the diagonal down
161  for row in range(column + 1, n):
162  # Skip over entries not in current level
163  if lca_matrix[row, column] <= 0:
164  raise InvalidLCAMatrix
165  elif lca_matrix[row, column] != current_level:
166  continue
167 
168  # Get the nodes
169  a_node = leaves[column]
170  another_node = leaves[row]
171 
172  # Determine the ancestors of both nodes
173  an_ancestor = _get_ancestor(a_node)
174  a_level = an_ancestor.level
175 
176  another_ancestor = _get_ancestor(another_node)
177  another_level = another_ancestor.level
178 
179  # The nodes both already have an ancestor at that level, confirm it's the same one
180  # and check that the common ancestor doesn't have a child which is in turn an ancestor of both left and right nodes
181  if a_level == another_level == (idx + 1):
182  if (
183  an_ancestor is not another_ancestor
184  or _nodes_in_ancestors_children(
185  an_ancestor, a_node, another_node
186  )
187  ):
188  raise InvalidLCAMatrix
189  # Should also check neither have an ancestor above the current level
190  # If so then something went really wrong
191  elif a_level > idx + 1 or another_level > idx + 1:
192  raise InvalidLCAMatrix
193 
194  # The nodes don't have an ancestor at the level we're inspecting.
195  # We need to make one and connect them to it
196  elif a_level < idx + 1 and another_level < idx + 1:
197  parent = Node(idx + 1, [an_ancestor, another_ancestor], lcas_level=current_level)
198  an_ancestor.parent = parent
199  another_ancestor.parent = parent
200  total_nodes += 1
201 
202  # the left node already has a higher order parent, lets attach to it
203  # I think should confirm that a_level == idx + 1 too
204  elif another_level < idx + 1 and a_level == idx + 1:
205  # This should be the another_ancestor.parent getting assigned
206  # another_node.parent = an_ancestor
207  # an_ancestor.children.append(another_node)
208  another_ancestor.parent = an_ancestor
209  an_ancestor.children.append(another_ancestor)
210 
211  # Same for right
212  elif a_level < idx + 1 and another_level == idx + 1:
213  an_ancestor.parent = another_ancestor
214  another_ancestor.children.append(an_ancestor)
215 
216  # If all this fails I think that's also bad
217  else:
218  raise InvalidLCAMatrix
219 
220  # The LCAs aren't guaranteed to actually be "lowest" ancestors, we need to make sure
221  # by pulling down any nodes that can be (i.e. have all children more than one level down)
222  for leaf in leaves:
223  _pull_down(leaf)
224 
225  # We have created the tree structure, let's initialize the adjacency matrix and find the root to traverse from
226  root = _get_ancestor(leaves[0])
227 
228  return root, total_nodes
229 
230 
231 def lca_to_adjacency(lca_matrix):
232  """
233  Converts a tree's LCA matrix representation, i.e. a square matrix (M, M) where each row/column corresponds to
234  a leaf of the tree and each matrix entry is the level of the lowest-common-ancestor (LCA) of the two leaves, into
235  the corresponding two-dimension adjacency matrix (N,N), with M < N. The levels are enumerated top-down from the
236  root.
237 
238  .. seealso::
239  The pseudocode for LCA to tree conversion is described in
240  `Kahn et al <https://iopscience.iop.org/article/10.1088/2632-2153/ac8de0>`_.
241 
242  :param lca_matrix: 2-dimensional LCA matrix (M, M).
243  :type lca_matrix: `Tensor <https://pytorch.org/docs/stable/tensors.html#torch.Tensor>`_
244 
245  :return: 2-dimensional matrix (N, N) encoding the graph's node adjacencies. Linked nodes have values unequal to zero.
246  :rtype: `Tensor <https://pytorch.org/docs/stable/tensors.html#torch.Tensor>`_
247 
248  Raises:
249  InvalidLCAMatrix: If passed LCA matrix is malformed (e.g. not 2d or not square) or does not encode a tree.
250  """
251 
252  # Ensure input is torch tensor or can be converted to it
253  if not isinstance(lca_matrix, t.Tensor):
254  try:
255  lca_matrix = t.Tensor(lca_matrix)
256  except TypeError as err:
257  print(f"Input type must be compatible with torch Tensor: {err}")
258  raise
259 
260  # Ensure two dimensions
261  if len(lca_matrix.shape) != 2:
262  raise InvalidLCAMatrix
263 
264  # Ensure that it is square
265  n, m = lca_matrix.shape
266  if n != m:
267  raise InvalidLCAMatrix
268 
269  # Check symmetry
270  if not (lca_matrix == lca_matrix.T).all():
271  raise InvalidLCAMatrix
272 
273  try:
274  root, total_nodes = _reconstruct(lca_matrix)
275  except IndexError:
276  raise InvalidLCAMatrix
277 
278  # Allocate the adjacency matrix
279  adjacency_matrix = t.zeros((total_nodes, total_nodes), dtype=t.int64)
280  try:
281  _breadth_first_adjacency(root, adjacency_matrix)
282  except IndexError:
283  raise InvalidLCAMatrix
284 
285  # Check whether what we reconstructed is actually a tree - might be a regular graph for example
286  if not is_valid_tree(adjacency_matrix):
287  raise InvalidLCAMatrix
288 
289  return adjacency_matrix
290 
291 
292 def _get_fsps_of_node(node):
293  """
294  Given a node, finds all the final state particles connected to it and get their indices in the LCA.
295 
296  Args:
297  node (Node): Node to be inspected.
298 
299  Returns:
300  indices (list): List of final state particles' indices in the LCA matrix connected to node.
301  """
302  indices = []
303 
304  if node.lca_index is not None: # If you simply use 'if node.lca_index:' you will always miss the first fsp
305  indices.append(node.lca_index)
306  else:
307  for child in node.children:
308  indices.extend(_get_fsps_of_node(child))
309 
310  return list(set(indices))
311 
312 
313 def select_good_decay(predicted_lcas, predicted_masses, sig_side_lcas=None, sig_side_masses=None):
314  """
315  Checks if given LCAS matrix is found in reconstructed LCAS matrix and mass hypotheses are correct.
316 
317  .. warning:: You have to make sure to call this function only for valid tree structures encoded in ``predicted_lcas``,
318  otherwise it will throw an exception.
319 
320  Mass hypotheses are indicated by letters. The following convention is used:
321 
322  .. math::
323  'e' \\to e \\\\
324  'i' \\to \\pi \\\\
325  'k' \\to K \\\\
326  'p' \\to p \\\\
327  'm' \\to \\mu \\\\
328  'g' \\to \\gamma \\\\
329  'o' \\to \\text{others}
330 
331  .. warning:: The order of mass hypotheses should match that of the final state particles in the LCAS.
332 
333  :param predicted_lcas: LCAS matrix.
334  :type predicted_lcas: `Tensor <https://pytorch.org/docs/stable/tensors.html#torch.Tensor>`_
335  :param predicted_masses: List of predicted mass classes.
336  :type predicted_masses: list[str]
337  :param sig_side_lcas: LCAS matrix of your signal-side.
338  :type sig_side_lcas: `Tensor <https://pytorch.org/docs/stable/tensors.html#torch.Tensor>`_
339  :param sig_side_masses: List of mass hypotheses for your FSPs.
340  :type sig_side_masses: list[str]
341 
342  Returns:
343  bool, int, list: True if LCAS and masses match, LCAS level of root node,
344  LCA indices of FSPs belonging to the signal side ([-1] if LCAS does not match decay string).
345  """
346 
347  # Reconstruct decay chain
348  root, _ = _reconstruct(predicted_lcas)
349 
350  # If root is not Ups nor B then decay is not good
351  if root.lcas_level not in [5, 6]:
352  return (False, root.lcas_level, [-1])
353 
354  # If root is B don't go any further (function is supposed to check wheter signal-side on Ups decay is good)
355  if root.lcas_level == 5:
356  return (True, 5, [i for i in range(predicted_lcas.shape[0])])
357 
358  # If chosen LCAS or masses are None then decay is not good
359  if sig_side_lcas is None or sig_side_masses is None:
360  return (False, root.lcas_level, [-1])
361 
362  # Check if the LCA matrix/masses you chose are valid
363  try:
364  if sig_side_lcas.item() == 0:
365  more_fsps = False
366  else:
367  raise InvalidLCAMatrix("If you have only one sig-side FSP, the LCA matrix should be [[0]]")
368  except ValueError:
369  try:
370  lca_to_adjacency(sig_side_lcas)
371  more_fsps = True
372  except InvalidLCAMatrix:
373  raise InvalidLCAMatrix("You chose an invalid LCA matrix")
374 
375  # Check if the number of FSPs in the LCA is the same as the number of mass hypotheses
376  if sig_side_lcas.shape[0] != len(sig_side_masses):
377  raise InvalidLCAMatrix("The dimension of the LCA matrix you chose does not match with the number of mass hypotheses")
378 
379  # Check if mass hypotheses are allowed
380  for e in set(sig_side_masses):
381  if e not in ['i', 'o', 'g', 'k', 'm', 'e', 'p']:
382  # Ok this is not properly an InvalidLCAMatrix case but I'm too lazy to define dedicated exception
383  raise InvalidLCAMatrix("Allowed mass hypotheses are 'i', 'o', 'g', 'k', 'm', 'e', 'p'")
384 
385  # Convert mass hypotheses to classes and then to integers
386  for s, n in zip(["i", "k", "p", "e", "m", "g", "o"], ["3", "4", "5", "1", "2", "6", "0"]):
387  sig_side_masses = list(map(lambda x: x.replace(s, n), sig_side_masses))
388  sig_side_masses = t.from_numpy(np.array(sig_side_masses, dtype=int))
389 
390  # Let's start the proper decay check
391  # Case 1: only one FSP in the signal-side
392  if not more_fsps:
393  # There should be two nodes: one '5' and one '0'
394  if Counter([child.lcas_level for child in root.children]) != Counter({5: 1, 0: 1}):
395  return (False, root.lcas_level, [-1])
396 
397  # Get FSP index in LCA
398  fsp_idx = root.children[0].lca_index if root.children[0].lcas_level == 0 else root.children[1].lca_index
399 
400  # Check mass hypothesis
401  if predicted_masses[fsp_idx] != sig_side_masses[0]:
402  return (False, root.lcas_level, [-1])
403 
404  # I think the exceptions are over, decay is good
405  return (True, root.lcas_level, [fsp_idx])
406 
407  # Case 2: more FSPs in the signal-side
408  else:
409  # There should be two nodes labelled as '5'
410  if Counter([child.lcas_level for child in root.children]) != Counter({5: 2}):
411  return (False, root.lcas_level, [-1])
412 
413  # If there are two '5', at least one of them should decay into the nodes given by the chosen LCAS/masses
414  # Step 1: get LCA indices of both Bs
415  B1_indices = _get_fsps_of_node(root.children[0])
416  B2_indices = _get_fsps_of_node(root.children[1])
417 
418  # Step 2: Loop over the two Bs and select LCA sub-matrix and sub-masses
419  for indices in [B1_indices, B2_indices]:
420  # Step 3: check whether number of FSPs in the chosen sig-side corresponds to that of one of the B's
421  if sig_side_lcas.shape[0] != len(indices):
422  continue
423 
424  sub_lca = predicted_lcas[indices][:, indices]
425  sub_masses = predicted_masses[indices]
426 
427  # Step 4: your chosen sig-side LCAS/masses could have different ordering,
428  # we have to check all possible permutations
429  for permutation in permutations(list(range(len(sub_lca)))):
430  permutation = list(permutation)
431  permuted_sig_side_lca = sig_side_lcas[permutation][:, permutation]
432  permuted_sig_side_masses = sig_side_masses[permutation]
433  # Step 5: if one of the permutations works decay is good
434  if (permuted_sig_side_lca == sub_lca).all() and (permuted_sig_side_masses == sub_masses).all():
435  return (True, root.lcas_level, indices)
436 
437  # If we get here decay is not good
438  return (False, root.lcas_level, [-1])
def __init__(self, level, children, lca_index=None, lcas_level=0)