912 def forward(self, x, return_attention=False):
913 batch_size, seq_length, embed_dim = x.size()
914 qkv = self.qkv_proj(x)
915
916
917 qkv = qkv.reshape(batch_size, seq_length, self.num_heads, 3 * self.head_dim)
918 qkv = qkv.permute(0, 2, 1, 3)
919 q, k, v = qkv.chunk(3, dim=-1)
920
921
922 values, attention = scaled_dot_product(q, k, v)
923 values = values.permute(0, 2, 1, 3)
924 values = values.reshape(batch_size, seq_length, embed_dim)
925 o = self.o_proj(values)
926
927 if return_attention:
928 return o, attention
929 else:
930 return o
931
932