11from torch_scatter
import scatter
12from ignite.metrics
import Metric
13from ignite.exceptions
import NotComputableError
14from ignite.metrics.metric
import sync_all_reduce, reinit__is_reduced
19 Computes the rate of perfectly predicted LCAS matrices over a batch.
21 ``output_transform`` should return the following items: ``(edge_pred, edge_y, edge_index, u_y, batch, num_graphs)``.
23 * ``edge_pred`` must contain edge prediction logits
and have shape (num_edges_in_batch, edge_classes);
24 * ``edge_y`` must contain edge ground-truth
class indices and have shape (num_edges_in_batch, 1);
25 * ``edge index`` maps edges to its nodes;
26 * ``u_y``
is the signal/background
class (always 1
in the current setting);
27 * ``batch`` maps nodes to their graph;
28 * ``num_graphs``
is the number of graph
in a batch (could be derived
from ``batch`` also).
31 `Ignite metrics <https://pytorch.org/ignite/metrics.html>`_
33 :param ignore_index: Class
or list of classes to ignore during the computation (e.g. padding).
34 :type ignore_index: list[int]
35 :param output_transform: Function to transform engine
's output to desired output.
36 :type output_transform: `function <https://docs.python.org/3/glossary.html
37 :param device: ``cpu``
or ``gpu``.
41 def __init__(self, ignore_index, output_transform, device='cpu'):
46 self.ignore_index = ignore_index if isinstance(ignore_index, list)
else [ignore_index]
54 super(PerfectLCA, self).
__init__(output_transform=output_transform, device=device)
64 super(PerfectLCA, self).reset()
71 edge_pred, edge_y, edge_index, u_y, batch, num_graphs = output
73 num_graphs = num_graphs.item()
75 probs = torch.softmax(edge_pred, dim=1)
76 winners = probs.argmax(dim=1)
78 assert winners.shape == edge_y.shape,
'Edge predictions shape does not match target shape'
81 mask = torch.ones(edge_y.size(), dtype=torch.long, device=self.
device)
83 mask &= (edge_y != ig_class)
86 y_pred_mask = winners * mask
87 y_mask = edge_y * mask
90 truth = y_pred_mask.eq(y_mask) + 0
91 truth = scatter(truth, edge_index[0], reduce=
"min")
92 truth = scatter(truth, batch, reduce=
"min")
95 batch_perfect = truth.sum().item()
100 @sync_all_reduce("_perfectLCA")
106 raise NotComputableError(
107 "CustomAccuracy must have at least one example before it can be computed."
114 Computes the rate of events with perfectly predicted mass hypotheses over a batch.
116 ``output_transform`` should
return the following items: ``(x_pred, x_y, u_y, batch, num_graphs)``.
118 * ``x_pred`` must contain node prediction logits
and have shape (num_nodes_in_batch, node_classes);
119 * ``x_y`` must contain node ground-truth
class indices and have shape (num_nodes_in_batch, 1);
120 * ``u_y``
is the signal/background
class (always 1
in the current setting);
121 * ``batch`` maps nodes to their graph;
122 * ``num_graphs``
is the number of graph
in a batch (could be derived
from ``batch`` also).
125 `Ignite metrics <https://pytorch.org/ignite/metrics.html>`_
127 :param ignore_index: Class
or list of classes to ignore during the computation (e.g. padding).
128 :type ignore_index: list[int]
129 :param output_transform: Function to transform engine
's output to desired output.
130 :type output_transform: `function <https://docs.python.org/3/glossary.html
131 :param device: ``cpu``
or ``gpu``.
135 def __init__(self, ignore_index, output_transform, device='cpu'):
140 self.ignore_index = ignore_index if isinstance(ignore_index, list)
else [ignore_index]
148 super(PerfectMasses, self).
__init__(output_transform=output_transform, device=device)
158 super(PerfectMasses, self).reset()
165 x_pred, x_y, u_y, batch, num_graphs = output
167 num_graphs = num_graphs.item()
169 probs = torch.softmax(x_pred, dim=1)
170 winners = probs.argmax(dim=1)
172 assert winners.shape == x_y.shape,
'Mass predictions shape does not match target shape'
175 mask = torch.ones(x_y.size(), dtype=torch.long, device=self.
device)
177 mask &= (x_y != ig_class)
180 y_pred_mask = winners * mask
184 truth = y_pred_mask.eq(y_mask) + 0
185 truth = scatter(truth, batch, reduce=
"min")
188 batch_perfect = truth.sum().item()
193 @sync_all_reduce("_perfectMasses")
199 raise NotComputableError(
200 "CustomAccuracy must have at least one example before it can be computed."
207 Computes the rate of events with perfectly predicted mass hypotheses
and LCAS matrices over a batch.
209 ``output_transform`` should
return the following items:
210 ``(x_pred, x_y, edge_pred, edge_y, edge_index, u_y, batch, num_graphs)``.
212 * ``x_pred`` must contain node prediction logits
and have shape (num_nodes_in_batch, node_classes);
213 * ``x_y`` must contain node ground-truth
class indices and have shape (num_nodes_in_batch, 1);
214 * ``edge_pred`` must contain edge prediction logits
and have shape (num_edges_in_batch, edge_classes);
215 * ``edge_y`` must contain edge ground-truth
class indices and have shape (num_edges_in_batch, 1);
216 * ``edge index`` maps edges to its nodes;
217 * ``u_y``
is the signal/background
class (always 1
in the current setting);
218 * ``batch`` maps nodes to their graph;
219 * ``num_graphs``
is the number of graph
in a batch (could be derived
from ``batch`` also).
222 `Ignite metrics <https://pytorch.org/ignite/metrics.html>`_
224 :param ignore_index: Class
or list of classes to ignore during the computation (e.g. padding).
225 :type ignore_index: list[int]
226 :param output_transform: Function to transform engine
's output to desired output.
227 :type output_transform: `function <https://docs.python.org/3/glossary.html
228 :param device: ``cpu``
or ``gpu``.
232 def __init__(self, ignore_index, output_transform, device='cpu'):
237 self.ignore_index = ignore_index if isinstance(ignore_index, list)
else [ignore_index]
245 super(PerfectEvent, self).
__init__(output_transform=output_transform, device=device)
255 super(PerfectEvent, self).reset()
262 x_pred, x_y, edge_pred, edge_y, edge_index, u_y, batch, num_graphs = output
264 num_graphs = num_graphs.item()
266 x_probs = torch.softmax(x_pred, dim=1)
267 x_winners = x_probs.argmax(dim=1)
268 edge_probs = torch.softmax(edge_pred, dim=1)
269 edge_winners = edge_probs.argmax(dim=1)
271 assert x_winners.shape == x_y.shape,
'Mass predictions shape does not match target shape'
272 assert edge_winners.shape == edge_y.shape,
'Edge predictions shape does not match target shape'
275 x_mask = torch.ones(x_y.size(), dtype=torch.long, device=self.
device)
276 edge_mask = torch.ones(edge_y.size(), dtype=torch.long, device=self.
device)
278 x_mask &= (x_y != ig_class)
279 edge_mask &= (edge_y != ig_class)
282 x_pred_mask = x_winners * x_mask
283 x_mask = x_y * x_mask
284 edge_pred_mask = edge_winners * edge_mask
285 edge_mask = edge_y * edge_mask
289 x_truth = x_pred_mask.eq(x_mask) + 0
290 x_truth = scatter(x_truth, batch, reduce=
"min")
292 edge_truth = edge_pred_mask.eq(edge_mask) + 0
293 edge_truth = scatter(edge_truth, edge_index[0], reduce=
"min")
294 edge_truth = scatter(edge_truth, batch, reduce=
"min")
297 truth = x_truth.bool() & edge_truth.bool()
298 batch_perfect = (truth + 0).sum().item()
303 @sync_all_reduce("_perfectEvent")
309 raise NotComputableError(
310 "CustomAccuracy must have at least one example before it can be computed."
ignore_index
Ignore index.
_num_examples
Total samples.
_per_corrects
Good samples.
def __init__(self, ignore_index, output_transform, device='cpu')
ignore_index
Ignore index.
_num_examples
Total samples.
_per_corrects
Good samples.
def __init__(self, ignore_index, output_transform, device='cpu')
ignore_index
Ignore index.
_num_examples
Total samples.
_per_corrects
Good samples.
def __init__(self, ignore_index, output_transform, device='cpu')