15 Sum of cross-entropies for training against LCAS and mass hypotheses.
18 alpha_mass (float): Weight of mass cross-entropy term in the loss.
19 ignore_index (int): Index to ignore in the computation (e.g. padding).
20 reduction (str): Type of reduction to be applied on the batch (``sum`` or ``mean``).
38 self.
LCA_CELCA_CE = nn.CrossEntropyLoss(
39 ignore_index=ignore_index, reduction=reduction
42 self.
mass_CEmass_CE = nn.CrossEntropyLoss(
43 ignore_index=ignore_index, reduction=reduction
46 assert alpha_mass >= 0,
"Alpha should be positive"
48 def forward(self, x_input, x_target, edge_input, edge_target, u_input, u_target):
50 Called internally by PyTorch to propagate the input.
53 LCA_loss = self.
LCA_CELCA_CE(
67 return LCA_loss + self.
alpha_massalpha_mass * mass_loss
def forward(self, x_input, x_target, edge_input, edge_target, u_input, u_target)
alpha_mass
Parameter controlling the importance of mass term in loss.
def __init__(self, alpha_mass=0, ignore_index=-1, reduction="mean")
mass_CE
Mass cross-entropy.