683 def forward(self, x, context=None):
684 b, c, h, w, k_dim, heads = *x.shape, self.key_dim, self.heads
685
686 q, k, v = (self.to_q(x), self.to_k(x), self.to_v(x))
687
688 q, k, v =
map(
lambda t: t.reshape(b, heads, -1, h * w), (q, k, v))
689
690 q, k =
map(
lambda x: x * (self.key_dim**-0.25), (q, k))
691
692 if context is not None:
693 context = context.reshape(b, c, 1, -1)
694 ck, cv = self.to_k(context), self.to_v(context)
695 ck, cv =
map(
lambda t: t.reshape(b, heads, k_dim, -1), (ck, cv))
696 k = torch.cat((k, ck), dim=3)
697 v = torch.cat((v, cv), dim=3)
698
699 k = k.softmax(dim=-1)
700
701 if self.norm_queries:
702 q = q.softmax(dim=-2)
703
704 context = torch.einsum("bhdn,bhen->bhde", k, v)
705 out = torch.einsum("bhdn,bhde->bhen", q, context)
706 out = out.reshape(b, -1, h, w)
707 out = self.to_out(out)
708 return out
709
710