Belle II Software development
MultiheadAttention Class Reference
Inheritance diagram for MultiheadAttention:

Public Member Functions

 __init__ (self, input_dim, embed_dim, num_heads, which_linear)
 Constructor.
 
 forward (self, x, return_attention=False)
 forward
 

Public Attributes

 embed_dim = embed_dim
 embedding dimension
 
 num_heads = num_heads
 number of heads
 
 head_dim = embed_dim // num_heads
 head dimension
 
 which_linear = which_linear
 which linear
 
 qkv_proj = self.which_linear(input_dim, 3 * embed_dim)
 qkv projection
 
 o_proj = self.which_linear(embed_dim, embed_dim)
 o projection
 

Protected Member Functions

 _reset_parameters (self)
 reset parameters
 

Detailed Description

MultiheadAttention

Definition at line 876 of file ieagan.py.

Constructor & Destructor Documentation

◆ __init__()

__init__ ( self,
input_dim,
embed_dim,
num_heads,
which_linear )

Constructor.

Definition at line 880 of file ieagan.py.

880 def __init__(self, input_dim, embed_dim, num_heads, which_linear):
881 super().__init__()
882 assert (
883 embed_dim % num_heads == 0
884 ), "Embedding dimension must be 0 modulo number of heads."
885
886
887 self.embed_dim = embed_dim
888
889 self.num_heads = num_heads
890
891 self.head_dim = embed_dim // num_heads
892
893 self.which_linear = which_linear
894
895 # Stack all weight matrices 1...h together for efficiency
896
897 self.qkv_proj = self.which_linear(input_dim, 3 * embed_dim)
898
899 self.o_proj = self.which_linear(embed_dim, embed_dim)
900
901 self._reset_parameters()
902

Member Function Documentation

◆ _reset_parameters()

_reset_parameters ( self)
protected

reset parameters

Definition at line 904 of file ieagan.py.

904 def _reset_parameters(self):
905 # Original Transformer initialization, see PyTorch documentation
906 nn.init.xavier_uniform_(self.qkv_proj.weight)
907 self.qkv_proj.bias.data.fill_(0)
908 nn.init.xavier_uniform_(self.o_proj.weight)
909 self.o_proj.bias.data.fill_(0)
910

◆ forward()

forward ( self,
x,
return_attention = False )

forward

Definition at line 912 of file ieagan.py.

912 def forward(self, x, return_attention=False):
913 batch_size, seq_length, embed_dim = x.size()
914 qkv = self.qkv_proj(x)
915
916 # Separate Q, K, V from linear output
917 qkv = qkv.reshape(batch_size, seq_length, self.num_heads, 3 * self.head_dim)
918 qkv = qkv.permute(0, 2, 1, 3) # [Batch, Head, SeqLen, Dims]
919 q, k, v = qkv.chunk(3, dim=-1)
920
921 # Determine value outputs
922 values, attention = scaled_dot_product(q, k, v)
923 values = values.permute(0, 2, 1, 3) # [Batch, SeqLen, Head, Dims]
924 values = values.reshape(batch_size, seq_length, embed_dim)
925 o = self.o_proj(values)
926
927 if return_attention:
928 return o, attention
929 else:
930 return o
931
932

Member Data Documentation

◆ embed_dim

embed_dim = embed_dim

embedding dimension

Definition at line 887 of file ieagan.py.

◆ head_dim

head_dim = embed_dim // num_heads

head dimension

Definition at line 891 of file ieagan.py.

◆ num_heads

num_heads = num_heads

number of heads

Definition at line 889 of file ieagan.py.

◆ o_proj

o_proj = self.which_linear(embed_dim, embed_dim)

o projection

Definition at line 899 of file ieagan.py.

◆ qkv_proj

qkv_proj = self.which_linear(input_dim, 3 * embed_dim)

qkv projection

Definition at line 897 of file ieagan.py.

◆ which_linear

which_linear = which_linear

which linear

Definition at line 893 of file ieagan.py.


The documentation for this class was generated from the following file: