9This module implements the IEA-GAN generator model.
18from torch
import optim
21import torch.nn.functional
as F
31 "use_multiepoch_sampler": false,
32 "model":
"BigGAN_deep",
46 "cross_replica": false,
48 "G_activation":
"inplace_relu",
49 "D_activation":
"inplace_relu",
64 "num_G_accumulations": 1,
66 "num_D_accumulations": 1,
72 "D_mixed_precision": false,
73 "G_mixed_precision": false,
74 "accumulate_stats": false,
75 "num_standing_accumulations": 16,
96 "sv_log_interval": 10,
100 "run_name":
"BGd_140",
103 "latent_reg_weight": 300,
108 "conditional_strategy":
"Contra",
109 "hypersphere_dim": 1024,
110 "pos_collected_numerator": false,
111 "nonlinear_embed": false,
112 "normalize_embed": true,
113 "inv_stereographic" :false,
114 "contra_lambda": 1.0,
119 "Uniformity_loss": true,
127 "normalized_proxy_G": false,
133 "sched_version":
"default",
135 "truncated_threshold": 1.0,
144 "stop_after": 100000,
147 "metric_log_name":
"metric_log.jsonl",
148 "reinitialize_metric_logs": false,
149 "reinitialize_parameter_logs": false,
150 "num_incep_images": 16000,
151 "load_optim": true}
"""
157 Projection of x onto y
159 return torch.mm(y, x.t()) * y / torch.mm(y, y.t())
162def gram_schmidt(x, ys):
164 Orthogonalize x wrt list of vectors ys
171def power_iteration(W, u_, update=True, eps=1e-12):
173 Apply num_itrs steps of the power method to estimate top N singular values.
176 us, vs, svs = [], [], []
177 for i, u
in enumerate(u_):
179 with torch.no_grad():
180 v = torch.matmul(u, W)
182 v = F.normalize(gram_schmidt(v, vs), eps=eps)
186 u = torch.matmul(v, W.t())
188 u = F.normalize(gram_schmidt(u, us), eps=eps)
194 svs += [torch.squeeze(torch.matmul(torch.matmul(v, W.t()), u.t()))]
199def groupnorm(x, norm_style):
201 Simple function to handle groupnorm norm stylization
204 if "ch" in norm_style:
205 ch = int(norm_style.split(
"_")[-1])
206 groups = max(int(x.shape[1]) // ch, 1)
208 elif "grp" in norm_style:
209 groups = int(norm_style.split(
"_")[-1])
213 return F.group_norm(x, groups)
218 Convenience passthrough function
228 Spectral normalization base class
232 def __init__(self, num_svs, num_itrs, num_outputs, transpose=False, eps=1e-12):
245 self.register_buffer(f
"u{i:d}", torch.randn(1, num_outputs))
246 self.register_buffer(f
"sv{i:d}", torch.ones(1))
251 Singular vectors (u side)
253 return [getattr(self, f
"u{i:d}")
for i
in range(self.
num_svs)]
259 note that these buffers are just for logging
and are
not used
in training.
261 return [getattr(self, f
"sv{i:d}")
for i
in range(self.
num_svs)]
265 Compute the spectrally-normalized weight
267 W_mat = self.weight.view(self.weight.size(0), -1)
272 svs, _, _ = power_iteration(
273 W_mat, self.
u, update=self.training, eps=self.
eps
278 with torch.no_grad():
279 for i, sv
in enumerate(svs):
281 return self.weight / svs[0]
286 2D Conv layer with spectral norm
315 SN.__init__(self, num_svs, num_itrs, out_channels, eps=eps)
332 Linear layer with spectral norm
345 nn.Linear.__init__(self, in_features, out_features, bias)
346 SN.__init__(self, num_svs, num_itrs, out_features, eps=eps)
350 return F.linear(x, self.
W_(), self.bias)
353def fused_bn(x, mean, var, gain=None, bias=None, eps=1e-5):
354 """Fused batchnorm op"""
358 scale = torch.rsqrt(var + eps)
367 return x * scale - shift
371def manual_bn(x, gain=None, bias=None, return_mean_var=False, eps=1e-5):
374 Calculate means and variances using mean-of-squares minus mean-squared
381 m = torch.mean(float_x, [0, 2, 3], keepdim=
True)
383 m2 = torch.mean(float_x**2, [0, 2, 3], keepdim=
True)
387 var = var.type(x.type())
392 fused_bn(x, m, var, gain, bias, eps),
397 return fused_bn(x, m, var, gain, bias, eps)
402 My batchnorm, supports standing stats
406 def __init__(self, num_channels, eps=1e-5, momentum=0.1):
415 self.register_buffer(
"stored_mean", torch.zeros(num_channels))
416 self.register_buffer(
"stored_var", torch.ones(num_channels))
417 self.register_buffer(
"accumulation_counter", torch.zeros(1))
424 self.stored_mean[:] = 0
425 self.stored_var[:] = 0
426 self.accumulation_counter[:] = 0
432 out, mean, var = manual_bn(
433 x, gain, bias, return_mean_var=
True, eps=self.
eps
437 self.stored_mean[:] = self.stored_mean + mean.data
438 self.stored_var[:] = self.stored_var + var.data
439 self.accumulation_counter += 1.0
442 self.stored_mean[:] = (
445 self.stored_var[:] = (
451 mean = self.stored_mean.view(1, -1, 1, 1)
452 var = self.stored_var.view(1, -1, 1, 1)
455 mean = mean / self.accumulation_counter
456 var = var / self.accumulation_counter
457 return fused_bn(x, mean, var, gain, bias, self.
eps)
462 Normal, non-class-conditional BN
478 self.
gain = P(torch.ones(output_size), requires_grad=
True)
480 self.
bias = P(torch.zeros(output_size), requires_grad=
True)
494 self.register_buffer(
"stored_mean", torch.zeros(output_size))
495 self.register_buffer(
"stored_var", torch.ones(output_size))
500 gain = self.
gain.view(1, -1, 1, 1)
501 bias = self.
bias.view(1, -1, 1, 1)
502 return self.
bn(x, gain=gain, bias=bias)
519 output size is the number of channels, input size
is for the linear layers
520 Andy
's Note: this class feels messy but I'm
not really sure how to clean it up
521 Suggestions welcome! (By which I mean, refactor this
and make a merge request
522 if you want to make this more readable/usable).
541 self.
gain = which_linear(input_size, output_size)
543 self.
bias = which_linear(input_size, output_size)
559 self.register_buffer(
"stored_mean", torch.zeros(output_size))
560 self.register_buffer(
"stored_var", torch.ones(output_size))
565 gain = (1 + self.
gain(y)).view(y.size(0), -1, 1, 1)
566 bias = self.
bias(y).view(y.size(0), -1, 1, 1)
569 return self.
bn(x, gain=gain, bias=bias)
584 out = F.instance_norm(
595 out = groupnorm(x, self.normstyle)
598 return out * gain + bias
602 s =
"out: {output_size}, in: {input_size},"
603 s +=
" cross_replica={cross_replica}"
604 return s.format(**self.__dict__)
609 Image_Linear_Attention
628 chan_out = chan
if chan_out
is None else chan_out
640 conv_kwargs = {
"padding": padding,
"stride": stride}
642 self.
to_q = nn.Conv2d(chan, key_dim * heads, kernel_size, **conv_kwargs)
644 self.
to_k = nn.Conv2d(chan, key_dim * heads, kernel_size, **conv_kwargs)
646 self.
to_v = nn.Conv2d(chan, value_dim * heads, kernel_size, **conv_kwargs)
648 out_conv_kwargs = {
"padding": padding}
651 value_dim * heads, chan_out, kernel_size, **out_conv_kwargs
656 b, c, h, w, k_dim, heads = *x.shape, self.
key_dim, self.
heads
660 q, k, v = map(
lambda t: t.reshape(b, heads, -1, h * w), (q, k, v))
662 q, k = map(
lambda x: x * (self.
key_dim**-0.25), (q, k))
664 if context
is not None:
665 context = context.reshape(b, c, 1, -1)
666 ck, cv = self.
to_k(context), self.
to_v(context)
667 ck, cv = map(
lambda t: t.reshape(b, heads, k_dim, -1), (ck, cv))
668 k = torch.cat((k, ck), dim=3)
669 v = torch.cat((v, cv), dim=3)
671 k = k.softmax(dim=-1)
674 q = q.softmax(dim=-2)
676 context = torch.einsum(
"bhdn,bhen->bhde", k, v)
677 out = torch.einsum(
"bhdn,bhde->bhen", q, context)
678 out = out.reshape(b, -1, h, w)
692 attention_kernel_size=3,
694 super(CBAM_attention, self).
__init__()
701 channels, channels // reduction, kernel_size=1, padding=0
704 self.
relu = nn.ReLU(inplace=
True)
707 channels // reduction, channels, kernel_size=1, padding=0
715 kernel_size=attention_kernel_size,
717 padding=attention_kernel_size // 2,
740 avg = torch.mean(x, 1,
True)
741 mx, _ = torch.max(x, 1,
True)
742 x = torch.cat((avg, mx), 1)
761 self.
ch, self.
ch // 8, kernel_size=1, padding=0, bias=
False
765 self.
ch, self.
ch // 8, kernel_size=1, padding=0, bias=
False
769 self.
ch, self.
ch // 2, kernel_size=1, padding=0, bias=
False
773 self.
ch // 2, self.
ch, kernel_size=1, padding=0, bias=
False
776 self.
gamma = P(torch.tensor(0.0), requires_grad=
True)
781 theta = self.
theta(x)
782 phi = F.max_pool2d(self.
phi(x), [2, 2])
783 g = F.max_pool2d(self.
g(x), [2, 2])
785 theta = theta.view(-1, self.
ch // 8, x.shape[2] * x.shape[3])
786 phi = phi.view(-1, self.
ch // 8, x.shape[2] * x.shape[3] // 4)
787 g = g.view(-1, self.
ch // 2, x.shape[2] * x.shape[3] // 4)
789 beta = F.softmax(torch.bmm(theta.transpose(1, 2), phi), -1)
792 torch.bmm(g, beta.transpose(1, 2)).view(
793 -1, self.
ch // 2, x.shape[2], x.shape[3]
796 return self.
gamma * o + x
801 Embedding layer with spectral norm
802 We use num_embeddings
as the dim instead of embedding_dim here
814 scale_grad_by_freq=False,
821 nn.Embedding.__init__(
832 SN.__init__(self, num_svs, num_itrs, num_embeddings, eps=eps)
836 return F.embedding(x, self.
W_())
839def scaled_dot_product(q, k, v):
841 attn_logits = torch.matmul(q, k.transpose(-2, -1))
842 attn_logits = attn_logits / math.sqrt(d_k)
843 attention = F.softmax(attn_logits, dim=-1)
844 values = torch.matmul(attention, v)
845 return values, attention
849 """MultiheadAttention"""
852 def __init__(self, input_dim, embed_dim, num_heads, which_linear):
855 embed_dim % num_heads == 0
856 ),
"Embedding dimension must be 0 modulo number of heads."
878 nn.init.xavier_uniform_(self.
qkv_proj.weight)
880 nn.init.xavier_uniform_(self.
o_proj.weight)
881 self.
o_proj.bias.data.fill_(0)
885 batch_size, seq_length, embed_dim = x.size()
890 qkv = qkv.permute(0, 2, 1, 3)
891 q, k, v = qkv.chunk(3, dim=-1)
894 values, attention = scaled_dot_product(q, k, v)
895 values = values.permute(0, 2, 1, 3)
896 values = values.reshape(batch_size, seq_length, embed_dim)
909 def __init__(self, input_dim, num_heads, dim_feedforward, dropout, which_linear):
912 input_dim - Dimensionality of the input
913 num_heads - Number of heads to use in the attention block
914 dim_feedforward - Dimensionality of the hidden layer
in the MLP
915 dropout - Dropout probability to use
in the dropout layers
923 input_dim, input_dim, num_heads, which_linear
930 nn.ReLU(inplace=True),
936 self.
norm1 = nn.LayerNorm(input_dim)
938 self.
norm2 = nn.LayerNorm(input_dim)
945 x_pre1 = self.
norm1(x)
951 x_pre2 = self.
norm2(x)
953 x = x + self.
dropout(linear_out)
960 """RelationalReasoning"""
963 def __init__(self, num_layers, hidden_dim, **block_args):
970 self.
norm = nn.LayerNorm(hidden_dim)
984 _, attn_map = layer.self_attn(x, return_attention=
True)
985 attention_maps.append(attn_map)
987 return attention_maps
1060def G_arch(ch=64, attention="64"):
1063 "in_channels": [ch * item
for item
in [16, 16, 8, 8, 4, 2, 1]],
1064 "out_channels": [ch * item
for item
in [16, 8, 8, 4, 2, 1, 1]],
1065 "upsample": [
True] * 7,
1066 "resolution": [8, 16, 32, 64, 128, 256, 512],
1068 2**i: (2 ** i
in [int(item)
for item
in attention.split(
"_")])
1069 for i
in range(3, 10)
1073 "in_channels": [ch * item
for item
in [16, 16, 8, 8, 4, 2]],
1074 "out_channels": [ch * item
for item
in [16, 8, 8, 4, 2, 1]],
1075 "upsample": [
True] * 6,
1076 "resolution": [8, 16, 32, 64, 128, 256],
1078 2**i: (2 ** i
in [int(item)
for item
in attention.split(
"_")])
1079 for i
in range(3, 9)
1083 "in_channels": [ch * item
for item
in [16, 16, 8, 4, 2]],
1084 "out_channels": [ch * item
for item
in [16, 8, 4, 2, 1]],
1085 "upsample": [
True] * 5,
1086 "resolution": [8, 16, 32, 64, 128],
1088 2**i: (2 ** i
in [int(item)
for item
in attention.split(
"_")])
1089 for i
in range(3, 8)
1093 "in_channels": [ch * item
for item
in [16, 16, 8, 4]],
1094 "out_channels": [ch * item
for item
in [16, 8, 4, 2]],
1095 "upsample": [
True] * 4,
1096 "resolution": [12, 24, 48, 96],
1098 12 * 2**i: (6 * 2 ** i
in [int(item)
for item
in attention.split(
"_")])
1099 for i
in range(0, 4)
1104 "in_channels": [ch * item
for item
in [16, 16, 8, 4]],
1105 "out_channels": [ch * item
for item
in [16, 8, 4, 2]],
1106 "upsample": [
True] * 4,
1107 "resolution": [8, 16, 32, 64],
1109 2**i: (2 ** i
in [int(item)
for item
in attention.split(
"_")])
1110 for i
in range(3, 7)
1114 "in_channels": [ch * item
for item
in [4, 4, 4]],
1115 "out_channels": [ch * item
for item
in [4, 4, 4]],
1116 "upsample": [
True] * 3,
1117 "resolution": [8, 16, 32],
1119 2**i: (2 ** i
in [int(item)
for item
in attention.split(
"_")])
1120 for i
in range(3, 6)
1148 cross_replica=False,
1150 G_activation=nn.ReLU(inplace=
False),
1158 G_mixed_precision=
False,
1162 sched_version=
"default",
1199 if G_activation ==
"inplace_relu":
1202 elif G_activation ==
"relu":
1203 self.
activation = torch.nn.ReLU(inplace=
False)
1204 elif G_activation ==
"leaky_relu":
1205 self.
activation = torch.nn.LeakyReLU(0.2, inplace=
False)
1207 raise NotImplementedError(
"activation function not implemented")
1235 num_itrs=num_G_SV_itrs,
1242 num_itrs=num_G_SV_itrs,
1246 self.
which_conv = functools.partial(nn.Conv2d, kernel_size=3, padding=1)
1261 which_linear=bn_linear,
1282 dim_feedforward=128,
1283 which_linear=nn.Linear,
1300 for index
in range(len(self.
arch[
"out_channels"])):
1304 in_channels=self.
arch[
"in_channels"][index],
1305 out_channels=self.
arch[
"in_channels"][index]
1307 else self.
arch[
"out_channels"][index],
1312 functools.partial(F.interpolate, scale_factor=2)
1313 if self.
arch[
"upsample"][index]
1314 and g_index == (self.
G_depth - 1)
1319 for g_index
in range(self.
G_depth)
1323 if self.
arch[
"attention"][self.
arch[
"resolution"][index]]:
1325 f
"Adding attention layer in G at resolution {self.arch['resolution'][index]:d}"
1328 if attn_type ==
"sa":
1332 elif attn_type ==
"cbam":
1338 elif attn_type ==
"ila":
1342 self.
blocks = nn.ModuleList([nn.ModuleList(block)
for block
in self.
blocks])
1349 self.
arch[
"out_channels"][-1],
1373 if G_mixed_precision:
1374 print(
"Using fp16 adam in G...")
1378 params=self.parameters(),
1380 betas=(self.
B1, self.
B2),
1386 self.
optim = optim.Adam(
1387 params=self.parameters(),
1389 betas=(self.
B1, self.
B2),
1394 if sched_version ==
"default":
1397 elif sched_version ==
"CosAnnealLR":
1398 self.
lr_sched = optim.lr_scheduler.CosineAnnealingLR(
1400 T_max=kwargs[
"num_epochs"],
1401 eta_min=self.
lr / 4,
1404 elif sched_version ==
"CosAnnealWarmRes":
1405 self.
lr_sched = optim.lr_scheduler.CosineAnnealingWarmRestarts(
1406 self.
optim, T_0=10, T_mult=2, eta_min=self.
lr / 4
1415 for module
in self.modules():
1417 isinstance(module, nn.Conv2d)
1418 or isinstance(module, nn.Linear)
1419 or isinstance(module, nn.Embedding)
1421 if self.
init ==
"ortho":
1422 init.orthogonal_(module.weight)
1423 elif self.
init ==
"N02":
1424 init.normal_(module.weight, 0, 0.02)
1425 elif self.
init in [
"glorot",
"xavier"]:
1426 init.xavier_uniform_(module.weight)
1428 print(
"Init style not recognized...")
1430 [p.data.nelement()
for p
in module.parameters()]
1432 print(f
"Param count for G's initialized parameters: {self.param_count}")
1439 y = self.
RR_G(y.unsqueeze(0)).squeeze(0)
1443 z = torch.cat([y, z], 1)
1450 for _, blocklist
in enumerate(self.
blocks):
1452 for block
in blocklist:
1462 default initializing with CONFIG dict
1470def generate(model: nn.Module):
1472 Run inference with the provided Generator model
1475 model (nn.Module): Generator model
1478 torch.Tensor: batch of 40 PXD images
1480 device = next(model.parameters()).device
1481 with torch.no_grad():
1482 latents = torch.randn(40, 128, device=device)
1483 labels = torch.tensor(list(range(40)), dtype=torch.long, device=device)
1484 imgs = model(latents, labels).detach().cpu()
1486 imgs = F.threshold(imgs, -0.26, -1)
1488 imgs = imgs.mul_(0.5).add_(0.5)
1490 imgs = torch.pow(256, imgs).add_(-1).clamp_(0, 255).to(torch.uint8)
1492 imgs = imgs[:, 0, 3:-3, :]
def __init__(self, ch, which_conv=SNConv2d)
Constructor.
def forward(self, x)
forward
gamma
Learnable gain parameter.
def forward(self, x)
forward
sigmoid_spatial
sigmoid_spatial
sigmoid_channel
sigmoid channel
def __init__(self, channels, which_conv=SNConv2d, reduction=8, attention_kernel_size=3)
Constructor.
conv_after_concat
convolution after concatenation
def forward(self, x)
forward
def __init__(self, input_dim, num_heads, dim_feedforward, dropout, which_linear)
Constructor.
self_attn
Attention layer.
out_channels
input channels
which_bn
which convolution
def forward(self, x, y)
forward
def __init__(self, in_channels, out_channels, which_conv=SNConv2d, which_bn=bn, activation=None, upsample=None, channel_ratio=4)
Constructor.
hidden_channels
hidden channels
param_count
parameter count
bottom_width
The initial spatial dimensions.
ch
Channel width mulitplier.
G_param
Parameterization style.
which_embedding
which embedding
linear
First linear layer.
shared_dim
Dimensionality of the shared embedding? Unused if not using G_shared.
norm_style
Normalization style.
H_base
The initial harizontal dimension.
n_classes
number of classes, for use in categorical conditional generation
BN_eps
Epsilon for BatchNorm?
hier
Hierarchical latent space?
def __init__(self, G_ch=64, G_depth=2, dim_z=128, bottom_width=4, resolution=256, G_kernel_size=3, G_attn="64", n_classes=40, H_base=1, num_G_SVs=1, num_G_SV_itrs=1, attn_type="sa", G_shared=True, shared_dim=128, hier=True, cross_replica=False, mybn=False, G_activation=nn.ReLU(inplace=False), G_lr=5e-5, G_B1=0.0, G_B2=0.999, adam_eps=1e-8, BN_eps=1e-5, SN_eps=1e-12, G_init="ortho", G_mixed_precision=False, G_fp16=False, skip_init=False, no_optim=False, sched_version="default", RRM_prx_G=True, n_head_G=2, G_param="SN", norm_style="bn", **kwargs)
Constructor.
def init_weights(self)
Initialize.
init
Initialization style.
def forward(self, z, y)
forward
G_shared
Use shared embeddings?
G_depth
Number of resblocks per stage.
cross_replica
Cross replica batchnorm?
dim_z
Dimensionality of the latent space.
resolution
Resolution of the output.
RR_G
RRM on proxy embeddings.
SN_eps
Epsilon for Spectral Norm?
def forward(self, x, context=None)
forward
def __init__(self, chan, chan_out=None, kernel_size=1, padding=0, stride=1, key_dim=32, value_dim=64, heads=8, norm_queries=True)
Constructor.
def __init__(self)
Constructor.
def __init__(self, input_dim, embed_dim, num_heads, which_linear)
Constructor.
def _reset_parameters(self)
reset parameters
embed_dim
embedding dimension
def forward(self, x, return_attention=False)
forward
def forward(self, x)
forward
def __init__(self, num_layers, hidden_dim, **block_args)
Constructor.
def get_attention_maps(self, x)
get attention maps
def forward(self, x)
forward
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, num_svs=1, num_itrs=1, eps=1e-12)
Constructor.
def forward(self, x)
forward
def __init__(self, num_embeddings, embedding_dim, padding_idx=None, max_norm=None, norm_type=2, scale_grad_by_freq=False, sparse=False, _weight=None, num_svs=1, num_itrs=1, eps=1e-12)
Constructor.
def forward(self, x)
forward
def __init__(self, in_features, out_features, bias=True, num_svs=1, num_itrs=1, eps=1e-12)
Constructor.
def __init__(self, num_svs, num_itrs, num_outputs, transpose=False, eps=1e-12)
num_itrs
Number of power iterations per step.
num_svs
Number of singular values.
eps
Epsilon value for avoiding divide-by-0.
def forward(self, x)
forward
def __init__(self, output_size, eps=1e-5, momentum=0.1, cross_replica=False, mybn=False)
Constructor.
gain
Prepare gain and bias layers.
cross_replica
Use cross-replica batchnorm?
eps
epsilon to avoid dividing by 0
def __init__(self, output_size, input_size, which_linear, eps=1e-5, momentum=0.1, cross_replica=False, mybn=False, norm_style="bn")
Constructor.
def forward(self, x, y)
forward
def extra_repr(self)
extra_repr
gain
Prepare gain and bias layers.
cross_replica
Use cross-replica batchnorm?
eps
epsilon to avoid dividing by 0
def forward(self, torch.Tensor tensor)
forward
def forward(self, x, gain, bias)
forward
def __init__(self, num_channels, eps=1e-5, momentum=0.1)
Constructor.
accumulate_standing
Accumulate running means and vars.
def reset_stats(self)
reset standing stats
momentum
momentum for updating running stats
eps
epsilon to avoid dividing by 0