Belle II Software development
Generator Class Reference
Inheritance diagram for Generator:
Model

Public Member Functions

def __init__ (self, G_ch=64, G_depth=2, dim_z=128, bottom_width=4, resolution=256, G_kernel_size=3, G_attn="64", n_classes=40, H_base=1, num_G_SVs=1, num_G_SV_itrs=1, attn_type="sa", G_shared=True, shared_dim=128, hier=True, cross_replica=False, mybn=False, G_activation=nn.ReLU(inplace=False), G_lr=5e-5, G_B1=0.0, G_B2=0.999, adam_eps=1e-8, BN_eps=1e-5, SN_eps=1e-12, G_init="ortho", G_mixed_precision=False, G_fp16=False, skip_init=False, no_optim=False, sched_version="default", RRM_prx_G=True, n_head_G=2, G_param="SN", norm_style="bn", **kwargs)
 Constructor.
 
def init_weights (self)
 Initialize.
 
def forward (self, z, y)
 forward
 

Public Attributes

 ch
 Channel width mulitplier.
 
 G_depth
 Number of resblocks per stage.
 
 dim_z
 Dimensionality of the latent space.
 
 bottom_width
 The initial spatial dimensions.
 
 H_base
 The initial harizontal dimension.
 
 resolution
 Resolution of the output.
 
 kernel_size
 Kernel size?
 
 attention
 Attention?
 
 n_classes
 number of classes, for use in categorical conditional generation
 
 G_shared
 Use shared embeddings?
 
 shared_dim
 Dimensionality of the shared embedding? Unused if not using G_shared.
 
 hier
 Hierarchical latent space?
 
 cross_replica
 Cross replica batchnorm?
 
 mybn
 Use my batchnorm?
 
 activation
 activation
 
 init
 Initialization style.
 
 G_param
 Parameterization style.
 
 norm_style
 Normalization style.
 
 BN_eps
 Epsilon for BatchNorm?
 
 SN_eps
 Epsilon for Spectral Norm?
 
 fp16
 fp16?
 
 arch
 Architecture dict.
 
 RRM_prx_G
 RRM_prx_G.
 
 n_head_G
 n_head_G
 
 which_conv
 which conv
 
 which_linear
 which linear
 
 which_embedding
 which embedding
 
 which_bn
 which bn
 
 shared
 shared
 
 RR_G
 RRM on proxy embeddings.
 
 linear
 First linear layer.
 
 blocks
 blocks
 
 output_layer
 output layer
 
 lr
 lr
 
 B1
 B1.
 
 B2
 B2.
 
 adam_eps
 adam_eps
 
 optim
 optim
 
 lr_sched
 lr sched
 
 param_count
 parameter count
 

Detailed Description

Generator

Definition at line 1127 of file ieagan.py.

Constructor & Destructor Documentation

◆ __init__()

def __init__ (   self,
  G_ch = 64,
  G_depth = 2,
  dim_z = 128,
  bottom_width = 4,
  resolution = 256,
  G_kernel_size = 3,
  G_attn = "64",
  n_classes = 40,
  H_base = 1,
  num_G_SVs = 1,
  num_G_SV_itrs = 1,
  attn_type = "sa",
  G_shared = True,
  shared_dim = 128,
  hier = True,
  cross_replica = False,
  mybn = False,
  G_activation = nn.ReLU(inplace=False),
  G_lr = 5e-5,
  G_B1 = 0.0,
  G_B2 = 0.999,
  adam_eps = 1e-8,
  BN_eps = 1e-5,
  SN_eps = 1e-12,
  G_init = "ortho",
  G_mixed_precision = False,
  G_fp16 = False,
  skip_init = False,
  no_optim = False,
  sched_version = "default",
  RRM_prx_G = True,
  n_head_G = 2,
  G_param = "SN",
  norm_style = "bn",
**  kwargs 
)

Constructor.

Reimplemented in Model.

Definition at line 1131 of file ieagan.py.

1168 ):
1169 super(Generator, self).__init__()
1170
1171 self.ch = G_ch
1172
1173 self.G_depth = G_depth
1174
1175 self.dim_z = dim_z
1176
1177 self.bottom_width = bottom_width
1178
1179 self.H_base = H_base
1180
1181 self.resolution = resolution
1182
1183 self.kernel_size = G_kernel_size
1184
1185 self.attention = G_attn
1186
1187 self.n_classes = n_classes
1188
1189 self.G_shared = G_shared
1190
1191 self.shared_dim = shared_dim if shared_dim > 0 else dim_z
1192
1193 self.hier = hier
1194
1195 self.cross_replica = cross_replica
1196
1197 self.mybn = mybn
1198 # nonlinearity for residual blocks
1199 if G_activation == "inplace_relu":
1200
1201 self.activation = torch.nn.ReLU(inplace=True)
1202 elif G_activation == "relu":
1203 self.activation = torch.nn.ReLU(inplace=False)
1204 elif G_activation == "leaky_relu":
1205 self.activation = torch.nn.LeakyReLU(0.2, inplace=False)
1206 else:
1207 raise NotImplementedError("activation function not implemented")
1208
1209 self.init = G_init
1210
1211 self.G_param = G_param
1212
1213 self.norm_style = norm_style
1214
1215 self.BN_eps = BN_eps
1216
1217 self.SN_eps = SN_eps
1218
1219 self.fp16 = G_fp16
1220
1221 self.arch = G_arch(self.ch, self.attention)[resolution]
1222
1223 self.RRM_prx_G = RRM_prx_G
1224
1225 self.n_head_G = n_head_G
1226
1227 # Which convs, batchnorms, and linear layers to use
1228 if self.G_param == "SN":
1229
1230 self.which_conv = functools.partial(
1231 SNConv2d,
1232 kernel_size=3,
1233 padding=1,
1234 num_svs=num_G_SVs,
1235 num_itrs=num_G_SV_itrs,
1236 eps=self.SN_eps,
1237 )
1238
1239 self.which_linear = functools.partial(
1240 SNLinear,
1241 num_svs=num_G_SVs,
1242 num_itrs=num_G_SV_itrs,
1243 eps=self.SN_eps,
1244 )
1245 else:
1246 self.which_conv = functools.partial(nn.Conv2d, kernel_size=3, padding=1)
1247 self.which_linear = nn.Linear
1248
1249 # We use a non-spectral-normed embedding here regardless;
1250 # For some reason applying SN to G's embedding seems to randomly cripple G # noqa
1251
1252 self.which_embedding = nn.Embedding
1253 bn_linear = (
1254 functools.partial(self.which_linear, bias=False)
1255 if self.G_shared
1256 else self.which_embedding
1257 )
1258
1259 self.which_bn = functools.partial(
1260 ccbn,
1261 which_linear=bn_linear,
1262 cross_replica=self.cross_replica,
1263 mybn=self.mybn,
1264 input_size=(
1265 self.shared_dim + self.dim_z if self.G_shared else self.n_classes
1266 ),
1267 norm_style=self.norm_style,
1268 eps=self.BN_eps,
1269 )
1270
1271 self.shared = (
1272 self.which_embedding(n_classes, self.shared_dim)
1273 if G_shared
1274 else identity()
1275 )
1276
1277 if self.RRM_prx_G:
1278
1279 self.RR_G = RelationalReasoning(
1280 num_layers=1,
1281 input_dim=128,
1282 dim_feedforward=128,
1283 which_linear=nn.Linear,
1284 num_heads=self.n_head_G,
1285 dropout=0.0,
1286 hidden_dim=128,
1287 )
1288
1289
1290 self.linear = self.which_linear(
1291 self.dim_z + self.shared_dim,
1292 self.arch["in_channels"][0] * ((self.bottom_width**2) * self.H_base),
1293 )
1294
1295 # self.blocks is a doubly-nested list of modules, the outer loop intended # noqa
1296 # to be over blocks at a given resolution (resblocks and/or self-attention) # noqa
1297 # while the inner loop is over a given block
1298
1299 self.blocks = []
1300 for index in range(len(self.arch["out_channels"])):
1301 self.blocks += [
1302 [
1303 GBlock(
1304 in_channels=self.arch["in_channels"][index],
1305 out_channels=self.arch["in_channels"][index]
1306 if g_index == 0
1307 else self.arch["out_channels"][index],
1308 which_conv=self.which_conv,
1309 which_bn=self.which_bn,
1310 activation=self.activation,
1311 upsample=(
1312 functools.partial(F.interpolate, scale_factor=2)
1313 if self.arch["upsample"][index]
1314 and g_index == (self.G_depth - 1)
1315 else None
1316 ),
1317 )
1318 ]
1319 for g_index in range(self.G_depth)
1320 ]
1321
1322 # If attention on this block, attach it to the end
1323 if self.arch["attention"][self.arch["resolution"][index]]:
1324 print(
1325 f"Adding attention layer in G at resolution {self.arch['resolution'][index]:d}"
1326 )
1327
1328 if attn_type == "sa":
1329 self.blocks[-1] += [
1330 Attention(self.arch["out_channels"][index], self.which_conv)
1331 ]
1332 elif attn_type == "cbam":
1333 self.blocks[-1] += [
1334 CBAM_attention(
1335 self.arch["out_channels"][index], self.which_conv
1336 )
1337 ]
1338 elif attn_type == "ila":
1339 self.blocks[-1] += [ILA(self.arch["out_channels"][index])]
1340
1341 # Turn self.blocks into a ModuleList so that it's all properly registered. # noqa
1342 self.blocks = nn.ModuleList([nn.ModuleList(block) for block in self.blocks])
1343
1344 # output layer: batchnorm-relu-conv.
1345 # Consider using a non-spectral conv here
1346
1347 self.output_layer = nn.Sequential(
1348 bn(
1349 self.arch["out_channels"][-1],
1350 cross_replica=self.cross_replica,
1351 mybn=self.mybn,
1352 ),
1353 self.activation,
1354 self.which_conv(self.arch["out_channels"][-1], 1),
1355 )
1356
1357 # Initialize weights. Optionally skip init for testing.
1358 if not skip_init:
1359 self.init_weights()
1360
1361 # Set up optimizer
1362 # If this is an EMA copy, no need for an optim, so just return now
1363 if no_optim:
1364 return
1365
1366 self.lr = G_lr
1367
1368 self.B1 = G_B1
1369
1370 self.B2 = G_B2
1371
1372 self.adam_eps = adam_eps
1373 if G_mixed_precision:
1374 print("Using fp16 adam in G...")
1375 import utils
1376
1377 self.optim = utils.Adam16(
1378 params=self.parameters(),
1379 lr=self.lr,
1380 betas=(self.B1, self.B2),
1381 weight_decay=0,
1382 eps=self.adam_eps,
1383 )
1384
1385
1386 self.optim = optim.Adam(
1387 params=self.parameters(),
1388 lr=self.lr,
1389 betas=(self.B1, self.B2),
1390 weight_decay=0,
1391 eps=self.adam_eps,
1392 )
1393 # LR scheduling
1394 if sched_version == "default":
1395
1396 self.lr_sched = None
1397 elif sched_version == "CosAnnealLR":
1398 self.lr_sched = optim.lr_scheduler.CosineAnnealingLR(
1399 self.optim,
1400 T_max=kwargs["num_epochs"],
1401 eta_min=self.lr / 4,
1402 last_epoch=-1,
1403 )
1404 elif sched_version == "CosAnnealWarmRes":
1405 self.lr_sched = optim.lr_scheduler.CosineAnnealingWarmRestarts(
1406 self.optim, T_0=10, T_mult=2, eta_min=self.lr / 4
1407 )
1408 else:
1409 self.lr_sched = None
1410

