Belle II Software  light-2403-persian
tree_utils.py
1 
8 
9 
10 import torch as t
11 import numpy as np
12 
13 
14 def masses_to_classes(array):
15  """
16  Converts mass hypotheses to classes used in cross-entropy computation.
17 
18  Classes are:
19 
20  .. math::
21  e \\to 1\\\\
22  \\mu \\to 2\\\\
23  \\pi \\to 3\\\\
24  K \\to 4\\\\
25  p \\to 5\\\\
26  \\gamma \\to 6\\\\
27  \\text{others} \\to 0
28 
29  Args:
30  array (numpy.ndarray): Array containing PDG mass codes.
31 
32  Returns:
33  numpy.ndarray: Array containing mass hypothese converted to classes.
34  """
35  array = -1 * np.abs(array) # All elements become negative
36  array[array == -11] = 1 # Electrons are class 1
37  array[array == -13] = 2 # Muons are class 2
38  array[array == -211] = 3 # Pions are class 3
39  array[array == -321] = 4 # Kaons are class 4
40  array[array == -2212] = 5 # Protons are class 5
41  array[array == -22] = 6 # Photons are class 6
42  array[array <= 0] = 0 # Others are 0
43 
44  return array
45 
46 
47 def _check_undirected(adjacency_matrix):
48  """
49  Checks whether an adjacency matrix-encoded graph is undirected, i.e. symmetric.
50  """
51  n, m = adjacency_matrix.shape
52  if n != m:
53  return False
54 
55  return (adjacency_matrix == adjacency_matrix.T).all()
56 
57 
58 def _connectedness_dfs(adjacency_matrix, index, reached):
59  """
60  Actual depth-first search of graph connectedness. Starting from the node marked by index a recursive search is
61  performed. Visited nodes are marked as reachable during recursion. If the graph is not connected, the reachability
62  `reached` mask will contain zero elements.
63  """
64  n = adjacency_matrix.shape[0]
65  reached[index] = 1
66 
67  # Traverse through all
68  for column in range(n):
69  # Recursively search for connectedness nodes that are adjacent and avoid nodes already marked as reachable
70  if adjacency_matrix[index, column] != 0 and not reached[column]:
71  _connectedness_dfs(adjacency_matrix, column, reached)
72 
73 
74 def _check_connectedness(adjacency_matrix, allow_disconnected_leaves=False):
75  """
76  Checks whether all sub-graphs of an adjacency matrix-encoded graph are connected,
77  i.e. have at least one edge linking them.
78  """
79  n, m = adjacency_matrix.shape
80  if n != m:
81  return False
82 
83  reached = t.zeros(n, dtype=t.uint8)
84  _connectedness_dfs(adjacency_matrix, 0, reached)
85 
86  if allow_disconnected_leaves:
87  reached = t.logical_or(reached, adjacency_matrix.sum(axis=1) == 0)
88 
89  return reached.all()
90 
91 
92 def _acyclic_dfs(adjacency_matrix, index, parent, reached):
93  """
94  Actual depth-first search of graph cycles. Starting from the node marked by index a recursive search is performed.
95  Visited nodes are marked as reachable during recursion. If a node is found in a trail that has been previously
96  marked as already reached this indicates a cycle.
97  """
98  n = adjacency_matrix.shape[0]
99  reached[index] = 1
100 
101  for row in range(n):
102  # the passed adjacency matrix may contain self-references
103  # while technically not acyclic, these are allowed,
104  if row == index:
105  continue
106 
107  if adjacency_matrix[index, row] != 0:
108  if not reached[row]:
109  # cycle
110  if not _acyclic_dfs(adjacency_matrix, row, index, reached):
111  return False
112  elif row != parent:
113  # cycle
114  return False
115  return True
116 
117 
118 def _check_acyclic(adjacency_matrix):
119  """
120  Checks whether the graph encoded by the passed adjacency matrix is acyclic, i.e. all non-empty trails in the graph
121  do not contain repetitions. Node self-references are legal and simply ignored.
122  """
123  n, m = adjacency_matrix.shape
124  if n != m:
125  return False
126 
127  reached = t.zeros(n, dtype=t.uint8)
128 
129  return _acyclic_dfs(adjacency_matrix, 0, -1, reached)
130 
131 
132 def is_valid_tree(adjacency_matrix):
133  """
134  Checks whether the graph encoded by the passed adjacency matrix encodes a valid tree,
135  i.e. an undirected, acyclic and connected graph.
136 
137  Args:
138  adjacency_matrix (numpy.ndarray): 2-dimensional matrix (N, N) encoding the graph's node adjacencies.
139  Linked nodes should have value unequal to zero.
140 
141  Returns:
142  bool: True if the encoded graph is a tree, False otherwise.
143  """
144  undirected = _check_undirected(adjacency_matrix)
145  connected = _check_connectedness(adjacency_matrix)
146  acyclic = _check_acyclic(adjacency_matrix)
147 
148  return undirected and connected and acyclic