|
| n_class |
| number of particle types
|
|
| n_detector |
| number of detectors
|
|
| fcs |
| linear layers for each particle type
|
|
PyTorch architecture for training calibration weights.
Definition at line 62 of file pidTrainWeights.py.
◆ __init__()
def __init__ |
( |
|
self, |
|
|
|
n_class = 6 , |
|
|
|
n_detector = 6 , |
|
|
|
const_init = 1 |
|
) |
| |
Initialize the network for training.
Args:
n_class (int, optional): Number of classification classes (particle
types). Defaults to 6.
n_detector (int, optional): Number of detectors. Defaults to 6.
const_init (int, optional): Constant value to initialize all
weights. If None, PyTorch's default random initialization is
used instead. Defaults to 1.
Definition at line 65 of file pidTrainWeights.py.
65 def __init__(self, n_class=6, n_detector=6, const_init=1):
66 """Initialize the network for training.
69 n_class (int, optional): Number of classification classes (particle
70 types). Defaults to 6.
71 n_detector (int, optional): Number of detectors. Defaults to 6.
72 const_init (int, optional): Constant value to initialize all
73 weights. If None, PyTorch's default random initialization is
74 used instead. Defaults to 1.
79 self.n_class = n_class
82 self.n_detector = n_detector
85 self.fcs = nn.ModuleList(
86 [nn.Linear(self.n_detector, 1, bias=
False)
for _
in range(self.n_class)]
89 if const_init
is not None:
90 self.const_init(const_init)
◆ const_init()
def const_init |
( |
|
self, |
|
|
|
const |
|
) |
| |
Fill all the weights with the given value.
Args:
const (float): Constant value to fill all weights with.
Definition at line 125 of file pidTrainWeights.py.
◆ forward()
Network's forward methods. Sums the detector log-likelihoods for each particle
type, then computes the likelihood ratios. Uses the weights.
Args:
x (torch.Tensor): Input detector log-likelihood data. Should be of
shape (N, n_detector * n_class), where N is the number of samples.
Returns:
torch.Tensor: Weighted likelihood ratios.
Definition at line 92 of file pidTrainWeights.py.
◆ get_weights()
def get_weights |
( |
|
self, |
|
|
|
to_numpy = True |
|
) |
| |
Returns the weights as a six-by-six array or tensor.
Args:
to_numpy (bool, optional): Whether to return the weights as a numpy
array (True) or torch tensor (False). Defaults to True.
Returns:
np.array or torch.Tensor: The six-by-six matrix containing the
weights.
Definition at line 108 of file pidTrainWeights.py.
◆ kill_unused()
def kill_unused |
( |
|
self, |
|
|
|
only |
|
) |
| |
Kills weights corresponding to unused particle types.
Args:
only (list(str) or None): List of allowed particle types. The
weights corresponding to any particle types that are _not_ in
this list will be filled with zero and be frozen (e.g. gradients
will not be computed/updated).
Definition at line 150 of file pidTrainWeights.py.
◆ random_init()
def random_init |
( |
|
self, |
|
|
|
mean = 1.0 , |
|
|
|
std = 0.5 |
|
) |
| |
Fill all the weights with values sampled from a Normal distribution
with given mean and standard deviation.
Args:
mean (float, optional): The mean of the Normal distribution.
Defaults to 1.0.
std (float, optional): The standard deviation of the Normal
distribution. Defaults to 0.5.
Definition at line 135 of file pidTrainWeights.py.
The documentation for this class was generated from the following file: