Belle II Software development
SN Class Reference
Inheritance diagram for SN:
SNConv2d SNEmbedding SNLinear

Public Member Functions

 __init__ (self, num_svs, num_itrs, num_outputs, transpose=False, eps=1e-12)
 
 u (self)
 
 sv (self)
 
 W_ (self)
 

Public Attributes

 num_itrs = num_itrs
 Number of power iterations per step.
 
 num_svs = num_svs
 Number of singular values.
 
 transpose = transpose
 Transposed?
 
 eps = eps
 Epsilon value for avoiding divide-by-0.
 
bool training
 Training mode flag (inherited from nn.Module).
 

Detailed Description

Spectral normalization base class

This base class expects subclasses to have a learnable weight parameter
(`self.weight`) as in `nn.Linear` or `nn.Conv2d`. It provides a method
to apply spectral normalization to that weight.

Attributes:
    num_svs (int): Number of singular values.
    num_itrs (int): Number of power iterations per step.
    transpose (bool): Whether to transpose the weight matrix.
    eps (float): Small constant to avoid divide-by-zero.
    u (list[Tensor]): Registered left singular vectors (buffers).
    sv (list[Tensor]): Registered singular values (buffers).
    training (bool): Inherited from nn.Module. True if in training mode.

Definition at line 226 of file ieagan.py.

Constructor & Destructor Documentation

◆ __init__()

__init__ ( self,
num_svs,
num_itrs,
num_outputs,
transpose = False,
eps = 1e-12 )
constructor

Definition at line 244 of file ieagan.py.

244 def __init__(self, num_svs, num_itrs, num_outputs, transpose=False, eps=1e-12):
245 """constructor"""
246
247 super().__init__()
248 ## Number of power iterations per step
249 self.num_itrs = num_itrs
250 ## Number of singular values
251 self.num_svs = num_svs
252 ## Transposed?
253 self.transpose = transpose
254 ## Epsilon value for avoiding divide-by-0
255 self.eps = eps
256 # Register a singular vector for each sv
257 for i in range(self.num_svs):
258 self.register_buffer(f"u{i:d}", torch.randn(1, num_outputs))
259 self.register_buffer(f"sv{i:d}", torch.ones(1))
260 ## Training mode flag (inherited from nn.Module). True if the module is in training mode.
261 self.training: bool
262

Member Function Documentation

◆ sv()

sv ( self)
Singular values
note that these buffers are just for logging and are not used in training.

Definition at line 271 of file ieagan.py.

271 def sv(self):
272 """
273 Singular values
274 note that these buffers are just for logging and are not used in training.
275 """
276 return [getattr(self, f"sv{i:d}") for i in range(self.num_svs)]
277

◆ u()

u ( self)
Singular vectors (u side)

Definition at line 264 of file ieagan.py.

264 def u(self):
265 """
266 Singular vectors (u side)
267 """
268 return [getattr(self, f"u{i:d}") for i in range(self.num_svs)]
269

◆ W_()

W_ ( self)
Compute the spectrally-normalized weight

Definition at line 278 of file ieagan.py.

278 def W_(self):
279 """
280 Compute the spectrally-normalized weight
281 """
282 W_mat = self.weight.view(self.weight.size(0), -1)
283 if self.transpose:
284 W_mat = W_mat.t()
285 # Apply num_itrs power iterations
286 for _ in range(self.num_itrs):
287 svs, _, _ = power_iteration(
288 W_mat, self.u, update=self.training, eps=self.eps
289 )
290 # Update the svs
291 if self.training:
292 # Make sure to do this in a no_grad() context or you'll get memory leaks! # noqa
293 with torch.no_grad():
294 for i, sv in enumerate(svs):
295 self.sv[i][:] = sv
296 return self.weight / svs[0]
297
298

Member Data Documentation

◆ eps

eps = eps

Epsilon value for avoiding divide-by-0.

Definition at line 255 of file ieagan.py.

◆ num_itrs

num_itrs = num_itrs

Number of power iterations per step.

Definition at line 249 of file ieagan.py.

◆ num_svs

num_svs = num_svs

Number of singular values.

Definition at line 251 of file ieagan.py.

◆ training

bool training

Training mode flag (inherited from nn.Module).

True if the module is in training mode.

Definition at line 261 of file ieagan.py.

◆ transpose

transpose = transpose

Transposed?

Definition at line 253 of file ieagan.py.


The documentation for this class was generated from the following file: