12from collections
import Counter
13from itertools
import permutations
14from .tree_utils
import is_valid_tree
19 Specialized Exception sub-class raised for malformed LCA matrices or LCA matrices not encoding trees.
26 Class to hold levels of nodes in the tree.
29 level (int): Level
in the tree.
30 children (list[Node]): Children of the nodes.
31 lca_index (int): Index
in the LCAS matrix.
32 lcas_level (int): Level
in the LCAS matrix.
35 def __init__(self, level, children, lca_index=None, lcas_level=0):
54def _get_ancestor(node):
56 Trail search for the highest ancestor of a node.
60 while ancestor.parent
is not None:
61 ancestor = ancestor.parent
66def _nodes_in_ancestors_children(parent, node1, node2):
68 Checks if any node
in parent
's line of descent is also an ancestor of both node1 and node2.
70 for child
in parent.children:
71 if (node1
in child.children)
and (node2
in child.children):
74 _nodes_in_ancestors_children(child, node1, node2)
81 Works up the node's history, pulling down a level any nodes
82 whose children are all more than one level below.
84 Performs the operation in place.
87 if len(node.children) > 0:
88 highest_child = max([c.level
for c
in node.children])
89 node.level = highest_child + 1
92 if node.parent
is not None:
93 _pull_down(node.parent)
98def _breadth_first_enumeration(root, queue, adjacency_matrix):
100 Enumerates the tree breadth-first into a queue.
104 queue.setdefault(level, []).append(root)
107 for child
in root.children:
108 _breadth_first_enumeration(child, queue, adjacency_matrix)
113def _breadth_first_adjacency(root, adjacency_matrix):
115 Enumerates the tree breadth-first into a queue.
117 queue = _breadth_first_enumeration(root, {}, adjacency_matrix)
121 for i
in range(root.level, 0, -1):
122 for node
in queue[i]:
123 node.bfs_index = index
127 for i
in range(root.level, 0, -1):
128 for node
in queue[i]:
129 for child
in node.children:
130 adjacency_matrix[node.bfs_index, child.bfs_index] = 1
131 adjacency_matrix[child.bfs_index, node.bfs_index] = 1
134def _reconstruct(lca_matrix):
136 Does the actual heavy lifting of the adjacency matrix reconstruction. Traverses the LCA matrix level-by-level,
137 starting at one. For each level new nodes have to be inserted into the adjacency matrix, if a LCA matrix
with this
138 level number exists. The newly created node(s) will then be connected to the lower leaves, respectively,
139 sub-graphs. This function may produce reconstructions that are valid graphs, but
not trees.
141 n = lca_matrix.shape[0]
144 levels = sorted(lca_matrix.unique().tolist())
149 leaves = [
Node(1, [], lca_index=i)
for i
in range(n)]
156 for idx, current_level
in enumerate(levels, 1):
158 for column
in range(n):
161 for row
in range(column + 1, n):
163 if lca_matrix[row, column] <= 0:
164 raise InvalidLCAMatrix
165 elif lca_matrix[row, column] != current_level:
169 a_node = leaves[column]
170 another_node = leaves[row]
173 an_ancestor = _get_ancestor(a_node)
174 a_level = an_ancestor.level
176 another_ancestor = _get_ancestor(another_node)
177 another_level = another_ancestor.level
181 if a_level == another_level == (idx + 1):
183 an_ancestor
is not another_ancestor
184 or _nodes_in_ancestors_children(
185 an_ancestor, a_node, another_node
188 raise InvalidLCAMatrix
191 elif a_level > idx + 1
or another_level > idx + 1:
192 raise InvalidLCAMatrix
196 elif a_level < idx + 1
and another_level < idx + 1:
197 parent =
Node(idx + 1, [an_ancestor, another_ancestor], lcas_level=current_level)
198 an_ancestor.parent = parent
199 another_ancestor.parent = parent
204 elif another_level < idx + 1
and a_level == idx + 1:
208 another_ancestor.parent = an_ancestor
209 an_ancestor.children.append(another_ancestor)
212 elif a_level < idx + 1
and another_level == idx + 1:
213 an_ancestor.parent = another_ancestor
214 another_ancestor.children.append(an_ancestor)
218 raise InvalidLCAMatrix
226 root = _get_ancestor(leaves[0])
228 return root, total_nodes
233 Converts a tree's LCA matrix representation, i.e. a square matrix (M, M) where each row/column corresponds to
234 a leaf of the tree and each matrix entry
is the level of the lowest-common-ancestor (LCA) of the two leaves, into
235 the corresponding two-dimension adjacency matrix (N,N),
with M < N. The levels are enumerated top-down
from the
239 The pseudocode
for LCA to tree conversion
is described
in
240 `Kahn et al <https://iopscience.iop.org/article/10.1088/2632-2153/ac8de0>`_.
242 :param lca_matrix: 2-dimensional LCA matrix (M, M).
243 :type lca_matrix: `Tensor <https://pytorch.org/docs/stable/tensors.html
245 :
return: 2-dimensional matrix (N, N) encoding the graph
's node adjacencies. Linked nodes have values unequal to zero.
246 :rtype: `Tensor <https://pytorch.org/docs/stable/tensors.html
249 InvalidLCAMatrix: If passed LCA matrix
is malformed (e.g.
not 2d
or not square)
or does
not encode a tree.
253 if not isinstance(lca_matrix, t.Tensor):
255 lca_matrix = t.Tensor(lca_matrix)
256 except TypeError
as err:
257 print(f
"Input type must be compatible with torch Tensor: {err}")
261 if len(lca_matrix.shape) != 2:
262 raise InvalidLCAMatrix
265 n, m = lca_matrix.shape
267 raise InvalidLCAMatrix
270 if not (lca_matrix == lca_matrix.T).all():
271 raise InvalidLCAMatrix
274 root, total_nodes = _reconstruct(lca_matrix)
276 raise InvalidLCAMatrix
279 adjacency_matrix = t.zeros((total_nodes, total_nodes), dtype=t.int64)
281 _breadth_first_adjacency(root, adjacency_matrix)
283 raise InvalidLCAMatrix
286 if not is_valid_tree(adjacency_matrix):
287 raise InvalidLCAMatrix
289 return adjacency_matrix
292def _get_fsps_of_node(node):
294 Given a node, finds all the final state particles connected to it and get their indices
in the LCA.
297 node (Node): Node to be inspected.
300 indices (list): List of final state particles
' indices in the LCA matrix connected to node.
304 if node.lca_index
is not None:
305 indices.append(node.lca_index)
307 for child
in node.children:
308 indices.extend(_get_fsps_of_node(child))
310 return list(set(indices))
313def select_good_decay(predicted_lcas, predicted_masses, sig_side_lcas=None, sig_side_masses=None):
315 Checks if given LCAS matrix
is found
in reconstructed LCAS matrix
and mass hypotheses are correct.
317 .. warning:: You have to make sure to call this function only
for valid tree structures encoded
in ``predicted_lcas``,
318 otherwise it will throw an exception.
320 Mass hypotheses are indicated by letters. The following convention
is used:
328 'g' \\to \\gamma \\\\
329 'o' \\to \\text{others}
331 .. warning:: The order of mass hypotheses should match that of the final state particles
in the LCAS.
333 :param predicted_lcas: LCAS matrix.
334 :type predicted_lcas: `Tensor <https://pytorch.org/docs/stable/tensors.html
335 :param predicted_masses: List of predicted mass classes.
336 :type predicted_masses: list[str]
337 :param sig_side_lcas: LCAS matrix of your signal-side.
338 :type sig_side_lcas: `Tensor <https://pytorch.org/docs/stable/tensors.html
339 :param sig_side_masses: List of mass hypotheses
for your FSPs.
340 :type sig_side_masses: list[str]
343 bool, int, list:
True if LCAS
and masses match, LCAS level of root node,
344 LCA indices of FSPs belonging to the signal side ([-1]
if LCAS does
not match decay string).
348 root, _ = _reconstruct(predicted_lcas)
351 if root.lcas_level
not in [5, 6]:
352 return (
False, root.lcas_level, [-1])
355 if root.lcas_level == 5:
356 return (
True, 5, [i
for i
in range(predicted_lcas.shape[0])])
359 if sig_side_lcas
is None or sig_side_masses
is None:
360 return (
False, root.lcas_level, [-1])
364 if sig_side_lcas.item() == 0:
367 raise InvalidLCAMatrix(
"If you have only one sig-side FSP, the LCA matrix should be [[0]]")
372 except InvalidLCAMatrix:
376 if sig_side_lcas.shape[0] != len(sig_side_masses):
377 raise InvalidLCAMatrix(
"The dimension of the LCA matrix you chose does not match with the number of mass hypotheses")
380 for e
in set(sig_side_masses):
381 if e
not in [
'i',
'o',
'g',
'k',
'm',
'e',
'p']:
383 raise InvalidLCAMatrix(
"Allowed mass hypotheses are 'i', 'o', 'g', 'k', 'm', 'e', 'p'")
386 for s, n
in zip([
"i",
"k",
"p",
"e",
"m",
"g",
"o"], [
"3",
"4",
"5",
"1",
"2",
"6",
"0"]):
387 sig_side_masses = list(map(
lambda x: x.replace(s, n), sig_side_masses))
388 sig_side_masses = t.from_numpy(np.array(sig_side_masses, dtype=int))
394 if Counter([child.lcas_level
for child
in root.children]) != Counter({5: 1, 0: 1}):
395 return (
False, root.lcas_level, [-1])
398 fsp_idx = root.children[0].lca_index
if root.children[0].lcas_level == 0
else root.children[1].lca_index
401 if predicted_masses[fsp_idx] != sig_side_masses[0]:
402 return (
False, root.lcas_level, [-1])
405 return (
True, root.lcas_level, [fsp_idx])
410 if Counter([child.lcas_level
for child
in root.children]) != Counter({5: 2}):
411 return (
False, root.lcas_level, [-1])
415 B1_indices = _get_fsps_of_node(root.children[0])
416 B2_indices = _get_fsps_of_node(root.children[1])
419 for indices
in [B1_indices, B2_indices]:
421 if sig_side_lcas.shape[0] != len(indices):
424 sub_lca = predicted_lcas[indices][:, indices]
425 sub_masses = predicted_masses[indices]
429 for permutation
in permutations(list(range(len(sub_lca)))):
430 permutation = list(permutation)
431 permuted_sig_side_lca = sig_side_lcas[permutation][:, permutation]
432 permuted_sig_side_masses = sig_side_masses[permutation]
434 if (permuted_sig_side_lca == sub_lca).all()
and (permuted_sig_side_masses == sub_masses).all():
435 return (
True, root.lcas_level, indices)
438 return (
False, root.lcas_level, [-1])
def __init__(self, level, children, lca_index=None, lcas_level=0)