Belle II Software development
MultiheadAttention Class Reference
Inheritance diagram for MultiheadAttention:

Public Member Functions

def __init__ (self, input_dim, embed_dim, num_heads, which_linear)
 Constructor.
 
def forward (self, x, return_attention=False)
 forward
 

Public Attributes

 embed_dim
 embedding dimension
 
 num_heads
 number of heads
 
 head_dim
 head dimension
 
 which_linear
 which linear
 
 qkv_proj
 qkv projection
 
 o_proj
 o projection
 

Protected Member Functions

def _reset_parameters (self)
 reset parameters
 

Detailed Description

MultiheadAttention

Definition at line 848 of file ieagan.py.

Constructor & Destructor Documentation

◆ __init__()

def __init__ (   self,
  input_dim,
  embed_dim,
  num_heads,
  which_linear 
)

Constructor.

Definition at line 852 of file ieagan.py.

852 def __init__(self, input_dim, embed_dim, num_heads, which_linear):
853 super().__init__()
854 assert (
855 embed_dim % num_heads == 0
856 ), "Embedding dimension must be 0 modulo number of heads."
857
858
859 self.embed_dim = embed_dim
860
861 self.num_heads = num_heads
862
863 self.head_dim = embed_dim // num_heads
864
865 self.which_linear = which_linear
866
867 # Stack all weight matrices 1...h together for efficiency
868
869 self.qkv_proj = self.which_linear(input_dim, 3 * embed_dim)
870
871 self.o_proj = self.which_linear(embed_dim, embed_dim)
872
873 self._reset_parameters()
874

Member Function Documentation

◆ _reset_parameters()

def _reset_parameters (   self)
protected

reset parameters

Definition at line 876 of file ieagan.py.

876 def _reset_parameters(self):
877 # Original Transformer initialization, see PyTorch documentation
878 nn.init.xavier_uniform_(self.qkv_proj.weight)
879 self.qkv_proj.bias.data.fill_(0)
880 nn.init.xavier_uniform_(self.o_proj.weight)
881 self.o_proj.bias.data.fill_(0)
882

◆ forward()

def forward (   self,
  x,
  return_attention = False 
)

forward

Definition at line 884 of file ieagan.py.

884 def forward(self, x, return_attention=False):
885 batch_size, seq_length, embed_dim = x.size()
886 qkv = self.qkv_proj(x)
887
888 # Separate Q, K, V from linear output
889 qkv = qkv.reshape(batch_size, seq_length, self.num_heads, 3 * self.head_dim)
890 qkv = qkv.permute(0, 2, 1, 3) # [Batch, Head, SeqLen, Dims]
891 q, k, v = qkv.chunk(3, dim=-1)
892
893 # Determine value outputs
894 values, attention = scaled_dot_product(q, k, v)
895 values = values.permute(0, 2, 1, 3) # [Batch, SeqLen, Head, Dims]
896 values = values.reshape(batch_size, seq_length, embed_dim)
897 o = self.o_proj(values)
898
899 if return_attention:
900 return o, attention
901 else:
902 return o
903
904

Member Data Documentation

◆ embed_dim

embed_dim

embedding dimension

Definition at line 859 of file ieagan.py.

◆ head_dim

head_dim

head dimension

Definition at line 863 of file ieagan.py.

◆ num_heads

num_heads

number of heads

Definition at line 861 of file ieagan.py.

◆ o_proj

o_proj

o projection

Definition at line 871 of file ieagan.py.

◆ qkv_proj

qkv_proj

qkv projection

Definition at line 869 of file ieagan.py.

◆ which_linear

which_linear

which linear

Definition at line 865 of file ieagan.py.


The documentation for this class was generated from the following file: