Belle II Software light-2406-ragdoll
WeightNet Class Reference
Inheritance diagram for WeightNet:
Collaboration diagram for WeightNet:

Public Member Functions

def __init__ (self, n_class=6, n_detector=6, const_init=1)
 
def forward (self, x)
 
def get_weights (self, to_numpy=True)
 
def const_init (self, const)
 
def random_init (self, mean=1.0, std=0.5)
 
def kill_unused (self, only)
 

Public Attributes

 n_class
 number of particle types
 
 n_detector
 number of detectors
 
 fcs
 linear layers for each particle type
 

Detailed Description

PyTorch architecture for training calibration weights.

Definition at line 62 of file pidTrainWeights.py.

Constructor & Destructor Documentation

◆ __init__()

def __init__ (   self,
  n_class = 6,
  n_detector = 6,
  const_init = 1 
)
Initialize the network for training.

Args:
    n_class (int, optional): Number of classification classes (particle
        types). Defaults to 6.
    n_detector (int, optional): Number of detectors. Defaults to 6.
    const_init (int, optional): Constant value to initialize all
        weights. If None, PyTorch's default random initialization is
        used instead. Defaults to 1.

Definition at line 65 of file pidTrainWeights.py.

65 def __init__(self, n_class=6, n_detector=6, const_init=1):
66 """Initialize the network for training.
67
68 Args:
69 n_class (int, optional): Number of classification classes (particle
70 types). Defaults to 6.
71 n_detector (int, optional): Number of detectors. Defaults to 6.
72 const_init (int, optional): Constant value to initialize all
73 weights. If None, PyTorch's default random initialization is
74 used instead. Defaults to 1.
75 """
76 super().__init__()
77
78
79 self.n_class = n_class
80
81
82 self.n_detector = n_detector
83
84
85 self.fcs = nn.ModuleList(
86 [nn.Linear(self.n_detector, 1, bias=False) for _ in range(self.n_class)]
87 )
88
89 if const_init is not None:
90 self.const_init(const_init)
91

Member Function Documentation

◆ const_init()

def const_init (   self,
  const 
)
Fill all the weights with the given value.

Args:
    const (float): Constant value to fill all weights with.

Definition at line 125 of file pidTrainWeights.py.

125 def const_init(self, const):
126 """Fill all the weights with the given value.
127
128 Args:
129 const (float): Constant value to fill all weights with.
130 """
131 with torch.no_grad():
132 for fc in self.fcs:
133 fc.weight.fill_(const)
134

◆ forward()

def forward (   self,
  x 
)
Network's forward methods. Sums the detector log-likelihoods for each particle
type, then computes the likelihood ratios. Uses the weights.

Args:
    x (torch.Tensor): Input detector log-likelihood data. Should be of
        shape (N, n_detector * n_class), where N is the number of samples.

Returns:
    torch.Tensor: Weighted likelihood ratios.

Definition at line 92 of file pidTrainWeights.py.

92 def forward(self, x):
93 """Network's forward methods. Sums the detector log-likelihoods for each particle
94 type, then computes the likelihood ratios. Uses the weights.
95
96 Args:
97 x (torch.Tensor): Input detector log-likelihood data. Should be of
98 shape (N, n_detector * n_class), where N is the number of samples.
99
100 Returns:
101 torch.Tensor: Weighted likelihood ratios.
102 """
103 n = self.n_detector
104 outs = [self.fcs[i](x[:, i * n: (i + 1) * n]) for i in range(self.n_class)]
105 out = torch.cat(outs, dim=1)
106 return F.softmax(out, dim=1)
107

◆ get_weights()

def get_weights (   self,
  to_numpy = True 
)
Returns the weights as a six-by-six array or tensor.

Args:
    to_numpy (bool, optional): Whether to return the weights as a numpy
        array (True) or torch tensor (False). Defaults to True.

Returns:
    np.array or torch.Tensor: The six-by-six matrix containing the
        weights.

Definition at line 108 of file pidTrainWeights.py.

108 def get_weights(self, to_numpy=True):
109 """Returns the weights as a six-by-six array or tensor.
110
111 Args:
112 to_numpy (bool, optional): Whether to return the weights as a numpy
113 array (True) or torch tensor (False). Defaults to True.
114
115 Returns:
116 np.array or torch.Tensor: The six-by-six matrix containing the
117 weights.
118 """
119 weights = torch.cat([fc.weight.detach() for fc in self.fcs])
120 if to_numpy:
121 return weights.cpu().numpy()
122 else:
123 return weights
124

◆ kill_unused()

def kill_unused (   self,
  only 
)
Kills weights corresponding to unused particle types.

Args:
    only (list(str) or None): List of allowed particle types. The
        weights corresponding to any particle types that are _not_ in
        this list will be filled with zero and be frozen (e.g. gradients
        will not be computed/updated).

Definition at line 150 of file pidTrainWeights.py.

150 def kill_unused(self, only):
151 """Kills weights corresponding to unused particle types.
152
153 Args:
154 only (list(str) or None): List of allowed particle types. The
155 weights corresponding to any particle types that are _not_ in
156 this list will be filled with zero and be frozen (e.g. gradients
157 will not be computed/updated).
158 """
159 if only is not None:
160 # particle types that are not being trained...
161 # set to zero and freeze
162 for i, pdg in enumerate(PDG_CODES):
163 if pdg in only:
164 continue
165 self.fcs[i].weight.requires_grad = False
166 self.fcs[i].weight.fill_(1)
167
168

◆ random_init()

def random_init (   self,
  mean = 1.0,
  std = 0.5 
)
Fill all the weights with values sampled from a Normal distribution
with given mean and standard deviation.

Args:
    mean (float, optional): The mean of the Normal distribution.
        Defaults to 1.0.
    std (float, optional): The standard deviation of the Normal
        distribution. Defaults to 0.5.

Definition at line 135 of file pidTrainWeights.py.

135 def random_init(self, mean=1.0, std=0.5):
136 """Fill all the weights with values sampled from a Normal distribution
137 with given mean and standard deviation.
138
139 Args:
140 mean (float, optional): The mean of the Normal distribution.
141 Defaults to 1.0.
142 std (float, optional): The standard deviation of the Normal
143 distribution. Defaults to 0.5.
144 """
145 with torch.no_grad():
146 for fc in self.fcs:
147 fc.weight.fill_(0)
148 fc.weight.add_(torch.normal(mean=mean, std=std, size=fc.weight.size()))
149

Member Data Documentation

◆ fcs

fcs

linear layers for each particle type

Definition at line 85 of file pidTrainWeights.py.

◆ n_class

n_class

number of particle types

Definition at line 79 of file pidTrainWeights.py.

◆ n_detector

n_detector

number of detectors

Definition at line 82 of file pidTrainWeights.py.


The documentation for this class was generated from the following file: