Belle II Software development
metrics.py
1
8
9
10import torch
11from torch_scatter import scatter
12from ignite.metrics import Metric
13from ignite.exceptions import NotComputableError
14from ignite.metrics.metric import sync_all_reduce, reinit__is_reduced
15
16
17class PerfectLCA(Metric, object):
18 """
19 Computes the rate of perfectly predicted LCAS matrices over a batch.
20
21 ``output_transform`` should return the following items: ``(edge_pred, edge_y, edge_index, u_y, batch, num_graphs)``.
22
23 * ``edge_pred`` must contain edge prediction logits and have shape (num_edges_in_batch, edge_classes);
24 * ``edge_y`` must contain edge ground-truth class indices and have shape (num_edges_in_batch, 1);
25 * ``edge index`` maps edges to its nodes;
26 * ``u_y`` is the signal/background class (always 1 in the current setting);
27 * ``batch`` maps nodes to their graph;
28 * ``num_graphs`` is the number of graph in a batch (could be derived from ``batch`` also).
29
30 .. seealso::
31 `Ignite metrics <https://pytorch.org/ignite/metrics.html>`_
32
33 :param ignore_index: Class or list of classes to ignore during the computation (e.g. padding).
34 :type ignore_index: list[int]
35 :param output_transform: Function to transform engine's output to desired output.
36 :type output_transform: `function <https://docs.python.org/3/glossary.html#term-function>`_
37 :param device: ``cpu`` or ``gpu``.
38 :type device: str
39 """
40
41 def __init__(self, ignore_index, output_transform, device='cpu'):
42 """
43 Initialization.
44 """
45
46 self.ignore_index = ignore_index if isinstance(ignore_index, list) else [ignore_index]
47
48 self.device = device
49
50 self._per_corrects = None
51
52 self._num_examples = None
53
54 super(PerfectLCA, self).__init__(output_transform=output_transform, device=device)
55
56 @reinit__is_reduced
57 def reset(self):
58 """
59 Resets counters.
60 """
61 self._per_corrects = 0
62 self._num_examples = 0
63
64 super(PerfectLCA, self).reset()
65
66 @reinit__is_reduced
67 def update(self, output):
68 """
69 Updates counts.
70 """
71 edge_pred, edge_y, edge_index, u_y, batch, num_graphs = output
72
73 num_graphs = num_graphs.item()
74
75 probs = torch.softmax(edge_pred, dim=1)
76 winners = probs.argmax(dim=1)
77
78 assert winners.shape == edge_y.shape, 'Edge predictions shape does not match target shape'
79
80 # Create a mask for the zeroth elements (padded entries)
81 mask = torch.ones(edge_y.size(), dtype=torch.long, device=self.device)
82 for ig_class in self.ignore_index:
83 mask &= (edge_y != ig_class)
84
85 # Zero the respective entries in the predictions
86 y_pred_mask = winners * mask
87 y_mask = edge_y * mask
88
89 # (N) compare the masked predictions with the target. The padded will be equal due to masking
90 truth = y_pred_mask.eq(y_mask) + 0 # +0 so it's not bool but 0 and 1
91 truth = scatter(truth, edge_index[0], reduce="min")
92 truth = scatter(truth, batch, reduce="min")
93
94 # Count the number of zero wrong predictions across the batch
95 batch_perfect = truth.sum().item()
96
97 self._per_corrects += batch_perfect
98 self._num_examples += num_graphs
99
100 @sync_all_reduce("_perfectLCA")
101 def compute(self):
102 """
103 Final result.
104 """
105 if self._num_examples == 0:
106 raise NotComputableError(
107 "CustomAccuracy must have at least one example before it can be computed."
108 )
109 return self._per_corrects / self._num_examples
110
111
112class PerfectMasses(Metric, object):
113 """
114 Computes the rate of events with perfectly predicted mass hypotheses over a batch.
115
116 ``output_transform`` should return the following items: ``(x_pred, x_y, u_y, batch, num_graphs)``.
117
118 * ``x_pred`` must contain node prediction logits and have shape (num_nodes_in_batch, node_classes);
119 * ``x_y`` must contain node ground-truth class indices and have shape (num_nodes_in_batch, 1);
120 * ``u_y`` is the signal/background class (always 1 in the current setting);
121 * ``batch`` maps nodes to their graph;
122 * ``num_graphs`` is the number of graph in a batch (could be derived from ``batch`` also).
123
124 .. seealso::
125 `Ignite metrics <https://pytorch.org/ignite/metrics.html>`_
126
127 :param ignore_index: Class or list of classes to ignore during the computation (e.g. padding).
128 :type ignore_index: list[int]
129 :param output_transform: Function to transform engine's output to desired output.
130 :type output_transform: `function <https://docs.python.org/3/glossary.html#term-function>`_
131 :param device: ``cpu`` or ``gpu``.
132 :type device: str
133 """
134
135 def __init__(self, ignore_index, output_transform, device='cpu'):
136 """
137 Initialization.
138 """
139
140 self.ignore_index = ignore_index if isinstance(ignore_index, list) else [ignore_index]
141
142 self.device = device
143
144 self._per_corrects = None
145
146 self._num_examples = None
147
148 super(PerfectMasses, self).__init__(output_transform=output_transform, device=device)
149
150 @reinit__is_reduced
151 def reset(self):
152 """
153 Resets counts.
154 """
155 self._per_corrects = 0
156 self._num_examples = 0
157
158 super(PerfectMasses, self).reset()
159
160 @reinit__is_reduced
161 def update(self, output):
162 """
163 Updates counts.
164 """
165 x_pred, x_y, u_y, batch, num_graphs = output
166
167 num_graphs = num_graphs.item()
168
169 probs = torch.softmax(x_pred, dim=1)
170 winners = probs.argmax(dim=1)
171
172 assert winners.shape == x_y.shape, 'Mass predictions shape does not match target shape'
173
174 # Create a mask for the zeroth elements (padded entries)
175 mask = torch.ones(x_y.size(), dtype=torch.long, device=self.device)
176 for ig_class in self.ignore_index:
177 mask &= (x_y != ig_class)
178
179 # Zero the respective entries in the predictions
180 y_pred_mask = winners * mask
181 y_mask = x_y * mask
182
183 # (N) compare the masked predictions with the target. The padded will be equal due to masking
184 truth = y_pred_mask.eq(y_mask) + 0 # +0 so it's not bool but 0 and 1
185 truth = scatter(truth, batch, reduce="min")
186
187 # Count the number of zero wrong predictions across the batch
188 batch_perfect = truth.sum().item()
189
190 self._per_corrects += batch_perfect
191 self._num_examples += num_graphs
192
193 @sync_all_reduce("_perfectMasses")
194 def compute(self):
195 """
196 Final computation.
197 """
198 if self._num_examples == 0:
199 raise NotComputableError(
200 "CustomAccuracy must have at least one example before it can be computed."
201 )
202 return self._per_corrects / self._num_examples
203
204
205class PerfectEvent(Metric, object):
206 """
207 Computes the rate of events with perfectly predicted mass hypotheses and LCAS matrices over a batch.
208
209 ``output_transform`` should return the following items:
210 ``(x_pred, x_y, edge_pred, edge_y, edge_index, u_y, batch, num_graphs)``.
211
212 * ``x_pred`` must contain node prediction logits and have shape (num_nodes_in_batch, node_classes);
213 * ``x_y`` must contain node ground-truth class indices and have shape (num_nodes_in_batch, 1);
214 * ``edge_pred`` must contain edge prediction logits and have shape (num_edges_in_batch, edge_classes);
215 * ``edge_y`` must contain edge ground-truth class indices and have shape (num_edges_in_batch, 1);
216 * ``edge index`` maps edges to its nodes;
217 * ``u_y`` is the signal/background class (always 1 in the current setting);
218 * ``batch`` maps nodes to their graph;
219 * ``num_graphs`` is the number of graph in a batch (could be derived from ``batch`` also).
220
221 .. seealso::
222 `Ignite metrics <https://pytorch.org/ignite/metrics.html>`_
223
224 :param ignore_index: Class or list of classes to ignore during the computation (e.g. padding).
225 :type ignore_index: list[int]
226 :param output_transform: Function to transform engine's output to desired output.
227 :type output_transform: `function <https://docs.python.org/3/glossary.html#term-function>`_
228 :param device: ``cpu`` or ``gpu``.
229 :type device: str
230 """
231
232 def __init__(self, ignore_index, output_transform, device='cpu'):
233 """
234 Initialization.
235 """
236
237 self.ignore_index = ignore_index if isinstance(ignore_index, list) else [ignore_index]
238
239 self.device = device
240
241 self._per_corrects = None
242
243 self._num_examples = None
244
245 super(PerfectEvent, self).__init__(output_transform=output_transform, device=device)
246
247 @reinit__is_reduced
248 def reset(self):
249 """
250 Resets counts.
251 """
252 self._per_corrects = 0
253 self._num_examples = 0
254
255 super(PerfectEvent, self).reset()
256
257 @reinit__is_reduced
258 def update(self, output):
259 """
260 Updates counts.
261 """
262 x_pred, x_y, edge_pred, edge_y, edge_index, u_y, batch, num_graphs = output
263
264 num_graphs = num_graphs.item()
265
266 x_probs = torch.softmax(x_pred, dim=1)
267 x_winners = x_probs.argmax(dim=1)
268 edge_probs = torch.softmax(edge_pred, dim=1)
269 edge_winners = edge_probs.argmax(dim=1)
270
271 assert x_winners.shape == x_y.shape, 'Mass predictions shape does not match target shape'
272 assert edge_winners.shape == edge_y.shape, 'Edge predictions shape does not match target shape'
273
274 # Create a mask for the zeroth elements (padded entries)
275 x_mask = torch.ones(x_y.size(), dtype=torch.long, device=self.device)
276 edge_mask = torch.ones(edge_y.size(), dtype=torch.long, device=self.device)
277 for ig_class in self.ignore_index:
278 x_mask &= (x_y != ig_class)
279 edge_mask &= (edge_y != ig_class)
280
281 # Zero the respective entries in the predictions
282 x_pred_mask = x_winners * x_mask
283 x_mask = x_y * x_mask
284 edge_pred_mask = edge_winners * edge_mask
285 edge_mask = edge_y * edge_mask
286
287 # (N) compare the masked predictions with the target. The padded will be equal due to masking
288 # Masses
289 x_truth = x_pred_mask.eq(x_mask) + 0 # +0 so it's not bool but 0 and 1
290 x_truth = scatter(x_truth, batch, reduce="min")
291 # Edges
292 edge_truth = edge_pred_mask.eq(edge_mask) + 0 # +0 so it's not bool but 0 and 1
293 edge_truth = scatter(edge_truth, edge_index[0], reduce="min")
294 edge_truth = scatter(edge_truth, batch, reduce="min")
295
296 # Count the number of zero wrong predictions across the batch
297 truth = x_truth.bool() & edge_truth.bool()
298 batch_perfect = (truth + 0).sum().item()
299
300 self._per_corrects += batch_perfect
301 self._num_examples += num_graphs
302
303 @sync_all_reduce("_perfectEvent")
304 def compute(self):
305 """
306 Final computation.
307 """
308 if self._num_examples == 0:
309 raise NotComputableError(
310 "CustomAccuracy must have at least one example before it can be computed."
311 )
312 return self._per_corrects / self._num_examples
ignore_index
Ignore index.
Definition: metrics.py:237
_num_examples
Total samples.
Definition: metrics.py:243
def reset(self)
Definition: metrics.py:248
def compute(self)
Definition: metrics.py:304
_per_corrects
Good samples.
Definition: metrics.py:241
def __init__(self, ignore_index, output_transform, device='cpu')
Definition: metrics.py:232
def update(self, output)
Definition: metrics.py:258
device
CPU or GPU.
Definition: metrics.py:239
ignore_index
Ignore index.
Definition: metrics.py:46
_num_examples
Total samples.
Definition: metrics.py:52
def reset(self)
Definition: metrics.py:57
def compute(self)
Definition: metrics.py:101
_per_corrects
Good samples.
Definition: metrics.py:50
def __init__(self, ignore_index, output_transform, device='cpu')
Definition: metrics.py:41
def update(self, output)
Definition: metrics.py:67
device
CPU or GPU.
Definition: metrics.py:48
ignore_index
Ignore index.
Definition: metrics.py:140
_num_examples
Total samples.
Definition: metrics.py:146
def compute(self)
Definition: metrics.py:194
_per_corrects
Good samples.
Definition: metrics.py:144
def __init__(self, ignore_index, output_transform, device='cpu')
Definition: metrics.py:135
def update(self, output)
Definition: metrics.py:161
device
CPU or GPU.
Definition: metrics.py:142