Skip to content

Commit 33bf58a

Browse files
Remove unnecessary conditional masking
1 parent af95d2e commit 33bf58a

1 file changed

Lines changed: 9 additions & 86 deletions

File tree

conch/kernels/attention/varlen_attention.py

Lines changed: 9 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -13,70 +13,10 @@
1313
import triton.language as tl
1414
from triton.language.extra import libdevice # type: ignore[attr-defined]
1515

16-
# Note: adding `bool` or `str` type annotations to these load/store helper functions doesn't work
17-
1816
# FP8 representation on CUDA and ROCm
1917
_FP8_DTYPES: Final = [torch.float8_e4m3fn, torch.float8_e4m3fnuz]
2018

2119

22-
@triton.jit # type: ignore[misc]
23-
def _load_2d_block_ptr( # type: ignore[no-untyped-def]
24-
data_ptr: tl.tensor,
25-
mask_first_dim,
26-
mask_second_dim,
27-
padding_option,
28-
) -> tl.tensor:
29-
"""Load a 2D tensor with custom strides and offsets."""
30-
if mask_first_dim and mask_second_dim:
31-
# Load with boundary check on both dimensions
32-
data = tl.load(data_ptr, boundary_check=(0, 1), padding_option=padding_option)
33-
elif mask_first_dim:
34-
# Load with boundary check on first dimension only
35-
data = tl.load(data_ptr, boundary_check=(0,), padding_option=padding_option)
36-
elif mask_second_dim:
37-
# Load with boundary check on second dimension only
38-
data = tl.load(data_ptr, boundary_check=(1,), padding_option=padding_option)
39-
else:
40-
# Load without boundary check
41-
data = tl.load(data_ptr)
42-
43-
return data
44-
45-
46-
@triton.jit # type: ignore[misc]
47-
def _load( # type: ignore[no-untyped-def]
48-
data_ptr: tl.tensor,
49-
use_mask,
50-
mask: tl.tensor,
51-
other,
52-
) -> tl.tensor:
53-
"""Load a 1D tensor with custom strides and offsets."""
54-
if use_mask:
55-
# Load with mask
56-
data = tl.load(data_ptr, mask=mask, other=other)
57-
else:
58-
# Load without mask
59-
data = tl.load(data_ptr)
60-
61-
return data
62-
63-
64-
@triton.jit # type: ignore[misc]
65-
def _store( # type: ignore[no-untyped-def]
66-
data_ptr: tl.tensor,
67-
value: tl.tensor,
68-
use_mask,
69-
mask: tl.tensor,
70-
) -> None:
71-
"""Store a 1D tensor with custom strides and offsets."""
72-
if use_mask:
73-
# Store with mask
74-
tl.store(data_ptr, value, mask=mask)
75-
else:
76-
# Store without mask
77-
tl.store(data_ptr, value)
78-
79-
8020
@triton.jit # type: ignore[misc]
8121
def _varlen_attention_compute_splits_kernel( # noqa: PLR0913, PLR0915
8222
# Pointers to tensors
@@ -261,15 +201,11 @@ def _varlen_attention_compute_splits_kernel( # noqa: PLR0913, PLR0915
261201
# Mask out query elements that are just for padding
262202
query_mask = query_split_group_seq_mask[:, None] & query_split_group_head_mask[:, None] & head_mask[None, :]
263203

264-
# Determine whether or not we need masking for different dimensions
265-
needs_query_split_mask = end_seqlen_q > this_query_length
266-
needs_query_group_mask = query_group_size != cxpr_query_group_size_padded
267-
needs_head_mask = head_size != cxpr_head_size_padded
268-
needs_query_mask = (needs_query_split_mask or needs_query_group_mask) or needs_head_mask
204+
# Only need causal masking if enabled and this program is not processing a decode
269205
needs_causal_mask = cxpr_is_causal and not is_pure_decode
270206

271207
# Load queries
272-
query = _load(query_ptr + query_offsets, use_mask=needs_query_mask, mask=query_mask, other=0.0)
208+
query = tl.load(query_ptr + query_offsets, mask=query_mask, other=0.0)
273209

274210
if cxpr_apply_fp8_scaling:
275211
q_scale = tl.load(q_scale_ptr)
@@ -301,9 +237,6 @@ def _varlen_attention_compute_splits_kernel( # noqa: PLR0913, PLR0915
301237
cache_block_offsets = tl.arange(0, cxpr_cache_block_size)
302238
cache_block_mask = cache_block_offsets < num_entries_in_cache_block
303239

304-
needs_cache_block_mask = num_entries_in_cache_block != cxpr_cache_block_size
305-
needs_qk_mask = (needs_query_split_mask or needs_query_group_mask) or needs_cache_block_mask
306-
307240
# Offset from the block_table row for the current batch by the number of cache blocks
308241
current_cache_block_number_ptr = current_block_table_ptr + cache_block_index
309242
physical_cache_block_number = tl.load(current_cache_block_number_ptr)
@@ -319,9 +252,8 @@ def _varlen_attention_compute_splits_kernel( # noqa: PLR0913, PLR0915
319252

320253
key_block_mask = head_mask[:, None] & cache_block_mask[None, :]
321254

322-
key_block = _load(
255+
key_block = tl.load(
323256
key_cache_ptr + kv_cache_block_index_offset + key_block_offsets,
324-
use_mask=(needs_cache_block_mask or needs_head_mask),
325257
mask=key_block_mask,
326258
other=0.0,
327259
)
@@ -345,9 +277,7 @@ def _varlen_attention_compute_splits_kernel( # noqa: PLR0913, PLR0915
345277
causal_mask = query_split_group_seq_offsets[:, None] >= effective_seqlen_k_offsets[None, :]
346278
qk_mask = qk_mask & causal_mask
347279

348-
if needs_qk_mask or needs_causal_mask:
349-
# Set masked out elements to -inf
350-
qk = tl.where(qk_mask, qk, -float("inf")).to(dtype)
280+
qk = tl.where(qk_mask, qk, -float("inf")).to(dtype)
351281

352282
# Handle softcapping
353283
if cxpr_is_softcap:
@@ -378,9 +308,8 @@ def _varlen_attention_compute_splits_kernel( # noqa: PLR0913, PLR0915
378308

379309
value_block_mask = cache_block_mask[:, None] & head_mask[None, :]
380310

381-
value_block = _load(
311+
value_block = tl.load(
382312
value_cache_ptr + kv_cache_block_index_offset + value_block_offsets,
383-
use_mask=(needs_cache_block_mask or needs_head_mask),
384313
mask=value_block_mask,
385314
other=0.0,
386315
)
@@ -415,10 +344,9 @@ def _varlen_attention_compute_splits_kernel( # noqa: PLR0913, PLR0915
415344
)
416345

417346
# Store output scratchpad results
418-
_store(
347+
tl.store(
419348
output_scratchpad_ptr + output_scratch_offsets,
420349
output,
421-
use_mask=needs_query_mask,
422350
mask=query_mask,
423351
)
424352

@@ -439,10 +367,9 @@ def _varlen_attention_compute_splits_kernel( # noqa: PLR0913, PLR0915
439367
lse_mask = query_split_group_seq_mask & query_split_group_head_mask
440368

441369
# Store lse scratchpad results
442-
_store(
370+
tl.store(
443371
lse_scratchpad_ptr + lse_scratch_offsets,
444372
lse,
445-
use_mask=(needs_query_split_mask or needs_query_group_mask),
446373
mask=lse_mask,
447374
)
448375

@@ -522,8 +449,6 @@ def _varlen_attention_reduce_splits_kernel( # noqa: PLR0913
522449
head_offsets = tl.arange(0, cxpr_head_size_padded)
523450
# Mask to only read valid indices of the actual head size
524451
head_mask = head_offsets < head_size
525-
# Only enable masking if its necessary
526-
needs_head_mask = head_size != cxpr_head_size_padded
527452

528453
# Iterate through every cache block for the current sequence
529454
for kv_split_index in range(num_kv_splits_this_seq):
@@ -537,9 +462,8 @@ def _varlen_attention_reduce_splits_kernel( # noqa: PLR0913
537462
)
538463

539464
# Load output for this cache block, shape -> (cxpr_head_size_padded,)
540-
block_output = _load(
465+
block_output = tl.load(
541466
output_scratchpad_ptr + output_scratchpad_offsets,
542-
use_mask=needs_head_mask,
543467
mask=head_mask,
544468
other=0.0,
545469
)
@@ -583,10 +507,9 @@ def _varlen_attention_reduce_splits_kernel( # noqa: PLR0913
583507
output_offsets = batch_index * output_batch_stride + query_head_index * output_head_stride + head_offsets
584508

585509
# Store final result
586-
_store(
510+
tl.store(
587511
output_ptr + output_offsets,
588512
output,
589-
use_mask=needs_head_mask,
590513
mask=head_mask,
591514
)
592515

0 commit comments

Comments
 (0)