Belle II Software development
Generator Class Reference
Inheritance diagram for Generator:
Model

Public Member Functions

 __init__ (self, G_ch=64, G_depth=2, dim_z=128, bottom_width=4, resolution=256, G_kernel_size=3, G_attn="64", n_classes=40, H_base=1, num_G_SVs=1, num_G_SV_itrs=1, attn_type="sa", G_shared=True, shared_dim=128, hier=True, cross_replica=False, mybn=False, G_activation=nn.ReLU(inplace=False), G_lr=5e-5, G_B1=0.0, G_B2=0.999, adam_eps=1e-8, BN_eps=1e-5, SN_eps=1e-12, G_init="ortho", G_mixed_precision=False, G_fp16=False, skip_init=False, no_optim=False, sched_version="default", RRM_prx_G=True, n_head_G=2, G_param="SN", norm_style="bn", **kwargs)
 Constructor.
 
 init_weights (self)
 Initialize.
 
 forward (self, z, y)
 forward
 

Public Attributes

 ch = G_ch
 Channel width multiplier.
 
 G_depth = G_depth
 Number of resblocks per stage.
 
 dim_z = dim_z
 Dimensionality of the latent space.
 
 bottom_width = bottom_width
 The initial spatial dimensions.
 
 H_base = H_base
 The initial harizontal dimension.
 
 resolution = resolution
 Resolution of the output.
 
 kernel_size = G_kernel_size
 Kernel size?
 
 attention = G_attn
 Attention?
 
 n_classes = n_classes
 number of classes, for use in categorical conditional generation
 
 G_shared = G_shared
 Use shared embeddings?
 
int shared_dim = shared_dim if shared_dim > 0 else dim_z
 Dimensionality of the shared embedding?
 
# y and z are[bs, 128] dimensional hier = hier
 Hierarchical latent space?
 
 cross_replica = cross_replica
 Cross replica batchnorm?
 
 mybn = mybn
 Use my batchnorm?
 
 activation = torch.nn.ReLU(inplace=True)
 activation
 
str init = G_init
 Initialization style.
 
str G_param = G_param
 Parameterization style.
 
 norm_style = norm_style
 Normalization style.
 
 BN_eps = BN_eps
 Epsilon for BatchNorm?
 
 SN_eps = SN_eps
 Epsilon for Spectral Norm?
 
 fp16 = G_fp16
 fp16?
 
 arch = G_arch(self.ch, self.attention)[resolution]
 Architecture dict.
 
 RRM_prx_G = RRM_prx_G
 RRM_prx_G.
 
 n_head_G = n_head_G
 n_head_G
 
 which_conv
 which conv
 
 which_linear
 which linear
 
 which_embedding = nn.Embedding
 which embedding
 
 which_bn
 which bn
 
tuple shared
 shared
 
 RR_G
 RRM on proxy embeddings.
 
 linear
 First linear layer.
 
list blocks = []
 blocks
 
 output_layer
 output layer
 
 lr = G_lr
 lr
 
 B1 = G_B1
 B1.
 
 B2 = G_B2
 B2.
 
 adam_eps = adam_eps
 adam_eps
 
 optim
 optim
 
 lr_sched = None
 lr sched
 
int param_count = 0
 parameter count
 

Detailed Description

Generator

Definition at line 1155 of file ieagan.py.

Constructor & Destructor Documentation

◆ __init__()

__init__ ( self,
G_ch = 64,
G_depth = 2,
dim_z = 128,
bottom_width = 4,
resolution = 256,
G_kernel_size = 3,
G_attn = "64",
n_classes = 40,
H_base = 1,
num_G_SVs = 1,
num_G_SV_itrs = 1,
attn_type = "sa",
G_shared = True,
shared_dim = 128,
hier = True,
cross_replica = False,
mybn = False,
G_activation = nn.ReLU(inplace=False),
G_lr = 5e-5,
G_B1 = 0.0,
G_B2 = 0.999,
adam_eps = 1e-8,
BN_eps = 1e-5,
SN_eps = 1e-12,
G_init = "ortho",
G_mixed_precision = False,
G_fp16 = False,
skip_init = False,
no_optim = False,
sched_version = "default",
RRM_prx_G = True,
n_head_G = 2,
G_param = "SN",
norm_style = "bn",
** kwargs )

Constructor.

Definition at line 1159 of file ieagan.py.

1196 ):
1197 super(Generator, self).__init__()
1198
1199 self.ch = G_ch
1200
1201 self.G_depth = G_depth
1202
1203 self.dim_z = dim_z
1204
1205 self.bottom_width = bottom_width
1206
1207 self.H_base = H_base
1208
1209 self.resolution = resolution
1210
1211 self.kernel_size = G_kernel_size
1212
1213 self.attention = G_attn
1214
1215 self.n_classes = n_classes
1216
1217 self.G_shared = G_shared
1218
1219 self.shared_dim = shared_dim if shared_dim > 0 else dim_z
1220
1221 self.hier = hier
1222
1223 self.cross_replica = cross_replica
1224
1225 self.mybn = mybn
1226 # nonlinearity for residual blocks
1227 if G_activation == "inplace_relu":
1228
1229 self.activation = torch.nn.ReLU(inplace=True)
1230 elif G_activation == "relu":
1231 self.activation = torch.nn.ReLU(inplace=False)
1232 elif G_activation == "leaky_relu":
1233 self.activation = torch.nn.LeakyReLU(0.2, inplace=False)
1234 else:
1235 raise NotImplementedError("activation function not implemented")
1236
1237 self.init = G_init
1238
1239 self.G_param = G_param
1240
1241 self.norm_style = norm_style
1242
1243 self.BN_eps = BN_eps
1244
1245 self.SN_eps = SN_eps
1246
1247 self.fp16 = G_fp16
1248
1249 self.arch = G_arch(self.ch, self.attention)[resolution]
1250
1251 self.RRM_prx_G = RRM_prx_G
1252
1253 self.n_head_G = n_head_G
1254
1255 # Which convs, batchnorms, and linear layers to use
1256 if self.G_param == "SN":
1257
1258 self.which_conv = functools.partial(
1259 SNConv2d,
1260 kernel_size=3,
1261 padding=1,
1262 num_svs=num_G_SVs,
1263 num_itrs=num_G_SV_itrs,
1264 eps=self.SN_eps,
1265 )
1266
1267 self.which_linear = functools.partial(
1268 SNLinear,
1269 num_svs=num_G_SVs,
1270 num_itrs=num_G_SV_itrs,
1271 eps=self.SN_eps,
1272 )
1273 else:
1274 self.which_conv = functools.partial(nn.Conv2d, kernel_size=3, padding=1)
1275 self.which_linear = nn.Linear
1276
1277 # We use a non-spectral-normed embedding here regardless;
1278 # For some reason applying SN to G's embedding seems to randomly cripple G # noqa
1279
1280 self.which_embedding = nn.Embedding
1281 bn_linear = (
1282 functools.partial(self.which_linear, bias=False)
1283 if self.G_shared
1284 else self.which_embedding
1285 )
1286
1287 self.which_bn = functools.partial(
1288 ccbn,
1289 which_linear=bn_linear,
1290 cross_replica=self.cross_replica,
1291 mybn=self.mybn,
1292 input_size=(
1293 self.shared_dim + self.dim_z if self.G_shared else self.n_classes
1294 ),
1295 norm_style=self.norm_style,
1296 eps=self.BN_eps,
1297 )
1298
1299 self.shared = (
1300 self.which_embedding(n_classes, self.shared_dim)
1301 if G_shared
1302 else identity()
1303 )
1304
1305 if self.RRM_prx_G:
1306
1307 self.RR_G = RelationalReasoning(
1308 num_layers=1,
1309 input_dim=128,
1310 dim_feedforward=128,
1311 which_linear=nn.Linear,
1312 num_heads=self.n_head_G,
1313 dropout=0.0,
1314 hidden_dim=128,
1315 )
1316
1317
1318 self.linear = self.which_linear(
1319 self.dim_z + self.shared_dim,
1320 self.arch["in_channels"][0] * ((self.bottom_width**2) * self.H_base),
1321 )
1322
1323 # self.blocks is a doubly-nested list of modules, the outer loop intended # noqa
1324 # to be over blocks at a given resolution (resblocks and/or self-attention) # noqa
1325 # while the inner loop is over a given block
1326
1327 self.blocks = []
1328 for index in range(len(self.arch["out_channels"])):
1329 self.blocks += [
1330 [
1331 GBlock(
1332 in_channels=self.arch["in_channels"][index],
1333 out_channels=self.arch["in_channels"][index]
1334 if g_index == 0
1335 else self.arch["out_channels"][index],
1336 which_conv=self.which_conv,
1337 which_bn=self.which_bn,
1338 activation=self.activation,
1339 upsample=(
1340 functools.partial(F.interpolate, scale_factor=2)
1341 if self.arch["upsample"][index]
1342 and g_index == (self.G_depth - 1)
1343 else None
1344 ),
1345 )
1346 ]
1347 for g_index in range(self.G_depth)
1348 ]
1349
1350 # If attention on this block, attach it to the end
1351 if self.arch["attention"][self.arch["resolution"][index]]:
1352 print(
1353 f"Adding attention layer in G at resolution {self.arch['resolution'][index]:d}"
1354 )
1355
1356 if attn_type == "sa":
1357 self.blocks[-1] += [
1358 Attention(self.arch["out_channels"][index], self.which_conv)
1359 ]
1360 elif attn_type == "cbam":
1361 self.blocks[-1] += [
1362 CBAM_attention(
1363 self.arch["out_channels"][index], self.which_conv
1364 )
1365 ]
1366 elif attn_type == "ila":
1367 self.blocks[-1] += [ILA(self.arch["out_channels"][index])]
1368
1369 # Turn self.blocks into a ModuleList so that it's all properly registered. # noqa
1370 self.blocks = nn.ModuleList([nn.ModuleList(block) for block in self.blocks])
1371
1372 # output layer: batchnorm-relu-conv.
1373 # Consider using a non-spectral conv here
1374
1375 self.output_layer = nn.Sequential(
1376 bn(
1377 self.arch["out_channels"][-1],
1378 cross_replica=self.cross_replica,
1379 mybn=self.mybn,
1380 ),
1381 self.activation,
1382 self.which_conv(self.arch["out_channels"][-1], 1),
1383 )
1384
1385 # Initialize weights. Optionally skip init for testing.
1386 if not skip_init:
1387 self.init_weights()
1388
1389 # Set up optimizer
1390 # If this is an EMA copy, no need for an optim, so just return now
1391 if no_optim:
1392 return
1393
1394 self.lr = G_lr
1395
1396 self.B1 = G_B1
1397
1398 self.B2 = G_B2
1399
1400 self.adam_eps = adam_eps
1401 if G_mixed_precision:
1402 print("Using fp16 adam in G...")
1403 import utils
1404
1405 self.optim = utils.Adam16(
1406 params=self.parameters(),
1407 lr=self.lr,
1408 betas=(self.B1, self.B2),
1409 weight_decay=0,
1410 eps=self.adam_eps,
1411 )
1412
1413
1414 self.optim = optim.Adam(
1415 params=self.parameters(),
1416 lr=self.lr,
1417 betas=(self.B1, self.B2),
1418 weight_decay=0,
1419 eps=self.adam_eps,
1420 )
1421 # LR scheduling
1422 if sched_version == "default":
1423
1424 self.lr_sched = None
1425 elif sched_version == "CosAnnealLR":
1426 self.lr_sched = optim.lr_scheduler.CosineAnnealingLR(
1427 self.optim,
1428 T_max=kwargs["num_epochs"],
1429 eta_min=self.lr / 4,
1430 last_epoch=-1,
1431 )
1432 elif sched_version == "CosAnnealWarmRes":
1433 self.lr_sched = optim.lr_scheduler.CosineAnnealingWarmRestarts(
1434 self.optim, T_0=10, T_mult=2, eta_min=self.lr / 4
1435 )
1436 else:
1437 self.lr_sched = None
1438

Member Function Documentation

◆ forward()

forward ( self,
z,
y )

forward

Definition at line 1463 of file ieagan.py.

1463 def forward(self, z, y):
1464 y = self.shared(y)
1465 # If relational embedding
1466 if self.RRM_prx_G:
1467 y = self.RR_G(y.unsqueeze(0)).squeeze(0)
1468 # y = F.normalize(y, dim=1)
1469 # If hierarchical, concatenate zs and ys
1470 if self.hier: # y and z are [bs,128] dimensional
1471 z = torch.cat([y, z], 1)
1472 y = z
1473 # First linear layer
1474 h = self.linear(z) # ([bs,256]-->[bs,24576])
1475 # Reshape
1476 h = h.view(h.size(0), -1, self.bottom_width, self.bottom_width * self.H_base)
1477 # Loop over blocks
1478 for _, blocklist in enumerate(self.blocks):
1479 # Second inner loop in case block has multiple layers
1480 for block in blocklist:
1481 h = block(h, y)
1482
1483 # Apply batchnorm-relu-conv-tanh at output
1484 return torch.tanh(self.output_layer(h))
1485
1486

◆ init_weights()

init_weights ( self)

Initialize.

Definition at line 1440 of file ieagan.py.

1440 def init_weights(self):
1441
1442 self.param_count = 0
1443 for module in self.modules():
1444 if (
1445 isinstance(module, nn.Conv2d)
1446 or isinstance(module, nn.Linear)
1447 or isinstance(module, nn.Embedding)
1448 ):
1449 if self.init == "ortho":
1450 init.orthogonal_(module.weight)
1451 elif self.init == "N02":
1452 init.normal_(module.weight, 0, 0.02)
1453 elif self.init in ["glorot", "xavier"]:
1454 init.xavier_uniform_(module.weight)
1455 else:
1456 print("Init style not recognized...")
1457 self.param_count += sum(
1458 [p.data.nelement() for p in module.parameters()]
1459 )
1460 print(f"Param count for G's initialized parameters: {self.param_count}")
1461

Member Data Documentation

◆ activation

activation = torch.nn.ReLU(inplace=True)

activation

Definition at line 1229 of file ieagan.py.

◆ adam_eps

adam_eps = adam_eps

adam_eps

Definition at line 1400 of file ieagan.py.

◆ arch

arch = G_arch(self.ch, self.attention)[resolution]

Architecture dict.

Definition at line 1249 of file ieagan.py.

◆ attention

attention = G_attn

Attention?

Definition at line 1213 of file ieagan.py.

◆ B1

B1 = G_B1

B1.

Definition at line 1396 of file ieagan.py.

◆ B2

B2 = G_B2

B2.

Definition at line 1398 of file ieagan.py.

◆ blocks

blocks = []

blocks

Definition at line 1327 of file ieagan.py.

◆ BN_eps

BN_eps = BN_eps

Epsilon for BatchNorm?

Definition at line 1243 of file ieagan.py.

◆ bottom_width

bottom_width = bottom_width

The initial spatial dimensions.

Definition at line 1205 of file ieagan.py.

◆ ch

ch = G_ch

Channel width multiplier.

Definition at line 1199 of file ieagan.py.

◆ cross_replica

cross_replica = cross_replica

Cross replica batchnorm?

Definition at line 1223 of file ieagan.py.

◆ dim_z

dim_z = dim_z

Dimensionality of the latent space.

Definition at line 1203 of file ieagan.py.

◆ fp16

fp16 = G_fp16

fp16?

Definition at line 1247 of file ieagan.py.

◆ G_depth

G_depth = G_depth

Number of resblocks per stage.

Definition at line 1201 of file ieagan.py.

◆ G_param

str G_param = G_param

Parameterization style.

Definition at line 1239 of file ieagan.py.

◆ G_shared

G_shared = G_shared

Use shared embeddings?

Definition at line 1217 of file ieagan.py.

◆ H_base

H_base = H_base

The initial harizontal dimension.

Definition at line 1207 of file ieagan.py.

◆ hier

# y and z are [bs,128] dimensional hier = hier

Hierarchical latent space?

Definition at line 1221 of file ieagan.py.

◆ init

str init = G_init

Initialization style.

Definition at line 1237 of file ieagan.py.

◆ kernel_size

kernel_size = G_kernel_size

Kernel size?

Definition at line 1211 of file ieagan.py.

◆ linear

linear
Initial value:
= self.which_linear(
self.dim_z + self.shared_dim,
self.arch["in_channels"][0] * ((self.bottom_width**2) * self.H_base),
)

First linear layer.

Definition at line 1318 of file ieagan.py.

◆ lr

lr = G_lr

lr

Definition at line 1394 of file ieagan.py.

◆ lr_sched

lr_sched = None

lr sched

Definition at line 1424 of file ieagan.py.

◆ mybn

mybn = mybn

Use my batchnorm?

Definition at line 1225 of file ieagan.py.

◆ n_classes

n_classes = n_classes

number of classes, for use in categorical conditional generation

Definition at line 1215 of file ieagan.py.

◆ n_head_G

n_head_G = n_head_G

n_head_G

Definition at line 1253 of file ieagan.py.

◆ norm_style

norm_style = norm_style

Normalization style.

Definition at line 1241 of file ieagan.py.

◆ optim

optim
Initial value:
= utils.Adam16(
params=self.parameters(),
lr=self.lr,
betas=(self.B1, self.B2),
weight_decay=0,
eps=self.adam_eps,
)

optim

Definition at line 1405 of file ieagan.py.

◆ output_layer

output_layer
Initial value:
= nn.Sequential(
bn(
self.arch["out_channels"][-1],
cross_replica=self.cross_replica,
mybn=self.mybn,
),
self.activation,
self.which_conv(self.arch["out_channels"][-1], 1),
)

output layer

Definition at line 1375 of file ieagan.py.

◆ param_count

int param_count = 0

parameter count

Definition at line 1442 of file ieagan.py.

◆ resolution

resolution = resolution

Resolution of the output.

Definition at line 1209 of file ieagan.py.

◆ RR_G

RR_G
Initial value:
= RelationalReasoning(
num_layers=1,
input_dim=128,
dim_feedforward=128,
which_linear=nn.Linear,
num_heads=self.n_head_G,
dropout=0.0,
hidden_dim=128,
)

RRM on proxy embeddings.

Definition at line 1307 of file ieagan.py.

◆ RRM_prx_G

RRM_prx_G = RRM_prx_G

RRM_prx_G.

Definition at line 1251 of file ieagan.py.

◆ shared

tuple shared
Initial value:
= (
self.which_embedding(n_classes, self.shared_dim)
if G_shared
else identity()
)

shared

Definition at line 1299 of file ieagan.py.

◆ shared_dim

int shared_dim = shared_dim if shared_dim > 0 else dim_z

Dimensionality of the shared embedding?

Unused if not using G_shared

Definition at line 1219 of file ieagan.py.

◆ SN_eps

SN_eps = SN_eps

Epsilon for Spectral Norm?

Definition at line 1245 of file ieagan.py.

◆ which_bn

which_bn
Initial value:
= functools.partial(
ccbn,
which_linear=bn_linear,
cross_replica=self.cross_replica,
mybn=self.mybn,
input_size=(
self.shared_dim + self.dim_z if self.G_shared else self.n_classes
),
norm_style=self.norm_style,
eps=self.BN_eps,
)

which bn

Definition at line 1287 of file ieagan.py.

◆ which_conv

which_conv
Initial value:
= functools.partial(
SNConv2d,
kernel_size=3,
padding=1,
num_svs=num_G_SVs,
num_itrs=num_G_SV_itrs,
eps=self.SN_eps,
)

which conv

Definition at line 1258 of file ieagan.py.

◆ which_embedding

which_embedding = nn.Embedding

which embedding

Definition at line 1280 of file ieagan.py.

◆ which_linear

which_linear
Initial value:
= functools.partial(
SNLinear,
num_svs=num_G_SVs,
num_itrs=num_G_SV_itrs,
eps=self.SN_eps,
)

which linear

Definition at line 1267 of file ieagan.py.


The documentation for this class was generated from the following file: