Belle II Software development
tree_utils.py
1
8
9
10import torch as t
11import numpy as np
12
13
14def masses_to_classes(array):
15 """
16 Converts mass hypotheses to classes used in cross-entropy computation.
17
18 Classes are:
19
20 .. math::
21 e \\to 1\\\\
22 \\mu \\to 2\\\\
23 \\pi \\to 3\\\\
24 K \\to 4\\\\
25 p \\to 5\\\\
26 \\gamma \\to 6\\\\
27 \\text{others} \\to 0
28
29 Args:
30 array (numpy.ndarray): Array containing PDG mass codes.
31
32 Returns:
33 numpy.ndarray: Array containing mass hypothese converted to classes.
34 """
35 array = -1 * np.abs(array) # All elements become negative
36 array[array == -11] = 1 # Electrons are class 1
37 array[array == -13] = 2 # Muons are class 2
38 array[array == -211] = 3 # Pions are class 3
39 array[array == -321] = 4 # Kaons are class 4
40 array[array == -2212] = 5 # Protons are class 5
41 array[array == -22] = 6 # Photons are class 6
42 array[array <= 0] = 0 # Others are 0
43
44 return array
45
46
47def _check_undirected(adjacency_matrix):
48 """
49 Checks whether an adjacency matrix-encoded graph is undirected, i.e. symmetric.
50 """
51 n, m = adjacency_matrix.shape
52 if n != m:
53 return False
54
55 return (adjacency_matrix == adjacency_matrix.T).all()
56
57
58def _connectedness_dfs(adjacency_matrix, index, reached):
59 """
60 Actual depth-first search of graph connectedness. Starting from the node marked by index a recursive search is
61 performed. Visited nodes are marked as reachable during recursion. If the graph is not connected, the reachability
62 `reached` mask will contain zero elements.
63 """
64 n = adjacency_matrix.shape[0]
65 reached[index] = 1
66
67 # Traverse through all
68 for column in range(n):
69 # Recursively search for connectedness nodes that are adjacent and avoid nodes already marked as reachable
70 if adjacency_matrix[index, column] != 0 and not reached[column]:
71 _connectedness_dfs(adjacency_matrix, column, reached)
72
73
74def _check_connectedness(adjacency_matrix, allow_disconnected_leaves=False):
75 """
76 Checks whether all sub-graphs of an adjacency matrix-encoded graph are connected,
77 i.e. have at least one edge linking them.
78 """
79 n, m = adjacency_matrix.shape
80 if n != m:
81 return False
82
83 reached = t.zeros(n, dtype=t.uint8)
84 _connectedness_dfs(adjacency_matrix, 0, reached)
85
86 if allow_disconnected_leaves:
87 reached = t.logical_or(reached, adjacency_matrix.sum(axis=1) == 0)
88
89 return reached.all()
90
91
92def _acyclic_dfs(adjacency_matrix, index, parent, reached):
93 """
94 Actual depth-first search of graph cycles. Starting from the node marked by index a recursive search is performed.
95 Visited nodes are marked as reachable during recursion. If a node is found in a trail that has been previously
96 marked as already reached this indicates a cycle.
97 """
98 n = adjacency_matrix.shape[0]
99 reached[index] = 1
100
101 for row in range(n):
102 # the passed adjacency matrix may contain self-references
103 # while technically not acyclic, these are allowed,
104 if row == index:
105 continue
106
107 if adjacency_matrix[index, row] != 0:
108 if not reached[row]:
109 # cycle
110 if not _acyclic_dfs(adjacency_matrix, row, index, reached):
111 return False
112 elif row != parent:
113 # cycle
114 return False
115 return True
116
117
118def _check_acyclic(adjacency_matrix):
119 """
120 Checks whether the graph encoded by the passed adjacency matrix is acyclic, i.e. all non-empty trails in the graph
121 do not contain repetitions. Node self-references are legal and simply ignored.
122 """
123 n, m = adjacency_matrix.shape
124 if n != m:
125 return False
126
127 reached = t.zeros(n, dtype=t.uint8)
128
129 return _acyclic_dfs(adjacency_matrix, 0, -1, reached)
130
131
132def is_valid_tree(adjacency_matrix):
133 """
134 Checks whether the graph encoded by the passed adjacency matrix encodes a valid tree,
135 i.e. an undirected, acyclic and connected graph.
136
137 Args:
138 adjacency_matrix (numpy.ndarray): 2-dimensional matrix (N, N) encoding the graph's node adjacencies. Linked nodes should have value unequal to zero.
139
140 Returns:
141 bool: True if the encoded graph is a tree, False otherwise.
142 """
143 undirected = _check_undirected(adjacency_matrix)
144 connected = _check_connectedness(adjacency_matrix)
145 acyclic = _check_acyclic(adjacency_matrix)
146
147 return undirected and connected and acyclic
148