55import ops .triton .torch
66
77
8- def torch_rope (input , sin_table , cos_table , interleaved = True ):
8+ def torch_rotary_position_embedding (input , sin_table , cos_table , interleaved = True ):
99 batch_size , seq_len , num_heads , emb_dim = input .shape
1010
1111 assert emb_dim % 2 == 0 , "The embedding dimension must be even."
@@ -55,11 +55,15 @@ def _generate_sin_and_cos_tables(
5555 sin_table , cos_table = _generate_sin_and_cos_tables (seq_len , emb_dim )
5656 x = torch .randn (batch_size , seq_len , num_heads , emb_dim , dtype = dtype , device = device )
5757
58- ninetoothed_output = ops .ninetoothed .torch .rope (
58+ ninetoothed_output = ops .ninetoothed .torch .rotary_position_embedding (
59+ x , sin_table , cos_table , interleaved = False
60+ )
61+ torch_output = torch_rotary_position_embedding (
62+ x , sin_table , cos_table , interleaved = False
63+ )
64+ triton_output = ops .triton .torch .rotary_position_embedding (
5965 x , sin_table , cos_table , interleaved = False
6066 )
61- torch_output = torch_rope (x , sin_table , cos_table , interleaved = False )
62- triton_output = ops .triton .torch .rope (x , sin_table , cos_table , interleaved = False )
6367
6468 print (ninetoothed_output )
6569 print (torch_output )
@@ -83,7 +87,7 @@ def _generate_sin_and_cos_tables(
8387 line_names = ["NineToothed" , "PyTorch" , "Triton" ],
8488 styles = [("blue" , "-" ), ("green" , "-" ), ("orange" , "-" )],
8589 ylabel = "ms" ,
86- plot_name = "rope -performance" ,
90+ plot_name = "rotary_position_embedding -performance" ,
8791 args = {},
8892 )
8993 )
@@ -98,13 +102,19 @@ def benchmark(seq_len, provider):
98102
99103 if provider == "ninetoothed" :
100104 ms = triton .testing .do_bench (
101- lambda : ops .ninetoothed .torch .rope (x , sin_table , cos_table )
105+ lambda : ops .ninetoothed .torch .rotary_position_embedding (
106+ x , sin_table , cos_table
107+ )
102108 )
103109 elif provider == "torch" :
104- ms = triton .testing .do_bench (lambda : torch_rope (x , sin_table , cos_table ))
110+ ms = triton .testing .do_bench (
111+ lambda : torch_rotary_position_embedding (x , sin_table , cos_table )
112+ )
105113 elif provider == "triton" :
106114 ms = triton .testing .do_bench (
107- lambda : ops .triton .torch .rope (x , sin_table , cos_table )
115+ lambda : ops .triton .torch .rotary_position_embedding (
116+ x , sin_table , cos_table
117+ )
108118 )
109119
110120 return ms
0 commit comments