Belle II Software development
ILA Class Reference
Inheritance diagram for ILA:

Public Member Functions

 __init__ (self, chan, chan_out=None, kernel_size=1, padding=0, stride=1, key_dim=32, value_dim=64, heads=8, norm_queries=True)
 Constructor.
 
 forward (self, x, context=None)
 forward
 

Public Attributes

 chan = chan
 chan
 
 key_dim = key_dim
 key dimension
 
 value_dim = value_dim
 value dimension
 
 heads = heads
 heads
 
 norm_queries = norm_queries
 norm queries
 
 to_q = nn.Conv2d(chan, key_dim * heads, kernel_size, **conv_kwargs)
 q
 
 to_k = nn.Conv2d(chan, key_dim * heads, kernel_size, **conv_kwargs)
 k
 
 to_v = nn.Conv2d(chan, value_dim * heads, kernel_size, **conv_kwargs)
 v
 
 to_out
 to out
 

Detailed Description

Image_Linear_Attention

Definition at line 635 of file ieagan.py.

Constructor & Destructor Documentation

◆ __init__()

__init__ ( self,
chan,
chan_out = None,
kernel_size = 1,
padding = 0,
stride = 1,
key_dim = 32,
value_dim = 64,
heads = 8,
norm_queries = True )

Constructor.

Definition at line 641 of file ieagan.py.

652 ):
653 super().__init__()
654
655 self.chan = chan
656 chan_out = chan if chan_out is None else chan_out
657
658
659 self.key_dim = key_dim
660
661 self.value_dim = value_dim
662
663 self.heads = heads
664
665
666 self.norm_queries = norm_queries
667
668 conv_kwargs = {"padding": padding, "stride": stride}
669
670 self.to_q = nn.Conv2d(chan, key_dim * heads, kernel_size, **conv_kwargs)
671
672 self.to_k = nn.Conv2d(chan, key_dim * heads, kernel_size, **conv_kwargs)
673
674 self.to_v = nn.Conv2d(chan, value_dim * heads, kernel_size, **conv_kwargs)
675
676 out_conv_kwargs = {"padding": padding}
677
678 self.to_out = nn.Conv2d(
679 value_dim * heads, chan_out, kernel_size, **out_conv_kwargs
680 )
681

Member Function Documentation

◆ forward()

forward ( self,
x,
context = None )

forward

Definition at line 683 of file ieagan.py.

683 def forward(self, x, context=None):
684 b, c, h, w, k_dim, heads = *x.shape, self.key_dim, self.heads
685
686 q, k, v = (self.to_q(x), self.to_k(x), self.to_v(x))
687
688 q, k, v = map(lambda t: t.reshape(b, heads, -1, h * w), (q, k, v))
689
690 q, k = map(lambda x: x * (self.key_dim**-0.25), (q, k))
691
692 if context is not None:
693 context = context.reshape(b, c, 1, -1)
694 ck, cv = self.to_k(context), self.to_v(context)
695 ck, cv = map(lambda t: t.reshape(b, heads, k_dim, -1), (ck, cv))
696 k = torch.cat((k, ck), dim=3)
697 v = torch.cat((v, cv), dim=3)
698
699 k = k.softmax(dim=-1)
700
701 if self.norm_queries:
702 q = q.softmax(dim=-2)
703
704 context = torch.einsum("bhdn,bhen->bhde", k, v)
705 out = torch.einsum("bhdn,bhde->bhen", q, context)
706 out = out.reshape(b, -1, h, w)
707 out = self.to_out(out)
708 return out
709
710
STL class.

Member Data Documentation

◆ chan

chan = chan

chan

Definition at line 655 of file ieagan.py.

◆ heads

heads = heads

heads

Definition at line 663 of file ieagan.py.

◆ key_dim

key_dim = key_dim

key dimension

Definition at line 659 of file ieagan.py.

◆ norm_queries

norm_queries = norm_queries

norm queries

Definition at line 666 of file ieagan.py.

◆ to_k

to_k = nn.Conv2d(chan, key_dim * heads, kernel_size, **conv_kwargs)

k

Definition at line 672 of file ieagan.py.

◆ to_out

to_out
Initial value:
= nn.Conv2d(
value_dim * heads, chan_out, kernel_size, **out_conv_kwargs
)

to out

Definition at line 678 of file ieagan.py.

◆ to_q

to_q = nn.Conv2d(chan, key_dim * heads, kernel_size, **conv_kwargs)

q

Definition at line 670 of file ieagan.py.

◆ to_v

to_v = nn.Conv2d(chan, value_dim * heads, kernel_size, **conv_kwargs)

v

Definition at line 674 of file ieagan.py.

◆ value_dim

value_dim = value_dim

value dimension

Definition at line 661 of file ieagan.py.


The documentation for this class was generated from the following file: