Belle II Software  release-08-01-10
WeightNet Class Reference
Inheritance diagram for WeightNet:
Collaboration diagram for WeightNet:

Public Member Functions

def __init__ (self, n_class=6, n_detector=6, const_init=1)
 
def forward (self, x)
 
def get_weights (self, to_numpy=True)
 
def const_init (self, const)
 
def random_init (self, mean=1.0, std=0.5)
 
def kill_unused (self, only)
 

Public Attributes

 n_class
 number of particle types
 
 n_detector
 number of detectors
 
 fcs
 linear layers for each particle type
 

Detailed Description

PyTorch architecture for training calibration weights.

Definition at line 62 of file pidTrainWeights.py.

Constructor & Destructor Documentation

◆ __init__()

def __init__ (   self,
  n_class = 6,
  n_detector = 6,
  const_init = 1 
)
Initialize the network for training.

Args:
    n_class (int, optional): Number of classification classes (particle
        types). Defaults to 6.
    n_detector (int, optional): Number of detectors. Defaults to 6.
    const_init (int, optional): Constant value to initialize all
        weights. If None, PyTorch's default random initialization is
        used instead. Defaults to 1.

Definition at line 65 of file pidTrainWeights.py.

65  def __init__(self, n_class=6, n_detector=6, const_init=1):
66  """Initialize the network for training.
67 
68  Args:
69  n_class (int, optional): Number of classification classes (particle
70  types). Defaults to 6.
71  n_detector (int, optional): Number of detectors. Defaults to 6.
72  const_init (int, optional): Constant value to initialize all
73  weights. If None, PyTorch's default random initialization is
74  used instead. Defaults to 1.
75  """
76  super().__init__()
77 
78 
79  self.n_class = n_class
80 
81 
82  self.n_detector = n_detector
83 
84 
85  self.fcs = nn.ModuleList(
86  [nn.Linear(self.n_detector, 1, bias=False) for _ in range(self.n_class)]
87  )
88 
89  if const_init is not None:
90  self.const_init(const_init)
91 

Member Function Documentation

◆ const_init()

def const_init (   self,
  const 
)
Fill all the weights with the given value.

Args:
    const (float): Constant value to fill all weights with.

Definition at line 125 of file pidTrainWeights.py.

◆ forward()

def forward (   self,
  x 
)
Network's forward methods. Sums the detector log-likelihoods for each particle
type, then computes the likelihood ratios. Uses the weights.

Args:
    x (torch.Tensor): Input detector log-likelihood data. Should be of
        shape (N, n_detector * n_class), where N is the number of samples.

Returns:
    torch.Tensor: Weighted likelihood ratios.

Definition at line 92 of file pidTrainWeights.py.

◆ get_weights()

def get_weights (   self,
  to_numpy = True 
)
Returns the weights as a six-by-six array or tensor.

Args:
    to_numpy (bool, optional): Whether to return the weights as a numpy
        array (True) or torch tensor (False). Defaults to True.

Returns:
    np.array or torch.Tensor: The six-by-six matrix containing the
        weights.

Definition at line 108 of file pidTrainWeights.py.

◆ kill_unused()

def kill_unused (   self,
  only 
)
Kills weights corresponding to unused particle types.

Args:
    only (list(str) or None): List of allowed particle types. The
        weights corresponding to any particle types that are _not_ in
        this list will be filled with zero and be frozen (e.g. gradients
        will not be computed/updated).

Definition at line 150 of file pidTrainWeights.py.

◆ random_init()

def random_init (   self,
  mean = 1.0,
  std = 0.5 
)
Fill all the weights with values sampled from a Normal distribution
with given mean and standard deviation.

Args:
    mean (float, optional): The mean of the Normal distribution.
        Defaults to 1.0.
    std (float, optional): The standard deviation of the Normal
        distribution. Defaults to 0.5.

Definition at line 135 of file pidTrainWeights.py.


The documentation for this class was generated from the following file: