Belle II Software development
MultiTrainLoss Class Reference
Inheritance diagram for MultiTrainLoss:

Public Member Functions

def __init__ (self, alpha_mass=0, ignore_index=-1, reduction="mean")
 
def forward (self, x_input, x_target, edge_input, edge_target, u_input, u_target)
 

Public Attributes

 alpha_mass
 Parameter controlling the importance of mass term in loss.
 
 LCA_CE
 LCA cross-entropy.
 
 mass_CE
 Mass cross-entropy.
 

Detailed Description

Sum of cross-entropies for training against LCAS and mass hypotheses.

Args:
    alpha_mass (float): Weight of mass cross-entropy term in the loss.
    ignore_index (int): Index to ignore in the computation (e.g. padding).
    reduction (str): Type of reduction to be applied on the batch (``sum`` or ``mean``).

Definition at line 13 of file multiTrain.py.

Constructor & Destructor Documentation

◆ __init__()

def __init__ (   self,
  alpha_mass = 0,
  ignore_index = -1,
  reduction = "mean" 
)
Initialization

Definition at line 23 of file multiTrain.py.

28 ):
29 """
30 Initialization
31 """
32 super().__init__()
33
34
35 self.alpha_mass = alpha_mass
36
37
38 self.LCA_CE = nn.CrossEntropyLoss(
39 ignore_index=ignore_index, reduction=reduction
40 )
41
42 self.mass_CE = nn.CrossEntropyLoss(
43 ignore_index=ignore_index, reduction=reduction
44 )
45
46 assert alpha_mass >= 0, "Alpha should be positive"
47

Member Function Documentation

◆ forward()

def forward (   self,
  x_input,
  x_target,
  edge_input,
  edge_target,
  u_input,
  u_target 
)
Called internally by PyTorch to propagate the input.

Definition at line 48 of file multiTrain.py.

48 def forward(self, x_input, x_target, edge_input, edge_target, u_input, u_target):
49 """
50 Called internally by PyTorch to propagate the input.
51 """
52
53 LCA_loss = self.LCA_CE(
54 edge_input,
55 edge_target,
56 )
57
58 mass_loss = (
59 self.mass_CE(
60 x_input,
61 x_target,
62 )
63 if self.alpha_mass > 0
64 else 0
65 )
66
67 return LCA_loss + self.alpha_mass * mass_loss

Member Data Documentation

◆ alpha_mass

alpha_mass

Parameter controlling the importance of mass term in loss.

Definition at line 35 of file multiTrain.py.

◆ LCA_CE

LCA_CE

LCA cross-entropy.

Definition at line 38 of file multiTrain.py.

◆ mass_CE

mass_CE

Mass cross-entropy.

Definition at line 42 of file multiTrain.py.


The documentation for this class was generated from the following file: