From fdf7d89c1094749e5ae1ed2114ee6397a866c93b Mon Sep 17 00:00:00 2001 From: edge-observer <2192672599@qq.com> Date: Fri, 15 Nov 2024 10:18:11 +0800 Subject: [PATCH] Modified the forward function of Attention so that it can operate on batches --- vit_pytorch/vit.py | 26 +++++++++++++++++--------- 1 file changed, 17 insertions(+), 9 deletions(-) diff --git a/vit_pytorch/vit.py b/vit_pytorch/vit.py index 5b34a44..bfcad3c 100644 --- a/vit_pytorch/vit.py +++ b/vit_pytorch/vit.py @@ -48,19 +48,27 @@ def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.): ) if project_out else nn.Identity() def forward(self, x): - x = self.norm(x) - - qkv = self.to_qkv(x).chunk(3, dim = -1) - q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv) - + assert x.dim() == 3, "Input x must have three dimensions: (batch_size, sequence_length, embedding_dim)" + + qkv = self.to_qkv(x) + q, k, v = self.rearrange(qkv).chunk(3, dim=-1) + dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale - + + if mask is not None: + mask = mask.unsqueeze(1).expand(dots.size(0), self.heads, dots.size(2), dots.size(3)) + dots = dots.masked_fill(~mask, -torch.finfo(dots.dtype).max) + attn = self.attend(dots) attn = self.dropout(attn) - + out = torch.matmul(attn, v) - out = rearrange(out, 'b h n d -> b n (h d)') - return self.to_out(out) + + out = self.rearrange(out, 'b h n d -> b n (h d)', h=self.heads).contiguous() + + out = self.to_out(out) + + return out class Transformer(nn.Module): def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):