232 def __init__(self, ignore_index, output_transform, device='cpu'):
237 self.
ignore_index = ignore_index
if isinstance(ignore_index, list)
else [ignore_index]
245 super(PerfectEvent, self).
__init__(output_transform=output_transform, device=device)
262 x_pred, x_y, edge_pred, edge_y, edge_index, u_y, batch, num_graphs = output
264 num_graphs = num_graphs.item()
266 x_probs = torch.softmax(x_pred, dim=1)
267 x_winners = x_probs.argmax(dim=1)
268 edge_probs = torch.softmax(edge_pred, dim=1)
269 edge_winners = edge_probs.argmax(dim=1)
271 assert x_winners.shape == x_y.shape,
'Mass predictions shape does not match target shape'
272 assert edge_winners.shape == edge_y.shape,
'Edge predictions shape does not match target shape'
275 x_mask = torch.ones(x_y.size(), dtype=torch.long, device=self.
device)
276 edge_mask = torch.ones(edge_y.size(), dtype=torch.long, device=self.
device)
278 x_mask &= (x_y != ig_class)
279 edge_mask &= (edge_y != ig_class)
282 x_pred_mask = x_winners * x_mask
283 x_mask = x_y * x_mask
284 edge_pred_mask = edge_winners * edge_mask
285 edge_mask = edge_y * edge_mask
289 x_truth = x_pred_mask.eq(x_mask) + 0
290 x_truth = scatter(x_truth, batch, reduce=
"min")
292 edge_truth = edge_pred_mask.eq(edge_mask) + 0
293 edge_truth = scatter(edge_truth, edge_index[0], reduce=
"min")
294 edge_truth = scatter(edge_truth, batch, reduce=
"min")
297 truth = x_truth.bool() & edge_truth.bool()
298 batch_perfect = (truth + 0).sum().item()