884 def forward(self, x, return_attention=False):
885 batch_size, seq_length, embed_dim = x.size()
886 qkv = self.qkv_proj(x)
887
888
889 qkv = qkv.reshape(batch_size, seq_length, self.num_heads, 3 * self.head_dim)
890 qkv = qkv.permute(0, 2, 1, 3)
891 q, k, v = qkv.chunk(3, dim=-1)
892
893
894 values, attention = scaled_dot_product(q, k, v)
895 values = values.permute(0, 2, 1, 3)
896 values = values.reshape(batch_size, seq_length, embed_dim)
897 o = self.o_proj(values)
898
899 if return_attention:
900 return o, attention
901 else:
902 return o
903
904