![]() |
Belle II Software release-09-00-03
|


Public Member Functions | |
| def | __init__ (self, n_class=6, n_detector=6, const_init=1) |
| def | forward (self, x) |
| def | get_weights (self, to_numpy=True) |
| def | const_init (self, const) |
| def | random_init (self, mean=1.0, std=0.5) |
| def | kill_unused (self, only) |
Public Attributes | |
| n_class | |
| number of particle types | |
| n_detector | |
| number of detectors | |
| fcs | |
| linear layers for each particle type | |
PyTorch architecture for training calibration weights.
Definition at line 62 of file pidTrainWeights.py.
| def __init__ | ( | self, | |
n_class = 6, |
|||
n_detector = 6, |
|||
const_init = 1 |
|||
| ) |
Initialize the network for training.
Args:
n_class (int, optional): Number of classification classes (particle
types). Defaults to 6.
n_detector (int, optional): Number of detectors. Defaults to 6.
const_init (int, optional): Constant value to initialize all
weights. If None, PyTorch's default random initialization is
used instead. Defaults to 1.
Definition at line 65 of file pidTrainWeights.py.
| def const_init | ( | self, | |
| const | |||
| ) |
Fill all the weights with the given value.
Args:
const (float): Constant value to fill all weights with.
Definition at line 125 of file pidTrainWeights.py.
| def forward | ( | self, | |
| x | |||
| ) |
Network's forward methods. Sums the detector log-likelihoods for each particle
type, then computes the likelihood ratios. Uses the weights.
Args:
x (torch.Tensor): Input detector log-likelihood data. Should be of
shape (N, n_detector * n_class), where N is the number of samples.
Returns:
torch.Tensor: Weighted likelihood ratios.
Definition at line 92 of file pidTrainWeights.py.
| def get_weights | ( | self, | |
to_numpy = True |
|||
| ) |
Returns the weights as a six-by-six array or tensor.
Args:
to_numpy (bool, optional): Whether to return the weights as a numpy
array (True) or torch tensor (False). Defaults to True.
Returns:
np.array or torch.Tensor: The six-by-six matrix containing the
weights.
Definition at line 108 of file pidTrainWeights.py.
| def kill_unused | ( | self, | |
| only | |||
| ) |
Kills weights corresponding to unused particle types.
Args:
only (list(str) or None): List of allowed particle types. The
weights corresponding to any particle types that are _not_ in
this list will be filled with zero and be frozen (e.g. gradients
will not be computed/updated).
Definition at line 150 of file pidTrainWeights.py.
| def random_init | ( | self, | |
mean = 1.0, |
|||
std = 0.5 |
|||
| ) |
Fill all the weights with values sampled from a Normal distribution
with given mean and standard deviation.
Args:
mean (float, optional): The mean of the Normal distribution.
Defaults to 1.0.
std (float, optional): The standard deviation of the Normal
distribution. Defaults to 0.5.
Definition at line 135 of file pidTrainWeights.py.
| fcs |
linear layers for each particle type
Definition at line 85 of file pidTrainWeights.py.
| n_class |
number of particle types
Definition at line 79 of file pidTrainWeights.py.
| n_detector |
number of detectors
Definition at line 82 of file pidTrainWeights.py.