Skip to content

Commit c8ae164

Browse files
committed
Rename rope to rotary_position_embedding
1 parent 72d552c commit c8ae164

8 files changed

Lines changed: 46 additions & 28 deletions

File tree

compare_code_metrics.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ def _key_from_kernel_name(path, kernel_name):
131131
return str(path / f"{kernel_name}.py").removeprefix(str(_PARENT_PATH))[1:]
132132

133133
data = {
134-
f"{_BACKSLASH_CHAR}texttt{{{kernel_name.replace('scaled_dot_product_attention', 'sdpa').replace('_', f'{_BACKSLASH_CHAR}_')}}}": {
134+
f"{_BACKSLASH_CHAR}texttt{{{kernel_name.replace('scaled_dot_product_attention', 'sdpa').replace('rotary_position_embedding', 'rope').replace('_', f'{_BACKSLASH_CHAR}_')}}}": {
135135
"Triton": {
136136
metric_name: data[
137137
_key_from_kernel_name(_TRITON_KERNELS_PATH, kernel_name)

infer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from linear import Linear, bmm_backend
1010
from scaled_dot_product_attention import (
1111
Attention,
12-
rope_backend,
12+
rotary_position_embedding_backend,
1313
scaled_dot_product_attention_backend,
1414
)
1515
from silu import SiLU, silu_backend
@@ -94,7 +94,7 @@
9494
with (
9595
bmm_backend(backend),
9696
rms_norm_backend(backend),
97-
rope_backend(backend),
97+
rotary_position_embedding_backend(backend),
9898
scaled_dot_product_attention_backend(backend),
9999
silu_backend(backend),
100100
):

ops/ninetoothed/kernels/rope.py renamed to ops/ninetoothed/kernels/rotary_position_embedding.py

File renamed without changes.

ops/ninetoothed/torch.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import ops.ninetoothed.kernels.fused_rms_norm
1010
import ops.ninetoothed.kernels.mm
1111
import ops.ninetoothed.kernels.rms_norm
12-
import ops.ninetoothed.kernels.rope
12+
import ops.ninetoothed.kernels.rotary_position_embedding
1313
import ops.ninetoothed.kernels.scaled_dot_product_attention
1414
import ops.ninetoothed.kernels.silu
1515
import ops.ninetoothed.kernels.softmax
@@ -92,14 +92,16 @@ def rms_norm(input, eps=None):
9292
return output
9393

9494

95-
def rope(input, sin_table, cos_table, interleaved=True):
95+
def rotary_position_embedding(input, sin_table, cos_table, interleaved=True):
9696
batch_size, _, num_heads, _ = input.shape
9797

9898
output = input.clone()
9999
sin_table = sin_table[None, :, None, :].expand(batch_size, -1, num_heads, -1)
100100
cos_table = cos_table[None, :, None, :].expand(batch_size, -1, num_heads, -1)
101101

102-
ops.ninetoothed.kernels.rope.kernel(output, sin_table, cos_table, interleaved)
102+
ops.ninetoothed.kernels.rotary_position_embedding.kernel(
103+
output, sin_table, cos_table, interleaved
104+
)
103105

104106
return output
105107

ops/triton/kernels/rope.py renamed to ops/triton/kernels/rotary_position_embedding.py

File renamed without changes.

ops/triton/torch.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import ops.triton.kernels.fused_rms_norm
1111
import ops.triton.kernels.mm
1212
import ops.triton.kernels.rms_norm
13-
import ops.triton.kernels.rope
13+
import ops.triton.kernels.rotary_position_embedding
1414
import ops.triton.kernels.scaled_dot_product_attention
1515
import ops.triton.kernels.silu
1616
import ops.triton.kernels.softmax
@@ -195,14 +195,16 @@ def rms_norm(input, eps=None):
195195
return output
196196

197197

198-
def rope(input, sin_table, cos_table, interleaved=True):
198+
def rotary_position_embedding(input, sin_table, cos_table, interleaved=True):
199199
batch_size, seq_len, num_heads, emb_dim = input.shape
200200

201201
BLOCK_SIZE = triton.next_power_of_2(emb_dim // 2)
202202

203203
output = input.clone()
204204

205-
ops.triton.kernels.rope.kernel[(batch_size, seq_len, num_heads)](
205+
ops.triton.kernels.rotary_position_embedding.kernel[
206+
(batch_size, seq_len, num_heads)
207+
](
206208
output,
207209
sin_table,
208210
cos_table,

rope.py renamed to rotary_position_embedding.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import ops.triton.torch
66

77

8-
def torch_rope(input, sin_table, cos_table, interleaved=True):
8+
def torch_rotary_position_embedding(input, sin_table, cos_table, interleaved=True):
99
batch_size, seq_len, num_heads, emb_dim = input.shape
1010

1111
assert emb_dim % 2 == 0, "The embedding dimension must be even."
@@ -55,11 +55,15 @@ def _generate_sin_and_cos_tables(
5555
sin_table, cos_table = _generate_sin_and_cos_tables(seq_len, emb_dim)
5656
x = torch.randn(batch_size, seq_len, num_heads, emb_dim, dtype=dtype, device=device)
5757

58-
ninetoothed_output = ops.ninetoothed.torch.rope(
58+
ninetoothed_output = ops.ninetoothed.torch.rotary_position_embedding(
59+
x, sin_table, cos_table, interleaved=False
60+
)
61+
torch_output = torch_rotary_position_embedding(
62+
x, sin_table, cos_table, interleaved=False
63+
)
64+
triton_output = ops.triton.torch.rotary_position_embedding(
5965
x, sin_table, cos_table, interleaved=False
6066
)
61-
torch_output = torch_rope(x, sin_table, cos_table, interleaved=False)
62-
triton_output = ops.triton.torch.rope(x, sin_table, cos_table, interleaved=False)
6367

6468
print(ninetoothed_output)
6569
print(torch_output)
@@ -83,7 +87,7 @@ def _generate_sin_and_cos_tables(
8387
line_names=["NineToothed", "PyTorch", "Triton"],
8488
styles=[("blue", "-"), ("green", "-"), ("orange", "-")],
8589
ylabel="ms",
86-
plot_name="rope-performance",
90+
plot_name="rotary_position_embedding-performance",
8791
args={},
8892
)
8993
)
@@ -98,13 +102,19 @@ def benchmark(seq_len, provider):
98102

99103
if provider == "ninetoothed":
100104
ms = triton.testing.do_bench(
101-
lambda: ops.ninetoothed.torch.rope(x, sin_table, cos_table)
105+
lambda: ops.ninetoothed.torch.rotary_position_embedding(
106+
x, sin_table, cos_table
107+
)
102108
)
103109
elif provider == "torch":
104-
ms = triton.testing.do_bench(lambda: torch_rope(x, sin_table, cos_table))
110+
ms = triton.testing.do_bench(
111+
lambda: torch_rotary_position_embedding(x, sin_table, cos_table)
112+
)
105113
elif provider == "triton":
106114
ms = triton.testing.do_bench(
107-
lambda: ops.triton.torch.rope(x, sin_table, cos_table)
115+
lambda: ops.triton.torch.rotary_position_embedding(
116+
x, sin_table, cos_table
117+
)
108118
)
109119

110120
return ms

scaled_dot_product_attention.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,13 @@
88

99
import ops.ninetoothed.torch
1010
import ops.triton.torch
11-
from rope import torch_rope
11+
from rotary_position_embedding import torch_rotary_position_embedding
1212

1313

1414
class Attention(nn.Module):
1515
scaled_dot_product_attention = None
1616

17-
rope = None
17+
rotary_position_embedding = None
1818

1919
def __init__(self, other):
2020
super().__init__()
@@ -41,8 +41,12 @@ def forward(
4141
sin_table = sin_table[0]
4242
cos_table = cos_table[0]
4343

44-
query_states = type(self).rope(query_states, sin_table, cos_table)
45-
key_states = type(self).rope(key_states, sin_table, cos_table)
44+
query_states = type(self).rotary_position_embedding(
45+
query_states, sin_table, cos_table
46+
)
47+
key_states = type(self).rotary_position_embedding(
48+
key_states, sin_table, cos_table
49+
)
4650

4751
query_states = query_states.transpose(1, 2)
4852
key_states = key_states.transpose(1, 2)
@@ -94,24 +98,24 @@ def scaled_dot_product_attention_backend(backend_name):
9498

9599

96100
@contextmanager
97-
def rope_backend(backend_name):
98-
_prev_impl = Attention.rope
101+
def rotary_position_embedding_backend(backend_name):
102+
_prev_impl = Attention.rotary_position_embedding
99103

100104
if backend_name == "ninetoothed":
101-
impl = ops.ninetoothed.torch.rope
105+
impl = ops.ninetoothed.torch.rotary_position_embedding
102106
elif backend_name == "triton":
103-
impl = ops.triton.torch.rope
107+
impl = ops.triton.torch.rotary_position_embedding
104108
elif backend_name == "torch":
105-
impl = torch_rope
109+
impl = torch_rotary_position_embedding
106110
else:
107111
raise ValueError(f"unknown backend: `{backend_name}`")
108112

109-
Attention.rope = impl
113+
Attention.rotary_position_embedding = impl
110114

111115
try:
112116
yield
113117
finally:
114-
Attention.rope = _prev_impl
118+
Attention.rotary_position_embedding = _prev_impl
115119

116120

117121
if __name__ == "__main__":

0 commit comments

Comments
 (0)