diff --git a/benchmark/scripts/benchmark_fused_linear_jsd.py b/benchmark/scripts/benchmark_fused_linear_jsd.py index ac62863b2..e4cf3113b 100644 --- a/benchmark/scripts/benchmark_fused_linear_jsd.py +++ b/benchmark/scripts/benchmark_fused_linear_jsd.py @@ -10,6 +10,7 @@ from liger_kernel.transformers.fused_linear_jsd import LigerFusedLinearJSD from liger_kernel.utils import infer_device +from liger_kernel.utils import get_total_gpu_memory device = infer_device() @@ -233,14 +234,19 @@ def full(): if __name__ == "__main__": args = parse_benchmark_script_args() - + gpu_memory_gbs = get_total_gpu_memory() + if gpu_memory_gbs >= 69: + vocab_size = 128256 + else: + vocab_size = 65536 + common_configs = { "kernel_name": "fused_linear_jsd", "x_name": "BT", "x_label": "B x T", "x_values": [2**i for i in range(10, 14)], "kernel_providers": ["liger", "torch"], - "extra_benchmark_configs": [{"H": 4096, "V": 128256, "mode": "forward", "dtype": torch.bfloat16}], + "extra_benchmark_configs": [{"H": 4096, "V": vocab_size, "mode": "forward", "dtype": torch.bfloat16}], "overwrite": args.overwrite, } diff --git a/src/liger_kernel/ops/fused_linear_jsd.py b/src/liger_kernel/ops/fused_linear_jsd.py index e31b10769..238623a3c 100644 --- a/src/liger_kernel/ops/fused_linear_jsd.py +++ b/src/liger_kernel/ops/fused_linear_jsd.py @@ -13,7 +13,7 @@ # The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576 https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19 # However, setting limit as 65536 as in LayerNorm tutorial is faster because of less register spilling # The optimal maximum block size depends on your hardware, your kernel, and your dtype -MAX_FUSED_SIZE = 4096 if infer_device() == "xpu" else 65536 // 2 +MAX_FUSED_SIZE = 4096 if infer_device() == "npu" else 65536 // 2 def fused_linear_jsd_forward(