Belle II Software  release-08-01-10
torch_tcce.py
1 #!/usr/bin/env python3
2 
3 
10 
11 import torch
12 import torch.nn as nn
13 import numpy as np
14 
15 
16 def tcc(y: torch.tensor, y_true: torch.LongTensor, n: int) -> torch.tensor:
17  """
18  Calculates loss using the required number of Taylor terms of cross entropy loss.
19 
20  Parameters:
21  y(torch.tensor): The probabilities predicted by ML model.
22  y_true(torch.LongTensor): The truth values provided for training purposes (1D tensor).
23  n(int): Number of terms to to be taken for the Taylor Series.
24 
25  Returns:
26  A torch tesor with the value of the calculated Taylor cross entropy loss.
27 
28  Note:
29  With n = 0, this returns the regular cross entropy loss.
30 
31  """
32  loss = torch.zeros(len(y_true))
33  if torch.cuda.is_available():
34  loss = loss.to("cuda")
35  ProbTrue = y[np.arange(len(y_true)), y_true]
36  if n != 0:
37  for i in range(1, n + 1):
38  loss += torch.pow(1 - ProbTrue, i) / i
39  elif n == 0:
40  loss = -1 * torch.log(ProbTrue)
41  loss = torch.sum(loss)
42  return loss / len(y)
43 
44 
45 class TCCE(nn.Module):
46  """
47  Class for calculation of Taylor cross entropy loss.
48 
49  Attributes:
50  n (int): Number of Taylor series terms to be used for loss calculation.
51 
52  """
53 
54  def __init__(self, n: int = 0):
55  """
56  Initialize the loss class.
57 
58  Parameters:
59  n (int)(optional): Number of Taylor series terms to be used for loss calculation.
60 
61  """
62  super().__init__()
63 
64  self.nn = n
65 
66  def forward(self, y: torch.tensor, y_true: torch.LongTensor) -> torch.tensor:
67  """
68  Calculates the Taylor categorical cross entropy loss.
69 
70  Parameters:
71  y(torch.tensor): Tensor containing the output of the model.
72  y_true(torch.tensor): 1D tensor containing the truth value for a given set of features.
73 
74  Returns:
75  The calculated loss as a torch tensor.
76  """
77  return tcc(y, y_true, self.nn)
def __init__(self, int n=0)
Definition: torch_tcce.py:54
torch.tensor forward(self, torch.tensor y, torch.LongTensor y_true)
Definition: torch_tcce.py:66
n
Number of Taylor terms.
Definition: torch_tcce.py:64