655 def forward(self, x, context=None):
656 b, c, h, w, k_dim, heads = *x.shape, self.key_dim, self.heads
657
658 q, k, v = (self.to_q(x), self.to_k(x), self.to_v(x))
659
660 q, k, v =
map(
lambda t: t.reshape(b, heads, -1, h * w), (q, k, v))
661
662 q, k =
map(
lambda x: x * (self.key_dim**-0.25), (q, k))
663
664 if context is not None:
665 context = context.reshape(b, c, 1, -1)
666 ck, cv = self.to_k(context), self.to_v(context)
667 ck, cv =
map(
lambda t: t.reshape(b, heads, k_dim, -1), (ck, cv))
668 k = torch.cat((k, ck), dim=3)
669 v = torch.cat((v, cv), dim=3)
670
671 k = k.softmax(dim=-1)
672
673 if self.norm_queries:
674 q = q.softmax(dim=-2)
675
676 context = torch.einsum("bhdn,bhen->bhde", k, v)
677 out = torch.einsum("bhdn,bhde->bhen", q, context)
678 out = out.reshape(b, -1, h, w)
679 out = self.to_out(out)
680 return out
681
682