1010@triton .jit
1111def _page_io (
1212 mem_index_ptr ,
13+ token_num ,
14+ page_write_head_num ,
1315 k_page_ptr ,
1416 k_page_stride_size ,
1517 k_page_stride_layer_num ,
@@ -45,88 +47,91 @@ def _page_io(
4547 k_stride_size = tl .cast (k_stride_size , dtype = tl .int64 )
4648 v_stride_size = tl .cast (v_stride_size , dtype = tl .int64 )
4749
48- tid = tl .program_id (0 )
49- kv_head_id = tl .program_id (1 )
50- page_head_id = page_head_start + kv_head_id
50+ start_index = tl .program_id (0 )
51+ grid_num = tl .num_programs (0 )
5152
52- mem_index = tl .load (mem_index_ptr + tid )
53- off_dim = tl .arange (0 , HEAD_DIM_BLOCK )
54- if NEED_MASK :
55- mask = off_dim < head_dim
56- else :
57- mask = None
53+ for tid in tl .range (start_index , token_num , step = grid_num ):
54+ for kv_head_id in tl .range (page_write_head_num ):
5855
59- for layer_index in tl .range (layer_num , num_stages = 3 ):
60- if IS_WRITE :
61- k_tensor = tl .load (
62- k_ptr
63- + layer_index * k_stride_layer_num
64- + mem_index * k_stride_size
65- + kv_head_id * k_stride_head
66- + off_dim * k_stride_dim ,
67- mask = mask ,
68- )
69- v_tensor = tl .load (
70- v_ptr
71- + layer_index * v_stride_layer_num
72- + mem_index * v_stride_size
73- + kv_head_id * v_stride_head
74- + off_dim * v_stride_dim ,
75- mask = mask ,
76- )
77- tl .store (
78- k_page_ptr
79- + tid * k_page_stride_size
80- + layer_index * k_page_stride_layer_num
81- + page_head_id * k_page_stride_head
82- + off_dim * k_page_stride_dim ,
83- k_tensor ,
84- mask = mask ,
85- )
86- tl .store (
87- v_page_ptr
88- + tid * v_page_stride_size
89- + layer_index * v_page_stride_layer_num
90- + page_head_id * v_page_stride_head
91- + off_dim * v_page_stride_dim ,
92- v_tensor ,
93- mask = mask ,
94- )
95- else :
96- k_page_tensor = tl .load (
97- k_page_ptr
98- + tid * k_page_stride_size
99- + layer_index * k_page_stride_layer_num
100- + page_head_id * k_page_stride_head
101- + off_dim * k_page_stride_dim ,
102- mask = mask ,
103- )
104- v_page_tensor = tl .load (
105- v_page_ptr
106- + tid * v_page_stride_size
107- + layer_index * v_page_stride_layer_num
108- + page_head_id * v_page_stride_head
109- + off_dim * v_page_stride_dim ,
110- mask = mask ,
111- )
112- tl .store (
113- k_ptr
114- + layer_index * k_stride_layer_num
115- + mem_index * k_stride_size
116- + kv_head_id * k_stride_head
117- + off_dim * k_stride_dim ,
118- k_page_tensor ,
119- mask = mask ,
120- )
121- tl .store (
122- v_ptr
123- + layer_index * v_stride_layer_num
124- + mem_index * v_stride_size
125- + kv_head_id * v_stride_head
126- + off_dim * v_stride_dim ,
127- v_page_tensor ,
128- mask = mask ,
129- )
56+ page_head_id = page_head_start + kv_head_id
57+ mem_index = tl .load (mem_index_ptr + tid )
58+ off_dim = tl .arange (0 , HEAD_DIM_BLOCK )
59+ if NEED_MASK :
60+ mask = off_dim < head_dim
61+ else :
62+ mask = None
63+
64+ for layer_index in tl .range (layer_num , num_stages = 3 ):
65+ if IS_WRITE :
66+ k_tensor = tl .load (
67+ k_ptr
68+ + layer_index * k_stride_layer_num
69+ + mem_index * k_stride_size
70+ + kv_head_id * k_stride_head
71+ + off_dim ,
72+ mask = mask ,
73+ )
74+ v_tensor = tl .load (
75+ v_ptr
76+ + layer_index * v_stride_layer_num
77+ + mem_index * v_stride_size
78+ + kv_head_id * v_stride_head
79+ + off_dim ,
80+ mask = mask ,
81+ )
82+ tl .store (
83+ k_page_ptr
84+ + tid * k_page_stride_size
85+ + layer_index * k_page_stride_layer_num
86+ + page_head_id * k_page_stride_head
87+ + off_dim ,
88+ k_tensor ,
89+ mask = mask ,
90+ )
91+ tl .store (
92+ v_page_ptr
93+ + tid * v_page_stride_size
94+ + layer_index * v_page_stride_layer_num
95+ + page_head_id * v_page_stride_head
96+ + off_dim ,
97+ v_tensor ,
98+ mask = mask ,
99+ )
100+ else :
101+ k_page_tensor = tl .load (
102+ k_page_ptr
103+ + tid * k_page_stride_size
104+ + layer_index * k_page_stride_layer_num
105+ + page_head_id * k_page_stride_head
106+ + off_dim ,
107+ mask = mask ,
108+ )
109+ v_page_tensor = tl .load (
110+ v_page_ptr
111+ + tid * v_page_stride_size
112+ + layer_index * v_page_stride_layer_num
113+ + page_head_id * v_page_stride_head
114+ + off_dim ,
115+ mask = mask ,
116+ )
117+ tl .store (
118+ k_ptr
119+ + layer_index * k_stride_layer_num
120+ + mem_index * k_stride_size
121+ + kv_head_id * k_stride_head
122+ + off_dim ,
123+ k_page_tensor ,
124+ mask = mask ,
125+ )
126+ tl .store (
127+ v_ptr
128+ + layer_index * v_stride_layer_num
129+ + mem_index * v_stride_size
130+ + kv_head_id * v_stride_head
131+ + off_dim ,
132+ v_page_tensor ,
133+ mask = mask ,
134+ )
130135 return
131136
132137
@@ -169,10 +174,17 @@ def page_io(
169174 page_head_start = tp_index * (page_write_head_num )
170175
171176 token_num = len (mem_indexes )
172- grid = (token_num , page_write_head_num )
177+ grid = (128 ,)
178+
179+ assert k_page_tensor .stride (3 ) == 1
180+ assert v_page_tensor .stride (3 ) == 1
181+ assert k_buffer .stride (3 ) == 1
182+ assert v_buffer .stride (3 ) == 1
173183
174184 _page_io [grid ](
175185 mem_index_ptr = mem_indexes ,
186+ token_num = token_num ,
187+ page_write_head_num = page_write_head_num ,
176188 k_page_ptr = k_page_tensor ,
177189 k_page_stride_size = k_page_tensor .stride (0 ),
178190 k_page_stride_layer_num = k_page_tensor .stride (1 ),
@@ -207,6 +219,7 @@ def page_io(
207219@triton .jit
208220def _mla_page_io (
209221 mem_index_ptr ,
222+ token_num ,
210223 page_ptr ,
211224 page_stride_size ,
212225 page_stride_layer_num ,
@@ -227,52 +240,54 @@ def _mla_page_io(
227240 kv_stride_layer_num = tl .cast (kv_stride_layer_num , dtype = tl .int64 )
228241 kv_stride_size = tl .cast (kv_stride_size , dtype = tl .int64 )
229242
230- tid = tl .program_id (0 )
243+ start_index = tl .program_id (0 )
244+ grid_num = tl .num_programs (0 )
231245
232- mem_index = tl .load (mem_index_ptr + tid )
233- off_dim = tl .arange (0 , HEAD_DIM_BLOCK )
234- if NEED_MASK :
235- mask = off_dim < head_dim
236- else :
237- mask = None
238-
239- for layer_index in tl .range (layer_num , num_stages = 3 ):
240- if IS_WRITE :
241- kv_tensor = tl .load (
242- kv_ptr
243- + layer_index * kv_stride_layer_num
244- + mem_index * kv_stride_size
245- + 0 * kv_stride_head
246- + off_dim * kv_stride_dim ,
247- mask = mask ,
248- )
249- tl .store (
250- page_ptr
251- + tid * page_stride_size
252- + layer_index * page_stride_layer_num
253- + 0 * page_stride_head
254- + off_dim * page_stride_dim ,
255- kv_tensor ,
256- mask = mask ,
257- )
246+ for tid in tl .range (start_index , token_num , step = grid_num ):
247+ mem_index = tl .load (mem_index_ptr + tid )
248+ off_dim = tl .arange (0 , HEAD_DIM_BLOCK )
249+ if NEED_MASK :
250+ mask = off_dim < head_dim
258251 else :
259- page_tensor = tl .load (
260- page_ptr
261- + tid * page_stride_size
262- + layer_index * page_stride_layer_num
263- + 0 * page_stride_head
264- + off_dim * page_stride_dim ,
265- mask = mask ,
266- )
267- tl .store (
268- kv_ptr
269- + layer_index * kv_stride_layer_num
270- + mem_index * kv_stride_size
271- + 0 * kv_stride_head
272- + off_dim * kv_stride_dim ,
273- page_tensor ,
274- mask = mask ,
275- )
252+ mask = None
253+
254+ for layer_index in tl .range (layer_num , num_stages = 3 ):
255+ if IS_WRITE :
256+ kv_tensor = tl .load (
257+ kv_ptr
258+ + layer_index * kv_stride_layer_num
259+ + mem_index * kv_stride_size
260+ + 0 * kv_stride_head
261+ + off_dim * kv_stride_dim ,
262+ mask = mask ,
263+ )
264+ tl .store (
265+ page_ptr
266+ + tid * page_stride_size
267+ + layer_index * page_stride_layer_num
268+ + 0 * page_stride_head
269+ + off_dim * page_stride_dim ,
270+ kv_tensor ,
271+ mask = mask ,
272+ )
273+ else :
274+ page_tensor = tl .load (
275+ page_ptr
276+ + tid * page_stride_size
277+ + layer_index * page_stride_layer_num
278+ + 0 * page_stride_head
279+ + off_dim * page_stride_dim ,
280+ mask = mask ,
281+ )
282+ tl .store (
283+ kv_ptr
284+ + layer_index * kv_stride_layer_num
285+ + mem_index * kv_stride_size
286+ + 0 * kv_stride_head
287+ + off_dim * kv_stride_dim ,
288+ page_tensor ,
289+ mask = mask ,
290+ )
276291 return
277292
278293
@@ -290,10 +305,11 @@ def mla_page_io(mem_indexes: torch.Tensor, page_tensor: torch.Tensor, kv_buffer:
290305 assert page_head_num == kv_head_num == 1
291306
292307 token_num = len (mem_indexes )
293- grid = (token_num ,)
308+ grid = (64 ,)
294309
295310 _mla_page_io [grid ](
296311 mem_index_ptr = mem_indexes ,
312+ token_num = token_num ,
297313 page_ptr = page_tensor ,
298314 page_stride_size = page_tensor .stride (0 ),
299315 page_stride_layer_num = page_tensor .stride (1 ),
0 commit comments