Belle II Software development
multiTrain.py
1
8
9
10from torch import nn
11
12
13class MultiTrainLoss(nn.Module):
14 """
15 Sum of cross-entropies for training against LCAS and mass hypotheses.
16
17 Args:
18 alpha_mass (float): Weight of mass cross-entropy term in the loss.
19 ignore_index (int): Index to ignore in the computation (e.g. padding).
20 reduction (str): Type of reduction to be applied on the batch (``sum`` or ``mean``).
21 """
22
24 self,
25 alpha_mass=0,
26 ignore_index=-1,
27 reduction="mean",
28 ):
29 """
30 Initialization
31 """
32 super().__init__()
33
34
35 self.alpha_mass = alpha_mass
36
37
38 self.LCA_CE = nn.CrossEntropyLoss(
39 ignore_index=ignore_index, reduction=reduction
40 )
41
42 self.mass_CE = nn.CrossEntropyLoss(
43 ignore_index=ignore_index, reduction=reduction
44 )
45
46 assert alpha_mass >= 0, "Alpha should be positive"
47
48 def forward(self, x_input, x_target, edge_input, edge_target, u_input, u_target):
49 """
50 Called internally by PyTorch to propagate the input.
51 """
52
53 LCA_loss = self.LCA_CE(
54 edge_input,
55 edge_target,
56 )
57
58 mass_loss = (
59 self.mass_CE(
60 x_input,
61 x_target,
62 )
63 if self.alpha_mass > 0
64 else 0
65 )
66
67 return LCA_loss + self.alpha_mass * mass_loss
def forward(self, x_input, x_target, edge_input, edge_target, u_input, u_target)
Definition: multiTrain.py:48
LCA_CE
LCA cross-entropy.
Definition: multiTrain.py:38
alpha_mass
Parameter controlling the importance of mass term in loss.
Definition: multiTrain.py:35
def __init__(self, alpha_mass=0, ignore_index=-1, reduction="mean")
Definition: multiTrain.py:28
mass_CE
Mass cross-entropy.
Definition: multiTrain.py:42