Belle II Software  light-2403-persian
multiTrain.py
1 
8 
9 
10 from torch import nn
11 
12 
13 class MultiTrainLoss(nn.Module):
14  """
15  Sum of cross-entropies for training against LCAS and mass hypotheses.
16 
17  Args:
18  alpha_mass (float): Weight of mass cross-entropy term in the loss.
19  ignore_index (int): Index to ignore in the computation (e.g. padding).
20  reduction (str): Type of reduction to be applied on the batch (``sum`` or ``mean``).
21  """
22 
23  def __init__(
24  self,
25  alpha_mass=0,
26  ignore_index=-1,
27  reduction="mean",
28  ):
29  """
30  Initialization
31  """
32  super().__init__()
33 
34 
35  self.alpha_massalpha_mass = alpha_mass
36 
37 
38  self.LCA_CELCA_CE = nn.CrossEntropyLoss(
39  ignore_index=ignore_index, reduction=reduction
40  )
41 
42  self.mass_CEmass_CE = nn.CrossEntropyLoss(
43  ignore_index=ignore_index, reduction=reduction
44  )
45 
46  assert alpha_mass >= 0, "Alpha should be positive"
47 
48  def forward(self, x_input, x_target, edge_input, edge_target, u_input, u_target):
49  """
50  Called internally by PyTorch to propagate the input.
51  """
52 
53  LCA_loss = self.LCA_CELCA_CE(
54  edge_input,
55  edge_target,
56  )
57 
58  mass_loss = (
59  self.mass_CEmass_CE(
60  x_input,
61  x_target,
62  )
63  if self.alpha_massalpha_mass > 0
64  else 0
65  )
66 
67  return LCA_loss + self.alpha_massalpha_mass * mass_loss
def forward(self, x_input, x_target, edge_input, edge_target, u_input, u_target)
Definition: multiTrain.py:48
LCA_CE
LCA cross-entropy.
Definition: multiTrain.py:38
alpha_mass
Parameter controlling the importance of mass term in loss.
Definition: multiTrain.py:35
def __init__(self, alpha_mass=0, ignore_index=-1, reduction="mean")
Definition: multiTrain.py:28
mass_CE
Mass cross-entropy.
Definition: multiTrain.py:42