Belle II Software  light-2403-persian
metrics.py
1 
8 
9 
10 import torch
11 from torch_scatter import scatter
12 from ignite.metrics import Metric
13 from ignite.exceptions import NotComputableError
14 from ignite.metrics.metric import sync_all_reduce, reinit__is_reduced
15 
16 
17 class PerfectLCA(Metric, object):
18  """
19  Computes the rate of perfectly predicted LCAS matrices over a batch.
20 
21  ``output_transform`` should return the following items: ``(edge_pred, edge_y, edge_index, u_y, batch, num_graphs)``.
22 
23  * ``edge_pred`` must contain edge prediction logits and have shape (num_edges_in_batch, edge_classes);
24  * ``edge_y`` must contain edge ground-truth class indices and have shape (num_edges_in_batch, 1);
25  * ``edge index`` maps edges to its nodes;
26  * ``u_y`` is the signal/background class (always 1 in the current setting);
27  * ``batch`` maps nodes to their graph;
28  * ``num_graphs`` is the number of graph in a batch (could be derived from ``batch`` also).
29 
30  .. seealso::
31  `Ignite metrics <https://pytorch.org/ignite/metrics.html>`_
32 
33  :param ignore_index: Class or list of classes to ignore during the computation (e.g. padding).
34  :type ignore_index: list[int]
35  :param output_transform: Function to transform engine's output to desired output.
36  :type output_transform: `function <https://docs.python.org/3/glossary.html#term-function>`_
37  :param device: ``cpu`` or ``gpu``.
38  :type device: str
39  """
40 
41  def __init__(self, ignore_index, output_transform, device='cpu'):
42  """
43  Initialization.
44  """
45 
46  self.ignore_indexignore_index = ignore_index if isinstance(ignore_index, list) else [ignore_index]
47 
48  self.devicedevice = device
49 
50  self._per_corrects_per_corrects = None
51 
52  self._num_examples_num_examples = None
53 
54  super(PerfectLCA, self).__init__(output_transform=output_transform, device=device)
55 
56  @reinit__is_reduced
57  def reset(self):
58  """
59  Resets counters.
60  """
61  self._per_corrects_per_corrects = 0
62  self._num_examples_num_examples = 0
63 
64  super(PerfectLCA, self).reset()
65 
66  @reinit__is_reduced
67  def update(self, output):
68  """
69  Updates counts.
70  """
71  edge_pred, edge_y, edge_index, u_y, batch, num_graphs = output
72 
73  num_graphs = num_graphs.item()
74 
75  probs = torch.softmax(edge_pred, dim=1)
76  winners = probs.argmax(dim=1)
77 
78  assert winners.shape == edge_y.shape, 'Edge predictions shape does not match target shape'
79 
80  # Create a mask for the zeroth elements (padded entries)
81  mask = torch.ones(edge_y.size(), dtype=torch.long, device=self.devicedevice)
82  for ig_class in self.ignore_indexignore_index:
83  mask &= (edge_y != ig_class)
84 
85  # Zero the respective entries in the predictions
86  y_pred_mask = winners * mask
87  y_mask = edge_y * mask
88 
89  # (N) compare the masked predictions with the target. The padded will be equal due to masking
90  truth = y_pred_mask.eq(y_mask) + 0 # +0 so it's not bool but 0 and 1
91  truth = scatter(truth, edge_index[0], reduce="min")
92  truth = scatter(truth, batch, reduce="min")
93 
94  # Count the number of zero wrong predictions across the batch
95  batch_perfect = truth.sum().item()
96 
97  self._per_corrects_per_corrects += batch_perfect
98  self._num_examples_num_examples += num_graphs
99 
100  @sync_all_reduce("_perfectLCA")
101  def compute(self):
102  """
103  Final result.
104  """
105  if self._num_examples_num_examples == 0:
106  raise NotComputableError(
107  "CustomAccuracy must have at least one example before it can be computed."
108  )
109  return self._per_corrects_per_corrects / self._num_examples_num_examples
110 
111 
112 class PerfectMasses(Metric, object):
113  """
114  Computes the rate of events with perfectly predicted mass hypotheses over a batch.
115 
116  ``output_transform`` should return the following items: ``(x_pred, x_y, u_y, batch, num_graphs)``.
117 
118  * ``x_pred`` must contain node prediction logits and have shape (num_nodes_in_batch, node_classes);
119  * ``x_y`` must contain node ground-truth class indices and have shape (num_nodes_in_batch, 1);
120  * ``u_y`` is the signal/background class (always 1 in the current setting);
121  * ``batch`` maps nodes to their graph;
122  * ``num_graphs`` is the number of graph in a batch (could be derived from ``batch`` also).
123 
124  .. seealso::
125  `Ignite metrics <https://pytorch.org/ignite/metrics.html>`_
126 
127  :param ignore_index: Class or list of classes to ignore during the computation (e.g. padding).
128  :type ignore_index: list[int]
129  :param output_transform: Function to transform engine's output to desired output.
130  :type output_transform: `function <https://docs.python.org/3/glossary.html#term-function>`_
131  :param device: ``cpu`` or ``gpu``.
132  :type device: str
133  """
134 
135  def __init__(self, ignore_index, output_transform, device='cpu'):
136  """
137  Initialization.
138  """
139 
140  self.ignore_indexignore_index = ignore_index if isinstance(ignore_index, list) else [ignore_index]
141 
142  self.devicedevice = device
143 
144  self._per_corrects_per_corrects = None
145 
146  self._num_examples_num_examples = None
147 
148  super(PerfectMasses, self).__init__(output_transform=output_transform, device=device)
149 
150  @reinit__is_reduced
151  def reset(self):
152  """
153  Resets counts.
154  """
155  self._per_corrects_per_corrects = 0
156  self._num_examples_num_examples = 0
157 
158  super(PerfectMasses, self).reset()
159 
160  @reinit__is_reduced
161  def update(self, output):
162  """
163  Updates counts.
164  """
165  x_pred, x_y, u_y, batch, num_graphs = output
166 
167  num_graphs = num_graphs.item()
168 
169  probs = torch.softmax(x_pred, dim=1)
170  winners = probs.argmax(dim=1)
171 
172  assert winners.shape == x_y.shape, 'Mass predictions shape does not match target shape'
173 
174  # Create a mask for the zeroth elements (padded entries)
175  mask = torch.ones(x_y.size(), dtype=torch.long, device=self.devicedevice)
176  for ig_class in self.ignore_indexignore_index:
177  mask &= (x_y != ig_class)
178 
179  # Zero the respective entries in the predictions
180  y_pred_mask = winners * mask
181  y_mask = x_y * mask
182 
183  # (N) compare the masked predictions with the target. The padded will be equal due to masking
184  truth = y_pred_mask.eq(y_mask) + 0 # +0 so it's not bool but 0 and 1
185  truth = scatter(truth, batch, reduce="min")
186 
187  # Count the number of zero wrong predictions across the batch
188  batch_perfect = truth.sum().item()
189 
190  self._per_corrects_per_corrects += batch_perfect
191  self._num_examples_num_examples += num_graphs
192 
193  @sync_all_reduce("_perfectMasses")
194  def compute(self):
195  """
196  Final computation.
197  """
198  if self._num_examples_num_examples == 0:
199  raise NotComputableError(
200  "CustomAccuracy must have at least one example before it can be computed."
201  )
202  return self._per_corrects_per_corrects / self._num_examples_num_examples
203 
204 
205 class PerfectEvent(Metric, object):
206  """
207  Computes the rate of events with perfectly predicted mass hypotheses and LCAS matrices over a batch.
208 
209  ``output_transform`` should return the following items:
210  ``(x_pred, x_y, edge_pred, edge_y, edge_index, u_y, batch, num_graphs)``.
211 
212  * ``x_pred`` must contain node prediction logits and have shape (num_nodes_in_batch, node_classes);
213  * ``x_y`` must contain node ground-truth class indices and have shape (num_nodes_in_batch, 1);
214  * ``edge_pred`` must contain edge prediction logits and have shape (num_edges_in_batch, edge_classes);
215  * ``edge_y`` must contain edge ground-truth class indices and have shape (num_edges_in_batch, 1);
216  * ``edge index`` maps edges to its nodes;
217  * ``u_y`` is the signal/background class (always 1 in the current setting);
218  * ``batch`` maps nodes to their graph;
219  * ``num_graphs`` is the number of graph in a batch (could be derived from ``batch`` also).
220 
221  .. seealso::
222  `Ignite metrics <https://pytorch.org/ignite/metrics.html>`_
223 
224  :param ignore_index: Class or list of classes to ignore during the computation (e.g. padding).
225  :type ignore_index: list[int]
226  :param output_transform: Function to transform engine's output to desired output.
227  :type output_transform: `function <https://docs.python.org/3/glossary.html#term-function>`_
228  :param device: ``cpu`` or ``gpu``.
229  :type device: str
230  """
231 
232  def __init__(self, ignore_index, output_transform, device='cpu'):
233  """
234  Initialization.
235  """
236 
237  self.ignore_indexignore_index = ignore_index if isinstance(ignore_index, list) else [ignore_index]
238 
239  self.devicedevice = device
240 
241  self._per_corrects_per_corrects = None
242 
243  self._num_examples_num_examples = None
244 
245  super(PerfectEvent, self).__init__(output_transform=output_transform, device=device)
246 
247  @reinit__is_reduced
248  def reset(self):
249  """
250  Resets counts.
251  """
252  self._per_corrects_per_corrects = 0
253  self._num_examples_num_examples = 0
254 
255  super(PerfectEvent, self).reset()
256 
257  @reinit__is_reduced
258  def update(self, output):
259  """
260  Updates counts.
261  """
262  x_pred, x_y, edge_pred, edge_y, edge_index, u_y, batch, num_graphs = output
263 
264  num_graphs = num_graphs.item()
265 
266  x_probs = torch.softmax(x_pred, dim=1)
267  x_winners = x_probs.argmax(dim=1)
268  edge_probs = torch.softmax(edge_pred, dim=1)
269  edge_winners = edge_probs.argmax(dim=1)
270 
271  assert x_winners.shape == x_y.shape, 'Mass predictions shape does not match target shape'
272  assert edge_winners.shape == edge_y.shape, 'Edge predictions shape does not match target shape'
273 
274  # Create a mask for the zeroth elements (padded entries)
275  x_mask = torch.ones(x_y.size(), dtype=torch.long, device=self.devicedevice)
276  edge_mask = torch.ones(edge_y.size(), dtype=torch.long, device=self.devicedevice)
277  for ig_class in self.ignore_indexignore_index:
278  x_mask &= (x_y != ig_class)
279  edge_mask &= (edge_y != ig_class)
280 
281  # Zero the respective entries in the predictions
282  x_pred_mask = x_winners * x_mask
283  x_mask = x_y * x_mask
284  edge_pred_mask = edge_winners * edge_mask
285  edge_mask = edge_y * edge_mask
286 
287  # (N) compare the masked predictions with the target. The padded will be equal due to masking
288  # Masses
289  x_truth = x_pred_mask.eq(x_mask) + 0 # +0 so it's not bool but 0 and 1
290  x_truth = scatter(x_truth, batch, reduce="min")
291  # Edges
292  edge_truth = edge_pred_mask.eq(edge_mask) + 0 # +0 so it's not bool but 0 and 1
293  edge_truth = scatter(edge_truth, edge_index[0], reduce="min")
294  edge_truth = scatter(edge_truth, batch, reduce="min")
295 
296  # Count the number of zero wrong predictions across the batch
297  truth = x_truth.bool() & edge_truth.bool()
298  batch_perfect = (truth + 0).sum().item()
299 
300  self._per_corrects_per_corrects += batch_perfect
301  self._num_examples_num_examples += num_graphs
302 
303  @sync_all_reduce("_perfectEvent")
304  def compute(self):
305  """
306  Final computation.
307  """
308  if self._num_examples_num_examples == 0:
309  raise NotComputableError(
310  "CustomAccuracy must have at least one example before it can be computed."
311  )
312  return self._per_corrects_per_corrects / self._num_examples_num_examples
ignore_index
Ignore index.
Definition: metrics.py:237
_num_examples
Total samples.
Definition: metrics.py:243
def reset(self)
Definition: metrics.py:248
def compute(self)
Definition: metrics.py:304
_per_corrects
Good samples.
Definition: metrics.py:241
def __init__(self, ignore_index, output_transform, device='cpu')
Definition: metrics.py:232
def update(self, output)
Definition: metrics.py:258
device
CPU or GPU.
Definition: metrics.py:239
ignore_index
Ignore index.
Definition: metrics.py:46
_num_examples
Total samples.
Definition: metrics.py:52
def reset(self)
Definition: metrics.py:57
def compute(self)
Definition: metrics.py:101
_per_corrects
Good samples.
Definition: metrics.py:50
def __init__(self, ignore_index, output_transform, device='cpu')
Definition: metrics.py:41
def update(self, output)
Definition: metrics.py:67
device
CPU or GPU.
Definition: metrics.py:48
ignore_index
Ignore index.
Definition: metrics.py:140
_num_examples
Total samples.
Definition: metrics.py:146
def compute(self)
Definition: metrics.py:194
_per_corrects
Good samples.
Definition: metrics.py:144
def __init__(self, ignore_index, output_transform, device='cpu')
Definition: metrics.py:135
def update(self, output)
Definition: metrics.py:161
device
CPU or GPU.
Definition: metrics.py:142