We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
x
input
torch_rotary_position_embedding
1 parent 3a39afe commit ae14296Copy full SHA for ae14296
1 file changed
rotary_position_embedding.py
@@ -21,8 +21,8 @@ def torch_rotary_position_embedding(input, sin_table, cos_table, interleaved=Tru
21
22
return torch.stack((input_0_rotated, input_1_rotated), dim=-1).view(input.shape)
23
else:
24
- input_0 = x[..., : x.shape[-1] // 2]
25
- input_1 = x[..., x.shape[-1] // 2 :]
+ input_0 = input[..., : input.shape[-1] // 2]
+ input_1 = input[..., input.shape[-1] // 2 :]
26
input_0_rotated = input_0 * cos_table - input_1 * sin_table
27
input_1_rotated = input_0 * sin_table + input_1 * cos_table
28
0 commit comments