Member Function Documentation

◆ forward()

def forward (   self,
  z,
  y 
)

forward

Definition at line 1435 of file ieagan.py.

1435 def forward(self, z, y):
1436 y = self.shared(y)
1437 # If relational embedding
1438 if self.RRM_prx_G:
1439 y = self.RR_G(y.unsqueeze(0)).squeeze(0)
1440 # y = F.normalize(y, dim=1)
1441 # If hierarchical, concatenate zs and ys
1442 if self.hier: # y and z are [bs,128] dimensional
1443 z = torch.cat([y, z], 1)
1444 y = z
1445 # First linear layer
1446 h = self.linear(z) # ([bs,256]-->[bs,24576])
1447 # Reshape
1448 h = h.view(h.size(0), -1, self.bottom_width, self.bottom_width * self.H_base)
1449 # Loop over blocks
1450 for _, blocklist in enumerate(self.blocks):
1451 # Second inner loop in case block has multiple layers
1452 for block in blocklist:
1453 h = block(h, y)
1454
1455 # Apply batchnorm-relu-conv-tanh at output
1456 return torch.tanh(self.output_layer(h))
1457
1458

◆ init_weights()

def init_weights (   self)

Initialize.

Definition at line 1412 of file ieagan.py.

1412 def init_weights(self):
1413
1414 self.param_count = 0
1415 for module in self.modules():
1416 if (
1417 isinstance(module, nn.Conv2d)
1418 or isinstance(module, nn.Linear)
1419 or isinstance(module, nn.Embedding)
1420 ):
1421 if self.init == "ortho":
1422 init.orthogonal_(module.weight)
1423 elif self.init == "N02":
1424 init.normal_(module.weight, 0, 0.02)
1425 elif self.init in ["glorot", "xavier"]:
1426 init.xavier_uniform_(module.weight)
1427 else:
1428 print("Init style not recognized...")
1429 self.param_count += sum(
1430 [p.data.nelement() for p in module.parameters()]
1431 )
1432 print(f"Param count for G's initialized parameters: {self.param_count}")
1433

