12from collections 
import Counter
 
   13from itertools 
import permutations
 
   14from .tree_utils 
import is_valid_tree
 
   19    Specialized Exception sub-class raised for malformed LCA matrices or LCA matrices not encoding trees. 
 
   26    Class to hold levels of nodes in the tree. 
   29        level (int): Level in the tree. 
   30        children (list[Node]): Children of the nodes. 
   31        lca_index (int): Index in the LCAS matrix. 
   32        lcas_level (int): Level in the LCAS matrix. 
   35    def __init__(self, level, children, lca_index=None, lcas_level=0):
 
 
 
   54def _get_ancestor(node):
 
   56    Trail search for the highest ancestor of a node. 
   60    while ancestor.parent 
is not None:
 
   61        ancestor = ancestor.parent
 
   66def _nodes_in_ancestors_children(parent, node1, node2):
 
   68    Checks if any node in parent's line of descent is also an ancestor of both node1 and node2. 
   70    for child 
in parent.children:
 
   71        if (node1 
in child.children) 
and (node2 
in child.children):
 
   74            _nodes_in_ancestors_children(child, node1, node2)
 
   81    Works up the node's history, pulling down a level any nodes 
   82    whose children are all more than one level below. 
   84    Performs the operation in place. 
   87    if len(node.children) > 0:
 
   88        highest_child = max([c.level 
for c 
in node.children])
 
   89        node.level = highest_child + 1
 
   92    if node.parent 
is not None:
 
   93        _pull_down(node.parent)
 
   98def _breadth_first_enumeration(root, queue, adjacency_matrix):
 
  100    Enumerates the tree breadth-first into a queue. 
  104    queue.setdefault(level, []).append(root)
 
  107    for child 
in root.children:
 
  108        _breadth_first_enumeration(child, queue, adjacency_matrix)
 
  113def _breadth_first_adjacency(root, adjacency_matrix):
 
  115    Enumerates the tree breadth-first into a queue. 
  117    queue = _breadth_first_enumeration(root, {}, adjacency_matrix)
 
  121    for i 
in range(root.level, 0, -1):
 
  122        for node 
in queue[i]:
 
  123            node.bfs_index = index
 
  127    for i 
in range(root.level, 0, -1):
 
  128        for node 
in queue[i]:
 
  129            for child 
in node.children:
 
  130                adjacency_matrix[node.bfs_index, child.bfs_index] = 1
 
  131                adjacency_matrix[child.bfs_index, node.bfs_index] = 1
 
  134def _reconstruct(lca_matrix):
 
  136    Does the actual heavy lifting of the adjacency matrix reconstruction. Traverses the LCA matrix level-by-level, 
  137    starting at one. For each level new nodes have to be inserted into the adjacency matrix, if a LCA matrix with this 
  138    level number exists. The newly created node(s) will then be connected to the lower leaves, respectively, 
  139    sub-graphs. This function may produce reconstructions that are valid graphs, but not trees. 
  141    n = lca_matrix.shape[0]
 
  144    levels = sorted(lca_matrix.unique().tolist())
 
  149    leaves = [
Node(1, [], lca_index=i) 
for i 
in range(n)]
 
  156    for idx, current_level 
in enumerate(levels, 1):
 
  158        for column 
in range(n):
 
  161            for row 
in range(column + 1, n):
 
  163                if lca_matrix[row, column] <= 0:
 
  164                    raise InvalidLCAMatrix
 
  165                elif lca_matrix[row, column] != current_level:
 
  169                a_node = leaves[column]
 
  170                another_node = leaves[row]
 
  173                an_ancestor = _get_ancestor(a_node)
 
  174                a_level = an_ancestor.level
 
  176                another_ancestor = _get_ancestor(another_node)
 
  177                another_level = another_ancestor.level
 
  181                if a_level == another_level == (idx + 1):
 
  183                        an_ancestor 
