diff --git a/labml_nn/transformers/rope/__init__.py b/labml_nn/transformers/rope/__init__.py index 7729cbd1..08092ac4 100644 --- a/labml_nn/transformers/rope/__init__.py +++ b/labml_nn/transformers/rope/__init__.py @@ -185,7 +185,7 @@ def forward(self, x: torch.Tensor): # \end{align} # # for $i \in {1, 2, ..., \frac{d}{2}}$ - x_rope = (x_rope * self.cos_cached[:x.shape[0]]) + (neg_half_x * self.sin_cached[:x.shape[0]]) + x_rope = (x_rope * self.cos_cached[:x_rope.shape[0]]) + (neg_half_x * self.sin_cached[:x_rope.shape[0]]) # return torch.cat((x_rope, x_pass), dim=-1)