9This module implements the IEA-GAN generator model.
18from torch
import optim
21import torch.nn.functional
as F
31 "use_multiepoch_sampler": false,
32 "model": "BigGAN_deep",
46 "cross_replica": false,
48 "G_activation": "inplace_relu",
49 "D_activation": "inplace_relu",
64 "num_G_accumulations": 1,
66 "num_D_accumulations": 1,
72 "D_mixed_precision": false,
73 "G_mixed_precision": false,
74 "accumulate_stats": false,
75 "num_standing_accumulations": 16,
96 "sv_log_interval": 10,
100 "run_name": "BGd_140",
103 "latent_reg_weight": 300,
108 "conditional_strategy": "Contra",
109 "hypersphere_dim": 1024,
110 "pos_collected_numerator": false,
111 "nonlinear_embed": false,
112 "normalize_embed": true,
113 "inv_stereographic" :false,
114 "contra_lambda": 1.0,
119 "Uniformity_loss": true,
127 "normalized_proxy_G": false,
133 "sched_version": "default",
135 "truncated_threshold": 1.0,
144 "stop_after": 100000,
147 "metric_log_name": "metric_log.jsonl",
148 "reinitialize_metric_logs": false,
149 "reinitialize_parameter_logs": false,
150 "num_incep_images": 16000,
151 "load_optim": true}"""
157 Projection of x onto y
159 return torch.mm(y, x.t()) * y / torch.mm(y, y.t())
162def gram_schmidt(x, ys):
164 Orthogonalize x wrt list of vectors ys
171def power_iteration(W, u_, update=True, eps=1e-12):
173 Apply num_itrs steps of the power method to estimate top N singular values.
176 us, vs, svs = [], [], []
177 for i, u
in enumerate(u_):
179 with torch.no_grad():
180 v = torch.matmul(u, W)
182 v = F.normalize(gram_schmidt(v, vs), eps=eps)
186 u = torch.matmul(v, W.t())
188 u = F.normalize(gram_schmidt(u, us), eps=eps)
194 svs += [torch.squeeze(torch.matmul(torch.matmul(v, W.t()), u.t()))]
199def groupnorm(x, norm_style):
201 Simple function to handle groupnorm norm stylization
204 if "ch" in norm_style:
205 ch = int(norm_style.split(
"_")[-1])
206 groups = max(int(x.shape[1]) // ch, 1)
208 elif "grp" in norm_style:
209 groups = int(norm_style.split(
"_")[-1])
213 return F.group_norm(x, groups)
218 Convenience passthrough function
228 Spectral normalization base class
232 def __init__(self, num_svs, num_itrs, num_outputs, transpose=False, eps=1e-12):
235 ## Number of power iterations per step
236 self.num_itrs = num_itrs
237 ## Number of singular values
238 self.num_svs = num_svs
240 self.transpose = transpose
241 ## Epsilon value for avoiding divide-by-0
243 # Register a singular vector for each sv
244 for i in range(self.num_svs):
245 self.register_buffer(f"u{i:d}", torch.randn(1, num_outputs))
246 self.register_buffer(f"sv{i:d}", torch.ones(1))
232 def __init__(self, num_svs, num_itrs, num_outputs, transpose=False, eps=1e-12):
…
251 Singular vectors (u side)
253 return [getattr(self, f"u{i:d}") for i in range(self.num_svs)]
259 note that these buffers are just for logging and are not used in training.
261 return [getattr(self, f"sv{i:d}") for i in range(self.num_svs)]
265 Compute the spectrally-normalized weight
267 W_mat = self.weight.view(self.weight.size(0), -1)
270 # Apply num_itrs power iterations
271 for _ in range(self.num_itrs):
272 svs, _, _ = power_iteration(
273 W_mat, self.u, update=self.training, eps=self.eps
277 # Make sure to do this in a no_grad() context or you'll get memory leaks! # noqa
278 with torch.no_grad():
279 for i, sv in enumerate(svs):
281 return self.weight / svs[0]
284class SNConv2d(nn.Conv2d, SN):
286 2D Conv layer with spectral norm
315 SN.__init__(self, num_svs, num_itrs, out_channels, eps=eps)
318 def forward(self, x):
318 def forward(self, x):
…
284class SNConv2d(nn.Conv2d, SN):
…
330class SNLinear(nn.Linear, SN):
332 Linear layer with spectral norm
345 nn.Linear.__init__(self, in_features, out_features, bias)
346 SN.__init__(self, num_svs, num_itrs, out_features, eps=eps)
349 def forward(self, x):
350 return F.linear(x, self.W_(), self.bias)
349 def forward(self, x):
…
330class SNLinear(nn.Linear, SN):
…
353def fused_bn(x, mean, var, gain=None, bias=None, eps=1e-5):
354 """Fused batchnorm op"""
356 # Apply scale and shift--if gain and bias are provided, fuse them here
358 scale = torch.rsqrt(var + eps)
359 # If a gain is provided, use it
364 # If bias is provided, use it
367 return x * scale - shift
368 # return ((x - mean) / ((var + eps) ** 0.5)) * gain + bias # The unfused way. # noqa
371def manual_bn(x, gain=None, bias=None, return_mean_var=False, eps=1e-5):
374 Calculate means and variances using mean-of-squares minus mean-squared
377 # Cast x to float32 if necessary
379 # Calculate expected value of x (m) and expected value of x**2 (m2)
381 m = torch.mean(float_x, [0, 2, 3], keepdim=True)
383 m2 = torch.mean(float_x**2, [0, 2, 3], keepdim=True)
384 # Calculate variance as mean of squared minus mean squared.
386 # Cast back to float 16 if necessary
387 var = var.type(x.type())
389 # Return mean and variance for updating stored mean/var if requested
392 fused_bn(x, m, var, gain, bias, eps),
397 return fused_bn(x, m, var, gain, bias, eps)
400class myBN(nn.Module):
402 My batchnorm, supports standing stats
406 def __init__(self, num_channels, eps=1e-5, momentum=0.1):
407 super(myBN, self).__init__()
408 ## momentum for updating running stats
409 self.momentum = momentum
410 ## epsilon to avoid dividing by 0
413 self.momentum = momentum
415 self.register_buffer("stored_mean", torch.zeros(num_channels))
416 self.register_buffer("stored_var", torch.ones(num_channels))
417 self.register_buffer("accumulation_counter", torch.zeros(1))
418 ## Accumulate running means and vars
419 self.accumulate_standing = False
406 def __init__(self, num_channels, eps=1e-5, momentum=0.1):
…
421 ## reset standing stats
422 def reset_stats(self):
423 # pylint: disable=no-member
424 self.stored_mean[:] = 0
425 self.stored_var[:] = 0
426 self.accumulation_counter[:] = 0
422 def reset_stats(self):
…
429 def forward(self, x, gain, bias):
430 # pylint: disable=no-member
432 out, mean, var = manual_bn(
433 x, gain, bias, return_mean_var=True, eps=self.eps
435 # If accumulating standing stats, increment them
436 if self.accumulate_standing:
437 self.stored_mean[:] = self.stored_mean + mean.data
438 self.stored_var[:] = self.stored_var + var.data
439 self.accumulation_counter += 1.0
440 # If not accumulating standing stats, take running averages
442 self.stored_mean[:] = (
443 self.stored_mean * (1 - self.momentum) + mean * self.momentum
445 self.stored_var[:] = (
446 self.stored_var * (1 - self.momentum) + var * self.momentum
449 # If not in training mode, use the stored statistics
451 mean = self.stored_mean.view(1, -1, 1, 1)
452 var = self.stored_var.view(1, -1, 1, 1)
453 # If using standing stats, divide them by the accumulation counter
454 if self.accumulate_standing:
455 mean = mean / self.accumulation_counter
456 var = var / self.accumulation_counter
457 return fused_bn(x, mean, var, gain, bias, self.eps)
429 def forward(self, x, gain, bias):
…
400class myBN(nn.Module):
…
462 Normal, non-class-conditional BN
474 super(bn, self).__init__()
476 self.output_size = output_size
477 ## Prepare gain and bias layers
478 self.gain = P(torch.ones(output_size), requires_grad=True)
480 self.bias = P(torch.zeros(output_size), requires_grad=True)
481 ## epsilon to avoid dividing by 0
484 self.momentum = momentum
485 ## Use cross-replica batchnorm?
486 self.cross_replica = cross_replica
491 self.bn = myBN(output_size, self.eps, self.momentum)
492 # Register buffers if neither of the above
494 self.register_buffer("stored_mean", torch.zeros(output_size))
495 self.register_buffer("stored_var", torch.ones(output_size))
498 def forward(self, x):
500 gain = self.gain.view(1, -1, 1, 1)
501 bias = self.bias.view(1, -1, 1, 1)
502 return self.bn(x, gain=gain, bias=bias)
498 def forward(self, x):
…
460class bn(nn.Module):
…
516class ccbn(nn.Module):
519 output size is the number of channels, input size is for the linear layers
520 Andy's Note: this class feels messy but I'm not really sure how to clean it up # noqa
521 Suggestions welcome! (By which I mean, refactor this and make a merge request
522 if you want to make this more readable/usable).
537 super(ccbn, self).__init__()
539 self.output_size, self.input_size = output_size, input_size
540 ## Prepare gain and bias layers
541 self.gain = which_linear(input_size, output_size)
543 self.bias = which_linear(input_size, output_size)
544 ## epsilon to avoid dividing by 0
547 self.momentum = momentum
548 ## Use cross-replica batchnorm?
549 self.cross_replica = cross_replica
553 self.norm_style = norm_style
557 self.bn = myBN(output_size, self.eps, self.momentum)
558 elif self.norm_style in ["bn", "in"]:
559 self.register_buffer("stored_mean", torch.zeros(output_size))
560 self.register_buffer("stored_var", torch.ones(output_size))
563 def forward(self, x, y):
564 # Calculate class-conditional gains and biases
565 gain = (1 + self.gain(y)).view(y.size(0), -1, 1, 1)
566 bias = self.bias(y).view(y.size(0), -1, 1, 1)
567 # If using my batchnorm
569 return self.bn(x, gain=gain, bias=bias)
572 if self.norm_style == "bn":
583 elif self.norm_style == "in":
584 out = F.instance_norm(
594 elif self.norm_style == "gn":
595 out = groupnorm(x, self.normstyle)
596 elif self.norm_style == "nonorm":
598 return out * gain + bias
563 def forward(self, x, y):
…
601 def extra_repr(self):
602 s = "out: {output_size}, in: {input_size},"
603 s += " cross_replica={cross_replica}"
604 return s.format(**self.__dict__)
601 def extra_repr(self):
…
516class ccbn(nn.Module):
…
609 Image_Linear_Attention
628 chan_out = chan if chan_out is None else chan_out
631 self.key_dim = key_dim
633 self.value_dim = value_dim
638 self.norm_queries = norm_queries
640 conv_kwargs = {"padding": padding, "stride": stride}
642 self.to_q = nn.Conv2d(chan, key_dim * heads, kernel_size, **conv_kwargs)
644 self.to_k = nn.Conv2d(chan, key_dim * heads, kernel_size, **conv_kwargs)
646 self.to_v = nn.Conv2d(chan, value_dim * heads, kernel_size, **conv_kwargs)
648 out_conv_kwargs = {"padding": padding}
650 self.to_out = nn.Conv2d(
651 value_dim * heads, chan_out, kernel_size, **out_conv_kwargs
655 def forward(self, x, context=None):
656 b, c, h, w, k_dim, heads = *x.shape, self.key_dim, self.heads
658 q, k, v = (self.to_q(x), self.to_k(x), self.to_v(x))
660 q, k, v = map(lambda t: t.reshape(b, heads, -1, h * w), (q, k, v))
662 q, k = map(lambda x: x * (self.key_dim**-0.25), (q, k))
664 if context is not None:
665 context = context.reshape(b, c, 1, -1)
666 ck, cv = self.to_k(context), self.to_v(context)
667 ck, cv = map(lambda t: t.reshape(b, heads, k_dim, -1), (ck, cv))
668 k = torch.cat((k, ck), dim=3)
669 v = torch.cat((v, cv), dim=3)
671 k = k.softmax(dim=-1)
673 if self.norm_queries:
674 q = q.softmax(dim=-2)
676 context = torch.einsum("bhdn,bhen->bhde", k, v)
677 out = torch.einsum("bhdn,bhde->bhen", q, context)
678 out = out.reshape(b, -1, h, w)
679 out = self.to_out(out)
655 def forward(self, x, context=None):
…
607class ILA(nn.Module):
…
683class CBAM_attention(nn.Module):
692 attention_kernel_size=3,
694 super(CBAM_attention, self).__init__()
696 self.avg_pool = nn.AdaptiveAvgPool2d(1)
698 self.max_pool = nn.AdaptiveMaxPool2d(1)
700 self.fc1 = which_conv(
701 channels, channels // reduction, kernel_size=1, padding=0
704 self.relu = nn.ReLU(inplace=True)
706 self.fc2 = which_conv(
707 channels // reduction, channels, kernel_size=1, padding=0
710 self.sigmoid_channel = nn.Sigmoid()
711 ## convolution after concatenation
712 self.conv_after_concat = which_conv(
715 kernel_size=attention_kernel_size,
717 padding=attention_kernel_size // 2,
720 self.sigmoid_spatial = nn.Sigmoid()
723 def forward(self, x):
724 # Channel attention module
726 avg = self.avg_pool(x)
727 mx = self.max_pool(x)
735 x = self.sigmoid_channel(x)
736 # Spatial attention module
739 # b, c, h, w = x.size()
740 avg = torch.mean(x, 1, True)
741 mx, _ = torch.max(x, 1, True)
742 x = torch.cat((avg, mx), 1)
743 x = self.conv_after_concat(x)
744 x = self.sigmoid_spatial(x)
723 def forward(self, x):
…
683class CBAM_attention(nn.Module):
…
749class Attention(nn.Module):
753 def __init__(self, ch, which_conv=SNConv2d):
754 super(Attention, self).__init__()
755 ## Channel multiplier
758 self.which_conv = which_conv
760 self.theta = self.which_conv(
761 self.ch, self.ch // 8, kernel_size=1, padding=0, bias=False
764 self.phi = self.which_conv(
765 self.ch, self.ch // 8, kernel_size=1, padding=0, bias=False
768 self.g = self.which_conv(
769 self.ch, self.ch // 2, kernel_size=1, padding=0, bias=False
772 self.o = self.which_conv(
773 self.ch // 2, self.ch, kernel_size=1, padding=0, bias=False
775 ## Learnable gain parameter
776 self.gamma = P(torch.tensor(0.0), requires_grad=True)
753 def __init__(self, ch, which_conv=SNConv2d):
…
779 def forward(self, x):
781 theta = self.theta(x)
782 phi = F.max_pool2d(self.phi(x), [2, 2])
783 g = F.max_pool2d(self.g(x), [2, 2])
785 theta = theta.view(-1, self.ch // 8, x.shape[2] * x.shape[3])
786 phi = phi.view(-1, self.ch // 8, x.shape[2] * x.shape[3] // 4)
787 g = g.view(-1, self.ch // 2, x.shape[2] * x.shape[3] // 4)
788 # Matmul and softmax to get attention maps
789 beta = F.softmax(torch.bmm(theta.transpose(1, 2), phi), -1)
790 # Attention map times g path
792 torch.bmm(g, beta.transpose(1, 2)).view(
793 -1, self.ch // 2, x.shape[2], x.shape[3]
796 return self.gamma * o + x
779 def forward(self, x):
…
749class Attention(nn.Module):
…
799class SNEmbedding(nn.Embedding, SN):
801 Embedding layer with spectral norm
802 We use num_embeddings as the dim instead of embedding_dim here
814 scale_grad_by_freq=False,
821 nn.Embedding.__init__(
832 SN.__init__(self, num_svs, num_itrs, num_embeddings, eps=eps)
835 def forward(self, x):
836 return F.embedding(x, self.W_())
835 def forward(self, x):
…
799class SNEmbedding(nn.Embedding, SN):
…
839def scaled_dot_product(q, k, v):
841 attn_logits = torch.matmul(q, k.transpose(-2, -1))
842 attn_logits = attn_logits / math.sqrt(d_k)
843 attention = F.softmax(attn_logits, dim=-1)
844 values = torch.matmul(attention, v)
845 return values, attention
848class MultiheadAttention(nn.Module):
849 """MultiheadAttention"""
852 def __init__(self, input_dim, embed_dim, num_heads, which_linear):
855 embed_dim % num_heads == 0
856 ), "Embedding dimension must be 0 modulo number of heads."
858 ## embedding dimension
859 self.embed_dim = embed_dim
861 self.num_heads = num_heads
863 self.head_dim = embed_dim // num_heads
865 self.which_linear = which_linear
867 # Stack all weight matrices 1...h together for efficiency
869 self.qkv_proj = self.which_linear(input_dim, 3 * embed_dim)
871 self.o_proj = self.which_linear(embed_dim, embed_dim)
873 self._reset_parameters()
852 def __init__(self, input_dim, embed_dim, num_heads, which_linear):
…
876 def _reset_parameters(self):
877 # Original Transformer initialization, see PyTorch documentation
878 nn.init.xavier_uniform_(self.qkv_proj.weight)
879 self.qkv_proj.bias.data.fill_(0)
880 nn.init.xavier_uniform_(self.o_proj.weight)
881 self.o_proj.bias.data.fill_(0)
876 def _reset_parameters(self):
…
884 def forward(self, x, return_attention=False):
885 batch_size, seq_length, embed_dim = x.size()
886 qkv = self.qkv_proj(x)
888 # Separate Q, K, V from linear output
889 qkv = qkv.reshape(batch_size, seq_length, self.num_heads, 3 * self.head_dim)
890 qkv = qkv.permute(0, 2, 1, 3) # [Batch, Head, SeqLen, Dims]
891 q, k, v = qkv.chunk(3, dim=-1)
893 # Determine value outputs
894 values, attention = scaled_dot_product(q, k, v)
895 values = values.permute(0, 2, 1, 3) # [Batch, SeqLen, Head, Dims]
896 values = values.reshape(batch_size, seq_length, embed_dim)
897 o = self.o_proj(values)
884 def forward(self, x, return_attention=False):
…
848class MultiheadAttention(nn.Module):
…
905class EncoderBlock(nn.Module):
909 def __init__(self, input_dim, num_heads, dim_feedforward, dropout, which_linear):
912 input_dim - Dimensionality of the input
913 num_heads - Number of heads to use in the attention block
914 dim_feedforward - Dimensionality of the hidden layer in the MLP
915 dropout - Dropout probability to use in the dropout layers
920 self.which_linear = which_linear
922 self.self_attn = MultiheadAttention(
923 input_dim, input_dim, num_heads, which_linear
927 self.linear_net = nn.Sequential(
928 self.which_linear(input_dim, dim_feedforward),
930 nn.ReLU(inplace=True),
931 self.which_linear(dim_feedforward, input_dim),
934 # Layers to apply in between the main layers
936 self.norm1 = nn.LayerNorm(input_dim)
938 self.norm2 = nn.LayerNorm(input_dim)
940 self.dropout = nn.Dropout(dropout)
909 def __init__(self, input_dim, num_heads, dim_feedforward, dropout, which_linear):
…
943 def forward(self, x):
945 x_pre1 = self.norm1(x)
946 attn_out = self.self_attn(x_pre1)
947 x = x + self.dropout(attn_out)
951 x_pre2 = self.norm2(x)
952 linear_out = self.linear_net(x_pre2)
953 x = x + self.dropout(linear_out)
943 def forward(self, x):
…
905class EncoderBlock(nn.Module):
…
959class RelationalReasoning(nn.Module):
960 """RelationalReasoning"""
963 def __init__(self, num_layers, hidden_dim, **block_args):
966 self.layers = nn.ModuleList(
967 [EncoderBlock(**block_args) for _ in range(num_layers)]
970 self.norm = nn.LayerNorm(hidden_dim)
963 def __init__(self, num_layers, hidden_dim, **block_args):
…
973 def forward(self, x):
974 for layer in self.layers:
973 def forward(self, x):
…
980 ## get attention maps
981 def get_attention_maps(self, x):
983 for layer in self.layers:
984 _, attn_map = layer.self_attn(x, return_attention=True)
985 attention_maps.append(attn_map)
987 return attention_maps
981 def get_attention_maps(self, x):
…
959class RelationalReasoning(nn.Module):
…
990class GBlock(nn.Module):
1004 super(GBlock, self).__init__()
1007 self.in_channels, self.out_channels = in_channels, out_channels
1009 self.hidden_channels = self.in_channels // channel_ratio
1010 ## which convolution
1011 self.which_conv, self.which_bn = which_conv, which_bn
1013 self.activation = activation
1016 self.conv1 = self.which_conv(
1017 self.in_channels, self.hidden_channels, kernel_size=1, padding=0
1020 self.conv2 = self.which_conv(self.hidden_channels, self.hidden_channels)
1022 self.conv3 = self.which_conv(self.hidden_channels, self.hidden_channels)
1024 self.conv4 = self.which_conv(
1025 self.hidden_channels, self.out_channels, kernel_size=1, padding=0
1029 self.bn1 = self.which_bn(self.in_channels)
1031 self.bn2 = self.which_bn(self.hidden_channels)
1033 self.bn3 = self.which_bn(self.hidden_channels)
1035 self.bn4 = self.which_bn(self.hidden_channels)
1037 self.upsample = upsample
1040 def forward(self, x, y):
1041 # Project down to channel ratio
1042 h = self.conv1(self.activation(self.bn1(x, y)))
1043 # Apply next BN-ReLU
1044 h = self.activation(self.bn2(h, y))
1045 # Drop channels in x if necessary
1046 if self.in_channels != self.out_channels:
1047 x = x[:, : self.out_channels]
1048 # Upsample both h and x at this point
1050 h = self.upsample(h)
1051 x = self.upsample(x)
1054 h = self.conv3(self.activation(self.bn3(h, y)))
1056 h = self.conv4(self.activation(self.bn4(h, y)))
1040 def forward(self, x, y):
…
990class GBlock(nn.Module):
…
1060def G_arch(ch=64, attention="64"):
1063 "in_channels": [ch * item for item in [16, 16, 8, 8, 4, 2, 1]],
1064 "out_channels": [ch * item for item in [16, 8, 8, 4, 2, 1, 1]],
1065 "upsample": [True] * 7,
1066 "resolution": [8, 16, 32, 64, 128, 256, 512],
1068 2**i: (2 ** i in [int(item) for item in attention.split("_")])
1069 for i in range(3, 10)
1073 "in_channels": [ch * item for item in [16, 16, 8, 8, 4, 2]],
1074 "out_channels": [ch * item for item in [16, 8, 8, 4, 2, 1]],
1075 "upsample": [True] * 6,
1076 "resolution": [8, 16, 32, 64, 128, 256],
1078 2**i: (2 ** i in [int(item) for item in attention.split("_")])
1079 for i in range(3, 9)
1083 "in_channels": [ch * item for item in [16, 16, 8, 4, 2]],
1084 "out_channels": [ch * item for item in [16, 8, 4, 2, 1]],
1085 "upsample": [True] * 5,
1086 "resolution": [8, 16, 32, 64, 128],
1088 2**i: (2 ** i in [int(item) for item in attention.split("_")])
1089 for i in range(3, 8)
1093 "in_channels": [ch * item for item in [16, 16, 8, 4]],
1094 "out_channels": [ch * item for item in [16, 8, 4, 2]],
1095 "upsample": [True] * 4,
1096 "resolution": [12, 24, 48, 96],
1098 12 * 2**i: (6 * 2 ** i in [int(item) for item in attention.split("_")])
1099 for i in range(0, 4)
1104 "in_channels": [ch * item for item in [16, 16, 8, 4]],
1105 "out_channels": [ch * item for item in [16, 8, 4, 2]],
1106 "upsample": [True] * 4,
1107 "resolution": [8, 16, 32, 64],
1109 2**i: (2 ** i in [int(item) for item in attention.split("_")])
1110 for i in range(3, 7)
1114 "in_channels": [ch * item for item in [4, 4, 4]],
1115 "out_channels": [ch * item for item in [4, 4, 4]],
1116 "upsample": [True] * 3,
1117 "resolution": [8, 16, 32],
1119 2**i: (2 ** i in [int(item) for item in attention.split("_")])
1120 for i in range(3, 6)
1127class Generator(nn.Module):
1148 cross_replica=False,
1150 G_activation=nn.ReLU(inplace=False),
1158 G_mixed_precision=False,
1162 sched_version="default",
1169 super(Generator, self).__init__()
1170 ## Channel width multiplier
1172 ## Number of resblocks per stage
1173 self.G_depth = G_depth
1174 ## Dimensionality of the latent space
1176 ## The initial spatial dimensions
1177 self.bottom_width = bottom_width
1178 ## The initial harizontal dimension
1179 self.H_base = H_base
1180 ## Resolution of the output
1181 self.resolution = resolution
1183 self.kernel_size = G_kernel_size
1185 self.attention = G_attn
1186 ## number of classes, for use in categorical conditional generation
1187 self.n_classes = n_classes
1188 ## Use shared embeddings?
1189 self.G_shared = G_shared
1190 ## Dimensionality of the shared embedding? Unused if not using G_shared
1191 self.shared_dim = shared_dim if shared_dim > 0 else dim_z
1192 ## Hierarchical latent space?
1194 ## Cross replica batchnorm?
1195 self.cross_replica = cross_replica
1196 ## Use my batchnorm?
1198 # nonlinearity for residual blocks
1199 if G_activation == "inplace_relu":
1201 self.activation = torch.nn.ReLU(inplace=True)
1202 elif G_activation == "relu":
1203 self.activation = torch.nn.ReLU(inplace=False)
1204 elif G_activation == "leaky_relu":
1205 self.activation = torch.nn.LeakyReLU(0.2, inplace=False)
1207 raise NotImplementedError("activation function not implemented")
1208 ## Initialization style
1210 ## Parameterization style
1211 self.G_param = G_param
1212 ## Normalization style
1213 self.norm_style = norm_style
1214 ## Epsilon for BatchNorm?
1215 self.BN_eps = BN_eps
1216 ## Epsilon for Spectral Norm?
1217 self.SN_eps = SN_eps
1220 ## Architecture dict
1221 self.arch = G_arch(self.ch, self.attention)[resolution]
1223 self.RRM_prx_G = RRM_prx_G
1225 self.n_head_G = n_head_G
1227 # Which convs, batchnorms, and linear layers to use
1228 if self.G_param == "SN":
1230 self.which_conv = functools.partial(
1235 num_itrs=num_G_SV_itrs,
1239 self.which_linear = functools.partial(
1242 num_itrs=num_G_SV_itrs,
1246 self.which_conv = functools.partial(nn.Conv2d, kernel_size=3, padding=1)
1247 self.which_linear = nn.Linear
1249 # We use a non-spectral-normed embedding here regardless;
1250 # For some reason applying SN to G's embedding seems to randomly cripple G # noqa
1252 self.which_embedding = nn.Embedding
1254 functools.partial(self.which_linear, bias=False)
1256 else self.which_embedding
1259 self.which_bn = functools.partial(
1261 which_linear=bn_linear,
1262 cross_replica=self.cross_replica,
1265 self.shared_dim + self.dim_z if self.G_shared else self.n_classes
1267 norm_style=self.norm_style,
1272 self.which_embedding(n_classes, self.shared_dim)
1278 ## RRM on proxy embeddings
1279 self.RR_G = RelationalReasoning(
1282 dim_feedforward=128,
1283 which_linear=nn.Linear,
1284 num_heads=self.n_head_G,
1289 ## First linear layer
1290 self.linear = self.which_linear(
1291 self.dim_z + self.shared_dim,
1292 self.arch["in_channels"][0] * ((self.bottom_width**2) * self.H_base),
1295 # self.blocks is a doubly-nested list of modules, the outer loop intended # noqa
1296 # to be over blocks at a given resolution (resblocks and/or self-attention) # noqa
1297 # while the inner loop is over a given block
1300 for index in range(len(self.arch["out_channels"])):
1304 in_channels=self.arch["in_channels"][index],
1305 out_channels=self.arch["in_channels"][index]
1307 else self.arch["out_channels"][index],
1308 which_conv=self.which_conv,
1309 which_bn=self.which_bn,
1310 activation=self.activation,
1312 functools.partial(F.interpolate, scale_factor=2)
1313 if self.arch["upsample"][index]
1314 and g_index == (self.G_depth - 1)
1319 for g_index in range(self.G_depth)
1322 # If attention on this block, attach it to the end
1323 if self.arch["attention"][self.arch["resolution"][index]]:
1325 f"Adding attention layer in G at resolution {self.arch['resolution'][index]:d}"
1328 if attn_type == "sa":
1329 self.blocks[-1] += [
1330 Attention(self.arch["out_channels"][index], self.which_conv)
1332 elif attn_type == "cbam":
1333 self.blocks[-1] += [
1335 self.arch["out_channels"][index], self.which_conv
1338 elif attn_type == "ila":
1339 self.blocks[-1] += [ILA(self.arch["out_channels"][index])]
1341 # Turn self.blocks into a ModuleList so that it's all properly registered. # noqa
1342 self.blocks = nn.ModuleList([nn.ModuleList(block) for block in self.blocks])
1344 # output layer: batchnorm-relu-conv.
1345 # Consider using a non-spectral conv here
1347 self.output_layer = nn.Sequential(
1349 self.arch["out_channels"][-1],
1350 cross_replica=self.cross_replica,
1354 self.which_conv(self.arch["out_channels"][-1], 1),
1357 # Initialize weights. Optionally skip init for testing.
1362 # If this is an EMA copy, no need for an optim, so just return now
1372 self.adam_eps = adam_eps
1373 if G_mixed_precision:
1374 print("Using fp16 adam in G...")
1377 self.optim = utils.Adam16(
1378 params=self.parameters(),
1380 betas=(self.B1, self.B2),
1386 self.optim = optim.Adam(
1387 params=self.parameters(),
1389 betas=(self.B1, self.B2),
1394 if sched_version == "default":
1396 self.lr_sched = None
1397 elif sched_version == "CosAnnealLR":
1398 self.lr_sched = optim.lr_scheduler.CosineAnnealingLR(
1400 T_max=kwargs["num_epochs"],
1401 eta_min=self.lr / 4,
1404 elif sched_version == "CosAnnealWarmRes":
1405 self.lr_sched = optim.lr_scheduler.CosineAnnealingWarmRestarts(
1406 self.optim, T_0=10, T_mult=2, eta_min=self.lr / 4
1409 self.lr_sched = None
1412 def init_weights(self):
1414 self.param_count = 0
1415 for module in self.modules():
1417 isinstance(module, nn.Conv2d)
1418 or isinstance(module, nn.Linear)
1419 or isinstance(module, nn.Embedding)
1421 if self.init == "ortho":
1422 init.orthogonal_(module.weight)
1423 elif self.init == "N02":
1424 init.normal_(module.weight, 0, 0.02)
1425 elif self.init in ["glorot", "xavier"]:
1426 init.xavier_uniform_(module.weight)
1428 print("Init style not recognized...")
1429 self.param_count += sum(
1430 [p.data.nelement() for p in module.parameters()]
1432 print(f"Param count for G's initialized parameters: {self.param_count}")
1412 def init_weights(self): …
1435 def forward(self, z, y):
1437 # If relational embedding
1439 y = self.RR_G(y.unsqueeze(0)).squeeze(0)
1440 # y = F.normalize(y, dim=1)
1441 # If hierarchical, concatenate zs and ys
1442 if self.hier: # y and z are [bs,128] dimensional
1443 z = torch.cat([y, z], 1)
1445 # First linear layer
1446 h = self.linear(z) # ([bs,256]-->[bs,24576])
1448 h = h.view(h.size(0), -1, self.bottom_width, self.bottom_width * self.H_base)
1450 for _, blocklist in enumerate(self.blocks):
1451 # Second inner loop in case block has multiple layers
1452 for block in blocklist:
1455 # Apply batchnorm-relu-conv-tanh at output
1456 return torch.tanh(self.output_layer(h))
1435 def forward(self, z, y): …
1127class Generator(nn.Module):
…
1459class Model(Generator):
1462 default initializing with CONFIG dict
1467 super().__init__(**CONFIG)
1466 def __init__(self):
…
1459class Model(Generator): …
1470def generate(model: nn.Module):
1472 Run inference with the provided Generator model
1475 model (nn.Module): Generator model
1478 torch.Tensor: batch of 40 PXD images
1480 device = next(model.parameters()).device
1481 with torch.no_grad():
1482 latents = torch.randn(40, 128, device=device)
1483 labels = torch.tensor(list(range(40)), dtype=torch.long, device=device)
1484 imgs = model(latents, labels).detach().cpu()
1485 # Cut the noise below 7 ADU
1486 imgs = F.threshold(imgs, -0.26, -1)
1487 # center range [-1, 1] to [0, 1]
1488 imgs = imgs.mul_(0.5).add_(0.5)
1489 # renormalize and convert to uint8
1490 imgs = torch.pow(256, imgs).add_(-1).clamp_(0, 255).to(torch.uint8)
1491 # flatten channel dimension and crop 256 to 250
1492 imgs = imgs[:, 0, 3:-3, :]
__init__(self, num_svs, num_itrs, num_outputs, transpose=False, eps=1e-12)
forward(self, torch.Tensor tensor)
forward