Belle II Software development
ieagan.py
1
8"""
9This module implements the IEA-GAN generator model.
10"""
11
12import functools
13import json
14import math
15
16import torch
17from torch import nn
18from torch import optim
19from torch.nn import init
20from torch.nn import Parameter as P
21import torch.nn.functional as F
22
23
24CONFIG = json.loads(
25 """{
26 "num_workers": 8,
27 "seed": 415,
28 "pin_memory": false,
29 "shuffle": true,
30 "augment": 0,
31 "use_multiepoch_sampler": false,
32 "model": "BigGAN_deep",
33 "G_ch": 32,
34 "G_param" : "SN",
35 "D_param" : "SN",
36 "D_ch": 32,
37 "G_depth": 2,
38 "D_depth": 2,
39 "H_base": 3,
40 "D_wide": true,
41 "G_shared": true,
42 "shared_dim": 128,
43 "dim_z": 128,
44 "z_var": 1.0,
45 "hier": true,
46 "cross_replica": false,
47 "mybn": false,
48 "G_activation": "inplace_relu",
49 "D_activation": "inplace_relu",
50 "G_attn": "0",
51 "D_attn": "0",
52 "norm_style": "bn",
53 "G_init": "ortho",
54 "D_init": "ortho",
55 "skip_init": false,
56 "G_lr": 5e-05,
57 "D_lr": 5e-05,
58 "G_B1": 0.0,
59 "D_B1": 0.0,
60 "G_B2": 0.999,
61 "D_B2": 0.999,
62 "batch_size": 40,
63 "G_batch_size": 0,
64 "num_G_accumulations": 1,
65 "num_D_steps": 1,
66 "num_D_accumulations": 1,
67 "split_D": true,
68 "num_epochs": 4,
69 "parallel": false,
70 "G_fp16": false,
71 "D_fp16": false,
72 "D_mixed_precision": false,
73 "G_mixed_precision": false,
74 "accumulate_stats": false,
75 "num_standing_accumulations": 16,
76 "G_eval_mode": true,
77 "save_every": 1000,
78 "test_every": 1000,
79 "num_save_copies": 2,
80 "num_best_copies": 2,
81 "ema": true,
82 "ema_decay": 0.9999,
83 "use_ema": true,
84 "ema_start": 10000,
85 "adam_eps": 1e-06,
86 "BN_eps": 1e-05,
87 "SN_eps": 1e-06,
88 "num_G_SVs": 1,
89 "num_D_SVs": 1,
90 "num_G_SV_itrs": 1,
91 "num_D_SV_itrs": 1,
92 "G_ortho": 0.0001,
93 "D_ortho": 0.0,
94 "toggle_grads": true,
95 "logstyle": "%3.3e",
96 "sv_log_interval": 10,
97 "log_interval": 100,
98 "resolution": 256,
99 "n_classes": 40,
100 "run_name": "BGd_140",
101 "resume": false,
102 "latent_op": false,
103 "latent_reg_weight": 300,
104 "bottom_width": 4,
105 "add_blur" : false,
106 "add_noise": true,
107 "add_style": false,
108 "conditional_strategy": "Contra",
109 "hypersphere_dim": 1024,
110 "pos_collected_numerator": false,
111 "nonlinear_embed": false,
112 "normalize_embed": true,
113 "inv_stereographic" :false,
114 "contra_lambda": 1.0,
115 "Angle": false,
116 "angle_lambda": 1.0,
117 "IEA_loss": true,
118 "IEA_lambda": 1.0,
119 "Uniformity_loss": true,
120 "unif_lambda": 0.1,
121 "diff_aug": true,
122 "Con_reg": false,
123 "cr_lambda": 10,
124 "pixel_reg": false,
125 "px_lambda": 1.0,
126 "RRM_prx_G": true,
127 "normalized_proxy_G": false,
128 "RRM_prx_D": false,
129 "RRM_embed": true,
130 "n_head_G": 2,
131 "n_head": 4,
132 "attn_type": "sa",
133 "sched_version": "default",
134 "z_dist": "normal",
135 "truncated_threshold": 1.0,
136 "clip_norm": "None",
137 "amsgrad": false,
138 "arch": "None",
139 "G_kernel_size": 3,
140 "D_kernel_size": 3,
141 "ada_belief": false,
142 "pbar": "tqdm",
143 "which_best": "FID",
144 "stop_after": 100000,
145 "trunc_z": 0.5,
146 "denoise": false,
147 "metric_log_name": "metric_log.jsonl",
148 "reinitialize_metric_logs": false,
149 "reinitialize_parameter_logs": false,
150 "num_incep_images": 16000,
151 "load_optim": true}"""
152)
153
154
155def proj(x, y):
156 """
157 Projection of x onto y
158 """
159 return torch.mm(y, x.t()) * y / torch.mm(y, y.t())
160
161
162def gram_schmidt(x, ys):
163 """
164 Orthogonalize x wrt list of vectors ys
165 """
166 for y in ys:
167 x = x - proj(x, y)
168 return x
169
170
171def power_iteration(W, u_, update=True, eps=1e-12):
172 """
173 Apply num_itrs steps of the power method to estimate top N singular values.
174 """
175 # Lists holding singular vectors and values
176 us, vs, svs = [], [], []
177 for i, u in enumerate(u_):
178 # Run one step of the power iteration
179 with torch.no_grad():
180 v = torch.matmul(u, W)
181 # Run Gram-Schmidt to subtract components of all other singular vectors # noqa
182 v = F.normalize(gram_schmidt(v, vs), eps=eps)
183 # Add to the list
184 vs += [v]
185 # Update the other singular vector
186 u = torch.matmul(v, W.t())
187 # Run Gram-Schmidt to subtract components of all other singular vectors # noqa
188 u = F.normalize(gram_schmidt(u, us), eps=eps)
189 # Add to the list
190 us += [u]
191 if update:
192 u_[i][:] = u
193 # Compute this singular value and add it to the list
194 svs += [torch.squeeze(torch.matmul(torch.matmul(v, W.t()), u.t()))]
195 # svs += [torch.sum(F.linear(u, W.transpose(0, 1)) * v)]
196 return svs, us, vs
197
198
199def groupnorm(x, norm_style):
200 """
201 Simple function to handle groupnorm norm stylization
202 """
203 # If number of channels specified in norm_style:
204 if "ch" in norm_style:
205 ch = int(norm_style.split("_")[-1])
206 groups = max(int(x.shape[1]) // ch, 1)
207 # If number of groups specified in norm style
208 elif "grp" in norm_style:
209 groups = int(norm_style.split("_")[-1])
210 # If neither, default to groups = 16
211 else:
212 groups = 16
213 return F.group_norm(x, groups)
214
215
216class identity(nn.Module):
217 """
218 Convenience passthrough function
219 """
220
221
222 def forward(self, tensor: torch.Tensor):
223 return tensor
224
225
226class SN(object):
227 """
228 Spectral normalization base class
229 """
230 # pylint: disable=no-member
231
232 def __init__(self, num_svs, num_itrs, num_outputs, transpose=False, eps=1e-12):
233 """constructor"""
234
235 ## Number of power iterations per step
236 self.num_itrs = num_itrs
237 ## Number of singular values
238 self.num_svs = num_svs
239 ## Transposed?
240 self.transpose = transpose
241 ## Epsilon value for avoiding divide-by-0
242 self.eps = eps
243 # Register a singular vector for each sv
244 for i in range(self.num_svs):
245 self.register_buffer(f"u{i:d}", torch.randn(1, num_outputs))
246 self.register_buffer(f"sv{i:d}", torch.ones(1))
247
248 @property
249 def u(self):
250 """
251 Singular vectors (u side)
252 """
253 return [getattr(self, f"u{i:d}") for i in range(self.num_svs)]
254
255 @property
256 def sv(self):
257 """
258 Singular values
259 note that these buffers are just for logging and are not used in training.
260 """
261 return [getattr(self, f"sv{i:d}") for i in range(self.num_svs)]
262
263 def W_(self):
264 """
265 Compute the spectrally-normalized weight
266 """
267 W_mat = self.weight.view(self.weight.size(0), -1)
268 if self.transpose:
269 W_mat = W_mat.t()
270 # Apply num_itrs power iterations
271 for _ in range(self.num_itrs):
272 svs, _, _ = power_iteration(
273 W_mat, self.u, update=self.training, eps=self.eps
274 )
275 # Update the svs
276 if self.training:
277 # Make sure to do this in a no_grad() context or you'll get memory leaks! # noqa
278 with torch.no_grad():
279 for i, sv in enumerate(svs):
280 self.sv[i][:] = sv
281 return self.weight / svs[0]
282
283
284class SNConv2d(nn.Conv2d, SN):
285 """
286 2D Conv layer with spectral norm
287 """
288
289 ## Constructor
290 def __init__(
291 self,
292 in_channels,
293 out_channels,
294 kernel_size,
295 stride=1,
296 padding=0,
297 dilation=1,
298 groups=1,
299 bias=True,
300 num_svs=1,
301 num_itrs=1,
302 eps=1e-12,
303 ):
304 nn.Conv2d.__init__(
305 self,
306 in_channels,
307 out_channels,
308 kernel_size,
309 stride,
310 padding,
311 dilation,
312 groups,
313 bias,
314 )
315 SN.__init__(self, num_svs, num_itrs, out_channels, eps=eps)
316
317 ## forward
318 def forward(self, x):
319 return F.conv2d(
320 x,
321 self.W_(),
322 self.bias,
323 self.stride,
324 self.padding,
325 self.dilation,
326 self.groups,
327 )
328
329
330class SNLinear(nn.Linear, SN):
331 """
332 Linear layer with spectral norm
333 """
334
335 ## Constructor
336 def __init__(
337 self,
338 in_features,
339 out_features,
340 bias=True,
341 num_svs=1,
342 num_itrs=1,
343 eps=1e-12,
344 ):
345 nn.Linear.__init__(self, in_features, out_features, bias)
346 SN.__init__(self, num_svs, num_itrs, out_features, eps=eps)
347
348 ## forward
349 def forward(self, x):
350 return F.linear(x, self.W_(), self.bias)
351
352
353def fused_bn(x, mean, var, gain=None, bias=None, eps=1e-5):
354 """Fused batchnorm op"""
355
356 # Apply scale and shift--if gain and bias are provided, fuse them here
357 # Prepare scale
358 scale = torch.rsqrt(var + eps)
359 # If a gain is provided, use it
360 if gain is not None:
361 scale = scale * gain
362 # Prepare shift
363 shift = mean * scale
364 # If bias is provided, use it
365 if bias is not None:
366 shift = shift - bias
367 return x * scale - shift
368 # return ((x - mean) / ((var + eps) ** 0.5)) * gain + bias # The unfused way. # noqa
369
370
371def manual_bn(x, gain=None, bias=None, return_mean_var=False, eps=1e-5):
372 """
373 Manual BN
374 Calculate means and variances using mean-of-squares minus mean-squared
375 """
376
377 # Cast x to float32 if necessary
378 float_x = x.float()
379 # Calculate expected value of x (m) and expected value of x**2 (m2)
380 # Mean of x
381 m = torch.mean(float_x, [0, 2, 3], keepdim=True)
382 # Mean of x squared
383 m2 = torch.mean(float_x**2, [0, 2, 3], keepdim=True)
384 # Calculate variance as mean of squared minus mean squared.
385 var = m2 - m**2
386 # Cast back to float 16 if necessary
387 var = var.type(x.type())
388 m = m.type(x.type())
389 # Return mean and variance for updating stored mean/var if requested
390 if return_mean_var:
391 return (
392 fused_bn(x, m, var, gain, bias, eps),
393 m.squeeze(),
394 var.squeeze(),
395 )
396 else:
397 return fused_bn(x, m, var, gain, bias, eps)
398
399
400class myBN(nn.Module):
401 """
402 My batchnorm, supports standing stats
403 """
404
405 ## Constructor
406 def __init__(self, num_channels, eps=1e-5, momentum=0.1):
407 super(myBN, self).__init__()
408 ## momentum for updating running stats
409 self.momentum = momentum
410 ## epsilon to avoid dividing by 0
411 self.eps = eps
412 # Momentum
413 self.momentum = momentum
414 # Register buffers
415 self.register_buffer("stored_mean", torch.zeros(num_channels))
416 self.register_buffer("stored_var", torch.ones(num_channels))
417 self.register_buffer("accumulation_counter", torch.zeros(1))
418 ## Accumulate running means and vars
419 self.accumulate_standing = False
420
421 ## reset standing stats
422 def reset_stats(self):
423 # pylint: disable=no-member
424 self.stored_mean[:] = 0
425 self.stored_var[:] = 0
426 self.accumulation_counter[:] = 0
427
428 ## forward
429 def forward(self, x, gain, bias):
430 # pylint: disable=no-member
431 if self.training:
432 out, mean, var = manual_bn(
433 x, gain, bias, return_mean_var=True, eps=self.eps
434 )
435 # If accumulating standing stats, increment them
436 if self.accumulate_standing:
437 self.stored_mean[:] = self.stored_mean + mean.data
438 self.stored_var[:] = self.stored_var + var.data
439 self.accumulation_counter += 1.0
440 # If not accumulating standing stats, take running averages
441 else:
442 self.stored_mean[:] = (
443 self.stored_mean * (1 - self.momentum) + mean * self.momentum
444 )
445 self.stored_var[:] = (
446 self.stored_var * (1 - self.momentum) + var * self.momentum
447 )
448 return out
449 # If not in training mode, use the stored statistics
450 else:
451 mean = self.stored_mean.view(1, -1, 1, 1)
452 var = self.stored_var.view(1, -1, 1, 1)
453 # If using standing stats, divide them by the accumulation counter
454 if self.accumulate_standing:
455 mean = mean / self.accumulation_counter
456 var = var / self.accumulation_counter
457 return fused_bn(x, mean, var, gain, bias, self.eps)
458
459
460class bn(nn.Module):
461 """
462 Normal, non-class-conditional BN
463 """
464
465 ## Constructor
466 def __init__(
467 self,
468 output_size,
469 eps=1e-5,
470 momentum=0.1,
471 cross_replica=False,
472 mybn=False,
473 ):
474 super(bn, self).__init__()
475 ## output size
476 self.output_size = output_size
477 ## Prepare gain and bias layers
478 self.gain = P(torch.ones(output_size), requires_grad=True)
479 ## bias
480 self.bias = P(torch.zeros(output_size), requires_grad=True)
481 ## epsilon to avoid dividing by 0
482 self.eps = eps
483 ## Momentum
484 self.momentum = momentum
485 ## Use cross-replica batchnorm?
486 self.cross_replica = cross_replica
487 ## Use my batchnorm?
488 self.mybn = mybn
489
490 if mybn:
491 self.bn = myBN(output_size, self.eps, self.momentum)
492 # Register buffers if neither of the above
493 else:
494 self.register_buffer("stored_mean", torch.zeros(output_size))
495 self.register_buffer("stored_var", torch.ones(output_size))
496
497 ## forward
498 def forward(self, x):
499 if self.mybn:
500 gain = self.gain.view(1, -1, 1, 1)
501 bias = self.bias.view(1, -1, 1, 1)
502 return self.bn(x, gain=gain, bias=bias)
503 else:
504 return F.batch_norm(
505 x,
506 self.stored_mean,
507 self.stored_var,
508 self.gain,
509 self.bias,
510 self.training,
511 self.momentum,
512 self.eps,
513 )
514
515
516class ccbn(nn.Module):
517 """
518 Class-conditional bn
519 output size is the number of channels, input size is for the linear layers
520 Andy's Note: this class feels messy but I'm not really sure how to clean it up # noqa
521 Suggestions welcome! (By which I mean, refactor this and make a merge request
522 if you want to make this more readable/usable).
523 """
524
525 ## Constructor
526 def __init__(
527 self,
528 output_size,
529 input_size,
530 which_linear,
531 eps=1e-5,
532 momentum=0.1,
533 cross_replica=False,
534 mybn=False,
535 norm_style="bn",
536 ):
537 super(ccbn, self).__init__()
538 ## output size
539 self.output_size, self.input_size = output_size, input_size
540 ## Prepare gain and bias layers
541 self.gain = which_linear(input_size, output_size)
542 ## bias
543 self.bias = which_linear(input_size, output_size)
544 ## epsilon to avoid dividing by 0
545 self.eps = eps
546 ## Momentum
547 self.momentum = momentum
548 ## Use cross-replica batchnorm?
549 self.cross_replica = cross_replica
550 ## Use my batchnorm?
551 self.mybn = mybn
552 ## Norm style?
553 self.norm_style = norm_style
554
555 if self.mybn:
556 ## bn
557 self.bn = myBN(output_size, self.eps, self.momentum)
558 elif self.norm_style in ["bn", "in"]:
559 self.register_buffer("stored_mean", torch.zeros(output_size))
560 self.register_buffer("stored_var", torch.ones(output_size))
561
562 ## forward
563 def forward(self, x, y):
564 # Calculate class-conditional gains and biases
565 gain = (1 + self.gain(y)).view(y.size(0), -1, 1, 1)
566 bias = self.bias(y).view(y.size(0), -1, 1, 1)
567 # If using my batchnorm
568 if self.mybn:
569 return self.bn(x, gain=gain, bias=bias)
570 # else:
571 else:
572 if self.norm_style == "bn":
573 out = F.batch_norm(
574 x,
575 self.stored_mean,
576 self.stored_var,
577 None,
578 None,
579 self.training,
580 0.1,
581 self.eps,
582 )
583 elif self.norm_style == "in":
584 out = F.instance_norm(
585 x,
586 self.stored_mean,
587 self.stored_var,
588 None,
589 None,
590 self.training,
591 0.1,
592 self.eps,
593 )
594 elif self.norm_style == "gn":
595 out = groupnorm(x, self.normstyle)
596 elif self.norm_style == "nonorm":
597 out = x
598 return out * gain + bias
599
600 ## extra_repr
601 def extra_repr(self):
602 s = "out: {output_size}, in: {input_size},"
603 s += " cross_replica={cross_replica}"
604 return s.format(**self.__dict__)
605
606
607class ILA(nn.Module):
608 """
609 Image_Linear_Attention
610 """
611
612 ## Constructor
613 def __init__(
614 self,
615 chan,
616 chan_out=None,
617 kernel_size=1,
618 padding=0,
619 stride=1,
620 key_dim=32,
621 value_dim=64,
622 heads=8,
623 norm_queries=True,
624 ):
625 super().__init__()
626 ## chan
627 self.chan = chan
628 chan_out = chan if chan_out is None else chan_out
629
630 ## key dimension
631 self.key_dim = key_dim
632 ## value dimension
633 self.value_dim = value_dim
634 ## heads
635 self.heads = heads
636
637 ## norm queries
638 self.norm_queries = norm_queries
639
640 conv_kwargs = {"padding": padding, "stride": stride}
641 ## q
642 self.to_q = nn.Conv2d(chan, key_dim * heads, kernel_size, **conv_kwargs)
643 ## k
644 self.to_k = nn.Conv2d(chan, key_dim * heads, kernel_size, **conv_kwargs)
645 ## v
646 self.to_v = nn.Conv2d(chan, value_dim * heads, kernel_size, **conv_kwargs)
647
648 out_conv_kwargs = {"padding": padding}
649 ## to out
650 self.to_out = nn.Conv2d(
651 value_dim * heads, chan_out, kernel_size, **out_conv_kwargs
652 )
653
654 ## forward
655 def forward(self, x, context=None):
656 b, c, h, w, k_dim, heads = *x.shape, self.key_dim, self.heads
657
658 q, k, v = (self.to_q(x), self.to_k(x), self.to_v(x))
659
660 q, k, v = map(lambda t: t.reshape(b, heads, -1, h * w), (q, k, v))
661
662 q, k = map(lambda x: x * (self.key_dim**-0.25), (q, k))
663
664 if context is not None:
665 context = context.reshape(b, c, 1, -1)
666 ck, cv = self.to_k(context), self.to_v(context)
667 ck, cv = map(lambda t: t.reshape(b, heads, k_dim, -1), (ck, cv))
668 k = torch.cat((k, ck), dim=3)
669 v = torch.cat((v, cv), dim=3)
670
671 k = k.softmax(dim=-1)
672
673 if self.norm_queries:
674 q = q.softmax(dim=-2)
675
676 context = torch.einsum("bhdn,bhen->bhde", k, v)
677 out = torch.einsum("bhdn,bhde->bhen", q, context)
678 out = out.reshape(b, -1, h, w)
679 out = self.to_out(out)
680 return out
681
682
683class CBAM_attention(nn.Module):
684 """CBAM attention"""
685
686 ## Constructor
687 def __init__(
688 self,
689 channels,
690 which_conv=SNConv2d,
691 reduction=8,
692 attention_kernel_size=3,
693 ):
694 super(CBAM_attention, self).__init__()
695 ## average pooling
696 self.avg_pool = nn.AdaptiveAvgPool2d(1)
697 ## max pooling
698 self.max_pool = nn.AdaptiveMaxPool2d(1)
699 ## fcl
700 self.fc1 = which_conv(
701 channels, channels // reduction, kernel_size=1, padding=0
702 )
703 ## relu
704 self.relu = nn.ReLU(inplace=True)
705 ## f2c
706 self.fc2 = which_conv(
707 channels // reduction, channels, kernel_size=1, padding=0
708 )
709 ## sigmoid channel
710 self.sigmoid_channel = nn.Sigmoid()
711 ## convolution after concatenation
712 self.conv_after_concat = which_conv(
713 2,
714 1,
715 kernel_size=attention_kernel_size,
716 stride=1,
717 padding=attention_kernel_size // 2,
718 )
719 ## sigmoid_spatial
720 self.sigmoid_spatial = nn.Sigmoid()
721
722 ## forward
723 def forward(self, x):
724 # Channel attention module
725 module_input = x
726 avg = self.avg_pool(x)
727 mx = self.max_pool(x)
728 avg = self.fc1(avg)
729 mx = self.fc1(mx)
730 avg = self.relu(avg)
731 mx = self.relu(mx)
732 avg = self.fc2(avg)
733 mx = self.fc2(mx)
734 x = avg + mx
735 x = self.sigmoid_channel(x)
736 # Spatial attention module
737 x = module_input * x
738 module_input = x
739 # b, c, h, w = x.size()
740 avg = torch.mean(x, 1, True)
741 mx, _ = torch.max(x, 1, True)
742 x = torch.cat((avg, mx), 1)
743 x = self.conv_after_concat(x)
744 x = self.sigmoid_spatial(x)
745 x = module_input * x
746 return x
747
748
749class Attention(nn.Module):
750 """Attention"""
751
752 ## Constructor
753 def __init__(self, ch, which_conv=SNConv2d):
754 super(Attention, self).__init__()
755 ## Channel multiplier
756 self.ch = ch
757 ## which_conv
758 self.which_conv = which_conv
759 ## theta
760 self.theta = self.which_conv(
761 self.ch, self.ch // 8, kernel_size=1, padding=0, bias=False
762 )
763 ## phi
764 self.phi = self.which_conv(
765 self.ch, self.ch // 8, kernel_size=1, padding=0, bias=False
766 )
767 ## g
768 self.g = self.which_conv(
769 self.ch, self.ch // 2, kernel_size=1, padding=0, bias=False
770 )
771 ## o
772 self.o = self.which_conv(
773 self.ch // 2, self.ch, kernel_size=1, padding=0, bias=False
774 )
775 ## Learnable gain parameter
776 self.gamma = P(torch.tensor(0.0), requires_grad=True)
777
778 ## forward
779 def forward(self, x):
780 # Apply convs
781 theta = self.theta(x)
782 phi = F.max_pool2d(self.phi(x), [2, 2])
783 g = F.max_pool2d(self.g(x), [2, 2])
784 # Perform reshapes
785 theta = theta.view(-1, self.ch // 8, x.shape[2] * x.shape[3])
786 phi = phi.view(-1, self.ch // 8, x.shape[2] * x.shape[3] // 4)
787 g = g.view(-1, self.ch // 2, x.shape[2] * x.shape[3] // 4)
788 # Matmul and softmax to get attention maps
789 beta = F.softmax(torch.bmm(theta.transpose(1, 2), phi), -1)
790 # Attention map times g path
791 o = self.o(
792 torch.bmm(g, beta.transpose(1, 2)).view(
793 -1, self.ch // 2, x.shape[2], x.shape[3]
794 )
795 )
796 return self.gamma * o + x
797
798
799class SNEmbedding(nn.Embedding, SN):
800 """
801 Embedding layer with spectral norm
802 We use num_embeddings as the dim instead of embedding_dim here
803 for convenience sake
804 """
805
806 ## Constructor
807 def __init__(
808 self,
809 num_embeddings,
810 embedding_dim,
811 padding_idx=None,
812 max_norm=None,
813 norm_type=2,
814 scale_grad_by_freq=False,
815 sparse=False,
816 _weight=None,
817 num_svs=1,
818 num_itrs=1,
819 eps=1e-12,
820 ):
821 nn.Embedding.__init__(
822 self,
823 num_embeddings,
824 embedding_dim,
825 padding_idx,
826 max_norm,
827 norm_type,
828 scale_grad_by_freq,
829 sparse,
830 _weight,
831 )
832 SN.__init__(self, num_svs, num_itrs, num_embeddings, eps=eps)
833
834 ## forward
835 def forward(self, x):
836 return F.embedding(x, self.W_())
837
838
839def scaled_dot_product(q, k, v):
840 d_k = q.size()[-1]
841 attn_logits = torch.matmul(q, k.transpose(-2, -1))
842 attn_logits = attn_logits / math.sqrt(d_k)
843 attention = F.softmax(attn_logits, dim=-1)
844 values = torch.matmul(attention, v)
845 return values, attention
846
847
848class MultiheadAttention(nn.Module):
849 """MultiheadAttention"""
850
851 ## Constructor
852 def __init__(self, input_dim, embed_dim, num_heads, which_linear):
853 super().__init__()
854 assert (
855 embed_dim % num_heads == 0
856 ), "Embedding dimension must be 0 modulo number of heads."
857
858 ## embedding dimension
859 self.embed_dim = embed_dim
860 ## number of heads
861 self.num_heads = num_heads
862 ## head dimension
863 self.head_dim = embed_dim // num_heads
864 ## which linear
865 self.which_linear = which_linear
866
867 # Stack all weight matrices 1...h together for efficiency
868 ## qkv projection
869 self.qkv_proj = self.which_linear(input_dim, 3 * embed_dim)
870 ## o projection
871 self.o_proj = self.which_linear(embed_dim, embed_dim)
872
873 self._reset_parameters()
874
875 ## reset parameters
876 def _reset_parameters(self):
877 # Original Transformer initialization, see PyTorch documentation
878 nn.init.xavier_uniform_(self.qkv_proj.weight)
879 self.qkv_proj.bias.data.fill_(0)
880 nn.init.xavier_uniform_(self.o_proj.weight)
881 self.o_proj.bias.data.fill_(0)
882
883 ## forward
884 def forward(self, x, return_attention=False):
885 batch_size, seq_length, embed_dim = x.size()
886 qkv = self.qkv_proj(x)
887
888 # Separate Q, K, V from linear output
889 qkv = qkv.reshape(batch_size, seq_length, self.num_heads, 3 * self.head_dim)
890 qkv = qkv.permute(0, 2, 1, 3) # [Batch, Head, SeqLen, Dims]
891 q, k, v = qkv.chunk(3, dim=-1)
892
893 # Determine value outputs
894 values, attention = scaled_dot_product(q, k, v)
895 values = values.permute(0, 2, 1, 3) # [Batch, SeqLen, Head, Dims]
896 values = values.reshape(batch_size, seq_length, embed_dim)
897 o = self.o_proj(values)
898
899 if return_attention:
900 return o, attention
901 else:
902 return o
903
904
905class EncoderBlock(nn.Module):
906 """EncoderBlock"""
907
908 ## Constructor
909 def __init__(self, input_dim, num_heads, dim_feedforward, dropout, which_linear):
910 """
911 Inputs:
912 input_dim - Dimensionality of the input
913 num_heads - Number of heads to use in the attention block
914 dim_feedforward - Dimensionality of the hidden layer in the MLP
915 dropout - Dropout probability to use in the dropout layers
916 """
917 super().__init__()
918
919 ## which linear
920 self.which_linear = which_linear
921 ## Attention layer
922 self.self_attn = MultiheadAttention(
923 input_dim, input_dim, num_heads, which_linear
924 )
925
926 ## Two-layer MLP
927 self.linear_net = nn.Sequential(
928 self.which_linear(input_dim, dim_feedforward),
929 nn.Dropout(dropout),
930 nn.ReLU(inplace=True),
931 self.which_linear(dim_feedforward, input_dim),
932 )
933
934 # Layers to apply in between the main layers
935 ## norm1
936 self.norm1 = nn.LayerNorm(input_dim)
937 ## norm2
938 self.norm2 = nn.LayerNorm(input_dim)
939 ## dropout
940 self.dropout = nn.Dropout(dropout)
941
942 ## forward
943 def forward(self, x):
944 # Attention part
945 x_pre1 = self.norm1(x)
946 attn_out = self.self_attn(x_pre1)
947 x = x + self.dropout(attn_out)
948 # x = self.norm1(x)
949
950 # MLP part
951 x_pre2 = self.norm2(x)
952 linear_out = self.linear_net(x_pre2)
953 x = x + self.dropout(linear_out)
954 # x = self.norm2(x)
955
956 return x
957
958
959class RelationalReasoning(nn.Module):
960 """RelationalReasoning"""
961
962 ## Constructor
963 def __init__(self, num_layers, hidden_dim, **block_args):
964 super().__init__()
965 ## layers
966 self.layers = nn.ModuleList(
967 [EncoderBlock(**block_args) for _ in range(num_layers)]
968 )
969 ## normalization
970 self.norm = nn.LayerNorm(hidden_dim)
971
972 ## forward
973 def forward(self, x):
974 for layer in self.layers:
975 x = layer(x)
976
977 x = self.norm(x)
978 return x
979
980 ## get attention maps
981 def get_attention_maps(self, x):
982 attention_maps = []
983 for layer in self.layers:
984 _, attn_map = layer.self_attn(x, return_attention=True)
985 attention_maps.append(attn_map)
986 x = layer(x)
987 return attention_maps
988
989
990class GBlock(nn.Module):
991 """GBlock"""
992
993 ## Constructor
994 def __init__(
995 self,
996 in_channels,
997 out_channels,
998 which_conv=SNConv2d,
999 which_bn=bn,
1000 activation=None,
1001 upsample=None,
1002 channel_ratio=4,
1003 ):
1004 super(GBlock, self).__init__()
1005
1006 ## input channels
1007 self.in_channels, self.out_channels = in_channels, out_channels
1008 ## hidden channels
1009 self.hidden_channels = self.in_channels // channel_ratio
1010 ## which convolution
1011 self.which_conv, self.which_bn = which_conv, which_bn
1012 ## activation
1013 self.activation = activation
1014 # Conv layers
1015 ## conv1
1016 self.conv1 = self.which_conv(
1017 self.in_channels, self.hidden_channels, kernel_size=1, padding=0
1018 )
1019 ## conv2
1020 self.conv2 = self.which_conv(self.hidden_channels, self.hidden_channels)
1021 ## conv3
1022 self.conv3 = self.which_conv(self.hidden_channels, self.hidden_channels)
1023 ## conv4
1024 self.conv4 = self.which_conv(
1025 self.hidden_channels, self.out_channels, kernel_size=1, padding=0
1026 )
1027 # Batchnorm layers
1028 ## bn1
1029 self.bn1 = self.which_bn(self.in_channels)
1030 ## bn2
1031 self.bn2 = self.which_bn(self.hidden_channels)
1032 ## bn3
1033 self.bn3 = self.which_bn(self.hidden_channels)
1034 ## bn4
1035 self.bn4 = self.which_bn(self.hidden_channels)
1036 ## upsample layers
1037 self.upsample = upsample
1038
1039 ## forward
1040 def forward(self, x, y):
1041 # Project down to channel ratio
1042 h = self.conv1(self.activation(self.bn1(x, y)))
1043 # Apply next BN-ReLU
1044 h = self.activation(self.bn2(h, y))
1045 # Drop channels in x if necessary
1046 if self.in_channels != self.out_channels:
1047 x = x[:, : self.out_channels]
1048 # Upsample both h and x at this point
1049 if self.upsample:
1050 h = self.upsample(h)
1051 x = self.upsample(x)
1052 # 3x3 convs
1053 h = self.conv2(h)
1054 h = self.conv3(self.activation(self.bn3(h, y)))
1055 # Final 1x1 conv
1056 h = self.conv4(self.activation(self.bn4(h, y)))
1057 return h + x
1058
1059
1060def G_arch(ch=64, attention="64"):
1061 arch = {}
1062 arch[512] = {
1063 "in_channels": [ch * item for item in [16, 16, 8, 8, 4, 2, 1]],
1064 "out_channels": [ch * item for item in [16, 8, 8, 4, 2, 1, 1]],
1065 "upsample": [True] * 7,
1066 "resolution": [8, 16, 32, 64, 128, 256, 512],
1067 "attention": {
1068 2**i: (2 ** i in [int(item) for item in attention.split("_")])
1069 for i in range(3, 10)
1070 },
1071 }
1072 arch[256] = {
1073 "in_channels": [ch * item for item in [16, 16, 8, 8, 4, 2]],
1074 "out_channels": [ch * item for item in [16, 8, 8, 4, 2, 1]],
1075 "upsample": [True] * 6,
1076 "resolution": [8, 16, 32, 64, 128, 256],
1077 "attention": {
1078 2**i: (2 ** i in [int(item) for item in attention.split("_")])
1079 for i in range(3, 9)
1080 },
1081 }
1082 arch[128] = {
1083 "in_channels": [ch * item for item in [16, 16, 8, 4, 2]],
1084 "out_channels": [ch * item for item in [16, 8, 4, 2, 1]],
1085 "upsample": [True] * 5,
1086 "resolution": [8, 16, 32, 64, 128],
1087 "attention": {
1088 2**i: (2 ** i in [int(item) for item in attention.split("_")])
1089 for i in range(3, 8)
1090 },
1091 }
1092 arch[96] = {
1093 "in_channels": [ch * item for item in [16, 16, 8, 4]],
1094 "out_channels": [ch * item for item in [16, 8, 4, 2]],
1095 "upsample": [True] * 4,
1096 "resolution": [12, 24, 48, 96],
1097 "attention": {
1098 12 * 2**i: (6 * 2 ** i in [int(item) for item in attention.split("_")])
1099 for i in range(0, 4)
1100 },
1101 }
1102
1103 arch[64] = {
1104 "in_channels": [ch * item for item in [16, 16, 8, 4]],
1105 "out_channels": [ch * item for item in [16, 8, 4, 2]],
1106 "upsample": [True] * 4,
1107 "resolution": [8, 16, 32, 64],
1108 "attention": {
1109 2**i: (2 ** i in [int(item) for item in attention.split("_")])
1110 for i in range(3, 7)
1111 },
1112 }
1113 arch[32] = {
1114 "in_channels": [ch * item for item in [4, 4, 4]],
1115 "out_channels": [ch * item for item in [4, 4, 4]],
1116 "upsample": [True] * 3,
1117 "resolution": [8, 16, 32],
1118 "attention": {
1119 2**i: (2 ** i in [int(item) for item in attention.split("_")])
1120 for i in range(3, 6)
1121 },
1122 }
1123
1124 return arch
1125
1126
1127class Generator(nn.Module):
1128 """Generator"""
1129
1130 ## Constructor
1131 def __init__(
1132 self,
1133 G_ch=64,
1134 G_depth=2,
1135 dim_z=128,
1136 bottom_width=4,
1137 resolution=256,
1138 G_kernel_size=3,
1139 G_attn="64",
1140 n_classes=40,
1141 H_base=1,
1142 num_G_SVs=1,
1143 num_G_SV_itrs=1,
1144 attn_type="sa",
1145 G_shared=True,
1146 shared_dim=128,
1147 hier=True,
1148 cross_replica=False,
1149 mybn=False,
1150 G_activation=nn.ReLU(inplace=False),
1151 G_lr=5e-5,
1152 G_B1=0.0,
1153 G_B2=0.999,
1154 adam_eps=1e-8,
1155 BN_eps=1e-5,
1156 SN_eps=1e-12,
1157 G_init="ortho",
1158 G_mixed_precision=False,
1159 G_fp16=False,
1160 skip_init=False,
1161 no_optim=False,
1162 sched_version="default",
1163 RRM_prx_G=True,
1164 n_head_G=2,
1165 G_param="SN",
1166 norm_style="bn",
1167 **kwargs
1168 ):
1169 super(Generator, self).__init__()
1170 ## Channel width multiplier
1171 self.ch = G_ch
1172 ## Number of resblocks per stage
1173 self.G_depth = G_depth
1174 ## Dimensionality of the latent space
1175 self.dim_z = dim_z
1176 ## The initial spatial dimensions
1177 self.bottom_width = bottom_width
1178 ## The initial harizontal dimension
1179 self.H_base = H_base
1180 ## Resolution of the output
1181 self.resolution = resolution
1182 ## Kernel size?
1183 self.kernel_size = G_kernel_size
1184 ## Attention?
1185 self.attention = G_attn
1186 ## number of classes, for use in categorical conditional generation
1187 self.n_classes = n_classes
1188 ## Use shared embeddings?
1189 self.G_shared = G_shared
1190 ## Dimensionality of the shared embedding? Unused if not using G_shared
1191 self.shared_dim = shared_dim if shared_dim > 0 else dim_z
1192 ## Hierarchical latent space?
1193 self.hier = hier
1194 ## Cross replica batchnorm?
1195 self.cross_replica = cross_replica
1196 ## Use my batchnorm?
1197 self.mybn = mybn
1198 # nonlinearity for residual blocks
1199 if G_activation == "inplace_relu":
1200 ## activation
1201 self.activation = torch.nn.ReLU(inplace=True)
1202 elif G_activation == "relu":
1203 self.activation = torch.nn.ReLU(inplace=False)
1204 elif G_activation == "leaky_relu":
1205 self.activation = torch.nn.LeakyReLU(0.2, inplace=False)
1206 else:
1207 raise NotImplementedError("activation function not implemented")
1208 ## Initialization style
1209 self.init = G_init
1210 ## Parameterization style
1211 self.G_param = G_param
1212 ## Normalization style
1213 self.norm_style = norm_style
1214 ## Epsilon for BatchNorm?
1215 self.BN_eps = BN_eps
1216 ## Epsilon for Spectral Norm?
1217 self.SN_eps = SN_eps
1218 ## fp16?
1219 self.fp16 = G_fp16
1220 ## Architecture dict
1221 self.arch = G_arch(self.ch, self.attention)[resolution]
1222 ## RRM_prx_G
1223 self.RRM_prx_G = RRM_prx_G
1224 ## n_head_G
1225 self.n_head_G = n_head_G
1226
1227 # Which convs, batchnorms, and linear layers to use
1228 if self.G_param == "SN":
1229 ## which conv
1230 self.which_conv = functools.partial(
1231 SNConv2d,
1232 kernel_size=3,
1233 padding=1,
1234 num_svs=num_G_SVs,
1235 num_itrs=num_G_SV_itrs,
1236 eps=self.SN_eps,
1237 )
1238 ## which linear
1239 self.which_linear = functools.partial(
1240 SNLinear,
1241 num_svs=num_G_SVs,
1242 num_itrs=num_G_SV_itrs,
1243 eps=self.SN_eps,
1244 )
1245 else:
1246 self.which_conv = functools.partial(nn.Conv2d, kernel_size=3, padding=1)
1247 self.which_linear = nn.Linear
1248
1249 # We use a non-spectral-normed embedding here regardless;
1250 # For some reason applying SN to G's embedding seems to randomly cripple G # noqa
1251 ## which embedding
1252 self.which_embedding = nn.Embedding
1253 bn_linear = (
1254 functools.partial(self.which_linear, bias=False)
1255 if self.G_shared
1256 else self.which_embedding
1257 )
1258 ## which bn
1259 self.which_bn = functools.partial(
1260 ccbn,
1261 which_linear=bn_linear,
1262 cross_replica=self.cross_replica,
1263 mybn=self.mybn,
1264 input_size=(
1265 self.shared_dim + self.dim_z if self.G_shared else self.n_classes
1266 ),
1267 norm_style=self.norm_style,
1268 eps=self.BN_eps,
1269 )
1270 ## shared
1271 self.shared = (
1272 self.which_embedding(n_classes, self.shared_dim)
1273 if G_shared
1274 else identity()
1275 )
1276
1277 if self.RRM_prx_G:
1278 ## RRM on proxy embeddings
1279 self.RR_G = RelationalReasoning(
1280 num_layers=1,
1281 input_dim=128,
1282 dim_feedforward=128,
1283 which_linear=nn.Linear,
1284 num_heads=self.n_head_G,
1285 dropout=0.0,
1286 hidden_dim=128,
1287 )
1288
1289 ## First linear layer
1290 self.linear = self.which_linear(
1291 self.dim_z + self.shared_dim,
1292 self.arch["in_channels"][0] * ((self.bottom_width**2) * self.H_base),
1293 )
1294
1295 # self.blocks is a doubly-nested list of modules, the outer loop intended # noqa
1296 # to be over blocks at a given resolution (resblocks and/or self-attention) # noqa
1297 # while the inner loop is over a given block
1298 ## blocks
1299 self.blocks = []
1300 for index in range(len(self.arch["out_channels"])):
1301 self.blocks += [
1302 [
1303 GBlock(
1304 in_channels=self.arch["in_channels"][index],
1305 out_channels=self.arch["in_channels"][index]
1306 if g_index == 0
1307 else self.arch["out_channels"][index],
1308 which_conv=self.which_conv,
1309 which_bn=self.which_bn,
1310 activation=self.activation,
1311 upsample=(
1312 functools.partial(F.interpolate, scale_factor=2)
1313 if self.arch["upsample"][index]
1314 and g_index == (self.G_depth - 1)
1315 else None
1316 ),
1317 )
1318 ]
1319 for g_index in range(self.G_depth)
1320 ]
1321
1322 # If attention on this block, attach it to the end
1323 if self.arch["attention"][self.arch["resolution"][index]]:
1324 print(
1325 f"Adding attention layer in G at resolution {self.arch['resolution'][index]:d}"
1326 )
1327
1328 if attn_type == "sa":
1329 self.blocks[-1] += [
1330 Attention(self.arch["out_channels"][index], self.which_conv)
1331 ]
1332 elif attn_type == "cbam":
1333 self.blocks[-1] += [
1334 CBAM_attention(
1335 self.arch["out_channels"][index], self.which_conv
1336 )
1337 ]
1338 elif attn_type == "ila":
1339 self.blocks[-1] += [ILA(self.arch["out_channels"][index])]
1340
1341 # Turn self.blocks into a ModuleList so that it's all properly registered. # noqa
1342 self.blocks = nn.ModuleList([nn.ModuleList(block) for block in self.blocks])
1343
1344 # output layer: batchnorm-relu-conv.
1345 # Consider using a non-spectral conv here
1346 ## output layer
1347 self.output_layer = nn.Sequential(
1348 bn(
1349 self.arch["out_channels"][-1],
1350 cross_replica=self.cross_replica,
1351 mybn=self.mybn,
1352 ),
1353 self.activation,
1354 self.which_conv(self.arch["out_channels"][-1], 1),
1355 )
1356
1357 # Initialize weights. Optionally skip init for testing.
1358 if not skip_init:
1359 self.init_weights()
1360
1361 # Set up optimizer
1362 # If this is an EMA copy, no need for an optim, so just return now
1363 if no_optim:
1364 return
1365 ## lr
1366 self.lr = G_lr
1367 ## B1
1368 self.B1 = G_B1
1369 ## B2
1370 self.B2 = G_B2
1371 ## adam_eps
1372 self.adam_eps = adam_eps
1373 if G_mixed_precision:
1374 print("Using fp16 adam in G...")
1375 import utils
1376
1377 self.optim = utils.Adam16(
1378 params=self.parameters(),
1379 lr=self.lr,
1380 betas=(self.B1, self.B2),
1381 weight_decay=0,
1382 eps=self.adam_eps,
1383 )
1384
1385 ## optim
1386 self.optim = optim.Adam(
1387 params=self.parameters(),
1388 lr=self.lr,
1389 betas=(self.B1, self.B2),
1390 weight_decay=0,
1391 eps=self.adam_eps,
1392 )
1393 # LR scheduling
1394 if sched_version == "default":
1395 ## lr sched
1396 self.lr_sched = None
1397 elif sched_version == "CosAnnealLR":
1398 self.lr_sched = optim.lr_scheduler.CosineAnnealingLR(
1399 self.optim,
1400 T_max=kwargs["num_epochs"],
1401 eta_min=self.lr / 4,
1402 last_epoch=-1,
1403 )
1404 elif sched_version == "CosAnnealWarmRes":
1405 self.lr_sched = optim.lr_scheduler.CosineAnnealingWarmRestarts(
1406 self.optim, T_0=10, T_mult=2, eta_min=self.lr / 4
1407 )
1408 else:
1409 self.lr_sched = None
1410
1411 ## Initialize
1412 def init_weights(self):
1413 ## parameter count
1414 self.param_count = 0
1415 for module in self.modules():
1416 if (
1417 isinstance(module, nn.Conv2d)
1418 or isinstance(module, nn.Linear)
1419 or isinstance(module, nn.Embedding)
1420 ):
1421 if self.init == "ortho":
1422 init.orthogonal_(module.weight)
1423 elif self.init == "N02":
1424 init.normal_(module.weight, 0, 0.02)
1425 elif self.init in ["glorot", "xavier"]:
1426 init.xavier_uniform_(module.weight)
1427 else:
1428 print("Init style not recognized...")
1429 self.param_count += sum(
1430 [p.data.nelement() for p in module.parameters()]
1431 )
1432 print(f"Param count for G's initialized parameters: {self.param_count}")
1433
1434 ## forward
1435 def forward(self, z, y):
1436 y = self.shared(y)
1437 # If relational embedding
1438 if self.RRM_prx_G:
1439 y = self.RR_G(y.unsqueeze(0)).squeeze(0)
1440 # y = F.normalize(y, dim=1)
1441 # If hierarchical, concatenate zs and ys
1442 if self.hier: # y and z are [bs,128] dimensional
1443 z = torch.cat([y, z], 1)
1444 y = z
1445 # First linear layer
1446 h = self.linear(z) # ([bs,256]-->[bs,24576])
1447 # Reshape
1448 h = h.view(h.size(0), -1, self.bottom_width, self.bottom_width * self.H_base)
1449 # Loop over blocks
1450 for _, blocklist in enumerate(self.blocks):
1451 # Second inner loop in case block has multiple layers
1452 for block in blocklist:
1453 h = block(h, y)
1454
1455 # Apply batchnorm-relu-conv-tanh at output
1456 return torch.tanh(self.output_layer(h))
1457
1458
1459class Model(Generator):
1460 """
1461 Generator subclass
1462 default initializing with CONFIG dict
1463 """
1464
1465 ## Constructor
1466 def __init__(self):
1467 super().__init__(**CONFIG)
1468
1469
1470def generate(model: nn.Module):
1471 """
1472 Run inference with the provided Generator model
1473
1474 Args:
1475 model (nn.Module): Generator model
1476
1477 Returns:
1478 torch.Tensor: batch of 40 PXD images
1479 """
1480 device = next(model.parameters()).device
1481 with torch.no_grad():
1482 latents = torch.randn(40, 128, device=device)
1483 labels = torch.tensor(list(range(40)), dtype=torch.long, device=device)
1484 imgs = model(latents, labels).detach().cpu()
1485 # Cut the noise below 7 ADU
1486 imgs = F.threshold(imgs, -0.26, -1)
1487 # center range [-1, 1] to [0, 1]
1488 imgs = imgs.mul_(0.5).add_(0.5)
1489 # renormalize and convert to uint8
1490 imgs = torch.pow(256, imgs).add_(-1).clamp_(0, 255).to(torch.uint8)
1491 # flatten channel dimension and crop 256 to 250
1492 imgs = imgs[:, 0, 3:-3, :]
1493 return imgs
__init__(self, num_svs, num_itrs, num_outputs, transpose=False, eps=1e-12)
Definition ieagan.py:232
forward(self, torch.Tensor tensor)
forward
Definition ieagan.py:222