Belle II Software light-2406-ragdoll
|
Public Member Functions | |
def | __init__ (self, n_class=6, n_detector=6, const_init=1) |
def | forward (self, x) |
def | get_weights (self, to_numpy=True) |
def | const_init (self, const) |
def | random_init (self, mean=1.0, std=0.5) |
def | kill_unused (self, only) |
Public Attributes | |
n_class | |
number of particle types | |
n_detector | |
number of detectors | |
fcs | |
linear layers for each particle type | |
PyTorch architecture for training calibration weights.
Definition at line 62 of file pidTrainWeights.py.
def __init__ | ( | self, | |
n_class = 6 , |
|||
n_detector = 6 , |
|||
const_init = 1 |
|||
) |
Initialize the network for training. Args: n_class (int, optional): Number of classification classes (particle types). Defaults to 6. n_detector (int, optional): Number of detectors. Defaults to 6. const_init (int, optional): Constant value to initialize all weights. If None, PyTorch's default random initialization is used instead. Defaults to 1.
Definition at line 65 of file pidTrainWeights.py.
def const_init | ( | self, | |
const | |||
) |
Fill all the weights with the given value. Args: const (float): Constant value to fill all weights with.
Definition at line 125 of file pidTrainWeights.py.
def forward | ( | self, | |
x | |||
) |
Network's forward methods. Sums the detector log-likelihoods for each particle type, then computes the likelihood ratios. Uses the weights. Args: x (torch.Tensor): Input detector log-likelihood data. Should be of shape (N, n_detector * n_class), where N is the number of samples. Returns: torch.Tensor: Weighted likelihood ratios.
Definition at line 92 of file pidTrainWeights.py.
def get_weights | ( | self, | |
to_numpy = True |
|||
) |
Returns the weights as a six-by-six array or tensor. Args: to_numpy (bool, optional): Whether to return the weights as a numpy array (True) or torch tensor (False). Defaults to True. Returns: np.array or torch.Tensor: The six-by-six matrix containing the weights.
Definition at line 108 of file pidTrainWeights.py.
def kill_unused | ( | self, | |
only | |||
) |
Kills weights corresponding to unused particle types. Args: only (list(str) or None): List of allowed particle types. The weights corresponding to any particle types that are _not_ in this list will be filled with zero and be frozen (e.g. gradients will not be computed/updated).
Definition at line 150 of file pidTrainWeights.py.
def random_init | ( | self, | |
mean = 1.0 , |
|||
std = 0.5 |
|||
) |
Fill all the weights with values sampled from a Normal distribution with given mean and standard deviation. Args: mean (float, optional): The mean of the Normal distribution. Defaults to 1.0. std (float, optional): The standard deviation of the Normal distribution. Defaults to 0.5.
Definition at line 135 of file pidTrainWeights.py.
fcs |
linear layers for each particle type
Definition at line 85 of file pidTrainWeights.py.
n_class |
number of particle types
Definition at line 79 of file pidTrainWeights.py.
n_detector |
number of detectors
Definition at line 82 of file pidTrainWeights.py.