Belle II Software development
ILA Class Reference
Inheritance diagram for ILA:

Public Member Functions

def __init__ (self, chan, chan_out=None, kernel_size=1, padding=0, stride=1, key_dim=32, value_dim=64, heads=8, norm_queries=True)
 Constructor.
 
def forward (self, x, context=None)
 forward
 

Public Attributes

 chan
 chan
 
 key_dim
 key dimension
 
 value_dim
 value dimension
 
 heads
 heads
 
 norm_queries
 norm queries
 
 to_q
 q
 
 to_k
 k
 
 to_v
 v
 
 to_out
 to out
 

Detailed Description

Image_Linear_Attention

Definition at line 607 of file ieagan.py.

Constructor & Destructor Documentation

◆ __init__()

def __init__ (   self,
  chan,
  chan_out = None,
  kernel_size = 1,
  padding = 0,
  stride = 1,
  key_dim = 32,
  value_dim = 64,
  heads = 8,
  norm_queries = True 
)

Constructor.

Definition at line 613 of file ieagan.py.

624 ):
625 super().__init__()
626
627 self.chan = chan
628 chan_out = chan if chan_out is None else chan_out
629
630
631 self.key_dim = key_dim
632
633 self.value_dim = value_dim
634
635 self.heads = heads
636
637
638 self.norm_queries = norm_queries
639
640 conv_kwargs = {"padding": padding, "stride": stride}
641
642 self.to_q = nn.Conv2d(chan, key_dim * heads, kernel_size, **conv_kwargs)
643
644 self.to_k = nn.Conv2d(chan, key_dim * heads, kernel_size, **conv_kwargs)
645
646 self.to_v = nn.Conv2d(chan, value_dim * heads, kernel_size, **conv_kwargs)
647
648 out_conv_kwargs = {"padding": padding}
649
650 self.to_out = nn.Conv2d(
651 value_dim * heads, chan_out, kernel_size, **out_conv_kwargs
652 )
653

Member Function Documentation

◆ forward()

def forward (   self,
  x,
  context = None 
)

forward

Definition at line 655 of file ieagan.py.

655 def forward(self, x, context=None):
656 b, c, h, w, k_dim, heads = *x.shape, self.key_dim, self.heads
657
658 q, k, v = (self.to_q(x), self.to_k(x), self.to_v(x))
659
660 q, k, v = map(lambda t: t.reshape(b, heads, -1, h * w), (q, k, v))
661
662 q, k = map(lambda x: x * (self.key_dim**-0.25), (q, k))
663
664 if context is not None:
665 context = context.reshape(b, c, 1, -1)
666 ck, cv = self.to_k(context), self.to_v(context)
667 ck, cv = map(lambda t: t.reshape(b, heads, k_dim, -1), (ck, cv))
668 k = torch.cat((k, ck), dim=3)
669 v = torch.cat((v, cv), dim=3)
670
671 k = k.softmax(dim=-1)
672
673 if self.norm_queries:
674 q = q.softmax(dim=-2)
675
676 context = torch.einsum("bhdn,bhen->bhde", k, v)
677 out = torch.einsum("bhdn,bhde->bhen", q, context)
678 out = out.reshape(b, -1, h, w)
679 out = self.to_out(out)
680 return out
681
682

Member Data Documentation

◆ chan

chan

chan

Definition at line 627 of file ieagan.py.

◆ heads

heads

heads

Definition at line 635 of file ieagan.py.

◆ key_dim

key_dim

key dimension

Definition at line 631 of file ieagan.py.

◆ norm_queries

norm_queries

norm queries

Definition at line 638 of file ieagan.py.

◆ to_k

to_k

k

Definition at line 644 of file ieagan.py.

◆ to_out

to_out

to out

Definition at line 650 of file ieagan.py.

◆ to_q

to_q

q

Definition at line 642 of file ieagan.py.

◆ to_v

to_v

v

Definition at line 646 of file ieagan.py.

◆ value_dim

value_dim

value dimension

Definition at line 633 of file ieagan.py.


The documentation for this class was generated from the following file: