Belle II Software development
torch_tcce.py
1#!/usr/bin/env python3
2
3
10
11import torch
12import torch.nn as nn
13import numpy as np
14
15
16def tcc(y: torch.tensor, y_true: torch.LongTensor, n: int) -> torch.tensor:
17 """
18 Calculates loss using the required number of Taylor terms of cross entropy loss.
19
20 Parameters:
21 y(torch.tensor): The probabilities predicted by ML model.
22 y_true(torch.LongTensor): The truth values provided for training purposes (1D tensor).
23 n(int): Number of terms to to be taken for the Taylor Series.
24
25 Returns:
26 A torch tesor with the value of the calculated Taylor cross entropy loss.
27
28 Note:
29 With n = 0, this returns the regular cross entropy loss.
30
31 """
32 loss = torch.zeros(len(y_true))
33 if torch.cuda.is_available():
34 loss = loss.to("cuda")
35 ProbTrue = y[np.arange(len(y_true)), y_true]
36 if n != 0:
37 for i in range(1, n + 1):
38 loss += torch.pow(1 - ProbTrue, i) / i
39 elif n == 0:
40 loss = -1 * torch.log(ProbTrue)
41 loss = torch.sum(loss)
42 return loss / len(y)
43
44
45class TCCE(nn.Module):
46 """
47 Class for calculation of Taylor cross entropy loss.
48
49 Attributes:
50 n (int): Number of Taylor series terms to be used for loss calculation.
51
52 """
53
54 def __init__(self, n: int = 0):
55 """
56 Initialize the loss class.
57
58 Parameters:
59 n (int)(optional): Number of Taylor series terms to be used for loss calculation.
60
61 """
62 super().__init__()
63
64 self.n = n
65
66 def forward(self, y: torch.tensor, y_true: torch.LongTensor) -> torch.tensor:
67 """
68 Calculates the Taylor categorical cross entropy loss.
69
70 Parameters:
71 y(torch.tensor): Tensor containing the output of the model.
72 y_true(torch.tensor): 1D tensor containing the truth value for a given set of features.
73
74 Returns:
75 The calculated loss as a torch tensor.
76 """
77 return tcc(y, y_true, self.n)
def __init__(self, int n=0)
Definition: torch_tcce.py:54
torch.tensor forward(self, torch.tensor y, torch.LongTensor y_true)
Definition: torch_tcce.py:66
n
Number of Taylor terms.
Definition: torch_tcce.py:64