Belle II Software  light-2403-persian
MultiTrainLoss Class Reference
Inheritance diagram for MultiTrainLoss:
Collaboration diagram for MultiTrainLoss:

Public Member Functions

def __init__ (self, alpha_mass=0, ignore_index=-1, reduction="mean")
 
def forward (self, x_input, x_target, edge_input, edge_target, u_input, u_target)
 

Public Attributes

 alpha_mass
 Parameter controlling the importance of mass term in loss.
 
 LCA_CE
 LCA cross-entropy.
 
 mass_CE
 Mass cross-entropy.
 

Detailed Description

Sum of cross-entropies for training against LCAS and mass hypotheses.

Args:
    alpha_mass (float): Weight of mass cross-entropy term in the loss.
    ignore_index (int): Index to ignore in the computation (e.g. padding).
    reduction (str): Type of reduction to be applied on the batch (``sum`` or ``mean``).

Definition at line 13 of file multiTrain.py.

Constructor & Destructor Documentation

◆ __init__()

def __init__ (   self,
  alpha_mass = 0,
  ignore_index = -1,
  reduction = "mean" 
)
Initialization

Definition at line 23 of file multiTrain.py.

28  ):
29  """
30  Initialization
31  """
32  super().__init__()
33 
34 
35  self.alpha_mass = alpha_mass
36 
37 
38  self.LCA_CE = nn.CrossEntropyLoss(
39  ignore_index=ignore_index, reduction=reduction
40  )
41 
42  self.mass_CE = nn.CrossEntropyLoss(
43  ignore_index=ignore_index, reduction=reduction
44  )
45 
46  assert alpha_mass >= 0, "Alpha should be positive"
47 

Member Function Documentation

◆ forward()

def forward (   self,
  x_input,
  x_target,
  edge_input,
  edge_target,
  u_input,
  u_target 
)
Called internally by PyTorch to propagate the input.

Definition at line 48 of file multiTrain.py.


The documentation for this class was generated from the following file: