9This module implements the IEA-GAN generator model.
18from torch
import optim
21import torch.nn.functional
as F
31 "use_multiepoch_sampler": false,
32 "model": "BigGAN_deep",
46 "cross_replica": false,
48 "G_activation": "inplace_relu",
49 "D_activation": "inplace_relu",
64 "num_G_accumulations": 1,
66 "num_D_accumulations": 1,
72 "D_mixed_precision": false,
73 "G_mixed_precision": false,
74 "accumulate_stats": false,
75 "num_standing_accumulations": 16,
96 "sv_log_interval": 10,
100 "run_name": "BGd_140",
103 "latent_reg_weight": 300,
108 "conditional_strategy": "Contra",
109 "hypersphere_dim": 1024,
110 "pos_collected_numerator": false,
111 "nonlinear_embed": false,
112 "normalize_embed": true,
113 "inv_stereographic" :false,
114 "contra_lambda": 1.0,
119 "Uniformity_loss": true,
127 "normalized_proxy_G": false,
133 "sched_version": "default",
135 "truncated_threshold": 1.0,
144 "stop_after": 100000,
147 "metric_log_name": "metric_log.jsonl",
148 "reinitialize_metric_logs": false,
149 "reinitialize_parameter_logs": false,
150 "num_incep_images": 16000,
151 "load_optim": true}"""
157 Projection of x onto y
159 return torch.mm(y, x.t()) * y / torch.mm(y, y.t())
162def gram_schmidt(x, ys):
164 Orthogonalize x wrt list of vectors ys
171def power_iteration(W, u_, update=True, eps=1e-12):
173 Apply num_itrs steps of the power method to estimate top N singular values.
176 us, vs, svs = [], [], []
177 for i, u
in enumerate(u_):
179 with torch.no_grad():
180 v = torch.matmul(u, W)
182 v = F.normalize(gram_schmidt(v, vs), eps=eps)
186 u = torch.matmul(v, W.t())
188 u = F.normalize(gram_schmidt(u, us), eps=eps)
194 svs += [torch.squeeze(torch.matmul(torch.matmul(v, W.t()), u.t()))]
199def groupnorm(x, norm_style):
201 Simple function to handle groupnorm norm stylization
204 if "ch" in norm_style:
205 ch = int(norm_style.split(
"_")[-1])
206 groups = max(int(x.shape[1]) // ch, 1)
208 elif "grp" in norm_style:
209 groups = int(norm_style.split(
"_")[-1])
213 return F.group_norm(x, groups)
218 Convenience passthrough function
228 Spectral normalization base class
230 This base class expects subclasses to have a learnable weight parameter
231 (`self.weight`) as in `nn.Linear` or `nn.Conv2d`. It provides a method
232 to apply spectral normalization to that weight.
235 num_svs (int): Number of singular values.
236 num_itrs (int): Number of power iterations per step.
237 transpose (bool): Whether to transpose the weight matrix.
238 eps (float): Small constant to avoid divide-by-zero.
239 u (list[Tensor]): Registered left singular vectors (buffers).
240 sv (list[Tensor]): Registered singular values (buffers).
241 training (bool): Inherited from nn.Module. True if in training mode.
244 def __init__(self, num_svs, num_itrs, num_outputs, transpose=False, eps=1e-12):
248 ## Number of power iterations per step
249 self.num_itrs = num_itrs
250 ## Number of singular values
251 self.num_svs = num_svs
253 self.transpose = transpose
254 ## Epsilon value for avoiding divide-by-0
256 # Register a singular vector for each sv
257 for i in range(self.num_svs):
258 self.register_buffer(f"u{i:d}", torch.randn(1, num_outputs))
259 self.register_buffer(f"sv{i:d}", torch.ones(1))
260 ## Training mode flag (inherited from nn.Module). True if the module is in training mode.
266 Singular vectors (u side)
268 return [getattr(self, f"u{i:d}") for i in range(self.num_svs)]
274 note that these buffers are just for logging and are not used in training.
276 return [getattr(self, f"sv{i:d}") for i in range(self.num_svs)]
280 Compute the spectrally-normalized weight
282 W_mat = self.weight.view(self.weight.size(0), -1)
285 # Apply num_itrs power iterations
286 for _ in range(self.num_itrs):
287 svs, _, _ = power_iteration(
288 W_mat, self.u, update=self.training, eps=self.eps
292 # Make sure to do this in a no_grad() context or you'll get memory leaks! # noqa
293 with torch.no_grad():
294 for i, sv in enumerate(svs):
296 return self.weight / svs[0]
299class SNConv2d(nn.Conv2d, SN):
301 2D Conv layer with spectral norm
330 SN.__init__(self, num_svs, num_itrs, out_channels, eps=eps)
333 def forward(self, x):
337 # \cond false positive doxygen warning
347class SNLinear(nn.Linear, SN):
349 Linear layer with spectral norm
362 nn.Linear.__init__(self, in_features, out_features, bias)
363 SN.__init__(self, num_svs, num_itrs, out_features, eps=eps)
366 def forward(self, x):
367 # \cond false positive doxygen warning
368 return F.linear(x, self.W_(), self.bias)
372def fused_bn(x, mean, var, gain=None, bias=None, eps=1e-5):
373 """Fused batchnorm op"""
375 # Apply scale and shift--if gain and bias are provided, fuse them here
377 scale = torch.rsqrt(var + eps)
378 # If a gain is provided, use it
383 # If bias is provided, use it
386 return x * scale - shift
387 # return ((x - mean) / ((var + eps) ** 0.5)) * gain + bias # The unfused way. # noqa
390def manual_bn(x, gain=None, bias=None, return_mean_var=False, eps=1e-5):
393 Calculate means and variances using mean-of-squares minus mean-squared
396 # Cast x to float32 if necessary
398 # Calculate expected value of x (m) and expected value of x**2 (m2)
400 m = torch.mean(float_x, [0, 2, 3], keepdim=True)
402 m2 = torch.mean(float_x**2, [0, 2, 3], keepdim=True)
403 # Calculate variance as mean of squared minus mean squared.
405 # Cast back to float 16 if necessary
406 var = var.type(x.type())
408 # Return mean and variance for updating stored mean/var if requested
411 fused_bn(x, m, var, gain, bias, eps),
416 return fused_bn(x, m, var, gain, bias, eps)
419class myBN(nn.Module):
421 My batchnorm, supports standing stats
425 def __init__(self, num_channels, eps=1e-5, momentum=0.1):
426 super(myBN, self).__init__()
427 ## momentum for updating running stats
428 self.momentum = momentum
429 ## epsilon to avoid dividing by 0
432 self.momentum = momentum
434 self.register_buffer("stored_mean", torch.zeros(num_channels))
435 self.register_buffer("stored_var", torch.ones(num_channels))
436 self.register_buffer("accumulation_counter", torch.zeros(1))
437 ## Accumulate running means and vars
438 self.accumulate_standing = False
439 ## Training mode flag (inherited from nn.Module). True if the module is in training mode.
442 ## reset standing stats
443 def reset_stats(self):
444 self.stored_mean[:] = 0
445 self.stored_var[:] = 0
446 self.accumulation_counter[:] = 0
449 def forward(self, x, gain, bias):
451 out, mean, var = manual_bn(
452 x, gain, bias, return_mean_var=True, eps=self.eps
454 # If accumulating standing stats, increment them
455 if self.accumulate_standing:
456 self.stored_mean[:] = self.stored_mean + mean.data
457 self.stored_var[:] = self.stored_var + var.data
458 self.accumulation_counter += 1.0
459 # If not accumulating standing stats, take running averages
461 self.stored_mean[:] = (
462 self.stored_mean * (1 - self.momentum) + mean * self.momentum
464 self.stored_var[:] = (
465 self.stored_var * (1 - self.momentum) + var * self.momentum
468 # If not in training mode, use the stored statistics
470 mean = self.stored_mean.view(1, -1, 1, 1)
471 var = self.stored_var.view(1, -1, 1, 1)
472 # If using standing stats, divide them by the accumulation counter
473 if self.accumulate_standing:
474 mean = mean / self.accumulation_counter
475 var = var / self.accumulation_counter
476 return fused_bn(x, mean, var, gain, bias, self.eps)
481 Normal, non-class-conditional BN
493 super(bn, self).__init__()
495 self.output_size = output_size
496 ## Prepare gain and bias layers
497 self.gain = P(torch.ones(output_size), requires_grad=True)
499 self.bias = P(torch.zeros(output_size), requires_grad=True)
500 ## epsilon to avoid dividing by 0
503 self.momentum = momentum
504 ## Use cross-replica batchnorm?
505 self.cross_replica = cross_replica
510 self.bn = myBN(output_size, self.eps, self.momentum)
511 # Register buffers if neither of the above
513 ## Running mean buffer, updated during training
514 self.stored_mean = torch.zeros(output_size)
515 self.register_buffer("stored_mean", torch.zeros(output_size))
516 ## Running variance buffer, updated during training
517 self.stored_var = torch.ones(output_size)
518 self.register_buffer("stored_var", torch.ones(output_size))
520 ## Training mode flag (inherited from nn.Module). True if the module is in training mode.
524 def forward(self, x):
526 gain = self.gain.view(1, -1, 1, 1)
527 bias = self.bias.view(1, -1, 1, 1)
528 return self.bn(x, gain=gain, bias=bias)
542class ccbn(nn.Module):
545 output size is the number of channels, input size is for the linear layers
546 Andy's Note: this class feels messy but I'm not really sure how to clean it up # noqa
547 Suggestions welcome! (By which I mean, refactor this and make a merge request
548 if you want to make this more readable/usable).
563 super(ccbn, self).__init__()
565 self.output_size, self.input_size = output_size, input_size
566 ## Prepare gain and bias layers
567 self.gain = which_linear(input_size, output_size)
569 self.bias = which_linear(input_size, output_size)
570 ## epsilon to avoid dividing by 0
573 self.momentum = momentum
574 ## Use cross-replica batchnorm?
575 self.cross_replica = cross_replica
579 self.norm_style = norm_style
583 self.bn = myBN(output_size, self.eps, self.momentum)
584 elif self.norm_style in ["bn", "in"]:
585 self.register_buffer("stored_mean", torch.zeros(output_size))
586 self.register_buffer("stored_var", torch.ones(output_size))
589 def forward(self, x, y):
590 # Calculate class-conditional gains and biases
591 gain = (1 + self.gain(y)).view(y.size(0), -1, 1, 1)
592 bias = self.bias(y).view(y.size(0), -1, 1, 1)
593 # If using my batchnorm
595 return self.bn(x, gain=gain, bias=bias)
598 if self.norm_style == "bn":
609 elif self.norm_style == "in":
610 out = F.instance_norm(
620 elif self.norm_style == "gn":
621 out = groupnorm(x, self.normstyle)
622 elif self.norm_style == "nonorm":
624 return out * gain + bias
627 def extra_repr(self):
628 s = "out: {output_size}, in: {input_size},"
629 s += " cross_replica={cross_replica}"
630 # \cond false positive doxygen warning
631 return s.format(**self.__dict__)
637 Image_Linear_Attention
656 chan_out = chan if chan_out is None else chan_out
659 self.key_dim = key_dim
661 self.value_dim = value_dim
666 self.norm_queries = norm_queries
668 conv_kwargs = {"padding": padding, "stride": stride}
670 self.to_q = nn.Conv2d(chan, key_dim * heads, kernel_size, **conv_kwargs)
672 self.to_k = nn.Conv2d(chan, key_dim * heads, kernel_size, **conv_kwargs)
674 self.to_v = nn.Conv2d(chan, value_dim * heads, kernel_size, **conv_kwargs)
676 out_conv_kwargs = {"padding": padding}
678 self.to_out = nn.Conv2d(
679 value_dim * heads, chan_out, kernel_size, **out_conv_kwargs
683 def forward(self, x, context=None):
684 b, c, h, w, k_dim, heads = *x.shape, self.key_dim, self.heads
686 q, k, v = (self.to_q(x), self.to_k(x), self.to_v(x))
688 q, k, v = map(lambda t: t.reshape(b, heads, -1, h * w), (q, k, v))
690 q, k = map(lambda x: x * (self.key_dim**-0.25), (q, k))
692 if context is not None:
693 context = context.reshape(b, c, 1, -1)
694 ck, cv = self.to_k(context), self.to_v(context)
695 ck, cv = map(lambda t: t.reshape(b, heads, k_dim, -1), (ck, cv))
696 k = torch.cat((k, ck), dim=3)
697 v = torch.cat((v, cv), dim=3)
699 k = k.softmax(dim=-1)
701 if self.norm_queries:
702 q = q.softmax(dim=-2)
704 context = torch.einsum("bhdn,bhen->bhde", k, v)
705 out = torch.einsum("bhdn,bhde->bhen", q, context)
706 out = out.reshape(b, -1, h, w)
707 out = self.to_out(out)
711class CBAM_attention(nn.Module):
720 attention_kernel_size=3,
722 super(CBAM_attention, self).__init__()
724 self.avg_pool = nn.AdaptiveAvgPool2d(1)
726 self.max_pool = nn.AdaptiveMaxPool2d(1)
728 self.fc1 = which_conv(
729 channels, channels // reduction, kernel_size=1, padding=0
732 self.relu = nn.ReLU(inplace=True)
734 self.fc2 = which_conv(
735 channels // reduction, channels, kernel_size=1, padding=0
738 self.sigmoid_channel = nn.Sigmoid()
739 ## convolution after concatenation
740 self.conv_after_concat = which_conv(
743 kernel_size=attention_kernel_size,
745 padding=attention_kernel_size // 2,
748 self.sigmoid_spatial = nn.Sigmoid()
751 def forward(self, x):
752 # Channel attention module
754 avg = self.avg_pool(x)
755 mx = self.max_pool(x)
763 x = self.sigmoid_channel(x)
764 # Spatial attention module
767 # b, c, h, w = x.size()
768 avg = torch.mean(x, 1, True)
769 mx, _ = torch.max(x, 1, True)
770 x = torch.cat((avg, mx), 1)
771 x = self.conv_after_concat(x)
772 x = self.sigmoid_spatial(x)
777class Attention(nn.Module):
781 def __init__(self, ch, which_conv=SNConv2d):
782 super(Attention, self).__init__()
783 ## Channel multiplier
786 self.which_conv = which_conv
788 self.theta = self.which_conv(
789 self.ch, self.ch // 8, kernel_size=1, padding=0, bias=False
792 self.phi = self.which_conv(
793 self.ch, self.ch // 8, kernel_size=1, padding=0, bias=False
796 self.g = self.which_conv(
797 self.ch, self.ch // 2, kernel_size=1, padding=0, bias=False
800 self.o = self.which_conv(
801 self.ch // 2, self.ch, kernel_size=1, padding=0, bias=False
803 ## Learnable gain parameter
804 self.gamma = P(torch.tensor(0.0), requires_grad=True)
807 def forward(self, x):
809 theta = self.theta(x)
810 phi = F.max_pool2d(self.phi(x), [2, 2])
811 g = F.max_pool2d(self.g(x), [2, 2])
813 theta = theta.view(-1, self.ch // 8, x.shape[2] * x.shape[3])
814 phi = phi.view(-1, self.ch // 8, x.shape[2] * x.shape[3] // 4)
815 g = g.view(-1, self.ch // 2, x.shape[2] * x.shape[3] // 4)
816 # Matmul and softmax to get attention maps
817 beta = F.softmax(torch.bmm(theta.transpose(1, 2), phi), -1)
818 # Attention map times g path
820 torch.bmm(g, beta.transpose(1, 2)).view(
821 -1, self.ch // 2, x.shape[2], x.shape[3]
824 return self.gamma * o + x
827class SNEmbedding(nn.Embedding, SN):
829 Embedding layer with spectral norm
830 We use num_embeddings as the dim instead of embedding_dim here
842 scale_grad_by_freq=False,
849 nn.Embedding.__init__(
860 SN.__init__(self, num_svs, num_itrs, num_embeddings, eps=eps)
863 def forward(self, x):
864 return F.embedding(x, self.W_())
867def scaled_dot_product(q, k, v):
869 attn_logits = torch.matmul(q, k.transpose(-2, -1))
870 attn_logits = attn_logits / math.sqrt(d_k)
871 attention = F.softmax(attn_logits, dim=-1)
872 values = torch.matmul(attention, v)
873 return values, attention
876class MultiheadAttention(nn.Module):
877 """MultiheadAttention"""
880 def __init__(self, input_dim, embed_dim, num_heads, which_linear):
883 embed_dim % num_heads == 0
884 ), "Embedding dimension must be 0 modulo number of heads."
886 ## embedding dimension
887 self.embed_dim = embed_dim
889 self.num_heads = num_heads
891 self.head_dim = embed_dim // num_heads
893 self.which_linear = which_linear
895 # Stack all weight matrices 1...h together for efficiency
897 self.qkv_proj = self.which_linear(input_dim, 3 * embed_dim)
899 self.o_proj = self.which_linear(embed_dim, embed_dim)
901 self._reset_parameters()
904 def _reset_parameters(self):
905 # Original Transformer initialization, see PyTorch documentation
906 nn.init.xavier_uniform_(self.qkv_proj.weight)
907 self.qkv_proj.bias.data.fill_(0)
908 nn.init.xavier_uniform_(self.o_proj.weight)
909 self.o_proj.bias.data.fill_(0)
912 def forward(self, x, return_attention=False):
913 batch_size, seq_length, embed_dim = x.size()
914 qkv = self.qkv_proj(x)
916 # Separate Q, K, V from linear output
917 qkv = qkv.reshape(batch_size, seq_length, self.num_heads, 3 * self.head_dim)
918 qkv = qkv.permute(0, 2, 1, 3) # [Batch, Head, SeqLen, Dims]
919 q, k, v = qkv.chunk(3, dim=-1)
921 # Determine value outputs
922 values, attention = scaled_dot_product(q, k, v)
923 values = values.permute(0, 2, 1, 3) # [Batch, SeqLen, Head, Dims]
924 values = values.reshape(batch_size, seq_length, embed_dim)
925 o = self.o_proj(values)
933class EncoderBlock(nn.Module):
937 def __init__(self, input_dim, num_heads, dim_feedforward, dropout, which_linear):
940 input_dim - Dimensionality of the input
941 num_heads - Number of heads to use in the attention block
942 dim_feedforward - Dimensionality of the hidden layer in the MLP
943 dropout - Dropout probability to use in the dropout layers
948 self.which_linear = which_linear
950 self.self_attn = MultiheadAttention(
951 input_dim, input_dim, num_heads, which_linear
955 self.linear_net = nn.Sequential(
956 self.which_linear(input_dim, dim_feedforward),
958 nn.ReLU(inplace=True),
959 self.which_linear(dim_feedforward, input_dim),
962 # Layers to apply in between the main layers
964 self.norm1 = nn.LayerNorm(input_dim)
966 self.norm2 = nn.LayerNorm(input_dim)
968 self.dropout = nn.Dropout(dropout)
971 def forward(self, x):
973 x_pre1 = self.norm1(x)
974 attn_out = self.self_attn(x_pre1)
975 x = x + self.dropout(attn_out)
979 x_pre2 = self.norm2(x)
980 linear_out = self.linear_net(x_pre2)
981 x = x + self.dropout(linear_out)
987class RelationalReasoning(nn.Module):
988 """RelationalReasoning"""
991 def __init__(self, num_layers, hidden_dim, **block_args):
994 self.layers = nn.ModuleList(
995 [EncoderBlock(**block_args) for _ in range(num_layers)]
998 self.norm = nn.LayerNorm(hidden_dim)
1001 def forward(self, x):
1002 for layer in self.layers:
1008 ## get attention maps
1009 def get_attention_maps(self, x):
1011 for layer in self.layers:
1012 _, attn_map = layer.self_attn(x, return_attention=True)
1013 attention_maps.append(attn_map)
1015 return attention_maps
1018class GBlock(nn.Module):
1026 which_conv=SNConv2d,
1032 super(GBlock, self).__init__()
1035 self.in_channels, self.out_channels = in_channels, out_channels
1037 self.hidden_channels = self.in_channels // channel_ratio
1038 ## which convolution
1039 self.which_conv, self.which_bn = which_conv, which_bn
1041 self.activation = activation
1044 self.conv1 = self.which_conv(
1045 self.in_channels, self.hidden_channels, kernel_size=1, padding=0
1048 self.conv2 = self.which_conv(self.hidden_channels, self.hidden_channels)
1050 self.conv3 = self.which_conv(self.hidden_channels, self.hidden_channels)
1052 self.conv4 = self.which_conv(
1053 self.hidden_channels, self.out_channels, kernel_size=1, padding=0
1057 self.bn1 = self.which_bn(self.in_channels)
1059 self.bn2 = self.which_bn(self.hidden_channels)
1061 self.bn3 = self.which_bn(self.hidden_channels)
1063 self.bn4 = self.which_bn(self.hidden_channels)
1065 self.upsample = upsample
1068 def forward(self, x, y):
1069 # Project down to channel ratio
1070 h = self.conv1(self.activation(self.bn1(x, y)))
1071 # Apply next BN-ReLU
1072 h = self.activation(self.bn2(h, y))
1073 # Drop channels in x if necessary
1074 if self.in_channels != self.out_channels:
1075 x = x[:, : self.out_channels]
1076 # Upsample both h and x at this point
1078 h = self.upsample(h)
1079 x = self.upsample(x)
1082 h = self.conv3(self.activation(self.bn3(h, y)))
1084 h = self.conv4(self.activation(self.bn4(h, y)))
1088def G_arch(ch=64, attention="64"):
1091 "in_channels": [ch * item for item in [16, 16, 8, 8, 4, 2, 1]],
1092 "out_channels": [ch * item for item in [16, 8, 8, 4, 2, 1, 1]],
1093 "upsample": [True] * 7,
1094 "resolution": [8, 16, 32, 64, 128, 256, 512],
1096 2**i: (2 ** i in [int(item) for item in attention.split("_")])
1097 for i in range(3, 10)
1101 "in_channels": [ch * item for item in [16, 16, 8, 8, 4, 2]],
1102 "out_channels": [ch * item for item in [16, 8, 8, 4, 2, 1]],
1103 "upsample": [True] * 6,
1104 "resolution": [8, 16, 32, 64, 128, 256],
1106 2**i: (2 ** i in [int(item) for item in attention.split("_")])
1107 for i in range(3, 9)
1111 "in_channels": [ch * item for item in [16, 16, 8, 4, 2]],
1112 "out_channels": [ch * item for item in [16, 8, 4, 2, 1]],
1113 "upsample": [True] * 5,
1114 "resolution": [8, 16, 32, 64, 128],
1116 2**i: (2 ** i in [int(item) for item in attention.split("_")])
1117 for i in range(3, 8)
1121 "in_channels": [ch * item for item in [16, 16, 8, 4]],
1122 "out_channels": [ch * item for item in [16, 8, 4, 2]],
1123 "upsample": [True] * 4,
1124 "resolution": [12, 24, 48, 96],
1126 12 * 2**i: (6 * 2 ** i in [int(item) for item in attention.split("_")])
1127 for i in range(0, 4)
1132 "in_channels": [ch * item for item in [16, 16, 8, 4]],
1133 "out_channels": [ch * item for item in [16, 8, 4, 2]],
1134 "upsample": [True] * 4,
1135 "resolution": [8, 16, 32, 64],
1137 2**i: (2 ** i in [int(item) for item in attention.split("_")])
1138 for i in range(3, 7)
1142 "in_channels": [ch * item for item in [4, 4, 4]],
1143 "out_channels": [ch * item for item in [4, 4, 4]],
1144 "upsample": [True] * 3,
1145 "resolution": [8, 16, 32],
1147 2**i: (2 ** i in [int(item) for item in attention.split("_")])
1148 for i in range(3, 6)
1155class Generator(nn.Module):
1176 cross_replica=False,
1178 G_activation=nn.ReLU(inplace=False),
1186 G_mixed_precision=False,
1190 sched_version="default",
1197 super(Generator, self).__init__()
1198 ## Channel width multiplier
1200 ## Number of resblocks per stage
1201 self.G_depth = G_depth
1202 ## Dimensionality of the latent space
1204 ## The initial spatial dimensions
1205 self.bottom_width = bottom_width
1206 ## The initial harizontal dimension
1207 self.H_base = H_base
1208 ## Resolution of the output
1209 self.resolution = resolution
1211 self.kernel_size = G_kernel_size
1213 self.attention = G_attn
1214 ## number of classes, for use in categorical conditional generation
1215 self.n_classes = n_classes
1216 ## Use shared embeddings?
1217 self.G_shared = G_shared
1218 ## Dimensionality of the shared embedding? Unused if not using G_shared
1219 self.shared_dim = shared_dim if shared_dim > 0 else dim_z
1220 ## Hierarchical latent space?
1222 ## Cross replica batchnorm?
1223 self.cross_replica = cross_replica
1224 ## Use my batchnorm?
1226 # nonlinearity for residual blocks
1227 if G_activation == "inplace_relu":
1229 self.activation = torch.nn.ReLU(inplace=True)
1230 elif G_activation == "relu":
1231 self.activation = torch.nn.ReLU(inplace=False)
1232 elif G_activation == "leaky_relu":
1233 self.activation = torch.nn.LeakyReLU(0.2, inplace=False)
1235 raise NotImplementedError("activation function not implemented")
1236 ## Initialization style
1238 ## Parameterization style
1239 self.G_param = G_param
1240 ## Normalization style
1241 self.norm_style = norm_style
1242 ## Epsilon for BatchNorm?
1243 self.BN_eps = BN_eps
1244 ## Epsilon for Spectral Norm?
1245 self.SN_eps = SN_eps
1248 ## Architecture dict
1249 self.arch = G_arch(self.ch, self.attention)[resolution]
1251 self.RRM_prx_G = RRM_prx_G
1253 self.n_head_G = n_head_G
1255 # Which convs, batchnorms, and linear layers to use
1256 if self.G_param == "SN":
1258 self.which_conv = functools.partial(
1263 num_itrs=num_G_SV_itrs,
1267 self.which_linear = functools.partial(
1270 num_itrs=num_G_SV_itrs,
1274 self.which_conv = functools.partial(nn.Conv2d, kernel_size=3, padding=1)
1275 self.which_linear = nn.Linear
1277 # We use a non-spectral-normed embedding here regardless;
1278 # For some reason applying SN to G's embedding seems to randomly cripple G # noqa
1280 self.which_embedding = nn.Embedding
1282 functools.partial(self.which_linear, bias=False)
1284 else self.which_embedding
1287 self.which_bn = functools.partial(
1289 which_linear=bn_linear,
1290 cross_replica=self.cross_replica,
1293 self.shared_dim + self.dim_z if self.G_shared else self.n_classes
1295 norm_style=self.norm_style,
1300 self.which_embedding(n_classes, self.shared_dim)
1306 ## RRM on proxy embeddings
1307 self.RR_G = RelationalReasoning(
1310 dim_feedforward=128,
1311 which_linear=nn.Linear,
1312 num_heads=self.n_head_G,
1317 ## First linear layer
1318 self.linear = self.which_linear(
1319 self.dim_z + self.shared_dim,
1320 self.arch["in_channels"][0] * ((self.bottom_width**2) * self.H_base),
1323 # self.blocks is a doubly-nested list of modules, the outer loop intended # noqa
1324 # to be over blocks at a given resolution (resblocks and/or self-attention) # noqa
1325 # while the inner loop is over a given block
1328 for index in range(len(self.arch["out_channels"])):
1332 in_channels=self.arch["in_channels"][index],
1333 out_channels=self.arch["in_channels"][index]
1335 else self.arch["out_channels"][index],
1336 which_conv=self.which_conv,
1337 which_bn=self.which_bn,
1338 activation=self.activation,
1340 functools.partial(F.interpolate, scale_factor=2)
1341 if self.arch["upsample"][index]
1342 and g_index == (self.G_depth - 1)
1347 for g_index in range(self.G_depth)
1350 # If attention on this block, attach it to the end
1351 if self.arch["attention"][self.arch["resolution"][index]]:
1353 f"Adding attention layer in G at resolution {self.arch['resolution'][index]:d}"
1356 if attn_type == "sa":
1357 self.blocks[-1] += [
1358 Attention(self.arch["out_channels"][index], self.which_conv)
1360 elif attn_type == "cbam":
1361 self.blocks[-1] += [
1363 self.arch["out_channels"][index], self.which_conv
1366 elif attn_type == "ila":
1367 self.blocks[-1] += [ILA(self.arch["out_channels"][index])]
1369 # Turn self.blocks into a ModuleList so that it's all properly registered. # noqa
1370 self.blocks = nn.ModuleList([nn.ModuleList(block) for block in self.blocks])
1372 # output layer: batchnorm-relu-conv.
1373 # Consider using a non-spectral conv here
1375 self.output_layer = nn.Sequential(
1377 self.arch["out_channels"][-1],
1378 cross_replica=self.cross_replica,
1382 self.which_conv(self.arch["out_channels"][-1], 1),
1385 # Initialize weights. Optionally skip init for testing.
1390 # If this is an EMA copy, no need for an optim, so just return now
1400 self.adam_eps = adam_eps
1401 if G_mixed_precision:
1402 print("Using fp16 adam in G...")
1405 self.optim = utils.Adam16(
1406 params=self.parameters(),
1408 betas=(self.B1, self.B2),
1414 self.optim = optim.Adam(
1415 params=self.parameters(),
1417 betas=(self.B1, self.B2),
1422 if sched_version == "default":
1424 self.lr_sched = None
1425 elif sched_version == "CosAnnealLR":
1426 self.lr_sched = optim.lr_scheduler.CosineAnnealingLR(
1428 T_max=kwargs["num_epochs"],
1429 eta_min=self.lr / 4,
1432 elif sched_version == "CosAnnealWarmRes":
1433 self.lr_sched = optim.lr_scheduler.CosineAnnealingWarmRestarts(
1434 self.optim, T_0=10, T_mult=2, eta_min=self.lr / 4
1437 self.lr_sched = None
1440 def init_weights(self):
1442 self.param_count = 0
1443 for module in self.modules():
1445 isinstance(module, nn.Conv2d)
1446 or isinstance(module, nn.Linear)
1447 or isinstance(module, nn.Embedding)
1449 if self.init == "ortho":
1450 init.orthogonal_(module.weight)
1451 elif self.init == "N02":
1452 init.normal_(module.weight, 0, 0.02)
1453 elif self.init in ["glorot", "xavier"]:
1454 init.xavier_uniform_(module.weight)
1456 print("Init style not recognized...")
1457 self.param_count += sum(
1458 [p.data.nelement() for p in module.parameters()]
1460 print(f"Param count for G's initialized parameters: {self.param_count}")
1463 def forward(self, z, y):
1465 # If relational embedding
1467 y = self.RR_G(y.unsqueeze(0)).squeeze(0)
1468 # y = F.normalize(y, dim=1)
1469 # If hierarchical, concatenate zs and ys
1470 if self.hier: # y and z are [bs,128] dimensional
1471 z = torch.cat([y, z], 1)
1473 # First linear layer
1474 h = self.linear(z) # ([bs,256]-->[bs,24576])
1476 h = h.view(h.size(0), -1, self.bottom_width, self.bottom_width * self.H_base)
1478 for _, blocklist in enumerate(self.blocks):
1479 # Second inner loop in case block has multiple layers
1480 for block in blocklist:
1483 # Apply batchnorm-relu-conv-tanh at output
1484 return torch.tanh(self.output_layer(h))
1487class Model(Generator):
1490 default initializing with CONFIG dict
1495 super().__init__(**CONFIG)
1498def generate(model: nn.Module):
1500 Run inference with the provided Generator model
1503 model (nn.Module): Generator model
1506 torch.Tensor: batch of 40 PXD images
1508 device = next(model.parameters()).device
1509 with torch.no_grad():
1510 latents = torch.randn(40, 128, device=device)
1511 labels = torch.tensor(list(range(40)), dtype=torch.long, device=device)
1512 imgs = model(latents, labels).detach().cpu()
1513 # Cut the noise below 7 ADU
1514 imgs = F.threshold(imgs, -0.26, -1)
1515 # center range [-1, 1] to [0, 1]
1516 imgs = imgs.mul_(0.5).add_(0.5)
1517 # renormalize and convert to uint8
1518 imgs = torch.pow(256, imgs).add_(-1).clamp_(0, 255).to(torch.uint8)
1519 # flatten channel dimension and crop 256 to 250
1520 imgs = imgs[:, 0, 3:-3, :]
__init__(self, num_svs, num_itrs, num_outputs, transpose=False, eps=1e-12)
forward(self, torch.Tensor tensor)
forward