Belle II Software
development
multiTrain.py
1
8
9
10
from
torch
import
nn
11
12
13
class
MultiTrainLoss
(nn.Module):
14
"""
15
Sum of cross-entropies for training against LCAS and mass hypotheses.
16
17
Args:
18
alpha_mass (float): Weight of mass cross-entropy term in the loss.
19
ignore_index (int): Index to ignore in the computation (e.g. padding).
20
reduction (str): Type of reduction to be applied on the batch (``sum`` or ``mean``).
21
"""
22
23
def
__init__
(
24
self,
25
alpha_mass=0,
26
ignore_index=-1,
27
reduction="mean",
28
):
29
"""
30
Initialization
31
"""
32
super().
__init__
()
33
34
35
self.
alpha_mass
= alpha_mass
36
37
38
self.
LCA_CE
= nn.CrossEntropyLoss(
39
ignore_index=ignore_index, reduction=reduction
40
)
41
42
self.
mass_CE
= nn.CrossEntropyLoss(
43
ignore_index=ignore_index, reduction=reduction
44
)
45
46
assert
alpha_mass >= 0,
"Alpha should be positive"
47
48
def
forward
(self, x_input, x_target, edge_input, edge_target, u_input, u_target):
49
"""
50
Called internally by PyTorch to propagate the input.
51
"""
52
53
LCA_loss = self.
LCA_CE
(
54
edge_input,
55
edge_target,
56
)
57
58
mass_loss = (
59
self.
mass_CE
(
60
x_input,
61
x_target,
62
)
63
if
self.
alpha_mass
> 0
64
else
0
65
)
66
67
return
LCA_loss + self.
alpha_mass
* mass_loss
multiTrain.MultiTrainLoss
Definition
multiTrain.py:13
multiTrain.MultiTrainLoss.__init__
__init__(self, alpha_mass=0, ignore_index=-1, reduction="mean")
Definition
multiTrain.py:28
multiTrain.MultiTrainLoss.forward
forward(self, x_input, x_target, edge_input, edge_target, u_input, u_target)
Definition
multiTrain.py:48
multiTrain.MultiTrainLoss.LCA_CE
LCA_CE
LCA cross-entropy.
Definition
multiTrain.py:38
multiTrain.MultiTrainLoss.alpha_mass
alpha_mass
Parameter controlling the importance of mass term in loss.
Definition
multiTrain.py:35
multiTrain.MultiTrainLoss.mass_CE
mass_CE
Mass cross-entropy.
Definition
multiTrain.py:42
analysis
scripts
grafei
model
multiTrain.py
Generated on Mon Sep 1 2025 02:45:57 for Belle II Software by
1.13.2