Belle II Software development
PerfectEvent Class Reference
Inheritance diagram for PerfectEvent:

Public Member Functions

 __init__ (self, ignore_index, output_transform, device='cpu')
 
 reset (self)
 
 update (self, output)
 
 compute (self)
 

Public Attributes

 ignore_index = ignore_index if isinstance(ignore_index, list) else [ignore_index]
 Ignore index.
 
 device = device
 CPU or GPU.
 

Protected Attributes

int _per_corrects = None
 Good samples.
 
int _num_examples = None
 Total samples.
 

Detailed Description

Computes the rate of events with perfectly predicted mass hypotheses and LCAS matrices over a batch. ``output_transform`` should return the following items: ``(x_pred, x_y, edge_pred, edge_y, edge_index, u_y, batch, num_graphs)``. * ``x_pred`` must contain node prediction logits and have shape (num_nodes_in_batch, node_classes); * ``x_y`` must contain node ground-truth class indices and have shape (num_nodes_in_batch, 1); * ``edge_pred`` must contain edge prediction logits and have shape (num_edges_in_batch, edge_classes); * ``edge_y`` must contain edge ground-truth class indices and have shape (num_edges_in_batch, 1); * ``edge index`` maps edges to its nodes; * ``u_y`` is the signal/background class (always 1 in the current setting); * ``batch`` maps nodes to their graph; * ``num_graphs`` is the number of graph in a batch (could be derived from ``batch`` also). .. seealso:: `Ignite metrics <https://pytorch.org/ignite/metrics.html>`_ :param ignore_index: Class or list of classes to ignore during the computation (e.g. padding). :type ignore_index: list[int] :param output_transform: Function to transform engine's output to desired output. :type output_transform: `function <https://docs.python.org/3/glossary.html#term-function>`_ :param device: ``cpu`` or ``gpu``. :type device: str

Definition at line 205 of file metrics.py.

Constructor & Destructor Documentation

◆ __init__()

__init__ ( self,
ignore_index,
output_transform,
device = 'cpu' )
Initialization.

Definition at line 232 of file metrics.py.

232 def __init__(self, ignore_index, output_transform, device='cpu'):
233 """
234 Initialization.
235 """
236
237 self.ignore_index = ignore_index if isinstance(ignore_index, list) else [ignore_index]
238
239 self.device = device
240
241 self._per_corrects = None
242
243 self._num_examples = None
244
245 super(PerfectEvent, self).__init__(output_transform=output_transform, device=device)
246

Member Function Documentation

◆ compute()

compute ( self)
Final computation.

Definition at line 304 of file metrics.py.

304 def compute(self):
305 """
306 Final computation.
307 """
308 if self._num_examples == 0:
309 raise NotComputableError(
310 "CustomAccuracy must have at least one example before it can be computed."
311 )
312 return self._per_corrects / self._num_examples

◆ reset()

reset ( self)
Resets counts.

Definition at line 248 of file metrics.py.

248 def reset(self):
249 """
250 Resets counts.
251 """
252 self._per_corrects = 0
253 self._num_examples = 0
254
255 super(PerfectEvent, self).reset()
256

◆ update()

update ( self,
output )
Updates counts.

Definition at line 258 of file metrics.py.

258 def update(self, output):
259 """
260 Updates counts.
261 """
262 x_pred, x_y, edge_pred, edge_y, edge_index, u_y, batch, num_graphs = output
263
264 num_graphs = num_graphs.item()
265
266 x_probs = torch.softmax(x_pred, dim=1)
267 x_winners = x_probs.argmax(dim=1)
268 edge_probs = torch.softmax(edge_pred, dim=1)
269 edge_winners = edge_probs.argmax(dim=1)
270
271 assert x_winners.shape == x_y.shape, 'Mass predictions shape does not match target shape'
272 assert edge_winners.shape == edge_y.shape, 'Edge predictions shape does not match target shape'
273
274 # Create a mask for the zeroth elements (padded entries)
275 x_mask = torch.ones(x_y.size(), dtype=torch.long, device=self.device)
276 edge_mask = torch.ones(edge_y.size(), dtype=torch.long, device=self.device)
277 for ig_class in self.ignore_index:
278 x_mask &= (x_y != ig_class)
279 edge_mask &= (edge_y != ig_class)
280
281 # Zero the respective entries in the predictions
282 x_pred_mask = x_winners * x_mask
283 x_mask = x_y * x_mask
284 edge_pred_mask = edge_winners * edge_mask
285 edge_mask = edge_y * edge_mask
286
287 # (N) compare the masked predictions with the target. The padded will be equal due to masking
288 # Masses
289 x_truth = x_pred_mask.eq(x_mask) + 0 # +0 so it's not bool but 0 and 1
290 x_truth = scatter(x_truth, batch, reduce="min")
291 # Edges
292 edge_truth = edge_pred_mask.eq(edge_mask) + 0 # +0 so it's not bool but 0 and 1
293 edge_truth = scatter(edge_truth, edge_index[0], reduce="min")
294 edge_truth = scatter(edge_truth, batch, reduce="min")
295
296 # Count the number of zero wrong predictions across the batch
297 truth = x_truth.bool() & edge_truth.bool()
298 batch_perfect = (truth + 0).sum().item()
299
300 self._per_corrects += batch_perfect
301 self._num_examples += num_graphs
302

Member Data Documentation

◆ _num_examples

int _num_examples = None
protected

Total samples.

Definition at line 243 of file metrics.py.

◆ _per_corrects

int _per_corrects = None
protected

Good samples.

Definition at line 241 of file metrics.py.

◆ device

device = device

CPU or GPU.

Definition at line 239 of file metrics.py.

◆ ignore_index

ignore_index = ignore_index if isinstance(ignore_index, list) else [ignore_index]

Ignore index.

Definition at line 237 of file metrics.py.


The documentation for this class was generated from the following file: