From 21a6747c3d456d216b892a3488d385c88208d7a4 Mon Sep 17 00:00:00 2001 From: MYH <84758754+MAYUNHUI666@users.noreply.github.com> Date: Tue, 27 Jan 2026 09:19:27 +0800 Subject: [PATCH 1/3] =?UTF-8?q?=E3=80=90NPU=E3=80=91fixed=20oom=20error=20?= =?UTF-8?q?for=20benchmark=5Ffused=5Flinear=5Fjsd.py?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- benchmark/scripts/benchmark_fused_linear_jsd.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/benchmark/scripts/benchmark_fused_linear_jsd.py b/benchmark/scripts/benchmark_fused_linear_jsd.py index ac62863b2..261fc99ad 100644 --- a/benchmark/scripts/benchmark_fused_linear_jsd.py +++ b/benchmark/scripts/benchmark_fused_linear_jsd.py @@ -240,7 +240,7 @@ def full(): "x_label": "B x T", "x_values": [2**i for i in range(10, 14)], "kernel_providers": ["liger", "torch"], - "extra_benchmark_configs": [{"H": 4096, "V": 128256, "mode": "forward", "dtype": torch.bfloat16}], + "extra_benchmark_configs": [{"H": 4096, "V": 65536, "mode": "forward", "dtype": torch.bfloat16}], "overwrite": args.overwrite, } From 1e8e5f87ebd72c6d39c69bbc2a981fe2016ed830 Mon Sep 17 00:00:00 2001 From: MYH <84758754+MAYUNHUI666@users.noreply.github.com> Date: Tue, 27 Jan 2026 09:21:06 +0800 Subject: [PATCH 2/3] [NPU] FIX fused_linear_jsd ub overflow on NPU --- src/liger_kernel/ops/fused_linear_jsd.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/liger_kernel/ops/fused_linear_jsd.py b/src/liger_kernel/ops/fused_linear_jsd.py index e31b10769..238623a3c 100644 --- a/src/liger_kernel/ops/fused_linear_jsd.py +++ b/src/liger_kernel/ops/fused_linear_jsd.py @@ -13,7 +13,7 @@ # The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576 https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19 # However, setting limit as 65536 as in LayerNorm tutorial is faster because of less register spilling # The optimal maximum block size depends on your hardware, your kernel, and your dtype -MAX_FUSED_SIZE = 4096 if infer_device() == "xpu" else 65536 // 2 +MAX_FUSED_SIZE = 4096 if infer_device() == "npu" else 65536 // 2 def fused_linear_jsd_forward( From 74b77d1538c6d7c4f4ad73d8a5d0739c42a7aba7 Mon Sep 17 00:00:00 2001 From: MYH <84758754+MAYUNHUI666@users.noreply.github.com> Date: Tue, 27 Jan 2026 10:29:59 +0800 Subject: [PATCH 3/3] Update benchmark_fused_linear_jsd.py --- benchmark/scripts/benchmark_fused_linear_jsd.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/benchmark/scripts/benchmark_fused_linear_jsd.py b/benchmark/scripts/benchmark_fused_linear_jsd.py index 261fc99ad..e4cf3113b 100644 --- a/benchmark/scripts/benchmark_fused_linear_jsd.py +++ b/benchmark/scripts/benchmark_fused_linear_jsd.py @@ -10,6 +10,7 @@ from liger_kernel.transformers.fused_linear_jsd import LigerFusedLinearJSD from liger_kernel.utils import infer_device +from liger_kernel.utils import get_total_gpu_memory device = infer_device() @@ -233,14 +234,19 @@ def full(): if __name__ == "__main__": args = parse_benchmark_script_args() - + gpu_memory_gbs = get_total_gpu_memory() + if gpu_memory_gbs >= 69: + vocab_size = 128256 + else: + vocab_size = 65536 + common_configs = { "kernel_name": "fused_linear_jsd", "x_name": "BT", "x_label": "B x T", "x_values": [2**i for i in range(10, 14)], "kernel_providers": ["liger", "torch"], - "extra_benchmark_configs": [{"H": 4096, "V": 65536, "mode": "forward", "dtype": torch.bfloat16}], + "extra_benchmark_configs": [{"H": 4096, "V": vocab_size, "mode": "forward", "dtype": torch.bfloat16}], "overwrite": args.overwrite, }