Belle II Software development
WeightNet Class Reference
Inheritance diagram for WeightNet:

Public Member Functions

def __init__ (self, n_class=6, n_detector=6, const_init=1)
 
def forward (self, x)
 
def get_weights (self, to_numpy=True)
 
def const_init (self, const)
 
def random_init (self, mean=1.0, std=0.5)
 
def kill_unused (self, only)
 

Public Attributes

 n_class
 number of particle types
 
 n_detector
 number of detectors
 
 fcs
 linear layers for each particle type
 

Detailed Description

PyTorch architecture for training calibration weights.

Definition at line 63 of file pidTrainWeights.py.

Constructor & Destructor Documentation

◆ __init__()

def __init__ (   self,
  n_class = 6,
  n_detector = 6,
  const_init = 1 
)
Initialize the network for training.

Args:
    n_class (int, optional): Number of classification classes (particle
        types). Defaults to 6.
    n_detector (int, optional): Number of detectors. Defaults to 6.
    const_init (int, optional): Constant value to initialize all
        weights. If None, PyTorch's default random initialization is
        used instead. Defaults to 1.

Definition at line 66 of file pidTrainWeights.py.

66 def __init__(self, n_class=6, n_detector=6, const_init=1):
67 """Initialize the network for training.
68
69 Args:
70 n_class (int, optional): Number of classification classes (particle
71 types). Defaults to 6.
72 n_detector (int, optional): Number of detectors. Defaults to 6.
73 const_init (int, optional): Constant value to initialize all
74 weights. If None, PyTorch's default random initialization is
75 used instead. Defaults to 1.
76 """
77 super().__init__()
78
79
80 self.n_class = n_class
81
82
83 self.n_detector = n_detector
84
85
86 self.fcs = nn.ModuleList(
87 [nn.Linear(self.n_detector, 1, bias=False) for _ in range(self.n_class)]
88 )
89
90 if const_init is not None:
91 self.const_init(const_init)
92

Member Function Documentation

◆ const_init()

def const_init (   self,
  const 
)
Fill all the weights with the given value.

Args:
    const (float): Constant value to fill all weights with.

Definition at line 126 of file pidTrainWeights.py.

126 def const_init(self, const):
127 """Fill all the weights with the given value.
128
129 Args:
130 const (float): Constant value to fill all weights with.
131 """
132 with torch.no_grad():
133 for fc in self.fcs:
134 fc.weight.fill_(const)
135

◆ forward()

def forward (   self,
  x 
)
Network's forward methods. Sums the detector log-likelihoods for each particle
type, then computes the likelihood ratios. Uses the weights.

Args:
    x (torch.Tensor): Input detector log-likelihood data. Should be of
        shape (N, n_detector * n_class), where N is the number of samples.

Returns:
    torch.Tensor: Weighted likelihood ratios.

Definition at line 93 of file pidTrainWeights.py.

93 def forward(self, x):
94 """Network's forward methods. Sums the detector log-likelihoods for each particle
95 type, then computes the likelihood ratios. Uses the weights.
96
97 Args:
98 x (torch.Tensor): Input detector log-likelihood data. Should be of
99 shape (N, n_detector * n_class), where N is the number of samples.
100
101 Returns:
102 torch.Tensor: Weighted likelihood ratios.
103 """
104 n = self.n_detector
105 outs = [self.fcs[i](x[:, i * n: (i + 1) * n]) for i in range(self.n_class)]
106 out = torch.cat(outs, dim=1)
107 return F.softmax(out, dim=1)
108

◆ get_weights()

def get_weights (   self,
  to_numpy = True 
)
Returns the weights as a six-by-six array or tensor.

Args:
    to_numpy (bool, optional): Whether to return the weights as a numpy
        array (True) or torch tensor (False). Defaults to True.

Returns:
    np.array or torch.Tensor: The six-by-six matrix containing the
        weights.

Definition at line 109 of file pidTrainWeights.py.

109 def get_weights(self, to_numpy=True):
110 """Returns the weights as a six-by-six array or tensor.
111
112 Args:
113 to_numpy (bool, optional): Whether to return the weights as a numpy
114 array (True) or torch tensor (False). Defaults to True.
115
116 Returns:
117 np.array or torch.Tensor: The six-by-six matrix containing the
118 weights.
119 """
120 weights = torch.cat([fc.weight.detach() for fc in self.fcs])
121 if to_numpy:
122 return weights.cpu().numpy()
123 else:
124 return weights
125

◆ kill_unused()

def kill_unused (   self,
  only 
)
Kills weights corresponding to unused particle types.

Args:
    only (list(str) or None): List of allowed particle types. The
        weights corresponding to any particle types that are _not_ in
        this list will be filled with zero and be frozen (e.g. gradients
        will not be computed/updated).

Definition at line 151 of file pidTrainWeights.py.

151 def kill_unused(self, only):
152 """Kills weights corresponding to unused particle types.
153
154 Args:
155 only (list(str) or None): List of allowed particle types. The
156 weights corresponding to any particle types that are _not_ in
157 this list will be filled with zero and be frozen (e.g. gradients
158 will not be computed/updated).
159 """
160 if only is not None:
161 # particle types that are not being trained...
162 # set to zero and freeze
163 for i, pdg in enumerate(PDG_CODES):
164 if pdg in only:
165 continue
166 self.fcs[i].weight.requires_grad = False
167 self.fcs[i].weight.fill_(1)
168
169

◆ random_init()

def random_init (   self,
  mean = 1.0,
  std = 0.5 
)
Fill all the weights with values sampled from a Normal distribution
with given mean and standard deviation.

Args:
    mean (float, optional): The mean of the Normal distribution.
        Defaults to 1.0.
    std (float, optional): The standard deviation of the Normal
        distribution. Defaults to 0.5.

Definition at line 136 of file pidTrainWeights.py.

136 def random_init(self, mean=1.0, std=0.5):
137 """Fill all the weights with values sampled from a Normal distribution
138 with given mean and standard deviation.
139
140 Args:
141 mean (float, optional): The mean of the Normal distribution.
142 Defaults to 1.0.
143 std (float, optional): The standard deviation of the Normal
144 distribution. Defaults to 0.5.
145 """
146 with torch.no_grad():
147 for fc in self.fcs:
148 fc.weight.fill_(0)
149 fc.weight.add_(torch.normal(mean=mean, std=std, size=fc.weight.size()))
150

Member Data Documentation

◆ fcs

fcs

linear layers for each particle type

Definition at line 86 of file pidTrainWeights.py.

◆ n_class

n_class

number of particle types

Definition at line 80 of file pidTrainWeights.py.

◆ n_detector

n_detector

number of detectors

Definition at line 83 of file pidTrainWeights.py.


The documentation for this class was generated from the following file: