Constructor.
1168 ):
1169 super(Generator, self).__init__()
1170
1171 self.ch = G_ch
1172
1173 self.G_depth = G_depth
1174
1175 self.dim_z = dim_z
1176
1177 self.bottom_width = bottom_width
1178
1179 self.H_base = H_base
1180
1181 self.resolution = resolution
1182
1183 self.kernel_size = G_kernel_size
1184
1185 self.attention = G_attn
1186
1187 self.n_classes = n_classes
1188
1189 self.G_shared = G_shared
1190
1191 self.shared_dim = shared_dim if shared_dim > 0 else dim_z
1192
1193 self.hier = hier
1194
1195 self.cross_replica = cross_replica
1196
1197 self.mybn = mybn
1198
1199 if G_activation == "inplace_relu":
1200
1201 self.activation = torch.nn.ReLU(inplace=True)
1202 elif G_activation == "relu":
1203 self.activation = torch.nn.ReLU(inplace=False)
1204 elif G_activation == "leaky_relu":
1205 self.activation = torch.nn.LeakyReLU(0.2, inplace=False)
1206 else:
1207 raise NotImplementedError("activation function not implemented")
1208
1209 self.init = G_init
1210
1211 self.G_param = G_param
1212
1213 self.norm_style = norm_style
1214
1215 self.BN_eps = BN_eps
1216
1217 self.SN_eps = SN_eps
1218
1219 self.fp16 = G_fp16
1220
1221 self.arch = G_arch(self.ch, self.attention)[resolution]
1222
1223 self.RRM_prx_G = RRM_prx_G
1224
1225 self.n_head_G = n_head_G
1226
1227
1228 if self.G_param == "SN":
1229
1230 self.which_conv = functools.partial(
1231 SNConv2d,
1232 kernel_size=3,
1233 padding=1,
1234 num_svs=num_G_SVs,
1235 num_itrs=num_G_SV_itrs,
1236 eps=self.SN_eps,
1237 )
1238
1239 self.which_linear = functools.partial(
1240 SNLinear,
1241 num_svs=num_G_SVs,
1242 num_itrs=num_G_SV_itrs,
1243 eps=self.SN_eps,
1244 )
1245 else:
1246 self.which_conv = functools.partial(nn.Conv2d, kernel_size=3, padding=1)
1247 self.which_linear = nn.Linear
1248
1249
1250
1251
1252 self.which_embedding = nn.Embedding
1253 bn_linear = (
1254 functools.partial(self.which_linear, bias=False)
1255 if self.G_shared
1256 else self.which_embedding
1257 )
1258
1259 self.which_bn = functools.partial(
1260 ccbn,
1261 which_linear=bn_linear,
1262 cross_replica=self.cross_replica,
1263 mybn=self.mybn,
1264 input_size=(
1265 self.shared_dim + self.dim_z if self.G_shared else self.n_classes
1266 ),
1267 norm_style=self.norm_style,
1268 eps=self.BN_eps,
1269 )
1270
1271 self.shared = (
1272 self.which_embedding(n_classes, self.shared_dim)
1273 if G_shared
1274 else identity()
1275 )
1276
1277 if self.RRM_prx_G:
1278
1279 self.RR_G = RelationalReasoning(
1280 num_layers=1,
1281 input_dim=128,
1282 dim_feedforward=128,
1283 which_linear=nn.Linear,
1284 num_heads=self.n_head_G,
1285 dropout=0.0,
1286 hidden_dim=128,
1287 )
1288
1289
1290 self.linear = self.which_linear(
1291 self.dim_z + self.shared_dim,
1292 self.arch["in_channels"][0] * ((self.bottom_width**2) * self.H_base),
1293 )
1294
1295
1296
1297
1298
1299 self.blocks = []
1300 for index in range(len(self.arch["out_channels"])):
1301 self.blocks += [
1302 [
1303 GBlock(
1304 in_channels=self.arch["in_channels"][index],
1305 out_channels=self.arch["in_channels"][index]
1306 if g_index == 0
1307 else self.arch["out_channels"][index],
1308 which_conv=self.which_conv,
1309 which_bn=self.which_bn,
1310 activation=self.activation,
1311 upsample=(
1312 functools.partial(F.interpolate, scale_factor=2)
1313 if self.arch["upsample"][index]
1314 and g_index == (self.G_depth - 1)
1315 else None
1316 ),
1317 )
1318 ]
1319 for g_index in range(self.G_depth)
1320 ]
1321
1322
1323 if self.arch["attention"][self.arch["resolution"][index]]:
1324 print(
1325 f"Adding attention layer in G at resolution {self.arch['resolution'][index]:d}"
1326 )
1327
1328 if attn_type == "sa":
1329 self.blocks[-1] += [
1330 Attention(self.arch["out_channels"][index], self.which_conv)
1331 ]
1332 elif attn_type == "cbam":
1333 self.blocks[-1] += [
1334 CBAM_attention(
1335 self.arch["out_channels"][index], self.which_conv
1336 )
1337 ]
1338 elif attn_type == "ila":
1339 self.blocks[-1] += [ILA(self.arch["out_channels"][index])]
1340
1341
1342 self.blocks = nn.ModuleList([nn.ModuleList(block) for block in self.blocks])
1343
1344
1345
1346
1347 self.output_layer = nn.Sequential(
1348 bn(
1349 self.arch["out_channels"][-1],
1350 cross_replica=self.cross_replica,
1351 mybn=self.mybn,
1352 ),
1353 self.activation,
1354 self.which_conv(self.arch["out_channels"][-1], 1),
1355 )
1356
1357
1358 if not skip_init:
1359 self.init_weights()
1360
1361
1362
1363 if no_optim:
1364 return
1365
1366 self.lr = G_lr
1367
1368 self.B1 = G_B1
1369
1370 self.B2 = G_B2
1371
1372 self.adam_eps = adam_eps
1373 if G_mixed_precision:
1374 print("Using fp16 adam in G...")
1375 import utils
1376
1377 self.optim = utils.Adam16(
1378 params=self.parameters(),
1379 lr=self.lr,
1380 betas=(self.B1, self.B2),
1381 weight_decay=0,
1382 eps=self.adam_eps,
1383 )
1384
1385
1386 self.optim = optim.Adam(
1387 params=self.parameters(),
1388 lr=self.lr,
1389 betas=(self.B1, self.B2),
1390 weight_decay=0,
1391 eps=self.adam_eps,
1392 )
1393
1394 if sched_version == "default":
1395
1396 self.lr_sched = None
1397 elif sched_version == "CosAnnealLR":
1398 self.lr_sched = optim.lr_scheduler.CosineAnnealingLR(
1399 self.optim,
1400 T_max=kwargs["num_epochs"],
1401 eta_min=self.lr / 4,
1402 last_epoch=-1,
1403 )
1404 elif sched_version == "CosAnnealWarmRes":
1405 self.lr_sched = optim.lr_scheduler.CosineAnnealingWarmRestarts(
1406 self.optim, T_0=10, T_mult=2, eta_min=self.lr / 4
1407 )
1408 else:
1409 self.lr_sched = None
1410