Skip to content

Commit bbd247b

Browse files
committed
optimize rowwise
1 parent 87928a4 commit bbd247b

File tree

2 files changed

+112
-108
lines changed

2 files changed

+112
-108
lines changed

src/liger_kernel/ops/backends/_ascend/ops/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,10 +52,10 @@
5252
from liger_kernel.ops.backends._ascend.ops.llama4_rope import LigerLlama4RopeFunction
5353
from liger_kernel.ops.backends._ascend.ops.llama4_rope import llama4_rope_backward
5454
from liger_kernel.ops.backends._ascend.ops.llama4_rope import llama4_rope_forward
55+
from liger_kernel.ops.backends._ascend.ops.multi_token_attention import LigerMultiTokenAttentionFunction
5556
from liger_kernel.ops.backends._ascend.ops.poly_norm import LigerPolyNormFunction
5657
from liger_kernel.ops.backends._ascend.ops.poly_norm import poly_norm_backward
5758
from liger_kernel.ops.backends._ascend.ops.poly_norm import poly_norm_forward
58-
from liger_kernel.ops.backends._ascend.ops.multi_token_attention import LigerMultiTokenAttentionFunction
5959
from liger_kernel.ops.backends._ascend.ops.qwen2vl_mrope import LigerQwen2VLMRopeFunction
6060
from liger_kernel.ops.backends._ascend.ops.qwen2vl_mrope import qwen2vl_mrope_backward
6161
from liger_kernel.ops.backends._ascend.ops.qwen2vl_mrope import qwen2vl_mrope_forward

src/liger_kernel/ops/backends/_ascend/ops/multi_token_attention.py

Lines changed: 111 additions & 107 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import math
2+
13
import torch
24
import torch.nn.functional as F
35
import triton
@@ -6,6 +8,7 @@
68
from torch.nn.modules.utils import _pair
79

810
from liger_kernel.ops.backends._ascend.ub_manager import compute_default_tiling_strategy
11+
from liger_kernel.ops.backends._ascend.ub_manager import get_ub_manager
912
from liger_kernel.ops.utils import ensure_contiguous
1013
from liger_kernel.ops.utils import get_npu_core_count
1114