Member Data Documentation

◆ activation

activation

activation

Definition at line 1201 of file ieagan.py.

◆ adam_eps

adam_eps

adam_eps

Definition at line 1372 of file ieagan.py.

◆ arch

arch

Architecture dict.

Definition at line 1221 of file ieagan.py.

◆ attention

attention

Attention?

Definition at line 1185 of file ieagan.py.

◆ B1

B1

B1.

Definition at line 1368 of file ieagan.py.

◆ B2

B2

B2.

Definition at line 1370 of file ieagan.py.

◆ blocks

blocks

blocks

Definition at line 1299 of file ieagan.py.

◆ BN_eps

BN_eps

Epsilon for BatchNorm?

Definition at line 1215 of file ieagan.py.

◆ bottom_width

bottom_width

The initial spatial dimensions.

Definition at line 1177 of file ieagan.py.

◆ ch

ch

Channel width mulitplier.

Definition at line 1171 of file ieagan.py.

◆ cross_replica

cross_replica

Cross replica batchnorm?

Definition at line 1195 of file ieagan.py.

◆ dim_z

dim_z

Dimensionality of the latent space.

Definition at line 1175 of file ieagan.py.

◆ fp16

fp16

fp16?

Definition at line 1219 of file ieagan.py.

◆ G_depth

G_depth

Number of resblocks per stage.

Definition at line 1173 of file ieagan.py.

◆ G_param

G_param

Parameterization style.

Definition at line 1211 of file ieagan.py.

◆ G_shared

G_shared

Use shared embeddings?

Definition at line 1189 of file ieagan.py.

◆ H_base

H_base

The initial harizontal dimension.

Definition at line 1179 of file ieagan.py.

◆ hier

hier

Hierarchical latent space?

Definition at line 1193 of file ieagan.py.

◆ init

init

Initialization style.

Definition at line 1209 of file ieagan.py.

◆ kernel_size

kernel_size

Kernel size?

Definition at line 1183 of file ieagan.py.

◆ linear

linear

First linear layer.

Definition at line 1290 of file ieagan.py.

◆ lr

lr

lr

Definition at line 1366 of file ieagan.py.

◆ lr_sched

lr_sched

lr sched

Definition at line 1396 of file ieagan.py.

◆ mybn

mybn

Use my batchnorm?

Definition at line 1197 of file ieagan.py.

◆ n_classes

n_classes

number of classes, for use in categorical conditional generation

Definition at line 1187 of file ieagan.py.

◆ n_head_G

n_head_G

n_head_G

Definition at line 1225 of file ieagan.py.

◆ norm_style

norm_style

Normalization style.

Definition at line 1213 of file ieagan.py.

◆ optim

optim

optim

Definition at line 1377 of file ieagan.py.

◆ output_layer

output_layer

output layer

Definition at line 1347 of file ieagan.py.

◆ param_count

param_count

parameter count

Definition at line 1414 of file ieagan.py.

◆ resolution

resolution

Resolution of the output.

Definition at line 1181 of file ieagan.py.

◆ RR_G

RR_G

RRM on proxy embeddings.

Definition at line 1279 of file ieagan.py.

◆ RRM_prx_G

RRM_prx_G

RRM_prx_G.

Definition at line 1223 of file ieagan.py.

◆ shared

shared

shared

Definition at line 1271 of file ieagan.py.

◆ shared_dim

shared_dim

Dimensionality of the shared embedding? Unused if not using G_shared.

Definition at line 1191 of file ieagan.py.

◆ SN_eps

SN_eps

Epsilon for Spectral Norm?

Definition at line 1217 of file ieagan.py.

◆ which_bn

which_bn

which bn

Definition at line 1259 of file ieagan.py.

◆ which_conv

which_conv

which conv

Definition at line 1230 of file ieagan.py.

◆ which_embedding

which_embedding

which embedding

Definition at line 1252 of file ieagan.py.

◆ which_linear

which_linear

which linear

Definition at line 1239 of file ieagan.py.


The documentation for this class was generated from the following file: