Constructor.
1196 ):
1197 super(Generator, self).__init__()
1198
1199 self.ch = G_ch
1200
1201 self.G_depth = G_depth
1202
1203 self.dim_z = dim_z
1204
1205 self.bottom_width = bottom_width
1206
1207 self.H_base = H_base
1208
1209 self.resolution = resolution
1210
1211 self.kernel_size = G_kernel_size
1212
1213 self.attention = G_attn
1214
1215 self.n_classes = n_classes
1216
1217 self.G_shared = G_shared
1218
1219 self.shared_dim = shared_dim if shared_dim > 0 else dim_z
1220
1221 self.hier = hier
1222
1223 self.cross_replica = cross_replica
1224
1225 self.mybn = mybn
1226
1227 if G_activation == "inplace_relu":
1228
1229 self.activation = torch.nn.ReLU(inplace=True)
1230 elif G_activation == "relu":
1231 self.activation = torch.nn.ReLU(inplace=False)
1232 elif G_activation == "leaky_relu":
1233 self.activation = torch.nn.LeakyReLU(0.2, inplace=False)
1234 else:
1235 raise NotImplementedError("activation function not implemented")
1236
1237 self.init = G_init
1238
1239 self.G_param = G_param
1240
1241 self.norm_style = norm_style
1242
1243 self.BN_eps = BN_eps
1244
1245 self.SN_eps = SN_eps
1246
1247 self.fp16 = G_fp16
1248
1249 self.arch = G_arch(self.ch, self.attention)[resolution]
1250
1251 self.RRM_prx_G = RRM_prx_G
1252
1253 self.n_head_G = n_head_G
1254
1255
1256 if self.G_param == "SN":
1257
1258 self.which_conv = functools.partial(
1259 SNConv2d,
1260 kernel_size=3,
1261 padding=1,
1262 num_svs=num_G_SVs,
1263 num_itrs=num_G_SV_itrs,
1264 eps=self.SN_eps,
1265 )
1266
1267 self.which_linear = functools.partial(
1268 SNLinear,
1269 num_svs=num_G_SVs,
1270 num_itrs=num_G_SV_itrs,
1271 eps=self.SN_eps,
1272 )
1273 else:
1274 self.which_conv = functools.partial(nn.Conv2d, kernel_size=3, padding=1)
1275 self.which_linear = nn.Linear
1276
1277
1278
1279
1280 self.which_embedding = nn.Embedding
1281 bn_linear = (
1282 functools.partial(self.which_linear, bias=False)
1283 if self.G_shared
1284 else self.which_embedding
1285 )
1286
1287 self.which_bn = functools.partial(
1288 ccbn,
1289 which_linear=bn_linear,
1290 cross_replica=self.cross_replica,
1291 mybn=self.mybn,
1292 input_size=(
1293 self.shared_dim + self.dim_z if self.G_shared else self.n_classes
1294 ),
1295 norm_style=self.norm_style,
1296 eps=self.BN_eps,
1297 )
1298
1299 self.shared = (
1300 self.which_embedding(n_classes, self.shared_dim)
1301 if G_shared
1302 else identity()
1303 )
1304
1305 if self.RRM_prx_G:
1306
1307 self.RR_G = RelationalReasoning(
1308 num_layers=1,
1309 input_dim=128,
1310 dim_feedforward=128,
1311 which_linear=nn.Linear,
1312 num_heads=self.n_head_G,
1313 dropout=0.0,
1314 hidden_dim=128,
1315 )
1316
1317
1318 self.linear = self.which_linear(
1319 self.dim_z + self.shared_dim,
1320 self.arch["in_channels"][0] * ((self.bottom_width**2) * self.H_base),
1321 )
1322
1323
1324
1325
1326
1327 self.blocks = []
1328 for index in range(len(self.arch["out_channels"])):
1329 self.blocks += [
1330 [
1331 GBlock(
1332 in_channels=self.arch["in_channels"][index],
1333 out_channels=self.arch["in_channels"][index]
1334 if g_index == 0
1335 else self.arch["out_channels"][index],
1336 which_conv=self.which_conv,
1337 which_bn=self.which_bn,
1338 activation=self.activation,
1339 upsample=(
1340 functools.partial(F.interpolate, scale_factor=2)
1341 if self.arch["upsample"][index]
1342 and g_index == (self.G_depth - 1)
1343 else None
1344 ),
1345 )
1346 ]
1347 for g_index in range(self.G_depth)
1348 ]
1349
1350
1351 if self.arch["attention"][self.arch["resolution"][index]]:
1352 print(
1353 f"Adding attention layer in G at resolution {self.arch['resolution'][index]:d}"
1354 )
1355
1356 if attn_type == "sa":
1357 self.blocks[-1] += [
1358 Attention(self.arch["out_channels"][index], self.which_conv)
1359 ]
1360 elif attn_type == "cbam":
1361 self.blocks[-1] += [
1362 CBAM_attention(
1363 self.arch["out_channels"][index], self.which_conv
1364 )
1365 ]
1366 elif attn_type == "ila":
1367 self.blocks[-1] += [ILA(self.arch["out_channels"][index])]
1368
1369
1370 self.blocks = nn.ModuleList([nn.ModuleList(block) for block in self.blocks])
1371
1372
1373
1374
1375 self.output_layer = nn.Sequential(
1376 bn(
1377 self.arch["out_channels"][-1],
1378 cross_replica=self.cross_replica,
1379 mybn=self.mybn,
1380 ),
1381 self.activation,
1382 self.which_conv(self.arch["out_channels"][-1], 1),
1383 )
1384
1385
1386 if not skip_init:
1387 self.init_weights()
1388
1389
1390
1391 if no_optim:
1392 return
1393
1394 self.lr = G_lr
1395
1396 self.B1 = G_B1
1397
1398 self.B2 = G_B2
1399
1400 self.adam_eps = adam_eps
1401 if G_mixed_precision:
1402 print("Using fp16 adam in G...")
1403 import utils
1404
1405 self.optim = utils.Adam16(
1406 params=self.parameters(),
1407 lr=self.lr,
1408 betas=(self.B1, self.B2),
1409 weight_decay=0,
1410 eps=self.adam_eps,
1411 )
1412
1413
1414 self.optim = optim.Adam(
1415 params=self.parameters(),
1416 lr=self.lr,
1417 betas=(self.B1, self.B2),
1418 weight_decay=0,
1419 eps=self.adam_eps,
1420 )
1421
1422 if sched_version == "default":
1423
1424 self.lr_sched = None
1425 elif sched_version == "CosAnnealLR":
1426 self.lr_sched = optim.lr_scheduler.CosineAnnealingLR(
1427 self.optim,
1428 T_max=kwargs["num_epochs"],
1429 eta_min=self.lr / 4,
1430 last_epoch=-1,
1431 )
1432 elif sched_version == "CosAnnealWarmRes":
1433 self.lr_sched = optim.lr_scheduler.CosineAnnealingWarmRestarts(
1434 self.optim, T_0=10, T_mult=2, eta_min=self.lr / 4
1435 )
1436 else:
1437 self.lr_sched = None
1438