15 Sum of cross-entropies for training against LCAS
and mass hypotheses.
18 alpha_mass (float): Weight of mass cross-entropy term
in the loss.
19 ignore_index (int): Index to ignore
in the computation (e.g. padding).
20 reduction (str): Type of reduction to be applied on the batch (``sum``
or ``mean``).
39 ignore_index=ignore_index, reduction=reduction
43 ignore_index=ignore_index, reduction=reduction
46 assert alpha_mass >= 0,
"Alpha should be positive"
48 def forward(self, x_input, x_target, edge_input, edge_target, u_input, u_target):
50 Called internally by PyTorch to propagate the input.
def forward(self, x_input, x_target, edge_input, edge_target, u_input, u_target)
alpha_mass
Parameter controlling the importance of mass term in loss.
def __init__(self, alpha_mass=0, ignore_index=-1, reduction="mean")
mass_CE
Mass cross-entropy.