Belle II Software light-2406-ragdoll
PerfectMasses Class Reference
Inheritance diagram for PerfectMasses:
Collaboration diagram for PerfectMasses:

Public Member Functions

def __init__ (self, ignore_index, output_transform, device='cpu')
 
def reset (self)
 
def update (self, output)
 
def compute (self)
 

Public Attributes

 ignore_index
 Ignore index.
 
 device
 CPU or GPU.
 

Protected Attributes

 _per_corrects
 Good samples.
 
 _num_examples
 Total samples.
 

Detailed Description

Computes the rate of events with perfectly predicted mass hypotheses over a batch.

``output_transform`` should return the following items: ``(x_pred, x_y, u_y, batch, num_graphs)``.

* ``x_pred`` must contain node prediction logits and have shape (num_nodes_in_batch, node_classes);
* ``x_y`` must contain node ground-truth class indices and have shape (num_nodes_in_batch, 1);
* ``u_y`` is the signal/background class (always 1 in the current setting);
* ``batch`` maps nodes to their graph;
* ``num_graphs`` is the number of graph in a batch (could be derived from ``batch`` also).

.. seealso::
    `Ignite metrics <https://pytorch.org/ignite/metrics.html>`_

:param ignore_index: Class or list of classes to ignore during the computation (e.g. padding).
:type ignore_index: list[int]
:param output_transform: Function to transform engine's output to desired output.
:type output_transform: `function <https://docs.python.org/3/glossary.html#term-function>`_
:param device: ``cpu`` or ``gpu``.
:type device: str

Definition at line 112 of file metrics.py.

Constructor & Destructor Documentation

◆ __init__()

def __init__ (   self,
  ignore_index,
  output_transform,
  device = 'cpu' 
)
Initialization.

Definition at line 135 of file metrics.py.

135 def __init__(self, ignore_index, output_transform, device='cpu'):
136 """
137 Initialization.
138 """
139
140 self.ignore_index = ignore_index if isinstance(ignore_index, list) else [ignore_index]
141
142 self.device = device
143
144 self._per_corrects = None
145
146 self._num_examples = None
147
148 super(PerfectMasses, self).__init__(output_transform=output_transform, device=device)
149

Member Function Documentation

◆ compute()

def compute (   self)
Final computation.

Definition at line 194 of file metrics.py.

194 def compute(self):
195 """
196 Final computation.
197 """
198 if self._num_examples == 0:
199 raise NotComputableError(
200 "CustomAccuracy must have at least one example before it can be computed."
201 )
202 return self._per_corrects / self._num_examples
203
204

◆ reset()

def reset (   self)
Resets counts.

Definition at line 151 of file metrics.py.

151 def reset(self):
152 """
153 Resets counts.
154 """
155 self._per_corrects = 0
156 self._num_examples = 0
157
158 super(PerfectMasses, self).reset()
159

◆ update()

def update (   self,
  output 
)
Updates counts.

Definition at line 161 of file metrics.py.

161 def update(self, output):
162 """
163 Updates counts.
164 """
165 x_pred, x_y, u_y, batch, num_graphs = output
166
167 num_graphs = num_graphs.item()
168
169 probs = torch.softmax(x_pred, dim=1)
170 winners = probs.argmax(dim=1)
171
172 assert winners.shape == x_y.shape, 'Mass predictions shape does not match target shape'
173
174 # Create a mask for the zeroth elements (padded entries)
175 mask = torch.ones(x_y.size(), dtype=torch.long, device=self.device)
176 for ig_class in self.ignore_index:
177 mask &= (x_y != ig_class)
178
179 # Zero the respective entries in the predictions
180 y_pred_mask = winners * mask
181 y_mask = x_y * mask
182
183 # (N) compare the masked predictions with the target. The padded will be equal due to masking
184 truth = y_pred_mask.eq(y_mask) + 0 # +0 so it's not bool but 0 and 1
185 truth = scatter(truth, batch, reduce="min")
186
187 # Count the number of zero wrong predictions across the batch
188 batch_perfect = truth.sum().item()
189
190 self._per_corrects += batch_perfect
191 self._num_examples += num_graphs
192

Member Data Documentation

◆ _num_examples

_num_examples
protected

Total samples.

Definition at line 146 of file metrics.py.

◆ _per_corrects

_per_corrects
protected

Good samples.

Definition at line 144 of file metrics.py.

◆ device

device

CPU or GPU.

Definition at line 142 of file metrics.py.

◆ ignore_index

ignore_index

Ignore index.

Definition at line 140 of file metrics.py.


The documentation for this class was generated from the following file: