Skip to content

Commit 4cde098

Browse files
committed
optimize rowwise
1 parent 87928a4 commit 4cde098

2 files changed

Lines changed: 118 additions & 95 deletions

File tree

src/liger_kernel/ops/backends/_ascend/ops/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,10 +52,10 @@
5252
from liger_kernel.ops.backends._ascend.ops.llama4_rope import LigerLlama4RopeFunction
5353
from liger_kernel.ops.backends._ascend.ops.llama4_rope import llama4_rope_backward
5454
from liger_kernel.ops.backends._ascend.ops.llama4_rope import llama4_rope_forward
55+
from liger_kernel.ops.backends._ascend.ops.multi_token_attention import LigerMultiTokenAttentionFunction
5556
from liger_kernel.ops.backends._ascend.ops.poly_norm import LigerPolyNormFunction
5657
from liger_kernel.ops.backends._ascend.ops.poly_norm import poly_norm_backward
5758
from liger_kernel.ops.backends._ascend.ops.poly_norm import poly_norm_forward
58-
from liger_kernel.ops.backends._ascend.ops.multi_token_attention import LigerMultiTokenAttentionFunction
5959
from liger_kernel.ops.backends._ascend.ops.qwen2vl_mrope import LigerQwen2VLMRopeFunction
6060
from liger_kernel.ops.backends._ascend.ops.qwen2vl_mrope import qwen2vl_mrope_backward
6161
from liger_kernel.ops.backends._ascend.ops.qwen2vl_mrope import qwen2vl_mrope_forward

src/liger_kernel/ops/backends/_ascend/ops/multi_token_attention.py

Lines changed: 117 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import math
2+
13
import torch
24
import torch.nn.functional as F
35
import triton
@@ -6,6 +8,7 @@
68
from torch.nn.modules.utils import _pair
79

810
from liger_kernel.ops.backends._ascend.ub_manager import compute_default_tiling_strategy
11+
from liger_kernel.ops.backends._ascend.ub_manager import get_ub_manager
912
from liger_kernel.ops.utils import ensure_contiguous
1013
from liger_kernel.ops.utils import get_npu_core_count
1114

