Belle II Software development
SN Class Reference
Inheritance diagram for SN:
SNConv2d SNEmbedding SNLinear

Public Member Functions

def __init__ (self, num_svs, num_itrs, num_outputs, transpose=False, eps=1e-12)
 
def u (self)
 
def sv (self)
 
def W_ (self)
 

Public Attributes

 num_itrs
 Number of power iterations per step.
 
 num_svs
 Number of singular values.
 
 transpose
 Transposed?
 
 eps
 Epsilon value for avoiding divide-by-0.
 

Detailed Description

Spectral normalization base class

Definition at line 226 of file ieagan.py.

Constructor & Destructor Documentation

◆ __init__()

def __init__ (   self,
  num_svs,
  num_itrs,
  num_outputs,
  transpose = False,
  eps = 1e-12 
)
constructor

Reimplemented in SNConv2d, SNLinear, and SNEmbedding.

Definition at line 232 of file ieagan.py.

232 def __init__(self, num_svs, num_itrs, num_outputs, transpose=False, eps=1e-12):
233 """constructor"""
234
235
236 self.num_itrs = num_itrs
237
238 self.num_svs = num_svs
239
240 self.transpose = transpose
241
242 self.eps = eps
243 # Register a singular vector for each sv
244 for i in range(self.num_svs):
245 self.register_buffer(f"u{i:d}", torch.randn(1, num_outputs))
246 self.register_buffer(f"sv{i:d}", torch.ones(1))
247

Member Function Documentation

◆ sv()

def sv (   self)
Singular values
note that these buffers are just for logging and are not used in training.

Definition at line 256 of file ieagan.py.

256 def sv(self):
257 """
258 Singular values
259 note that these buffers are just for logging and are not used in training.
260 """
261 return [getattr(self, f"sv{i:d}") for i in range(self.num_svs)]
262

◆ u()

def u (   self)
Singular vectors (u side)

Definition at line 249 of file ieagan.py.

249 def u(self):
250 """
251 Singular vectors (u side)
252 """
253 return [getattr(self, f"u{i:d}") for i in range(self.num_svs)]
254

◆ W_()

def W_ (   self)
Compute the spectrally-normalized weight

Definition at line 263 of file ieagan.py.

263 def W_(self):
264 """
265 Compute the spectrally-normalized weight
266 """
267 W_mat = self.weight.view(self.weight.size(0), -1)
268 if self.transpose:
269 W_mat = W_mat.t()
270 # Apply num_itrs power iterations
271 for _ in range(self.num_itrs):
272 svs, _, _ = power_iteration(
273 W_mat, self.u, update=self.training, eps=self.eps
274 )
275 # Update the svs
276 if self.training:
277 # Make sure to do this in a no_grad() context or you'll get memory leaks! # noqa
278 with torch.no_grad():
279 for i, sv in enumerate(svs):
280 self.sv[i][:] = sv
281 return self.weight / svs[0]
282
283

Member Data Documentation

◆ eps

eps

Epsilon value for avoiding divide-by-0.

Definition at line 242 of file ieagan.py.

◆ num_itrs

num_itrs

Number of power iterations per step.

Definition at line 236 of file ieagan.py.

◆ num_svs

num_svs

Number of singular values.

Definition at line 238 of file ieagan.py.

◆ transpose

transpose

Transposed?

Definition at line 240 of file ieagan.py.


The documentation for this class was generated from the following file: