1313import triton .language as tl
1414from triton .language .extra import libdevice # type: ignore[attr-defined]
1515
16- # Note: adding `bool` or `str` type annotations to these load/store helper functions doesn't work
17-
1816# FP8 representation on CUDA and ROCm
1917_FP8_DTYPES : Final = [torch .float8_e4m3fn , torch .float8_e4m3fnuz ]
2018
2119
22- @triton .jit # type: ignore[misc]
23- def _load_2d_block_ptr ( # type: ignore[no-untyped-def]
24- data_ptr : tl .tensor ,
25- mask_first_dim ,
26- mask_second_dim ,
27- padding_option ,
28- ) -> tl .tensor :
29- """Load a 2D tensor with custom strides and offsets."""
30- if mask_first_dim and mask_second_dim :
31- # Load with boundary check on both dimensions
32- data = tl .load (data_ptr , boundary_check = (0 , 1 ), padding_option = padding_option )
33- elif mask_first_dim :
34- # Load with boundary check on first dimension only
35- data = tl .load (data_ptr , boundary_check = (0 ,), padding_option = padding_option )
36- elif mask_second_dim :
37- # Load with boundary check on second dimension only
38- data = tl .load (data_ptr , boundary_check = (1 ,), padding_option = padding_option )
39- else :
40- # Load without boundary check
41- data = tl .load (data_ptr )
42-
43- return data
44-
45-
46- @triton .jit # type: ignore[misc]
47- def _load ( # type: ignore[no-untyped-def]
48- data_ptr : tl .tensor ,
49- use_mask ,
50- mask : tl .tensor ,
51- other ,
52- ) -> tl .tensor :
53- """Load a 1D tensor with custom strides and offsets."""
54- if use_mask :
55- # Load with mask
56- data = tl .load (data_ptr , mask = mask , other = other )
57- else :
58- # Load without mask
59- data = tl .load (data_ptr )
60-
61- return data
62-
63-
64- @triton .jit # type: ignore[misc]
65- def _store ( # type: ignore[no-untyped-def]
66- data_ptr : tl .tensor ,
67- value : tl .tensor ,
68- use_mask ,
69- mask : tl .tensor ,
70- ) -> None :
71- """Store a 1D tensor with custom strides and offsets."""
72- if use_mask :
73- # Store with mask
74- tl .store (data_ptr , value , mask = mask )
75- else :
76- # Store without mask
77- tl .store (data_ptr , value )
78-
79-
8020@triton .jit # type: ignore[misc]
8121def _varlen_attention_compute_splits_kernel ( # noqa: PLR0913, PLR0915
8222 # Pointers to tensors
@@ -261,15 +201,11 @@ def _varlen_attention_compute_splits_kernel( # noqa: PLR0913, PLR0915
261201 # Mask out query elements that are just for padding
262202 query_mask = query_split_group_seq_mask [:, None ] & query_split_group_head_mask [:, None ] & head_mask [None , :]
263203
264- # Determine whether or not we need masking for different dimensions
265- needs_query_split_mask = end_seqlen_q > this_query_length
266- needs_query_group_mask = query_group_size != cxpr_query_group_size_padded
267- needs_head_mask = head_size != cxpr_head_size_padded
268- needs_query_mask = (needs_query_split_mask or needs_query_group_mask ) or needs_head_mask
204+ # Only need causal masking if enabled and this program is not processing a decode
269205 needs_causal_mask = cxpr_is_causal and not is_pure_decode
270206
271207 # Load queries
272- query = _load (query_ptr + query_offsets , use_mask = needs_query_mask , mask = query_mask , other = 0.0 )
208+ query = tl . load (query_ptr + query_offsets , mask = query_mask , other = 0.0 )
273209
274210 if cxpr_apply_fp8_scaling :
275211 q_scale = tl .load (q_scale_ptr )
@@ -301,9 +237,6 @@ def _varlen_attention_compute_splits_kernel( # noqa: PLR0913, PLR0915
301237 cache_block_offsets = tl .arange (0 , cxpr_cache_block_size )
302238 cache_block_mask = cache_block_offsets < num_entries_in_cache_block
303239
304- needs_cache_block_mask = num_entries_in_cache_block != cxpr_cache_block_size
305- needs_qk_mask = (needs_query_split_mask or needs_query_group_mask ) or needs_cache_block_mask
306-
307240 # Offset from the block_table row for the current batch by the number of cache blocks
308241 current_cache_block_number_ptr = current_block_table_ptr + cache_block_index
309242 physical_cache_block_number = tl .load (current_cache_block_number_ptr )
@@ -319,9 +252,8 @@ def _varlen_attention_compute_splits_kernel( # noqa: PLR0913, PLR0915
319252
320253 key_block_mask = head_mask [:, None ] & cache_block_mask [None , :]
321254
322- key_block = _load (
255+ key_block = tl . load (
323256 key_cache_ptr + kv_cache_block_index_offset + key_block_offsets ,
324- use_mask = (needs_cache_block_mask or needs_head_mask ),
325257 mask = key_block_mask ,
326258 other = 0.0 ,
327259 )
@@ -345,9 +277,7 @@ def _varlen_attention_compute_splits_kernel( # noqa: PLR0913, PLR0915
345277 causal_mask = query_split_group_seq_offsets [:, None ] >= effective_seqlen_k_offsets [None , :]
346278 qk_mask = qk_mask & causal_mask
347279
348- if needs_qk_mask or needs_causal_mask :
349- # Set masked out elements to -inf
350- qk = tl .where (qk_mask , qk , - float ("inf" )).to (dtype )
280+ qk = tl .where (qk_mask , qk , - float ("inf" )).to (dtype )
351281
352282 # Handle softcapping
353283 if cxpr_is_softcap :
@@ -378,9 +308,8 @@ def _varlen_attention_compute_splits_kernel( # noqa: PLR0913, PLR0915
378308
379309 value_block_mask = cache_block_mask [:, None ] & head_mask [None , :]
380310
381- value_block = _load (
311+ value_block = tl . load (
382312 value_cache_ptr + kv_cache_block_index_offset + value_block_offsets ,
383- use_mask = (needs_cache_block_mask or needs_head_mask ),
384313 mask = value_block_mask ,
385314 other = 0.0 ,
386315 )
@@ -415,10 +344,9 @@ def _varlen_attention_compute_splits_kernel( # noqa: PLR0913, PLR0915
415344 )
416345
417346 # Store output scratchpad results
418- _store (
347+ tl . store (
419348 output_scratchpad_ptr + output_scratch_offsets ,
420349 output ,
421- use_mask = needs_query_mask ,
422350 mask = query_mask ,
423351 )
424352
@@ -439,10 +367,9 @@ def _varlen_attention_compute_splits_kernel( # noqa: PLR0913, PLR0915
439367 lse_mask = query_split_group_seq_mask & query_split_group_head_mask
440368
441369 # Store lse scratchpad results
442- _store (
370+ tl . store (
443371 lse_scratchpad_ptr + lse_scratch_offsets ,
444372 lse ,
445- use_mask = (needs_query_split_mask or needs_query_group_mask ),
446373 mask = lse_mask ,
447374 )
448375
@@ -522,8 +449,6 @@ def _varlen_attention_reduce_splits_kernel( # noqa: PLR0913
522449 head_offsets = tl .arange (0 , cxpr_head_size_padded )
523450 # Mask to only read valid indices of the actual head size
524451 head_mask = head_offsets < head_size
525- # Only enable masking if its necessary
526- needs_head_mask = head_size != cxpr_head_size_padded
527452
528453 # Iterate through every cache block for the current sequence
529454 for kv_split_index in range (num_kv_splits_this_seq ):
@@ -537,9 +462,8 @@ def _varlen_attention_reduce_splits_kernel( # noqa: PLR0913
537462 )
538463
539464 # Load output for this cache block, shape -> (cxpr_head_size_padded,)
540- block_output = _load (
465+ block_output = tl . load (
541466 output_scratchpad_ptr + output_scratchpad_offsets ,
542- use_mask = needs_head_mask ,
543467 mask = head_mask ,
544468 other = 0.0 ,
545469 )
@@ -583,10 +507,9 @@ def _varlen_attention_reduce_splits_kernel( # noqa: PLR0913
583507 output_offsets = batch_index * output_batch_stride + query_head_index * output_head_stride + head_offsets
584508
585509 # Store final result
586- _store (
510+ tl . store (
587511 output_ptr + output_offsets ,
588512 output ,
589- use_mask = needs_head_mask ,
590513 mask = head_mask ,
591514 )
592515
0 commit comments