@@ -22,13 +25,6 @@ def _fused_mask_softmax_fwd_kernel(
2225
):
2326
"""
2427
Fused forward kernel: causal masking + softmax with grid-stride loop.
25-
Each program processes multiple rows for better resource utilization.
26-
27-
Optimizations:
28-
- Grid-stride loop to reduce kernel launch overhead
29-
- Fuses masking and softmax to reduce memory traffic
30-
- Online softmax algorithm for numerical stability
31-
- Only loads valid elements (causal mask)
3228
3329
Args:
3430
scores_ptr: Input scores tensor pointer
@@ -58,23 +54,23 @@ def _fused_mask_softmax_fwd_kernel(
5854
for block_start in range(0, valid_len, BLOCK_SIZE):
5955
col_idx = block_start + tl.arange(0, BLOCK_SIZE)
6056
col_mask = col_idx < valid_len
61-
vals = tl.load(row_ptr + col_idx, mask=col_mask, other=float("-inf"), eviction_policy="evict_first")
57+
vals = tl.load(row_ptr + col_idx, mask=col_mask, other=float("-inf"))
6258
max_val = tl.maximum(max_val, tl.max(vals))
6359

6460
# Second pass: compute exp and sum
6561
sum_exp = 0.0
6662
for block_start in range(0, valid_len, BLOCK_SIZE):
6763
col_idx = block_start + tl.arange(0, BLOCK_SIZE)
6864
col_mask = col_idx < valid_len
69-
vals = tl.load(row_ptr + col_idx, mask=col_mask, other=float("-inf"), eviction_policy="evict_first")
65+
vals = tl.load(row_ptr + col_idx, mask=col_mask, other=float("-inf"))
7066
exp_vals = tl.exp(vals - max_val)
7167
sum_exp += tl.sum(tl.where(col_mask, exp_vals, 0.0))
7268

7369
# Third pass: normalize and store
7470
for block_start in range(0, valid_len, BLOCK_SIZE):
7571
col_idx = block_start + tl.arange(0, BLOCK_SIZE)
7672
col_mask = col_idx < valid_len
77-
vals = tl.load(row_ptr + col_idx, mask=col_mask, other=float("-inf"), eviction_policy="evict_first")
73+
vals = tl.load(row_ptr + col_idx, mask=col_mask, other=float("-inf"))
7874
exp_vals = tl.exp(vals - max_val)
7975
probs = exp_vals / sum_exp
8076
tl.store(out_row_ptr + col_idx, probs, mask=col_mask)
@@ -133,8 +129,8 @@ def _fused_mask_softmax_bwd_kernel(
133129
for block_start in range(0, valid_len, BLOCK_SIZE):
134130
col_idx = block_start + tl.arange(0, BLOCK_SIZE)
135131
col_mask = col_idx < valid_len
136-
grad_vals = tl.load(grad_row_ptr + col_idx, mask=col_mask, other=0.0, eviction_policy="evict_first")
137-
prob_vals = tl.load(probs_row_ptr + col_idx, mask=col_mask, other=0.0, eviction_policy="evict_first")
132+
grad_vals = tl.load(grad_row_ptr + col_idx, mask=col_mask, other=0.0)
133+
prob_vals = tl.load(probs_row_ptr + col_idx, mask=col_mask, other=0.0)
138134
dot += tl.sum(tl.where(col_mask, grad_vals * prob_vals, 0.0))
139135

140136
# Second pass: compute gradient
@@ -201,12 +197,8 @@ def _fused_mask_sparsemax_bwd_kernel(
201197
for block_start in range(0, valid_len, BLOCK_SIZE):
202198
col_idx = block_start + tl.arange(0, BLOCK_SIZE)
203199
col_mask = col_idx < valid_len
204-
prob_vals = tl.load(probs_row_ptr + col_idx, mask=col_mask, other=0.0, eviction_policy="evict_first").to(
205-
tl.float32
206-
)
207-
grad_vals = tl.load(grad_row_ptr + col_idx, mask=col_mask, other=0.0, eviction_policy="evict_first").to(
208-
tl.float32
209-
)
200+
prob_vals = tl.load(probs_row_ptr + col_idx, mask=col_mask, other=0.0).to(tl.float32)
201+
grad_vals = tl.load(grad_row_ptr + col_idx, mask=col_mask, other=0.0).to(tl.float32)
210202
supp = prob_vals > 0.0
211203
go_sum += tl.sum(tl.where(supp & col_mask, grad_vals, 0.0))
212204
supp_cnt += tl.sum(tl.where(supp & col_mask, 1.0, 0.0))
@@ -229,110 +221,97 @@ def _fused_mask_sparsemax_bwd_kernel(
229221
tl.store(out_row_ptr + col_idx, 0.0, mask=col_mask)
230222

231223

232-
def mask_zero_rowwise(scores: torch.Tensor) -> torch.Tensor:
233-
"""
234-
Forward pass for causal masking with zero values.
235-
Uses 1D row-wise processing.
236-
237-
Args:
238-
scores: Input scores tensor of shape (*batch, L, L)
239-
240-
Returns:
241-
Masked scores tensor with future positions set to 0.0
242-
"""
243-
*batch, L, _ = scores.shape
244-
N = int(torch.prod(torch.tensor(batch))) if batch else 1
245-
scores_f = scores.view(N, L, L)
246-
out = torch.empty_like(scores_f)
247-
248-
BLOCK_SIZE = get_optimal_block_size(L, is_forward=True)
249-
num_cores = get_npu_core_count()
250-
grid_size = min(num_cores, N * L)
251-
252-
_mask_row_kernel[(grid_size,)](
253-
scores_f,
254-
out,
255-
scores_f.stride(0),
256-
scores_f.stride(1),
257-
N,
258-
L,
259-
mask_val=0.0,
260-
BLOCK_SIZE=BLOCK_SIZE,
261-
)
262-
263-
return out.view(*batch, L, L)
264-
265-
266224
@triton.jit
267225
def _mask_row_kernel(
268226
scores_ptr,
269227
out_ptr,
270228
stride_b,
271229
stride_row,
272-
N,
273230
L,
274231
mask_val: tl.constexpr,
275232
BLOCK_SIZE: tl.constexpr,
276233
):
277-
pid = tl.program_id(0)
278-
num_progs = tl.num_programs(0)
279-
n_rows = N * L
234+
"""
235+
2D-tiled causal masking kernel.
280236
281-
for linear_idx in tl.range(pid, n_rows, num_progs):
282-
batch_id = linear_idx // L
283-
row_idx = linear_idx % L
237+
Grid:
238+
axis 0 (pid_m): row-block index
239+
axis 1 (pid_n): col-block index
240+
axis 2 (pid_b): batch index
284241
285-
row_ptr = scores_ptr + batch_id * stride_b + row_idx * stride_row
286-
out_row_ptr = out_ptr + batch_id * stride_b + row_idx * stride_row
242+
Each program handles a BLOCK_SIZE x BLOCK_SIZE tile.
243+
"""
287244

288-
# columns handled in blocks
289-
for block_start in range(0, L, BLOCK_SIZE):
290-
col_idx = block_start + tl.arange(0, BLOCK_SIZE)
291-
col_mask = col_idx < L
245+
pid_m = tl.program_id(0) # row block
246+
pid_n = tl.program_id(1) # col block
247+
pid_b = tl.program_id(2) # batch
248+
249+
# Compute row/col indices
250+
row_start = pid_m * BLOCK_SIZE
251+
col_start = pid_n * BLOCK_SIZE
292252

293-
# causal condition
294-
keep = col_idx <= row_idx
253+
row_idx = row_start + tl.arange(0, BLOCK_SIZE)
254+
col_idx = col_start + tl.arange(0, BLOCK_SIZE)
295255

296-
vals = tl.load(
297-
row_ptr + col_idx,
298-
mask=col_mask,
299-
other=0.0,
300-
)
256+
row_mask = row_idx < L
257+
col_mask = col_idx < L
301258

302-
masked_vals = tl.where(keep, vals, mask_val)
259+
ptrs = scores_ptr + pid_b * stride_b + row_idx[:, None] * stride_row + col_idx[None, :]
260+
out_ptrs = out_ptr + pid_b * stride_b + row_idx[:, None] * stride_row + col_idx[None, :]
303261

304-
tl.store(
305-
out_row_ptr + col_idx,
306-
masked_vals,
307-
mask=col_mask,
308-
)
262+
vals = tl.load(ptrs, mask=row_mask[:, None] & col_mask[None, :], other=0.0)
263+
causal = col_idx[None, :] <= row_idx[:, None]
264+
masked_vals = tl.where(causal, vals, mask_val)
309265

266+
tl.store(out_ptrs, masked_vals, mask=row_mask[:, None] & col_mask[None, :])
310267

311-
def get_optimal_block_size(n_cols: int, is_forward: bool) -> int:
268+
269+
def get_2d_optimal_block_size(
270+
n: int,
271+
dtype_size: int,
272+
memory_multiplier: float,
273+
safety_margin: float = 0.9,
274+
min_block: int = 16,
275+
max_block: int = 512,
276+
):
312277
"""
313-
Compute optimal block size for mask-zero rowwise kernel.
278+
Compute optimal 2D block sizes (BLOCK_SIZE, BLOCK_SIZE) for a square matrix.
279+
280+
Args:
281+
n: matrix size (N x N)
282+
ub_capacity_bits: UB capacity in bits
283+
dtype_size: bytes per element
284+
memory_multiplier: estimated live buffer multiplier
285+
safety_margin: UB safety factor
286+
min_block: minimum block size
287+
max_block: maximum block size
288+
289+
Returns:
290+
BLOCK_SIZE
314291
"""
292+
ub_manager = get_ub_manager()
315293

316-
# For small sizes, just use next power of 2
317-
if n_cols <= 4096:
318-
return triton.next_power_of_2(n_cols)
294+
dtype_size = max(dtype_size, 4)
319295

320-
# Mask kernel is light → small multiplier
321-
memory_multiplier = 4.0 if is_forward else 6.0 # slightly conservative
296+
# Safe UB budget
297+
safe_ub = int(ub_manager.ub_capacity_bits * safety_margin)
322298

323-
tile_shapes = compute_default_tiling_strategy(
324-
safety_margin=0.9,
325-
dtype_size=4,
326-
memory_multiplier=memory_multiplier,
327-
shapes=((n_cols,),),
328-
tiling_dims=(0,),
329-
)
299+
# Max tile area
300+
max_area = safe_ub // (memory_multiplier * dtype_size * 8)
301+
max_area = max(1, max_area)
330302

331-
if tile_shapes and len(tile_shapes) > 0:
332-
block_size = tile_shapes[0][0]
333-
return max(4096, block_size)
303+
# Ideal square tile size
304+
block = int(math.sqrt(max_area))
334305

335-
return 4096
306+
# Clamp to problem size
307+
block = min(block, n, max_block)
308+
block = max(block, min_block)
309+
310+
block = triton.next_power_of_2(block)
311+
if block > n:
312+
block = n
313+
314+
return block
336315

337316

338317
def get_optimal_size_fused_mask_softmax(L: int, is_forward: bool = True, dtype_size: int = 2):
@@ -360,6 +339,38 @@ def get_optimal_size_fused_mask_softmax(L: int, is_forward: bool = True, dtype_s
360339
return 2048
361340

362341

342+
def mask_zero_rowwise(scores: torch.Tensor) -> torch.Tensor:
343+
"""
344+
Forward pass for causal masking with zero values.
345+
Uses 1D row-wise processing.
346+
347+
Args:
348+
scores: Input scores tensor of shape (*batch, L, L)
349+
350+
Returns:
351+
Masked scores tensor with future positions set to 0.0
352+
"""
353+
*batch, L, _ = scores.shape
354+
N = int(torch.prod(torch.tensor(batch))) if batch else 1
355+
scores_f = scores.view(N, L, L)
356+
out = torch.empty_like(scores_f)
357+
358+
BLOCK_SIZE = get_2d_optimal_block_size(L, dtype_size=scores_f.element_size(), memory_multiplier=12.0)
359+
360+
grid = (triton.cdiv(L, BLOCK_SIZE), triton.cdiv(L, BLOCK_SIZE), N)
361+
_mask_row_kernel[grid](
362+
scores_f,
363+
out,
364+
scores_f.stride(0),
365+
scores_f.stride(1),
366+
L,
367+
mask_val=0.0,
368+
BLOCK_SIZE=BLOCK_SIZE,
369+
)
370+
371+
return out.view(*batch, L, L)
372+
373+
363374
def fused_mask_softmax_forward(scores: torch.Tensor) -> torch.Tensor:
364375
"""
365376
Fused forward pass: causal masking + softmax.
@@ -438,7 +449,6 @@ def fused_mask_sparsemax_forward(scores: torch.Tensor) -> tuple[torch.Tensor, to
438449
"""
439450
Forward pass: causal masking + sparsemax using reference implementation.
440451
Uses one-axis grid (one program per row).
441-
Because of the complexity of sparsemax, we implement it in PyTorch and fuse only the masking.
442452
443453
Args:
444454
scores: Input scores tensor of shape (*batch, L, L)
@@ -453,16 +463,15 @@ def fused_mask_sparsemax_forward(scores: torch.Tensor) -> tuple[torch.Tensor, to
453463
scores_f = scores.view(N, L, L)
454464
scores_masked = torch.empty_like(scores_f)
455465

456-
BLOCK_SIZE = get_optimal_size_fused_mask_softmax(L, is_forward=True)
457-
num_cores = get_npu_core_count()
458-
grid_size = min(num_cores, N * L)
466+
# BLOCK_SIZE = get_optimal_size_fused_mask_softmax(L, is_forward=True)
467+
BLOCK_SIZE = get_2d_optimal_block_size(L, dtype_size=scores_f.element_size(), memory_multiplier=12.0)
459468

460-
_mask_row_kernel[(grid_size,)](
469+
grid = (triton.cdiv(L, BLOCK_SIZE), triton.cdiv(L, BLOCK_SIZE), N)
470+
_mask_row_kernel[grid](
461471
scores_f,
462472
scores_masked,
463473
scores_f.stride(0),
464474
scores_f.stride(1),
465-
N,
466475
L,
467476
mask_val=-1e9,
468477
BLOCK_SIZE=BLOCK_SIZE,
@@ -513,11 +522,6 @@ def fused_mask_sparsemax_backward(grad_out: torch.Tensor, probs: torch.Tensor) -
513522
class LigerMultiTokenAttentionFunction(torch.autograd.Function):
514523
"""
515524
NPU-optimized Multi-Token Attention using 1D row-wise processing.
516-
517-
This implementation is optimized for NPU hardware by:
518-
1. Using 1D row-wise kernels instead of 2D block-based kernels
519-
2. Larger block sizes for better memory throughput
520-
3. Reduced kernel launch overhead
521525
"""
522526

523527
@staticmethod

0 commit comments

Comments
 (0)