Skip to content

Commit 547cf4c

Browse files
authored
[NPU A3] Fix benchmark issues for fused_linear_jsd and dyt. (#1231)
## Summary Fix benchmark issues for fused_linear_jsd and dyt. 1.dyt throws errors when using torch.compile on NPU. Add logic in benchmark to disable torch.compile baseline for NPU devices. 2.fused_linear_jsd encounters out-of-limit grid error exceeding 65536 on NPU. The issue arises from taking num_row as grid size. Replace it with min(num_cores, n_rows) to fix the problem. ## Testing Done dyt: <img width="1699" height="480" alt="image" src="https://github.com/user-attachments/assets/a0a44250-fc8d-45d5-9b5a-1c4529a1db2b" /> fused_linear_jsd: <img width="1676" height="499" alt="image" src="https://github.com/user-attachments/assets/c5a91b9f-5b74-4065-a6b7-74118820b43f" /> Atlas 800T-A3 x86 Complete the following tasks before sending your PR, and replace `[ ]` with `[x]` to indicate you have done them. --> - Hardware Type: <BLANK> - [x] run `make test` to ensure correctness - [x] run `make checkstyle` to ensure code style - [x] run `make test-convergence` to ensure convergence
1 parent decb1b7 commit 547cf4c

2 files changed

Lines changed: 36 additions & 4 deletions

File tree

benchmark/scripts/benchmark_dyt.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,13 @@
2020
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")))
2121

2222

23+
def get_kernel_providers():
24+
providers = ["liger", "torch"]
25+
if device != "npu":
26+
providers.append("torch_compile")
27+
return providers
28+
29+
2330
def setup_dyt(input: SingleBenchmarkRunInput):
2431
"""Create input tensor and DyT layer from benchmark config."""
2532
from test.transformers.test_dyt import LigerDyT
@@ -85,7 +92,7 @@ def setup_dyt(input: SingleBenchmarkRunInput):
8592
overwrite=args.overwrite,
8693
)
8794

88-
common_configs["kernel_providers"] = ["liger", "torch", "torch_compile"]
95+
common_configs["kernel_providers"] = get_kernel_providers()
8996
run_benchmarks(
9097
bench_test_fn=build_speed_bench_fn(setup_dyt),
9198
kernel_operation_modes=["forward", "backward", "full"],

src/liger_kernel/ops/backends/_ascend/ops/fused_linear_jsd.py

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,38 @@
22

33
import torch
44
import triton
5+
import triton.language as tl
56

67
from liger_kernel.ops.backends._ascend.ops.jsd import _jsd_kernel
78
from liger_kernel.ops.utils import amp_custom_bwd
89
from liger_kernel.ops.utils import amp_custom_fwd
9-
from liger_kernel.ops.utils import element_mul_kernel
1010
from liger_kernel.ops.utils import get_npu_core_count
1111

1212
MAX_FUSED_SIZE = 4096
1313

1414

15+
@triton.jit
16+
def _element_mul_kernel(
17+
X_ptr,
18+
X_stride,
19+
grad_output_ptr,
20+
n_rows: tl.constexpr,
21+
n_cols: tl.constexpr,
22+
BLOCK_SIZE: tl.constexpr,
23+
):
24+
pid = tl.program_id(0)
25+
num_progs = tl.num_programs(0)
26+
grad_output = tl.load(grad_output_ptr)
27+
28+
for row_idx in range(pid, n_rows, num_progs):
29+
row_ptr = X_ptr + row_idx * X_stride
30+
for col_start in range(0, n_cols, BLOCK_SIZE):
31+
offsets = col_start + tl.arange(0, BLOCK_SIZE)
32+
mask = offsets < n_cols
33+
values = tl.load(row_ptr + offsets, mask=mask)
34+
tl.store(row_ptr + offsets, values * grad_output, mask=mask)
35+
36+
1537
def fused_linear_jsd_forward(
1638
student_input,
1739
student_weight,
@@ -131,11 +153,13 @@ def fused_linear_jsd_backward(grad_output, grad_input, grad_weight):
131153
BT, H = grad_input.shape
132154
n_rows = BT
133155
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(H))
156+
num_cores = get_npu_core_count()
134157

135-
element_mul_kernel[(n_rows,)](
158+
_element_mul_kernel[(min(num_cores, n_rows),)](
136159
grad_input,
137160
grad_input.stride(-2),
138161
grad_output,
162+
n_rows,
139163
H,
140164
BLOCK_SIZE=BLOCK_SIZE,
141165
)
@@ -145,10 +169,11 @@ def fused_linear_jsd_backward(grad_output, grad_input, grad_weight):
145169
V, H = grad_weight.shape
146170
n_rows = V
147171

148-
element_mul_kernel[(n_rows,)](
172+
_element_mul_kernel[(min(num_cores, n_rows),)](
149173
grad_weight,
150174
grad_weight.stride(-2),
151175
grad_output,
176+
n_rows,
152177
H,
153178
BLOCK_SIZE=BLOCK_SIZE,
154179
)

0 commit comments

Comments
 (0)