Belle II Software development
ieagan.py
1
8"""
9This module implements the IEA-GAN generator model.
10"""
11
12import functools
13import json
14import math
15
16import torch
17from torch import nn
18from torch import optim
19from torch.nn import init
20from torch.nn import Parameter as P
21import torch.nn.functional as F
22
23
24CONFIG = json.loads(
25 """{
26 "num_workers": 8,
27 "seed": 415,
28 "pin_memory": false,
29 "shuffle": true,
30 "augment": 0,
31 "use_multiepoch_sampler": false,
32 "model": "BigGAN_deep",
33 "G_ch": 32,
34 "G_param" : "SN",
35 "D_param" : "SN",
36 "D_ch": 32,
37 "G_depth": 2,
38 "D_depth": 2,
39 "H_base": 3,
40 "D_wide": true,
41 "G_shared": true,
42 "shared_dim": 128,
43 "dim_z": 128,
44 "z_var": 1.0,
45 "hier": true,
46 "cross_replica": false,
47 "mybn": false,
48 "G_activation": "inplace_relu",
49 "D_activation": "inplace_relu",
50 "G_attn": "0",
51 "D_attn": "0",
52 "norm_style": "bn",
53 "G_init": "ortho",
54 "D_init": "ortho",
55 "skip_init": false,
56 "G_lr": 5e-05,
57 "D_lr": 5e-05,
58 "G_B1": 0.0,
59 "D_B1": 0.0,
60 "G_B2": 0.999,
61 "D_B2": 0.999,
62 "batch_size": 40,
63 "G_batch_size": 0,
64 "num_G_accumulations": 1,
65 "num_D_steps": 1,
66 "num_D_accumulations": 1,
67 "split_D": true,
68 "num_epochs": 4,
69 "parallel": false,
70 "G_fp16": false,
71 "D_fp16": false,
72 "D_mixed_precision": false,
73 "G_mixed_precision": false,
74 "accumulate_stats": false,
75 "num_standing_accumulations": 16,
76 "G_eval_mode": true,
77 "save_every": 1000,
78 "test_every": 1000,
79 "num_save_copies": 2,
80 "num_best_copies": 2,
81 "ema": true,
82 "ema_decay": 0.9999,
83 "use_ema": true,
84 "ema_start": 10000,
85 "adam_eps": 1e-06,
86 "BN_eps": 1e-05,
87 "SN_eps": 1e-06,
88 "num_G_SVs": 1,
89 "num_D_SVs": 1,
90 "num_G_SV_itrs": 1,
91 "num_D_SV_itrs": 1,
92 "G_ortho": 0.0001,
93 "D_ortho": 0.0,
94 "toggle_grads": true,
95 "logstyle": "%3.3e",
96 "sv_log_interval": 10,
97 "log_interval": 100,
98 "resolution": 256,
99 "n_classes": 40,
100 "run_name": "BGd_140",
101 "resume": false,
102 "latent_op": false,
103 "latent_reg_weight": 300,
104 "bottom_width": 4,
105 "add_blur" : false,
106 "add_noise": true,
107 "add_style": false,
108 "conditional_strategy": "Contra",
109 "hypersphere_dim": 1024,
110 "pos_collected_numerator": false,
111 "nonlinear_embed": false,
112 "normalize_embed": true,
113 "inv_stereographic" :false,
114 "contra_lambda": 1.0,
115 "Angle": false,
116 "angle_lambda": 1.0,
117 "IEA_loss": true,
118 "IEA_lambda": 1.0,
119 "Uniformity_loss": true,
120 "unif_lambda": 0.1,
121 "diff_aug": true,
122 "Con_reg": false,
123 "cr_lambda": 10,
124 "pixel_reg": false,
125 "px_lambda": 1.0,
126 "RRM_prx_G": true,
127 "normalized_proxy_G": false,
128 "RRM_prx_D": false,
129 "RRM_embed": true,
130 "n_head_G": 2,
131 "n_head": 4,
132 "attn_type": "sa",
133 "sched_version": "default",
134 "z_dist": "normal",
135 "truncated_threshold": 1.0,
136 "clip_norm": "None",
137 "amsgrad": false,
138 "arch": "None",
139 "G_kernel_size": 3,
140 "D_kernel_size": 3,
141 "ada_belief": false,
142 "pbar": "tqdm",
143 "which_best": "FID",
144 "stop_after": 100000,
145 "trunc_z": 0.5,
146 "denoise": false,
147 "metric_log_name": "metric_log.jsonl",
148 "reinitialize_metric_logs": false,
149 "reinitialize_parameter_logs": false,
150 "num_incep_images": 16000,
151 "load_optim": true}"""
152)
153
154
155def proj(x, y):
156 """
157 Projection of x onto y
158 """
159 return torch.mm(y, x.t()) * y / torch.mm(y, y.t())
160
161
162def gram_schmidt(x, ys):
163 """
164 Orthogonalize x wrt list of vectors ys
165 """
166 for y in ys:
167 x = x - proj(x, y)
168 return x
169
170
171def power_iteration(W, u_, update=True, eps=1e-12):
172 """
173 Apply num_itrs steps of the power method to estimate top N singular values.
174 """
175 # Lists holding singular vectors and values
176 us, vs, svs = [], [], []
177 for i, u in enumerate(u_):
178 # Run one step of the power iteration
179 with torch.no_grad():
180 v = torch.matmul(u, W)
181 # Run Gram-Schmidt to subtract components of all other singular vectors # noqa
182 v = F.normalize(gram_schmidt(v, vs), eps=eps)
183 # Add to the list
184 vs += [v]
185 # Update the other singular vector
186 u = torch.matmul(v, W.t())
187 # Run Gram-Schmidt to subtract components of all other singular vectors # noqa
188 u = F.normalize(gram_schmidt(u, us), eps=eps)
189 # Add to the list
190 us += [u]
191 if update:
192 u_[i][:] = u
193 # Compute this singular value and add it to the list
194 svs += [torch.squeeze(torch.matmul(torch.matmul(v, W.t()), u.t()))]
195 # svs += [torch.sum(F.linear(u, W.transpose(0, 1)) * v)]
196 return svs, us, vs
197
198
199def groupnorm(x, norm_style):
200 """
201 Simple function to handle groupnorm norm stylization
202 """
203 # If number of channels specified in norm_style:
204 if "ch" in norm_style:
205 ch = int(norm_style.split("_")[-1])
206 groups = max(int(x.shape[1]) // ch, 1)
207 # If number of groups specified in norm style
208 elif "grp" in norm_style:
209 groups = int(norm_style.split("_")[-1])
210 # If neither, default to groups = 16
211 else:
212 groups = 16
213 return F.group_norm(x, groups)
214
215
216class identity(nn.Module):
217 """
218 Convenience passthrough function
219 """
220
221
222 def forward(self, tensor: torch.Tensor):
223 return tensor
224
225
226class SN(object):
227 """
228 Spectral normalization base class
229 """
230 # pylint: disable=no-member
231
232 def __init__(self, num_svs, num_itrs, num_outputs, transpose=False, eps=1e-12):
233 """constructor"""
234
235
236 self.num_itrs = num_itrs
237
238 self.num_svs = num_svs
239
240 self.transpose = transpose
241
242 self.eps = eps
243 # Register a singular vector for each sv
244 for i in range(self.num_svs):
245 self.register_buffer(f"u{i:d}", torch.randn(1, num_outputs))
246 self.register_buffer(f"sv{i:d}", torch.ones(1))
247
248 @property
249 def u(self):
250 """
251 Singular vectors (u side)
252 """
253 return [getattr(self, f"u{i:d}") for i in range(self.num_svs)]
254
255 @property
256 def sv(self):
257 """
258 Singular values
259 note that these buffers are just for logging and are not used in training.
260 """
261 return [getattr(self, f"sv{i:d}") for i in range(self.num_svs)]
262
263 def W_(self):
264 """
265 Compute the spectrally-normalized weight
266 """
267 W_mat = self.weight.view(self.weight.size(0), -1)
268 if self.transpose:
269 W_mat = W_mat.t()
270 # Apply num_itrs power iterations
271 for _ in range(self.num_itrs):
272 svs, _, _ = power_iteration(
273 W_mat, self.u, update=self.training, eps=self.eps
274 )
275 # Update the svs
276 if self.training:
277 # Make sure to do this in a no_grad() context or you'll get memory leaks! # noqa
278 with torch.no_grad():
279 for i, sv in enumerate(svs):
280 self.sv[i][:] = sv
281 return self.weight / svs[0]
282
283
284class SNConv2d(nn.Conv2d, SN):
285 """
286 2D Conv layer with spectral norm
287 """
288
289
291 self,
292 in_channels,
293 out_channels,
294 kernel_size,
295 stride=1,
296 padding=0,
297 dilation=1,
298 groups=1,
299 bias=True,
300 num_svs=1,
301 num_itrs=1,
302 eps=1e-12,
303 ):
304 nn.Conv2d.__init__(
305 self,
306 in_channels,
307 out_channels,
308 kernel_size,
309 stride,
310 padding,
311 dilation,
312 groups,
313 bias,
314 )
315 SN.__init__(self, num_svs, num_itrs, out_channels, eps=eps)
316
317
318 def forward(self, x):
319 return F.conv2d(
320 x,
321 self.W_(),
322 self.bias,
323 self.stride,
324 self.padding,
325 self.dilation,
326 self.groups,
327 )
328
329
330class SNLinear(nn.Linear, SN):
331 """
332 Linear layer with spectral norm
333 """
334
335
337 self,
338 in_features,
339 out_features,
340 bias=True,
341 num_svs=1,
342 num_itrs=1,
343 eps=1e-12,
344 ):
345 nn.Linear.__init__(self, in_features, out_features, bias)
346 SN.__init__(self, num_svs, num_itrs, out_features, eps=eps)
347
348
349 def forward(self, x):
350 return F.linear(x, self.W_(), self.bias)
351
352
353def fused_bn(x, mean, var, gain=None, bias=None, eps=1e-5):
354 """Fused batchnorm op"""
355
356 # Apply scale and shift--if gain and bias are provided, fuse them here
357 # Prepare scale
358 scale = torch.rsqrt(var + eps)
359 # If a gain is provided, use it
360 if gain is not None:
361 scale = scale * gain
362 # Prepare shift
363 shift = mean * scale
364 # If bias is provided, use it
365 if bias is not None:
366 shift = shift - bias
367 return x * scale - shift
368 # return ((x - mean) / ((var + eps) ** 0.5)) * gain + bias # The unfused way. # noqa
369
370
371def manual_bn(x, gain=None, bias=None, return_mean_var=False, eps=1e-5):
372 """
373 Manual BN
374 Calculate means and variances using mean-of-squares minus mean-squared
375 """
376
377 # Cast x to float32 if necessary
378 float_x = x.float()
379 # Calculate expected value of x (m) and expected value of x**2 (m2)
380 # Mean of x
381 m = torch.mean(float_x, [0, 2, 3], keepdim=True)
382 # Mean of x squared
383 m2 = torch.mean(float_x**2, [0, 2, 3], keepdim=True)
384 # Calculate variance as mean of squared minus mean squared.
385 var = m2 - m**2
386 # Cast back to float 16 if necessary
387 var = var.type(x.type())
388 m = m.type(x.type())
389 # Return mean and variance for updating stored mean/var if requested
390 if return_mean_var:
391 return (
392 fused_bn(x, m, var, gain, bias, eps),
393 m.squeeze(),
394 var.squeeze(),
395 )
396 else:
397 return fused_bn(x, m, var, gain, bias, eps)
398
399
400class myBN(nn.Module):
401 """
402 My batchnorm, supports standing stats
403 """
404
405
406 def __init__(self, num_channels, eps=1e-5, momentum=0.1):
407 super(myBN, self).__init__()
408
409 self.momentum = momentum
410
411 self.eps = eps
412 # Momentum
413 self.momentum = momentum
414 # Register buffers
415 self.register_buffer("stored_mean", torch.zeros(num_channels))
416 self.register_buffer("stored_var", torch.ones(num_channels))
417 self.register_buffer("accumulation_counter", torch.zeros(1))
418
420
421
422 def reset_stats(self):
423 # pylint: disable=no-member
424 self.stored_mean[:] = 0
425 self.stored_var[:] = 0
426 self.accumulation_counter[:] = 0
427
428
429 def forward(self, x, gain, bias):
430 # pylint: disable=no-member
431 if self.training:
432 out, mean, var = manual_bn(
433 x, gain, bias, return_mean_var=True, eps=self.eps
434 )
435 # If accumulating standing stats, increment them
436 if self.accumulate_standing:
437 self.stored_mean[:] = self.stored_mean + mean.data
438 self.stored_var[:] = self.stored_var + var.data
439 self.accumulation_counter += 1.0
440 # If not accumulating standing stats, take running averages
441 else:
442 self.stored_mean[:] = (
443 self.stored_mean * (1 - self.momentum) + mean * self.momentum
444 )
445 self.stored_var[:] = (
446 self.stored_var * (1 - self.momentum) + var * self.momentum
447 )
448 return out
449 # If not in training mode, use the stored statistics
450 else:
451 mean = self.stored_mean.view(1, -1, 1, 1)
452 var = self.stored_var.view(1, -1, 1, 1)
453 # If using standing stats, divide them by the accumulation counter
454 if self.accumulate_standing:
455 mean = mean / self.accumulation_counter
456 var = var / self.accumulation_counter
457 return fused_bn(x, mean, var, gain, bias, self.eps)
458
459
460class bn(nn.Module):
461 """
462 Normal, non-class-conditional BN
463 """
464
465
467 self,
468 output_size,
469 eps=1e-5,
470 momentum=0.1,
471 cross_replica=False,
472 mybn=False,
473 ):
474 super(bn, self).__init__()
475
476 self.output_size = output_size
477
478 self.gain = P(torch.ones(output_size), requires_grad=True)
479
480 self.bias = P(torch.zeros(output_size), requires_grad=True)
481
482 self.eps = eps
483
484 self.momentum = momentum
485
486 self.cross_replica = cross_replica
487
488 self.mybn = mybn
489
490 if mybn:
491 self.bn = myBN(output_size, self.eps, self.momentum)
492 # Register buffers if neither of the above
493 else:
494 self.register_buffer("stored_mean", torch.zeros(output_size))
495 self.register_buffer("stored_var", torch.ones(output_size))
496
497
498 def forward(self, x):
499 if self.mybn:
500 gain = self.gain.view(1, -1, 1, 1)
501 bias = self.bias.view(1, -1, 1, 1)
502 return self.bn(x, gain=gain, bias=bias)
503 else:
504 return F.batch_norm(
505 x,
506 self.stored_mean,
507 self.stored_var,
508 self.gain,
509 self.bias,
510 self.training,
511 self.momentum,
512 self.eps,
513 )
514
515
516class ccbn(nn.Module):
517 """
518 Class-conditional bn
519 output size is the number of channels, input size is for the linear layers
520 Andy's Note: this class feels messy but I'm not really sure how to clean it up # noqa
521 Suggestions welcome! (By which I mean, refactor this and make a merge request
522 if you want to make this more readable/usable).
523 """
524
525
527 self,
528 output_size,
529 input_size,
530 which_linear,
531 eps=1e-5,
532 momentum=0.1,
533 cross_replica=False,
534 mybn=False,
535 norm_style="bn",
536 ):
537 super(ccbn, self).__init__()
538
539 self.output_size, self.input_size = output_size, input_size
540
541 self.gain = which_linear(input_size, output_size)
542
543 self.bias = which_linear(input_size, output_size)
544
545 self.eps = eps
546
547 self.momentum = momentum
548
549 self.cross_replica = cross_replica
550
551 self.mybn = mybn
552
553 self.norm_style = norm_style
554
555 if self.mybn:
556
557 self.bn = myBN(output_size, self.eps, self.momentum)
558 elif self.norm_style in ["bn", "in"]:
559 self.register_buffer("stored_mean", torch.zeros(output_size))
560 self.register_buffer("stored_var", torch.ones(output_size))
561
562
563 def forward(self, x, y):
564 # Calculate class-conditional gains and biases
565 gain = (1 + self.gain(y)).view(y.size(0), -1, 1, 1)
566 bias = self.bias(y).view(y.size(0), -1, 1, 1)
567 # If using my batchnorm
568 if self.mybn:
569 return self.bn(x, gain=gain, bias=bias)
570 # else:
571 else:
572 if self.norm_style == "bn":
573 out = F.batch_norm(
574 x,
575 self.stored_mean,
576 self.stored_var,
577 None,
578 None,
579 self.training,
580 0.1,
581 self.eps,
582 )
583 elif self.norm_style == "in":
584 out = F.instance_norm(
585 x,
586 self.stored_mean,
587 self.stored_var,
588 None,
589 None,
590 self.training,
591 0.1,
592 self.eps,
593 )
594 elif self.norm_style == "gn":
595 out = groupnorm(x, self.normstyle)
596 elif self.norm_style == "nonorm":
597 out = x
598 return out * gain + bias
599
600
601 def extra_repr(self):
602 s = "out: {output_size}, in: {input_size},"
603 s += " cross_replica={cross_replica}"
604 return s.format(**self.__dict__)
605
606
607class ILA(nn.Module):
608 """
609 Image_Linear_Attention
610 """
611
612
614 self,
615 chan,
616 chan_out=None,
617 kernel_size=1,
618 padding=0,
619 stride=1,
620 key_dim=32,
621 value_dim=64,
622 heads=8,
623 norm_queries=True,
624 ):
625 super().__init__()
626
627 self.chan = chan
628 chan_out = chan if chan_out is None else chan_out
629
630
631 self.key_dim = key_dim
632
633 self.value_dim = value_dim
634
635 self.heads = heads
636
637
638 self.norm_queries = norm_queries
639
640 conv_kwargs = {"padding": padding, "stride": stride}
641
642 self.to_q = nn.Conv2d(chan, key_dim * heads, kernel_size, **conv_kwargs)
643
644 self.to_k = nn.Conv2d(chan, key_dim * heads, kernel_size, **conv_kwargs)
645
646 self.to_v = nn.Conv2d(chan, value_dim * heads, kernel_size, **conv_kwargs)
647
648 out_conv_kwargs = {"padding": padding}
649
650 self.to_out = nn.Conv2d(
651 value_dim * heads, chan_out, kernel_size, **out_conv_kwargs
652 )
653
654
655 def forward(self, x, context=None):
656 b, c, h, w, k_dim, heads = *x.shape, self.key_dim, self.heads
657
658 q, k, v = (self.to_q(x), self.to_k(x), self.to_v(x))
659
660 q, k, v = map(lambda t: t.reshape(b, heads, -1, h * w), (q, k, v))
661
662 q, k = map(lambda x: x * (self.key_dim**-0.25), (q, k))
663
664 if context is not None:
665 context = context.reshape(b, c, 1, -1)
666 ck, cv = self.to_k(context), self.to_v(context)
667 ck, cv = map(lambda t: t.reshape(b, heads, k_dim, -1), (ck, cv))
668 k = torch.cat((k, ck), dim=3)
669 v = torch.cat((v, cv), dim=3)
670
671 k = k.softmax(dim=-1)
672
673 if self.norm_queries:
674 q = q.softmax(dim=-2)
675
676 context = torch.einsum("bhdn,bhen->bhde", k, v)
677 out = torch.einsum("bhdn,bhde->bhen", q, context)
678 out = out.reshape(b, -1, h, w)
679 out = self.to_out(out)
680 return out
681
682
683class CBAM_attention(nn.Module):
684 """CBAM attention"""
685
686
688 self,
689 channels,
690 which_conv=SNConv2d,
691 reduction=8,
692 attention_kernel_size=3,
693 ):
694 super(CBAM_attention, self).__init__()
695
696 self.avg_pool = nn.AdaptiveAvgPool2d(1)
697
698 self.max_pool = nn.AdaptiveMaxPool2d(1)
699
700 self.fc1 = which_conv(
701 channels, channels // reduction, kernel_size=1, padding=0
702 )
703
704 self.relu = nn.ReLU(inplace=True)
705
706 self.fc2 = which_conv(
707 channels // reduction, channels, kernel_size=1, padding=0
708 )
709
710 self.sigmoid_channel = nn.Sigmoid()
711
712 self.conv_after_concat = which_conv(
713 2,
714 1,
715 kernel_size=attention_kernel_size,
716 stride=1,
717 padding=attention_kernel_size // 2,
718 )
719
720 self.sigmoid_spatial = nn.Sigmoid()
721
722
723 def forward(self, x):
724 # Channel attention module
725 module_input = x
726 avg = self.avg_pool(x)
727 mx = self.max_pool(x)
728 avg = self.fc1(avg)
729 mx = self.fc1(mx)
730 avg = self.relu(avg)
731 mx = self.relu(mx)
732 avg = self.fc2(avg)
733 mx = self.fc2(mx)
734 x = avg + mx
735 x = self.sigmoid_channel(x)
736 # Spatial attention module
737 x = module_input * x
738 module_input = x
739 # b, c, h, w = x.size()
740 avg = torch.mean(x, 1, True)
741 mx, _ = torch.max(x, 1, True)
742 x = torch.cat((avg, mx), 1)
743 x = self.conv_after_concat(x)
744 x = self.sigmoid_spatial(x)
745 x = module_input * x
746 return x
747
748
749class Attention(nn.Module):
750 """Attention"""
751
752
753 def __init__(self, ch, which_conv=SNConv2d):
754 super(Attention, self).__init__()
755
756 self.ch = ch
757
758 self.which_conv = which_conv
759
760 self.theta = self.which_conv(
761 self.ch, self.ch // 8, kernel_size=1, padding=0, bias=False
762 )
763
764 self.phi = self.which_conv(
765 self.ch, self.ch // 8, kernel_size=1, padding=0, bias=False
766 )
767
768 self.g = self.which_conv(
769 self.ch, self.ch // 2, kernel_size=1, padding=0, bias=False
770 )
771
772 self.o = self.which_conv(
773 self.ch // 2, self.ch, kernel_size=1, padding=0, bias=False
774 )
775
776 self.gamma = P(torch.tensor(0.0), requires_grad=True)
777
778
779 def forward(self, x):
780 # Apply convs
781 theta = self.theta(x)
782 phi = F.max_pool2d(self.phi(x), [2, 2])
783 g = F.max_pool2d(self.g(x), [2, 2])
784 # Perform reshapes
785 theta = theta.view(-1, self.ch // 8, x.shape[2] * x.shape[3])
786 phi = phi.view(-1, self.ch // 8, x.shape[2] * x.shape[3] // 4)
787 g = g.view(-1, self.ch // 2, x.shape[2] * x.shape[3] // 4)
788 # Matmul and softmax to get attention maps
789 beta = F.softmax(torch.bmm(theta.transpose(1, 2), phi), -1)
790 # Attention map times g path
791 o = self.o(
792 torch.bmm(g, beta.transpose(1, 2)).view(
793 -1, self.ch // 2, x.shape[2], x.shape[3]
794 )
795 )
796 return self.gamma * o + x
797
798
799class SNEmbedding(nn.Embedding, SN):
800 """
801 Embedding layer with spectral norm
802 We use num_embeddings as the dim instead of embedding_dim here
803 for convenience sake
804 """
805
806
808 self,
809 num_embeddings,
810 embedding_dim,
811 padding_idx=None,
812 max_norm=None,
813 norm_type=2,
814 scale_grad_by_freq=False,
815 sparse=False,
816 _weight=None,
817 num_svs=1,
818 num_itrs=1,
819 eps=1e-12,
820 ):
821 nn.Embedding.__init__(
822 self,
823 num_embeddings,
824 embedding_dim,
825 padding_idx,
826 max_norm,
827 norm_type,
828 scale_grad_by_freq,
829 sparse,
830 _weight,
831 )
832 SN.__init__(self, num_svs, num_itrs, num_embeddings, eps=eps)
833
834
835 def forward(self, x):
836 return F.embedding(x, self.W_())
837
838
839def scaled_dot_product(q, k, v):
840 d_k = q.size()[-1]
841 attn_logits = torch.matmul(q, k.transpose(-2, -1))
842 attn_logits = attn_logits / math.sqrt(d_k)
843 attention = F.softmax(attn_logits, dim=-1)
844 values = torch.matmul(attention, v)
845 return values, attention
846
847
848class MultiheadAttention(nn.Module):
849 """MultiheadAttention"""
850
851
852 def __init__(self, input_dim, embed_dim, num_heads, which_linear):
853 super().__init__()
854 assert (
855 embed_dim % num_heads == 0
856 ), "Embedding dimension must be 0 modulo number of heads."
857
858
859 self.embed_dim = embed_dim
860
861 self.num_heads = num_heads
862
863 self.head_dim = embed_dim // num_heads
864
865 self.which_linear = which_linear
866
867 # Stack all weight matrices 1...h together for efficiency
868
869 self.qkv_proj = self.which_linear(input_dim, 3 * embed_dim)
870
871 self.o_proj = self.which_linear(embed_dim, embed_dim)
872
873 self._reset_parameters()
874
875
877 # Original Transformer initialization, see PyTorch documentation
878 nn.init.xavier_uniform_(self.qkv_proj.weight)
879 self.qkv_proj.bias.data.fill_(0)
880 nn.init.xavier_uniform_(self.o_proj.weight)
881 self.o_proj.bias.data.fill_(0)
882
883
884 def forward(self, x, return_attention=False):
885 batch_size, seq_length, embed_dim = x.size()
886 qkv = self.qkv_proj(x)
887
888 # Separate Q, K, V from linear output
889 qkv = qkv.reshape(batch_size, seq_length, self.num_heads, 3 * self.head_dim)
890 qkv = qkv.permute(0, 2, 1, 3) # [Batch, Head, SeqLen, Dims]
891 q, k, v = qkv.chunk(3, dim=-1)
892
893 # Determine value outputs
894 values, attention = scaled_dot_product(q, k, v)
895 values = values.permute(0, 2, 1, 3) # [Batch, SeqLen, Head, Dims]
896 values = values.reshape(batch_size, seq_length, embed_dim)
897 o = self.o_proj(values)
898
899 if return_attention:
900 return o, attention
901 else:
902 return o
903
904
905class EncoderBlock(nn.Module):
906 """EncoderBlock"""
907
908
909 def __init__(self, input_dim, num_heads, dim_feedforward, dropout, which_linear):
910 """
911 Inputs:
912 input_dim - Dimensionality of the input
913 num_heads - Number of heads to use in the attention block
914 dim_feedforward - Dimensionality of the hidden layer in the MLP
915 dropout - Dropout probability to use in the dropout layers
916 """
917 super().__init__()
918
919
920 self.which_linear = which_linear
921
923 input_dim, input_dim, num_heads, which_linear
924 )
925
926
927 self.linear_net = nn.Sequential(
928 self.which_linear(input_dim, dim_feedforward),
929 nn.Dropout(dropout),
930 nn.ReLU(inplace=True),
931 self.which_linear(dim_feedforward, input_dim),
932 )
933
934 # Layers to apply in between the main layers
935
936 self.norm1 = nn.LayerNorm(input_dim)
937
938 self.norm2 = nn.LayerNorm(input_dim)
939
940 self.dropout = nn.Dropout(dropout)
941
942
943 def forward(self, x):
944 # Attention part
945 x_pre1 = self.norm1(x)
946 attn_out = self.self_attn(x_pre1)
947 x = x + self.dropout(attn_out)
948 # x = self.norm1(x)
949
950 # MLP part
951 x_pre2 = self.norm2(x)
952 linear_out = self.linear_net(x_pre2)
953 x = x + self.dropout(linear_out)
954 # x = self.norm2(x)
955
956 return x
957
958
959class RelationalReasoning(nn.Module):
960 """RelationalReasoning"""
961
962
963 def __init__(self, num_layers, hidden_dim, **block_args):
964 super().__init__()
965
966 self.layers = nn.ModuleList(
967 [EncoderBlock(**block_args) for _ in range(num_layers)]
968 )
969
970 self.norm = nn.LayerNorm(hidden_dim)
971
972
973 def forward(self, x):
974 for layer in self.layers:
975 x = layer(x)
976
977 x = self.norm(x)
978 return x
979
980
981 def get_attention_maps(self, x):
982 attention_maps = []
983 for layer in self.layers:
984 _, attn_map = layer.self_attn(x, return_attention=True)
985 attention_maps.append(attn_map)
986 x = layer(x)
987 return attention_maps
988
989
990class GBlock(nn.Module):
991 """GBlock"""
992
993
995 self,
996 in_channels,
997 out_channels,
998 which_conv=SNConv2d,
999 which_bn=bn,
1000 activation=None,
1001 upsample=None,
1002 channel_ratio=4,
1003 ):
1004 super(GBlock, self).__init__()
1005
1006
1007 self.in_channels, self.out_channels = in_channels, out_channels
1008
1009 self.hidden_channels = self.in_channels // channel_ratio
1010
1011 self.which_conv, self.which_bn = which_conv, which_bn
1012
1013 self.activation = activation
1014 # Conv layers
1015
1016 self.conv1 = self.which_conv(
1017 self.in_channels, self.hidden_channels, kernel_size=1, padding=0
1018 )
1019
1020 self.conv2 = self.which_conv(self.hidden_channels, self.hidden_channels)
1021
1022 self.conv3 = self.which_conv(self.hidden_channels, self.hidden_channels)
1023
1024 self.conv4 = self.which_conv(
1025 self.hidden_channels, self.out_channels, kernel_size=1, padding=0
1026 )
1027 # Batchnorm layers
1028
1029 self.bn1 = self.which_bn(self.in_channels)
1030
1031 self.bn2 = self.which_bn(self.hidden_channels)
1032
1033 self.bn3 = self.which_bn(self.hidden_channels)
1034
1035 self.bn4 = self.which_bn(self.hidden_channels)
1036
1037 self.upsample = upsample
1038
1039
1040 def forward(self, x, y):
1041 # Project down to channel ratio
1042 h = self.conv1(self.activation(self.bn1(x, y)))
1043 # Apply next BN-ReLU
1044 h = self.activation(self.bn2(h, y))
1045 # Drop channels in x if necessary
1046 if self.in_channels != self.out_channels:
1047 x = x[:, : self.out_channels]
1048 # Upsample both h and x at this point
1049 if self.upsample:
1050 h = self.upsample(h)
1051 x = self.upsample(x)
1052 # 3x3 convs
1053 h = self.conv2(h)
1054 h = self.conv3(self.activation(self.bn3(h, y)))
1055 # Final 1x1 conv
1056 h = self.conv4(self.activation(self.bn4(h, y)))
1057 return h + x
1058
1059
1060def G_arch(ch=64, attention="64"):
1061 arch = {}
1062 arch[512] = {
1063 "in_channels": [ch * item for item in [16, 16, 8, 8, 4, 2, 1]],
1064 "out_channels": [ch * item for item in [16, 8, 8, 4, 2, 1, 1]],
1065 "upsample": [True] * 7,
1066 "resolution": [8, 16, 32, 64, 128, 256, 512],
1067 "attention": {
1068 2**i: (2 ** i in [int(item) for item in attention.split("_")])
1069 for i in range(3, 10)
1070 },
1071 }
1072 arch[256] = {
1073 "in_channels": [ch * item for item in [16, 16, 8, 8, 4, 2]],
1074 "out_channels": [ch * item for item in [16, 8, 8, 4, 2, 1]],
1075 "upsample": [True] * 6,
1076 "resolution": [8, 16, 32, 64, 128, 256],
1077 "attention": {
1078 2**i: (2 ** i in [int(item) for item in attention.split("_")])
1079 for i in range(3, 9)
1080 },
1081 }
1082 arch[128] = {
1083 "in_channels": [ch * item for item in [16, 16, 8, 4, 2]],
1084 "out_channels": [ch * item for item in [16, 8, 4, 2, 1]],
1085 "upsample": [True] * 5,
1086 "resolution": [8, 16, 32, 64, 128],
1087 "attention": {
1088 2**i: (2 ** i in [int(item) for item in attention.split("_")])
1089 for i in range(3, 8)
1090 },
1091 }
1092 arch[96] = {
1093 "in_channels": [ch * item for item in [16, 16, 8, 4]],
1094 "out_channels": [ch * item for item in [16, 8, 4, 2]],
1095 "upsample": [True] * 4,
1096 "resolution": [12, 24, 48, 96],
1097 "attention": {
1098 12 * 2**i: (6 * 2 ** i in [int(item) for item in attention.split("_")])
1099 for i in range(0, 4)
1100 },
1101 }
1102
1103 arch[64] = {
1104 "in_channels": [ch * item for item in [16, 16, 8, 4]],
1105 "out_channels": [ch * item for item in [16, 8, 4, 2]],
1106 "upsample": [True] * 4,
1107 "resolution": [8, 16, 32, 64],
1108 "attention": {
1109 2**i: (2 ** i in [int(item) for item in attention.split("_")])
1110 for i in range(3, 7)
1111 },
1112 }
1113 arch[32] = {
1114 "in_channels": [ch * item for item in [4, 4, 4]],
1115 "out_channels": [ch * item for item in [4, 4, 4]],
1116 "upsample": [True] * 3,
1117 "resolution": [8, 16, 32],
1118 "attention": {
1119 2**i: (2 ** i in [int(item) for item in attention.split("_")])
1120 for i in range(3, 6)
1121 },
1122 }
1123
1124 return arch
1125
1126
1127class Generator(nn.Module):
1128 """Generator"""
1129
1130
1132 self,
1133 G_ch=64,
1134 G_depth=2,
1135 dim_z=128,
1136 bottom_width=4,
1137 resolution=256,
1138 G_kernel_size=3,
1139 G_attn="64",
1140 n_classes=40,
1141 H_base=1,
1142 num_G_SVs=1,
1143 num_G_SV_itrs=1,
1144 attn_type="sa",
1145 G_shared=True,
1146 shared_dim=128,
1147 hier=True,
1148 cross_replica=False,
1149 mybn=False,
1150 G_activation=nn.ReLU(inplace=False),
1151 G_lr=5e-5,
1152 G_B1=0.0,
1153 G_B2=0.999,
1154 adam_eps=1e-8,
1155 BN_eps=1e-5,
1156 SN_eps=1e-12,
1157 G_init="ortho",
1158 G_mixed_precision=False,
1159 G_fp16=False,
1160 skip_init=False,
1161 no_optim=False,
1162 sched_version="default",
1163 RRM_prx_G=True,
1164 n_head_G=2,
1165 G_param="SN",
1166 norm_style="bn",
1167 **kwargs
1168 ):
1169 super(Generator, self).__init__()
1170
1171 self.ch = G_ch
1172
1173 self.G_depth = G_depth
1174
1175 self.dim_z = dim_z
1176
1177 self.bottom_width = bottom_width
1178
1179 self.H_base = H_base
1180
1181 self.resolution = resolution
1182
1183 self.kernel_size = G_kernel_size
1184
1185 self.attention = G_attn
1186
1187 self.n_classes = n_classes
1188
1189 self.G_shared = G_shared
1190
1191 self.shared_dim = shared_dim if shared_dim > 0 else dim_z
1192
1193 self.hier = hier
1194
1195 self.cross_replica = cross_replica
1196
1197 self.mybn = mybn
1198 # nonlinearity for residual blocks
1199 if G_activation == "inplace_relu":
1200
1201 self.activation = torch.nn.ReLU(inplace=True)
1202 elif G_activation == "relu":
1203 self.activation = torch.nn.ReLU(inplace=False)
1204 elif G_activation == "leaky_relu":
1205 self.activation = torch.nn.LeakyReLU(0.2, inplace=False)
1206 else:
1207 raise NotImplementedError("activation function not implemented")
1208
1209 self.init = G_init
1210
1211 self.G_param = G_param
1212
1213 self.norm_style = norm_style
1214
1215 self.BN_eps = BN_eps
1216
1217 self.SN_eps = SN_eps
1218
1219 self.fp16 = G_fp16
1220
1221 self.arch = G_arch(self.ch, self.attention)[resolution]
1222
1223 self.RRM_prx_G = RRM_prx_G
1224
1225 self.n_head_G = n_head_G
1226
1227 # Which convs, batchnorms, and linear layers to use
1228 if self.G_param == "SN":
1229
1230 self.which_conv = functools.partial(
1231 SNConv2d,
1232 kernel_size=3,
1233 padding=1,
1234 num_svs=num_G_SVs,
1235 num_itrs=num_G_SV_itrs,
1236 eps=self.SN_eps,
1237 )
1238
1239 self.which_linear = functools.partial(
1240 SNLinear,
1241 num_svs=num_G_SVs,
1242 num_itrs=num_G_SV_itrs,
1243 eps=self.SN_eps,
1244 )
1245 else:
1246 self.which_conv = functools.partial(nn.Conv2d, kernel_size=3, padding=1)
1247 self.which_linear = nn.Linear
1248
1249 # We use a non-spectral-normed embedding here regardless;
1250 # For some reason applying SN to G's embedding seems to randomly cripple G # noqa
1251
1252 self.which_embedding = nn.Embedding
1253 bn_linear = (
1254 functools.partial(self.which_linear, bias=False)
1255 if self.G_shared
1256 else self.which_embedding
1257 )
1258
1259 self.which_bn = functools.partial(
1260 ccbn,
1261 which_linear=bn_linear,
1262 cross_replica=self.cross_replica,
1263 mybn=self.mybn,
1264 input_size=(
1265 self.shared_dim + self.dim_z if self.G_shared else self.n_classes
1266 ),
1267 norm_style=self.norm_style,
1268 eps=self.BN_eps,
1269 )
1270
1271 self.shared = (
1272 self.which_embedding(n_classes, self.shared_dim)
1273 if G_shared
1274 else identity()
1275 )
1276
1277 if self.RRM_prx_G:
1278
1280 num_layers=1,
1281 input_dim=128,
1282 dim_feedforward=128,
1283 which_linear=nn.Linear,
1284 num_heads=self.n_head_G,
1285 dropout=0.0,
1286 hidden_dim=128,
1287 )
1288
1289
1291 self.dim_z + self.shared_dim,
1292 self.arch["in_channels"][0] * ((self.bottom_width**2) * self.H_base),
1293 )
1294
1295 # self.blocks is a doubly-nested list of modules, the outer loop intended # noqa
1296 # to be over blocks at a given resolution (resblocks and/or self-attention) # noqa
1297 # while the inner loop is over a given block
1298
1299 self.blocks = []
1300 for index in range(len(self.arch["out_channels"])):
1301 self.blocks += [
1302 [
1303 GBlock(
1304 in_channels=self.arch["in_channels"][index],
1305 out_channels=self.arch["in_channels"][index]
1306 if g_index == 0
1307 else self.arch["out_channels"][index],
1308 which_conv=self.which_conv,
1309 which_bn=self.which_bn,
1310 activation=self.activation,
1311 upsample=(
1312 functools.partial(F.interpolate, scale_factor=2)
1313 if self.arch["upsample"][index]
1314 and g_index == (self.G_depth - 1)
1315 else None
1316 ),
1317 )
1318 ]
1319 for g_index in range(self.G_depth)
1320 ]
1321
1322 # If attention on this block, attach it to the end
1323 if self.arch["attention"][self.arch["resolution"][index]]:
1324 print(
1325 f"Adding attention layer in G at resolution {self.arch['resolution'][index]:d}"
1326 )
1327
1328 if attn_type == "sa":
1329 self.blocks[-1] += [
1330 Attention(self.arch["out_channels"][index], self.which_conv)
1331 ]
1332 elif attn_type == "cbam":
1333 self.blocks[-1] += [
1335 self.arch["out_channels"][index], self.which_conv
1336 )
1337 ]
1338 elif attn_type == "ila":
1339 self.blocks[-1] += [ILA(self.arch["out_channels"][index])]
1340
1341 # Turn self.blocks into a ModuleList so that it's all properly registered. # noqa
1342 self.blocks = nn.ModuleList([nn.ModuleList(block) for block in self.blocks])
1343
1344 # output layer: batchnorm-relu-conv.
1345 # Consider using a non-spectral conv here
1346
1347 self.output_layer = nn.Sequential(
1348 bn(
1349 self.arch["out_channels"][-1],
1350 cross_replica=self.cross_replica,
1351 mybn=self.mybn,
1352 ),
1353 self.activation,
1354 self.which_conv(self.arch["out_channels"][-1], 1),
1355 )
1356
1357 # Initialize weights. Optionally skip init for testing.
1358 if not skip_init:
1359 self.init_weights()
1360
1361 # Set up optimizer
1362 # If this is an EMA copy, no need for an optim, so just return now
1363 if no_optim:
1364 return
1365
1366 self.lr = G_lr
1367
1368 self.B1 = G_B1
1369
1370 self.B2 = G_B2
1371
1372 self.adam_eps = adam_eps
1373 if G_mixed_precision:
1374 print("Using fp16 adam in G...")
1375 import utils
1376
1377 self.optim = utils.Adam16(
1378 params=self.parameters(),
1379 lr=self.lr,
1380 betas=(self.B1, self.B2),
1381 weight_decay=0,
1382 eps=self.adam_eps,
1383 )
1384
1385
1386 self.optim = optim.Adam(
1387 params=self.parameters(),
1388 lr=self.lr,
1389 betas=(self.B1, self.B2),
1390 weight_decay=0,
1391 eps=self.adam_eps,
1392 )
1393 # LR scheduling
1394 if sched_version == "default":
1395
1396 self.lr_sched = None
1397 elif sched_version == "CosAnnealLR":
1398 self.lr_sched = optim.lr_scheduler.CosineAnnealingLR(
1399 self.optim,
1400 T_max=kwargs["num_epochs"],
1401 eta_min=self.lr / 4,
1402 last_epoch=-1,
1403 )
1404 elif sched_version == "CosAnnealWarmRes":
1405 self.lr_sched = optim.lr_scheduler.CosineAnnealingWarmRestarts(
1406 self.optim, T_0=10, T_mult=2, eta_min=self.lr / 4
1407 )
1408 else:
1409 self.lr_sched = None
1410
1411
1412 def init_weights(self):
1413
1415 for module in self.modules():
1416 if (
1417 isinstance(module, nn.Conv2d)
1418 or isinstance(module, nn.Linear)
1419 or isinstance(module, nn.Embedding)
1420 ):
1421 if self.init == "ortho":
1422 init.orthogonal_(module.weight)
1423 elif self.init == "N02":
1424 init.normal_(module.weight, 0, 0.02)
1425 elif self.init in ["glorot", "xavier"]:
1426 init.xavier_uniform_(module.weight)
1427 else:
1428 print("Init style not recognized...")
1429 self.param_count += sum(
1430 [p.data.nelement() for p in module.parameters()]
1431 )
1432 print(f"Param count for G's initialized parameters: {self.param_count}")
1433
1434
1435 def forward(self, z, y):
1436 y = self.shared(y)
1437 # If relational embedding
1438 if self.RRM_prx_G:
1439 y = self.RR_G(y.unsqueeze(0)).squeeze(0)
1440 # y = F.normalize(y, dim=1)
1441 # If hierarchical, concatenate zs and ys
1442 if self.hier: # y and z are [bs,128] dimensional
1443 z = torch.cat([y, z], 1)
1444 y = z
1445 # First linear layer
1446 h = self.linear(z) # ([bs,256]-->[bs,24576])
1447 # Reshape
1448 h = h.view(h.size(0), -1, self.bottom_width, self.bottom_width * self.H_base)
1449 # Loop over blocks
1450 for _, blocklist in enumerate(self.blocks):
1451 # Second inner loop in case block has multiple layers
1452 for block in blocklist:
1453 h = block(h, y)
1454
1455 # Apply batchnorm-relu-conv-tanh at output
1456 return torch.tanh(self.output_layer(h))
1457
1458
1460 """
1461 Generator subclass
1462 default initializing with CONFIG dict
1463 """
1464
1465
1466 def __init__(self):
1467 super().__init__(**CONFIG)
1468
1469
1470def generate(model: nn.Module):
1471 """
1472 Run inference with the provided Generator model
1473
1474 Args:
1475 model (nn.Module): Generator model
1476
1477 Returns:
1478 torch.Tensor: batch of 40 PXD images
1479 """
1480 device = next(model.parameters()).device
1481 with torch.no_grad():
1482 latents = torch.randn(40, 128, device=device)
1483 labels = torch.tensor(list(range(40)), dtype=torch.long, device=device)
1484 imgs = model(latents, labels).detach().cpu()
1485 # Cut the noise below 7 ADU
1486 imgs = F.threshold(imgs, -0.26, -1)
1487 # center range [-1, 1] to [0, 1]
1488 imgs = imgs.mul_(0.5).add_(0.5)
1489 # renormalize and convert to uint8
1490 imgs = torch.pow(256, imgs).add_(-1).clamp_(0, 255).to(torch.uint8)
1491 # flatten channel dimension and crop 256 to 250
1492 imgs = imgs[:, 0, 3:-3, :]
1493 return imgs
def __init__(self, ch, which_conv=SNConv2d)
Constructor.
Definition: ieagan.py:753
def __init__(self, channels, which_conv=SNConv2d, reduction=8, attention_kernel_size=3)
Constructor.
Definition: ieagan.py:693
conv_after_concat
convolution after concatenation
Definition: ieagan.py:712
def __init__(self, input_dim, num_heads, dim_feedforward, dropout, which_linear)
Constructor.
Definition: ieagan.py:909
def __init__(self, in_channels, out_channels, which_conv=SNConv2d, which_bn=bn, activation=None, upsample=None, channel_ratio=4)
Constructor.
Definition: ieagan.py:1003
bottom_width
The initial spatial dimensions.
Definition: ieagan.py:1177
shared_dim
Dimensionality of the shared embedding? Unused if not using G_shared.
Definition: ieagan.py:1191
H_base
The initial harizontal dimension.
Definition: ieagan.py:1179
n_classes
number of classes, for use in categorical conditional generation
Definition: ieagan.py:1187
def __init__(self, G_ch=64, G_depth=2, dim_z=128, bottom_width=4, resolution=256, G_kernel_size=3, G_attn="64", n_classes=40, H_base=1, num_G_SVs=1, num_G_SV_itrs=1, attn_type="sa", G_shared=True, shared_dim=128, hier=True, cross_replica=False, mybn=False, G_activation=nn.ReLU(inplace=False), G_lr=5e-5, G_B1=0.0, G_B2=0.999, adam_eps=1e-8, BN_eps=1e-5, SN_eps=1e-12, G_init="ortho", G_mixed_precision=False, G_fp16=False, skip_init=False, no_optim=False, sched_version="default", RRM_prx_G=True, n_head_G=2, G_param="SN", norm_style="bn", **kwargs)
Constructor.
Definition: ieagan.py:1168
G_depth
Number of resblocks per stage.
Definition: ieagan.py:1173
dim_z
Dimensionality of the latent space.
Definition: ieagan.py:1175
def forward(self, x, context=None)
forward
Definition: ieagan.py:655
def __init__(self, chan, chan_out=None, kernel_size=1, padding=0, stride=1, key_dim=32, value_dim=64, heads=8, norm_queries=True)
Constructor.
Definition: ieagan.py:624
def __init__(self, input_dim, embed_dim, num_heads, which_linear)
Constructor.
Definition: ieagan.py:852
def forward(self, x, return_attention=False)
forward
Definition: ieagan.py:884
def __init__(self, num_layers, hidden_dim, **block_args)
Constructor.
Definition: ieagan.py:963
def get_attention_maps(self, x)
get attention maps
Definition: ieagan.py:981
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, num_svs=1, num_itrs=1, eps=1e-12)
Constructor.
Definition: ieagan.py:303
def __init__(self, num_embeddings, embedding_dim, padding_idx=None, max_norm=None, norm_type=2, scale_grad_by_freq=False, sparse=False, _weight=None, num_svs=1, num_itrs=1, eps=1e-12)
Constructor.
Definition: ieagan.py:820
def __init__(self, in_features, out_features, bias=True, num_svs=1, num_itrs=1, eps=1e-12)
Constructor.
Definition: ieagan.py:344
def __init__(self, num_svs, num_itrs, num_outputs, transpose=False, eps=1e-12)
Definition: ieagan.py:232
num_itrs
Number of power iterations per step.
Definition: ieagan.py:236
num_svs
Number of singular values.
Definition: ieagan.py:238
eps
Epsilon value for avoiding divide-by-0.
Definition: ieagan.py:242
def __init__(self, output_size, eps=1e-5, momentum=0.1, cross_replica=False, mybn=False)
Constructor.
Definition: ieagan.py:473
gain
Prepare gain and bias layers.
Definition: ieagan.py:478
cross_replica
Use cross-replica batchnorm?
Definition: ieagan.py:486
eps
epsilon to avoid dividing by 0
Definition: ieagan.py:482
def __init__(self, output_size, input_size, which_linear, eps=1e-5, momentum=0.1, cross_replica=False, mybn=False, norm_style="bn")
Constructor.
Definition: ieagan.py:536
gain
Prepare gain and bias layers.
Definition: ieagan.py:541
cross_replica
Use cross-replica batchnorm?
Definition: ieagan.py:549
eps
epsilon to avoid dividing by 0
Definition: ieagan.py:545
def forward(self, torch.Tensor tensor)
forward
Definition: ieagan.py:222
def forward(self, x, gain, bias)
forward
Definition: ieagan.py:429
def __init__(self, num_channels, eps=1e-5, momentum=0.1)
Constructor.
Definition: ieagan.py:406
accumulate_standing
Accumulate running means and vars.
Definition: ieagan.py:419
def reset_stats(self)
reset standing stats
Definition: ieagan.py:422
momentum
momentum for updating running stats
Definition: ieagan.py:409
eps
epsilon to avoid dividing by 0
Definition: ieagan.py:411