@@ -58,23 +61,23 @@ def _fused_mask_softmax_fwd_kernel(
5861
for block_start in range(0, valid_len, BLOCK_SIZE):
5962
col_idx = block_start + tl.arange(0, BLOCK_SIZE)
6063
col_mask = col_idx < valid_len
61-
vals = tl.load(row_ptr + col_idx, mask=col_mask, other=float("-inf"), eviction_policy="evict_first")
64+
vals = tl.load(row_ptr + col_idx, mask=col_mask, other=float("-inf"))
6265
max_val = tl.maximum(max_val, tl.max(vals))
6366

6467
# Second pass: compute exp and sum
6568
sum_exp = 0.0
6669
for block_start in range(0, valid_len, BLOCK_SIZE):
6770
col_idx = block_start + tl.arange(0, BLOCK_SIZE)
6871
col_mask = col_idx < valid_len
69-
vals = tl.load(row_ptr + col_idx, mask=col_mask, other=float("-inf"), eviction_policy="evict_first")
72+
vals = tl.load(row_ptr + col_idx, mask=col_mask, other=float("-inf"))
7073
exp_vals = tl.exp(vals - max_val)
7174
sum_exp += tl.sum(tl.where(col_mask, exp_vals, 0.0))
7275

7376
# Third pass: normalize and store
7477
for block_start in range(0, valid_len, BLOCK_SIZE):
7578
col_idx = block_start + tl.arange(0, BLOCK_SIZE)
7679
col_mask = col_idx < valid_len
77-
vals = tl.load(row_ptr + col_idx, mask=col_mask, other=float("-inf"), eviction_policy="evict_first")
80+
vals = tl.load(row_ptr + col_idx, mask=col_mask, other=float("-inf"))
7881
exp_vals = tl.exp(vals - max_val)
7982
probs = exp_vals / sum_exp
8083
tl.store(out_row_ptr + col_idx, probs, mask=col_mask)
@@ -133,8 +136,8 @@ def _fused_mask_softmax_bwd_kernel(
133136
for block_start in range(0, valid_len, BLOCK_SIZE):
134137
col_idx = block_start + tl.arange(0, BLOCK_SIZE)
135138
col_mask = col_idx < valid_len
136-
grad_vals = tl.load(grad_row_ptr + col_idx, mask=col_mask, other=0.0, eviction_policy="evict_first")
137-
prob_vals = tl.load(probs_row_ptr + col_idx, mask=col_mask, other=0.0, eviction_policy="evict_first")
139+
grad_vals = tl.load(grad_row_ptr + col_idx, mask=col_mask, other=0.0)
140+
prob_vals = tl.load(probs_row_ptr + col_idx, mask=col_mask, other=0.0)
138141
dot += tl.sum(tl.where(col_mask, grad_vals * prob_vals, 0.0))
139142

140143
# Second pass: compute gradient
@@ -201,12 +204,8 @@ def _fused_mask_sparsemax_bwd_kernel(
201204
for block_start in range(0, valid_len, BLOCK_SIZE):
202205
col_idx = block_start + tl.arange(0, BLOCK_SIZE)
203206
col_mask = col_idx < valid_len
204-
prob_vals = tl.load(probs_row_ptr + col_idx, mask=col_mask, other=0.0, eviction_policy="evict_first").to(
205-
tl.float32
206-
)
207-
grad_vals = tl.load(grad_row_ptr + col_idx, mask=col_mask, other=0.0, eviction_policy="evict_first").to(
208-
tl.float32
209-
)
207+
prob_vals = tl.load(probs_row_ptr + col_idx, mask=col_mask, other=0.0).to(tl.float32)
208+
grad_vals = tl.load(grad_row_ptr + col_idx, mask=col_mask, other=0.0).to(tl.float32)
210209
supp = prob_vals > 0.0
211210
go_sum += tl.sum(tl.where(supp & col_mask, grad_vals, 0.0))
212211
supp_cnt += tl.sum(tl.where(supp & col_mask, 1.0, 0.0))
@@ -229,110 +228,103 @@ def _fused_mask_sparsemax_bwd_kernel(
229228
tl.store(out_row_ptr + col_idx, 0.0, mask=col_mask)
230229

231230

232-
def mask_zero_rowwise(scores: torch.Tensor) -> torch.Tensor:
233-
"""
234-
Forward pass for causal masking with zero values.
235-
Uses 1D row-wise processing.
236-
237-
Args:
238-
scores: Input scores tensor of shape (*batch, L, L)
239-
240-
Returns:
241-
Masked scores tensor with future positions set to 0.0
242-
"""
243-
*batch, L, _ = scores.shape
244-
N = int(torch.prod(torch.tensor(batch))) if batch else 1
245-
scores_f = scores.view(N, L, L)
246-
out = torch.empty_like(scores_f)
247-
248-
BLOCK_SIZE = get_optimal_block_size(L, is_forward=True)
249-
num_cores = get_npu_core_count()
250-
grid_size = min(num_cores, N * L)
251-
252-
_mask_row_kernel[(grid_size,)](
253-
scores_f,
254-
out,
255-
scores_f.stride(0),
256-
scores_f.stride(1),
257-
N,
258-
L,
259-
mask_val=0.0,
260-
BLOCK_SIZE=BLOCK_SIZE,
261-
)
262-
263-
return out.view(*batch, L, L)
264-
265-
266231
@triton.jit
267232
def _mask_row_kernel(
268233
scores_ptr,
269234
out_ptr,
270235
stride_b,
271236
stride_row,
272-
N,
273237
L,
274238
mask_val: tl.constexpr,
275239
BLOCK_SIZE: tl.constexpr,
276240
):
277-
pid = tl.program_id(0)
278-
num_progs = tl.num_programs(0)
279-
n_rows = N * L
241+
"""
242+
2D-tiled causal masking kernel.
280243
281-
for linear_idx in tl.range(pid, n_rows, num_progs):
282-
batch_id = linear_idx // L
283-
row_idx = linear_idx % L
244+
Grid:
245+
axis 0 (pid_m): row-block index
246+
axis 1 (pid_n): col-block index
247+
axis 2 (pid_b): batch index
284248
285-
row_ptr = scores_ptr + batch_id * stride_b + row_idx * stride_row
286-
out_row_ptr = out_ptr + batch_id * stride_b + row_idx * stride_row
249+
Each program handles a BLOCK_SIZE x BLOCK_SIZE tile.
250+
"""
287251

288-
# columns handled in blocks
289-
for block_start in range(0, L, BLOCK_SIZE):
290-
col_idx = block_start + tl.arange(0, BLOCK_SIZE)
291-
col_mask = col_idx < L
252+
pid_m = tl.program_id(0) # row block
253+
pid_n = tl.program_id(1) # col block
254+
pid_b = tl.program_id(2) # batch
255+
256+
# Compute row/col indices
257+
row_start = pid_m * BLOCK_SIZE
258+
col_start = pid_n * BLOCK_SIZE
292259

293-
# causal condition
294-
keep = col_idx <= row_idx
260+
row_idx = row_start + tl.arange(0, BLOCK_SIZE)
261+
col_idx = col_start + tl.arange(0, BLOCK_SIZE)
295262

296-
vals = tl.load(
297-
row_ptr + col_idx,
298-
mask=col_mask,
299-
other=0.0,
300-
)
263+
row_mask = row_idx < L
264+
col_mask = col_idx < L
301265

302-
masked_vals = tl.where(keep, vals, mask_val)
266+
ptrs = scores_ptr + pid_b * stride_b + row_idx[:, None] * stride_row + col_idx[None, :]
267+
out_ptrs = out_ptr + pid_b * stride_b + row_idx[:, None] * stride_row + col_idx[None, :]
303268

304-
tl.store(
305-
out_row_ptr + col_idx,
306-
masked_vals,
307-
mask=col_mask,
308-
)
269+
vals = tl.load(ptrs, mask=row_mask[:, None] & col_mask[None, :], other=0.0)
270+
causal = col_idx[None, :] <= row_idx[:, None]
271+
masked_vals = tl.where(causal, vals, mask_val)
309272

273+
tl.store(out_ptrs, masked_vals, mask=row_mask[:, None] & col_mask[None, :])
310274

311-
def get_optimal_block_size(n_cols: int, is_forward: bool) -> int:
275+
276+
def get_2d_optimal_block_size(
277+
n: int,
278+
dtype_size: int,
279+
memory_multiplier: float,
280+
safety_margin: float = 0.9,
281+
min_block: int = 16,
282+
max_block: int = 512,
283+
):
312284
"""
313-
Compute optimal block size for mask-zero rowwise kernel.
285+
Compute optimal 2D block sizes (BLOCK_SIZE, BLOCK_SIZE) for a square matrix.
286+
287+
Strategy:
288+
- Compute max tile area allowed by UB
289+
- Use square tiling for best balance
290+
- Round to power-of-two for Triton efficiency
291+
- Clamp to problem size
292+
293+
Args:
294+
n: matrix size (N x N)
295+
ub_capacity_bits: UB capacity in bits
296+
dtype_size: bytes per element
297+
memory_multiplier: estimated live buffer multiplier
298+
safety_margin: UB safety factor
299+
min_block: minimum block size
300+
max_block: maximum block size
301+
302+
Returns:
303+
BLOCK_SIZE
314304
"""
305+
ub_manager = get_ub_manager()
315306

316-
# For small sizes, just use next power of 2
317-
if n_cols <= 4096:
318-
return triton.next_power_of_2(n_cols)
307+
dtype_size = max(dtype_size, 4)
319308

320-
# Mask kernel is light → small multiplier
321-
memory_multiplier = 4.0 if is_forward else 6.0 # slightly conservative
309+
# Safe UB budget
310+
safe_ub = int(ub_manager.ub_capacity_bits * safety_margin)
322311

323-
tile_shapes = compute_default_tiling_strategy(
324-
safety_margin=0.9,
325-
dtype_size=4,
326-
memory_multiplier=memory_multiplier,
327-
shapes=((n_cols,),),
328-
tiling_dims=(0,),
329-
)
312+
# Max tile area
313+
max_area = safe_ub // (memory_multiplier * dtype_size * 8)
314+
max_area = max(1, max_area)
330315

331-
if tile_shapes and len(tile_shapes) > 0:
332-
block_size = tile_shapes[0][0]
333-
return max(4096, block_size)
316+
# Ideal square tile size
317+
block = int(math.sqrt(max_area))
318+
319+
# Clamp to problem size
320+
block = min(block, n, max_block)
321+
block = max(block, min_block)
322+
323+
block = triton.next_power_of_2(block)
324+
if block > n:
325+
block = n
334326

335-
return 4096
327+
return block
336328

337329

338330
def get_optimal_size_fused_mask_softmax(L: int, is_forward: bool = True, dtype_size: int = 2):
@@ -360,6 +352,38 @@ def get_optimal_size_fused_mask_softmax(L: int, is_forward: bool = True, dtype_s
360352
return 2048
361353

362354

355+
def mask_zero_rowwise(scores: torch.Tensor) -> torch.Tensor:
356+
"""
357+
Forward pass for causal masking with zero values.
358+
Uses 1D row-wise processing.
359+
360+
Args:
361+
scores: Input scores tensor of shape (*batch, L, L)
362+
363+
Returns:
364+
Masked scores tensor with future positions set to 0.0
365+
"""
366+
*batch, L, _ = scores.shape
367+
N = int(torch.prod(torch.tensor(batch))) if batch else 1
368+
scores_f = scores.view(N, L, L)
369+
out = torch.empty_like(scores_f)
370+
371+
BLOCK_SIZE = get_2d_optimal_block_size(L, dtype_size=scores_f.element_size(), memory_multiplier=12.0)
372+
373+
grid = (triton.cdiv(L, BLOCK_SIZE), triton.cdiv(L, BLOCK_SIZE), N)
374+
_mask_row_kernel[grid](
375+
scores_f,
376+
out,
377+
scores_f.stride(0),
378+
scores_f.stride(1),
379+
L,
380+
mask_val=0.0,
381+
BLOCK_SIZE=BLOCK_SIZE,
382+
)
383+
384+
return out.view(*batch, L, L)
385+
386+
363387
def fused_mask_softmax_forward(scores: torch.Tensor) -> torch.Tensor:
364388
"""
365389
Fused forward pass: causal masking + softmax.
@@ -453,16 +477,15 @@ def fused_mask_sparsemax_forward(scores: torch.Tensor) -> tuple[torch.Tensor, to
453477
scores_f = scores.view(N, L, L)
454478
scores_masked = torch.empty_like(scores_f)
455479

456-
BLOCK_SIZE = get_optimal_size_fused_mask_softmax(L, is_forward=True)
457-
num_cores = get_npu_core_count()
458-
grid_size = min(num_cores, N * L)
480+
# BLOCK_SIZE = get_optimal_size_fused_mask_softmax(L, is_forward=True)
481+
BLOCK_SIZE = get_2d_optimal_block_size(L, dtype_size=scores_f.element_size(), memory_multiplier=12.0)
459482

460-
_mask_row_kernel[(grid_size,)](
483+
grid = (triton.cdiv(L, BLOCK_SIZE), triton.cdiv(L, BLOCK_SIZE), N)
484+
_mask_row_kernel[grid](
461485
scores_f,
462486
scores_masked,
463487
scores_f.stride(0),
464488
scores_f.stride(1),
465-
N,
466489
L,
467490
mask_val=-1e9,
468491
BLOCK_SIZE=BLOCK_SIZE,

0 commit comments

Comments
 (0)