is not another_ancestor
 
  184                        or _nodes_in_ancestors_children(
 
  185                            an_ancestor, a_node, another_node
 
  188                        raise InvalidLCAMatrix
 
  191                elif a_level > idx + 1 
or another_level > idx + 1:
 
  192                    raise InvalidLCAMatrix
 
  196                elif a_level < idx + 1 
and another_level < idx + 1:
 
  197                    parent = 
Node(idx + 1, [an_ancestor, another_ancestor], lcas_level=current_level)
 
  198                    an_ancestor.parent = parent
 
  199                    another_ancestor.parent = parent
 
  204                elif another_level < idx + 1 
and a_level == idx + 1:
 
  208                    another_ancestor.parent = an_ancestor
 
  209                    an_ancestor.children.append(another_ancestor)
 
  212                elif a_level < idx + 1 
and another_level == idx + 1:
 
  213                    an_ancestor.parent = another_ancestor
 
  214                    another_ancestor.children.append(an_ancestor)
 
  218                    raise InvalidLCAMatrix
 
  226    root = _get_ancestor(leaves[0])
 
  228    return root, total_nodes
 
  233    Converts a tree's LCA matrix representation, i.e. a square matrix (M, M) where each row/column corresponds to 
  234    a leaf of the tree and each matrix entry is the level of the lowest-common-ancestor (LCA) of the two leaves, into 
  235    the corresponding two-dimension adjacency matrix (N,N), with M < N. The levels are enumerated top-down from the 
  239        The pseudocode for LCA to tree conversion is described in 
  240        `Kahn et al <https://iopscience.iop.org/article/10.1088/2632-2153/ac8de0>`_. 
  242    :param lca_matrix: 2-dimensional LCA matrix (M, M). 
  243    :type lca_matrix: `Tensor <https://pytorch.org/docs/stable/tensors.html#torch.Tensor>`_ 
  245    :return: 2-dimensional matrix (N, N) encoding the graph's node adjacencies. Linked nodes have values unequal to zero. 
  246    :rtype: `Tensor <https://pytorch.org/docs/stable/tensors.html#torch.Tensor>`_ 
  249        InvalidLCAMatrix: If passed LCA matrix is malformed (e.g. not 2d or not square) or does not encode a tree. 
  253    if not isinstance(lca_matrix, t.Tensor):
 
  255            lca_matrix = t.Tensor(lca_matrix)
 
  256        except TypeError 
as err:
 
  257            print(f
"Input type must be compatible with torch Tensor: {err}")
 
  261    if len(lca_matrix.shape) != 2:
 
  262        raise InvalidLCAMatrix
 
  265    n, m = lca_matrix.shape
 
  267        raise InvalidLCAMatrix
 
  270    if not (lca_matrix == lca_matrix.T).all():
 
  271        raise InvalidLCAMatrix
 
  274        root, total_nodes = _reconstruct(lca_matrix)
 
  276        raise InvalidLCAMatrix
 
  279    adjacency_matrix = t.zeros((total_nodes, total_nodes), dtype=t.int64)
 
  281        _breadth_first_adjacency(root, adjacency_matrix)
 
  283        raise InvalidLCAMatrix
 
  286    if not is_valid_tree(adjacency_matrix):
 
  287        raise InvalidLCAMatrix
 
  289    return adjacency_matrix
 
  292def _get_fsps_of_node(node):
 
  294    Given a node, finds all the final state particles connected to it and get their indices in the LCA. 
  297        node (Node): Node to be inspected. 
  300        indices (list): List of final state particles' indices in the LCA matrix connected to node. 
  304    if node.lca_index 
is not None:  
 
  305        indices.append(node.lca_index)
 
  307        for child 
in node.children:
 
  308            indices.extend(_get_fsps_of_node(child))
 
  310    return list(set(indices))
 
  313def select_good_decay(predicted_lcas, predicted_masses, sig_side_lcas=None, sig_side_masses=None):
 
  315    Checks if given LCAS matrix is found in reconstructed LCAS matrix and mass hypotheses are correct. 
  317    .. warning:: You have to make sure to call this function only for valid tree structures encoded in ``predicted_lcas``, 
  318        otherwise it will throw an exception. 
  320    Mass hypotheses are indicated by letters. The following convention is used: 
  328        'g' \\to \\gamma \\\\ 
  329        'o' \\to \\text{others} 
  331    .. warning:: The order of mass hypotheses should match that of the final state particles in the LCAS. 
  333    :param predicted_lcas: LCAS matrix. 
  334    :type predicted_lcas: `Tensor <https://pytorch.org/docs/stable/tensors.html#torch.Tensor>`_ 
  335    :param predicted_masses: List of predicted mass classes. 
  336    :type predicted_masses: list[str] 
  337    :param sig_side_lcas: LCAS matrix of your signal-side. 
  338    :type sig_side_lcas: `Tensor <https://pytorch.org/docs/stable/tensors.html#torch.Tensor>`_ 
  339    :param sig_side_masses: List of mass hypotheses for your FSPs. 
  340    :type sig_side_masses: list[str] 
  343        bool, int, list: True if LCAS and masses match, LCAS level of root node, 
  344        LCA indices of FSPs belonging to the signal side ([-1] if LCAS does not match decay string). 
  348    root, _ = _reconstruct(predicted_lcas)
 
  351    if root.lcas_level 
not in [5, 6]:
 
  352        return (
False, root.lcas_level, [-1])
 
  355    if root.lcas_level == 5:
 
  356        return (
True, 5, [i 
for i 
in range(predicted_lcas.shape[0])])
 
  359    if sig_side_lcas 
is None or sig_side_masses 
is None:
 
  360        return (
False, root.lcas_level, [-1])
 
  364        if sig_side_lcas.item() == 0:
 
  367            raise InvalidLCAMatrix(
"If you have only one sig-side FSP, the LCA matrix should be [[0]]")
 
  372        except InvalidLCAMatrix:
 
  376    if sig_side_lcas.shape[0] != len(sig_side_masses):
 
  377        raise InvalidLCAMatrix(
"The dimension of the LCA matrix you chose does not match with the number of mass hypotheses")
 
  380    for e 
in set(sig_side_masses):
 
  381        if e 
not in [
'i', 
'o', 
'g', 
'k', 
'm', 
'e', 
'p']:
 
  383            raise InvalidLCAMatrix(
"Allowed mass hypotheses are 'i', 'o', 'g', 'k', 'm', 'e', 'p'")
 
  386    for s, n 
in zip([
"i", 
"k", 
"p", 
"e", 
"m", 
"g", 
"o"], [
"3", 
"4", 
"5", 
"1", 
"2", 
"6", 
"0"]):
 
  387        sig_side_masses = list(map(
lambda x: x.replace(s, n), sig_side_masses))
 
  388    sig_side_masses = t.from_numpy(np.array(sig_side_masses, dtype=int))
 
  394        if Counter([child.lcas_level 
for child 
in root.children]) != Counter({5: 1, 0: 1}):
 
  395            return (
False, root.lcas_level, [-1])
 
  398        fsp_idx = root.children[0].lca_index 
if root.children[0].lcas_level == 0 
else root.children[1].lca_index
 
  401        if predicted_masses[fsp_idx] != sig_side_masses[0]:
 
  402            return (
False, root.lcas_level, [-1])
 
  405        return (
True, root.lcas_level, [fsp_idx])
 
  410        if Counter([child.lcas_level 
for child 
in root.children]) != Counter({5: 2}):
 
  411            return (
False, root.lcas_level, [-1])
 
  415        B1_indices = _get_fsps_of_node(root.children[0])
 
  416        B2_indices = _get_fsps_of_node(root.children[1])
 
  419        for indices 
in [B1_indices, B2_indices]:
 
  421            if sig_side_lcas.shape[0] != len(indices):
 
  424            sub_lca = predicted_lcas[indices][:, indices]
 
  425            sub_masses = predicted_masses[indices]
 
  429            for permutation 
in permutations(list(range(len(sub_lca)))):
 
  430                permutation = list(permutation)
 
  431                permuted_sig_side_lca = sig_side_lcas[permutation][:, permutation]
 
  432                permuted_sig_side_masses = sig_side_masses[permutation]
 
  434                if (permuted_sig_side_lca == sub_lca).all() 
and (permuted_sig_side_masses == sub_masses).all():
 
  435                    return (
True, root.lcas_level, indices)
 
  438        return (
False, root.lcas_level, [-1])
 
__init__(self, level, children, lca_index=None, lcas_level=0)