Belle II Software development
ieagan.py
1
8"""
9This module implements the IEA-GAN generator model.
10"""
11
12import functools
13import json
14import math
15
16import torch
17from torch import nn
18from torch import optim
19from torch.nn import init
20from torch.nn import Parameter as P
21import torch.nn.functional as F
22
23
24CONFIG = json.loads(
25 """{
26 "num_workers": 8,
27 "seed": 415,
28 "pin_memory": false,
29 "shuffle": true,
30 "augment": 0,
31 "use_multiepoch_sampler": false,
32 "model": "BigGAN_deep",
33 "G_ch": 32,
34 "G_param" : "SN",
35 "D_param" : "SN",
36 "D_ch": 32,
37 "G_depth": 2,
38 "D_depth": 2,
39 "H_base": 3,
40 "D_wide": true,
41 "G_shared": true,
42 "shared_dim": 128,
43 "dim_z": 128,
44 "z_var": 1.0,
45 "hier": true,
46 "cross_replica": false,
47 "mybn": false,
48 "G_activation": "inplace_relu",
49 "D_activation": "inplace_relu",
50 "G_attn": "0",
51 "D_attn": "0",
52 "norm_style": "bn",
53 "G_init": "ortho",
54 "D_init": "ortho",
55 "skip_init": false,
56 "G_lr": 5e-05,
57 "D_lr": 5e-05,
58 "G_B1": 0.0,
59 "D_B1": 0.0,
60 "G_B2": 0.999,
61 "D_B2": 0.999,
62 "batch_size": 40,
63 "G_batch_size": 0,
64 "num_G_accumulations": 1,
65 "num_D_steps": 1,
66 "num_D_accumulations": 1,
67 "split_D": true,
68 "num_epochs": 4,
69 "parallel": false,
70 "G_fp16": false,
71 "D_fp16": false,
72 "D_mixed_precision": false,
73 "G_mixed_precision": false,
74 "accumulate_stats": false,
75 "num_standing_accumulations": 16,
76 "G_eval_mode": true,
77 "save_every": 1000,
78 "test_every": 1000,
79 "num_save_copies": 2,
80 "num_best_copies": 2,
81 "ema": true,
82 "ema_decay": 0.9999,
83 "use_ema": true,
84 "ema_start": 10000,
85 "adam_eps": 1e-06,
86 "BN_eps": 1e-05,
87 "SN_eps": 1e-06,
88 "num_G_SVs": 1,
89 "num_D_SVs": 1,
90 "num_G_SV_itrs": 1,
91 "num_D_SV_itrs": 1,
92 "G_ortho": 0.0001,
93 "D_ortho": 0.0,
94 "toggle_grads": true,
95 "logstyle": "%3.3e",
96 "sv_log_interval": 10,
97 "log_interval": 100,
98 "resolution": 256,
99 "n_classes": 40,
100 "run_name": "BGd_140",
101 "resume": false,
102 "latent_op": false,
103 "latent_reg_weight": 300,
104 "bottom_width": 4,
105 "add_blur" : false,
106 "add_noise": true,
107 "add_style": false,
108 "conditional_strategy": "Contra",
109 "hypersphere_dim": 1024,
110 "pos_collected_numerator": false,
111 "nonlinear_embed": false,
112 "normalize_embed": true,
113 "inv_stereographic" :false,
114 "contra_lambda": 1.0,
115 "Angle": false,
116 "angle_lambda": 1.0,
117 "IEA_loss": true,
118 "IEA_lambda": 1.0,
119 "Uniformity_loss": true,
120 "unif_lambda": 0.1,
121 "diff_aug": true,
122 "Con_reg": false,
123 "cr_lambda": 10,
124 "pixel_reg": false,
125 "px_lambda": 1.0,
126 "RRM_prx_G": true,
127 "normalized_proxy_G": false,
128 "RRM_prx_D": false,
129 "RRM_embed": true,
130 "n_head_G": 2,
131 "n_head": 4,
132 "attn_type": "sa",
133 "sched_version": "default",
134 "z_dist": "normal",
135 "truncated_threshold": 1.0,
136 "clip_norm": "None",
137 "amsgrad": false,
138 "arch": "None",
139 "G_kernel_size": 3,
140 "D_kernel_size": 3,
141 "ada_belief": false,
142 "pbar": "tqdm",
143 "which_best": "FID",
144 "stop_after": 100000,
145 "trunc_z": 0.5,
146 "denoise": false,
147 "metric_log_name": "metric_log.jsonl",
148 "reinitialize_metric_logs": false,
149 "reinitialize_parameter_logs": false,
150 "num_incep_images": 16000,
151 "load_optim": true}"""
152)
153
154
155def proj(x, y):
156 """
157 Projection of x onto y
158 """
159 return torch.mm(y, x.t()) * y / torch.mm(y, y.t())
160
161
162def gram_schmidt(x, ys):
163 """
164 Orthogonalize x wrt list of vectors ys
165 """
166 for y in ys:
167 x = x - proj(x, y)
168 return x
169
170
171def power_iteration(W, u_, update=True, eps=1e-12):
172 """
173 Apply num_itrs steps of the power method to estimate top N singular values.
174 """
175 # Lists holding singular vectors and values
176 us, vs, svs = [], [], []
177 for i, u in enumerate(u_):
178 # Run one step of the power iteration
179 with torch.no_grad():
180 v = torch.matmul(u, W)
181 # Run Gram-Schmidt to subtract components of all other singular vectors # noqa
182 v = F.normalize(gram_schmidt(v, vs), eps=eps)
183 # Add to the list
184 vs += [v]
185 # Update the other singular vector
186 u = torch.matmul(v, W.t())
187 # Run Gram-Schmidt to subtract components of all other singular vectors # noqa
188 u = F.normalize(gram_schmidt(u, us), eps=eps)
189 # Add to the list
190 us += [u]
191 if update:
192 u_[i][:] = u
193 # Compute this singular value and add it to the list
194 svs += [torch.squeeze(torch.matmul(torch.matmul(v, W.t()), u.t()))]
195 # svs += [torch.sum(F.linear(u, W.transpose(0, 1)) * v)]
196 return svs, us, vs
197
198
199def groupnorm(x, norm_style):
200 """
201 Simple function to handle groupnorm norm stylization
202 """
203 # If number of channels specified in norm_style:
204 if "ch" in norm_style:
205 ch = int(norm_style.split("_")[-1])
206 groups = max(int(x.shape[1]) // ch, 1)
207 # If number of groups specified in norm style
208 elif "grp" in norm_style:
209 groups = int(norm_style.split("_")[-1])
210 # If neither, default to groups = 16
211 else:
212 groups = 16
213 return F.group_norm(x, groups)
214
215
216class identity(nn.Module):
217 """
218 Convenience passthrough function
219 """
220
221
222 def forward(self, tensor: torch.Tensor):
223 return tensor
224
225
226class SN(nn.Module):
227 """
228 Spectral normalization base class
229
230 This base class expects subclasses to have a learnable weight parameter
231 (`self.weight`) as in `nn.Linear` or `nn.Conv2d`. It provides a method
232 to apply spectral normalization to that weight.
233
234 Attributes:
235 num_svs (int): Number of singular values.
236 num_itrs (int): Number of power iterations per step.
237 transpose (bool): Whether to transpose the weight matrix.
238 eps (float): Small constant to avoid divide-by-zero.
239 u (list[Tensor]): Registered left singular vectors (buffers).
240 sv (list[Tensor]): Registered singular values (buffers).
241 training (bool): Inherited from nn.Module. True if in training mode.
242 """
243
244 def __init__(self, num_svs, num_itrs, num_outputs, transpose=False, eps=1e-12):
245 """constructor"""
246
247 super().__init__()
248 ## Number of power iterations per step
249 self.num_itrs = num_itrs
250 ## Number of singular values
251 self.num_svs = num_svs
252 ## Transposed?
253 self.transpose = transpose
254 ## Epsilon value for avoiding divide-by-0
255 self.eps = eps
256 # Register a singular vector for each sv
257 for i in range(self.num_svs):
258 self.register_buffer(f"u{i:d}", torch.randn(1, num_outputs))
259 self.register_buffer(f"sv{i:d}", torch.ones(1))
260 ## Training mode flag (inherited from nn.Module). True if the module is in training mode.
261 self.training: bool
262
263 @property
264 def u(self):
265 """
266 Singular vectors (u side)
267 """
268 return [getattr(self, f"u{i:d}") for i in range(self.num_svs)]
269
270 @property
271 def sv(self):
272 """
273 Singular values
274 note that these buffers are just for logging and are not used in training.
275 """
276 return [getattr(self, f"sv{i:d}") for i in range(self.num_svs)]
277
278 def W_(self):
279 """
280 Compute the spectrally-normalized weight
281 """
282 W_mat = self.weight.view(self.weight.size(0), -1)
283 if self.transpose:
284 W_mat = W_mat.t()
285 # Apply num_itrs power iterations
286 for _ in range(self.num_itrs):
287 svs, _, _ = power_iteration(
288 W_mat, self.u, update=self.training, eps=self.eps
289 )
290 # Update the svs
291 if self.training:
292 # Make sure to do this in a no_grad() context or you'll get memory leaks! # noqa
293 with torch.no_grad():
294 for i, sv in enumerate(svs):
295 self.sv[i][:] = sv
296 return self.weight / svs[0]
297
298
299class SNConv2d(nn.Conv2d, SN):
300 """
301 2D Conv layer with spectral norm
302 """
303
304 ## Constructor
305 def __init__(
306 self,
307 in_channels,
308 out_channels,
309 kernel_size,
310 stride=1,
311 padding=0,
312 dilation=1,
313 groups=1,
314 bias=True,
315 num_svs=1,
316 num_itrs=1,
317 eps=1e-12,
318 ):
319 nn.Conv2d.__init__(
320 self,
321 in_channels,
322 out_channels,
323 kernel_size,
324 stride,
325 padding,
326 dilation,
327 groups,
328 bias,
329 )
330 SN.__init__(self, num_svs, num_itrs, out_channels, eps=eps)
331
332 ## forward
333 def forward(self, x):
334 return F.conv2d(
335 x,
336 self.W_(),
337 # \cond false positive doxygen warning
338 self.bias,
339 self.stride,
340 self.padding,
341 self.dilation,
342 self.groups,
343 # \endcond
344 )
345
346
347class SNLinear(nn.Linear, SN):
348 """
349 Linear layer with spectral norm
350 """
351
352 ## Constructor
353 def __init__(
354 self,
355 in_features,
356 out_features,
357 bias=True,
358 num_svs=1,
359 num_itrs=1,
360 eps=1e-12,
361 ):
362 nn.Linear.__init__(self, in_features, out_features, bias)
363 SN.__init__(self, num_svs, num_itrs, out_features, eps=eps)
364
365 ## forward
366 def forward(self, x):
367 # \cond false positive doxygen warning
368 return F.linear(x, self.W_(), self.bias)
369 # \endcond
370
371
372def fused_bn(x, mean, var, gain=None, bias=None, eps=1e-5):
373 """Fused batchnorm op"""
374
375 # Apply scale and shift--if gain and bias are provided, fuse them here
376 # Prepare scale
377 scale = torch.rsqrt(var + eps)
378 # If a gain is provided, use it
379 if gain is not None:
380 scale = scale * gain
381 # Prepare shift
382 shift = mean * scale
383 # If bias is provided, use it
384 if bias is not None:
385 shift = shift - bias
386 return x * scale - shift
387 # return ((x - mean) / ((var + eps) ** 0.5)) * gain + bias # The unfused way. # noqa
388
389
390def manual_bn(x, gain=None, bias=None, return_mean_var=False, eps=1e-5):
391 """
392 Manual BN
393 Calculate means and variances using mean-of-squares minus mean-squared
394 """
395
396 # Cast x to float32 if necessary
397 float_x = x.float()
398 # Calculate expected value of x (m) and expected value of x**2 (m2)
399 # Mean of x
400 m = torch.mean(float_x, [0, 2, 3], keepdim=True)
401 # Mean of x squared
402 m2 = torch.mean(float_x**2, [0, 2, 3], keepdim=True)
403 # Calculate variance as mean of squared minus mean squared.
404 var = m2 - m**2
405 # Cast back to float 16 if necessary
406 var = var.type(x.type())
407 m = m.type(x.type())
408 # Return mean and variance for updating stored mean/var if requested
409 if return_mean_var:
410 return (
411 fused_bn(x, m, var, gain, bias, eps),
412 m.squeeze(),
413 var.squeeze(),
414 )
415 else:
416 return fused_bn(x, m, var, gain, bias, eps)
417
418
419class myBN(nn.Module):
420 """
421 My batchnorm, supports standing stats
422 """
423
424 ## Constructor
425 def __init__(self, num_channels, eps=1e-5, momentum=0.1):
426 super(myBN, self).__init__()
427 ## momentum for updating running stats
428 self.momentum = momentum
429 ## epsilon to avoid dividing by 0
430 self.eps = eps
431 # Momentum
432 self.momentum = momentum
433 # Register buffers
434 self.register_buffer("stored_mean", torch.zeros(num_channels))
435 self.register_buffer("stored_var", torch.ones(num_channels))
436 self.register_buffer("accumulation_counter", torch.zeros(1))
437 ## Accumulate running means and vars
438 self.accumulate_standing = False
439 ## Training mode flag (inherited from nn.Module). True if the module is in training mode.
440 self.training: bool
441
442 ## reset standing stats
443 def reset_stats(self):
444 self.stored_mean[:] = 0
445 self.stored_var[:] = 0
446 self.accumulation_counter[:] = 0
447
448 ## forward
449 def forward(self, x, gain, bias):
450 if self.training:
451 out, mean, var = manual_bn(
452 x, gain, bias, return_mean_var=True, eps=self.eps
453 )
454 # If accumulating standing stats, increment them
455 if self.accumulate_standing:
456 self.stored_mean[:] = self.stored_mean + mean.data
457 self.stored_var[:] = self.stored_var + var.data
458 self.accumulation_counter += 1.0
459 # If not accumulating standing stats, take running averages
460 else:
461 self.stored_mean[:] = (
462 self.stored_mean * (1 - self.momentum) + mean * self.momentum
463 )
464 self.stored_var[:] = (
465 self.stored_var * (1 - self.momentum) + var * self.momentum
466 )
467 return out
468 # If not in training mode, use the stored statistics
469 else:
470 mean = self.stored_mean.view(1, -1, 1, 1)
471 var = self.stored_var.view(1, -1, 1, 1)
472 # If using standing stats, divide them by the accumulation counter
473 if self.accumulate_standing:
474 mean = mean / self.accumulation_counter
475 var = var / self.accumulation_counter
476 return fused_bn(x, mean, var, gain, bias, self.eps)
477
478
479class bn(nn.Module):
480 """
481 Normal, non-class-conditional BN
482 """
483
484 ## Constructor
485 def __init__(
486 self,
487 output_size,
488 eps=1e-5,
489 momentum=0.1,
490 cross_replica=False,
491 mybn=False,
492 ):
493 super(bn, self).__init__()
494 ## output size
495 self.output_size = output_size
496 ## Prepare gain and bias layers
497 self.gain = P(torch.ones(output_size), requires_grad=True)
498 ## bias
499 self.bias = P(torch.zeros(output_size), requires_grad=True)
500 ## epsilon to avoid dividing by 0
501 self.eps = eps
502 ## Momentum
503 self.momentum = momentum
504 ## Use cross-replica batchnorm?
505 self.cross_replica = cross_replica
506 ## Use my batchnorm?
507 self.mybn = mybn
508
509 if mybn:
510 self.bn = myBN(output_size, self.eps, self.momentum)
511 # Register buffers if neither of the above
512 else:
513 ## Running mean buffer, updated during training
514 self.stored_mean = torch.zeros(output_size)
515 self.register_buffer("stored_mean", torch.zeros(output_size))
516 ## Running variance buffer, updated during training
517 self.stored_var = torch.ones(output_size)
518 self.register_buffer("stored_var", torch.ones(output_size))
519
520 ## Training mode flag (inherited from nn.Module). True if the module is in training mode.
521 self.training: bool
522
523 ## forward
524 def forward(self, x):
525 if self.mybn:
526 gain = self.gain.view(1, -1, 1, 1)
527 bias = self.bias.view(1, -1, 1, 1)
528 return self.bn(x, gain=gain, bias=bias)
529 else:
530 return F.batch_norm(
531 x,
532 self.stored_mean,
533 self.stored_var,
534 self.gain,
535 self.bias,
536 self.training,
537 self.momentum,
538 self.eps,
539 )
540
541
542class ccbn(nn.Module):
543 """
544 Class-conditional bn
545 output size is the number of channels, input size is for the linear layers
546 Andy's Note: this class feels messy but I'm not really sure how to clean it up # noqa
547 Suggestions welcome! (By which I mean, refactor this and make a merge request
548 if you want to make this more readable/usable).
549 """
550
551 ## Constructor
552 def __init__(
553 self,
554 output_size,
555 input_size,
556 which_linear,
557 eps=1e-5,
558 momentum=0.1,
559 cross_replica=False,
560 mybn=False,
561 norm_style="bn",
562 ):
563 super(ccbn, self).__init__()
564 ## output size
565 self.output_size, self.input_size = output_size, input_size
566 ## Prepare gain and bias layers
567 self.gain = which_linear(input_size, output_size)
568 ## bias
569 self.bias = which_linear(input_size, output_size)
570 ## epsilon to avoid dividing by 0
571 self.eps = eps
572 ## Momentum
573 self.momentum = momentum
574 ## Use cross-replica batchnorm?
575 self.cross_replica = cross_replica
576 ## Use my batchnorm?
577 self.mybn = mybn
578 ## Norm style?
579 self.norm_style = norm_style
580
581 if self.mybn:
582 ## bn
583 self.bn = myBN(output_size, self.eps, self.momentum)
584 elif self.norm_style in ["bn", "in"]:
585 self.register_buffer("stored_mean", torch.zeros(output_size))
586 self.register_buffer("stored_var", torch.ones(output_size))
587
588 ## forward
589 def forward(self, x, y):
590 # Calculate class-conditional gains and biases
591 gain = (1 + self.gain(y)).view(y.size(0), -1, 1, 1)
592 bias = self.bias(y).view(y.size(0), -1, 1, 1)
593 # If using my batchnorm
594 if self.mybn:
595 return self.bn(x, gain=gain, bias=bias)
596 # else:
597 else:
598 if self.norm_style == "bn":
599 out = F.batch_norm(
600 x,
601 self.stored_mean,
602 self.stored_var,
603 None,
604 None,
605 self.training,
606 0.1,
607 self.eps,
608 )
609 elif self.norm_style == "in":
610 out = F.instance_norm(
611 x,
612 self.stored_mean,
613 self.stored_var,
614 None,
615 None,
616 self.training,
617 0.1,
618 self.eps,
619 )
620 elif self.norm_style == "gn":
621 out = groupnorm(x, self.normstyle)
622 elif self.norm_style == "nonorm":
623 out = x
624 return out * gain + bias
625
626 ## extra_repr
627 def extra_repr(self):
628 s = "out: {output_size}, in: {input_size},"
629 s += " cross_replica={cross_replica}"
630 # \cond false positive doxygen warning
631 return s.format(**self.__dict__)
632 # \endcond
633
634
635class ILA(nn.Module):
636 """
637 Image_Linear_Attention
638 """
639
640 ## Constructor
641 def __init__(
642 self,
643 chan,
644 chan_out=None,
645 kernel_size=1,
646 padding=0,
647 stride=1,
648 key_dim=32,
649 value_dim=64,
650 heads=8,
651 norm_queries=True,
652 ):
653 super().__init__()
654 ## chan
655 self.chan = chan
656 chan_out = chan if chan_out is None else chan_out
657
658 ## key dimension
659 self.key_dim = key_dim
660 ## value dimension
661 self.value_dim = value_dim
662 ## heads
663 self.heads = heads
664
665 ## norm queries
666 self.norm_queries = norm_queries
667
668 conv_kwargs = {"padding": padding, "stride": stride}
669 ## q
670 self.to_q = nn.Conv2d(chan, key_dim * heads, kernel_size, **conv_kwargs)
671 ## k
672 self.to_k = nn.Conv2d(chan, key_dim * heads, kernel_size, **conv_kwargs)
673 ## v
674 self.to_v = nn.Conv2d(chan, value_dim * heads, kernel_size, **conv_kwargs)
675
676 out_conv_kwargs = {"padding": padding}
677 ## to out
678 self.to_out = nn.Conv2d(
679 value_dim * heads, chan_out, kernel_size, **out_conv_kwargs
680 )
681
682 ## forward
683 def forward(self, x, context=None):
684 b, c, h, w, k_dim, heads = *x.shape, self.key_dim, self.heads
685
686 q, k, v = (self.to_q(x), self.to_k(x), self.to_v(x))
687
688 q, k, v = map(lambda t: t.reshape(b, heads, -1, h * w), (q, k, v))
689
690 q, k = map(lambda x: x * (self.key_dim**-0.25), (q, k))
691
692 if context is not None:
693 context = context.reshape(b, c, 1, -1)
694 ck, cv = self.to_k(context), self.to_v(context)
695 ck, cv = map(lambda t: t.reshape(b, heads, k_dim, -1), (ck, cv))
696 k = torch.cat((k, ck), dim=3)
697 v = torch.cat((v, cv), dim=3)
698
699 k = k.softmax(dim=-1)
700
701 if self.norm_queries:
702 q = q.softmax(dim=-2)
703
704 context = torch.einsum("bhdn,bhen->bhde", k, v)
705 out = torch.einsum("bhdn,bhde->bhen", q, context)
706 out = out.reshape(b, -1, h, w)
707 out = self.to_out(out)
708 return out
709
710
711class CBAM_attention(nn.Module):
712 """CBAM attention"""
713
714 ## Constructor
715 def __init__(
716 self,
717 channels,
718 which_conv=SNConv2d,
719 reduction=8,
720 attention_kernel_size=3,
721 ):
722 super(CBAM_attention, self).__init__()
723 ## average pooling
724 self.avg_pool = nn.AdaptiveAvgPool2d(1)
725 ## max pooling
726 self.max_pool = nn.AdaptiveMaxPool2d(1)
727 ## fcl
728 self.fc1 = which_conv(
729 channels, channels // reduction, kernel_size=1, padding=0
730 )
731 ## relu
732 self.relu = nn.ReLU(inplace=True)
733 ## f2c
734 self.fc2 = which_conv(
735 channels // reduction, channels, kernel_size=1, padding=0
736 )
737 ## sigmoid channel
738 self.sigmoid_channel = nn.Sigmoid()
739 ## convolution after concatenation
740 self.conv_after_concat = which_conv(
741 2,
742 1,
743 kernel_size=attention_kernel_size,
744 stride=1,
745 padding=attention_kernel_size // 2,
746 )
747 ## sigmoid_spatial
748 self.sigmoid_spatial = nn.Sigmoid()
749
750 ## forward
751 def forward(self, x):
752 # Channel attention module
753 module_input = x
754 avg = self.avg_pool(x)
755 mx = self.max_pool(x)
756 avg = self.fc1(avg)
757 mx = self.fc1(mx)
758 avg = self.relu(avg)
759 mx = self.relu(mx)
760 avg = self.fc2(avg)
761 mx = self.fc2(mx)
762 x = avg + mx
763 x = self.sigmoid_channel(x)
764 # Spatial attention module
765 x = module_input * x
766 module_input = x
767 # b, c, h, w = x.size()
768 avg = torch.mean(x, 1, True)
769 mx, _ = torch.max(x, 1, True)
770 x = torch.cat((avg, mx), 1)
771 x = self.conv_after_concat(x)
772 x = self.sigmoid_spatial(x)
773 x = module_input * x
774 return x
775
776
777class Attention(nn.Module):
778 """Attention"""
779
780 ## Constructor
781 def __init__(self, ch, which_conv=SNConv2d):
782 super(Attention, self).__init__()
783 ## Channel multiplier
784 self.ch = ch
785 ## which_conv
786 self.which_conv = which_conv
787 ## theta
788 self.theta = self.which_conv(
789 self.ch, self.ch // 8, kernel_size=1, padding=0, bias=False
790 )
791 ## phi
792 self.phi = self.which_conv(
793 self.ch, self.ch // 8, kernel_size=1, padding=0, bias=False
794 )
795 ## g
796 self.g = self.which_conv(
797 self.ch, self.ch // 2, kernel_size=1, padding=0, bias=False
798 )
799 ## o
800 self.o = self.which_conv(
801 self.ch // 2, self.ch, kernel_size=1, padding=0, bias=False
802 )
803 ## Learnable gain parameter
804 self.gamma = P(torch.tensor(0.0), requires_grad=True)
805
806 ## forward
807 def forward(self, x):
808 # Apply convs
809 theta = self.theta(x)
810 phi = F.max_pool2d(self.phi(x), [2, 2])
811 g = F.max_pool2d(self.g(x), [2, 2])
812 # Perform reshapes
813 theta = theta.view(-1, self.ch // 8, x.shape[2] * x.shape[3])
814 phi = phi.view(-1, self.ch // 8, x.shape[2] * x.shape[3] // 4)
815 g = g.view(-1, self.ch // 2, x.shape[2] * x.shape[3] // 4)
816 # Matmul and softmax to get attention maps
817 beta = F.softmax(torch.bmm(theta.transpose(1, 2), phi), -1)
818 # Attention map times g path
819 o = self.o(
820 torch.bmm(g, beta.transpose(1, 2)).view(
821 -1, self.ch // 2, x.shape[2], x.shape[3]
822 )
823 )
824 return self.gamma * o + x
825
826
827class SNEmbedding(nn.Embedding, SN):
828 """
829 Embedding layer with spectral norm
830 We use num_embeddings as the dim instead of embedding_dim here
831 for convenience sake
832 """
833
834 ## Constructor
835 def __init__(
836 self,
837 num_embeddings,
838 embedding_dim,
839 padding_idx=None,
840 max_norm=None,
841 norm_type=2,
842 scale_grad_by_freq=False,
843 sparse=False,
844 _weight=None,
845 num_svs=1,
846 num_itrs=1,
847 eps=1e-12,
848 ):
849 nn.Embedding.__init__(
850 self,
851 num_embeddings,
852 embedding_dim,
853 padding_idx,
854 max_norm,
855 norm_type,
856 scale_grad_by_freq,
857 sparse,
858 _weight,
859 )
860 SN.__init__(self, num_svs, num_itrs, num_embeddings, eps=eps)
861
862 ## forward
863 def forward(self, x):
864 return F.embedding(x, self.W_())
865
866
867def scaled_dot_product(q, k, v):
868 d_k = q.size()[-1]
869 attn_logits = torch.matmul(q, k.transpose(-2, -1))
870 attn_logits = attn_logits / math.sqrt(d_k)
871 attention = F.softmax(attn_logits, dim=-1)
872 values = torch.matmul(attention, v)
873 return values, attention
874
875
876class MultiheadAttention(nn.Module):
877 """MultiheadAttention"""
878
879 ## Constructor
880 def __init__(self, input_dim, embed_dim, num_heads, which_linear):
881 super().__init__()
882 assert (
883 embed_dim % num_heads == 0
884 ), "Embedding dimension must be 0 modulo number of heads."
885
886 ## embedding dimension
887 self.embed_dim = embed_dim
888 ## number of heads
889 self.num_heads = num_heads
890 ## head dimension
891 self.head_dim = embed_dim // num_heads
892 ## which linear
893 self.which_linear = which_linear
894
895 # Stack all weight matrices 1...h together for efficiency
896 ## qkv projection
897 self.qkv_proj = self.which_linear(input_dim, 3 * embed_dim)
898 ## o projection
899 self.o_proj = self.which_linear(embed_dim, embed_dim)
900
901 self._reset_parameters()
902
903 ## reset parameters
904 def _reset_parameters(self):
905 # Original Transformer initialization, see PyTorch documentation
906 nn.init.xavier_uniform_(self.qkv_proj.weight)
907 self.qkv_proj.bias.data.fill_(0)
908 nn.init.xavier_uniform_(self.o_proj.weight)
909 self.o_proj.bias.data.fill_(0)
910
911 ## forward
912 def forward(self, x, return_attention=False):
913 batch_size, seq_length, embed_dim = x.size()
914 qkv = self.qkv_proj(x)
915
916 # Separate Q, K, V from linear output
917 qkv = qkv.reshape(batch_size, seq_length, self.num_heads, 3 * self.head_dim)
918 qkv = qkv.permute(0, 2, 1, 3) # [Batch, Head, SeqLen, Dims]
919 q, k, v = qkv.chunk(3, dim=-1)
920
921 # Determine value outputs
922 values, attention = scaled_dot_product(q, k, v)
923 values = values.permute(0, 2, 1, 3) # [Batch, SeqLen, Head, Dims]
924 values = values.reshape(batch_size, seq_length, embed_dim)
925 o = self.o_proj(values)
926
927 if return_attention:
928 return o, attention
929 else:
930 return o
931
932
933class EncoderBlock(nn.Module):
934 """EncoderBlock"""
935
936 ## Constructor
937 def __init__(self, input_dim, num_heads, dim_feedforward, dropout, which_linear):
938 """
939 Inputs:
940 input_dim - Dimensionality of the input
941 num_heads - Number of heads to use in the attention block
942 dim_feedforward - Dimensionality of the hidden layer in the MLP
943 dropout - Dropout probability to use in the dropout layers
944 """
945 super().__init__()
946
947 ## which linear
948 self.which_linear = which_linear
949 ## Attention layer
950 self.self_attn = MultiheadAttention(
951 input_dim, input_dim, num_heads, which_linear
952 )
953
954 ## Two-layer MLP
955 self.linear_net = nn.Sequential(
956 self.which_linear(input_dim, dim_feedforward),
957 nn.Dropout(dropout),
958 nn.ReLU(inplace=True),
959 self.which_linear(dim_feedforward, input_dim),
960 )
961
962 # Layers to apply in between the main layers
963 ## norm1
964 self.norm1 = nn.LayerNorm(input_dim)
965 ## norm2
966 self.norm2 = nn.LayerNorm(input_dim)
967 ## dropout
968 self.dropout = nn.Dropout(dropout)
969
970 ## forward
971 def forward(self, x):
972 # Attention part
973 x_pre1 = self.norm1(x)
974 attn_out = self.self_attn(x_pre1)
975 x = x + self.dropout(attn_out)
976 # x = self.norm1(x)
977
978 # MLP part
979 x_pre2 = self.norm2(x)
980 linear_out = self.linear_net(x_pre2)
981 x = x + self.dropout(linear_out)
982 # x = self.norm2(x)
983
984 return x
985
986
987class RelationalReasoning(nn.Module):
988 """RelationalReasoning"""
989
990 ## Constructor
991 def __init__(self, num_layers, hidden_dim, **block_args):
992 super().__init__()
993 ## layers
994 self.layers = nn.ModuleList(
995 [EncoderBlock(**block_args) for _ in range(num_layers)]
996 )
997 ## normalization
998 self.norm = nn.LayerNorm(hidden_dim)
999
1000 ## forward
1001 def forward(self, x):
1002 for layer in self.layers:
1003 x = layer(x)
1004
1005 x = self.norm(x)
1006 return x
1007
1008 ## get attention maps
1009 def get_attention_maps(self, x):
1010 attention_maps = []
1011 for layer in self.layers:
1012 _, attn_map = layer.self_attn(x, return_attention=True)
1013 attention_maps.append(attn_map)
1014 x = layer(x)
1015 return attention_maps
1016
1017
1018class GBlock(nn.Module):
1019 """GBlock"""
1020
1021 ## Constructor
1022 def __init__(
1023 self,
1024 in_channels,
1025 out_channels,
1026 which_conv=SNConv2d,
1027 which_bn=bn,
1028 activation=None,
1029 upsample=None,
1030 channel_ratio=4,
1031 ):
1032 super(GBlock, self).__init__()
1033
1034 ## input channels
1035 self.in_channels, self.out_channels = in_channels, out_channels
1036 ## hidden channels
1037 self.hidden_channels = self.in_channels // channel_ratio
1038 ## which convolution
1039 self.which_conv, self.which_bn = which_conv, which_bn
1040 ## activation
1041 self.activation = activation
1042 # Conv layers
1043 ## conv1
1044 self.conv1 = self.which_conv(
1045 self.in_channels, self.hidden_channels, kernel_size=1, padding=0
1046 )
1047 ## conv2
1048 self.conv2 = self.which_conv(self.hidden_channels, self.hidden_channels)
1049 ## conv3
1050 self.conv3 = self.which_conv(self.hidden_channels, self.hidden_channels)
1051 ## conv4
1052 self.conv4 = self.which_conv(
1053 self.hidden_channels, self.out_channels, kernel_size=1, padding=0
1054 )
1055 # Batchnorm layers
1056 ## bn1
1057 self.bn1 = self.which_bn(self.in_channels)
1058 ## bn2
1059 self.bn2 = self.which_bn(self.hidden_channels)
1060 ## bn3
1061 self.bn3 = self.which_bn(self.hidden_channels)
1062 ## bn4
1063 self.bn4 = self.which_bn(self.hidden_channels)
1064 ## upsample layers
1065 self.upsample = upsample
1066
1067 ## forward
1068 def forward(self, x, y):
1069 # Project down to channel ratio
1070 h = self.conv1(self.activation(self.bn1(x, y)))
1071 # Apply next BN-ReLU
1072 h = self.activation(self.bn2(h, y))
1073 # Drop channels in x if necessary
1074 if self.in_channels != self.out_channels:
1075 x = x[:, : self.out_channels]
1076 # Upsample both h and x at this point
1077 if self.upsample:
1078 h = self.upsample(h)
1079 x = self.upsample(x)
1080 # 3x3 convs
1081 h = self.conv2(h)
1082 h = self.conv3(self.activation(self.bn3(h, y)))
1083 # Final 1x1 conv
1084 h = self.conv4(self.activation(self.bn4(h, y)))
1085 return h + x
1086
1087
1088def G_arch(ch=64, attention="64"):
1089 arch = {}
1090 arch[512] = {
1091 "in_channels": [ch * item for item in [16, 16, 8, 8, 4, 2, 1]],
1092 "out_channels": [ch * item for item in [16, 8, 8, 4, 2, 1, 1]],
1093 "upsample": [True] * 7,
1094 "resolution": [8, 16, 32, 64, 128, 256, 512],
1095 "attention": {
1096 2**i: (2 ** i in [int(item) for item in attention.split("_")])
1097 for i in range(3, 10)
1098 },
1099 }
1100 arch[256] = {
1101 "in_channels": [ch * item for item in [16, 16, 8, 8, 4, 2]],
1102 "out_channels": [ch * item for item in [16, 8, 8, 4, 2, 1]],
1103 "upsample": [True] * 6,
1104 "resolution": [8, 16, 32, 64, 128, 256],
1105 "attention": {
1106 2**i: (2 ** i in [int(item) for item in attention.split("_")])
1107 for i in range(3, 9)
1108 },
1109 }
1110 arch[128] = {
1111 "in_channels": [ch * item for item in [16, 16, 8, 4, 2]],
1112 "out_channels": [ch * item for item in [16, 8, 4, 2, 1]],
1113 "upsample": [True] * 5,
1114 "resolution": [8, 16, 32, 64, 128],
1115 "attention": {
1116 2**i: (2 ** i in [int(item) for item in attention.split("_")])
1117 for i in range(3, 8)
1118 },
1119 }
1120 arch[96] = {
1121 "in_channels": [ch * item for item in [16, 16, 8, 4]],
1122 "out_channels": [ch * item for item in [16, 8, 4, 2]],
1123 "upsample": [True] * 4,
1124 "resolution": [12, 24, 48, 96],
1125 "attention": {
1126 12 * 2**i: (6 * 2 ** i in [int(item) for item in attention.split("_")])
1127 for i in range(0, 4)
1128 },
1129 }
1130
1131 arch[64] = {
1132 "in_channels": [ch * item for item in [16, 16, 8, 4]],
1133 "out_channels": [ch * item for item in [16, 8, 4, 2]],
1134 "upsample": [True] * 4,
1135 "resolution": [8, 16, 32, 64],
1136 "attention": {
1137 2**i: (2 ** i in [int(item) for item in attention.split("_")])
1138 for i in range(3, 7)
1139 },
1140 }
1141 arch[32] = {
1142 "in_channels": [ch * item for item in [4, 4, 4]],
1143 "out_channels": [ch * item for item in [4, 4, 4]],
1144 "upsample": [True] * 3,
1145 "resolution": [8, 16, 32],
1146 "attention": {
1147 2**i: (2 ** i in [int(item) for item in attention.split("_")])
1148 for i in range(3, 6)
1149 },
1150 }
1151
1152 return arch
1153
1154
1155class Generator(nn.Module):
1156 """Generator"""
1157
1158 ## Constructor
1159 def __init__(
1160 self,
1161 G_ch=64,
1162 G_depth=2,
1163 dim_z=128,
1164 bottom_width=4,
1165 resolution=256,
1166 G_kernel_size=3,
1167 G_attn="64",
1168 n_classes=40,
1169 H_base=1,
1170 num_G_SVs=1,
1171 num_G_SV_itrs=1,
1172 attn_type="sa",
1173 G_shared=True,
1174 shared_dim=128,
1175 hier=True,
1176 cross_replica=False,
1177 mybn=False,
1178 G_activation=nn.ReLU(inplace=False),
1179 G_lr=5e-5,
1180 G_B1=0.0,
1181 G_B2=0.999,
1182 adam_eps=1e-8,
1183 BN_eps=1e-5,
1184 SN_eps=1e-12,
1185 G_init="ortho",
1186 G_mixed_precision=False,
1187 G_fp16=False,
1188 skip_init=False,
1189 no_optim=False,
1190 sched_version="default",
1191 RRM_prx_G=True,
1192 n_head_G=2,
1193 G_param="SN",
1194 norm_style="bn",
1195 **kwargs
1196 ):
1197 super(Generator, self).__init__()
1198 ## Channel width multiplier
1199 self.ch = G_ch
1200 ## Number of resblocks per stage
1201 self.G_depth = G_depth
1202 ## Dimensionality of the latent space
1203 self.dim_z = dim_z
1204 ## The initial spatial dimensions
1205 self.bottom_width = bottom_width
1206 ## The initial harizontal dimension
1207 self.H_base = H_base
1208 ## Resolution of the output
1209 self.resolution = resolution
1210 ## Kernel size?
1211 self.kernel_size = G_kernel_size
1212 ## Attention?
1213 self.attention = G_attn
1214 ## number of classes, for use in categorical conditional generation
1215 self.n_classes = n_classes
1216 ## Use shared embeddings?
1217 self.G_shared = G_shared
1218 ## Dimensionality of the shared embedding? Unused if not using G_shared
1219 self.shared_dim = shared_dim if shared_dim > 0 else dim_z
1220 ## Hierarchical latent space?
1221 self.hier = hier
1222 ## Cross replica batchnorm?
1223 self.cross_replica = cross_replica
1224 ## Use my batchnorm?
1225 self.mybn = mybn
1226 # nonlinearity for residual blocks
1227 if G_activation == "inplace_relu":
1228 ## activation
1229 self.activation = torch.nn.ReLU(inplace=True)
1230 elif G_activation == "relu":
1231 self.activation = torch.nn.ReLU(inplace=False)
1232 elif G_activation == "leaky_relu":
1233 self.activation = torch.nn.LeakyReLU(0.2, inplace=False)
1234 else:
1235 raise NotImplementedError("activation function not implemented")
1236 ## Initialization style
1237 self.init = G_init
1238 ## Parameterization style
1239 self.G_param = G_param
1240 ## Normalization style
1241 self.norm_style = norm_style
1242 ## Epsilon for BatchNorm?
1243 self.BN_eps = BN_eps
1244 ## Epsilon for Spectral Norm?
1245 self.SN_eps = SN_eps
1246 ## fp16?
1247 self.fp16 = G_fp16
1248 ## Architecture dict
1249 self.arch = G_arch(self.ch, self.attention)[resolution]
1250 ## RRM_prx_G
1251 self.RRM_prx_G = RRM_prx_G
1252 ## n_head_G
1253 self.n_head_G = n_head_G
1254
1255 # Which convs, batchnorms, and linear layers to use
1256 if self.G_param == "SN":
1257 ## which conv
1258 self.which_conv = functools.partial(
1259 SNConv2d,
1260 kernel_size=3,
1261 padding=1,
1262 num_svs=num_G_SVs,
1263 num_itrs=num_G_SV_itrs,
1264 eps=self.SN_eps,
1265 )
1266 ## which linear
1267 self.which_linear = functools.partial(
1268 SNLinear,
1269 num_svs=num_G_SVs,
1270 num_itrs=num_G_SV_itrs,
1271 eps=self.SN_eps,
1272 )
1273 else:
1274 self.which_conv = functools.partial(nn.Conv2d, kernel_size=3, padding=1)
1275 self.which_linear = nn.Linear
1276
1277 # We use a non-spectral-normed embedding here regardless;
1278 # For some reason applying SN to G's embedding seems to randomly cripple G # noqa
1279 ## which embedding
1280 self.which_embedding = nn.Embedding
1281 bn_linear = (
1282 functools.partial(self.which_linear, bias=False)
1283 if self.G_shared
1284 else self.which_embedding
1285 )
1286 ## which bn
1287 self.which_bn = functools.partial(
1288 ccbn,
1289 which_linear=bn_linear,
1290 cross_replica=self.cross_replica,
1291 mybn=self.mybn,
1292 input_size=(
1293 self.shared_dim + self.dim_z if self.G_shared else self.n_classes
1294 ),
1295 norm_style=self.norm_style,
1296 eps=self.BN_eps,
1297 )
1298 ## shared
1299 self.shared = (
1300 self.which_embedding(n_classes, self.shared_dim)
1301 if G_shared
1302 else identity()
1303 )
1304
1305 if self.RRM_prx_G:
1306 ## RRM on proxy embeddings
1307 self.RR_G = RelationalReasoning(
1308 num_layers=1,
1309 input_dim=128,
1310 dim_feedforward=128,
1311 which_linear=nn.Linear,
1312 num_heads=self.n_head_G,
1313 dropout=0.0,
1314 hidden_dim=128,
1315 )
1316
1317 ## First linear layer
1318 self.linear = self.which_linear(
1319 self.dim_z + self.shared_dim,
1320 self.arch["in_channels"][0] * ((self.bottom_width**2) * self.H_base),
1321 )
1322
1323 # self.blocks is a doubly-nested list of modules, the outer loop intended # noqa
1324 # to be over blocks at a given resolution (resblocks and/or self-attention) # noqa
1325 # while the inner loop is over a given block
1326 ## blocks
1327 self.blocks = []
1328 for index in range(len(self.arch["out_channels"])):
1329 self.blocks += [
1330 [
1331 GBlock(
1332 in_channels=self.arch["in_channels"][index],
1333 out_channels=self.arch["in_channels"][index]
1334 if g_index == 0
1335 else self.arch["out_channels"][index],
1336 which_conv=self.which_conv,
1337 which_bn=self.which_bn,
1338 activation=self.activation,
1339 upsample=(
1340 functools.partial(F.interpolate, scale_factor=2)
1341 if self.arch["upsample"][index]
1342 and g_index == (self.G_depth - 1)
1343 else None
1344 ),
1345 )
1346 ]
1347 for g_index in range(self.G_depth)
1348 ]
1349
1350 # If attention on this block, attach it to the end
1351 if self.arch["attention"][self.arch["resolution"][index]]:
1352 print(
1353 f"Adding attention layer in G at resolution {self.arch['resolution'][index]:d}"
1354 )
1355
1356 if attn_type == "sa":
1357 self.blocks[-1] += [
1358 Attention(self.arch["out_channels"][index], self.which_conv)
1359 ]
1360 elif attn_type == "cbam":
1361 self.blocks[-1] += [
1362 CBAM_attention(
1363 self.arch["out_channels"][index], self.which_conv
1364 )
1365 ]
1366 elif attn_type == "ila":
1367 self.blocks[-1] += [ILA(self.arch["out_channels"][index])]
1368
1369 # Turn self.blocks into a ModuleList so that it's all properly registered. # noqa
1370 self.blocks = nn.ModuleList([nn.ModuleList(block) for block in self.blocks])
1371
1372 # output layer: batchnorm-relu-conv.
1373 # Consider using a non-spectral conv here
1374 ## output layer
1375 self.output_layer = nn.Sequential(
1376 bn(
1377 self.arch["out_channels"][-1],
1378 cross_replica=self.cross_replica,
1379 mybn=self.mybn,
1380 ),
1381 self.activation,
1382 self.which_conv(self.arch["out_channels"][-1], 1),
1383 )
1384
1385 # Initialize weights. Optionally skip init for testing.
1386 if not skip_init:
1387 self.init_weights()
1388
1389 # Set up optimizer
1390 # If this is an EMA copy, no need for an optim, so just return now
1391 if no_optim:
1392 return
1393 ## lr
1394 self.lr = G_lr
1395 ## B1
1396 self.B1 = G_B1
1397 ## B2
1398 self.B2 = G_B2
1399 ## adam_eps
1400 self.adam_eps = adam_eps
1401 if G_mixed_precision:
1402 print("Using fp16 adam in G...")
1403 import utils
1404
1405 self.optim = utils.Adam16(
1406 params=self.parameters(),
1407 lr=self.lr,
1408 betas=(self.B1, self.B2),
1409 weight_decay=0,
1410 eps=self.adam_eps,
1411 )
1412
1413 ## optim
1414 self.optim = optim.Adam(
1415 params=self.parameters(),
1416 lr=self.lr,
1417 betas=(self.B1, self.B2),
1418 weight_decay=0,
1419 eps=self.adam_eps,
1420 )
1421 # LR scheduling
1422 if sched_version == "default":
1423 ## lr sched
1424 self.lr_sched = None
1425 elif sched_version == "CosAnnealLR":
1426 self.lr_sched = optim.lr_scheduler.CosineAnnealingLR(
1427 self.optim,
1428 T_max=kwargs["num_epochs"],
1429 eta_min=self.lr / 4,
1430 last_epoch=-1,
1431 )
1432 elif sched_version == "CosAnnealWarmRes":
1433 self.lr_sched = optim.lr_scheduler.CosineAnnealingWarmRestarts(
1434 self.optim, T_0=10, T_mult=2, eta_min=self.lr / 4
1435 )
1436 else:
1437 self.lr_sched = None
1438
1439 ## Initialize
1440 def init_weights(self):
1441 ## parameter count
1442 self.param_count = 0
1443 for module in self.modules():
1444 if (
1445 isinstance(module, nn.Conv2d)
1446 or isinstance(module, nn.Linear)
1447 or isinstance(module, nn.Embedding)
1448 ):
1449 if self.init == "ortho":
1450 init.orthogonal_(module.weight)
1451 elif self.init == "N02":
1452 init.normal_(module.weight, 0, 0.02)
1453 elif self.init in ["glorot", "xavier"]:
1454 init.xavier_uniform_(module.weight)
1455 else:
1456 print("Init style not recognized...")
1457 self.param_count += sum(
1458 [p.data.nelement() for p in module.parameters()]
1459 )
1460 print(f"Param count for G's initialized parameters: {self.param_count}")
1461
1462 ## forward
1463 def forward(self, z, y):
1464 y = self.shared(y)
1465 # If relational embedding
1466 if self.RRM_prx_G:
1467 y = self.RR_G(y.unsqueeze(0)).squeeze(0)
1468 # y = F.normalize(y, dim=1)
1469 # If hierarchical, concatenate zs and ys
1470 if self.hier: # y and z are [bs,128] dimensional
1471 z = torch.cat([y, z], 1)
1472 y = z
1473 # First linear layer
1474 h = self.linear(z) # ([bs,256]-->[bs,24576])
1475 # Reshape
1476 h = h.view(h.size(0), -1, self.bottom_width, self.bottom_width * self.H_base)
1477 # Loop over blocks
1478 for _, blocklist in enumerate(self.blocks):
1479 # Second inner loop in case block has multiple layers
1480 for block in blocklist:
1481 h = block(h, y)
1482
1483 # Apply batchnorm-relu-conv-tanh at output
1484 return torch.tanh(self.output_layer(h))
1485
1486
1487class Model(Generator):
1488 """
1489 Generator subclass
1490 default initializing with CONFIG dict
1491 """
1492
1493 ## Constructor
1494 def __init__(self):
1495 super().__init__(**CONFIG)
1496
1497
1498def generate(model: nn.Module):
1499 """
1500 Run inference with the provided Generator model
1501
1502 Args:
1503 model (nn.Module): Generator model
1504
1505 Returns:
1506 torch.Tensor: batch of 40 PXD images
1507 """
1508 device = next(model.parameters()).device
1509 with torch.no_grad():
1510 latents = torch.randn(40, 128, device=device)
1511 labels = torch.tensor(list(range(40)), dtype=torch.long, device=device)
1512 imgs = model(latents, labels).detach().cpu()
1513 # Cut the noise below 7 ADU
1514 imgs = F.threshold(imgs, -0.26, -1)
1515 # center range [-1, 1] to [0, 1]
1516 imgs = imgs.mul_(0.5).add_(0.5)
1517 # renormalize and convert to uint8
1518 imgs = torch.pow(256, imgs).add_(-1).clamp_(0, 255).to(torch.uint8)
1519 # flatten channel dimension and crop 256 to 250
1520 imgs = imgs[:, 0, 3:-3, :]
1521 return imgs
__init__(self, num_svs, num_itrs, num_outputs, transpose=False, eps=1e-12)
Definition ieagan.py:244
forward(self, torch.Tensor tensor)
forward
Definition ieagan.py:222