diff --git a/benchmark/benchmarks_visualizer.py b/benchmark/benchmarks_visualizer.py index 613587a9a..0ede56406 100644 --- a/benchmark/benchmarks_visualizer.py +++ b/benchmark/benchmarks_visualizer.py @@ -319,50 +319,97 @@ def plot_data(df: pd.DataFrame, config: VisualizationsConfig): plt.figure(figsize=(10, 6)) sns.set(style="whitegrid") - try: - ax = sns.lineplot( - data=df, - x="x_value", - y="y_value_50", - hue="kernel_provider", - marker="o", - palette="tab10", - errorbar=("ci", None), - ) - except Exception: - ax = sns.lineplot( + + use_bar_chart = config.sweep_mode == "model_config" + + if use_bar_chart: + # Grouped bar chart for model_config sweep + ax = sns.barplot( data=df, x="x_value", y="y_value_50", hue="kernel_provider", - marker="o", palette="tab10", - errorbar=None, + edgecolor="black", + linewidth=0.5, ) - # For numeric x axes, show tick labels only at actual data points - if is_numeric_x: - tick_values = sorted(df["x_value"].unique()) - ax.set_xticks(tick_values) - ax.set_xticklabels([str(int(v)) if v == int(v) else str(v) for v in tick_values]) - - # Seaborn can't plot pre-computed error bars, so we need to do it manually - lines = ax.get_lines() - colors = [line.get_color() for line in lines] - - for (_, group_data), color in zip(df.groupby("kernel_provider"), colors): - y_error_lower = group_data["y_value_50"] - group_data["y_value_20"] - y_error_upper = group_data["y_value_80"] - group_data["y_value_50"] - y_error = [y_error_lower, y_error_upper] - - plt.errorbar( - group_data["x_value"], - group_data["y_value_50"], - yerr=y_error, - fmt="o", - color=color, - capsize=5, - ) + # Add error bars on each bar using pre-computed percentiles + providers = df.sort_values("kernel_provider")["kernel_provider"].unique() + x_values = df["x_value"].unique() + n_providers = len(providers) + bar_width = 0.8 / n_providers # seaborn default total width is 0.8 + + for i, provider in enumerate(providers): + group_data = df[df["kernel_provider"] == provider] + for j, x_val in enumerate(x_values): + row = group_data[group_data["x_value"] == x_val] + if row.empty: + continue + y_val = row["y_value_50"].values[0] + y_err_lower = y_val - row["y_value_20"].values[0] + y_err_upper = row["y_value_80"].values[0] - y_val + bar_x = j + (i - (n_providers - 1) / 2) * bar_width + ax.errorbar( + bar_x, + y_val, + yerr=[[y_err_lower], [y_err_upper]], + fmt="none", + color="black", + capsize=3, + linewidth=1, + ) + + # Rotate x labels if they are long model config names + if not is_numeric_x: + plt.xticks(rotation=30, ha="right") + else: + # Line chart for token_length sweep + try: + ax = sns.lineplot( + data=df, + x="x_value", + y="y_value_50", + hue="kernel_provider", + marker="o", + palette="tab10", + errorbar=("ci", None), + ) + except Exception: + ax = sns.lineplot( + data=df, + x="x_value", + y="y_value_50", + hue="kernel_provider", + marker="o", + palette="tab10", + errorbar=None, + ) + + # For numeric x axes, show tick labels only at actual data points + if is_numeric_x: + tick_values = sorted(df["x_value"].unique()) + ax.set_xticks(tick_values) + ax.set_xticklabels([str(int(v)) if v == int(v) else str(v) for v in tick_values]) + + # Seaborn can't plot pre-computed error bars, so we need to do it manually + lines = ax.get_lines() + colors = [line.get_color() for line in lines] + + for (_, group_data), color in zip(df.groupby("kernel_provider"), colors): + y_error_lower = group_data["y_value_50"] - group_data["y_value_20"] + y_error_upper = group_data["y_value_80"] - group_data["y_value_50"] + y_error = [y_error_lower, y_error_upper] + + plt.errorbar( + group_data["x_value"], + group_data["y_value_50"], + yerr=y_error, + fmt="o", + color=color, + capsize=5, + ) + plt.legend(title="Kernel Provider") plt.xlabel(xlabel) plt.ylabel(ylabel) diff --git a/benchmark/scripts/benchmark_attn_res.py b/benchmark/scripts/benchmark_attn_res.py index cf52361db..710555a8f 100644 --- a/benchmark/scripts/benchmark_attn_res.py +++ b/benchmark/scripts/benchmark_attn_res.py @@ -12,6 +12,8 @@ sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))) +from benchmark_model_configs import MODEL_REGISTRY +from benchmark_model_configs import compute_model_config_sweep_config from benchmark_model_configs import compute_seq_len_sweep_config from benchmark_model_configs import estimate_kernel_peak_memory from benchmark_model_configs import get_benchmark_model_config @@ -69,61 +71,154 @@ def bench_memory_attn_res(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunO return run_memory_benchmark(fn, input.kernel_operation_mode) -if __name__ == "__main__": - args = parse_benchmark_script_args() - - model = get_benchmark_model_config(args.model) - probe_seq_len = 1024 - - def _probe(): - probe_input = SingleBenchmarkRunInput( - x=probe_seq_len, - kernel_provider="pytorch", +def _resolve_model_config_attn_res(input: SingleBenchmarkRunInput): + """Resolve model-config-sweep input into standard setup args.""" + cfg = input.extra_benchmark_config + model_info = cfg["model_configs"][input.x] + return _setup_attn_res( + SingleBenchmarkRunInput( + x=cfg["seq_len"], + kernel_provider=input.kernel_provider, extra_benchmark_config={ - "N": 8, - "bsz": 1, - "hidden_size": model.hidden_size, - "dtype": model.dtype, - "eps": 1e-6, + "N": cfg["N"], + "bsz": cfg["bsz"], + "hidden_size": model_info["hidden_size"], + "dtype": model_info["dtype"], + "eps": cfg.get("eps", 1e-6), }, ) - V, fn = _setup_attn_res(probe_input) - return fn() - - peak_bytes = estimate_kernel_peak_memory(probe_fn=_probe) - kernel_bpt = peak_bytes // probe_seq_len - - config = compute_seq_len_sweep_config(model, kernel_bytes_per_token=kernel_bpt) - - common_configs = { - "kernel_name": "attn_res", - "x_name": "T", - "x_label": "sequence length", - "x_values": [2**i for i in range(10, int(math.log2(config.seq_len)) + 1)], - "kernel_providers": ["liger", "pytorch"], - "extra_benchmark_configs": [ - { - "N": 8, - "bsz": config.batch_size, - "hidden_size": model.hidden_size, - "dtype": model.dtype, - "eps": 1e-6, - } - ], - "overwrite": args.overwrite, - } - - run_benchmarks( - bench_test_fn=bench_speed_attn_res, - kernel_operation_modes=["full", "forward", "backward"], - metric_name="speed", - metric_unit="ms", - **common_configs, - ) - run_benchmarks( - bench_test_fn=bench_memory_attn_res, - kernel_operation_modes=["full", "forward", "backward"], - metric_name="memory", - metric_unit="MB", - **common_configs, ) + + +def bench_speed_attn_res_model_config(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + V, fn = _resolve_model_config_attn_res(input) + return run_speed_benchmark(fn, input.kernel_operation_mode, [V]) + + +def bench_memory_attn_res_model_config(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + V, fn = _resolve_model_config_attn_res(input) + return run_memory_benchmark(fn, input.kernel_operation_mode) + + +if __name__ == "__main__": + args = parse_benchmark_script_args() + + if args.sweep_mode == "model_config": + all_model_configs = list(MODEL_REGISTRY.values()) + + def _probe_factory(model_cfg, probe_seq_len): + def _probe(): + probe_input = SingleBenchmarkRunInput( + x=probe_seq_len, + kernel_provider="pytorch", + extra_benchmark_config={ + "N": 8, + "bsz": 1, + "hidden_size": model_cfg.hidden_size, + "dtype": model_cfg.dtype, + "eps": 1e-6, + }, + ) + V, fn = _setup_attn_res(probe_input) + return fn() + + return _probe + + sweep = compute_model_config_sweep_config(all_model_configs, probe_fn_factory=_probe_factory, bt=args.bt) + + model_configs_info = { + cfg.name: { + "hidden_size": cfg.hidden_size, + "dtype": cfg.dtype, + } + for cfg in sweep.model_configs + } + + common_configs = { + "kernel_name": "attn_res", + "x_name": "model_config", + "x_label": "model configuration", + "x_values": [cfg.name for cfg in sweep.model_configs], + "kernel_providers": ["liger", "pytorch"], + "extra_benchmark_configs": [ + { + "model_configs": model_configs_info, + "N": 8, + "bsz": sweep.batch_size, + "seq_len": sweep.seq_len, + "eps": 1e-6, + } + ], + "overwrite": args.overwrite, + } + + run_benchmarks( + bench_test_fn=bench_speed_attn_res_model_config, + kernel_operation_modes=["full", "forward", "backward"], + metric_name="speed", + metric_unit="ms", + **common_configs, + ) + run_benchmarks( + bench_test_fn=bench_memory_attn_res_model_config, + kernel_operation_modes=["full", "forward", "backward"], + metric_name="memory", + metric_unit="MB", + **common_configs, + ) + else: + model = get_benchmark_model_config(args.model) + probe_seq_len = 1024 + + def _probe(): + probe_input = SingleBenchmarkRunInput( + x=probe_seq_len, + kernel_provider="pytorch", + extra_benchmark_config={ + "N": 8, + "bsz": 1, + "hidden_size": model.hidden_size, + "dtype": model.dtype, + "eps": 1e-6, + }, + ) + V, fn = _setup_attn_res(probe_input) + return fn() + + peak_bytes = estimate_kernel_peak_memory(probe_fn=_probe) + kernel_bpt = peak_bytes // probe_seq_len + + config = compute_seq_len_sweep_config(model, kernel_bytes_per_token=kernel_bpt) + + common_configs = { + "kernel_name": "attn_res", + "x_name": "T", + "x_label": "sequence length", + "x_values": [2**i for i in range(10, int(math.log2(config.seq_len)) + 1)], + "kernel_providers": ["liger", "pytorch"], + "extra_benchmark_configs": [ + { + "N": 8, + "bsz": config.batch_size, + "hidden_size": model.hidden_size, + "dtype": model.dtype, + "eps": 1e-6, + } + ], + "overwrite": args.overwrite, + } + + run_benchmarks( + bench_test_fn=bench_speed_attn_res, + kernel_operation_modes=["full", "forward", "backward"], + metric_name="speed", + metric_unit="ms", + **common_configs, + ) + run_benchmarks( + bench_test_fn=bench_memory_attn_res, + kernel_operation_modes=["full", "forward", "backward"], + metric_name="memory", + metric_unit="MB", + **common_configs, + ) diff --git a/benchmark/scripts/benchmark_cpo_loss.py b/benchmark/scripts/benchmark_cpo_loss.py index 8b10d5188..9d2c225b5 100644 --- a/benchmark/scripts/benchmark_cpo_loss.py +++ b/benchmark/scripts/benchmark_cpo_loss.py @@ -1,9 +1,15 @@ +import math import os import sys import torch import triton +from benchmark_model_configs import MODEL_REGISTRY +from benchmark_model_configs import compute_model_config_sweep_config +from benchmark_model_configs import compute_seq_len_sweep_config +from benchmark_model_configs import estimate_kernel_peak_memory +from benchmark_model_configs import get_benchmark_model_config from utils import QUANTILES from utils import SingleBenchmarkRunInput from utils import SingleBenchmarkRunOutput @@ -18,94 +24,96 @@ sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))) -############################################################################# -# Test the memory consumption of the linear fused cross entropy loss -############################################################################# - - -def bench_memory_fused_linear_cpo_loss( - input: SingleBenchmarkRunInput, -) -> SingleBenchmarkRunOutput: +def _setup_cpo_loss(input: SingleBenchmarkRunInput): + """Create input tensors and CPO loss from benchmark config.""" from test.chunked_loss.test_cpo_loss import LigerLMHeadCPO from test.chunked_loss.test_cpo_loss import TorchLMHeadCPO + cfg = input.extra_benchmark_config + H = cfg["hidden_size"] + V = cfg["vocab_size"] + dtype = cfg["dtype"] B = input.x - T = input.extra_benchmark_config["T"] - H = input.extra_benchmark_config["H"] - V = input.extra_benchmark_config["V"] - dtype = input.extra_benchmark_config["dtype"] - provider = input.kernel_provider - - # Instantiate once and retrieve the first output only - torch_lm_head_cpo = TorchLMHeadCPO(H=H, V=V, dtype=dtype).to(device) - liger_lm_head_cpo = LigerLMHeadCPO(H=H, V=V, dtype=dtype).to(device) - torch_fwd = lambda x, target: torch_lm_head_cpo(x, target)[0] - liger_fwd = lambda x, target: liger_lm_head_cpo(x, target)[0] + T = cfg["T"] _input = torch.randn(B, T, H, requires_grad=True, dtype=dtype, device=device) target = torch.randint(V, (B, T), dtype=torch.long, device=device) - def fwd(): - if provider == "liger": - return liger_fwd(_input, target) - elif provider == "huggingface": - return torch_fwd(_input, target) + if input.kernel_provider == "liger": + loss_module = LigerLMHeadCPO(H=H, V=V, dtype=dtype).to(device) + elif input.kernel_provider == "huggingface": + loss_module = TorchLMHeadCPO(H=H, V=V, dtype=dtype).to(device) + else: + raise ValueError(f"Invalid provider: {input.kernel_provider} for CPOLoss") - def full(): - y = fwd() - y.backward() + fwd_fn = lambda: loss_module(_input, target)[0] + return _input, fwd_fn - mem_50, mem_20, mem_80 = _test_memory(full, _iter=10, quantiles=QUANTILES) - return SingleBenchmarkRunOutput( - y_20=mem_20, - y_50=mem_50, - y_80=mem_80, - ) +def bench_speed_cpo_loss(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + _input, fwd_fn = _setup_cpo_loss(input) + mode = input.kernel_operation_mode + + if mode == "forward": + ms_50, ms_20, ms_80 = triton.testing.do_bench(fwd_fn, rep=100, quantiles=QUANTILES) + elif mode == "backward": + y = fwd_fn() + ms_50, ms_20, ms_80 = triton.testing.do_bench( + lambda: y.backward(retain_graph=True), + grad_to_none=[_input], + rep=100, + quantiles=QUANTILES, + ) + elif mode == "full": + + def full(): + y = fwd_fn() + y.backward() -# ############################################################################# -# # Test the speed of the fused linear cross entropy loss -# ############################################################################# + ms_50, ms_20, ms_80 = triton.testing.do_bench(full, rep=100, quantiles=QUANTILES) + else: + raise ValueError(f"Unsupported mode: {mode}") + return SingleBenchmarkRunOutput(y_20=ms_20, y_50=ms_50, y_80=ms_80) -def bench_speed_fused_linear_cpo_loss( - input: SingleBenchmarkRunInput, -) -> SingleBenchmarkRunOutput: - from test.chunked_loss.test_cpo_loss import LigerLMHeadCPO - from test.chunked_loss.test_cpo_loss import TorchLMHeadCPO - B = input.x - T = input.extra_benchmark_config["T"] - H = input.extra_benchmark_config["H"] - V = input.extra_benchmark_config["V"] - dtype = input.extra_benchmark_config["dtype"] - provider = input.kernel_provider - mode = input.kernel_operation_mode +def bench_memory_cpo_loss(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + _input, fwd_fn = _setup_cpo_loss(input) - # Instantiate once and retrieve the first output only - torch_lm_head_cpo = TorchLMHeadCPO(H=H, V=V, dtype=dtype).to(device) - liger_lm_head_cpo = LigerLMHeadCPO(H=H, V=V, dtype=dtype).to(device) - torch_fwd = lambda x, target: torch_lm_head_cpo(x, target)[0] - liger_fwd = lambda x, target: liger_lm_head_cpo(x, target)[0] + def full(): + y = fwd_fn() + y.backward() - _input = torch.randn(B, T, H, requires_grad=True, dtype=dtype, device=device) - target = torch.randint(V, (B, T), dtype=torch.long, device=device) + mem_50, mem_20, mem_80 = _test_memory(full, _iter=10, quantiles=QUANTILES) + return SingleBenchmarkRunOutput(y_20=mem_20, y_50=mem_50, y_80=mem_80) - def fwd(): - if provider == "liger": - return liger_fwd(_input, target) - elif provider == "huggingface": - return torch_fwd(_input, target) - if mode == "forward": - ms_50, ms_20, ms_80 = triton.testing.do_bench( - fwd, - rep=100, - quantiles=QUANTILES, +def _resolve_model_config_cpo_loss(input: SingleBenchmarkRunInput): + """Resolve model-config-sweep input into standard setup args.""" + cfg = input.extra_benchmark_config + model_info = cfg["model_configs"][input.x] + return _setup_cpo_loss( + SingleBenchmarkRunInput( + x=cfg["B"], + kernel_provider=input.kernel_provider, + extra_benchmark_config={ + "hidden_size": model_info["hidden_size"], + "vocab_size": model_info["vocab_size"], + "dtype": model_info["dtype"], + "T": cfg["T"], + }, ) - elif mode == "backward": - y = fwd() + ) + +def bench_speed_cpo_loss_model_config(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + _input, fwd_fn = _resolve_model_config_cpo_loss(input) + mode = input.kernel_operation_mode + + if mode == "forward": + ms_50, ms_20, ms_80 = triton.testing.do_bench(fwd_fn, rep=100, quantiles=QUANTILES) + elif mode == "backward": + y = fwd_fn() ms_50, ms_20, ms_80 = triton.testing.do_bench( lambda: y.backward(retain_graph=True), grad_to_none=[_input], @@ -115,53 +123,148 @@ def fwd(): elif mode == "full": def full(): - y = fwd() + y = fwd_fn() y.backward() - ms_50, ms_20, ms_80 = triton.testing.do_bench( - full, - rep=100, - quantiles=QUANTILES, - ) - return SingleBenchmarkRunOutput( - y_20=ms_20, - y_50=ms_50, - y_80=ms_80, - ) + ms_50, ms_20, ms_80 = triton.testing.do_bench(full, rep=100, quantiles=QUANTILES) + else: + raise ValueError(f"Unsupported mode: {mode}") + + return SingleBenchmarkRunOutput(y_20=ms_20, y_50=ms_50, y_80=ms_80) + + +def bench_memory_cpo_loss_model_config(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + _input, fwd_fn = _resolve_model_config_cpo_loss(input) + + def full(): + y = fwd_fn() + y.backward() + + mem_50, mem_20, mem_80 = _test_memory(full, _iter=10, quantiles=QUANTILES) + return SingleBenchmarkRunOutput(y_20=mem_20, y_50=mem_50, y_80=mem_80) if __name__ == "__main__": args = parse_benchmark_script_args() - common_configs = { - "kernel_name": "fused_linear_cpo_loss", - "x_name": "B", - "x_label": "B", - "x_values": [2**i for i in range(1, 5)], - "kernel_providers": ["liger", "huggingface"], - "extra_benchmark_configs": [ - { - "T": 1024, - "H": 4096, - "V": 128256, - "mode": "forward", - "dtype": torch.bfloat16, + if args.sweep_mode == "model_config": + all_model_configs = list(MODEL_REGISTRY.values()) + T = 1024 + + def _probe_factory(model_cfg, probe_bt): + def _probe(): + B = max(1, probe_bt // T) + probe_input = SingleBenchmarkRunInput( + x=B, + kernel_provider="huggingface", + extra_benchmark_config={ + "hidden_size": model_cfg.hidden_size, + "vocab_size": model_cfg.vocab_size, + "dtype": model_cfg.dtype, + "T": T, + }, + ) + _, fwd_fn = _setup_cpo_loss(probe_input) + return fwd_fn() + + return _probe + + sweep = compute_model_config_sweep_config(all_model_configs, probe_fn_factory=_probe_factory, bt=args.bt) + + model_configs_info = { + cfg.name: { + "hidden_size": cfg.hidden_size, + "vocab_size": cfg.vocab_size, + "dtype": cfg.dtype, } - ], - "overwrite": args.overwrite, - } - - run_benchmarks( - bench_test_fn=bench_speed_fused_linear_cpo_loss, - kernel_operation_modes=["forward", "backward", "full"], - metric_name="speed", - metric_unit="ms", - **common_configs, - ) - run_benchmarks( - bench_test_fn=bench_memory_fused_linear_cpo_loss, - kernel_operation_modes=["full"], - metric_name="memory", - metric_unit="MB", - **common_configs, - ) + for cfg in sweep.model_configs + } + + B = max(1, sweep.bt // T) + + common_configs = { + "kernel_name": "fused_linear_cpo_loss", + "x_name": "model_config", + "x_label": "model configuration", + "x_values": [cfg.name for cfg in sweep.model_configs], + "kernel_providers": ["liger", "huggingface"], + "extra_benchmark_configs": [ + { + "model_configs": model_configs_info, + "B": B, + "T": T, + } + ], + "overwrite": args.overwrite, + } + + run_benchmarks( + bench_test_fn=bench_speed_cpo_loss_model_config, + kernel_operation_modes=["forward", "backward", "full"], + metric_name="speed", + metric_unit="ms", + **common_configs, + ) + run_benchmarks( + bench_test_fn=bench_memory_cpo_loss_model_config, + kernel_operation_modes=["full"], + metric_name="memory", + metric_unit="MB", + **common_configs, + ) + else: + model = get_benchmark_model_config(args.model) + T = 1024 + probe_bt = 1024 + + def _probe(): + B = max(1, probe_bt // T) + probe_input = SingleBenchmarkRunInput( + x=B, + kernel_provider="huggingface", + extra_benchmark_config={ + "hidden_size": model.hidden_size, + "vocab_size": model.vocab_size, + "dtype": model.dtype, + "T": T, + }, + ) + _, fwd_fn = _setup_cpo_loss(probe_input) + return fwd_fn() + + peak_bytes = estimate_kernel_peak_memory(probe_fn=_probe) + kernel_bpt = peak_bytes // probe_bt + + config = compute_seq_len_sweep_config(model, kernel_bytes_per_token=kernel_bpt) + + common_configs = { + "kernel_name": "fused_linear_cpo_loss", + "x_name": "B", + "x_label": "Batch Size (B)", + "x_values": [2**i for i in range(1, int(math.log2(max(2, config.batch_size * config.seq_len // T))) + 1)], + "kernel_providers": ["liger", "huggingface"], + "extra_benchmark_configs": [ + { + "hidden_size": model.hidden_size, + "vocab_size": model.vocab_size, + "dtype": model.dtype, + "T": T, + } + ], + "overwrite": args.overwrite, + } + + run_benchmarks( + bench_test_fn=bench_speed_cpo_loss, + kernel_operation_modes=["forward", "backward", "full"], + metric_name="speed", + metric_unit="ms", + **common_configs, + ) + run_benchmarks( + bench_test_fn=bench_memory_cpo_loss, + kernel_operation_modes=["full"], + metric_name="memory", + metric_unit="MB", + **common_configs, + ) diff --git a/benchmark/scripts/benchmark_cross_entropy.py b/benchmark/scripts/benchmark_cross_entropy.py index cdd61814a..576af32f0 100644 --- a/benchmark/scripts/benchmark_cross_entropy.py +++ b/benchmark/scripts/benchmark_cross_entropy.py @@ -1,6 +1,13 @@ +import math + import torch import triton +from benchmark_model_configs import MODEL_REGISTRY +from benchmark_model_configs import compute_model_config_sweep_config +from benchmark_model_configs import compute_seq_len_sweep_config +from benchmark_model_configs import estimate_kernel_peak_memory +from benchmark_model_configs import get_benchmark_model_config from torch.nn import CrossEntropyLoss from utils import QUANTILES from utils import SingleBenchmarkRunInput @@ -15,58 +22,87 @@ device = infer_device() -def bench_memory_cross_entropy( - input: SingleBenchmarkRunInput, -) -> SingleBenchmarkRunOutput: - torch_ce = CrossEntropyLoss() - liger_ce = LigerCrossEntropyLoss() +def _setup_cross_entropy(input: SingleBenchmarkRunInput): + """Create input tensor, target, and CE loss from benchmark config.""" + cfg = input.extra_benchmark_config + V = cfg["vocab_size"] + BT = input.x + _input = torch.randn(BT, V, requires_grad=True, device=device) + target = torch.randint(V, (BT, 1), device=device).squeeze(1) + if input.kernel_provider == "liger": + loss_fn = LigerCrossEntropyLoss() + elif input.kernel_provider == "huggingface": + loss_fn = CrossEntropyLoss() + else: + raise ValueError(f"Invalid provider: {input.kernel_provider} for CrossEntropy") + return _input, target, loss_fn + + +def bench_speed_cross_entropy(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + _input, target, loss_fn = _setup_cross_entropy(input) + mode = input.kernel_operation_mode - V = input.x - provider = input.kernel_provider - B = input.extra_benchmark_config["B"] - T = input.extra_benchmark_config["T"] + def fwd(): + return loss_fn(_input, target) - _input = torch.randn(B * T, V, requires_grad=True, device=device) - target = torch.randint(V, (B * T, 1), device=device).squeeze(1) + if mode == "forward": + ms_50, ms_20, ms_80 = triton.testing.do_bench(fwd, rep=100, quantiles=QUANTILES) + elif mode == "no-grad-forward": + with torch.no_grad(): + ms_50, ms_20, ms_80 = triton.testing.do_bench(fwd, rep=100, quantiles=QUANTILES) + elif mode == "backward": + y = fwd() + ms_50, ms_20, ms_80 = triton.testing.do_bench( + lambda: y.backward(retain_graph=True), + grad_to_none=[_input], + rep=100, + quantiles=QUANTILES, + ) + elif mode == "full": - def fwd(): - if provider == "liger": - return liger_ce(_input, target) - else: - return torch_ce(_input, target) + def full(): + y = fwd() + y.backward() + + ms_50, ms_20, ms_80 = triton.testing.do_bench(full, rep=100, quantiles=QUANTILES) + else: + raise ValueError(f"Unsupported mode: {mode}") + + return SingleBenchmarkRunOutput(y_20=ms_20, y_50=ms_50, y_80=ms_80) + + +def bench_memory_cross_entropy(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + _input, target, loss_fn = _setup_cross_entropy(input) def full(): - y = fwd() + y = loss_fn(_input, target) y.backward() mem_50, mem_20, mem_80 = _test_memory(full, quantiles=QUANTILES) - return SingleBenchmarkRunOutput( - y_20=mem_20, - y_50=mem_50, - y_80=mem_80, + return SingleBenchmarkRunOutput(y_20=mem_20, y_50=mem_50, y_80=mem_80) + + +def _resolve_model_config_cross_entropy(input: SingleBenchmarkRunInput): + """Resolve model-config-sweep input into standard setup args.""" + cfg = input.extra_benchmark_config + model_info = cfg["model_configs"][input.x] + return _setup_cross_entropy( + SingleBenchmarkRunInput( + x=cfg["BT"], + kernel_provider=input.kernel_provider, + extra_benchmark_config={ + "vocab_size": model_info["vocab_size"], + }, + ) ) -def bench_speed_cross_entropy( - input: SingleBenchmarkRunInput, -) -> SingleBenchmarkRunOutput: - torch_ce = CrossEntropyLoss() - liger_ce = LigerCrossEntropyLoss() - - V = input.x - provider = input.kernel_provider +def bench_speed_cross_entropy_model_config(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + _input, target, loss_fn = _resolve_model_config_cross_entropy(input) mode = input.kernel_operation_mode - B = input.extra_benchmark_config["B"] - T = input.extra_benchmark_config["T"] - - _input = torch.randn(B * T, V, requires_grad=True, device=device) - target = torch.randint(V, (B * T, 1), device=device).squeeze(1) def fwd(): - if provider == "liger": - return liger_ce(_input, target) - else: - return torch_ce(_input, target) + return loss_fn(_input, target) if mode == "forward": ms_50, ms_20, ms_80 = triton.testing.do_bench(fwd, rep=100, quantiles=QUANTILES) @@ -75,7 +111,6 @@ def fwd(): ms_50, ms_20, ms_80 = triton.testing.do_bench(fwd, rep=100, quantiles=QUANTILES) elif mode == "backward": y = fwd() - ms_50, ms_20, ms_80 = triton.testing.do_bench( lambda: y.backward(retain_graph=True), grad_to_none=[_input], @@ -89,38 +124,126 @@ def full(): y.backward() ms_50, ms_20, ms_80 = triton.testing.do_bench(full, rep=100, quantiles=QUANTILES) + else: + raise ValueError(f"Unsupported mode: {mode}") - return SingleBenchmarkRunOutput( - y_20=ms_20, - y_50=ms_50, - y_80=ms_80, - ) + return SingleBenchmarkRunOutput(y_20=ms_20, y_50=ms_50, y_80=ms_80) + + +def bench_memory_cross_entropy_model_config(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + _input, target, loss_fn = _resolve_model_config_cross_entropy(input) + + def full(): + y = loss_fn(_input, target) + y.backward() + + mem_50, mem_20, mem_80 = _test_memory(full, quantiles=QUANTILES) + return SingleBenchmarkRunOutput(y_20=mem_20, y_50=mem_50, y_80=mem_80) if __name__ == "__main__": args = parse_benchmark_script_args() - common_configs = { - "kernel_name": "cross_entropy", - "x_name": "V", - "x_label": "vocab size", - "x_values": [2**i for i in range(12, 18)], - "kernel_providers": ["liger", "huggingface"], - "extra_benchmark_configs": [{"B": 8, "T": 2048}], - "overwrite": args.overwrite, - } - - run_benchmarks( - bench_test_fn=bench_speed_cross_entropy, - kernel_operation_modes=["forward", "backward", "full", "no-grad-forward"], - metric_name="speed", - metric_unit="ms", - **common_configs, - ) - run_benchmarks( - bench_test_fn=bench_memory_cross_entropy, - kernel_operation_modes=["full"], - metric_name="memory", - metric_unit="MB", - **common_configs, - ) + if args.sweep_mode == "model_config": + all_model_configs = list(MODEL_REGISTRY.values()) + + def _probe_factory(model_cfg, probe_bt): + def _probe(): + probe_input = SingleBenchmarkRunInput( + x=probe_bt, + kernel_provider="huggingface", + extra_benchmark_config={ + "vocab_size": model_cfg.vocab_size, + }, + ) + _input, target, loss_fn = _setup_cross_entropy(probe_input) + return loss_fn(_input, target) + + return _probe + + sweep = compute_model_config_sweep_config(all_model_configs, probe_fn_factory=_probe_factory, bt=args.bt) + + model_configs_info = { + cfg.name: { + "vocab_size": cfg.vocab_size, + } + for cfg in sweep.model_configs + } + + common_configs = { + "kernel_name": "cross_entropy", + "x_name": "model_config", + "x_label": "model configuration", + "x_values": [cfg.name for cfg in sweep.model_configs], + "kernel_providers": ["liger", "huggingface"], + "extra_benchmark_configs": [ + { + "model_configs": model_configs_info, + "BT": sweep.bt, + } + ], + "overwrite": args.overwrite, + } + + run_benchmarks( + bench_test_fn=bench_speed_cross_entropy_model_config, + kernel_operation_modes=["forward", "backward", "full"], + metric_name="speed", + metric_unit="ms", + **common_configs, + ) + run_benchmarks( + bench_test_fn=bench_memory_cross_entropy_model_config, + kernel_operation_modes=["full"], + metric_name="memory", + metric_unit="MB", + **common_configs, + ) + else: + model = get_benchmark_model_config(args.model) + probe_bt = 1024 + + def _probe(): + probe_input = SingleBenchmarkRunInput( + x=probe_bt, + kernel_provider="huggingface", + extra_benchmark_config={ + "vocab_size": model.vocab_size, + }, + ) + _input, target, loss_fn = _setup_cross_entropy(probe_input) + return loss_fn(_input, target) + + peak_bytes = estimate_kernel_peak_memory(probe_fn=_probe) + kernel_bpt = peak_bytes // probe_bt + + config = compute_seq_len_sweep_config(model, kernel_bytes_per_token=kernel_bpt) + + common_configs = { + "kernel_name": "cross_entropy", + "x_name": "BT", + "x_label": "B * T", + "x_values": [2**i for i in range(10, int(math.log2(config.batch_size * config.seq_len)) + 1)], + "kernel_providers": ["liger", "huggingface"], + "extra_benchmark_configs": [ + { + "vocab_size": model.vocab_size, + } + ], + "overwrite": args.overwrite, + } + + run_benchmarks( + bench_test_fn=bench_speed_cross_entropy, + kernel_operation_modes=["forward", "backward", "full", "no-grad-forward"], + metric_name="speed", + metric_unit="ms", + **common_configs, + ) + run_benchmarks( + bench_test_fn=bench_memory_cross_entropy, + kernel_operation_modes=["full"], + metric_name="memory", + metric_unit="MB", + **common_configs, + ) diff --git a/benchmark/scripts/benchmark_distill_cosine_loss.py b/benchmark/scripts/benchmark_distill_cosine_loss.py index 5cf12b495..23414c1e2 100644 --- a/benchmark/scripts/benchmark_distill_cosine_loss.py +++ b/benchmark/scripts/benchmark_distill_cosine_loss.py @@ -1,3 +1,4 @@ +import math import os import sys @@ -5,6 +6,11 @@ import torch.nn as nn import triton +from benchmark_model_configs import MODEL_REGISTRY +from benchmark_model_configs import compute_model_config_sweep_config +from benchmark_model_configs import compute_seq_len_sweep_config +from benchmark_model_configs import estimate_kernel_peak_memory +from benchmark_model_configs import get_benchmark_model_config from utils import QUANTILES from utils import SingleBenchmarkRunInput from utils import SingleBenchmarkRunOutput @@ -84,125 +90,124 @@ def forward(self, student: torch.Tensor, teacher: torch.Tensor, target: torch.Te ) -def bench_memory_cosine_similarity_loss(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: +def _setup_distill_cosine_loss(input: SingleBenchmarkRunInput): + """Create input tensors and cosine similarity loss from benchmark config.""" + cfg = input.extra_benchmark_config + H = cfg["hidden_size"] + V = cfg["vocab_size"] + dtype = cfg["dtype"] + bias = cfg["bias"] + weight_hard_loss = cfg["weight_hard_loss"] + weight_soft_loss = cfg["weight_soft_loss"] + ignore_index = cfg["ignore_index"] BT = input.x - H = input.extra_benchmark_config["H"] - V = input.extra_benchmark_config["V"] - dtype = input.extra_benchmark_config["dtype"] - bias = input.extra_benchmark_config["bias"] - weight_hard_loss = input.extra_benchmark_config["weight_hard_loss"] - weight_soft_loss = input.extra_benchmark_config["weight_soft_loss"] - ignore_index = input.extra_benchmark_config["ignore_index"] - provider = input.kernel_provider - - torch_cosine_loss = TorchCosineSimilarityLoss( - H=H, - V=V, - dtype=dtype, - weight_hard_loss=weight_hard_loss, - weight_soft_loss=weight_soft_loss, - bias=bias, - ).to(device) - liger_cosine_loss = LigerCosineSimilarityLoss( - H=H, - V=V, - dtype=dtype, - ignore_index=ignore_index, - bias=bias, - weight_hard_loss=weight_hard_loss, - weight_soft_loss=weight_soft_loss, - ).to(device) _tensor = torch.rand(BT, H // 2, device=device, dtype=dtype) - student_input1 = _tensor.detach().clone().requires_grad_(True) - student_input2 = _tensor.detach().clone().requires_grad_(True) - + student_input = _tensor.detach().clone().requires_grad_(True) teacher_input = torch.rand(BT, H, device=device, dtype=dtype) - target = torch.randint(0, V, (BT,), device=device, dtype=torch.long) + if input.kernel_provider == "liger": + loss_module = LigerCosineSimilarityLoss( + H=H, + V=V, + dtype=dtype, + ignore_index=ignore_index, + bias=bias, + weight_hard_loss=weight_hard_loss, + weight_soft_loss=weight_soft_loss, + ).to(device) + elif input.kernel_provider == "torch": + loss_module = TorchCosineSimilarityLoss( + H=H, + V=V, + dtype=dtype, + ignore_index=ignore_index, + bias=bias, + weight_hard_loss=weight_hard_loss, + weight_soft_loss=weight_soft_loss, + ).to(device) + else: + raise ValueError(f"Invalid provider: {input.kernel_provider} for DistillCosineLoss") + return student_input, teacher_input, target, loss_module + + +def bench_speed_distill_cosine_loss(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + student_input, teacher_input, target, loss_module = _setup_distill_cosine_loss(input) + mode = input.kernel_operation_mode + def fwd(): - if provider == "liger": - return liger_cosine_loss(student_input1, teacher_input, target) - elif provider == "torch": - return torch_cosine_loss(student_input2, teacher_input, target) + return loss_module(student_input, teacher_input, target) - def full(): + if mode == "forward": + ms_50, ms_20, ms_80 = triton.testing.do_bench(fwd, rep=100, quantiles=QUANTILES) + elif mode == "backward": y = fwd() - y.backward() + ms_50, ms_20, ms_80 = triton.testing.do_bench( + lambda: y.backward(retain_graph=True), + grad_to_none=[student_input], + rep=100, + quantiles=QUANTILES, + ) + elif mode == "full": - mem_50, mem_20, mem_80 = _test_memory(full, _iter=10, quantiles=QUANTILES) - return SingleBenchmarkRunOutput( - y_20=mem_20, - y_50=mem_50, - y_80=mem_80, - ) + def full(): + y = fwd() + y.backward() + ms_50, ms_20, ms_80 = triton.testing.do_bench(full, rep=100, quantiles=QUANTILES) + else: + raise ValueError(f"Unsupported mode: {mode}") -def bench_speed_cosine_similarity_loss(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: - BT = input.x - H = input.extra_benchmark_config["H"] - V = input.extra_benchmark_config["V"] - dtype = input.extra_benchmark_config["dtype"] - bias = input.extra_benchmark_config["bias"] - weight_hard_loss = input.extra_benchmark_config["weight_hard_loss"] - weight_soft_loss = input.extra_benchmark_config["weight_soft_loss"] - ignore_index = input.extra_benchmark_config["ignore_index"] - provider = input.kernel_provider - mode = input.kernel_operation_mode + return SingleBenchmarkRunOutput(y_20=ms_20, y_50=ms_50, y_80=ms_80) - torch_cosine_loss = TorchCosineSimilarityLoss( - H=H, - V=V, - dtype=dtype, - ignore_index=ignore_index, - bias=bias, - weight_hard_loss=weight_hard_loss, - weight_soft_loss=weight_soft_loss, - ).to(device) - - liger_cosine_loss = LigerCosineSimilarityLoss( - H=H, - V=V, - dtype=dtype, - ignore_index=ignore_index, - bias=bias, - weight_hard_loss=weight_hard_loss, - weight_soft_loss=weight_soft_loss, - ).to(device) - _tensor = torch.rand(BT, H // 2, device=device, dtype=dtype) - student_input1 = _tensor.detach().clone().requires_grad_(True) - student_input2 = _tensor.detach().clone().requires_grad_(True) +def bench_memory_distill_cosine_loss(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + student_input, teacher_input, target, loss_module = _setup_distill_cosine_loss(input) - teacher_input = torch.rand(BT, H, device=device, dtype=dtype) + def full(): + y = loss_module(student_input, teacher_input, target) + y.backward() - target = torch.randint(0, V, (BT,), device=device, dtype=torch.long) + mem_50, mem_20, mem_80 = _test_memory(full, _iter=10, quantiles=QUANTILES) + return SingleBenchmarkRunOutput(y_20=mem_20, y_50=mem_50, y_80=mem_80) + + +def _resolve_model_config_distill_cosine_loss(input: SingleBenchmarkRunInput): + """Resolve model-config-sweep input into standard setup args.""" + cfg = input.extra_benchmark_config + model_info = cfg["model_configs"][input.x] + return _setup_distill_cosine_loss( + SingleBenchmarkRunInput( + x=cfg["BT"], + kernel_provider=input.kernel_provider, + extra_benchmark_config={ + "hidden_size": model_info["hidden_size"], + "vocab_size": model_info["vocab_size"], + "dtype": model_info["dtype"], + "bias": cfg["bias"], + "weight_hard_loss": cfg["weight_hard_loss"], + "weight_soft_loss": cfg["weight_soft_loss"], + "ignore_index": cfg["ignore_index"], + }, + ) + ) + + +def bench_speed_distill_cosine_loss_model_config(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + student_input, teacher_input, target, loss_module = _resolve_model_config_distill_cosine_loss(input) + mode = input.kernel_operation_mode def fwd(): - if provider == "liger": - return liger_cosine_loss(student_input1, teacher_input, target) - elif provider == "torch": - return torch_cosine_loss(student_input2, teacher_input, target) + return loss_module(student_input, teacher_input, target) if mode == "forward": - ms_50, ms_20, ms_80 = triton.testing.do_bench( - fwd, - rep=100, - quantiles=QUANTILES, - ) - elif mode == "backward": - y = fwd() - ms_50, ms_20, ms_80 = triton.testing.do_bench( - fwd, - rep=100, - quantiles=QUANTILES, - ) + ms_50, ms_20, ms_80 = triton.testing.do_bench(fwd, rep=100, quantiles=QUANTILES) elif mode == "backward": y = fwd() ms_50, ms_20, ms_80 = triton.testing.do_bench( lambda: y.backward(retain_graph=True), - grad_to_none=[student_input1, student_input2], + grad_to_none=[student_input], rep=100, quantiles=QUANTILES, ) @@ -212,55 +217,151 @@ def full(): y = fwd() y.backward() - ms_50, ms_20, ms_80 = triton.testing.do_bench( - full, - rep=100, - quantiles=QUANTILES, - ) + ms_50, ms_20, ms_80 = triton.testing.do_bench(full, rep=100, quantiles=QUANTILES) + else: + raise ValueError(f"Unsupported mode: {mode}") - return SingleBenchmarkRunOutput( - y_20=ms_20, - y_50=ms_50, - y_80=ms_80, - ) + return SingleBenchmarkRunOutput(y_20=ms_20, y_50=ms_50, y_80=ms_80) + + +def bench_memory_distill_cosine_loss_model_config(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + student_input, teacher_input, target, loss_module = _resolve_model_config_distill_cosine_loss(input) + + def full(): + y = loss_module(student_input, teacher_input, target) + y.backward() + + mem_50, mem_20, mem_80 = _test_memory(full, _iter=10, quantiles=QUANTILES) + return SingleBenchmarkRunOutput(y_20=mem_20, y_50=mem_50, y_80=mem_80) if __name__ == "__main__": args = parse_benchmark_script_args() - common_configs = { - "kernel_name": "distill_cosine_loss", - "x_name": "BT", - "x_label": "B x T", - "x_values": [2**i for i in range(10, 14)], - "kernel_providers": ["liger", "torch"], - "extra_benchmark_configs": [ - { - "H": 4096, - "V": 128256, - "mode": "forward", - "dtype": torch.bfloat16, - "bias": False, - "weight_hard_loss": 0.5, - "weight_soft_loss": 0.5, - "ignore_index": -100, + if args.sweep_mode == "model_config": + all_model_configs = list(MODEL_REGISTRY.values()) + + def _probe_factory(model_cfg, probe_bt): + def _probe(): + probe_input = SingleBenchmarkRunInput( + x=probe_bt, + kernel_provider="torch", + extra_benchmark_config={ + "hidden_size": model_cfg.hidden_size, + "vocab_size": model_cfg.vocab_size, + "dtype": model_cfg.dtype, + "bias": False, + "weight_hard_loss": 0.5, + "weight_soft_loss": 0.5, + "ignore_index": -100, + }, + ) + student_input, teacher_input, target, loss_module = _setup_distill_cosine_loss(probe_input) + return loss_module(student_input, teacher_input, target) + + return _probe + + sweep = compute_model_config_sweep_config(all_model_configs, probe_fn_factory=_probe_factory, bt=args.bt) + + model_configs_info = { + cfg.name: { + "hidden_size": cfg.hidden_size, + "vocab_size": cfg.vocab_size, + "dtype": cfg.dtype, } - ], - "overwrite": args.overwrite, - } - - run_benchmarks( - bench_test_fn=bench_speed_cosine_similarity_loss, - kernel_operation_modes=["forward", "full"], - metric_name="speed", - metric_unit="ms", - **common_configs, - ) - - run_benchmarks( - bench_test_fn=bench_memory_cosine_similarity_loss, - kernel_operation_modes=["full"], - metric_name="memory", - metric_unit="MB", - **common_configs, - ) + for cfg in sweep.model_configs + } + + common_configs = { + "kernel_name": "distill_cosine_loss", + "x_name": "model_config", + "x_label": "model configuration", + "x_values": [cfg.name for cfg in sweep.model_configs], + "kernel_providers": ["liger", "torch"], + "extra_benchmark_configs": [ + { + "model_configs": model_configs_info, + "BT": sweep.bt, + "bias": False, + "weight_hard_loss": 0.5, + "weight_soft_loss": 0.5, + "ignore_index": -100, + } + ], + "overwrite": args.overwrite, + } + + run_benchmarks( + bench_test_fn=bench_speed_distill_cosine_loss_model_config, + kernel_operation_modes=["forward", "full"], + metric_name="speed", + metric_unit="ms", + **common_configs, + ) + run_benchmarks( + bench_test_fn=bench_memory_distill_cosine_loss_model_config, + kernel_operation_modes=["full"], + metric_name="memory", + metric_unit="MB", + **common_configs, + ) + else: + model = get_benchmark_model_config(args.model) + probe_bt = 1024 + + def _probe(): + probe_input = SingleBenchmarkRunInput( + x=probe_bt, + kernel_provider="torch", + extra_benchmark_config={ + "hidden_size": model.hidden_size, + "vocab_size": model.vocab_size, + "dtype": model.dtype, + "bias": False, + "weight_hard_loss": 0.5, + "weight_soft_loss": 0.5, + "ignore_index": -100, + }, + ) + student_input, teacher_input, target, loss_module = _setup_distill_cosine_loss(probe_input) + return loss_module(student_input, teacher_input, target) + + peak_bytes = estimate_kernel_peak_memory(probe_fn=_probe) + kernel_bpt = peak_bytes // probe_bt + + config = compute_seq_len_sweep_config(model, kernel_bytes_per_token=kernel_bpt) + + common_configs = { + "kernel_name": "distill_cosine_loss", + "x_name": "BT", + "x_label": "B * T", + "x_values": [2**i for i in range(10, int(math.log2(config.batch_size * config.seq_len)) + 1)], + "kernel_providers": ["liger", "torch"], + "extra_benchmark_configs": [ + { + "hidden_size": model.hidden_size, + "vocab_size": model.vocab_size, + "dtype": model.dtype, + "bias": False, + "weight_hard_loss": 0.5, + "weight_soft_loss": 0.5, + "ignore_index": -100, + } + ], + "overwrite": args.overwrite, + } + + run_benchmarks( + bench_test_fn=bench_speed_distill_cosine_loss, + kernel_operation_modes=["forward", "full"], + metric_name="speed", + metric_unit="ms", + **common_configs, + ) + run_benchmarks( + bench_test_fn=bench_memory_distill_cosine_loss, + kernel_operation_modes=["full"], + metric_name="memory", + metric_unit="MB", + **common_configs, + ) diff --git a/benchmark/scripts/benchmark_distill_jsd_loss.py b/benchmark/scripts/benchmark_distill_jsd_loss.py index 324418e17..a8ea3eaca 100644 --- a/benchmark/scripts/benchmark_distill_jsd_loss.py +++ b/benchmark/scripts/benchmark_distill_jsd_loss.py @@ -1,9 +1,15 @@ +import math import os import sys import torch import triton +from benchmark_model_configs import MODEL_REGISTRY +from benchmark_model_configs import compute_model_config_sweep_config +from benchmark_model_configs import compute_seq_len_sweep_config +from benchmark_model_configs import estimate_kernel_peak_memory +from benchmark_model_configs import get_benchmark_model_config from utils import QUANTILES from utils import SingleBenchmarkRunInput from utils import SingleBenchmarkRunOutput @@ -12,7 +18,6 @@ from utils import run_benchmarks from liger_kernel.chunked_loss.jsd_loss import LigerFusedLinearJSDFunction -from liger_kernel.utils import get_total_gpu_memory from liger_kernel.utils import infer_device device = infer_device() @@ -89,118 +94,124 @@ def forward(self, student, teacher, target): ) -def bench_memory_jsd_loss(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: +def _setup_distill_jsd_loss(input: SingleBenchmarkRunInput): + """Create input tensors and JSD loss from benchmark config.""" + cfg = input.extra_benchmark_config + H = cfg["hidden_size"] + V = cfg["vocab_size"] + dtype = cfg["dtype"] + bias = cfg["bias"] + weight_hard_loss = cfg["weight_hard_loss"] + weight_soft_loss = cfg["weight_soft_loss"] + ignore_index = cfg["ignore_index"] BT = input.x - H = input.extra_benchmark_config["H"] - V = input.extra_benchmark_config["V"] - dtype = input.extra_benchmark_config["dtype"] - bias = input.extra_benchmark_config["bias"] - weight_hard_loss = input.extra_benchmark_config["weight_hard_loss"] - weight_soft_loss = input.extra_benchmark_config["weight_soft_loss"] - ignore_index = input.extra_benchmark_config["ignore_index"] - provider = input.kernel_provider - - torch_jsd_loss = TorchJSDLoss( - H=H, - V=V, - dtype=dtype, - ignore_index=ignore_index, - bias=bias, - weight_hard_loss=weight_hard_loss, - weight_soft_loss=weight_soft_loss, - ).to(device) - liger_jsd_loss = LigerJSDLoss( - H=H, - V=V, - dtype=dtype, - ignore_index=ignore_index, - bias=bias, - weight_hard_loss=weight_hard_loss, - weight_soft_loss=weight_soft_loss, - ).to(device) _tensor = torch.rand(BT, H // 2, device=device, dtype=dtype) - student_input1 = _tensor.detach().clone().requires_grad_(True) - student_input2 = _tensor.detach().clone().requires_grad_(True) - + student_input = _tensor.detach().clone().requires_grad_(True) teacher_input = torch.rand(BT, H, device=device, dtype=dtype) - target = torch.randint(0, V, (BT,), device=device, dtype=torch.long) + if input.kernel_provider == "liger": + loss_module = LigerJSDLoss( + H=H, + V=V, + dtype=dtype, + ignore_index=ignore_index, + bias=bias, + weight_hard_loss=weight_hard_loss, + weight_soft_loss=weight_soft_loss, + ).to(device) + elif input.kernel_provider == "torch": + loss_module = TorchJSDLoss( + H=H, + V=V, + dtype=dtype, + ignore_index=ignore_index, + bias=bias, + weight_hard_loss=weight_hard_loss, + weight_soft_loss=weight_soft_loss, + ).to(device) + else: + raise ValueError(f"Invalid provider: {input.kernel_provider} for DistillJSDLoss") + return student_input, teacher_input, target, loss_module + + +def bench_speed_distill_jsd_loss(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + student_input, teacher_input, target, loss_module = _setup_distill_jsd_loss(input) + mode = input.kernel_operation_mode + def fwd(): - if provider == "liger": - return liger_jsd_loss(student_input1, teacher_input, target) - elif provider == "torch": - return torch_jsd_loss(student_input2, teacher_input, target) + return loss_module(student_input, teacher_input, target) - def full(): + if mode == "forward": + ms_50, ms_20, ms_80 = triton.testing.do_bench(fwd, rep=100, quantiles=QUANTILES) + elif mode == "backward": y = fwd() - y.backward() + ms_50, ms_20, ms_80 = triton.testing.do_bench( + lambda: y.backward(retain_graph=True), + grad_to_none=[student_input], + rep=100, + quantiles=QUANTILES, + ) + elif mode == "full": - mem_50, mem_20, mem_80 = _test_memory(full, _iter=10, quantiles=QUANTILES) - return SingleBenchmarkRunOutput( - y_20=mem_20, - y_50=mem_50, - y_80=mem_80, - ) + def full(): + y = fwd() + y.backward() + ms_50, ms_20, ms_80 = triton.testing.do_bench(full, rep=100, quantiles=QUANTILES) + else: + raise ValueError(f"Unsupported mode: {mode}") -def bench_speed_jsd_loss(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: - BT = input.x - H = input.extra_benchmark_config["H"] - V = input.extra_benchmark_config["V"] - dtype = input.extra_benchmark_config["dtype"] - bias = input.extra_benchmark_config["bias"] - weight_hard_loss = input.extra_benchmark_config["weight_hard_loss"] - weight_soft_loss = input.extra_benchmark_config["weight_soft_loss"] - ignore_index = input.extra_benchmark_config["ignore_index"] - provider = input.kernel_provider - mode = input.kernel_operation_mode + return SingleBenchmarkRunOutput(y_20=ms_20, y_50=ms_50, y_80=ms_80) - torch_jsd_loss = TorchJSDLoss( - H=H, - V=V, - dtype=dtype, - ignore_index=ignore_index, - bias=bias, - weight_hard_loss=weight_hard_loss, - weight_soft_loss=weight_soft_loss, - ).to(device) - liger_jsd_loss = LigerJSDLoss( - H=H, - V=V, - dtype=dtype, - ignore_index=ignore_index, - bias=bias, - weight_hard_loss=weight_hard_loss, - weight_soft_loss=weight_soft_loss, - ).to(device) - _tensor = torch.rand(BT, H // 2, device=device, dtype=dtype) - student_input1 = _tensor.detach().clone().requires_grad_(True) - student_input2 = _tensor.detach().clone().requires_grad_(True) +def bench_memory_distill_jsd_loss(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + student_input, teacher_input, target, loss_module = _setup_distill_jsd_loss(input) - teacher_input = torch.rand(BT, H, device=device, dtype=dtype) + def full(): + y = loss_module(student_input, teacher_input, target) + y.backward() + + mem_50, mem_20, mem_80 = _test_memory(full, _iter=10, quantiles=QUANTILES) + return SingleBenchmarkRunOutput(y_20=mem_20, y_50=mem_50, y_80=mem_80) + + +def _resolve_model_config_distill_jsd_loss(input: SingleBenchmarkRunInput): + """Resolve model-config-sweep input into standard setup args.""" + cfg = input.extra_benchmark_config + model_info = cfg["model_configs"][input.x] + return _setup_distill_jsd_loss( + SingleBenchmarkRunInput( + x=cfg["BT"], + kernel_provider=input.kernel_provider, + extra_benchmark_config={ + "hidden_size": model_info["hidden_size"], + "vocab_size": model_info["vocab_size"], + "dtype": model_info["dtype"], + "bias": cfg["bias"], + "weight_hard_loss": cfg["weight_hard_loss"], + "weight_soft_loss": cfg["weight_soft_loss"], + "ignore_index": cfg["ignore_index"], + }, + ) + ) - target = torch.randint(0, V, (BT,), device=device, dtype=torch.long) + +def bench_speed_distill_jsd_loss_model_config(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + student_input, teacher_input, target, loss_module = _resolve_model_config_distill_jsd_loss(input) + mode = input.kernel_operation_mode def fwd(): - if provider == "liger": - return liger_jsd_loss(student_input1, teacher_input, target) - elif provider == "torch": - return torch_jsd_loss(student_input2, teacher_input, target) + return loss_module(student_input, teacher_input, target) if mode == "forward": - ms_50, ms_20, ms_80 = triton.testing.do_bench( - fwd, - rep=100, - quantiles=QUANTILES, - ) + ms_50, ms_20, ms_80 = triton.testing.do_bench(fwd, rep=100, quantiles=QUANTILES) elif mode == "backward": y = fwd() ms_50, ms_20, ms_80 = triton.testing.do_bench( lambda: y.backward(retain_graph=True), - grad_to_none=[student_input1, student_input2], + grad_to_none=[student_input], rep=100, quantiles=QUANTILES, ) @@ -210,63 +221,151 @@ def full(): y = fwd() y.backward() - ms_50, ms_20, ms_80 = triton.testing.do_bench( - full, - rep=100, - quantiles=QUANTILES, - ) + ms_50, ms_20, ms_80 = triton.testing.do_bench(full, rep=100, quantiles=QUANTILES) + else: + raise ValueError(f"Unsupported mode: {mode}") - return SingleBenchmarkRunOutput( - y_20=ms_20, - y_50=ms_50, - y_80=ms_80, - ) + return SingleBenchmarkRunOutput(y_20=ms_20, y_50=ms_50, y_80=ms_80) + + +def bench_memory_distill_jsd_loss_model_config(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + student_input, teacher_input, target, loss_module = _resolve_model_config_distill_jsd_loss(input) + + def full(): + y = loss_module(student_input, teacher_input, target) + y.backward() + + mem_50, mem_20, mem_80 = _test_memory(full, _iter=10, quantiles=QUANTILES) + return SingleBenchmarkRunOutput(y_20=mem_20, y_50=mem_50, y_80=mem_80) if __name__ == "__main__": args = parse_benchmark_script_args() - gpu_memory_gbs = get_total_gpu_memory() - # We know that the full test will require 69GBs for vocab size 2^13 and 39GBs for vocab size 2^12 on torch - if gpu_memory_gbs >= 69: - x_max = 13 - elif gpu_memory_gbs >= 39: - x_max = 12 - else: - x_max = 11 - - common_configs = { - "kernel_name": "distill_jsd_loss", - "x_name": "BT", - "x_label": "B x T", - "x_values": [2**i for i in range(10, x_max + 1)], - "kernel_providers": ["liger", "torch"], - "extra_benchmark_configs": [ - { - "H": 4096, - "V": 128256, - "mode": "forward", - "dtype": torch.bfloat16, - "bias": False, - "weight_hard_loss": 0.5, - "weight_soft_loss": 0.5, - "ignore_index": -100, - } - ], - "overwrite": args.overwrite, - } - - run_benchmarks( - bench_test_fn=bench_speed_jsd_loss, - kernel_operation_modes=["forward", "backward", "full"], - metric_name="speed", - metric_unit="ms", - **common_configs, - ) - run_benchmarks( - bench_test_fn=bench_memory_jsd_loss, - kernel_operation_modes=["full"], - metric_name="memory", - metric_unit="MB", - **common_configs, - ) + if args.sweep_mode == "model_config": + all_model_configs = list(MODEL_REGISTRY.values()) + + def _probe_factory(model_cfg, probe_bt): + def _probe(): + probe_input = SingleBenchmarkRunInput( + x=probe_bt, + kernel_provider="torch", + extra_benchmark_config={ + "hidden_size": model_cfg.hidden_size, + "vocab_size": model_cfg.vocab_size, + "dtype": model_cfg.dtype, + "bias": False, + "weight_hard_loss": 0.5, + "weight_soft_loss": 0.5, + "ignore_index": -100, + }, + ) + student_input, teacher_input, target, loss_module = _setup_distill_jsd_loss(probe_input) + return loss_module(student_input, teacher_input, target) + + return _probe + + sweep = compute_model_config_sweep_config(all_model_configs, probe_fn_factory=_probe_factory, bt=args.bt) + + model_configs_info = { + cfg.name: { + "hidden_size": cfg.hidden_size, + "vocab_size": cfg.vocab_size, + "dtype": cfg.dtype, + } + for cfg in sweep.model_configs + } + + common_configs = { + "kernel_name": "distill_jsd_loss", + "x_name": "model_config", + "x_label": "model configuration", + "x_values": [cfg.name for cfg in sweep.model_configs], + "kernel_providers": ["liger", "torch"], + "extra_benchmark_configs": [ + { + "model_configs": model_configs_info, + "BT": sweep.bt, + "bias": False, + "weight_hard_loss": 0.5, + "weight_soft_loss": 0.5, + "ignore_index": -100, + } + ], + "overwrite": args.overwrite, + } + + run_benchmarks( + bench_test_fn=bench_speed_distill_jsd_loss_model_config, + kernel_operation_modes=["forward", "backward", "full"], + metric_name="speed", + metric_unit="ms", + **common_configs, + ) + run_benchmarks( + bench_test_fn=bench_memory_distill_jsd_loss_model_config, + kernel_operation_modes=["full"], + metric_name="memory", + metric_unit="MB", + **common_configs, + ) + else: + model = get_benchmark_model_config(args.model) + probe_bt = 1024 + + def _probe(): + probe_input = SingleBenchmarkRunInput( + x=probe_bt, + kernel_provider="torch", + extra_benchmark_config={ + "hidden_size": model.hidden_size, + "vocab_size": model.vocab_size, + "dtype": model.dtype, + "bias": False, + "weight_hard_loss": 0.5, + "weight_soft_loss": 0.5, + "ignore_index": -100, + }, + ) + student_input, teacher_input, target, loss_module = _setup_distill_jsd_loss(probe_input) + return loss_module(student_input, teacher_input, target) + + peak_bytes = estimate_kernel_peak_memory(probe_fn=_probe) + kernel_bpt = peak_bytes // probe_bt + + config = compute_seq_len_sweep_config(model, kernel_bytes_per_token=kernel_bpt) + + common_configs = { + "kernel_name": "distill_jsd_loss", + "x_name": "BT", + "x_label": "B * T", + "x_values": [2**i for i in range(10, int(math.log2(config.batch_size * config.seq_len)) + 1)], + "kernel_providers": ["liger", "torch"], + "extra_benchmark_configs": [ + { + "hidden_size": model.hidden_size, + "vocab_size": model.vocab_size, + "dtype": model.dtype, + "bias": False, + "weight_hard_loss": 0.5, + "weight_soft_loss": 0.5, + "ignore_index": -100, + } + ], + "overwrite": args.overwrite, + } + + run_benchmarks( + bench_test_fn=bench_speed_distill_jsd_loss, + kernel_operation_modes=["forward", "backward", "full"], + metric_name="speed", + metric_unit="ms", + **common_configs, + ) + run_benchmarks( + bench_test_fn=bench_memory_distill_jsd_loss, + kernel_operation_modes=["full"], + metric_name="memory", + metric_unit="MB", + **common_configs, + ) diff --git a/benchmark/scripts/benchmark_dpo_loss.py b/benchmark/scripts/benchmark_dpo_loss.py index 228a228d5..17c793333 100644 --- a/benchmark/scripts/benchmark_dpo_loss.py +++ b/benchmark/scripts/benchmark_dpo_loss.py @@ -1,9 +1,15 @@ +import math import os import sys import torch import triton +from benchmark_model_configs import MODEL_REGISTRY +from benchmark_model_configs import compute_model_config_sweep_config +from benchmark_model_configs import compute_seq_len_sweep_config +from benchmark_model_configs import estimate_kernel_peak_memory +from benchmark_model_configs import get_benchmark_model_config from utils import QUANTILES from utils import SingleBenchmarkRunInput from utils import SingleBenchmarkRunOutput @@ -18,30 +24,23 @@ sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))) -def bench_memory_dpo_loss(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: +def _setup_dpo_loss(input: SingleBenchmarkRunInput): + """Create input tensors and DPO loss from benchmark config.""" from test.chunked_loss.test_dpo_loss import LigerLMHeadDPO from test.chunked_loss.test_dpo_loss import TorchLMHeadDPO + cfg = input.extra_benchmark_config + H = cfg["hidden_size"] + V = cfg["vocab_size"] + dtype = cfg["dtype"] + bias = cfg["bias"] + beta = cfg["beta"] + ignore_index = cfg["ignore_index"] B = input.x - T = input.extra_benchmark_config["T"] - H = input.extra_benchmark_config["H"] - V = input.extra_benchmark_config["V"] - dtype = input.extra_benchmark_config["dtype"] - bias = input.extra_benchmark_config["bias"] - beta = input.extra_benchmark_config["beta"] - ignore_index = input.extra_benchmark_config["ignore_index"] - provider = input.kernel_provider - - # Instantiate once and retrieve the first output only - torch_dpo_loss = TorchLMHeadDPO(H=H, V=V, dtype=dtype, beta=beta, ignore_index=ignore_index, bias=bias).to(device) - liger_dpo_loss = LigerLMHeadDPO(H=H, V=V, dtype=dtype, beta=beta, ignore_index=ignore_index, bias=bias).to(device) - torch_fwd = lambda x, ref_x, target: torch_dpo_loss(x, ref_x, target)[0] - liger_fwd = lambda x, ref_x, target: liger_dpo_loss(x, ref_x, target)[0] - - # Input shape: [B, T, H] + T = cfg["T"] + _input = torch.randn(B, T, H, device=device, dtype=dtype) ref_input = torch.randn(B, T, H, device=device, dtype=dtype, requires_grad=False) - # Target shape: [B, T] target = torch.randint(V, (B, T), dtype=torch.long, device=device) # Add ignore_index tokens to simulate padding @@ -49,70 +48,25 @@ def bench_memory_dpo_loss(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunO indices_to_assign = torch.randperm(B * T)[:num_elements_to_assign] target.view(-1)[indices_to_assign] = ignore_index - def fwd(): - if provider == "liger": - return liger_fwd(_input, ref_input, target) - elif provider == "huggingface": - return torch_fwd(_input, ref_input, target) - - def full(): - y = fwd() - y.backward() + if input.kernel_provider == "liger": + loss_module = LigerLMHeadDPO(H=H, V=V, dtype=dtype, beta=beta, ignore_index=ignore_index, bias=bias).to(device) + elif input.kernel_provider == "huggingface": + loss_module = TorchLMHeadDPO(H=H, V=V, dtype=dtype, beta=beta, ignore_index=ignore_index, bias=bias).to(device) + else: + raise ValueError(f"Invalid provider: {input.kernel_provider} for DPOLoss") - mem_50, mem_20, mem_80 = _test_memory(full, _iter=10, quantiles=QUANTILES) - return SingleBenchmarkRunOutput( - y_20=mem_20, - y_50=mem_50, - y_80=mem_80, - ) + fwd_fn = lambda: loss_module(_input, ref_input, target)[0] + return _input, fwd_fn def bench_speed_dpo_loss(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: - from test.chunked_loss.test_dpo_loss import LigerLMHeadDPO - from test.chunked_loss.test_dpo_loss import TorchLMHeadDPO - - B = input.x - T = input.extra_benchmark_config["T"] - H = input.extra_benchmark_config["H"] - V = input.extra_benchmark_config["V"] - dtype = input.extra_benchmark_config["dtype"] - bias = input.extra_benchmark_config["bias"] - beta = input.extra_benchmark_config["beta"] - ignore_index = input.extra_benchmark_config["ignore_index"] - provider = input.kernel_provider + _input, fwd_fn = _setup_dpo_loss(input) mode = input.kernel_operation_mode - # Instantiate once and retrieve the first output only - torch_dpo_loss = TorchLMHeadDPO(H=H, V=V, dtype=dtype, beta=beta, ignore_index=ignore_index, bias=bias).to(device) - liger_dpo_loss = LigerLMHeadDPO(H=H, V=V, dtype=dtype, beta=beta, ignore_index=ignore_index, bias=bias).to(device) - torch_fwd = lambda x, ref_x, target: torch_dpo_loss(x, ref_x, target)[0] - liger_fwd = lambda x, ref_x, target: liger_dpo_loss(x, ref_x, target)[0] - - # Input shape: [B, T, H] - _input = torch.randn(B, T, H, device=device, dtype=dtype) - ref_input = torch.randn(B, T, H, device=device, dtype=dtype, requires_grad=False) - # Target shape: [B, T] - target = torch.randint(V, (B, T), device=device, dtype=torch.long) - - # Add ignore_index tokens - num_elements_to_assign = torch.randint(1, B * T // 2, (1,)).item() - indices_to_assign = torch.randperm(B * T)[:num_elements_to_assign] - target.view(-1)[indices_to_assign] = ignore_index - - def fwd(): - if provider == "liger": - return liger_fwd(_input, ref_input, target) - elif provider == "huggingface": - return torch_fwd(_input, ref_input, target) - if mode == "forward": - ms_50, ms_20, ms_80 = triton.testing.do_bench( - fwd, - rep=100, - quantiles=QUANTILES, - ) + ms_50, ms_20, ms_80 = triton.testing.do_bench(fwd_fn, rep=100, quantiles=QUANTILES) elif mode == "backward": - y = fwd() + y = fwd_fn() ms_50, ms_20, ms_80 = triton.testing.do_bench( lambda: y.backward(retain_graph=True), grad_to_none=[_input], @@ -122,58 +76,220 @@ def fwd(): elif mode == "full": def full(): - y = fwd() + y = fwd_fn() y.backward() + ms_50, ms_20, ms_80 = triton.testing.do_bench(full, rep=100, quantiles=QUANTILES) + else: + raise ValueError(f"Unsupported mode: {mode}") + + return SingleBenchmarkRunOutput(y_20=ms_20, y_50=ms_50, y_80=ms_80) + + +def bench_memory_dpo_loss(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + _input, fwd_fn = _setup_dpo_loss(input) + + def full(): + y = fwd_fn() + y.backward() + + mem_50, mem_20, mem_80 = _test_memory(full, _iter=10, quantiles=QUANTILES) + return SingleBenchmarkRunOutput(y_20=mem_20, y_50=mem_50, y_80=mem_80) + + +def _resolve_model_config_dpo_loss(input: SingleBenchmarkRunInput): + """Resolve model-config-sweep input into standard setup args.""" + cfg = input.extra_benchmark_config + model_info = cfg["model_configs"][input.x] + return _setup_dpo_loss( + SingleBenchmarkRunInput( + x=cfg["B"], + kernel_provider=input.kernel_provider, + extra_benchmark_config={ + "hidden_size": model_info["hidden_size"], + "vocab_size": model_info["vocab_size"], + "dtype": model_info["dtype"], + "T": cfg["T"], + "bias": cfg["bias"], + "beta": cfg["beta"], + "ignore_index": cfg["ignore_index"], + }, + ) + ) + + +def bench_speed_dpo_loss_model_config(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + _input, fwd_fn = _resolve_model_config_dpo_loss(input) + mode = input.kernel_operation_mode + + if mode == "forward": + ms_50, ms_20, ms_80 = triton.testing.do_bench(fwd_fn, rep=100, quantiles=QUANTILES) + elif mode == "backward": + y = fwd_fn() ms_50, ms_20, ms_80 = triton.testing.do_bench( - full, + lambda: y.backward(retain_graph=True), + grad_to_none=[_input], rep=100, quantiles=QUANTILES, ) + elif mode == "full": - return SingleBenchmarkRunOutput( - y_20=ms_20, - y_50=ms_50, - y_80=ms_80, - ) + def full(): + y = fwd_fn() + y.backward() + + ms_50, ms_20, ms_80 = triton.testing.do_bench(full, rep=100, quantiles=QUANTILES) + else: + raise ValueError(f"Unsupported mode: {mode}") + + return SingleBenchmarkRunOutput(y_20=ms_20, y_50=ms_50, y_80=ms_80) + + +def bench_memory_dpo_loss_model_config(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + _input, fwd_fn = _resolve_model_config_dpo_loss(input) + + def full(): + y = fwd_fn() + y.backward() + + mem_50, mem_20, mem_80 = _test_memory(full, _iter=10, quantiles=QUANTILES) + return SingleBenchmarkRunOutput(y_20=mem_20, y_50=mem_50, y_80=mem_80) if __name__ == "__main__": args = parse_benchmark_script_args() - common_configs = { - "kernel_name": "dpo_loss", - "x_name": "B", - "x_label": "Batch Size (B)", - "x_values": [2**i for i in range(1, 6)], - "kernel_providers": ["liger", "huggingface"], - "extra_benchmark_configs": [ - { - "T": 512, - "H": 1024, - "V": 128256, - "mode": "forward", - "dtype": torch.bfloat16, - "bias": True, - "beta": 0.1, - "ignore_index": 42, + if args.sweep_mode == "model_config": + all_model_configs = list(MODEL_REGISTRY.values()) + + def _probe_factory(model_cfg, probe_bt): + def _probe(): + T = 512 + B = max(1, probe_bt // T) + probe_input = SingleBenchmarkRunInput( + x=B, + kernel_provider="huggingface", + extra_benchmark_config={ + "hidden_size": model_cfg.hidden_size, + "vocab_size": model_cfg.vocab_size, + "dtype": model_cfg.dtype, + "T": T, + "bias": True, + "beta": 0.1, + "ignore_index": 42, + }, + ) + _, fwd_fn = _setup_dpo_loss(probe_input) + return fwd_fn() + + return _probe + + sweep = compute_model_config_sweep_config(all_model_configs, probe_fn_factory=_probe_factory, bt=args.bt) + + model_configs_info = { + cfg.name: { + "hidden_size": cfg.hidden_size, + "vocab_size": cfg.vocab_size, + "dtype": cfg.dtype, } - ], - "overwrite": args.overwrite, - } - - run_benchmarks( - bench_test_fn=bench_speed_dpo_loss, - kernel_operation_modes=["forward", "backward", "full"], - metric_name="speed", - metric_unit="ms", - **common_configs, - ) + for cfg in sweep.model_configs + } - run_benchmarks( - bench_test_fn=bench_memory_dpo_loss, - kernel_operation_modes=["full"], - metric_name="memory", - metric_unit="MB", - **common_configs, - ) + T = 512 + B = max(1, sweep.bt // T) + + common_configs = { + "kernel_name": "dpo_loss", + "x_name": "model_config", + "x_label": "model configuration", + "x_values": [cfg.name for cfg in sweep.model_configs], + "kernel_providers": ["liger", "huggingface"], + "extra_benchmark_configs": [ + { + "model_configs": model_configs_info, + "B": B, + "T": T, + "bias": True, + "beta": 0.1, + "ignore_index": 42, + } + ], + "overwrite": args.overwrite, + } + + run_benchmarks( + bench_test_fn=bench_speed_dpo_loss_model_config, + kernel_operation_modes=["forward", "backward", "full"], + metric_name="speed", + metric_unit="ms", + **common_configs, + ) + run_benchmarks( + bench_test_fn=bench_memory_dpo_loss_model_config, + kernel_operation_modes=["full"], + metric_name="memory", + metric_unit="MB", + **common_configs, + ) + else: + model = get_benchmark_model_config(args.model) + T = 512 + probe_bt = 1024 + + def _probe(): + B = max(1, probe_bt // T) + probe_input = SingleBenchmarkRunInput( + x=B, + kernel_provider="huggingface", + extra_benchmark_config={ + "hidden_size": model.hidden_size, + "vocab_size": model.vocab_size, + "dtype": model.dtype, + "T": T, + "bias": True, + "beta": 0.1, + "ignore_index": 42, + }, + ) + _, fwd_fn = _setup_dpo_loss(probe_input) + return fwd_fn() + + peak_bytes = estimate_kernel_peak_memory(probe_fn=_probe) + kernel_bpt = peak_bytes // probe_bt + + config = compute_seq_len_sweep_config(model, kernel_bytes_per_token=kernel_bpt) + + common_configs = { + "kernel_name": "dpo_loss", + "x_name": "B", + "x_label": "Batch Size (B)", + "x_values": [2**i for i in range(1, int(math.log2(config.batch_size * config.seq_len // T)) + 1)], + "kernel_providers": ["liger", "huggingface"], + "extra_benchmark_configs": [ + { + "hidden_size": model.hidden_size, + "vocab_size": model.vocab_size, + "dtype": model.dtype, + "T": T, + "bias": True, + "beta": 0.1, + "ignore_index": 42, + } + ], + "overwrite": args.overwrite, + } + + run_benchmarks( + bench_test_fn=bench_speed_dpo_loss, + kernel_operation_modes=["forward", "backward", "full"], + metric_name="speed", + metric_unit="ms", + **common_configs, + ) + run_benchmarks( + bench_test_fn=bench_memory_dpo_loss, + kernel_operation_modes=["full"], + metric_name="memory", + metric_unit="MB", + **common_configs, + ) diff --git a/benchmark/scripts/benchmark_embedding.py b/benchmark/scripts/benchmark_embedding.py index 2bd0c60be..09ea62c61 100644 --- a/benchmark/scripts/benchmark_embedding.py +++ b/benchmark/scripts/benchmark_embedding.py @@ -1,6 +1,15 @@ +import math +import os +import sys + import torch import triton +from benchmark_model_configs import MODEL_REGISTRY +from benchmark_model_configs import compute_model_config_sweep_config +from benchmark_model_configs import compute_seq_len_sweep_config +from benchmark_model_configs import estimate_kernel_peak_memory +from benchmark_model_configs import get_benchmark_model_config from torch.nn import Embedding from utils import QUANTILES from utils import SingleBenchmarkRunInput @@ -14,42 +23,47 @@ device = infer_device() -# NOTE: For torch compile, we will just use default inductor settings. No further customization -# is needed. - +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))) -def bench_speed_embedding(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: - V = input.x - provider = input.kernel_provider - mode = input.kernel_operation_mode - B = input.extra_benchmark_config["B"] - T = input.extra_benchmark_config["T"] - D = input.extra_benchmark_config["D"] - dtype = input.extra_benchmark_config["dtype"] +def _setup_embedding(input: SingleBenchmarkRunInput): + """Create input tensors and embedding module from benchmark config.""" + cfg = input.extra_benchmark_config + V = cfg.get("vocab_size", input.x) + D = cfg["hidden_size"] + dtype = cfg["dtype"] + BT = cfg.get("BT", input.x) + T = cfg.get("T", 512) + B = max(1, BT // T) if "BT" not in cfg else BT // T - torch_emb = Embedding(V, D).to(device).to(dtype) - liger_emb = LigerEmbedding(V, D).to(device).to(dtype) - torch_compile_emb = torch.compile(torch_emb) + # If BT is the x value, compute B from BT and T + if "BT" not in cfg: + B = max(1, input.x // T) + BT = B * T input_ids = torch.randint(0, V, (B, T), device=device) - def fwd(): - if provider == "liger": - return liger_emb(input_ids) - elif provider == "torch_compile": - return torch_compile_emb(input_ids) - else: - return torch_emb(input_ids) + if input.kernel_provider == "liger": + emb = LigerEmbedding(V, D).to(device).to(dtype) + elif input.kernel_provider == "torch_compile": + emb = torch.compile(Embedding(V, D).to(device).to(dtype)) + elif input.kernel_provider == "huggingface": + emb = Embedding(V, D).to(device).to(dtype) + else: + raise ValueError(f"Invalid provider: {input.kernel_provider} for embedding") + + fwd_fn = lambda: emb(input_ids) + return input_ids, fwd_fn - def full(): - output = fwd() - output.backward(torch.randn_like(output)) + +def bench_speed_embedding(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + input_ids, fwd_fn = _setup_embedding(input) + mode = input.kernel_operation_mode if mode == "forward": - ms_50, ms_20, ms_80 = triton.testing.do_bench(fwd, quantiles=QUANTILES, rep=100) + ms_50, ms_20, ms_80 = triton.testing.do_bench(fwd_fn, quantiles=QUANTILES, rep=100) elif mode == "backward": - output = fwd() + output = fwd_fn() ms_50, ms_20, ms_80 = triton.testing.do_bench( lambda: output.backward(torch.randn_like(output), retain_graph=True), quantiles=QUANTILES, @@ -57,78 +71,189 @@ def full(): rep=100, ) elif mode == "full": + + def full(): + output = fwd_fn() + output.backward(torch.randn_like(output)) + ms_50, ms_20, ms_80 = triton.testing.do_bench(full, quantiles=QUANTILES, rep=100) - return SingleBenchmarkRunOutput( - y_20=ms_20, - y_50=ms_50, - y_80=ms_80, - ) + else: + raise ValueError(f"Unsupported mode: {mode}") + return SingleBenchmarkRunOutput(y_20=ms_20, y_50=ms_50, y_80=ms_80) def bench_memory_embedding(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: - V = input.x - provider = input.kernel_provider + input_ids, fwd_fn = _setup_embedding(input) - B = input.extra_benchmark_config["B"] - T = input.extra_benchmark_config["T"] - D = input.extra_benchmark_config["D"] - dtype = input.extra_benchmark_config["dtype"] + def full(): + output = fwd_fn() + output.backward(torch.randn_like(output)) - torch_emb = Embedding(V, D).to(device).to(dtype) - liger_emb = LigerEmbedding(V, D).to(device).to(dtype) - torch_compile_emb = torch.compile(torch_emb) + mem_50, mem_20, mem_80 = _test_memory(full, quantiles=QUANTILES) + return SingleBenchmarkRunOutput(y_20=mem_20, y_50=mem_50, y_80=mem_80) - input_ids = torch.randint(0, V, (B, T), device=device) - def fwd(): - if provider == "liger": - return liger_emb(input_ids) - elif provider == "torch_compile": - return torch_compile_emb(input_ids) - else: - return torch_emb(input_ids) +def _resolve_model_config_embedding(input: SingleBenchmarkRunInput): + cfg = input.extra_benchmark_config + model_info = cfg["model_configs"][input.x] + return _setup_embedding( + SingleBenchmarkRunInput( + x=input.x, + kernel_provider=input.kernel_provider, + extra_benchmark_config={ + "vocab_size": model_info["vocab_size"], + "hidden_size": model_info["hidden_size"], + "dtype": model_info["dtype"], + "BT": cfg["BT"], + "T": cfg["T"], + }, + ) + ) + + +def bench_speed_embedding_model_config(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + input_ids, fwd_fn = _resolve_model_config_embedding(input) + mode = input.kernel_operation_mode + + if mode == "forward": + ms_50, ms_20, ms_80 = triton.testing.do_bench(fwd_fn, quantiles=QUANTILES, rep=100) + elif mode == "backward": + output = fwd_fn() + ms_50, ms_20, ms_80 = triton.testing.do_bench( + lambda: output.backward(torch.randn_like(output), retain_graph=True), + quantiles=QUANTILES, + grad_to_none=[input_ids], + rep=100, + ) + elif mode == "full": + + def full(): + output = fwd_fn() + output.backward(torch.randn_like(output)) + + ms_50, ms_20, ms_80 = triton.testing.do_bench(full, quantiles=QUANTILES, rep=100) + else: + raise ValueError(f"Unsupported mode: {mode}") + return SingleBenchmarkRunOutput(y_20=ms_20, y_50=ms_50, y_80=ms_80) + + +def bench_memory_embedding_model_config(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + input_ids, fwd_fn = _resolve_model_config_embedding(input) def full(): - output = fwd() + output = fwd_fn() output.backward(torch.randn_like(output)) mem_50, mem_20, mem_80 = _test_memory(full, quantiles=QUANTILES) - return SingleBenchmarkRunOutput( - y_20=mem_20, - y_50=mem_50, - y_80=mem_80, - ) + return SingleBenchmarkRunOutput(y_20=mem_20, y_50=mem_50, y_80=mem_80) if __name__ == "__main__": args = parse_benchmark_script_args() - common_configs = { - "kernel_name": "embedding", - "x_name": "V", - "x_label": "embedding dimension", - "x_values": [2**i for i in range(10, 18)], - "kernel_providers": ["liger", "huggingface", "torch_compile"], - "extra_benchmark_configs": [ - # BERT - {"B": 32, "T": 512, "D": 768, "dtype": torch.float32}, - # Llama - {"B": 8, "T": 2048, "D": 4096, "dtype": torch.float32}, - ], - "overwrite": args.overwrite, - } - - run_benchmarks( - bench_test_fn=bench_speed_embedding, - kernel_operation_modes=["forward", "backward", "full"], - metric_name="speed", - metric_unit="ms", - **common_configs, - ) - run_benchmarks( - bench_test_fn=bench_memory_embedding, - kernel_operation_modes=["full"], - metric_name="memory", - metric_unit="MB", - **common_configs, - ) + if args.sweep_mode == "model_config": + all_model_configs = list(MODEL_REGISTRY.values()) + T = 512 + BT = 2048 + + def _probe_factory(model_cfg, probe_bt): + def _probe(): + B = max(1, probe_bt // T) + probe_input = SingleBenchmarkRunInput( + x=0, + kernel_provider="huggingface", + extra_benchmark_config={ + "vocab_size": model_cfg.vocab_size, + "hidden_size": model_cfg.hidden_size, + "dtype": model_cfg.dtype, + "BT": B * T, + "T": T, + }, + ) + _, fwd_fn = _setup_embedding(probe_input) + return fwd_fn() + + return _probe + + sweep = compute_model_config_sweep_config(all_model_configs, probe_fn_factory=_probe_factory, bt=args.bt) + model_configs_info = { + cfg.name: {"vocab_size": cfg.vocab_size, "hidden_size": cfg.hidden_size, "dtype": cfg.dtype} + for cfg in sweep.model_configs + } + BT = sweep.bt + + common_configs = { + "kernel_name": "embedding", + "x_name": "model_config", + "x_label": "model configuration", + "x_values": [cfg.name for cfg in sweep.model_configs], + "kernel_providers": ["liger", "huggingface", "torch_compile"], + "extra_benchmark_configs": [{"model_configs": model_configs_info, "BT": BT, "T": T}], + "overwrite": args.overwrite, + } + + run_benchmarks( + bench_test_fn=bench_speed_embedding_model_config, + kernel_operation_modes=["forward", "backward", "full"], + metric_name="speed", + metric_unit="ms", + **common_configs, + ) + run_benchmarks( + bench_test_fn=bench_memory_embedding_model_config, + kernel_operation_modes=["full"], + metric_name="memory", + metric_unit="MB", + **common_configs, + ) + else: + model = get_benchmark_model_config(args.model) + T = 512 + probe_bt = 2048 + + def _probe(): + B = max(1, probe_bt // T) + probe_input = SingleBenchmarkRunInput( + x=0, + kernel_provider="huggingface", + extra_benchmark_config={ + "vocab_size": model.vocab_size, + "hidden_size": model.hidden_size, + "dtype": model.dtype, + "BT": B * T, + "T": T, + }, + ) + _, fwd_fn = _setup_embedding(probe_input) + return fwd_fn() + + peak_bytes = estimate_kernel_peak_memory(probe_fn=_probe) + kernel_bpt = peak_bytes // probe_bt + config = compute_seq_len_sweep_config(model, kernel_bytes_per_token=kernel_bpt) + + common_configs = { + "kernel_name": "embedding", + "x_name": "BT", + "x_label": "B x T", + "x_values": [2**i for i in range(10, int(math.log2(max(1024, config.batch_size * config.seq_len))) + 1)], + "kernel_providers": ["liger", "huggingface", "torch_compile"], + "extra_benchmark_configs": [ + {"vocab_size": model.vocab_size, "hidden_size": model.hidden_size, "dtype": model.dtype, "T": T} + ], + "overwrite": args.overwrite, + } + + run_benchmarks( + bench_test_fn=bench_speed_embedding, + kernel_operation_modes=["forward", "backward", "full"], + metric_name="speed", + metric_unit="ms", + **common_configs, + ) + run_benchmarks( + bench_test_fn=bench_memory_embedding, + kernel_operation_modes=["full"], + metric_name="memory", + metric_unit="MB", + **common_configs, + ) diff --git a/benchmark/scripts/benchmark_fused_add_rms_norm.py b/benchmark/scripts/benchmark_fused_add_rms_norm.py index 935871e90..f1e008afe 100644 --- a/benchmark/scripts/benchmark_fused_add_rms_norm.py +++ b/benchmark/scripts/benchmark_fused_add_rms_norm.py @@ -1,7 +1,14 @@ +import math + import torch import torch.nn as nn import triton +from benchmark_model_configs import MODEL_REGISTRY +from benchmark_model_configs import compute_model_config_sweep_config +from benchmark_model_configs import compute_seq_len_sweep_config +from benchmark_model_configs import estimate_kernel_peak_memory +from benchmark_model_configs import get_benchmark_model_config from utils import QUANTILES from utils import SingleBenchmarkRunInput from utils import SingleBenchmarkRunOutput @@ -56,56 +63,43 @@ def forward(self, hidden_states, residual): return self.weight * hidden_states.to(input_dtype), residual.to(input_dtype) -def bench_speed_fused_residual_rms_norm(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: - N = input.x - provider = input.kernel_provider +def _setup_fused_add_rms_norm(input: SingleBenchmarkRunInput): + """Create input tensors and FusedAddRMSNorm layer from benchmark config.""" + cfg = input.extra_benchmark_config + hidden_size = cfg["hidden_size"] + eps = cfg["eps"] + x_shape = (input.x, hidden_size) + x = torch.randn(x_shape, dtype=cfg["dtype"], device=device, requires_grad=True) + r = torch.randn(x_shape, dtype=cfg["dtype"], device=device, requires_grad=True) + + if input.kernel_provider == "liger_fused_add_rms_norm": + layer = LigerFusedAddRMSNorm(hidden_size=hidden_size, eps=eps).to(device) + elif input.kernel_provider == "huggingface": + layer = NaiveAddRMSNorm(hidden_size=hidden_size, eps=eps).to(device) + elif input.kernel_provider == "liger_rms_norm": + layer = AddLigerRMSNorm(hidden_size=hidden_size, eps=eps).to(device) + else: + raise ValueError(f"Invalid provider: {input.kernel_provider} for FusedAddRMSNorm") + return x, r, layer + + +def bench_speed_fused_add_rms_norm(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + x, r, layer = _setup_fused_add_rms_norm(input) mode = input.kernel_operation_mode - - extra_benchmark_config = input.extra_benchmark_config - M = extra_benchmark_config["M"] - eps = extra_benchmark_config["eps"] - dtype = extra_benchmark_config["dtype"] - - x_shape = (M, N) - - # Fused Add RMS Norm - fused_add_rms_norm = LigerFusedAddRMSNorm(hidden_size=N, eps=eps).to(device) - # Naive implementation - naive_rms_norm = NaiveAddRMSNorm(hidden_size=N, eps=eps).to(device) - # LigerRMSNorm without fused residual addition - liger_rms_norm = AddLigerRMSNorm(hidden_size=N, eps=eps).to(device) - - x = torch.randn(x_shape, dtype=dtype, device=device) - r = torch.randn(x_shape, dtype=dtype, device=device) dy = torch.randn_like(x) ds = torch.randn_like(r) - x.requires_grad_(True) - r.requires_grad_(True) - # utility functions def y_fwd(): - if provider == "liger_fused_add_rms_norm": - return fused_add_rms_norm(x, r) - - if provider == "huggingface": - return naive_rms_norm(x, r) - - if provider == "liger_rms_norm": - return liger_rms_norm(x, r) + return layer(x, r) if mode == "forward": - ms_50, ms_20, ms_80 = triton.testing.do_bench( - y_fwd, - grad_to_none=[x, r], - rep=500, - quantiles=QUANTILES, - ) + ms_50, ms_20, ms_80 = triton.testing.do_bench(y_fwd, grad_to_none=[x, r], rep=100, quantiles=QUANTILES) elif mode == "backward": y, s = y_fwd() ms_50, ms_20, ms_80 = triton.testing.do_bench( lambda: torch.autograd.backward((y, s), (dy, ds), retain_graph=True), grad_to_none=[x, r], - rep=500, + rep=100, quantiles=QUANTILES, ) elif mode == "full": @@ -114,88 +108,207 @@ def full(): y, s = y_fwd() torch.autograd.backward((y, s), (dy, ds)) + ms_50, ms_20, ms_80 = triton.testing.do_bench(full, grad_to_none=[x, r], rep=100, quantiles=QUANTILES) + else: + raise ValueError(f"Unsupported mode: {mode}") + + return SingleBenchmarkRunOutput(y_20=ms_20, y_50=ms_50, y_80=ms_80) + + +def bench_memory_fused_add_rms_norm(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + x, r, layer = _setup_fused_add_rms_norm(input) + dy = torch.randn_like(x) + ds = torch.randn_like(r) + + def y_fwd(): + return layer(x, r) + + def full(): + y, s = y_fwd() + torch.autograd.backward((y, s), (dy, ds)) + + mem_50, mem_20, mem_80 = _test_memory(full, quantiles=QUANTILES) + return SingleBenchmarkRunOutput(y_20=mem_20, y_50=mem_50, y_80=mem_80) + + +def _resolve_model_config_fused_add_rms_norm(input: SingleBenchmarkRunInput): + """Resolve model-config-sweep input into standard setup args.""" + cfg = input.extra_benchmark_config + model_info = cfg["model_configs"][input.x] + return _setup_fused_add_rms_norm( + SingleBenchmarkRunInput( + x=cfg["BT"], + kernel_provider=input.kernel_provider, + extra_benchmark_config={ + "hidden_size": model_info["hidden_size"], + "dtype": model_info["dtype"], + "eps": cfg["eps"], + }, + ) + ) + + +def bench_speed_fused_add_rms_norm_model_config(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + x, r, layer = _resolve_model_config_fused_add_rms_norm(input) + mode = input.kernel_operation_mode + dy = torch.randn_like(x) + ds = torch.randn_like(r) + + def y_fwd(): + return layer(x, r) + + if mode == "forward": + ms_50, ms_20, ms_80 = triton.testing.do_bench(y_fwd, grad_to_none=[x, r], rep=100, quantiles=QUANTILES) + elif mode == "backward": + y, s = y_fwd() ms_50, ms_20, ms_80 = triton.testing.do_bench( - full, + lambda: torch.autograd.backward((y, s), (dy, ds), retain_graph=True), grad_to_none=[x, r], - rep=500, + rep=100, quantiles=QUANTILES, ) + elif mode == "full": - return SingleBenchmarkRunOutput( - y_20=ms_20, - y_50=ms_50, - y_80=ms_80, - ) - - -def bench_memory_fused_residual_rms_norm(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: - N = input.x - provider = input.kernel_provider + def full(): + y, s = y_fwd() + torch.autograd.backward((y, s), (dy, ds)) - extra_benchmark_config = input.extra_benchmark_config - M = extra_benchmark_config["M"] - eps = extra_benchmark_config["eps"] - dtype = extra_benchmark_config["dtype"] + ms_50, ms_20, ms_80 = triton.testing.do_bench(full, grad_to_none=[x, r], rep=100, quantiles=QUANTILES) + else: + raise ValueError(f"Unsupported mode: {mode}") - x_shape = (M, N) + return SingleBenchmarkRunOutput(y_20=ms_20, y_50=ms_50, y_80=ms_80) - fused_add_rms_norm = LigerFusedAddRMSNorm(hidden_size=N, eps=eps).to(device) - naive_rms_norm = NaiveAddRMSNorm(hidden_size=N, eps=eps).to(device) - liger_rms_norm = AddLigerRMSNorm(hidden_size=N, eps=eps).to(device) - x = torch.randn(x_shape, dtype=dtype, device=device) - r = torch.randn(x_shape, dtype=dtype, device=device) +def bench_memory_fused_add_rms_norm_model_config(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + x, r, layer = _resolve_model_config_fused_add_rms_norm(input) dy = torch.randn_like(x) ds = torch.randn_like(r) - x.requires_grad_(True) - r.requires_grad_(True) - # utility functions def y_fwd(): - if provider == "liger_fused_add_rms_norm": - return fused_add_rms_norm(x, r) - if provider == "huggingface": - return naive_rms_norm(x, r) - if provider == "liger_rms_norm": - return liger_rms_norm(x, r) + return layer(x, r) def full(): y, s = y_fwd() torch.autograd.backward((y, s), (dy, ds)) mem_50, mem_20, mem_80 = _test_memory(full, quantiles=QUANTILES) - - return SingleBenchmarkRunOutput( - y_20=mem_20, - y_50=mem_50, - y_80=mem_80, - ) + return SingleBenchmarkRunOutput(y_20=mem_20, y_50=mem_50, y_80=mem_80) if __name__ == "__main__": args = parse_benchmark_script_args() - common_configs = { - "kernel_name": "fused_add_rms_norm", - "x_name": "H", - "x_label": "hidden size", - "x_values": [2**i for i in range(10, 16)], - "kernel_providers": ["liger_fused_add_rms_norm", "huggingface", "liger_rms_norm"], - "extra_benchmark_configs": [{"M": 2048, "dtype": torch.float32, "eps": 1e-6}], - "overwrite": args.overwrite, - } - - run_benchmarks( - bench_test_fn=bench_speed_fused_residual_rms_norm, - kernel_operation_modes=["forward", "full", "backward"], - metric_name="speed", - metric_unit="ms", - **common_configs, - ) - run_benchmarks( - bench_test_fn=bench_memory_fused_residual_rms_norm, - kernel_operation_modes=["full"], - metric_name="memory", - metric_unit="MB", - **common_configs, - ) + if args.sweep_mode == "model_config": + all_model_configs = list(MODEL_REGISTRY.values()) + + def _probe_factory(model_cfg, probe_bt): + def _probe(): + probe_input = SingleBenchmarkRunInput( + x=probe_bt, + kernel_provider="huggingface", + extra_benchmark_config={ + "hidden_size": model_cfg.hidden_size, + "dtype": model_cfg.dtype, + "eps": 1e-6, + }, + ) + x, r, layer = _setup_fused_add_rms_norm(probe_input) + y, s = layer(x, r) + return y + s # combine for backward probe + + return _probe + + sweep = compute_model_config_sweep_config(all_model_configs, probe_fn_factory=_probe_factory, bt=args.bt) + + model_configs_info = { + cfg.name: { + "hidden_size": cfg.hidden_size, + "dtype": cfg.dtype, + } + for cfg in sweep.model_configs + } + + common_configs = { + "kernel_name": "fused_add_rms_norm", + "x_name": "model_config", + "x_label": "model configuration", + "x_values": [cfg.name for cfg in sweep.model_configs], + "kernel_providers": ["liger_fused_add_rms_norm", "huggingface", "liger_rms_norm"], + "extra_benchmark_configs": [ + { + "model_configs": model_configs_info, + "BT": sweep.bt, + "eps": 1e-6, + } + ], + "overwrite": args.overwrite, + } + + run_benchmarks( + bench_test_fn=bench_speed_fused_add_rms_norm_model_config, + kernel_operation_modes=["full", "forward", "backward"], + metric_name="speed", + metric_unit="ms", + **common_configs, + ) + run_benchmarks( + bench_test_fn=bench_memory_fused_add_rms_norm_model_config, + kernel_operation_modes=["full"], + metric_name="memory", + metric_unit="MB", + **common_configs, + ) + else: + model = get_benchmark_model_config(args.model) + probe_bt = 1024 + + def _probe(): + probe_input = SingleBenchmarkRunInput( + x=probe_bt, + kernel_provider="huggingface", + extra_benchmark_config={ + "hidden_size": model.hidden_size, + "dtype": model.dtype, + "eps": 1e-6, + }, + ) + x, r, layer = _setup_fused_add_rms_norm(probe_input) + y, s = layer(x, r) + return y + s + + peak_bytes = estimate_kernel_peak_memory(probe_fn=_probe) + kernel_bpt = peak_bytes // probe_bt + + config = compute_seq_len_sweep_config(model, kernel_bytes_per_token=kernel_bpt) + + common_configs = { + "kernel_name": "fused_add_rms_norm", + "x_name": "BT", + "x_label": "B * T", + "x_values": [2**i for i in range(10, int(math.log2(config.batch_size * config.seq_len)) + 1)], + "kernel_providers": ["liger_fused_add_rms_norm", "huggingface", "liger_rms_norm"], + "extra_benchmark_configs": [ + { + "hidden_size": model.hidden_size, + "dtype": model.dtype, + "eps": 1e-6, + } + ], + "overwrite": args.overwrite, + } + + run_benchmarks( + bench_test_fn=bench_speed_fused_add_rms_norm, + kernel_operation_modes=["full", "forward", "backward"], + metric_name="speed", + metric_unit="ms", + **common_configs, + ) + run_benchmarks( + bench_test_fn=bench_memory_fused_add_rms_norm, + kernel_operation_modes=["full"], + metric_name="memory", + metric_unit="MB", + **common_configs, + ) diff --git a/benchmark/scripts/benchmark_fused_linear_cross_entropy.py b/benchmark/scripts/benchmark_fused_linear_cross_entropy.py index 4d36a66a6..70755af64 100644 --- a/benchmark/scripts/benchmark_fused_linear_cross_entropy.py +++ b/benchmark/scripts/benchmark_fused_linear_cross_entropy.py @@ -1,6 +1,13 @@ +import math + import torch import triton +from benchmark_model_configs import MODEL_REGISTRY +from benchmark_model_configs import compute_model_config_sweep_config +from benchmark_model_configs import compute_seq_len_sweep_config +from benchmark_model_configs import estimate_kernel_peak_memory +from benchmark_model_configs import get_benchmark_model_config from utils import QUANTILES from utils import SingleBenchmarkRunInput from utils import SingleBenchmarkRunOutput @@ -45,92 +52,103 @@ def forward(self, x, y): return self.ce_loss(self.lin.weight, x, y) -############################################################################# -# Test the memory consumption of the linear fused cross entropy loss -############################################################################# - - -def bench_memory_fused_linear_cross_entropy( - input: SingleBenchmarkRunInput, -) -> SingleBenchmarkRunOutput: +def _setup_fused_linear_cross_entropy(input: SingleBenchmarkRunInput): + """Create input tensor, target, and fused linear CE from benchmark config.""" + cfg = input.extra_benchmark_config + H = cfg["hidden_size"] + V = cfg["vocab_size"] + dtype = cfg["dtype"] BT = input.x - H = input.extra_benchmark_config["H"] - V = input.extra_benchmark_config["V"] - dtype = input.extra_benchmark_config["dtype"] - provider = input.kernel_provider - lm_head_ce = None - if provider == "liger": + _input = torch.randn(BT, H, requires_grad=True, dtype=dtype, device=device) + target = torch.randint(V, (BT, 1), dtype=torch.long, device=device).squeeze(1) + + if input.kernel_provider == "liger": lm_head_ce = LigerLMHeadCE(H=H, V=V, dtype=dtype).to(device) - elif provider == "liger-fp32-accum": + elif input.kernel_provider == "liger-fp32-accum": lm_head_ce = LigerLMHeadCE(H=H, V=V, dtype=dtype, accum_dtype=torch.float32).to(device) - else: + elif input.kernel_provider == "huggingface": lm_head_ce = TorchLMHeadCE(H=H, V=V, dtype=dtype).to(device) + else: + raise ValueError(f"Invalid provider: {input.kernel_provider} for FusedLinearCrossEntropy") + return _input, target, lm_head_ce - _input = torch.randn(BT, H, requires_grad=True, dtype=dtype, device=device) - target = torch.randint(V, (BT, 1), dtype=torch.long, device=device).squeeze(1) + +def bench_speed_fused_linear_cross_entropy(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + _input, target, lm_head_ce = _setup_fused_linear_cross_entropy(input) + mode = input.kernel_operation_mode def fwd(): return lm_head_ce(_input, target) - def full(): + if mode == "forward": + ms_50, ms_20, ms_80 = triton.testing.do_bench(fwd, rep=100, quantiles=QUANTILES) + elif mode == "no-grad-forward": + with torch.no_grad(): + ms_50, ms_20, ms_80 = triton.testing.do_bench(fwd, rep=100, quantiles=QUANTILES) + elif mode == "backward": y = fwd() - y.backward() + ms_50, ms_20, ms_80 = triton.testing.do_bench( + lambda: y.backward(retain_graph=True), + grad_to_none=[_input], + rep=100, + quantiles=QUANTILES, + ) + elif mode == "full": - mem_50, mem_20, mem_80 = _test_memory(full, _iter=10, quantiles=QUANTILES) + def full(): + y = fwd() + y.backward() - return SingleBenchmarkRunOutput( - y_20=mem_20, - y_50=mem_50, - y_80=mem_80, - ) + ms_50, ms_20, ms_80 = triton.testing.do_bench(full, rep=100, quantiles=QUANTILES) + else: + raise ValueError(f"Unsupported mode: {mode}") + return SingleBenchmarkRunOutput(y_20=ms_20, y_50=ms_50, y_80=ms_80) -# ############################################################################# -# # Test the speed of the fused linear cross entropy loss -# ############################################################################# +def bench_memory_fused_linear_cross_entropy(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + _input, target, lm_head_ce = _setup_fused_linear_cross_entropy(input) -def bench_speed_fused_linear_cross_entropy( - input: SingleBenchmarkRunInput, -) -> SingleBenchmarkRunOutput: - BT = input.x - H = input.extra_benchmark_config["H"] - V = input.extra_benchmark_config["V"] - dtype = input.extra_benchmark_config["dtype"] - provider = input.kernel_provider - mode = input.kernel_operation_mode + def full(): + y = lm_head_ce(_input, target) + y.backward() - lm_head_ce = None - if provider == "liger": - lm_head_ce = LigerLMHeadCE(H=H, V=V, dtype=dtype).to(device) - elif provider == "liger-fp32-accum": - lm_head_ce = LigerLMHeadCE(H=H, V=V, dtype=dtype, accum_dtype=torch.float32).to(device) - else: - lm_head_ce = TorchLMHeadCE(H=H, V=V, dtype=dtype).to(device) + mem_50, mem_20, mem_80 = _test_memory(full, _iter=10, quantiles=QUANTILES) + return SingleBenchmarkRunOutput(y_20=mem_20, y_50=mem_50, y_80=mem_80) + + +def _resolve_model_config_fused_linear_cross_entropy(input: SingleBenchmarkRunInput): + """Resolve model-config-sweep input into standard setup args.""" + cfg = input.extra_benchmark_config + model_info = cfg["model_configs"][input.x] + return _setup_fused_linear_cross_entropy( + SingleBenchmarkRunInput( + x=cfg["BT"], + kernel_provider=input.kernel_provider, + extra_benchmark_config={ + "hidden_size": model_info["hidden_size"], + "vocab_size": model_info["vocab_size"], + "dtype": model_info["dtype"], + }, + ) + ) - _input = torch.randn(BT, H, requires_grad=True, dtype=dtype, device=device) - target = torch.randint(V, (BT, 1), dtype=torch.long, device=device).squeeze(1) + +def bench_speed_fused_linear_cross_entropy_model_config(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + _input, target, lm_head_ce = _resolve_model_config_fused_linear_cross_entropy(input) + mode = input.kernel_operation_mode def fwd(): return lm_head_ce(_input, target) if mode == "forward": - ms_50, ms_20, ms_80 = triton.testing.do_bench( - fwd, - rep=100, - quantiles=QUANTILES, - ) + ms_50, ms_20, ms_80 = triton.testing.do_bench(fwd, rep=100, quantiles=QUANTILES) elif mode == "no-grad-forward": with torch.no_grad(): - ms_50, ms_20, ms_80 = triton.testing.do_bench( - fwd, - rep=100, - quantiles=QUANTILES, - ) + ms_50, ms_20, ms_80 = triton.testing.do_bench(fwd, rep=100, quantiles=QUANTILES) elif mode == "backward": y = fwd() - ms_50, ms_20, ms_80 = triton.testing.do_bench( lambda: y.backward(retain_graph=True), grad_to_none=[_input], @@ -143,42 +161,135 @@ def full(): y = fwd() y.backward() - ms_50, ms_20, ms_80 = triton.testing.do_bench( - full, - rep=100, - quantiles=QUANTILES, - ) - return SingleBenchmarkRunOutput( - y_20=ms_20, - y_50=ms_50, - y_80=ms_80, - ) + ms_50, ms_20, ms_80 = triton.testing.do_bench(full, rep=100, quantiles=QUANTILES) + else: + raise ValueError(f"Unsupported mode: {mode}") + + return SingleBenchmarkRunOutput(y_20=ms_20, y_50=ms_50, y_80=ms_80) + + +def bench_memory_fused_linear_cross_entropy_model_config(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + _input, target, lm_head_ce = _resolve_model_config_fused_linear_cross_entropy(input) + + def full(): + y = lm_head_ce(_input, target) + y.backward() + + mem_50, mem_20, mem_80 = _test_memory(full, _iter=10, quantiles=QUANTILES) + return SingleBenchmarkRunOutput(y_20=mem_20, y_50=mem_50, y_80=mem_80) if __name__ == "__main__": args = parse_benchmark_script_args() - common_configs = { - "kernel_name": "fused_linear_cross_entropy", - "x_name": "BT", - "x_label": "B x T", - "x_values": [2**i for i in range(12, 16)], - "kernel_providers": ["liger", "liger-fp32-accum", "huggingface"], - "extra_benchmark_configs": [{"H": 4096, "V": 128256, "mode": "forward", "dtype": torch.bfloat16}], - "overwrite": args.overwrite, - } - - run_benchmarks( - bench_test_fn=bench_speed_fused_linear_cross_entropy, - kernel_operation_modes=["forward", "backward", "full", "no-grad-forward"], - metric_name="speed", - metric_unit="ms", - **common_configs, - ) - run_benchmarks( - bench_test_fn=bench_memory_fused_linear_cross_entropy, - kernel_operation_modes=["full"], - metric_name="memory", - metric_unit="MB", - **common_configs, - ) + if args.sweep_mode == "model_config": + all_model_configs = list(MODEL_REGISTRY.values()) + + def _probe_factory(model_cfg, probe_bt): + def _probe(): + probe_input = SingleBenchmarkRunInput( + x=probe_bt, + kernel_provider="huggingface", + extra_benchmark_config={ + "hidden_size": model_cfg.hidden_size, + "vocab_size": model_cfg.vocab_size, + "dtype": model_cfg.dtype, + }, + ) + _input, target, lm_head_ce = _setup_fused_linear_cross_entropy(probe_input) + return lm_head_ce(_input, target) + + return _probe + + sweep = compute_model_config_sweep_config(all_model_configs, probe_fn_factory=_probe_factory, bt=args.bt) + + model_configs_info = { + cfg.name: { + "hidden_size": cfg.hidden_size, + "vocab_size": cfg.vocab_size, + "dtype": cfg.dtype, + } + for cfg in sweep.model_configs + } + + common_configs = { + "kernel_name": "fused_linear_cross_entropy", + "x_name": "model_config", + "x_label": "model configuration", + "x_values": [cfg.name for cfg in sweep.model_configs], + "kernel_providers": ["liger", "liger-fp32-accum", "huggingface"], + "extra_benchmark_configs": [ + { + "model_configs": model_configs_info, + "BT": sweep.bt, + } + ], + "overwrite": args.overwrite, + } + + run_benchmarks( + bench_test_fn=bench_speed_fused_linear_cross_entropy_model_config, + kernel_operation_modes=["forward", "backward", "full"], + metric_name="speed", + metric_unit="ms", + **common_configs, + ) + run_benchmarks( + bench_test_fn=bench_memory_fused_linear_cross_entropy_model_config, + kernel_operation_modes=["full"], + metric_name="memory", + metric_unit="MB", + **common_configs, + ) + else: + model = get_benchmark_model_config(args.model) + probe_bt = 1024 + + def _probe(): + probe_input = SingleBenchmarkRunInput( + x=probe_bt, + kernel_provider="huggingface", + extra_benchmark_config={ + "hidden_size": model.hidden_size, + "vocab_size": model.vocab_size, + "dtype": model.dtype, + }, + ) + _input, target, lm_head_ce = _setup_fused_linear_cross_entropy(probe_input) + return lm_head_ce(_input, target) + + peak_bytes = estimate_kernel_peak_memory(probe_fn=_probe) + kernel_bpt = peak_bytes // probe_bt + + config = compute_seq_len_sweep_config(model, kernel_bytes_per_token=kernel_bpt) + + common_configs = { + "kernel_name": "fused_linear_cross_entropy", + "x_name": "BT", + "x_label": "B * T", + "x_values": [2**i for i in range(10, int(math.log2(config.batch_size * config.seq_len)) + 1)], + "kernel_providers": ["liger", "liger-fp32-accum", "huggingface"], + "extra_benchmark_configs": [ + { + "hidden_size": model.hidden_size, + "vocab_size": model.vocab_size, + "dtype": model.dtype, + } + ], + "overwrite": args.overwrite, + } + + run_benchmarks( + bench_test_fn=bench_speed_fused_linear_cross_entropy, + kernel_operation_modes=["forward", "backward", "full", "no-grad-forward"], + metric_name="speed", + metric_unit="ms", + **common_configs, + ) + run_benchmarks( + bench_test_fn=bench_memory_fused_linear_cross_entropy, + kernel_operation_modes=["full"], + metric_name="memory", + metric_unit="MB", + **common_configs, + ) diff --git a/benchmark/scripts/benchmark_fused_linear_jsd.py b/benchmark/scripts/benchmark_fused_linear_jsd.py index ac62863b2..15bb0df65 100644 --- a/benchmark/scripts/benchmark_fused_linear_jsd.py +++ b/benchmark/scripts/benchmark_fused_linear_jsd.py @@ -1,6 +1,13 @@ +import math + import torch import triton +from benchmark_model_configs import MODEL_REGISTRY +from benchmark_model_configs import compute_model_config_sweep_config +from benchmark_model_configs import compute_seq_len_sweep_config +from benchmark_model_configs import estimate_kernel_peak_memory +from benchmark_model_configs import get_benchmark_model_config from utils import QUANTILES from utils import SingleBenchmarkRunInput from utils import SingleBenchmarkRunOutput @@ -112,19 +119,13 @@ def forward(self, student_input, teacher_input, label=None): ) -############################################################################# -# Test the memory consumption of the fused linear JSD -############################################################################# - - -def bench_memory_fused_linear_jsd( - input: SingleBenchmarkRunInput, -) -> SingleBenchmarkRunOutput: +def _setup_fused_linear_jsd(input: SingleBenchmarkRunInput): + """Create input tensors and fused linear JSD from benchmark config.""" + cfg = input.extra_benchmark_config + H = cfg["hidden_size"] + V = cfg["vocab_size"] + dtype = cfg["dtype"] BT = input.x - H = input.extra_benchmark_config["H"] - V = input.extra_benchmark_config["V"] - dtype = input.extra_benchmark_config["dtype"] - provider = input.kernel_provider torch_lm_head_jsd = TorchLMHeadJSD(H=H, V=V, dtype=dtype, device=device).to(device) liger_lm_head_jsd = LigerLMHeadJSD(H=H, V=V, dtype=dtype, device=device).to(device) @@ -140,76 +141,88 @@ def bench_memory_fused_linear_jsd( student_input = torch.rand(BT, H, requires_grad=True, dtype=dtype, device=device) teacher_input = torch.rand(BT, H, dtype=dtype, device=device) + if input.kernel_provider == "liger": + lm_head = liger_lm_head_jsd + elif input.kernel_provider == "torch": + lm_head = torch_lm_head_jsd + else: + raise ValueError(f"Invalid provider: {input.kernel_provider} for FusedLinearJSD") + + return student_input, teacher_input, lm_head + + +def bench_speed_fused_linear_jsd(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + student_input, teacher_input, lm_head = _setup_fused_linear_jsd(input) + mode = input.kernel_operation_mode + def fwd(): - if provider == "liger": - return liger_lm_head_jsd(student_input, teacher_input) - elif provider == "torch": - return torch_lm_head_jsd(student_input, teacher_input) + return lm_head(student_input, teacher_input) - def full(): + if mode == "forward": + ms_50, ms_20, ms_80 = triton.testing.do_bench(fwd, rep=100, quantiles=QUANTILES) + elif mode == "backward": y = fwd() - y.backward() - - mem_50, mem_20, mem_80 = _test_memory(full, _iter=10, quantiles=QUANTILES) - return SingleBenchmarkRunOutput( - y_20=mem_20, - y_50=mem_50, - y_80=mem_80, - ) + ms_50, ms_20, ms_80 = triton.testing.do_bench( + lambda: y.backward(retain_graph=True), + grad_to_none=[student_input], + rep=100, + quantiles=QUANTILES, + ) + elif mode == "full": + def full(): + y = fwd() + y.backward() -# ############################################################################# -# # Test the speed of the fused linear JSD -# ############################################################################# + ms_50, ms_20, ms_80 = triton.testing.do_bench(full, rep=100, quantiles=QUANTILES) + else: + raise ValueError(f"Unsupported mode: {mode}") + return SingleBenchmarkRunOutput(y_20=ms_20, y_50=ms_50, y_80=ms_80) -def bench_speed_fused_linear_jsd( - input: SingleBenchmarkRunInput, -) -> SingleBenchmarkRunOutput: - BT = input.x - H = input.extra_benchmark_config["H"] - V = input.extra_benchmark_config["V"] - mode = input.kernel_operation_mode - dtype = input.extra_benchmark_config["dtype"] - provider = input.kernel_provider +def bench_memory_fused_linear_jsd(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + student_input, teacher_input, lm_head = _setup_fused_linear_jsd(input) - torch_lm_head_jsd = TorchLMHeadJSD(H=H, V=V, dtype=dtype, device=device).to(device) - liger_lm_head_jsd = LigerLMHeadJSD(H=H, V=V, dtype=dtype, device=device).to(device) + def full(): + y = lm_head(student_input, teacher_input) + y.backward() - # init the linear in all FusedLinearJSDs with the same weights - torch_lm_head_jsd.student_lin.weight.data = liger_lm_head_jsd.student_lin.weight.data = torch.rand( - V, H, device=device, dtype=dtype - ) - torch_lm_head_jsd.teacher_lin.weight.data = liger_lm_head_jsd.teacher_lin.weight.data = torch.rand( - V, H, device=device, dtype=dtype + mem_50, mem_20, mem_80 = _test_memory(full, _iter=10, quantiles=QUANTILES) + return SingleBenchmarkRunOutput(y_20=mem_20, y_50=mem_50, y_80=mem_80) + + +def _resolve_model_config_fused_linear_jsd(input: SingleBenchmarkRunInput): + """Resolve model-config-sweep input into standard setup args.""" + cfg = input.extra_benchmark_config + model_info = cfg["model_configs"][input.x] + return _setup_fused_linear_jsd( + SingleBenchmarkRunInput( + x=cfg["BT"], + kernel_provider=input.kernel_provider, + extra_benchmark_config={ + "hidden_size": model_info["hidden_size"], + "vocab_size": model_info["vocab_size"], + "dtype": model_info["dtype"], + }, + ) ) - student_input = torch.rand(BT, H, requires_grad=True, dtype=dtype, device=device) - teacher_input = torch.rand(BT, H, dtype=dtype, device=device) + +def bench_speed_fused_linear_jsd_model_config(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + student_input, teacher_input, lm_head = _resolve_model_config_fused_linear_jsd(input) + mode = input.kernel_operation_mode def fwd(): - if provider == "liger": - return liger_lm_head_jsd(student_input, teacher_input) - elif provider == "torch": - return torch_lm_head_jsd(student_input, teacher_input) + return lm_head(student_input, teacher_input) if mode == "forward": - ms_50, ms_20, ms_80 = triton.testing.do_bench( - fwd, - rep=100, - quantiles=QUANTILES, - ) + ms_50, ms_20, ms_80 = triton.testing.do_bench(fwd, rep=100, quantiles=QUANTILES) elif mode == "backward": y = fwd() - ms_50, ms_20, ms_80 = triton.testing.do_bench( lambda: y.backward(retain_graph=True), - grad_to_none=[ - student_input, - torch_lm_head_jsd.student_lin.weight, - torch_lm_head_jsd.teacher_lin.weight, - ], + grad_to_none=[student_input], rep=100, quantiles=QUANTILES, ) @@ -219,42 +232,135 @@ def full(): y = fwd() y.backward() - ms_50, ms_20, ms_80 = triton.testing.do_bench( - full, - rep=100, - quantiles=QUANTILES, - ) - return SingleBenchmarkRunOutput( - y_20=ms_20, - y_50=ms_50, - y_80=ms_80, - ) + ms_50, ms_20, ms_80 = triton.testing.do_bench(full, rep=100, quantiles=QUANTILES) + else: + raise ValueError(f"Unsupported mode: {mode}") + + return SingleBenchmarkRunOutput(y_20=ms_20, y_50=ms_50, y_80=ms_80) + + +def bench_memory_fused_linear_jsd_model_config(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + student_input, teacher_input, lm_head = _resolve_model_config_fused_linear_jsd(input) + + def full(): + y = lm_head(student_input, teacher_input) + y.backward() + + mem_50, mem_20, mem_80 = _test_memory(full, _iter=10, quantiles=QUANTILES) + return SingleBenchmarkRunOutput(y_20=mem_20, y_50=mem_50, y_80=mem_80) if __name__ == "__main__": args = parse_benchmark_script_args() - common_configs = { - "kernel_name": "fused_linear_jsd", - "x_name": "BT", - "x_label": "B x T", - "x_values": [2**i for i in range(10, 14)], - "kernel_providers": ["liger", "torch"], - "extra_benchmark_configs": [{"H": 4096, "V": 128256, "mode": "forward", "dtype": torch.bfloat16}], - "overwrite": args.overwrite, - } - - run_benchmarks( - bench_test_fn=bench_speed_fused_linear_jsd, - kernel_operation_modes=["forward", "backward", "full"], - metric_name="speed", - metric_unit="ms", - **common_configs, - ) - run_benchmarks( - bench_test_fn=bench_memory_fused_linear_jsd, - kernel_operation_modes=["full"], - metric_name="memory", - metric_unit="MB", - **common_configs, - ) + if args.sweep_mode == "model_config": + all_model_configs = list(MODEL_REGISTRY.values()) + + def _probe_factory(model_cfg, probe_bt): + def _probe(): + probe_input = SingleBenchmarkRunInput( + x=probe_bt, + kernel_provider="torch", + extra_benchmark_config={ + "hidden_size": model_cfg.hidden_size, + "vocab_size": model_cfg.vocab_size, + "dtype": model_cfg.dtype, + }, + ) + student_input, teacher_input, lm_head = _setup_fused_linear_jsd(probe_input) + return lm_head(student_input, teacher_input) + + return _probe + + sweep = compute_model_config_sweep_config(all_model_configs, probe_fn_factory=_probe_factory, bt=args.bt) + + model_configs_info = { + cfg.name: { + "hidden_size": cfg.hidden_size, + "vocab_size": cfg.vocab_size, + "dtype": cfg.dtype, + } + for cfg in sweep.model_configs + } + + common_configs = { + "kernel_name": "fused_linear_jsd", + "x_name": "model_config", + "x_label": "model configuration", + "x_values": [cfg.name for cfg in sweep.model_configs], + "kernel_providers": ["liger", "torch"], + "extra_benchmark_configs": [ + { + "model_configs": model_configs_info, + "BT": sweep.bt, + } + ], + "overwrite": args.overwrite, + } + + run_benchmarks( + bench_test_fn=bench_speed_fused_linear_jsd_model_config, + kernel_operation_modes=["forward", "backward", "full"], + metric_name="speed", + metric_unit="ms", + **common_configs, + ) + run_benchmarks( + bench_test_fn=bench_memory_fused_linear_jsd_model_config, + kernel_operation_modes=["full"], + metric_name="memory", + metric_unit="MB", + **common_configs, + ) + else: + model = get_benchmark_model_config(args.model) + probe_bt = 1024 + + def _probe(): + probe_input = SingleBenchmarkRunInput( + x=probe_bt, + kernel_provider="torch", + extra_benchmark_config={ + "hidden_size": model.hidden_size, + "vocab_size": model.vocab_size, + "dtype": model.dtype, + }, + ) + student_input, teacher_input, lm_head = _setup_fused_linear_jsd(probe_input) + return lm_head(student_input, teacher_input) + + peak_bytes = estimate_kernel_peak_memory(probe_fn=_probe) + kernel_bpt = peak_bytes // probe_bt + + config = compute_seq_len_sweep_config(model, kernel_bytes_per_token=kernel_bpt) + + common_configs = { + "kernel_name": "fused_linear_jsd", + "x_name": "BT", + "x_label": "B * T", + "x_values": [2**i for i in range(10, int(math.log2(config.batch_size * config.seq_len)) + 1)], + "kernel_providers": ["liger", "torch"], + "extra_benchmark_configs": [ + { + "hidden_size": model.hidden_size, + "vocab_size": model.vocab_size, + "dtype": model.dtype, + } + ], + "overwrite": args.overwrite, + } + + run_benchmarks( + bench_test_fn=bench_speed_fused_linear_jsd, + kernel_operation_modes=["forward", "backward", "full"], + metric_name="speed", + metric_unit="ms", + **common_configs, + ) + run_benchmarks( + bench_test_fn=bench_memory_fused_linear_jsd, + kernel_operation_modes=["full"], + metric_name="memory", + metric_unit="MB", + **common_configs, + ) diff --git a/benchmark/scripts/benchmark_fused_neighborhood_attention.py b/benchmark/scripts/benchmark_fused_neighborhood_attention.py index 515d65cad..d5cbd0db6 100644 --- a/benchmark/scripts/benchmark_fused_neighborhood_attention.py +++ b/benchmark/scripts/benchmark_fused_neighborhood_attention.py @@ -1,8 +1,15 @@ import math +import os +import sys import torch import triton +from benchmark_model_configs import MODEL_REGISTRY +from benchmark_model_configs import compute_model_config_sweep_config +from benchmark_model_configs import compute_seq_len_sweep_config +from benchmark_model_configs import estimate_kernel_peak_memory +from benchmark_model_configs import get_benchmark_model_config from utils import QUANTILES from utils import SingleBenchmarkRunInput from utils import SingleBenchmarkRunOutput @@ -15,6 +22,8 @@ device = infer_device() +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))) + class TorchNeighborhoodAttention(torch.nn.Module): def __init__( @@ -93,21 +102,17 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: return output -def bench_speed_fused_neighborhood_attention(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: - seq_len = input.x - provider = input.kernel_provider - mode = input.kernel_operation_mode - - extra_benchmark_config = input.extra_benchmark_config - batch_size = extra_benchmark_config["batch_size"] - hidden_size = extra_benchmark_config["hidden_size"] - num_heads = extra_benchmark_config["num_heads"] - kernel_size = extra_benchmark_config["kernel_size"] - dilation = extra_benchmark_config["dilation"] - bias = extra_benchmark_config["bias"] - dtype = extra_benchmark_config["dtype"] - - x_shape = (batch_size, seq_len, hidden_size) +def _setup_fused_neighborhood_attention(input: SingleBenchmarkRunInput): + """Create input tensors and fused neighborhood attention from benchmark config.""" + cfg = input.extra_benchmark_config + hidden_size = cfg["hidden_size"] + num_heads = cfg["num_heads"] + kernel_size = cfg.get("kernel_size", 7) + dilation = cfg.get("dilation", 1) + bias = cfg.get("bias", True) + dtype = cfg["dtype"] + batch_size = cfg.get("batch_size", 2) + seq_len = cfg.get("seq_len", input.x) liger_attn = ( LigerFusedNeighborhoodAttention( @@ -140,34 +145,38 @@ def bench_speed_fused_neighborhood_attention(input: SingleBenchmarkRunInput) -> torch_attn.k_proj.weight.copy_(liger_attn.k_proj.weight) torch_attn.v_proj.weight.copy_(liger_attn.v_proj.weight) torch_attn.out_proj.weight.copy_(liger_attn.out_proj.weight) - if bias: torch_attn.q_proj.bias.copy_(liger_attn.q_proj.bias) torch_attn.k_proj.bias.copy_(liger_attn.k_proj.bias) torch_attn.v_proj.bias.copy_(liger_attn.v_proj.bias) torch_attn.out_proj.bias.copy_(liger_attn.out_proj.bias) - x = torch.randn(x_shape, dtype=dtype, device=device) + x = torch.randn(batch_size, seq_len, hidden_size, dtype=dtype, device=device, requires_grad=True) dy = torch.randn_like(x) - x.requires_grad_(True) - def fwd(): - if provider == "liger": - return liger_attn(x) - elif provider == "torch": - return torch_attn(x) + if input.kernel_provider == "liger": + fwd_fn = lambda: liger_attn(x) + elif input.kernel_provider == "torch": + fwd_fn = lambda: torch_attn(x) + else: + raise ValueError(f"Invalid provider: {input.kernel_provider} for fused neighborhood attention") + + return x, dy, fwd_fn + + +def bench_speed_fused_neighborhood_attention(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + x, dy, fwd_fn = _setup_fused_neighborhood_attention(input) + mode = input.kernel_operation_mode - print(f"Starting Warmup for input size: {x_shape}") - _ = fwd() + # Warmup + _ = fwd_fn() if mode in ("backward", "full"): - y = _ - y.backward(dy, retain_graph=True) - print("Done Warmup") + _.backward(dy, retain_graph=True) if mode == "forward": - ms_50, ms_20, ms_80 = triton.testing.do_bench(fwd, grad_to_none=[x], rep=100, quantiles=QUANTILES) + ms_50, ms_20, ms_80 = triton.testing.do_bench(fwd_fn, grad_to_none=[x], rep=100, quantiles=QUANTILES) elif mode == "backward": - y = fwd() + y = fwd_fn() ms_50, ms_20, ms_80 = triton.testing.do_bench( lambda: y.backward(dy, retain_graph=True), grad_to_none=[x], @@ -177,191 +186,218 @@ def fwd(): elif mode == "full": def full(): - y = fwd() + y = fwd_fn() y.backward(dy, retain_graph=True) ms_50, ms_20, ms_80 = triton.testing.do_bench(full, grad_to_none=[x], rep=100, quantiles=QUANTILES) - - return SingleBenchmarkRunOutput( - y_20=ms_20, - y_50=ms_50, - y_80=ms_80, - ) + else: + raise ValueError(f"Unsupported mode: {mode}") + return SingleBenchmarkRunOutput(y_20=ms_20, y_50=ms_50, y_80=ms_80) def bench_memory_fused_neighborhood_attention(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: - seq_len = input.x - provider = input.kernel_provider + x, dy, fwd_fn = _setup_fused_neighborhood_attention(input) - extra_benchmark_config = input.extra_benchmark_config - batch_size = extra_benchmark_config["batch_size"] - hidden_size = extra_benchmark_config["hidden_size"] - num_heads = extra_benchmark_config["num_heads"] - kernel_size = extra_benchmark_config["kernel_size"] - dilation = extra_benchmark_config["dilation"] - bias = extra_benchmark_config["bias"] - dtype = extra_benchmark_config["dtype"] - - x_shape = (batch_size, seq_len, hidden_size) + def full(): + y = fwd_fn() + y.backward(dy, retain_graph=True) - liger_attn = ( - LigerFusedNeighborhoodAttention( - hidden_size=hidden_size, - num_heads=num_heads, - kernel_size=kernel_size, - dilation=dilation, - bias=bias, - dropout=0.0, + mem_50, mem_20, mem_80 = _test_memory(full, quantiles=QUANTILES) + return SingleBenchmarkRunOutput(y_20=mem_20, y_50=mem_50, y_80=mem_80) + + +def _resolve_model_config_fused_neighborhood_attention(input: SingleBenchmarkRunInput): + cfg = input.extra_benchmark_config + model_info = cfg["model_configs"][input.x] + return _setup_fused_neighborhood_attention( + SingleBenchmarkRunInput( + x=input.x, + kernel_provider=input.kernel_provider, + extra_benchmark_config={ + "hidden_size": model_info["hidden_size"], + "num_heads": model_info["num_heads"], + "dtype": model_info["dtype"], + "seq_len": cfg["seq_len"], + "batch_size": cfg["batch_size"], + "kernel_size": cfg.get("kernel_size", 7), + "dilation": cfg.get("dilation", 1), + "bias": cfg.get("bias", True), + }, ) - .to(device) - .to(dtype) ) - torch_attn = ( - TorchNeighborhoodAttention( - hidden_size=hidden_size, - num_heads=num_heads, - kernel_size=kernel_size, - dilation=dilation, - bias=bias, - dropout=0.0, + +def bench_speed_fused_neighborhood_attention_model_config(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + x, dy, fwd_fn = _resolve_model_config_fused_neighborhood_attention(input) + mode = input.kernel_operation_mode + + _ = fwd_fn() + if mode in ("backward", "full"): + _.backward(dy, retain_graph=True) + + if mode == "forward": + ms_50, ms_20, ms_80 = triton.testing.do_bench(fwd_fn, grad_to_none=[x], rep=100, quantiles=QUANTILES) + elif mode == "backward": + y = fwd_fn() + ms_50, ms_20, ms_80 = triton.testing.do_bench( + lambda: y.backward(dy, retain_graph=True), + grad_to_none=[x], + rep=100, + quantiles=QUANTILES, ) - .to(device) - .to(dtype) - ) + elif mode == "full": - with torch.no_grad(): - torch_attn.q_proj.weight.copy_(liger_attn.q_proj.weight) - torch_attn.k_proj.weight.copy_(liger_attn.k_proj.weight) - torch_attn.v_proj.weight.copy_(liger_attn.v_proj.weight) - torch_attn.out_proj.weight.copy_(liger_attn.out_proj.weight) + def full(): + y = fwd_fn() + y.backward(dy, retain_graph=True) - if bias: - torch_attn.q_proj.bias.copy_(liger_attn.q_proj.bias) - torch_attn.k_proj.bias.copy_(liger_attn.k_proj.bias) - torch_attn.v_proj.bias.copy_(liger_attn.v_proj.bias) - torch_attn.out_proj.bias.copy_(liger_attn.out_proj.bias) + ms_50, ms_20, ms_80 = triton.testing.do_bench(full, grad_to_none=[x], rep=100, quantiles=QUANTILES) + else: + raise ValueError(f"Unsupported mode: {mode}") + return SingleBenchmarkRunOutput(y_20=ms_20, y_50=ms_50, y_80=ms_80) - x = torch.randn(x_shape, dtype=dtype, device=device) - dy = torch.randn_like(x) - x.requires_grad_(True) - def fwd(): - if provider == "liger": - return liger_attn(x) - elif provider == "torch": - return torch_attn(x) +def bench_memory_fused_neighborhood_attention_model_config(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + x, dy, fwd_fn = _resolve_model_config_fused_neighborhood_attention(input) def full(): - y = fwd() + y = fwd_fn() y.backward(dy, retain_graph=True) mem_50, mem_20, mem_80 = _test_memory(full, quantiles=QUANTILES) - - return SingleBenchmarkRunOutput( - y_20=mem_20, - y_50=mem_50, - y_80=mem_80, - ) + return SingleBenchmarkRunOutput(y_20=mem_20, y_50=mem_50, y_80=mem_80) if __name__ == "__main__": args = parse_benchmark_script_args() - common_configs = { - "kernel_name": "fused_neighborhood_attention", - "x_name": "seq_len", - "x_label": "sequence length", - "x_values": [2**i for i in range(6, 13)], - "kernel_providers": ["liger", "torch"], - "extra_benchmark_configs": [ - { - "batch_size": 2, - "hidden_size": 512, - "num_heads": 8, - "kernel_size": 7, - "dilation": 1, - "bias": True, - "dtype": torch.float32, - }, - { - "batch_size": 4, - "hidden_size": 768, - "num_heads": 12, - "kernel_size": 7, - "dilation": 1, - "bias": True, - "dtype": torch.float32, - }, - { - "batch_size": 2, - "hidden_size": 1024, - "num_heads": 16, - "kernel_size": 9, - "dilation": 1, - "bias": True, - "dtype": torch.float32, - }, - { - "batch_size": 2, - "hidden_size": 512, - "num_heads": 8, - "kernel_size": 7, - "dilation": 2, - "bias": True, - "dtype": torch.float32, - }, - { - "batch_size": 2, - "hidden_size": 512, - "num_heads": 8, - "kernel_size": 7, - "dilation": 1, - "bias": True, - "dtype": torch.bfloat16, - }, - { - "batch_size": 4, - "hidden_size": 768, - "num_heads": 12, - "kernel_size": 7, - "dilation": 1, - "bias": True, - "dtype": torch.bfloat16, - }, - { - "batch_size": 2, - "hidden_size": 1024, - "num_heads": 16, - "kernel_size": 9, - "dilation": 1, - "bias": True, - "dtype": torch.bfloat16, - }, - { - "batch_size": 2, - "hidden_size": 512, - "num_heads": 8, - "kernel_size": 7, - "dilation": 2, - "bias": True, - "dtype": torch.bfloat16, - }, - ], - } - - run_benchmarks( - bench_test_fn=bench_speed_fused_neighborhood_attention, - kernel_operation_modes=["forward", "full", "backward"], - metric_name="speed", - metric_unit="ms", - **common_configs, - ) - - run_benchmarks( - bench_test_fn=bench_memory_fused_neighborhood_attention, - kernel_operation_modes=["full"], - metric_name="memory", - metric_unit="MB", - **common_configs, - ) + if args.sweep_mode == "model_config": + all_model_configs = list(MODEL_REGISTRY.values()) + seq_len = 256 + batch_size = 2 + + def _probe_factory(model_cfg, probe_bt): + def _probe(): + probe_input = SingleBenchmarkRunInput( + x=0, + kernel_provider="torch", + extra_benchmark_config={ + "hidden_size": model_cfg.hidden_size, + "num_heads": model_cfg.num_attention_heads, + "dtype": model_cfg.dtype, + "seq_len": seq_len, + "batch_size": batch_size, + "kernel_size": 7, + "dilation": 1, + "bias": True, + }, + ) + _, _, fwd_fn = _setup_fused_neighborhood_attention(probe_input) + return fwd_fn() + + return _probe + + sweep = compute_model_config_sweep_config(all_model_configs, probe_fn_factory=_probe_factory, bt=args.bt) + model_configs_info = { + cfg.name: { + "hidden_size": cfg.hidden_size, + "num_heads": cfg.num_attention_heads, + "dtype": cfg.dtype, + } + for cfg in sweep.model_configs + } + + common_configs = { + "kernel_name": "fused_neighborhood_attention", + "x_name": "model_config", + "x_label": "model configuration", + "x_values": [cfg.name for cfg in sweep.model_configs], + "kernel_providers": ["liger", "torch"], + "extra_benchmark_configs": [ + { + "model_configs": model_configs_info, + "seq_len": seq_len, + "batch_size": batch_size, + "kernel_size": 7, + "dilation": 1, + "bias": True, + } + ], + "overwrite": args.overwrite, + } + + run_benchmarks( + bench_test_fn=bench_speed_fused_neighborhood_attention_model_config, + kernel_operation_modes=["forward", "backward", "full"], + metric_name="speed", + metric_unit="ms", + **common_configs, + ) + run_benchmarks( + bench_test_fn=bench_memory_fused_neighborhood_attention_model_config, + kernel_operation_modes=["full"], + metric_name="memory", + metric_unit="MB", + **common_configs, + ) + else: + model = get_benchmark_model_config(args.model) + batch_size = 2 + probe_seq_len = 256 + + def _probe(): + probe_input = SingleBenchmarkRunInput( + x=0, + kernel_provider="torch", + extra_benchmark_config={ + "hidden_size": model.hidden_size, + "num_heads": model.num_attention_heads, + "dtype": model.dtype, + "seq_len": probe_seq_len, + "batch_size": batch_size, + "kernel_size": 7, + "dilation": 1, + "bias": True, + }, + ) + _, _, fwd_fn = _setup_fused_neighborhood_attention(probe_input) + return fwd_fn() + + peak_bytes = estimate_kernel_peak_memory(probe_fn=_probe) + kernel_bpt = peak_bytes // probe_seq_len + config = compute_seq_len_sweep_config(model, kernel_bytes_per_token=kernel_bpt) + + common_configs = { + "kernel_name": "fused_neighborhood_attention", + "x_name": "seq_len", + "x_label": "sequence length", + "x_values": [2**i for i in range(6, int(math.log2(max(64, config.seq_len))) + 1)], + "kernel_providers": ["liger", "torch"], + "extra_benchmark_configs": [ + { + "hidden_size": model.hidden_size, + "num_heads": model.num_attention_heads, + "dtype": model.dtype, + "batch_size": batch_size, + "kernel_size": 7, + "dilation": 1, + "bias": True, + } + ], + "overwrite": args.overwrite, + } + + run_benchmarks( + bench_test_fn=bench_speed_fused_neighborhood_attention, + kernel_operation_modes=["forward", "backward", "full"], + metric_name="speed", + metric_unit="ms", + **common_configs, + ) + run_benchmarks( + bench_test_fn=bench_memory_fused_neighborhood_attention, + kernel_operation_modes=["full"], + metric_name="memory", + metric_unit="MB", + **common_configs, + ) diff --git a/benchmark/scripts/benchmark_group_norm.py b/benchmark/scripts/benchmark_group_norm.py index 5a8bf37f4..4905ca43f 100644 --- a/benchmark/scripts/benchmark_group_norm.py +++ b/benchmark/scripts/benchmark_group_norm.py @@ -1,12 +1,18 @@ +import math + import torch -import triton -from utils import QUANTILES +from benchmark_model_configs import MODEL_REGISTRY +from benchmark_model_configs import compute_model_config_sweep_config +from benchmark_model_configs import compute_seq_len_sweep_config +from benchmark_model_configs import estimate_kernel_peak_memory +from benchmark_model_configs import get_benchmark_model_config from utils import SingleBenchmarkRunInput from utils import SingleBenchmarkRunOutput -from utils import _test_memory from utils import parse_benchmark_script_args from utils import run_benchmarks +from utils import run_memory_benchmark +from utils import run_speed_benchmark from liger_kernel.transformers.group_norm import LigerGroupNorm from liger_kernel.utils import infer_device @@ -14,124 +20,197 @@ device = infer_device() -def bench_speed_group_norm(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: - C = input.x - provider = input.kernel_provider - mode = input.kernel_operation_mode - extra_benchmark_config = input.extra_benchmark_config - M = extra_benchmark_config["M"] - H = extra_benchmark_config["H"] - channels_per_group = extra_benchmark_config["channels_per_group"] - eps = extra_benchmark_config["eps"] - dtype = extra_benchmark_config["dtype"] - - x_shape = (M, C, H) - triton_ln = LigerGroupNorm(num_channels=C, num_groups=C // channels_per_group, eps=eps).to(device) - torch_ln = torch.nn.GroupNorm(num_groups=C // channels_per_group, num_channels=C, eps=eps).to(device) - - x = torch.randn(x_shape, dtype=dtype, device=device) - dy = torch.randn_like(x) - x.requires_grad_(True) - - def y_fwd(): - if provider == "liger": - return triton_ln(x) - if provider == "huggingface": - return torch_ln(x) - - if mode == "forward": - ms_50, ms_20, ms_80 = triton.testing.do_bench(y_fwd, quantiles=QUANTILES, grad_to_none=[x], rep=500) - elif mode == "backward": - y = y_fwd() - ms_50, ms_20, ms_80 = triton.testing.do_bench( - lambda: y.backward(dy, retain_graph=True), - quantiles=QUANTILES, - grad_to_none=[x], - rep=500, - ) - elif mode == "full": - - def full(): - y = y_fwd() - y.backward(dy, retain_graph=True) +def _setup_group_norm(input: SingleBenchmarkRunInput): + """Create input tensor and GroupNorm layer from benchmark config.""" + cfg = input.extra_benchmark_config + num_channels = cfg["num_channels"] + channels_per_group = cfg["channels_per_group"] + H = cfg["H"] + eps = cfg["eps"] + num_groups = num_channels // channels_per_group + x = torch.randn( + input.x, + num_channels, + H, + device=device, + dtype=cfg["dtype"], + requires_grad=True, + ) + if input.kernel_provider == "liger": + layer = LigerGroupNorm(num_channels=num_channels, num_groups=num_groups, eps=eps).to(device) + elif input.kernel_provider == "huggingface": + layer = torch.nn.GroupNorm(num_groups=num_groups, num_channels=num_channels, eps=eps).to(device) + else: + raise ValueError(f"Invalid provider: {input.kernel_provider} for GroupNorm") + return x, layer - ms_50, ms_20, ms_80 = triton.testing.do_bench(full, quantiles=QUANTILES, grad_to_none=[x], rep=500) - return SingleBenchmarkRunOutput( - y_20=ms_20, - y_50=ms_50, - y_80=ms_80, - ) +def bench_speed_group_norm(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + x, layer = _setup_group_norm(input) + return run_speed_benchmark(lambda: layer(x), input.kernel_operation_mode, [x]) def bench_memory_group_norm(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: - C = input.x - provider = input.kernel_provider - extra_benchmark_config = input.extra_benchmark_config - M = extra_benchmark_config["M"] - H = extra_benchmark_config["H"] - channels_per_group = extra_benchmark_config["channels_per_group"] - eps = extra_benchmark_config["eps"] - dtype = extra_benchmark_config["dtype"] - - x_shape = (M, C, H) - triton_ln = LigerGroupNorm(num_channels=C, num_groups=C // channels_per_group, eps=eps).to(device) - torch_ln = torch.nn.GroupNorm(num_groups=C // channels_per_group, num_channels=C, eps=eps).to(device) - - x = torch.randn(x_shape, dtype=dtype, device=device) - dy = torch.randn_like(x) - x.requires_grad_(True) - - def y_fwd(): - if provider == "liger": - return triton_ln(x) - if provider == "huggingface": - return torch_ln(x) - - def full(): - y = y_fwd() - y.backward(dy, retain_graph=True) - - mem_50, mem_20, mem_80 = _test_memory(full, quantiles=QUANTILES) - return SingleBenchmarkRunOutput( - y_20=mem_20, - y_50=mem_50, - y_80=mem_80, + x, layer = _setup_group_norm(input) + return run_memory_benchmark(lambda: layer(x), input.kernel_operation_mode) + + +def _resolve_model_config_group_norm(input: SingleBenchmarkRunInput): + """Resolve model-config-sweep input into standard setup args.""" + cfg = input.extra_benchmark_config + model_info = cfg["model_configs"][input.x] + return _setup_group_norm( + SingleBenchmarkRunInput( + x=cfg["M"], + kernel_provider=input.kernel_provider, + extra_benchmark_config={ + "num_channels": model_info["hidden_size"], + "channels_per_group": cfg["channels_per_group"], + "H": cfg["H"], + "dtype": model_info["dtype"], + "eps": cfg["eps"], + }, + ) ) +def bench_speed_group_norm_model_config(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + x, layer = _resolve_model_config_group_norm(input) + return run_speed_benchmark(lambda: layer(x), input.kernel_operation_mode, [x]) + + +def bench_memory_group_norm_model_config(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + x, layer = _resolve_model_config_group_norm(input) + return run_memory_benchmark(lambda: layer(x), input.kernel_operation_mode) + + if __name__ == "__main__": args = parse_benchmark_script_args() - common_configs = { - "kernel_name": "group_norm", - "x_name": "C", - "x_label": "num_channels", - "x_values": [2**i for i in range(5, 12)], - "kernel_providers": ["liger", "huggingface"], - "extra_benchmark_configs": [ - { - "M": 128, - "H": 512, - "channels_per_group": 4, - "dtype": torch.float32, - "eps": 1e-6, + if args.sweep_mode == "model_config": + all_model_configs = list(MODEL_REGISTRY.values()) + channels_per_group = 4 + H = 512 + + def _probe_factory(model_cfg, probe_bt): + def _probe(): + M = max(1, probe_bt // H) + probe_input = SingleBenchmarkRunInput( + x=M, + kernel_provider="huggingface", + extra_benchmark_config={ + "num_channels": model_cfg.hidden_size, + "channels_per_group": channels_per_group, + "H": H, + "dtype": model_cfg.dtype, + "eps": 1e-6, + }, + ) + x, layer = _setup_group_norm(probe_input) + return layer(x) + + return _probe + + sweep = compute_model_config_sweep_config(all_model_configs, probe_fn_factory=_probe_factory, bt=args.bt) + + model_configs_info = { + cfg.name: { + "hidden_size": cfg.hidden_size, + "dtype": cfg.dtype, } - ], - "overwrite": args.overwrite, - } - - run_benchmarks( - bench_test_fn=bench_speed_group_norm, - kernel_operation_modes=["forward", "full", "backward"], - metric_name="speed", - metric_unit="ms", - **common_configs, - ) - run_benchmarks( - bench_test_fn=bench_memory_group_norm, - kernel_operation_modes=["full", "forward", "backward"], - metric_name="memory", - metric_unit="MB", - **common_configs, - ) + for cfg in sweep.model_configs + } + + M = max(1, sweep.bt // H) + + common_configs = { + "kernel_name": "group_norm", + "x_name": "model_config", + "x_label": "model configuration", + "x_values": [cfg.name for cfg in sweep.model_configs], + "kernel_providers": ["liger", "huggingface"], + "extra_benchmark_configs": [ + { + "model_configs": model_configs_info, + "M": M, + "channels_per_group": channels_per_group, + "H": H, + "eps": 1e-6, + } + ], + "overwrite": args.overwrite, + } + + run_benchmarks( + bench_test_fn=bench_speed_group_norm_model_config, + kernel_operation_modes=["full", "forward", "backward"], + metric_name="speed", + metric_unit="ms", + **common_configs, + ) + run_benchmarks( + bench_test_fn=bench_memory_group_norm_model_config, + kernel_operation_modes=["full", "forward", "backward"], + metric_name="memory", + metric_unit="MB", + **common_configs, + ) + else: + model = get_benchmark_model_config(args.model) + channels_per_group = 4 + H = 512 + probe_bt = 1024 + + def _probe(): + M = max(1, probe_bt // H) + probe_input = SingleBenchmarkRunInput( + x=M, + kernel_provider="huggingface", + extra_benchmark_config={ + "num_channels": model.hidden_size, + "channels_per_group": channels_per_group, + "H": H, + "dtype": model.dtype, + "eps": 1e-6, + }, + ) + x, layer = _setup_group_norm(probe_input) + return layer(x) + + peak_bytes = estimate_kernel_peak_memory(probe_fn=_probe) + kernel_bpt = peak_bytes // probe_bt + + config = compute_seq_len_sweep_config(model, kernel_bytes_per_token=kernel_bpt) + + common_configs = { + "kernel_name": "group_norm", + "x_name": "M", + "x_label": "batch size (M)", + "x_values": [2**i for i in range(2, int(math.log2(config.batch_size * config.seq_len // H)) + 1)], + "kernel_providers": ["liger", "huggingface"], + "extra_benchmark_configs": [ + { + "num_channels": model.hidden_size, + "channels_per_group": channels_per_group, + "H": H, + "dtype": model.dtype, + "eps": 1e-6, + } + ], + "overwrite": args.overwrite, + } + + run_benchmarks( + bench_test_fn=bench_speed_group_norm, + kernel_operation_modes=["full", "forward", "backward"], + metric_name="speed", + metric_unit="ms", + **common_configs, + ) + run_benchmarks( + bench_test_fn=bench_memory_group_norm, + kernel_operation_modes=["full", "forward", "backward"], + metric_name="memory", + metric_unit="MB", + **common_configs, + ) diff --git a/benchmark/scripts/benchmark_grpo_loss.py b/benchmark/scripts/benchmark_grpo_loss.py index 497d8692c..df3244c7b 100644 --- a/benchmark/scripts/benchmark_grpo_loss.py +++ b/benchmark/scripts/benchmark_grpo_loss.py @@ -1,9 +1,15 @@ +import math import os import sys import torch import triton +from benchmark_model_configs import MODEL_REGISTRY +from benchmark_model_configs import compute_model_config_sweep_config +from benchmark_model_configs import compute_seq_len_sweep_config +from benchmark_model_configs import estimate_kernel_peak_memory +from benchmark_model_configs import get_benchmark_model_config from utils import QUANTILES from utils import SingleBenchmarkRunInput from utils import SingleBenchmarkRunOutput @@ -18,217 +24,257 @@ sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))) -############################################################################# -# Test the memory consumption of the linear fused GRPO loss -############################################################################# - - -def bench_memory_fused_linear_grpo_loss( - input: SingleBenchmarkRunInput, -) -> SingleBenchmarkRunOutput: +def _setup_grpo_loss(input: SingleBenchmarkRunInput): + """Create input tensors and GRPO loss from benchmark config.""" from test.chunked_loss.test_grpo_loss import LigerLMHeadGRPO from test.chunked_loss.test_grpo_loss import TorchLMHeadGRPO + cfg = input.extra_benchmark_config + H = cfg["hidden_size"] + V = cfg["vocab_size"] + dtype = cfg["dtype"] + importance_sampling_level = cfg["importance_sampling_level"] B = input.x - T = input.extra_benchmark_config["T"] - H = input.extra_benchmark_config["H"] - V = input.extra_benchmark_config["V"] - dtype = input.extra_benchmark_config["dtype"] - importance_sampling_level = input.extra_benchmark_config["importance_sampling_level"] - provider = input.kernel_provider - - # Instantiate once and retrieve the first output only - torch_lm_head_grpo = TorchLMHeadGRPO(H=H, V=V, dtype=dtype, importance_sampling_level=importance_sampling_level).to( - device - ) - liger_lm_head_grpo = LigerLMHeadGRPO(H=H, V=V, dtype=dtype, importance_sampling_level=importance_sampling_level).to( - device - ) + T = cfg["T"] - # Create inputs _input = torch.randn(B, T, H, requires_grad=True, dtype=dtype, device=device) selected_token_ids = torch.randint(0, V, (B, T), dtype=torch.long, device=device) attention_mask = torch.ones(B, T, device=device) advantages = torch.randn(B, dtype=dtype, device=device) ref_input = torch.randn(B, T, H, dtype=dtype, device=device) - torch_fwd = lambda: torch_lm_head_grpo(_input, selected_token_ids, attention_mask, advantages, ref_input=ref_input)[ - 0 - ] - liger_fwd = lambda: liger_lm_head_grpo(_input, selected_token_ids, attention_mask, advantages, ref_input=ref_input)[ - 0 - ] + if input.kernel_provider == "liger": + loss_module = LigerLMHeadGRPO(H=H, V=V, dtype=dtype, importance_sampling_level=importance_sampling_level).to( + device + ) + elif input.kernel_provider == "torch": + loss_module = TorchLMHeadGRPO(H=H, V=V, dtype=dtype, importance_sampling_level=importance_sampling_level).to( + device + ) + else: + raise ValueError(f"Invalid provider: {input.kernel_provider} for GRPOLoss") - def fwd(): - if provider == "liger": - return liger_fwd() - elif provider == "torch": - return torch_fwd() + fwd_fn = lambda: loss_module(_input, selected_token_ids, attention_mask, advantages, ref_input=ref_input)[0] + return _input, fwd_fn - def full(): - y = fwd() - y.backward() - mem_50, mem_20, mem_80 = _test_memory(full, _iter=10, quantiles=QUANTILES) - return SingleBenchmarkRunOutput( - y_20=mem_20, - y_50=mem_50, - y_80=mem_80, - ) +def bench_speed_grpo_loss(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + _input, fwd_fn = _setup_grpo_loss(input) + mode = input.kernel_operation_mode + if mode == "forward": + ms_50, ms_20, ms_80 = triton.testing.do_bench(fwd_fn, rep=100, quantiles=QUANTILES) + elif mode == "backward": + y = fwd_fn() + ms_50, ms_20, ms_80 = triton.testing.do_bench( + lambda: y.backward(retain_graph=True), grad_to_none=[_input], rep=100, quantiles=QUANTILES + ) + elif mode == "full": -############################################################################# -# Test the speed of the fused linear GRPO loss -############################################################################# + def full(): + y = fwd_fn() + y.backward() + ms_50, ms_20, ms_80 = triton.testing.do_bench(full, rep=100, quantiles=QUANTILES) + else: + raise ValueError(f"Unsupported mode: {mode}") + return SingleBenchmarkRunOutput(y_20=ms_20, y_50=ms_50, y_80=ms_80) -def bench_speed_fused_linear_grpo_loss( - input: SingleBenchmarkRunInput, -) -> SingleBenchmarkRunOutput: - from test.chunked_loss.test_grpo_loss import LigerLMHeadGRPO - from test.chunked_loss.test_grpo_loss import TorchLMHeadGRPO - B = input.x - T = input.extra_benchmark_config["T"] - H = input.extra_benchmark_config["H"] - V = input.extra_benchmark_config["V"] - dtype = input.extra_benchmark_config["dtype"] - importance_sampling_level = input.extra_benchmark_config["importance_sampling_level"] - provider = input.kernel_provider - mode = input.kernel_operation_mode +def bench_memory_grpo_loss(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + _input, fwd_fn = _setup_grpo_loss(input) - # Instantiate once and retrieve the first output only - torch_lm_head_grpo = TorchLMHeadGRPO(H=H, V=V, dtype=dtype, importance_sampling_level=importance_sampling_level).to( - device - ) - liger_lm_head_grpo = LigerLMHeadGRPO(H=H, V=V, dtype=dtype, importance_sampling_level=importance_sampling_level).to( - device - ) + def full(): + y = fwd_fn() + y.backward() - # Create inputs - _input = torch.randn(B, T, H, requires_grad=True, dtype=dtype, device=device) - selected_token_ids = torch.randint(0, V, (B, T), dtype=torch.long, device=device) - attention_mask = torch.ones(B, T, device=device) - advantages = torch.randn(B, dtype=dtype, device=device) - ref_input = torch.randn(B, T, H, dtype=dtype, device=device) + mem_50, mem_20, mem_80 = _test_memory(full, _iter=10, quantiles=QUANTILES) + return SingleBenchmarkRunOutput(y_20=mem_20, y_50=mem_50, y_80=mem_80) + + +def _resolve_model_config_grpo_loss(input: SingleBenchmarkRunInput): + cfg = input.extra_benchmark_config + model_info = cfg["model_configs"][input.x] + return _setup_grpo_loss( + SingleBenchmarkRunInput( + x=cfg["B"], + kernel_provider=input.kernel_provider, + extra_benchmark_config={ + "hidden_size": model_info["hidden_size"], + "vocab_size": model_info["vocab_size"], + "dtype": model_info["dtype"], + "T": cfg["T"], + "importance_sampling_level": cfg["importance_sampling_level"], + }, + ) + ) - torch_fwd = lambda: torch_lm_head_grpo(_input, selected_token_ids, attention_mask, advantages, ref_input=ref_input)[ - 0 - ] - liger_fwd = lambda: liger_lm_head_grpo(_input, selected_token_ids, attention_mask, advantages, ref_input=ref_input)[ - 0 - ] - def fwd(): - if provider == "liger": - return liger_fwd() - elif provider == "torch": - return torch_fwd() +def bench_speed_grpo_loss_model_config(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + _input, fwd_fn = _resolve_model_config_grpo_loss(input) + mode = input.kernel_operation_mode if mode == "forward": - ms_50, ms_20, ms_80 = triton.testing.do_bench( - fwd, - rep=100, - quantiles=QUANTILES, - ) + ms_50, ms_20, ms_80 = triton.testing.do_bench(fwd_fn, rep=100, quantiles=QUANTILES) elif mode == "backward": - y = fwd() - + y = fwd_fn() ms_50, ms_20, ms_80 = triton.testing.do_bench( - lambda: y.backward(retain_graph=True), - grad_to_none=[_input], - rep=100, - quantiles=QUANTILES, + lambda: y.backward(retain_graph=True), grad_to_none=[_input], rep=100, quantiles=QUANTILES ) elif mode == "full": def full(): - y = fwd() + y = fwd_fn() y.backward() - ms_50, ms_20, ms_80 = triton.testing.do_bench( - full, - rep=100, - quantiles=QUANTILES, + ms_50, ms_20, ms_80 = triton.testing.do_bench(full, rep=100, quantiles=QUANTILES) + else: + raise ValueError(f"Unsupported mode: {mode}") + return SingleBenchmarkRunOutput(y_20=ms_20, y_50=ms_50, y_80=ms_80) + + +def bench_memory_grpo_loss_model_config(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + _input, fwd_fn = _resolve_model_config_grpo_loss(input) + + def full(): + y = fwd_fn() + y.backward() + + mem_50, mem_20, mem_80 = _test_memory(full, _iter=10, quantiles=QUANTILES) + return SingleBenchmarkRunOutput(y_20=mem_20, y_50=mem_50, y_80=mem_80) + + +def _run_grpo_benchmarks(args, importance_sampling_level, kernel_name_suffix): + """Run D1 or D2 benchmarks for a given importance_sampling_level.""" + kernel_name = f"fused_linear_grpo_loss_{kernel_name_suffix}" + + if args.sweep_mode == "model_config": + all_model_configs = list(MODEL_REGISTRY.values()) + T = 1024 + + def _probe_factory(model_cfg, probe_bt): + def _probe(): + B = max(1, probe_bt // T) + probe_input = SingleBenchmarkRunInput( + x=B, + kernel_provider="torch", + extra_benchmark_config={ + "hidden_size": model_cfg.hidden_size, + "vocab_size": model_cfg.vocab_size, + "dtype": model_cfg.dtype, + "T": T, + "importance_sampling_level": importance_sampling_level, + }, + ) + _, fwd_fn = _setup_grpo_loss(probe_input) + return fwd_fn() + + return _probe + + sweep = compute_model_config_sweep_config(all_model_configs, probe_fn_factory=_probe_factory, bt=args.bt) + model_configs_info = { + cfg.name: {"hidden_size": cfg.hidden_size, "vocab_size": cfg.vocab_size, "dtype": cfg.dtype} + for cfg in sweep.model_configs + } + B = max(1, sweep.bt // T) + + common_configs = { + "kernel_name": kernel_name, + "x_name": "model_config", + "x_label": "model configuration", + "x_values": [cfg.name for cfg in sweep.model_configs], + "kernel_providers": ["liger", "torch"], + "extra_benchmark_configs": [ + { + "model_configs": model_configs_info, + "B": B, + "T": T, + "importance_sampling_level": importance_sampling_level, + } + ], + "overwrite": args.overwrite, + } + + run_benchmarks( + bench_test_fn=bench_speed_grpo_loss_model_config, + kernel_operation_modes=["forward", "backward", "full"], + metric_name="speed", + metric_unit="ms", + **common_configs, + ) + run_benchmarks( + bench_test_fn=bench_memory_grpo_loss_model_config, + kernel_operation_modes=["full"], + metric_name="memory", + metric_unit="MB", + **common_configs, + ) + else: + model = get_benchmark_model_config(args.model) + T = 1024 + probe_bt = 1024 + + def _probe(): + B = max(1, probe_bt // T) + probe_input = SingleBenchmarkRunInput( + x=B, + kernel_provider="torch", + extra_benchmark_config={ + "hidden_size": model.hidden_size, + "vocab_size": model.vocab_size, + "dtype": model.dtype, + "T": T, + "importance_sampling_level": importance_sampling_level, + }, + ) + _, fwd_fn = _setup_grpo_loss(probe_input) + return fwd_fn() + + peak_bytes = estimate_kernel_peak_memory(probe_fn=_probe) + kernel_bpt = peak_bytes // probe_bt + config = compute_seq_len_sweep_config(model, kernel_bytes_per_token=kernel_bpt) + + common_configs = { + "kernel_name": kernel_name, + "x_name": "B", + "x_label": "Batch Size (B)", + "x_values": [2**i for i in range(1, int(math.log2(max(2, config.batch_size * config.seq_len // T))) + 1)], + "kernel_providers": ["liger", "torch"], + "extra_benchmark_configs": [ + { + "hidden_size": model.hidden_size, + "vocab_size": model.vocab_size, + "dtype": model.dtype, + "T": T, + "importance_sampling_level": importance_sampling_level, + } + ], + "overwrite": args.overwrite, + } + + run_benchmarks( + bench_test_fn=bench_speed_grpo_loss, + kernel_operation_modes=["forward", "backward", "full"], + metric_name="speed", + metric_unit="ms", + **common_configs, + ) + run_benchmarks( + bench_test_fn=bench_memory_grpo_loss, + kernel_operation_modes=["full"], + metric_name="memory", + metric_unit="MB", + **common_configs, ) - return SingleBenchmarkRunOutput( - y_20=ms_20, - y_50=ms_50, - y_80=ms_80, - ) if __name__ == "__main__": args = parse_benchmark_script_args() # Benchmark token-level importance sampling (original GRPO) - token_configs = { - "kernel_name": "fused_linear_grpo_loss_token", - "x_name": "B", - "x_label": "B", - "x_values": [2**i for i in range(1, 5)], - "kernel_providers": ["liger", "torch"], - "extra_benchmark_configs": [ - { - "T": 1024, - "H": 4096, - "V": 128256, - "importance_sampling_level": "token", - "dtype": torch.bfloat16, - } - ], - "overwrite": args.overwrite, - } - - # Benchmark sequence-level importance sampling (GSPO) - sequence_configs = { - "kernel_name": "fused_linear_grpo_loss_sequence", - "x_name": "B", - "x_label": "B", - "x_values": [2**i for i in range(1, 5)], - "kernel_providers": ["liger", "torch"], - "extra_benchmark_configs": [ - { - "T": 1024, - "H": 4096, - "V": 128256, - "importance_sampling_level": "sequence", - "dtype": torch.bfloat16, - } - ], - "overwrite": args.overwrite, - } - - # Run benchmarks for token-level (GRPO) print("Benchmarking GRPO (token-level importance sampling)...") - run_benchmarks( - bench_test_fn=bench_speed_fused_linear_grpo_loss, - kernel_operation_modes=["forward", "full", "backward"], - metric_name="speed", - metric_unit="ms", - **token_configs, - ) - run_benchmarks( - bench_test_fn=bench_memory_fused_linear_grpo_loss, - kernel_operation_modes=["full"], - metric_name="memory", - metric_unit="MB", - **token_configs, - ) + _run_grpo_benchmarks(args, importance_sampling_level="token", kernel_name_suffix="token") - # Run benchmarks for sequence-level (GSPO) + # Benchmark sequence-level importance sampling (GSPO) print("Benchmarking GSPO (sequence-level importance sampling)...") - run_benchmarks( - bench_test_fn=bench_speed_fused_linear_grpo_loss, - kernel_operation_modes=["forward", "full", "backward"], - metric_name="speed", - metric_unit="ms", - **sequence_configs, - ) - run_benchmarks( - bench_test_fn=bench_memory_fused_linear_grpo_loss, - kernel_operation_modes=["full"], - metric_name="memory", - metric_unit="MB", - **sequence_configs, - ) + _run_grpo_benchmarks(args, importance_sampling_level="sequence", kernel_name_suffix="sequence") diff --git a/benchmark/scripts/benchmark_jsd.py b/benchmark/scripts/benchmark_jsd.py index 16d71eac0..16733c425 100644 --- a/benchmark/scripts/benchmark_jsd.py +++ b/benchmark/scripts/benchmark_jsd.py @@ -1,6 +1,13 @@ +import math + import torch import triton +from benchmark_model_configs import MODEL_REGISTRY +from benchmark_model_configs import compute_model_config_sweep_config +from benchmark_model_configs import compute_seq_len_sweep_config +from benchmark_model_configs import estimate_kernel_peak_memory +from benchmark_model_configs import get_benchmark_model_config from utils import QUANTILES from utils import SingleBenchmarkRunInput from utils import SingleBenchmarkRunOutput @@ -9,7 +16,6 @@ from utils import run_benchmarks from liger_kernel.transformers.jsd import LigerJSD -from liger_kernel.utils import get_total_gpu_memory from liger_kernel.utils import infer_device device = infer_device() @@ -53,105 +59,223 @@ def forward( return loss.to(self.dtype) -def bench_speed_jsd(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: - V = input.x - B, T = input.extra_benchmark_config["B"], input.extra_benchmark_config["T"] - torch_jsd = TorchJSD() - liger_jsd = LigerJSD() +def _setup_jsd(input: SingleBenchmarkRunInput): + """Create input tensors and JSD loss from benchmark config.""" + cfg = input.extra_benchmark_config + V = cfg["vocab_size"] + BT = input.x + _input = torch.randn(BT, V, requires_grad=True, device=device).log_softmax(dim=-1) + target = torch.randn(BT, V, device=device).log_softmax(dim=-1) + + if input.kernel_provider == "liger": + loss_fn = LigerJSD() + elif input.kernel_provider == "torch": + loss_fn = TorchJSD() + else: + raise ValueError(f"Invalid provider: {input.kernel_provider} for JSD") + return _input, target, loss_fn - _input = torch.randn(B * T, V, requires_grad=True, device=device).log_softmax(dim=-1) - target = torch.randn(B * T, V, device=device).log_softmax(dim=-1) + +def bench_speed_jsd(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + _input, target, loss_fn = _setup_jsd(input) + mode = input.kernel_operation_mode def fwd(): - if input.kernel_provider == "liger": - return liger_jsd(_input, target) - else: - return torch_jsd(_input, target) + return loss_fn(_input, target) - if input.kernel_operation_mode == "forward": + if mode == "forward": ms_50, ms_20, ms_80 = triton.testing.do_bench(fwd, quantiles=QUANTILES, rep=100) - elif input.kernel_operation_mode == "backward": + elif mode == "backward": y = fwd() - ms_50, ms_20, ms_80 = triton.testing.do_bench( lambda: y.backward(retain_graph=True), quantiles=QUANTILES, grad_to_none=[_input], rep=100, ) - elif input.kernel_operation_mode == "full": + elif mode == "full": def full(): y = fwd() y.backward(retain_graph=True) ms_50, ms_20, ms_80 = triton.testing.do_bench(full, quantiles=QUANTILES, rep=100) - return SingleBenchmarkRunOutput( - y_20=ms_20, - y_50=ms_50, - y_80=ms_80, - ) + else: + raise ValueError(f"Unsupported mode: {mode}") + + return SingleBenchmarkRunOutput(y_20=ms_20, y_50=ms_50, y_80=ms_80) def bench_memory_jsd(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: - torch_jsd = TorchJSD() - liger_jsd = LigerJSD() + _input, target, loss_fn = _setup_jsd(input) + + def full(): + y = loss_fn(_input, target) + y.backward(retain_graph=True) + + mem_50, mem_20, mem_80 = _test_memory(full, quantiles=QUANTILES) + return SingleBenchmarkRunOutput(y_20=mem_20, y_50=mem_50, y_80=mem_80) + + +def _resolve_model_config_jsd(input: SingleBenchmarkRunInput): + """Resolve model-config-sweep input into standard setup args.""" + cfg = input.extra_benchmark_config + model_info = cfg["model_configs"][input.x] + return _setup_jsd( + SingleBenchmarkRunInput( + x=cfg["BT"], + kernel_provider=input.kernel_provider, + extra_benchmark_config={ + "vocab_size": model_info["vocab_size"], + }, + ) + ) - V = input.x - B, T = input.extra_benchmark_config["B"], input.extra_benchmark_config["T"] - _input = torch.randn(B * T, V, requires_grad=True, device=device).log_softmax(dim=-1) - target = torch.randn(B * T, V, device=device).log_softmax(dim=-1) +def bench_speed_jsd_model_config(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + _input, target, loss_fn = _resolve_model_config_jsd(input) + mode = input.kernel_operation_mode def fwd(): - if input.kernel_provider == "liger": - return liger_jsd(_input, target) - else: - return torch_jsd(_input, target) + return loss_fn(_input, target) - def full(): + if mode == "forward": + ms_50, ms_20, ms_80 = triton.testing.do_bench(fwd, quantiles=QUANTILES, rep=100) + elif mode == "backward": y = fwd() + ms_50, ms_20, ms_80 = triton.testing.do_bench( + lambda: y.backward(retain_graph=True), + quantiles=QUANTILES, + grad_to_none=[_input], + rep=100, + ) + elif mode == "full": + + def full(): + y = fwd() + y.backward(retain_graph=True) + + ms_50, ms_20, ms_80 = triton.testing.do_bench(full, quantiles=QUANTILES, rep=100) + else: + raise ValueError(f"Unsupported mode: {mode}") + + return SingleBenchmarkRunOutput(y_20=ms_20, y_50=ms_50, y_80=ms_80) + + +def bench_memory_jsd_model_config(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + _input, target, loss_fn = _resolve_model_config_jsd(input) + + def full(): + y = loss_fn(_input, target) y.backward(retain_graph=True) mem_50, mem_20, mem_80 = _test_memory(full, quantiles=QUANTILES) - - return SingleBenchmarkRunOutput( - y_20=mem_20, - y_50=mem_50, - y_80=mem_80, - ) + return SingleBenchmarkRunOutput(y_20=mem_20, y_50=mem_50, y_80=mem_80) if __name__ == "__main__": args = parse_benchmark_script_args() - gpu_memory_gbs = get_total_gpu_memory() - # We know that the full test will require 54GBs for vocab size 2^17 on torch - if gpu_memory_gbs >= 54: - x_max = 17 + + if args.sweep_mode == "model_config": + all_model_configs = list(MODEL_REGISTRY.values()) + + def _probe_factory(model_cfg, probe_bt): + def _probe(): + probe_input = SingleBenchmarkRunInput( + x=probe_bt, + kernel_provider="torch", + extra_benchmark_config={ + "vocab_size": model_cfg.vocab_size, + }, + ) + _input, target, loss_fn = _setup_jsd(probe_input) + return loss_fn(_input, target) + + return _probe + + sweep = compute_model_config_sweep_config(all_model_configs, probe_fn_factory=_probe_factory, bt=args.bt) + + model_configs_info = { + cfg.name: { + "vocab_size": cfg.vocab_size, + } + for cfg in sweep.model_configs + } + + common_configs = { + "kernel_name": "jsd", + "x_name": "model_config", + "x_label": "model configuration", + "x_values": [cfg.name for cfg in sweep.model_configs], + "kernel_providers": ["liger", "torch"], + "extra_benchmark_configs": [ + { + "model_configs": model_configs_info, + "BT": sweep.bt, + } + ], + "overwrite": args.overwrite, + } + + run_benchmarks( + bench_test_fn=bench_speed_jsd_model_config, + kernel_operation_modes=["forward", "backward", "full"], + metric_name="speed", + metric_unit="ms", + **common_configs, + ) + run_benchmarks( + bench_test_fn=bench_memory_jsd_model_config, + kernel_operation_modes=["full"], + metric_name="memory", + metric_unit="MB", + **common_configs, + ) else: - x_max = 16 - common_args = { - "kernel_name": "jsd", - "x_name": "V", - "x_label": "vocab size", - "x_values": [2**i for i in range(12, x_max + 1)], - "kernel_providers": ["liger", "torch"], - "extra_benchmark_configs": [{"B": 4, "T": 2048}], - "overwrite": args.overwrite, - } - - run_benchmarks( - bench_test_fn=bench_memory_jsd, - kernel_operation_modes=["full"], - metric_name="memory", - metric_unit="MB", - **common_args, - ) + model = get_benchmark_model_config(args.model) + probe_bt = 1024 - run_benchmarks( - bench_test_fn=bench_speed_jsd, - kernel_operation_modes=["forward", "backward", "full"], - metric_name="speed", - metric_unit="ms", - **common_args, - ) + def _probe(): + probe_input = SingleBenchmarkRunInput( + x=probe_bt, + kernel_provider="torch", + extra_benchmark_config={ + "vocab_size": model.vocab_size, + }, + ) + _input, target, loss_fn = _setup_jsd(probe_input) + return loss_fn(_input, target) + + peak_bytes = estimate_kernel_peak_memory(probe_fn=_probe) + kernel_bpt = peak_bytes // probe_bt + + config = compute_seq_len_sweep_config(model, kernel_bytes_per_token=kernel_bpt) + + common_configs = { + "kernel_name": "jsd", + "x_name": "BT", + "x_label": "B * T", + "x_values": [2**i for i in range(10, int(math.log2(config.batch_size * config.seq_len)) + 1)], + "kernel_providers": ["liger", "torch"], + "extra_benchmark_configs": [ + { + "vocab_size": model.vocab_size, + } + ], + "overwrite": args.overwrite, + } + + run_benchmarks( + bench_test_fn=bench_speed_jsd, + kernel_operation_modes=["forward", "backward", "full"], + metric_name="speed", + metric_unit="ms", + **common_configs, + ) + run_benchmarks( + bench_test_fn=bench_memory_jsd, + kernel_operation_modes=["full"], + metric_name="memory", + metric_unit="MB", + **common_configs, + ) diff --git a/benchmark/scripts/benchmark_kl_div.py b/benchmark/scripts/benchmark_kl_div.py index 09948c38b..2fc10b7d5 100644 --- a/benchmark/scripts/benchmark_kl_div.py +++ b/benchmark/scripts/benchmark_kl_div.py @@ -1,7 +1,14 @@ +import math + import torch import torch.nn as nn import triton +from benchmark_model_configs import MODEL_REGISTRY +from benchmark_model_configs import compute_model_config_sweep_config +from benchmark_model_configs import compute_seq_len_sweep_config +from benchmark_model_configs import estimate_kernel_peak_memory +from benchmark_model_configs import get_benchmark_model_config from utils import QUANTILES from utils import SingleBenchmarkRunInput from utils import SingleBenchmarkRunOutput @@ -14,104 +21,226 @@ device = infer_device() -S, E = 12, 18 - -def bench_speed_kldiv(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: +def _setup_kl_div(input: SingleBenchmarkRunInput): + """Create input tensors and KL div loss from benchmark config.""" + cfg = input.extra_benchmark_config + V = cfg["vocab_size"] + BT = input.x reduction = "batchmean" - V = input.x - B, T = input.extra_benchmark_config["B"], input.extra_benchmark_config["T"] - torch_kl_div = nn.KLDivLoss(reduction=reduction) - liger_kl_div = LigerKLDIVLoss(reduction=reduction) - _input = torch.randn(B * T, V, requires_grad=True, device=device).log_softmax(dim=-1) - target = torch.randn(B * T, V, device=device).softmax(dim=-1) + _input = torch.randn(BT, V, requires_grad=True, device=device).log_softmax(dim=-1) + target = torch.randn(BT, V, device=device).softmax(dim=-1) + + if input.kernel_provider == "liger": + loss_fn = LigerKLDIVLoss(reduction=reduction) + elif input.kernel_provider == "torch": + loss_fn = nn.KLDivLoss(reduction=reduction) + else: + raise ValueError(f"Invalid provider: {input.kernel_provider} for KLDiv") + return _input, target, loss_fn + + +def bench_speed_kl_div(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + _input, target, loss_fn = _setup_kl_div(input) + mode = input.kernel_operation_mode def fwd(): - if input.kernel_provider == "liger": - return liger_kl_div(_input, target) - else: - return torch_kl_div(_input, target) + return loss_fn(_input, target) - if input.kernel_operation_mode == "forward": + if mode == "forward": ms_50, ms_20, ms_80 = triton.testing.do_bench(fwd, quantiles=QUANTILES, rep=100) - elif input.kernel_operation_mode == "backward": + elif mode == "backward": y = fwd() - ms_50, ms_20, ms_80 = triton.testing.do_bench( lambda: y.backward(retain_graph=True), quantiles=QUANTILES, grad_to_none=[_input], rep=100, ) - elif input.kernel_operation_mode == "full": + elif mode == "full": def full(): y = fwd() y.backward(retain_graph=True) ms_50, ms_20, ms_80 = triton.testing.do_bench(full, quantiles=QUANTILES, rep=100) - return SingleBenchmarkRunOutput( - y_20=ms_20, - y_50=ms_50, - y_80=ms_80, - ) + else: + raise ValueError(f"Unsupported mode: {mode}") + return SingleBenchmarkRunOutput(y_20=ms_20, y_50=ms_50, y_80=ms_80) -def bench_memory_kldiv(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: - reduction = "batchmean" - torch_kl_div = nn.KLDivLoss(reduction=reduction) - liger_kl_div = LigerKLDIVLoss(reduction=reduction) - V = input.x - B, T = input.extra_benchmark_config["B"], input.extra_benchmark_config["T"] +def bench_memory_kl_div(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + _input, target, loss_fn = _setup_kl_div(input) + + def full(): + y = loss_fn(_input, target) + y.backward(retain_graph=True) + + mem_50, mem_20, mem_80 = _test_memory(full, quantiles=QUANTILES) + return SingleBenchmarkRunOutput(y_20=mem_20, y_50=mem_50, y_80=mem_80) + + +def _resolve_model_config_kl_div(input: SingleBenchmarkRunInput): + """Resolve model-config-sweep input into standard setup args.""" + cfg = input.extra_benchmark_config + model_info = cfg["model_configs"][input.x] + return _setup_kl_div( + SingleBenchmarkRunInput( + x=cfg["BT"], + kernel_provider=input.kernel_provider, + extra_benchmark_config={ + "vocab_size": model_info["vocab_size"], + }, + ) + ) + - _input = torch.randn(B * T, V, requires_grad=True, device=device).log_softmax(dim=-1) - target = torch.randn(B * T, V, device=device).softmax(dim=-1) +def bench_speed_kl_div_model_config(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + _input, target, loss_fn = _resolve_model_config_kl_div(input) + mode = input.kernel_operation_mode def fwd(): - if input.kernel_provider == "liger": - return liger_kl_div(_input, target) - else: - return torch_kl_div(_input, target) + return loss_fn(_input, target) - def full(): + if mode == "forward": + ms_50, ms_20, ms_80 = triton.testing.do_bench(fwd, quantiles=QUANTILES, rep=100) + elif mode == "backward": y = fwd() + ms_50, ms_20, ms_80 = triton.testing.do_bench( + lambda: y.backward(retain_graph=True), + quantiles=QUANTILES, + grad_to_none=[_input], + rep=100, + ) + elif mode == "full": + + def full(): + y = fwd() + y.backward(retain_graph=True) + + ms_50, ms_20, ms_80 = triton.testing.do_bench(full, quantiles=QUANTILES, rep=100) + else: + raise ValueError(f"Unsupported mode: {mode}") + + return SingleBenchmarkRunOutput(y_20=ms_20, y_50=ms_50, y_80=ms_80) + + +def bench_memory_kl_div_model_config(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + _input, target, loss_fn = _resolve_model_config_kl_div(input) + + def full(): + y = loss_fn(_input, target) y.backward(retain_graph=True) mem_50, mem_20, mem_80 = _test_memory(full, quantiles=QUANTILES) - - return SingleBenchmarkRunOutput( - y_20=mem_20, - y_50=mem_50, - y_80=mem_80, - ) + return SingleBenchmarkRunOutput(y_20=mem_20, y_50=mem_50, y_80=mem_80) if __name__ == "__main__": args = parse_benchmark_script_args() - common_args = { - "kernel_name": "kl_div", - "x_name": "V", - "x_label": "vocab size", - "x_values": [2**i for i in range(12, 18)], - "kernel_providers": ["liger", "torch"], - "extra_benchmark_configs": [{"B": 8, "T": 512}], - "overwrite": args.overwrite, - } - - run_benchmarks( - bench_test_fn=bench_memory_kldiv, - kernel_operation_modes=["full"], - metric_name="memory", - metric_unit="MB", - **common_args, - ) - run_benchmarks( - bench_test_fn=bench_speed_kldiv, - kernel_operation_modes=["forward", "backward", "full"], - metric_name="speed", - metric_unit="ms", - **common_args, - ) + if args.sweep_mode == "model_config": + all_model_configs = list(MODEL_REGISTRY.values()) + + def _probe_factory(model_cfg, probe_bt): + def _probe(): + probe_input = SingleBenchmarkRunInput( + x=probe_bt, + kernel_provider="torch", + extra_benchmark_config={ + "vocab_size": model_cfg.vocab_size, + }, + ) + _input, target, loss_fn = _setup_kl_div(probe_input) + return loss_fn(_input, target) + + return _probe + + sweep = compute_model_config_sweep_config(all_model_configs, probe_fn_factory=_probe_factory, bt=args.bt) + + model_configs_info = { + cfg.name: { + "vocab_size": cfg.vocab_size, + } + for cfg in sweep.model_configs + } + + common_configs = { + "kernel_name": "kl_div", + "x_name": "model_config", + "x_label": "model configuration", + "x_values": [cfg.name for cfg in sweep.model_configs], + "kernel_providers": ["liger", "torch"], + "extra_benchmark_configs": [ + { + "model_configs": model_configs_info, + "BT": sweep.bt, + } + ], + "overwrite": args.overwrite, + } + + run_benchmarks( + bench_test_fn=bench_speed_kl_div_model_config, + kernel_operation_modes=["forward", "backward", "full"], + metric_name="speed", + metric_unit="ms", + **common_configs, + ) + run_benchmarks( + bench_test_fn=bench_memory_kl_div_model_config, + kernel_operation_modes=["full"], + metric_name="memory", + metric_unit="MB", + **common_configs, + ) + else: + model = get_benchmark_model_config(args.model) + probe_bt = 1024 + + def _probe(): + probe_input = SingleBenchmarkRunInput( + x=probe_bt, + kernel_provider="torch", + extra_benchmark_config={ + "vocab_size": model.vocab_size, + }, + ) + _input, target, loss_fn = _setup_kl_div(probe_input) + return loss_fn(_input, target) + + peak_bytes = estimate_kernel_peak_memory(probe_fn=_probe) + kernel_bpt = peak_bytes // probe_bt + + config = compute_seq_len_sweep_config(model, kernel_bytes_per_token=kernel_bpt) + + common_configs = { + "kernel_name": "kl_div", + "x_name": "BT", + "x_label": "B * T", + "x_values": [2**i for i in range(10, int(math.log2(config.batch_size * config.seq_len)) + 1)], + "kernel_providers": ["liger", "torch"], + "extra_benchmark_configs": [ + { + "vocab_size": model.vocab_size, + } + ], + "overwrite": args.overwrite, + } + + run_benchmarks( + bench_test_fn=bench_speed_kl_div, + kernel_operation_modes=["forward", "backward", "full"], + metric_name="speed", + metric_unit="ms", + **common_configs, + ) + run_benchmarks( + bench_test_fn=bench_memory_kl_div, + kernel_operation_modes=["full"], + metric_name="memory", + metric_unit="MB", + **common_configs, + ) diff --git a/benchmark/scripts/benchmark_kto_loss.py b/benchmark/scripts/benchmark_kto_loss.py index bbde1d5c6..27929420f 100644 --- a/benchmark/scripts/benchmark_kto_loss.py +++ b/benchmark/scripts/benchmark_kto_loss.py @@ -1,9 +1,15 @@ +import math import os import sys import torch import triton +from benchmark_model_configs import MODEL_REGISTRY +from benchmark_model_configs import compute_model_config_sweep_config +from benchmark_model_configs import compute_seq_len_sweep_config +from benchmark_model_configs import estimate_kernel_peak_memory +from benchmark_model_configs import get_benchmark_model_config from utils import QUANTILES from utils import SingleBenchmarkRunInput from utils import SingleBenchmarkRunOutput @@ -19,26 +25,13 @@ class TorchLMHeadKTO(torch.nn.Module): - def __init__( - self, - H: int, - V: int, - dtype: torch.dtype, - use_bias: bool = False, - use_ref_bias: bool = False, - ignore_index: int = -100, - beta: float = 0.1, - ): + def __init__(self, H, V, dtype, use_bias=False, use_ref_bias=False, ignore_index=-100, beta=0.1): from test.chunked_loss.test_kto_loss import HFKTOLoss super().__init__() self.lin = torch.nn.Linear(in_features=H, out_features=V, bias=use_bias, dtype=dtype) self.ref_lin = torch.nn.Linear(in_features=H, out_features=V, bias=use_ref_bias, dtype=dtype) - self.KTO_loss = HFKTOLoss( - ignore_index=ignore_index, - beta=beta, - use_ref_model=True, - ).get_batch_loss_metrics + self.KTO_loss = HFKTOLoss(ignore_index=ignore_index, beta=beta, use_ref_model=True).get_batch_loss_metrics def forward(self, x, ref_x, y, preference_labels, kl=None): return self.KTO_loss( @@ -55,24 +48,11 @@ def forward(self, x, ref_x, y, preference_labels, kl=None): class LigerLMHeadKTO(torch.nn.Module): - def __init__( - self, - H: int, - V: int, - dtype: torch.dtype, - use_bias: bool = False, - use_ref_bias: bool = False, - ignore_index: int = -100, - beta: float = 0.1, - ): + def __init__(self, H, V, dtype, use_bias=False, use_ref_bias=False, ignore_index=-100, beta=0.1): super().__init__() self.lin = torch.nn.Linear(in_features=H, out_features=V, bias=use_bias, dtype=dtype) self.ref_lin = torch.nn.Linear(in_features=H, out_features=V, bias=use_ref_bias, dtype=dtype) - self.KTO_loss = LigerFusedLinearKTOLoss( - ignore_index=ignore_index, - beta=beta, - use_ref_model=True, - ) + self.KTO_loss = LigerFusedLinearKTOLoss(ignore_index=ignore_index, beta=beta, use_ref_model=True) def forward(self, x, ref_x, y, preference_labels, kl=None): return self.KTO_loss( @@ -88,227 +68,250 @@ def forward(self, x, ref_x, y, preference_labels, kl=None): ) -def bench_memory_kto_loss(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: +def _setup_kto_loss(input: SingleBenchmarkRunInput): + """Create input tensors and KTO loss from benchmark config.""" + cfg = input.extra_benchmark_config + H = cfg["hidden_size"] + V = cfg["vocab_size"] + dtype = cfg["dtype"] + bias = cfg["bias"] + beta = cfg["beta"] + ignore_index = cfg["ignore_index"] B = input.x - T = input.extra_benchmark_config["T"] - H = input.extra_benchmark_config["H"] - V = input.extra_benchmark_config["V"] - dtype = input.extra_benchmark_config["dtype"] - bias = input.extra_benchmark_config["bias"] - beta = input.extra_benchmark_config["beta"] - ignore_index = input.extra_benchmark_config["ignore_index"] - provider = input.kernel_provider - - torch_kto_loss = TorchLMHeadKTO( - H=H, - V=V, - dtype=dtype, - use_bias=bias, - use_ref_bias=bias, - ignore_index=ignore_index, - beta=beta, - ).to(device) - - liger_kto_loss = LigerLMHeadKTO( - H=H, - V=V, - dtype=dtype, - use_bias=bias, - use_ref_bias=bias, - ignore_index=ignore_index, - beta=beta, - ).to(device) - - # Input shape: [B, T, H] - _input = torch.randn(B, T, H, device=device, dtype=dtype) + T = cfg["T"] - # Target shape: [B, T] + _input = torch.randn(B, T, H, device=device, dtype=dtype) + ref_input = torch.randn(B, T, H, device=device, dtype=dtype) target = torch.randint(V, (B, T), dtype=torch.long, device=device) - - # Preference labels shape: [B] - # Create binary preference labels (0 or 1) for each sequence in the batch - # Used to indicate preferred sequences (1) vs non-preferred sequences (0) preference_labels = torch.randint(2, (B,), dtype=torch.bool, device=device) - - # Precomputed KL divergence between policy and reference distributions kl = torch.randn(1, device=device, dtype=dtype) - # Add ignore_index tokens to simulate padding num_elements_to_assign = torch.randint(1, B * T // 2, (1,)).item() indices_to_assign = torch.randperm(B * T)[:num_elements_to_assign] target.view(-1)[indices_to_assign] = ignore_index - # Add ref_x with the same shape as _input - ref_input = torch.randn(B, T, H, device=device, dtype=dtype) - - def fwd(): - if provider == "liger": - return liger_kto_loss( - x=_input, - ref_x=ref_input, - y=target, - preference_labels=preference_labels, - kl=kl, - )[0] - elif provider == "huggingface": - return torch_kto_loss( - x=_input, - ref_x=ref_input, - y=target, - preference_labels=preference_labels, - kl=kl, - )[0] + if input.kernel_provider == "liger": + loss_module = LigerLMHeadKTO( + H=H, V=V, dtype=dtype, use_bias=bias, use_ref_bias=bias, ignore_index=ignore_index, beta=beta + ).to(device) + elif input.kernel_provider == "huggingface": + loss_module = TorchLMHeadKTO( + H=H, V=V, dtype=dtype, use_bias=bias, use_ref_bias=bias, ignore_index=ignore_index, beta=beta + ).to(device) + else: + raise ValueError(f"Invalid provider: {input.kernel_provider} for KTOLoss") - def full(): - y = fwd() - y.backward() - - mem_50, mem_20, mem_80 = _test_memory(full, _iter=10, quantiles=QUANTILES) - return SingleBenchmarkRunOutput( - y_20=mem_20, - y_50=mem_50, - y_80=mem_80, - ) + fwd_fn = lambda: loss_module(x=_input, ref_x=ref_input, y=target, preference_labels=preference_labels, kl=kl)[0] + return _input, fwd_fn def bench_speed_kto_loss(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: - B = input.x - T = input.extra_benchmark_config["T"] - H = input.extra_benchmark_config["H"] - V = input.extra_benchmark_config["V"] - dtype = input.extra_benchmark_config["dtype"] - bias = input.extra_benchmark_config["bias"] - beta = input.extra_benchmark_config["beta"] - ignore_index = input.extra_benchmark_config["ignore_index"] - provider = input.kernel_provider + _input, fwd_fn = _setup_kto_loss(input) mode = input.kernel_operation_mode - torch_kto_loss = TorchLMHeadKTO( - H=H, - V=V, - dtype=dtype, - beta=beta, - ignore_index=ignore_index, - use_bias=bias, - ).to(device) - liger_kto_loss = LigerLMHeadKTO( - H=H, - V=V, - dtype=dtype, - beta=beta, - ignore_index=ignore_index, - use_bias=bias, - ).to(device) - - # Input shape: [B, T, H] - _input = torch.randn(B, T, H, device=device, dtype=dtype) + if mode == "forward": + ms_50, ms_20, ms_80 = triton.testing.do_bench(fwd_fn, rep=100, quantiles=QUANTILES) + elif mode == "backward": + y = fwd_fn() + ms_50, ms_20, ms_80 = triton.testing.do_bench( + lambda: y.backward(retain_graph=True), grad_to_none=[_input], rep=100, quantiles=QUANTILES + ) + elif mode == "full": - # Target shape: [B, T] - target = torch.randint(V, (B, T), device=device, dtype=torch.long) + def full(): + y = fwd_fn() + y.backward() - # Preference labels shape: [B] - # Create binary preference labels (0 or 1) for each sequence in the batch - # Used to indicate preferred sequences (1) vs non-preferred sequences (0) - preference_labels = torch.randint(2, (B,), dtype=torch.bool, device=device) + ms_50, ms_20, ms_80 = triton.testing.do_bench(full, rep=100, quantiles=QUANTILES) + else: + raise ValueError(f"Unsupported mode: {mode}") + return SingleBenchmarkRunOutput(y_20=ms_20, y_50=ms_50, y_80=ms_80) - # Precomputed KL divergence between policy and reference distributions - kl = torch.randn(1, device=device, dtype=dtype) - # Add ignore_index tokens - num_elements_to_assign = torch.randint(1, B * T // 2, (1,)).item() - indices_to_assign = torch.randperm(B * T)[:num_elements_to_assign] - target.view(-1)[indices_to_assign] = ignore_index +def bench_memory_kto_loss(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + _input, fwd_fn = _setup_kto_loss(input) + + def full(): + y = fwd_fn() + y.backward() + + mem_50, mem_20, mem_80 = _test_memory(full, _iter=10, quantiles=QUANTILES) + return SingleBenchmarkRunOutput(y_20=mem_20, y_50=mem_50, y_80=mem_80) + + +def _resolve_model_config_kto_loss(input: SingleBenchmarkRunInput): + cfg = input.extra_benchmark_config + model_info = cfg["model_configs"][input.x] + return _setup_kto_loss( + SingleBenchmarkRunInput( + x=cfg["B"], + kernel_provider=input.kernel_provider, + extra_benchmark_config={ + "hidden_size": model_info["hidden_size"], + "vocab_size": model_info["vocab_size"], + "dtype": model_info["dtype"], + "T": cfg["T"], + "bias": cfg["bias"], + "beta": cfg["beta"], + "ignore_index": cfg["ignore_index"], + }, + ) + ) - # Add ref_x with the same shape as _input - ref_input = torch.randn(B, T, H, device=device, dtype=dtype) - def fwd(): - if provider == "liger": - return liger_kto_loss( - x=_input, - ref_x=ref_input, - y=target, - preference_labels=preference_labels, - kl=kl, - )[0] - elif provider == "huggingface": - return torch_kto_loss( - x=_input, - ref_x=ref_input, - y=target, - preference_labels=preference_labels, - kl=kl, - )[0] +def bench_speed_kto_loss_model_config(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + _input, fwd_fn = _resolve_model_config_kto_loss(input) + mode = input.kernel_operation_mode if mode == "forward": - ms_50, ms_20, ms_80 = triton.testing.do_bench( - fwd, - rep=100, - quantiles=QUANTILES, - ) + ms_50, ms_20, ms_80 = triton.testing.do_bench(fwd_fn, rep=100, quantiles=QUANTILES) elif mode == "backward": - y = fwd() + y = fwd_fn() ms_50, ms_20, ms_80 = triton.testing.do_bench( - lambda: y.backward(retain_graph=True), - grad_to_none=[_input], - rep=100, - quantiles=QUANTILES, + lambda: y.backward(retain_graph=True), grad_to_none=[_input], rep=100, quantiles=QUANTILES ) elif mode == "full": def full(): - y = fwd() + y = fwd_fn() y.backward() - ms_50, ms_20, ms_80 = triton.testing.do_bench( - full, - rep=100, - quantiles=QUANTILES, - ) + ms_50, ms_20, ms_80 = triton.testing.do_bench(full, rep=100, quantiles=QUANTILES) + else: + raise ValueError(f"Unsupported mode: {mode}") + return SingleBenchmarkRunOutput(y_20=ms_20, y_50=ms_50, y_80=ms_80) - return SingleBenchmarkRunOutput( - y_20=ms_20, - y_50=ms_50, - y_80=ms_80, - ) + +def bench_memory_kto_loss_model_config(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + _input, fwd_fn = _resolve_model_config_kto_loss(input) + + def full(): + y = fwd_fn() + y.backward() + + mem_50, mem_20, mem_80 = _test_memory(full, _iter=10, quantiles=QUANTILES) + return SingleBenchmarkRunOutput(y_20=mem_20, y_50=mem_50, y_80=mem_80) if __name__ == "__main__": args = parse_benchmark_script_args() - common_configs = { - "kernel_name": "kto_loss", - "x_name": "B", - "x_label": "Batch Size (B)", - "x_values": [2**i for i in range(1, 6)], - "kernel_providers": ["liger", "huggingface"], - "extra_benchmark_configs": [ - { - "T": 512, - "H": 1024, - "V": 128256, - "mode": "forward", - "dtype": torch.bfloat16, - "bias": True, - "beta": 0.1, - "ignore_index": 42, - } - ], - "overwrite": args.overwrite, - } - - run_benchmarks( - bench_test_fn=bench_speed_kto_loss, - kernel_operation_modes=["forward", "backward", "full"], - metric_name="speed", - metric_unit="ms", - **common_configs, - ) - - run_benchmarks( - bench_test_fn=bench_memory_kto_loss, - kernel_operation_modes=["full"], - metric_name="memory", - metric_unit="MB", - **common_configs, - ) + if args.sweep_mode == "model_config": + all_model_configs = list(MODEL_REGISTRY.values()) + T = 512 + + def _probe_factory(model_cfg, probe_bt): + def _probe(): + B = max(1, probe_bt // T) + probe_input = SingleBenchmarkRunInput( + x=B, + kernel_provider="huggingface", + extra_benchmark_config={ + "hidden_size": model_cfg.hidden_size, + "vocab_size": model_cfg.vocab_size, + "dtype": model_cfg.dtype, + "T": T, + "bias": True, + "beta": 0.1, + "ignore_index": 42, + }, + ) + _, fwd_fn = _setup_kto_loss(probe_input) + return fwd_fn() + + return _probe + + sweep = compute_model_config_sweep_config(all_model_configs, probe_fn_factory=_probe_factory, bt=args.bt) + model_configs_info = { + cfg.name: {"hidden_size": cfg.hidden_size, "vocab_size": cfg.vocab_size, "dtype": cfg.dtype} + for cfg in sweep.model_configs + } + B = max(1, sweep.bt // T) + + common_configs = { + "kernel_name": "kto_loss", + "x_name": "model_config", + "x_label": "model configuration", + "x_values": [cfg.name for cfg in sweep.model_configs], + "kernel_providers": ["liger", "huggingface"], + "extra_benchmark_configs": [ + {"model_configs": model_configs_info, "B": B, "T": T, "bias": True, "beta": 0.1, "ignore_index": 42} + ], + "overwrite": args.overwrite, + } + + run_benchmarks( + bench_test_fn=bench_speed_kto_loss_model_config, + kernel_operation_modes=["forward", "backward", "full"], + metric_name="speed", + metric_unit="ms", + **common_configs, + ) + run_benchmarks( + bench_test_fn=bench_memory_kto_loss_model_config, + kernel_operation_modes=["full"], + metric_name="memory", + metric_unit="MB", + **common_configs, + ) + else: + model = get_benchmark_model_config(args.model) + T = 512 + probe_bt = 1024 + + def _probe(): + B = max(1, probe_bt // T) + probe_input = SingleBenchmarkRunInput( + x=B, + kernel_provider="huggingface", + extra_benchmark_config={ + "hidden_size": model.hidden_size, + "vocab_size": model.vocab_size, + "dtype": model.dtype, + "T": T, + "bias": True, + "beta": 0.1, + "ignore_index": 42, + }, + ) + _, fwd_fn = _setup_kto_loss(probe_input) + return fwd_fn() + + peak_bytes = estimate_kernel_peak_memory(probe_fn=_probe) + kernel_bpt = peak_bytes // probe_bt + config = compute_seq_len_sweep_config(model, kernel_bytes_per_token=kernel_bpt) + + common_configs = { + "kernel_name": "kto_loss", + "x_name": "B", + "x_label": "Batch Size (B)", + "x_values": [2**i for i in range(1, int(math.log2(max(2, config.batch_size * config.seq_len // T))) + 1)], + "kernel_providers": ["liger", "huggingface"], + "extra_benchmark_configs": [ + { + "hidden_size": model.hidden_size, + "vocab_size": model.vocab_size, + "dtype": model.dtype, + "T": T, + "bias": True, + "beta": 0.1, + "ignore_index": 42, + } + ], + "overwrite": args.overwrite, + } + + run_benchmarks( + bench_test_fn=bench_speed_kto_loss, + kernel_operation_modes=["forward", "backward", "full"], + metric_name="speed", + metric_unit="ms", + **common_configs, + ) + run_benchmarks( + bench_test_fn=bench_memory_kto_loss, + kernel_operation_modes=["full"], + metric_name="memory", + metric_unit="MB", + **common_configs, + ) diff --git a/benchmark/scripts/benchmark_layer_norm.py b/benchmark/scripts/benchmark_layer_norm.py index 0addf78ed..456e62b69 100644 --- a/benchmark/scripts/benchmark_layer_norm.py +++ b/benchmark/scripts/benchmark_layer_norm.py @@ -1,12 +1,18 @@ +import math + import torch -import triton -from utils import QUANTILES +from benchmark_model_configs import MODEL_REGISTRY +from benchmark_model_configs import compute_model_config_sweep_config +from benchmark_model_configs import compute_seq_len_sweep_config +from benchmark_model_configs import estimate_kernel_peak_memory +from benchmark_model_configs import get_benchmark_model_config from utils import SingleBenchmarkRunInput from utils import SingleBenchmarkRunOutput -from utils import _test_memory from utils import parse_benchmark_script_args from utils import run_benchmarks +from utils import run_memory_benchmark +from utils import run_speed_benchmark from liger_kernel.transformers.layer_norm import LigerLayerNorm from liger_kernel.utils import infer_device @@ -14,112 +20,175 @@ device = infer_device() -def bench_speed_layer_norm(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: - N = input.x - provider = input.kernel_provider - mode = input.kernel_operation_mode - extra_benchmark_config = input.extra_benchmark_config - M = extra_benchmark_config["M"] - eps = extra_benchmark_config["eps"] - dtype = extra_benchmark_config["dtype"] - - x_shape = (M, N) - triton_ln = LigerLayerNorm(hidden_size=N).to(device) - torch_ln = torch.nn.LayerNorm(N, eps=eps).to(device) - - x = torch.randn(x_shape, dtype=dtype, device=device) - dy = torch.randn_like(x) - x.requires_grad_(True) - - def y_fwd(): - if provider == "liger": - return triton_ln(x) - if provider == "huggingface": - return torch_ln(x) - - if mode == "forward": - ms_50, ms_20, ms_80 = triton.testing.do_bench(y_fwd, quantiles=QUANTILES, grad_to_none=[x], rep=500) - elif mode == "backward": - y = y_fwd() - ms_50, ms_20, ms_80 = triton.testing.do_bench( - lambda: y.backward(dy, retain_graph=True), - quantiles=QUANTILES, - grad_to_none=[x], - rep=500, - ) - elif mode == "full": - - def full(): - y = y_fwd() - y.backward(dy, retain_graph=True) +def _setup_layer_norm(input: SingleBenchmarkRunInput): + """Create input tensor and LayerNorm layer from benchmark config.""" + cfg = input.extra_benchmark_config + hidden_size = cfg["hidden_size"] + eps = cfg["eps"] + x = torch.randn( + input.x, + hidden_size, + device=device, + dtype=cfg["dtype"], + requires_grad=True, + ) + if input.kernel_provider == "liger": + layer = LigerLayerNorm(hidden_size=hidden_size, eps=eps).to(device) + elif input.kernel_provider == "huggingface": + layer = torch.nn.LayerNorm(hidden_size, eps=eps).to(device) + else: + raise ValueError(f"Invalid provider: {input.kernel_provider} for LayerNorm") + return x, layer - ms_50, ms_20, ms_80 = triton.testing.do_bench(full, quantiles=QUANTILES, grad_to_none=[x], rep=500) - return SingleBenchmarkRunOutput( - y_20=ms_20, - y_50=ms_50, - y_80=ms_80, - ) +def bench_speed_layer_norm(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + x, layer = _setup_layer_norm(input) + return run_speed_benchmark(lambda: layer(x), input.kernel_operation_mode, [x]) def bench_memory_layer_norm(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: - N = input.x - provider = input.kernel_provider - dtype = input.extra_benchmark_config["dtype"] - M = input.extra_benchmark_config["M"] - eps = input.extra_benchmark_config["eps"] - - x_shape = (M, N) - - triton_ln = LigerLayerNorm(hidden_size=N).to(device) - torch_ln = torch.nn.LayerNorm(N, eps=eps).to(device) - - x = torch.randn(x_shape, dtype=dtype, device=device) - dy = torch.randn_like(x) - x.requires_grad_(True) - - def y_fwd(): - if provider == "liger": - return triton_ln(x) - if provider == "huggingface": - return torch_ln(x) - - def full(): - y = y_fwd() - y.backward(dy, retain_graph=True) - - mem_50, mem_20, mem_80 = _test_memory(full, quantiles=QUANTILES) - return SingleBenchmarkRunOutput( - y_20=mem_20, - y_50=mem_50, - y_80=mem_80, + x, layer = _setup_layer_norm(input) + return run_memory_benchmark(lambda: layer(x), input.kernel_operation_mode) + + +def _resolve_model_config_layer_norm(input: SingleBenchmarkRunInput): + """Resolve model-config-sweep input into standard setup args.""" + cfg = input.extra_benchmark_config + model_info = cfg["model_configs"][input.x] + return _setup_layer_norm( + SingleBenchmarkRunInput( + x=cfg["BT"], + kernel_provider=input.kernel_provider, + extra_benchmark_config={ + "hidden_size": model_info["hidden_size"], + "dtype": model_info["dtype"], + "eps": cfg["eps"], + }, + ) ) +def bench_speed_layer_norm_model_config(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + x, layer = _resolve_model_config_layer_norm(input) + return run_speed_benchmark(lambda: layer(x), input.kernel_operation_mode, [x]) + + +def bench_memory_layer_norm_model_config(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + x, layer = _resolve_model_config_layer_norm(input) + return run_memory_benchmark(lambda: layer(x), input.kernel_operation_mode) + + if __name__ == "__main__": args = parse_benchmark_script_args() - common_configs = { - "kernel_name": "layer_norm", - "x_name": "N", - "x_label": "hidden size", - "x_values": [2**i for i in range(10, 15)], - "kernel_providers": ["liger", "huggingface"], - "extra_benchmark_configs": [{"M": 4096, "dtype": torch.float32, "eps": 1e-6}], - "overwrite": args.overwrite, - } - - run_benchmarks( - bench_test_fn=bench_speed_layer_norm, - kernel_operation_modes=["forward", "backward", "full"], - metric_name="speed", - metric_unit="ms", - **common_configs, - ) - run_benchmarks( - bench_test_fn=bench_memory_layer_norm, - kernel_operation_modes=["full"], - metric_name="memory", - metric_unit="MB", - **common_configs, - ) + if args.sweep_mode == "model_config": + all_model_configs = list(MODEL_REGISTRY.values()) + + def _probe_factory(model_cfg, probe_bt): + def _probe(): + probe_input = SingleBenchmarkRunInput( + x=probe_bt, + kernel_provider="huggingface", + extra_benchmark_config={ + "hidden_size": model_cfg.hidden_size, + "dtype": model_cfg.dtype, + "eps": 1e-6, + }, + ) + x, layer = _setup_layer_norm(probe_input) + return layer(x) + + return _probe + + sweep = compute_model_config_sweep_config(all_model_configs, probe_fn_factory=_probe_factory, bt=args.bt) + + model_configs_info = { + cfg.name: { + "hidden_size": cfg.hidden_size, + "dtype": cfg.dtype, + } + for cfg in sweep.model_configs + } + + common_configs = { + "kernel_name": "layer_norm", + "x_name": "model_config", + "x_label": "model configuration", + "x_values": [cfg.name for cfg in sweep.model_configs], + "kernel_providers": ["liger", "huggingface"], + "extra_benchmark_configs": [ + { + "model_configs": model_configs_info, + "BT": sweep.bt, + "eps": 1e-6, + } + ], + "overwrite": args.overwrite, + } + + run_benchmarks( + bench_test_fn=bench_speed_layer_norm_model_config, + kernel_operation_modes=["full", "forward", "backward"], + metric_name="speed", + metric_unit="ms", + **common_configs, + ) + run_benchmarks( + bench_test_fn=bench_memory_layer_norm_model_config, + kernel_operation_modes=["full", "forward", "backward"], + metric_name="memory", + metric_unit="MB", + **common_configs, + ) + else: + model = get_benchmark_model_config(args.model) + probe_bt = 1024 + + def _probe(): + probe_input = SingleBenchmarkRunInput( + x=probe_bt, + kernel_provider="huggingface", + extra_benchmark_config={ + "hidden_size": model.hidden_size, + "dtype": model.dtype, + "eps": 1e-6, + }, + ) + x, layer = _setup_layer_norm(probe_input) + return layer(x) + + peak_bytes = estimate_kernel_peak_memory(probe_fn=_probe) + kernel_bpt = peak_bytes // probe_bt + + config = compute_seq_len_sweep_config(model, kernel_bytes_per_token=kernel_bpt) + + common_configs = { + "kernel_name": "layer_norm", + "x_name": "BT", + "x_label": "B * T", + "x_values": [2**i for i in range(10, int(math.log2(config.batch_size * config.seq_len)) + 1)], + "kernel_providers": ["liger", "huggingface"], + "extra_benchmark_configs": [ + { + "hidden_size": model.hidden_size, + "dtype": model.dtype, + "eps": 1e-6, + } + ], + "overwrite": args.overwrite, + } + + run_benchmarks( + bench_test_fn=bench_speed_layer_norm, + kernel_operation_modes=["full", "forward", "backward"], + metric_name="speed", + metric_unit="ms", + **common_configs, + ) + run_benchmarks( + bench_test_fn=bench_memory_layer_norm, + kernel_operation_modes=["full", "forward", "backward"], + metric_name="memory", + metric_unit="MB", + **common_configs, + ) diff --git a/benchmark/scripts/benchmark_llama4_rope.py b/benchmark/scripts/benchmark_llama4_rope.py index 47d06051e..20890d06c 100644 --- a/benchmark/scripts/benchmark_llama4_rope.py +++ b/benchmark/scripts/benchmark_llama4_rope.py @@ -1,6 +1,15 @@ +import math +import os +import sys + import torch import triton +from benchmark_model_configs import MODEL_REGISTRY +from benchmark_model_configs import compute_model_config_sweep_config +from benchmark_model_configs import compute_seq_len_sweep_config +from benchmark_model_configs import estimate_kernel_peak_memory +from benchmark_model_configs import get_benchmark_model_config from transformers.models.llama4.configuration_llama4 import Llama4TextConfig from transformers.models.llama4.modeling_llama4 import Llama4TextRotaryEmbedding from transformers.models.llama4.modeling_llama4 import apply_rotary_emb @@ -17,23 +26,19 @@ device = infer_device() +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))) -def bench_speed_llama4_rope(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: - provider = input.kernel_provider - mode = input.kernel_operation_mode - - extra_benchmark_config = input.extra_benchmark_config - num_q_heads = extra_benchmark_config["num_q_heads"] - num_kv_heads = extra_benchmark_config["num_kv_heads"] - dtype = extra_benchmark_config["dtype"] - # x can be either hidden_size or seq_len - hidden_size = extra_benchmark_config["hidden_size"] if "hidden_size" in extra_benchmark_config else input.x - seq_len = extra_benchmark_config["seq_len"] if "seq_len" in extra_benchmark_config else input.x +def _setup_llama4_rope(input: SingleBenchmarkRunInput): + """Create input tensors and Llama4 RoPE embedding from benchmark config.""" + cfg = input.extra_benchmark_config + num_q_heads = cfg["num_q_heads"] + num_kv_heads = cfg["num_kv_heads"] + dtype = cfg["dtype"] + hidden_size = cfg.get("hidden_size", input.x) + seq_len = cfg.get("seq_len", input.x) head_dim = hidden_size // num_q_heads - - # Create Llama4TextConfig for the rotary embedding config = Llama4TextConfig( hidden_size=hidden_size, num_attention_heads=num_q_heads, @@ -41,7 +46,6 @@ def bench_speed_llama4_rope(input: SingleBenchmarkRunInput) -> SingleBenchmarkRu head_dim=head_dim, max_position_embeddings=seq_len, ) - rotary_emb = transformers_version_dispatch( "4.48.0", Llama4TextRotaryEmbedding, @@ -50,42 +54,30 @@ def bench_speed_llama4_rope(input: SingleBenchmarkRunInput) -> SingleBenchmarkRu after_kwargs={"config": config, "device": device}, ) - q = torch.randn( - (1, seq_len, num_q_heads, head_dim), - device=device, - requires_grad=True, - dtype=dtype, - ) - k = torch.randn( - (1, seq_len, num_kv_heads, head_dim), - device=device, - requires_grad=True, - dtype=dtype, - ) - dq, dk = ( - torch.randn_like(q, device=device, dtype=dtype), - torch.randn_like(k, device=device), - ) + q = torch.randn((1, seq_len, num_q_heads, head_dim), device=device, requires_grad=True, dtype=dtype) + k = torch.randn((1, seq_len, num_kv_heads, head_dim), device=device, requires_grad=True, dtype=dtype) + dq, dk = torch.randn_like(q, device=device, dtype=dtype), torch.randn_like(k, device=device) pos_ids = torch.arange(seq_len, device=device, dtype=torch.long).unsqueeze(0) freqs_cis = rotary_emb(q, pos_ids) - def fwd(): - if provider == "liger": - return liger_llama4_text_rotary_pos_emb(q, k, freqs_cis) - elif provider == "huggingface": - return apply_rotary_emb(q, k, freqs_cis) - else: - raise ValueError(f"Invalid provider: {provider} for Llama4 RoPE embedding") + if input.kernel_provider == "liger": + fwd_fn = lambda: liger_llama4_text_rotary_pos_emb(q, k, freqs_cis) + elif input.kernel_provider == "huggingface": + fwd_fn = lambda: apply_rotary_emb(q, k, freqs_cis) + else: + raise ValueError(f"Invalid provider: {input.kernel_provider} for Llama4 RoPE embedding") + + return q, k, dq, dk, fwd_fn + + +def bench_speed_llama4_rope(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + q, k, dq, dk, fwd_fn = _setup_llama4_rope(input) + mode = input.kernel_operation_mode if mode == "forward": - ms_50, ms_20, ms_80 = triton.testing.do_bench( - fwd, - grad_to_none=[q, k], - rep=400, - quantiles=QUANTILES, - ) + ms_50, ms_20, ms_80 = triton.testing.do_bench(fwd_fn, grad_to_none=[q, k], rep=400, quantiles=QUANTILES) elif mode == "backward": - q_out, k_out = fwd() + q_out, k_out = fwd_fn() ms_50, ms_20, ms_80 = triton.testing.do_bench( lambda: torch.autograd.grad((q_out, k_out), (q, k), (dq, dk), allow_unused=True, retain_graph=True), grad_to_none=[q, k], @@ -95,151 +87,192 @@ def fwd(): elif mode == "full": def full(): - q_out, k_out = fwd() + q_out, k_out = fwd_fn() torch.autograd.grad((q_out, k_out), (q, k), (dq, dk), allow_unused=True) - ms_50, ms_20, ms_80 = triton.testing.do_bench( - full, - grad_to_none=[q, k], - rep=400, - quantiles=QUANTILES, - ) - return SingleBenchmarkRunOutput( - y_20=ms_20, - y_50=ms_50, - y_80=ms_80, - ) + ms_50, ms_20, ms_80 = triton.testing.do_bench(full, grad_to_none=[q, k], rep=400, quantiles=QUANTILES) + else: + raise ValueError(f"Unsupported mode: {mode}") + return SingleBenchmarkRunOutput(y_20=ms_20, y_50=ms_50, y_80=ms_80) def bench_memory_llama4_rope(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: - provider = input.kernel_provider + q, k, dq, dk, fwd_fn = _setup_llama4_rope(input) - extra_benchmark_config = input.extra_benchmark_config - num_q_heads = extra_benchmark_config["num_q_heads"] - num_kv_heads = extra_benchmark_config["num_kv_heads"] - dtype = extra_benchmark_config["dtype"] + def full(): + q_out, k_out = fwd_fn() + torch.autograd.grad((q_out, k_out), (q, k), (dq, dk), allow_unused=True, retain_graph=True) - # x can be either hidden_size or seq_len - hidden_size = extra_benchmark_config["hidden_size"] if "hidden_size" in extra_benchmark_config else input.x - seq_len = extra_benchmark_config["seq_len"] if "seq_len" in extra_benchmark_config else input.x + mem_50, mem_20, mem_80 = _test_memory(full, quantiles=QUANTILES) + return SingleBenchmarkRunOutput(y_20=mem_20, y_50=mem_50, y_80=mem_80) - head_dim = hidden_size // num_q_heads - # Create Llama4TextConfig for the rotary embedding - config = Llama4TextConfig( - hidden_size=hidden_size, - num_attention_heads=num_q_heads, - num_key_value_heads=num_kv_heads, - head_dim=head_dim, - max_position_embeddings=seq_len, +def _resolve_model_config_llama4_rope(input: SingleBenchmarkRunInput): + cfg = input.extra_benchmark_config + model_info = cfg["model_configs"][input.x] + return _setup_llama4_rope( + SingleBenchmarkRunInput( + x=input.x, + kernel_provider=input.kernel_provider, + extra_benchmark_config={ + "hidden_size": model_info["hidden_size"], + "num_q_heads": model_info["num_q_heads"], + "num_kv_heads": model_info["num_kv_heads"], + "dtype": model_info["dtype"], + "seq_len": cfg["seq_len"], + }, + ) ) - rotary_emb = transformers_version_dispatch( - "4.48.0", - Llama4TextRotaryEmbedding, - Llama4TextRotaryEmbedding, - before_kwargs={"config": config, "device": device}, - after_kwargs={"config": config, "device": device}, - ) - q = torch.randn( - (1, seq_len, num_q_heads, head_dim), - device=device, - requires_grad=True, - dtype=dtype, - ) - k = torch.randn( - (1, seq_len, num_kv_heads, head_dim), - device=device, - requires_grad=True, - dtype=dtype, - ) - dq, dk = ( - torch.randn_like(q, device=device, dtype=dtype), - torch.randn_like(k, device=device), - ) - pos_ids = torch.arange(seq_len, device=device, dtype=torch.long).unsqueeze(0) - freqs_cis = rotary_emb(q, pos_ids) +def bench_speed_llama4_rope_model_config(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + q, k, dq, dk, fwd_fn = _resolve_model_config_llama4_rope(input) + mode = input.kernel_operation_mode + + if mode == "forward": + ms_50, ms_20, ms_80 = triton.testing.do_bench(fwd_fn, grad_to_none=[q, k], rep=400, quantiles=QUANTILES) + elif mode == "backward": + q_out, k_out = fwd_fn() + ms_50, ms_20, ms_80 = triton.testing.do_bench( + lambda: torch.autograd.grad((q_out, k_out), (q, k), (dq, dk), allow_unused=True, retain_graph=True), + grad_to_none=[q, k], + rep=400, + quantiles=QUANTILES, + ) + elif mode == "full": + + def full(): + q_out, k_out = fwd_fn() + torch.autograd.grad((q_out, k_out), (q, k), (dq, dk), allow_unused=True) + + ms_50, ms_20, ms_80 = triton.testing.do_bench(full, grad_to_none=[q, k], rep=400, quantiles=QUANTILES) + else: + raise ValueError(f"Unsupported mode: {mode}") + return SingleBenchmarkRunOutput(y_20=ms_20, y_50=ms_50, y_80=ms_80) + + +def bench_memory_llama4_rope_model_config(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + q, k, dq, dk, fwd_fn = _resolve_model_config_llama4_rope(input) def full(): - if provider == "liger": - q_out, k_out = liger_llama4_text_rotary_pos_emb(q, k, freqs_cis) - else: - q_out, k_out = apply_rotary_emb(q, k, freqs_cis) + q_out, k_out = fwd_fn() torch.autograd.grad((q_out, k_out), (q, k), (dq, dk), allow_unused=True, retain_graph=True) - mem_50, mem_20, mem_80 = _test_memory( - full, - quantiles=QUANTILES, - ) - return SingleBenchmarkRunOutput( - y_20=mem_20, - y_50=mem_50, - y_80=mem_80, - ) + mem_50, mem_20, mem_80 = _test_memory(full, quantiles=QUANTILES) + return SingleBenchmarkRunOutput(y_20=mem_20, y_50=mem_50, y_80=mem_80) if __name__ == "__main__": args = parse_benchmark_script_args() - common_configs_varying_hidden_size = { - "kernel_name": "llama4_rope", - "x_name": "H", - "x_label": "hidden size", - "x_values": [32 * (2**i) for i in range(4, 10, 2)], - "kernel_providers": ["liger", "huggingface"], - "extra_benchmark_configs": [ - { - "dtype": torch.bfloat16, - "seq_len": 2048, - "num_q_heads": 32, - "num_kv_heads": 8, - } - ], - "overwrite": args.overwrite, - } - run_benchmarks( - bench_test_fn=bench_speed_llama4_rope, - kernel_operation_modes=["forward", "backward", "full"], - metric_name="speed", - metric_unit="ms", - **common_configs_varying_hidden_size, - ) - run_benchmarks( - bench_test_fn=bench_memory_llama4_rope, - kernel_operation_modes=["full"], - metric_name="memory", - metric_unit="MB", - **common_configs_varying_hidden_size, - ) + if args.sweep_mode == "model_config": + all_model_configs = list(MODEL_REGISTRY.values()) + seq_len = 2048 + + def _probe_factory(model_cfg, probe_bt): + def _probe(): + probe_input = SingleBenchmarkRunInput( + x=0, + kernel_provider="huggingface", + extra_benchmark_config={ + "hidden_size": model_cfg.hidden_size, + "num_q_heads": model_cfg.num_attention_heads, + "num_kv_heads": model_cfg.num_key_value_heads, + "dtype": model_cfg.dtype, + "seq_len": seq_len, + }, + ) + _, _, _, _, fwd_fn = _setup_llama4_rope(probe_input) + return fwd_fn()[0] + + return _probe - common_configs_varying_seq_len = { - "kernel_name": "llama4_rope", - "x_name": "T", - "x_label": "sequence length", - "x_values": [2**i for i in range(10, 15)], - "kernel_providers": ["liger", "huggingface"], - "extra_benchmark_configs": [ - { - "dtype": torch.bfloat16, - "hidden_size": 8192, - "num_q_heads": 32, - "num_kv_heads": 8, + sweep = compute_model_config_sweep_config(all_model_configs, probe_fn_factory=_probe_factory, bt=args.bt) + model_configs_info = { + cfg.name: { + "hidden_size": cfg.hidden_size, + "num_q_heads": cfg.num_attention_heads, + "num_kv_heads": cfg.num_key_value_heads, + "dtype": cfg.dtype, } - ], - "overwrite": args.overwrite, - } - run_benchmarks( - bench_test_fn=bench_speed_llama4_rope, - kernel_operation_modes=["forward", "backward", "full"], - metric_name="speed", - metric_unit="ms", - **common_configs_varying_seq_len, - ) - run_benchmarks( - bench_test_fn=bench_memory_llama4_rope, - kernel_operation_modes=["full"], - metric_name="memory", - metric_unit="MB", - **common_configs_varying_seq_len, - ) + for cfg in sweep.model_configs + } + + common_configs = { + "kernel_name": "llama4_rope", + "x_name": "model_config", + "x_label": "model configuration", + "x_values": [cfg.name for cfg in sweep.model_configs], + "kernel_providers": ["liger", "huggingface"], + "extra_benchmark_configs": [{"model_configs": model_configs_info, "seq_len": seq_len}], + "overwrite": args.overwrite, + } + + run_benchmarks( + bench_test_fn=bench_speed_llama4_rope_model_config, + kernel_operation_modes=["forward", "backward", "full"], + metric_name="speed", + metric_unit="ms", + **common_configs, + ) + run_benchmarks( + bench_test_fn=bench_memory_llama4_rope_model_config, + kernel_operation_modes=["full"], + metric_name="memory", + metric_unit="MB", + **common_configs, + ) + else: + model = get_benchmark_model_config(args.model) + probe_seq_len = 2048 + + def _probe(): + probe_input = SingleBenchmarkRunInput( + x=0, + kernel_provider="huggingface", + extra_benchmark_config={ + "hidden_size": model.hidden_size, + "num_q_heads": model.num_attention_heads, + "num_kv_heads": model.num_key_value_heads, + "dtype": model.dtype, + "seq_len": probe_seq_len, + }, + ) + _, _, _, _, fwd_fn = _setup_llama4_rope(probe_input) + return fwd_fn()[0] + + peak_bytes = estimate_kernel_peak_memory(probe_fn=_probe) + kernel_bpt = peak_bytes // probe_seq_len + config = compute_seq_len_sweep_config(model, kernel_bytes_per_token=kernel_bpt) + + common_configs = { + "kernel_name": "llama4_rope", + "x_name": "T", + "x_label": "sequence length", + "x_values": [2**i for i in range(10, int(math.log2(max(1024, config.seq_len))) + 1)], + "kernel_providers": ["liger", "huggingface"], + "extra_benchmark_configs": [ + { + "hidden_size": model.hidden_size, + "num_q_heads": model.num_attention_heads, + "num_kv_heads": model.num_key_value_heads, + "dtype": model.dtype, + } + ], + "overwrite": args.overwrite, + } + + run_benchmarks( + bench_test_fn=bench_speed_llama4_rope, + kernel_operation_modes=["forward", "backward", "full"], + metric_name="speed", + metric_unit="ms", + **common_configs, + ) + run_benchmarks( + bench_test_fn=bench_memory_llama4_rope, + kernel_operation_modes=["full"], + metric_name="memory", + metric_unit="MB", + **common_configs, + ) diff --git a/benchmark/scripts/benchmark_mhc.py b/benchmark/scripts/benchmark_mhc.py index 47cdd6336..534268a51 100644 --- a/benchmark/scripts/benchmark_mhc.py +++ b/benchmark/scripts/benchmark_mhc.py @@ -1,9 +1,15 @@ +import math import os import sys import torch import triton +from benchmark_model_configs import MODEL_REGISTRY +from benchmark_model_configs import compute_model_config_sweep_config +from benchmark_model_configs import compute_seq_len_sweep_config +from benchmark_model_configs import estimate_kernel_peak_memory +from benchmark_model_configs import get_benchmark_model_config from utils import QUANTILES from utils import SingleBenchmarkRunInput from utils import SingleBenchmarkRunOutput @@ -21,20 +27,21 @@ sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))) -def bench_speed_mhc(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: +def _setup_mhc(input: SingleBenchmarkRunInput): + """Create input tensors and MHC kernel from benchmark config.""" from test.transformers.test_mhc import mhc_coeffs_ref - T = input.x - B = input.extra_benchmark_config["B"] - HC = input.extra_benchmark_config["HC"] - C = input.extra_benchmark_config["C"] - sub_kernel = input.extra_benchmark_config["sub_kernel"] - tmax = input.extra_benchmark_config["tmax"] - rms_eps = input.extra_benchmark_config["rms_eps"] - pre_eps = input.extra_benchmark_config["pre_eps"] - sinkhorn_eps = input.extra_benchmark_config["sinkhorn_eps"] - post_mult = input.extra_benchmark_config["post_mult"] - provider = input.kernel_provider + cfg = input.extra_benchmark_config + T = cfg.get("T", input.x) + B = cfg["B"] + HC = cfg["HC"] + C = cfg["C"] + sub_kernel = cfg["sub_kernel"] + tmax = cfg["tmax"] + rms_eps = cfg["rms_eps"] + pre_eps = cfg["pre_eps"] + sinkhorn_eps = cfg["sinkhorn_eps"] + post_mult = cfg["post_mult"] mode = input.kernel_operation_mode coeffs_cfg = dict(tmax=tmax, rms_eps=rms_eps, pre_eps=pre_eps, sinkhorn_eps=sinkhorn_eps, post_mult=post_mult) @@ -53,7 +60,7 @@ def bench_speed_mhc(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: if sub_kernel == "coeffs": def fwd(): - if provider == "liger": + if input.kernel_provider == "liger": return liger_mhc_coeffs(x, phi, b_param, alpha_pre, alpha_post, alpha_res, **coeffs_cfg) return mhc_coeffs_ref(x, phi, b_param, alpha_pre, alpha_post, alpha_res, **coeffs_cfg) @@ -76,7 +83,7 @@ def fwd_loss(): grad_to_none = [x, h_pre_c] if need_grad else None def fwd(): - if provider == "liger": + if input.kernel_provider == "liger": return liger_mhc_pre(x, h_pre_c) return (x.float() * h_pre_c.unsqueeze(-1)).sum(dim=-2) @@ -100,7 +107,7 @@ def fwd_loss(): grad_to_none = [x, f_out, h_post_c, h_res_c] if need_grad else None def fwd(): - if provider == "liger": + if input.kernel_provider == "liger": return liger_mhc_post_res(x, f_out, h_post_c, h_res_c) return torch.einsum("...oi,...ic->...oc", h_res_c, x.float()) + h_post_c.unsqueeze( -1 @@ -109,6 +116,13 @@ def fwd(): def fwd_loss(): return fwd().square().mean() + return grad_to_none, fwd, fwd_loss + + +def bench_speed_mhc(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + grad_to_none, fwd, fwd_loss = _setup_mhc(input) + mode = input.kernel_operation_mode + if mode == "forward": ms_50, ms_20, ms_80 = triton.testing.do_bench(fwd, rep=100, quantiles=QUANTILES) elif mode == "backward": @@ -126,87 +140,78 @@ def full(): y.backward() ms_50, ms_20, ms_80 = triton.testing.do_bench(full, grad_to_none=grad_to_none, rep=100, quantiles=QUANTILES) - + else: + raise ValueError(f"Unsupported mode: {mode}") return SingleBenchmarkRunOutput(y_20=ms_20, y_50=ms_50, y_80=ms_80) def bench_memory_mhc(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: - from test.transformers.test_mhc import mhc_coeffs_ref + grad_to_none, fwd, fwd_loss = _setup_mhc(input) - T = input.x - B = input.extra_benchmark_config["B"] - HC = input.extra_benchmark_config["HC"] - C = input.extra_benchmark_config["C"] - sub_kernel = input.extra_benchmark_config["sub_kernel"] - tmax = input.extra_benchmark_config["tmax"] - rms_eps = input.extra_benchmark_config["rms_eps"] - pre_eps = input.extra_benchmark_config["pre_eps"] - sinkhorn_eps = input.extra_benchmark_config["sinkhorn_eps"] - post_mult = input.extra_benchmark_config["post_mult"] - provider = input.kernel_provider + def full(): + y = fwd_loss() + y.backward() - coeffs_cfg = dict(tmax=tmax, rms_eps=rms_eps, pre_eps=pre_eps, sinkhorn_eps=sinkhorn_eps, post_mult=post_mult) + mem_50, mem_20, mem_80 = _test_memory(full, _iter=10, quantiles=QUANTILES) + return SingleBenchmarkRunOutput(y_20=mem_20, y_50=mem_50, y_80=mem_80) - x = torch.randn(B, T, HC, C, device=device, dtype=torch.bfloat16, requires_grad=True) - K, M = HC * C, HC * HC + 2 * HC - phi = (torch.randn(K, M, device=device, dtype=torch.bfloat16) * 0.02).requires_grad_(True) - b_param = torch.zeros(M, device=device, dtype=torch.float32, requires_grad=True) - alpha_pre = torch.tensor(1.0, device=device, dtype=torch.float32, requires_grad=True) - alpha_post = torch.tensor(1.0, device=device, dtype=torch.float32, requires_grad=True) - alpha_res = torch.tensor(1.0, device=device, dtype=torch.float32, requires_grad=True) - if sub_kernel == "coeffs": +def _resolve_model_config_mhc(input: SingleBenchmarkRunInput): + cfg = input.extra_benchmark_config + model_info = cfg["model_configs"][input.x] + return _setup_mhc( + SingleBenchmarkRunInput( + x=input.x, + kernel_provider=input.kernel_provider, + kernel_operation_mode=input.kernel_operation_mode, + extra_benchmark_config={ + "B": cfg["B"], + "HC": cfg["HC"], + "C": model_info["hidden_size"], + "T": cfg["T"], + "sub_kernel": cfg["sub_kernel"], + "tmax": cfg["tmax"], + "rms_eps": cfg["rms_eps"], + "pre_eps": cfg["pre_eps"], + "sinkhorn_eps": cfg["sinkhorn_eps"], + "post_mult": cfg["post_mult"], + }, + ) + ) - def full(): - if provider == "liger": - hp, hpo, hr = liger_mhc_coeffs(x, phi, b_param, alpha_pre, alpha_post, alpha_res, **coeffs_cfg) - else: - hp, hpo, hr = mhc_coeffs_ref(x, phi, b_param, alpha_pre, alpha_post, alpha_res, **coeffs_cfg) - (hp.square().mean() + hpo.square().mean() + hr.square().mean()).backward() - elif sub_kernel == "pre": - with torch.no_grad(): - h_pre_c, _, _ = liger_mhc_coeffs( - x.detach(), - phi.detach(), - b_param.detach(), - alpha_pre.detach(), - alpha_post.detach(), - alpha_res.detach(), - **coeffs_cfg, - ) - h_pre_c.requires_grad_(True) +def bench_speed_mhc_model_config(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + grad_to_none, fwd, fwd_loss = _resolve_model_config_mhc(input) + mode = input.kernel_operation_mode + + if mode == "forward": + ms_50, ms_20, ms_80 = triton.testing.do_bench(fwd, rep=100, quantiles=QUANTILES) + elif mode == "backward": + y = fwd_loss() + ms_50, ms_20, ms_80 = triton.testing.do_bench( + lambda: y.backward(retain_graph=True), + grad_to_none=grad_to_none, + rep=100, + quantiles=QUANTILES, + ) + elif mode == "full": def full(): - if provider == "liger": - out = liger_mhc_pre(x, h_pre_c) - else: - out = (x.float() * h_pre_c.unsqueeze(-1)).sum(dim=-2) - out.square().mean().backward() + y = fwd_loss() + y.backward() - elif sub_kernel == "post_res": - with torch.no_grad(): - _, h_post_c, h_res_c = liger_mhc_coeffs( - x.detach(), - phi.detach(), - b_param.detach(), - alpha_pre.detach(), - alpha_post.detach(), - alpha_res.detach(), - **coeffs_cfg, - ) - h_post_c.requires_grad_(True) - h_res_c.requires_grad_(True) - f_out = torch.randn(B, T, C, device=device, dtype=torch.bfloat16, requires_grad=True) + ms_50, ms_20, ms_80 = triton.testing.do_bench(full, grad_to_none=grad_to_none, rep=100, quantiles=QUANTILES) + else: + raise ValueError(f"Unsupported mode: {mode}") + return SingleBenchmarkRunOutput(y_20=ms_20, y_50=ms_50, y_80=ms_80) - def full(): - if provider == "liger": - out = liger_mhc_post_res(x, f_out, h_post_c, h_res_c) - else: - out = torch.einsum("...oi,...ic->...oc", h_res_c, x.float()) + h_post_c.unsqueeze( - -1 - ) * f_out.float().unsqueeze(-2) - out.square().mean().backward() + +def bench_memory_mhc_model_config(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + grad_to_none, fwd, fwd_loss = _resolve_model_config_mhc(input) + + def full(): + y = fwd_loss() + y.backward() mem_50, mem_20, mem_80 = _test_memory(full, _iter=10, quantiles=QUANTILES) return SingleBenchmarkRunOutput(y_20=mem_20, y_50=mem_50, y_80=mem_80) @@ -215,41 +220,129 @@ def full(): if __name__ == "__main__": args = parse_benchmark_script_args() - for sub_kernel in ["coeffs", "pre", "post_res"]: - common_configs = { - "kernel_name": f"mhc_{sub_kernel}", - "x_name": "T", - "x_label": "Sequence Length (T)", - "x_values": [2**i for i in range(7, 12)], - "kernel_providers": ["liger", "torch"], - "extra_benchmark_configs": [ - { - "B": 4, - "HC": 4, - "C": 4096, - "tmax": 20, - "rms_eps": 1e-6, - "pre_eps": 0.0, - "sinkhorn_eps": 1e-6, - "post_mult": 2.0, - "sub_kernel": sub_kernel, - } - ], - "overwrite": args.overwrite, - } - - run_benchmarks( - bench_test_fn=bench_speed_mhc, - kernel_operation_modes=["forward", "backward", "full"], - metric_name="speed", - metric_unit="ms", - **common_configs, - ) + mhc_defaults = {"tmax": 20, "rms_eps": 1e-6, "pre_eps": 0.0, "sinkhorn_eps": 1e-6, "post_mult": 2.0} - run_benchmarks( - bench_test_fn=bench_memory_mhc, - kernel_operation_modes=["full"], - metric_name="memory", - metric_unit="MB", - **common_configs, - ) + for sub_kernel in ["coeffs", "pre", "post_res"]: + if args.sweep_mode == "model_config": + all_model_configs = list(MODEL_REGISTRY.values()) + T = 256 + B = 4 + HC = 4 + + def _probe_factory(model_cfg, probe_bt, _sk=sub_kernel): + def _probe(): + probe_input = SingleBenchmarkRunInput( + x=0, + kernel_provider="torch", + kernel_operation_mode="full", + extra_benchmark_config={ + "B": B, + "HC": HC, + "C": model_cfg.hidden_size, + "T": T, + "sub_kernel": _sk, + **mhc_defaults, + }, + ) + _, _, fwd_loss = _setup_mhc(probe_input) + return fwd_loss() + + return _probe + + sweep = compute_model_config_sweep_config(all_model_configs, probe_fn_factory=_probe_factory, bt=args.bt) + model_configs_info = { + cfg.name: {"hidden_size": cfg.hidden_size, "dtype": cfg.dtype} for cfg in sweep.model_configs + } + + common_configs = { + "kernel_name": f"mhc_{sub_kernel}", + "x_name": "model_config", + "x_label": "model configuration", + "x_values": [cfg.name for cfg in sweep.model_configs], + "kernel_providers": ["liger", "torch"], + "extra_benchmark_configs": [ + { + "model_configs": model_configs_info, + "B": B, + "HC": HC, + "T": T, + "sub_kernel": sub_kernel, + **mhc_defaults, + } + ], + "overwrite": args.overwrite, + } + + run_benchmarks( + bench_test_fn=bench_speed_mhc_model_config, + kernel_operation_modes=["forward", "backward", "full"], + metric_name="speed", + metric_unit="ms", + **common_configs, + ) + run_benchmarks( + bench_test_fn=bench_memory_mhc_model_config, + kernel_operation_modes=["full"], + metric_name="memory", + metric_unit="MB", + **common_configs, + ) + else: + model = get_benchmark_model_config(args.model) + B = 4 + HC = 4 + probe_T = 256 + + def _probe(_sk=sub_kernel): + probe_input = SingleBenchmarkRunInput( + x=0, + kernel_provider="torch", + kernel_operation_mode="full", + extra_benchmark_config={ + "B": B, + "HC": HC, + "C": model.hidden_size, + "T": probe_T, + "sub_kernel": _sk, + **mhc_defaults, + }, + ) + _, _, fwd_loss = _setup_mhc(probe_input) + return fwd_loss() + + peak_bytes = estimate_kernel_peak_memory(probe_fn=_probe) + kernel_bpt = peak_bytes // probe_T + config = compute_seq_len_sweep_config(model, kernel_bytes_per_token=kernel_bpt) + + common_configs = { + "kernel_name": f"mhc_{sub_kernel}", + "x_name": "T", + "x_label": "Sequence Length (T)", + "x_values": [2**i for i in range(7, int(math.log2(max(128, config.seq_len))) + 1)], + "kernel_providers": ["liger", "torch"], + "extra_benchmark_configs": [ + { + "B": B, + "HC": HC, + "C": model.hidden_size, + "sub_kernel": sub_kernel, + **mhc_defaults, + } + ], + "overwrite": args.overwrite, + } + + run_benchmarks( + bench_test_fn=bench_speed_mhc, + kernel_operation_modes=["forward", "backward", "full"], + metric_name="speed", + metric_unit="ms", + **common_configs, + ) + run_benchmarks( + bench_test_fn=bench_memory_mhc, + kernel_operation_modes=["full"], + metric_name="memory", + metric_unit="MB", + **common_configs, + ) diff --git a/benchmark/scripts/benchmark_mhc_lm.py b/benchmark/scripts/benchmark_mhc_lm.py index 6330a0e1a..23113ed8e 100644 --- a/benchmark/scripts/benchmark_mhc_lm.py +++ b/benchmark/scripts/benchmark_mhc_lm.py @@ -1,3 +1,4 @@ +import math import os import sys @@ -6,6 +7,11 @@ import torch.nn.functional as F import triton +from benchmark_model_configs import MODEL_REGISTRY +from benchmark_model_configs import compute_model_config_sweep_config +from benchmark_model_configs import compute_seq_len_sweep_config +from benchmark_model_configs import estimate_kernel_peak_memory +from benchmark_model_configs import get_benchmark_model_config from utils import QUANTILES from utils import SingleBenchmarkRunInput from utils import SingleBenchmarkRunOutput @@ -182,34 +188,11 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class MHCDecoderLayer(nn.Module): - def __init__( - self, - mhc_cls: type[nn.Module], - *, - hidden_size: int, - hc: int, - num_heads: int, - intermediate_mult: int, - tmax: int, - dtype: torch.dtype, - device: str, - ): + def __init__(self, mhc_cls, *, hidden_size, hc, num_heads, intermediate_mult, tmax, dtype, device): super().__init__() attn = AttentionBlock(hidden_size, num_heads, dtype=dtype, device=device) mlp = MLPBlock(hidden_size, intermediate_mult, dtype=dtype, device=device) - self.attn = mhc_cls( - attn, - hc=hc, - c=hidden_size, - tmax=tmax, - rms_eps=1e-6, - pre_eps=1e-4, - sinkhorn_eps=1e-6, - post_mult=2.0, - phi_dtype=dtype, - ) - self.mlp = mhc_cls( - mlp, + mhc_kwargs = dict( hc=hc, c=hidden_size, tmax=tmax, @@ -219,6 +202,8 @@ def __init__( post_mult=2.0, phi_dtype=dtype, ) + self.attn = mhc_cls(attn, **mhc_kwargs) + self.mlp = mhc_cls(mlp, **mhc_kwargs) def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.attn(x) @@ -228,18 +213,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class BenchMiniMHCLM(nn.Module): def __init__( - self, - mhc_cls: type[nn.Module], - *, - vocab_size: int, - hidden_size: int, - hc: int, - num_layers: int, - num_heads: int, - intermediate_mult: int, - tmax: int, - dtype: torch.dtype, - device: str, + self, mhc_cls, *, vocab_size, hidden_size, hc, num_layers, num_heads, intermediate_mult, tmax, dtype, device ): super().__init__() self.hc = hc @@ -274,18 +248,7 @@ def forward(self, input_ids: torch.Tensor) -> torch.Tensor: return self.lm_head(x) -def _build_model( - provider: str, - *, - hidden_size: int, - hc: int, - num_layers: int, - num_heads: int, - intermediate_mult: int, - vocab_size: int, - tmax: int, - dtype: torch.dtype, -): +def _build_model(provider, *, hidden_size, hc, num_layers, num_heads, intermediate_mult, vocab_size, tmax, dtype): mhc_cls = LigerMHC if provider == "liger" else TorchMHC return BenchMiniMHCLM( mhc_cls, @@ -301,26 +264,22 @@ def _build_model( ) -def bench_speed_mhc_lm(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: - hidden_size = int(input.x) - provider = input.kernel_provider - mode = input.kernel_operation_mode - extra = input.extra_benchmark_config - bsz = extra["B"] - seq_len = extra["T"] - hc = extra["HC"] - num_layers = extra["layers"] - num_heads = extra["heads"] - vocab_size = extra["vocab"] - dtype = extra["dtype"] - tmax = extra["tmax"] - intermediate_mult = extra["intermediate_mult"] - - if hidden_size % num_heads != 0: - raise ValueError("hidden_size must be divisible by num_heads") +def _setup_mhc_lm(input: SingleBenchmarkRunInput): + """Create model and inputs for MHC LM benchmark.""" + cfg = input.extra_benchmark_config + hidden_size = cfg["hidden_size"] + bsz = cfg["B"] + seq_len = cfg.get("T", input.x) + hc = cfg["HC"] + num_layers = cfg["layers"] + num_heads = cfg["heads"] + vocab_size = cfg["vocab"] + dtype = cfg["dtype"] + tmax = cfg["tmax"] + intermediate_mult = cfg["intermediate_mult"] model = _build_model( - provider, + input.kernel_provider, hidden_size=hidden_size, hc=hc, num_layers=num_layers, @@ -332,19 +291,21 @@ def bench_speed_mhc_lm(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutp ) input_ids = torch.randint(0, vocab_size, (bsz, seq_len), device=device) + grad_to_none = list(model.parameters()) - def fwd(): - return model(input_ids) + fwd_fn = lambda: model(input_ids) + fwd_loss_fn = lambda: fwd_fn().float().mean() + return grad_to_none, fwd_fn, fwd_loss_fn - def fwd_loss(): - return fwd().float().mean() - grad_to_none = list(model.parameters()) +def bench_speed_mhc_lm(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + grad_to_none, fwd_fn, fwd_loss_fn = _setup_mhc_lm(input) + mode = input.kernel_operation_mode if mode == "forward": - ms_50, ms_20, ms_80 = triton.testing.do_bench(fwd, quantiles=QUANTILES, grad_to_none=grad_to_none, rep=100) + ms_50, ms_20, ms_80 = triton.testing.do_bench(fwd_fn, quantiles=QUANTILES, grad_to_none=grad_to_none, rep=100) elif mode == "backward": - loss = fwd_loss() + loss = fwd_loss_fn() ms_50, ms_20, ms_80 = triton.testing.do_bench( lambda: loss.backward(retain_graph=True), quantiles=QUANTILES, @@ -354,102 +315,202 @@ def fwd_loss(): elif mode == "full": def full(): - loss = fwd_loss() + loss = fwd_loss_fn() loss.backward() ms_50, ms_20, ms_80 = triton.testing.do_bench(full, quantiles=QUANTILES, grad_to_none=grad_to_none, rep=100) else: - raise ValueError(f"Unknown mode: {mode}") - - return SingleBenchmarkRunOutput( - y_20=ms_20, - y_50=ms_50, - y_80=ms_80, - ) + raise ValueError(f"Unsupported mode: {mode}") + return SingleBenchmarkRunOutput(y_20=ms_20, y_50=ms_50, y_80=ms_80) def bench_memory_mhc_lm(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: - hidden_size = int(input.x) - provider = input.kernel_provider - extra = input.extra_benchmark_config - bsz = extra["B"] - seq_len = extra["T"] - hc = extra["HC"] - num_layers = extra["layers"] - num_heads = extra["heads"] - vocab_size = extra["vocab"] - dtype = extra["dtype"] - tmax = extra["tmax"] - intermediate_mult = extra["intermediate_mult"] - - if hidden_size % num_heads != 0: - raise ValueError("hidden_size must be divisible by num_heads") + grad_to_none, fwd_fn, fwd_loss_fn = _setup_mhc_lm(input) - model = _build_model( - provider, - hidden_size=hidden_size, - hc=hc, - num_layers=num_layers, - num_heads=num_heads, - intermediate_mult=intermediate_mult, - vocab_size=vocab_size, - tmax=tmax, - dtype=dtype, + def full(): + loss = fwd_loss_fn() + loss.backward() + + mem_50, mem_20, mem_80 = _test_memory(full, quantiles=QUANTILES) + return SingleBenchmarkRunOutput(y_20=mem_20, y_50=mem_50, y_80=mem_80) + + +def _resolve_model_config_mhc_lm(input: SingleBenchmarkRunInput): + cfg = input.extra_benchmark_config + model_info = cfg["model_configs"][input.x] + return _setup_mhc_lm( + SingleBenchmarkRunInput( + x=input.x, + kernel_provider=input.kernel_provider, + extra_benchmark_config={ + "hidden_size": model_info["hidden_size"], + "dtype": model_info["dtype"], + "B": cfg["B"], + "T": cfg["T"], + "HC": cfg["HC"], + "layers": cfg["layers"], + "heads": cfg["heads"], + "vocab": cfg["vocab"], + "tmax": cfg["tmax"], + "intermediate_mult": cfg["intermediate_mult"], + }, + ) ) - input_ids = torch.randint(0, vocab_size, (bsz, seq_len), device=device) - def fwd(): - return model(input_ids) +def bench_speed_mhc_lm_model_config(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + grad_to_none, fwd_fn, fwd_loss_fn = _resolve_model_config_mhc_lm(input) + mode = input.kernel_operation_mode + + if mode == "forward": + ms_50, ms_20, ms_80 = triton.testing.do_bench(fwd_fn, quantiles=QUANTILES, grad_to_none=grad_to_none, rep=100) + elif mode == "backward": + loss = fwd_loss_fn() + ms_50, ms_20, ms_80 = triton.testing.do_bench( + lambda: loss.backward(retain_graph=True), + quantiles=QUANTILES, + grad_to_none=grad_to_none, + rep=100, + ) + elif mode == "full": + + def full(): + loss = fwd_loss_fn() + loss.backward() + + ms_50, ms_20, ms_80 = triton.testing.do_bench(full, quantiles=QUANTILES, grad_to_none=grad_to_none, rep=100) + else: + raise ValueError(f"Unsupported mode: {mode}") + return SingleBenchmarkRunOutput(y_20=ms_20, y_50=ms_50, y_80=ms_80) + + +def bench_memory_mhc_lm_model_config(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + grad_to_none, fwd_fn, fwd_loss_fn = _resolve_model_config_mhc_lm(input) def full(): - loss = fwd().float().mean() + loss = fwd_loss_fn() loss.backward() mem_50, mem_20, mem_80 = _test_memory(full, quantiles=QUANTILES) - return SingleBenchmarkRunOutput( - y_20=mem_20, - y_50=mem_50, - y_80=mem_80, - ) + return SingleBenchmarkRunOutput(y_20=mem_20, y_50=mem_50, y_80=mem_80) if __name__ == "__main__": args = parse_benchmark_script_args() - common_configs = { - "kernel_name": "mhc_llama_like_lm", - "x_name": "hidden_size", - "x_label": "hidden_size", - "x_values": [256, 512, 1024], - "kernel_providers": ["liger", "torch"], - "extra_benchmark_configs": [ - { - "B": 2, - "T": 256, - "HC": 4, - "layers": 2, - "heads": 8, - "vocab": 4096, - "dtype": torch.bfloat16, - "tmax": 8, - "intermediate_mult": 4, - } - ], - "overwrite": args.overwrite, - } - - run_benchmarks( - bench_test_fn=bench_speed_mhc_lm, - kernel_operation_modes=["forward", "backward", "full"], - metric_name="speed", - metric_unit="ms", - **common_configs, - ) - run_benchmarks( - bench_test_fn=bench_memory_mhc_lm, - kernel_operation_modes=["full"], - metric_name="memory", - metric_unit="MB", - **common_configs, - ) + mhc_lm_defaults = {"HC": 4, "layers": 2, "heads": 8, "vocab": 4096, "tmax": 8, "intermediate_mult": 4} + + if args.sweep_mode == "model_config": + all_model_configs = list(MODEL_REGISTRY.values()) + B = 2 + T = 256 + + def _probe_factory(model_cfg, probe_bt): + def _probe(): + probe_input = SingleBenchmarkRunInput( + x=0, + kernel_provider="torch", + extra_benchmark_config={ + "hidden_size": model_cfg.hidden_size, + "dtype": model_cfg.dtype, + "B": B, + "T": T, + **mhc_lm_defaults, + }, + ) + _, _, fwd_loss_fn = _setup_mhc_lm(probe_input) + return fwd_loss_fn() + + return _probe + + sweep = compute_model_config_sweep_config(all_model_configs, probe_fn_factory=_probe_factory, bt=args.bt) + model_configs_info = { + cfg.name: {"hidden_size": cfg.hidden_size, "dtype": cfg.dtype} for cfg in sweep.model_configs + } + + common_configs = { + "kernel_name": "mhc_llama_like_lm", + "x_name": "model_config", + "x_label": "model configuration", + "x_values": [cfg.name for cfg in sweep.model_configs], + "kernel_providers": ["liger", "torch"], + "extra_benchmark_configs": [ + { + "model_configs": model_configs_info, + "B": B, + "T": T, + **mhc_lm_defaults, + } + ], + "overwrite": args.overwrite, + } + + run_benchmarks( + bench_test_fn=bench_speed_mhc_lm_model_config, + kernel_operation_modes=["forward", "backward", "full"], + metric_name="speed", + metric_unit="ms", + **common_configs, + ) + run_benchmarks( + bench_test_fn=bench_memory_mhc_lm_model_config, + kernel_operation_modes=["full"], + metric_name="memory", + metric_unit="MB", + **common_configs, + ) + else: + model = get_benchmark_model_config(args.model) + B = 2 + probe_T = 256 + + def _probe(): + probe_input = SingleBenchmarkRunInput( + x=0, + kernel_provider="torch", + extra_benchmark_config={ + "hidden_size": model.hidden_size, + "dtype": model.dtype, + "B": B, + "T": probe_T, + **mhc_lm_defaults, + }, + ) + _, _, fwd_loss_fn = _setup_mhc_lm(probe_input) + return fwd_loss_fn() + + peak_bytes = estimate_kernel_peak_memory(probe_fn=_probe) + kernel_bpt = peak_bytes // probe_T + config = compute_seq_len_sweep_config(model, kernel_bytes_per_token=kernel_bpt) + + common_configs = { + "kernel_name": "mhc_llama_like_lm", + "x_name": "T", + "x_label": "sequence length", + "x_values": [2**i for i in range(7, int(math.log2(max(128, config.seq_len))) + 1)], + "kernel_providers": ["liger", "torch"], + "extra_benchmark_configs": [ + { + "hidden_size": model.hidden_size, + "B": B, + "dtype": model.dtype, + **mhc_lm_defaults, + } + ], + "overwrite": args.overwrite, + } + + run_benchmarks( + bench_test_fn=bench_speed_mhc_lm, + kernel_operation_modes=["forward", "backward", "full"], + metric_name="speed", + metric_unit="ms", + **common_configs, + ) + run_benchmarks( + bench_test_fn=bench_memory_mhc_lm, + kernel_operation_modes=["full"], + metric_name="memory", + metric_unit="MB", + **common_configs, + ) diff --git a/benchmark/scripts/benchmark_model_configs.py b/benchmark/scripts/benchmark_model_configs.py index c4b101366..bbd957bf4 100644 --- a/benchmark/scripts/benchmark_model_configs.py +++ b/benchmark/scripts/benchmark_model_configs.py @@ -117,9 +117,76 @@ class ModelConfigSweepConfig: max_position_embeddings=8192, ) +QWEN_2_5_7B = ModelConfig( + name="qwen2.5_7b", + hidden_size=3584, + intermediate_size=18944, + vocab_size=152064, + num_attention_heads=28, + num_key_value_heads=4, + head_dim=128, + hidden_act="silu", + max_position_embeddings=131072, +) + +QWEN_2_5_14B = ModelConfig( + name="qwen2.5_14b", + hidden_size=5120, + intermediate_size=13824, + vocab_size=152064, + num_attention_heads=40, + num_key_value_heads=8, + head_dim=128, + hidden_act="silu", + max_position_embeddings=131072, +) + +QWEN_2_5_72B = ModelConfig( + name="qwen2.5_72b", + hidden_size=8192, + intermediate_size=29568, + vocab_size=152064, + num_attention_heads=64, + num_key_value_heads=8, + head_dim=128, + hidden_act="silu", + max_position_embeddings=131072, +) + +DEEPSEEK_V2_LITE = ModelConfig( + name="deepseek_v2_lite", + hidden_size=2048, + intermediate_size=10944, + vocab_size=102400, + num_attention_heads=16, + num_key_value_heads=16, + head_dim=128, + hidden_act="silu", + max_position_embeddings=163840, +) + +DEEPSEEK_V3 = ModelConfig( + name="deepseek_v3", + hidden_size=7168, + intermediate_size=18432, + vocab_size=129280, + num_attention_heads=128, + num_key_value_heads=128, + head_dim=128, # v_head_dim; MLA splits Q/K into nope(128) + rope(64) dims internally + # MLA-specific params for reference: + # qk_nope_head_dim=128, qk_rope_head_dim=64, v_head_dim=128 + hidden_act="silu", + max_position_embeddings=163840, +) + MODEL_REGISTRY: Dict[str, ModelConfig] = { "llama_2_7b": LLAMA_2_7B, "llama_3_8b": LLAMA_3_8B, + "qwen2.5_7b": QWEN_2_5_7B, + "qwen2.5_14b": QWEN_2_5_14B, + "qwen2.5_72b": QWEN_2_5_72B, + "deepseek_v2_lite": DEEPSEEK_V2_LITE, + "deepseek_v3": DEEPSEEK_V3, } DEFAULT_MODEL_CONFIG = LLAMA_3_8B diff --git a/benchmark/scripts/benchmark_multi_token_attention.py b/benchmark/scripts/benchmark_multi_token_attention.py index b5319af5c..4045cfecc 100644 --- a/benchmark/scripts/benchmark_multi_token_attention.py +++ b/benchmark/scripts/benchmark_multi_token_attention.py @@ -1,6 +1,15 @@ +import math +import os +import sys + import torch import triton +from benchmark_model_configs import MODEL_REGISTRY +from benchmark_model_configs import compute_model_config_sweep_config +from benchmark_model_configs import compute_seq_len_sweep_config +from benchmark_model_configs import estimate_kernel_peak_memory +from benchmark_model_configs import get_benchmark_model_config from utils import QUANTILES from utils import SingleBenchmarkRunInput from utils import SingleBenchmarkRunOutput @@ -13,6 +22,8 @@ device = infer_device() +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))) + class TorchMultiTokenAttention(torch.nn.Module): def __init__(self, C_in, C_out, K, groups, bias, dtype, device): @@ -35,23 +46,19 @@ def forward(self, scores): return out_c.masked_fill(~mask, zero) -def bench_speed_multi_token_attention(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: - L = input.x - provider = input.kernel_provider - mode = input.kernel_operation_mode - - extra_benchmark_config = input.extra_benchmark_config - B = extra_benchmark_config["B"] - C_in = extra_benchmark_config["C_in"] - C_out = extra_benchmark_config["C_out"] - K = extra_benchmark_config["K"] - groups = extra_benchmark_config["groups"] - bias = extra_benchmark_config["bias"] - dtype = extra_benchmark_config["dtype"] - - x_shape = (B, C_in, L, L) +def _setup_multi_token_attention(input: SingleBenchmarkRunInput): + """Create input tensors and multi-token attention from benchmark config.""" + cfg = input.extra_benchmark_config + C_in = cfg["C_in"] + C_out = cfg["C_out"] + K = cfg["K"] + groups = cfg["groups"] + bias = cfg["bias"] + dtype = cfg["dtype"] + B = cfg.get("B", 2) + L = cfg.get("L", input.x) - triton_attn = ( + liger_attn = ( LigerMultiTokenAttention( in_channels=C_in, out_channels=C_out, @@ -67,35 +74,45 @@ def bench_speed_multi_token_attention(input: SingleBenchmarkRunInput) -> SingleB ) torch_attn = TorchMultiTokenAttention( - C_in=C_in, C_out=C_out, K=K, groups=groups, bias=bias, dtype=dtype, device=device + C_in=C_in, + C_out=C_out, + K=K, + groups=groups, + bias=bias, + dtype=dtype, + device=device, ) with torch.no_grad(): - torch_attn.weight.copy_(triton_attn.weight) + torch_attn.weight.copy_(liger_attn.weight) if bias: - torch_attn.bias.copy_(triton_attn.bias) + torch_attn.bias.copy_(liger_attn.bias) - x = torch.randn(x_shape, dtype=dtype, device=device) + x = torch.randn(B, C_in, L, L, dtype=dtype, device=device, requires_grad=True) dy = torch.randn_like(x) - x.requires_grad_(True) - - def fwd(): - if provider == "liger": - return triton_attn(x) - elif provider == "torch": - return torch_attn(x) - - print(f"Starting Warmup for input size: {x_shape}") - _ = fwd() - if mode in ("backward", "full"): - y = _ - y.backward(dy, retain_graph=True) - print("Done Warmup") + + if input.kernel_provider == "liger": + fwd_fn = lambda: liger_attn(x) + elif input.kernel_provider == "torch": + fwd_fn = lambda: torch_attn(x) + else: + raise ValueError(f"Invalid provider: {input.kernel_provider} for multi-token attention") + + # Warmup + _ = fwd_fn() + _.backward(dy, retain_graph=True) + + return x, dy, fwd_fn + + +def bench_speed_multi_token_attention(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + x, dy, fwd_fn = _setup_multi_token_attention(input) + mode = input.kernel_operation_mode if mode == "forward": - ms_50, ms_20, ms_80 = triton.testing.do_bench(fwd, grad_to_none=[x], rep=100, quantiles=QUANTILES) + ms_50, ms_20, ms_80 = triton.testing.do_bench(fwd_fn, grad_to_none=[x], rep=100, quantiles=QUANTILES) elif mode == "backward": - y = fwd() + y = fwd_fn() ms_50, ms_20, ms_80 = triton.testing.do_bench( lambda: y.backward(dy, retain_graph=True), grad_to_none=[x], @@ -105,114 +122,201 @@ def fwd(): elif mode == "full": def full(): - y = fwd() + y = fwd_fn() y.backward(dy, retain_graph=True) ms_50, ms_20, ms_80 = triton.testing.do_bench(full, grad_to_none=[x], rep=100, quantiles=QUANTILES) - - return SingleBenchmarkRunOutput( - y_20=ms_20, - y_50=ms_50, - y_80=ms_80, - ) + else: + raise ValueError(f"Unsupported mode: {mode}") + return SingleBenchmarkRunOutput(y_20=ms_20, y_50=ms_50, y_80=ms_80) def bench_memory_multi_token_attention(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: - L = input.x - provider = input.kernel_provider - - extra_benchmark_config = input.extra_benchmark_config - B = extra_benchmark_config["B"] - C_in = extra_benchmark_config["C_in"] - C_out = extra_benchmark_config["C_out"] - K = extra_benchmark_config["K"] - groups = extra_benchmark_config["groups"] - bias = extra_benchmark_config["bias"] - dtype = extra_benchmark_config["dtype"] + x, dy, fwd_fn = _setup_multi_token_attention(input) - x_shape = (B, C_in, L, L) + def full(): + y = fwd_fn() + y.backward(dy, retain_graph=True) - triton_attn = ( - LigerMultiTokenAttention( - in_channels=C_in, - out_channels=C_out, - kernel_size=K, - stride=1, - padding=K // 2, - dilation=1, - groups=groups, - bias=bias, + mem_50, mem_20, mem_80 = _test_memory(full, quantiles=QUANTILES) + return SingleBenchmarkRunOutput(y_20=mem_20, y_50=mem_50, y_80=mem_80) + + +def _resolve_model_config_multi_token_attention(input: SingleBenchmarkRunInput): + cfg = input.extra_benchmark_config + model_info = cfg["model_configs"][input.x] + return _setup_multi_token_attention( + SingleBenchmarkRunInput( + x=input.x, + kernel_provider=input.kernel_provider, + extra_benchmark_config={ + "C_in": cfg["C_in"], + "C_out": cfg["C_out"], + "K": cfg["K"], + "groups": cfg["groups"], + "bias": cfg["bias"], + "dtype": model_info["dtype"], + "B": cfg["B"], + "L": cfg["L"], + }, ) - .to(device) - .to(dtype) ) - torch_attn = TorchMultiTokenAttention( - C_in=C_in, C_out=C_out, K=K, groups=groups, bias=bias, dtype=dtype, device=device - ) - with torch.no_grad(): - torch_attn.weight.copy_(triton_attn.weight) - if bias: - torch_attn.bias.copy_(triton_attn.bias) +def bench_speed_multi_token_attention_model_config(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + x, dy, fwd_fn = _resolve_model_config_multi_token_attention(input) + mode = input.kernel_operation_mode - x = torch.randn(x_shape, dtype=dtype, device=device) - dy = torch.randn_like(x) - x.requires_grad_(True) + if mode == "forward": + ms_50, ms_20, ms_80 = triton.testing.do_bench(fwd_fn, grad_to_none=[x], rep=100, quantiles=QUANTILES) + elif mode == "backward": + y = fwd_fn() + ms_50, ms_20, ms_80 = triton.testing.do_bench( + lambda: y.backward(dy, retain_graph=True), + grad_to_none=[x], + rep=100, + quantiles=QUANTILES, + ) + elif mode == "full": + + def full(): + y = fwd_fn() + y.backward(dy, retain_graph=True) + + ms_50, ms_20, ms_80 = triton.testing.do_bench(full, grad_to_none=[x], rep=100, quantiles=QUANTILES) + else: + raise ValueError(f"Unsupported mode: {mode}") + return SingleBenchmarkRunOutput(y_20=ms_20, y_50=ms_50, y_80=ms_80) - def fwd(): - if provider == "liger": - return triton_attn(x) - elif provider == "torch": - return torch_attn(x) + +def bench_memory_multi_token_attention_model_config(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + x, dy, fwd_fn = _resolve_model_config_multi_token_attention(input) def full(): - y = fwd() + y = fwd_fn() y.backward(dy, retain_graph=True) mem_50, mem_20, mem_80 = _test_memory(full, quantiles=QUANTILES) - - return SingleBenchmarkRunOutput( - y_20=mem_20, - y_50=mem_50, - y_80=mem_80, - ) + return SingleBenchmarkRunOutput(y_20=mem_20, y_50=mem_50, y_80=mem_80) if __name__ == "__main__": args = parse_benchmark_script_args() - common_configs = { - "kernel_name": "multi_token_attention", - "x_name": "L", - "x_label": "sequence length", - "x_values": [2**i for i in range(5, 10)], - "kernel_providers": ["liger", "torch"], - "extra_benchmark_configs": [ - { - "B": 2, - "C_in": 4, - "C_out": 4, - "K": 3, - "groups": 1, - "bias": True, - "dtype": torch.bfloat16, - } - ], - "overwrite": args.overwrite, - } - - run_benchmarks( - bench_test_fn=bench_speed_multi_token_attention, - kernel_operation_modes=["forward", "full", "backward"], - metric_name="speed", - metric_unit="ms", - **common_configs, - ) - run_benchmarks( - bench_test_fn=bench_memory_multi_token_attention, - kernel_operation_modes=["full"], - metric_name="memory", - metric_unit="MB", - **common_configs, - ) + if args.sweep_mode == "model_config": + all_model_configs = list(MODEL_REGISTRY.values()) + L = 256 + B = 2 + + def _probe_factory(model_cfg, probe_bt): + def _probe(): + probe_input = SingleBenchmarkRunInput( + x=0, + kernel_provider="torch", + extra_benchmark_config={ + "C_in": 4, + "C_out": 4, + "K": 3, + "groups": 1, + "bias": True, + "dtype": model_cfg.dtype, + "B": B, + "L": L, + }, + ) + _, _, fwd_fn = _setup_multi_token_attention(probe_input) + return fwd_fn() + + return _probe + + sweep = compute_model_config_sweep_config(all_model_configs, probe_fn_factory=_probe_factory, bt=args.bt) + model_configs_info = {cfg.name: {"dtype": cfg.dtype} for cfg in sweep.model_configs} + + common_configs = { + "kernel_name": "multi_token_attention", + "x_name": "model_config", + "x_label": "model configuration", + "x_values": [cfg.name for cfg in sweep.model_configs], + "kernel_providers": ["liger", "torch"], + "extra_benchmark_configs": [ + { + "model_configs": model_configs_info, + "C_in": 4, + "C_out": 4, + "K": 3, + "groups": 1, + "bias": True, + "B": B, + "L": L, + } + ], + "overwrite": args.overwrite, + } + + run_benchmarks( + bench_test_fn=bench_speed_multi_token_attention_model_config, + kernel_operation_modes=["forward", "backward", "full"], + metric_name="speed", + metric_unit="ms", + **common_configs, + ) + run_benchmarks( + bench_test_fn=bench_memory_multi_token_attention_model_config, + kernel_operation_modes=["full"], + metric_name="memory", + metric_unit="MB", + **common_configs, + ) + else: + model = get_benchmark_model_config(args.model) + B = 2 + probe_L = 256 + + def _probe(): + probe_input = SingleBenchmarkRunInput( + x=0, + kernel_provider="torch", + extra_benchmark_config={ + "C_in": 4, + "C_out": 4, + "K": 3, + "groups": 1, + "bias": True, + "dtype": model.dtype, + "B": B, + "L": probe_L, + }, + ) + _, _, fwd_fn = _setup_multi_token_attention(probe_input) + return fwd_fn() + + peak_bytes = estimate_kernel_peak_memory(probe_fn=_probe) + kernel_bpt = peak_bytes // probe_L + config = compute_seq_len_sweep_config(model, kernel_bytes_per_token=kernel_bpt) + + common_configs = { + "kernel_name": "multi_token_attention", + "x_name": "L", + "x_label": "sequence length", + "x_values": [2**i for i in range(5, int(math.log2(max(32, config.seq_len))) + 1)], + "kernel_providers": ["liger", "torch"], + "extra_benchmark_configs": [ + {"C_in": 4, "C_out": 4, "K": 3, "groups": 1, "bias": True, "dtype": model.dtype, "B": B} + ], + "overwrite": args.overwrite, + } + + run_benchmarks( + bench_test_fn=bench_speed_multi_token_attention, + kernel_operation_modes=["forward", "backward", "full"], + metric_name="speed", + metric_unit="ms", + **common_configs, + ) + run_benchmarks( + bench_test_fn=bench_memory_multi_token_attention, + kernel_operation_modes=["full"], + metric_name="memory", + metric_unit="MB", + **common_configs, + ) diff --git a/benchmark/scripts/benchmark_orpo_loss.py b/benchmark/scripts/benchmark_orpo_loss.py index 30b308c42..afbf47487 100644 --- a/benchmark/scripts/benchmark_orpo_loss.py +++ b/benchmark/scripts/benchmark_orpo_loss.py @@ -1,9 +1,15 @@ +import math import os import sys import torch import triton +from benchmark_model_configs import MODEL_REGISTRY +from benchmark_model_configs import compute_model_config_sweep_config +from benchmark_model_configs import compute_seq_len_sweep_config +from benchmark_model_configs import estimate_kernel_peak_memory +from benchmark_model_configs import get_benchmark_model_config from utils import QUANTILES from utils import SingleBenchmarkRunInput from utils import SingleBenchmarkRunOutput @@ -18,152 +24,221 @@ sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))) -############################################################################# -# Test the memory consumption of the linear fused cross entropy loss -############################################################################# - - -def bench_memory_fused_linear_orpo_loss( - input: SingleBenchmarkRunInput, -) -> SingleBenchmarkRunOutput: +def _setup_orpo_loss(input: SingleBenchmarkRunInput): + """Create input tensors and ORPO loss from benchmark config.""" from test.chunked_loss.test_orpo_loss import LigerLMHeadORPO from test.chunked_loss.test_orpo_loss import TorchLMHeadORPO + cfg = input.extra_benchmark_config + H = cfg["hidden_size"] + V = cfg["vocab_size"] + dtype = cfg["dtype"] B = input.x - T = input.extra_benchmark_config["T"] - H = input.extra_benchmark_config["H"] - V = input.extra_benchmark_config["V"] - dtype = input.extra_benchmark_config["dtype"] - provider = input.kernel_provider - - # Instantiate once and retrieve the first output only - torch_lm_head_orpo = TorchLMHeadORPO(H=H, V=V, dtype=dtype).to(device) - liger_lm_head_orpo = LigerLMHeadORPO(H=H, V=V, dtype=dtype).to(device) - torch_fwd = lambda x, target, nll_target: torch_lm_head_orpo(x, target, nll_target)[0] - liger_fwd = lambda x, target, nll_target: liger_lm_head_orpo(x, target, nll_target)[0] + T = cfg["T"] _input = torch.randn(B, T, H, requires_grad=True, dtype=dtype, device=device) target = torch.randint(V, (B, T), dtype=torch.long, device=device) nll_target = torch.randint(V, (B, T), dtype=torch.long, device=device) - def fwd(): - if provider == "liger": - return liger_fwd(_input, target, nll_target) - elif provider == "huggingface": - return torch_fwd(_input, target, nll_target) + if input.kernel_provider == "liger": + loss_module = LigerLMHeadORPO(H=H, V=V, dtype=dtype).to(device) + elif input.kernel_provider == "huggingface": + loss_module = TorchLMHeadORPO(H=H, V=V, dtype=dtype).to(device) + else: + raise ValueError(f"Invalid provider: {input.kernel_provider} for ORPOLoss") - def full(): - y = fwd() - y.backward() + fwd_fn = lambda: loss_module(_input, target, nll_target)[0] + return _input, fwd_fn - mem_50, mem_20, mem_80 = _test_memory(full, _iter=10, quantiles=QUANTILES) - return SingleBenchmarkRunOutput( - y_20=mem_20, - y_50=mem_50, - y_80=mem_80, - ) +def bench_speed_orpo_loss(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + _input, fwd_fn = _setup_orpo_loss(input) + mode = input.kernel_operation_mode + + if mode == "forward": + ms_50, ms_20, ms_80 = triton.testing.do_bench(fwd_fn, rep=100, quantiles=QUANTILES) + elif mode == "backward": + y = fwd_fn() + ms_50, ms_20, ms_80 = triton.testing.do_bench( + lambda: y.backward(retain_graph=True), grad_to_none=[_input], rep=100, quantiles=QUANTILES + ) + elif mode == "full": -# ############################################################################# -# # Test the speed of the fused linear cross entropy loss -# ############################################################################# + def full(): + y = fwd_fn() + y.backward() + ms_50, ms_20, ms_80 = triton.testing.do_bench(full, rep=100, quantiles=QUANTILES) + else: + raise ValueError(f"Unsupported mode: {mode}") + return SingleBenchmarkRunOutput(y_20=ms_20, y_50=ms_50, y_80=ms_80) -def bench_speed_fused_linear_orpo_loss( - input: SingleBenchmarkRunInput, -) -> SingleBenchmarkRunOutput: - from test.chunked_loss.test_orpo_loss import LigerLMHeadORPO - from test.chunked_loss.test_orpo_loss import TorchLMHeadORPO - B = input.x - T = input.extra_benchmark_config["T"] - H = input.extra_benchmark_config["H"] - V = input.extra_benchmark_config["V"] - dtype = input.extra_benchmark_config["dtype"] - provider = input.kernel_provider - mode = input.kernel_operation_mode +def bench_memory_orpo_loss(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + _input, fwd_fn = _setup_orpo_loss(input) - # Instantiate once and retrieve the first output only - torch_lm_head_orpo = TorchLMHeadORPO(H=H, V=V, dtype=dtype).to(device) - liger_lm_head_orpo = LigerLMHeadORPO(H=H, V=V, dtype=dtype).to(device) - torch_fwd = lambda x, target, nll_target: torch_lm_head_orpo(x, target, nll_target)[0] - liger_fwd = lambda x, target, nll_target: liger_lm_head_orpo(x, target, nll_target)[0] + def full(): + y = fwd_fn() + y.backward() - _input = torch.randn(B, T, H, requires_grad=True, dtype=dtype, device=device) - target = torch.randint(V, (B, T), dtype=torch.long, device=device) - nll_target = torch.randint(V, (B, T), dtype=torch.long, device=device) + mem_50, mem_20, mem_80 = _test_memory(full, _iter=10, quantiles=QUANTILES) + return SingleBenchmarkRunOutput(y_20=mem_20, y_50=mem_50, y_80=mem_80) + + +def _resolve_model_config_orpo_loss(input: SingleBenchmarkRunInput): + cfg = input.extra_benchmark_config + model_info = cfg["model_configs"][input.x] + return _setup_orpo_loss( + SingleBenchmarkRunInput( + x=cfg["B"], + kernel_provider=input.kernel_provider, + extra_benchmark_config={ + "hidden_size": model_info["hidden_size"], + "vocab_size": model_info["vocab_size"], + "dtype": model_info["dtype"], + "T": cfg["T"], + }, + ) + ) - def fwd(): - if provider == "liger": - return liger_fwd(_input, target, nll_target) - elif provider == "huggingface": - return torch_fwd(_input, target, nll_target) + +def bench_speed_orpo_loss_model_config(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + _input, fwd_fn = _resolve_model_config_orpo_loss(input) + mode = input.kernel_operation_mode if mode == "forward": - ms_50, ms_20, ms_80 = triton.testing.do_bench( - fwd, - rep=100, - quantiles=QUANTILES, - ) + ms_50, ms_20, ms_80 = triton.testing.do_bench(fwd_fn, rep=100, quantiles=QUANTILES) elif mode == "backward": - y = fwd() - + y = fwd_fn() ms_50, ms_20, ms_80 = triton.testing.do_bench( - lambda: y.backward(retain_graph=True), - grad_to_none=[_input], - rep=100, - quantiles=QUANTILES, + lambda: y.backward(retain_graph=True), grad_to_none=[_input], rep=100, quantiles=QUANTILES ) elif mode == "full": def full(): - y = fwd() + y = fwd_fn() y.backward() - ms_50, ms_20, ms_80 = triton.testing.do_bench( - full, - rep=100, - quantiles=QUANTILES, - ) - return SingleBenchmarkRunOutput( - y_20=ms_20, - y_50=ms_50, - y_80=ms_80, - ) + ms_50, ms_20, ms_80 = triton.testing.do_bench(full, rep=100, quantiles=QUANTILES) + else: + raise ValueError(f"Unsupported mode: {mode}") + return SingleBenchmarkRunOutput(y_20=ms_20, y_50=ms_50, y_80=ms_80) + + +def bench_memory_orpo_loss_model_config(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + _input, fwd_fn = _resolve_model_config_orpo_loss(input) + + def full(): + y = fwd_fn() + y.backward() + + mem_50, mem_20, mem_80 = _test_memory(full, _iter=10, quantiles=QUANTILES) + return SingleBenchmarkRunOutput(y_20=mem_20, y_50=mem_50, y_80=mem_80) if __name__ == "__main__": args = parse_benchmark_script_args() - common_configs = { - "kernel_name": "fused_linear_orpo_loss", - "x_name": "B", - "x_label": "B", - "x_values": [2**i for i in range(1, 5)], - "kernel_providers": ["liger", "huggingface"], - "extra_benchmark_configs": [ - { - "T": 1024, - "H": 4096, - "V": 128256, - "mode": "forward", - "dtype": torch.bfloat16, - } - ], - "overwrite": args.overwrite, - } - - run_benchmarks( - bench_test_fn=bench_speed_fused_linear_orpo_loss, - kernel_operation_modes=["forward", "full", "backward"], - metric_name="speed", - metric_unit="ms", - **common_configs, - ) - run_benchmarks( - bench_test_fn=bench_memory_fused_linear_orpo_loss, - kernel_operation_modes=["full"], - metric_name="memory", - metric_unit="MB", - **common_configs, - ) + if args.sweep_mode == "model_config": + all_model_configs = list(MODEL_REGISTRY.values()) + T = 1024 + + def _probe_factory(model_cfg, probe_bt): + def _probe(): + B = max(1, probe_bt // T) + probe_input = SingleBenchmarkRunInput( + x=B, + kernel_provider="huggingface", + extra_benchmark_config={ + "hidden_size": model_cfg.hidden_size, + "vocab_size": model_cfg.vocab_size, + "dtype": model_cfg.dtype, + "T": T, + }, + ) + _, fwd_fn = _setup_orpo_loss(probe_input) + return fwd_fn() + + return _probe + + sweep = compute_model_config_sweep_config(all_model_configs, probe_fn_factory=_probe_factory, bt=args.bt) + model_configs_info = { + cfg.name: {"hidden_size": cfg.hidden_size, "vocab_size": cfg.vocab_size, "dtype": cfg.dtype} + for cfg in sweep.model_configs + } + B = max(1, sweep.bt // T) + + common_configs = { + "kernel_name": "fused_linear_orpo_loss", + "x_name": "model_config", + "x_label": "model configuration", + "x_values": [cfg.name for cfg in sweep.model_configs], + "kernel_providers": ["liger", "huggingface"], + "extra_benchmark_configs": [{"model_configs": model_configs_info, "B": B, "T": T}], + "overwrite": args.overwrite, + } + + run_benchmarks( + bench_test_fn=bench_speed_orpo_loss_model_config, + kernel_operation_modes=["forward", "full", "backward"], + metric_name="speed", + metric_unit="ms", + **common_configs, + ) + run_benchmarks( + bench_test_fn=bench_memory_orpo_loss_model_config, + kernel_operation_modes=["full"], + metric_name="memory", + metric_unit="MB", + **common_configs, + ) + else: + model = get_benchmark_model_config(args.model) + T = 1024 + probe_bt = 1024 + + def _probe(): + B = max(1, probe_bt // T) + probe_input = SingleBenchmarkRunInput( + x=B, + kernel_provider="huggingface", + extra_benchmark_config={ + "hidden_size": model.hidden_size, + "vocab_size": model.vocab_size, + "dtype": model.dtype, + "T": T, + }, + ) + _, fwd_fn = _setup_orpo_loss(probe_input) + return fwd_fn() + + peak_bytes = estimate_kernel_peak_memory(probe_fn=_probe) + kernel_bpt = peak_bytes // probe_bt + config = compute_seq_len_sweep_config(model, kernel_bytes_per_token=kernel_bpt) + + common_configs = { + "kernel_name": "fused_linear_orpo_loss", + "x_name": "B", + "x_label": "Batch Size (B)", + "x_values": [2**i for i in range(1, int(math.log2(max(2, config.batch_size * config.seq_len // T))) + 1)], + "kernel_providers": ["liger", "huggingface"], + "extra_benchmark_configs": [ + {"hidden_size": model.hidden_size, "vocab_size": model.vocab_size, "dtype": model.dtype, "T": T} + ], + "overwrite": args.overwrite, + } + + run_benchmarks( + bench_test_fn=bench_speed_orpo_loss, + kernel_operation_modes=["forward", "full", "backward"], + metric_name="speed", + metric_unit="ms", + **common_configs, + ) + run_benchmarks( + bench_test_fn=bench_memory_orpo_loss, + kernel_operation_modes=["full"], + metric_name="memory", + metric_unit="MB", + **common_configs, + ) diff --git a/benchmark/scripts/benchmark_poly_norm.py b/benchmark/scripts/benchmark_poly_norm.py index ddff431d7..460a8daf9 100644 --- a/benchmark/scripts/benchmark_poly_norm.py +++ b/benchmark/scripts/benchmark_poly_norm.py @@ -1,13 +1,19 @@ +import math + import torch import torch.nn as nn -import triton -from utils import QUANTILES +from benchmark_model_configs import MODEL_REGISTRY +from benchmark_model_configs import compute_model_config_sweep_config +from benchmark_model_configs import compute_seq_len_sweep_config +from benchmark_model_configs import estimate_kernel_peak_memory +from benchmark_model_configs import get_benchmark_model_config from utils import SingleBenchmarkRunInput from utils import SingleBenchmarkRunOutput -from utils import _test_memory from utils import parse_benchmark_script_args from utils import run_benchmarks +from utils import run_memory_benchmark +from utils import run_speed_benchmark from liger_kernel.transformers.poly_norm import LigerPolyNorm from liger_kernel.utils import infer_device @@ -39,159 +45,191 @@ def _norm(self, x): return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.variance_epsilon) def forward(self, hidden_states): - """ - Forward pass of PolyNorm - - Args: - hidden_states: input tensor of shape (..., H) - - Returns: - output tensor of same shape as input - """ input_dtype = hidden_states.dtype hidden_states = hidden_states.to(torch.float32) - # Compute powers x_pow3 = hidden_states**3 x_pow2 = hidden_states**2 x_pow1 = hidden_states**1 - # Normalize each power norm_x3 = self._norm(x_pow3) norm_x2 = self._norm(x_pow2) norm_x1 = self._norm(x_pow1) - # Weighted sum with bias output = self.weight[0] * norm_x3 + self.weight[1] * norm_x2 + self.weight[2] * norm_x1 + self.bias return output.to(input_dtype) -def bench_speed_poly_norm(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: - N = input.x - provider = input.kernel_provider - mode = input.kernel_operation_mode - - extra_benchmark_config = input.extra_benchmark_config - M = extra_benchmark_config["M"] - eps = extra_benchmark_config["eps"] - dtype = extra_benchmark_config["dtype"] - - x_shape = (M, N) - - triton_poly = LigerPolyNorm(eps=eps).to(device) - naive_poly = NaivePolyNorm(eps=eps).to(device) - - x = torch.randn(x_shape, dtype=dtype, device=device) - dy = torch.randn_like(x) - x.requires_grad_(True) - - # utility functions - - def y_fwd(): - if provider == "liger": - return triton_poly(x) +def _setup_poly_norm(input: SingleBenchmarkRunInput): + """Create input tensor and PolyNorm layer from benchmark config.""" + cfg = input.extra_benchmark_config + hidden_size = cfg["hidden_size"] + eps = cfg["eps"] + x = torch.randn( + input.x, + hidden_size, + device=device, + dtype=cfg["dtype"], + requires_grad=True, + ) + if input.kernel_provider == "liger": + layer = LigerPolyNorm(eps=eps).to(device) + elif input.kernel_provider == "huggingface": + layer = NaivePolyNorm(eps=eps).to(device) + else: + raise ValueError(f"Invalid provider: {input.kernel_provider} for PolyNorm") + return x, layer - if provider == "huggingface": - return naive_poly(x) - if mode == "forward": - ms_50, ms_20, ms_80 = triton.testing.do_bench( - y_fwd, - grad_to_none=[x], - rep=500, - quantiles=QUANTILES, - ) - elif mode == "backward": - y = y_fwd() - ms_50, ms_20, ms_80 = triton.testing.do_bench( - lambda: y.backward(dy, retain_graph=True), - grad_to_none=[x], - rep=500, - quantiles=QUANTILES, - ) - elif mode == "full": +def bench_speed_poly_norm(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + x, layer = _setup_poly_norm(input) + return run_speed_benchmark(lambda: layer(x), input.kernel_operation_mode, [x]) - def full(): - y = y_fwd() - y.backward(dy, retain_graph=True) - ms_50, ms_20, ms_80 = triton.testing.do_bench( - full, - grad_to_none=[x], - rep=500, - quantiles=QUANTILES, +def bench_memory_poly_norm(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + x, layer = _setup_poly_norm(input) + return run_memory_benchmark(lambda: layer(x), input.kernel_operation_mode) + + +def _resolve_model_config_poly_norm(input: SingleBenchmarkRunInput): + """Resolve model-config-sweep input into standard setup args.""" + cfg = input.extra_benchmark_config + model_info = cfg["model_configs"][input.x] + return _setup_poly_norm( + SingleBenchmarkRunInput( + x=cfg["BT"], + kernel_provider=input.kernel_provider, + extra_benchmark_config={ + "hidden_size": model_info["hidden_size"], + "dtype": model_info["dtype"], + "eps": cfg["eps"], + }, ) - - return SingleBenchmarkRunOutput( - y_20=ms_20, - y_50=ms_50, - y_80=ms_80, ) -def bench_memory_poly_norm(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: - N = input.x - provider = input.kernel_provider - - extra_benchmark_config = input.extra_benchmark_config - M = extra_benchmark_config["M"] - eps = extra_benchmark_config["eps"] - dtype = extra_benchmark_config["dtype"] - - x_shape = (M, N) +def bench_speed_poly_norm_model_config(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + x, layer = _resolve_model_config_poly_norm(input) + return run_speed_benchmark(lambda: layer(x), input.kernel_operation_mode, [x]) - triton_poly = LigerPolyNorm(eps=eps).to(device) - naive_poly = NaivePolyNorm(eps=eps).to(device) - x = torch.randn(x_shape, dtype=dtype, device=device) - dy = torch.randn_like(x) - x.requires_grad_(True) - - # utility functions - def y_fwd(): - if provider == "liger": - return triton_poly(x) - if provider == "huggingface": - return naive_poly(x) - - def full(): - y = y_fwd() - y.backward(dy, retain_graph=True) - - mem_50, mem_20, mem_80 = _test_memory(full, quantiles=QUANTILES) - - return SingleBenchmarkRunOutput( - y_20=mem_20, - y_50=mem_50, - y_80=mem_80, - ) +def bench_memory_poly_norm_model_config(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + x, layer = _resolve_model_config_poly_norm(input) + return run_memory_benchmark(lambda: layer(x), input.kernel_operation_mode) if __name__ == "__main__": args = parse_benchmark_script_args() - common_configs = { - "kernel_name": "poly_norm", - "x_name": "H", - "x_label": "hidden size", - "x_values": [2**i for i in range(10, 16)], - "kernel_providers": ["liger", "huggingface"], - "extra_benchmark_configs": [{"M": 2048, "dtype": torch.bfloat16, "eps": 1e-6}], - "overwrite": args.overwrite, - } - - run_benchmarks( - bench_test_fn=bench_speed_poly_norm, - kernel_operation_modes=["forward", "full", "backward"], - metric_name="speed", - metric_unit="ms", - **common_configs, - ) - run_benchmarks( - bench_test_fn=bench_memory_poly_norm, - kernel_operation_modes=["full"], - metric_name="memory", - metric_unit="MB", - **common_configs, - ) + if args.sweep_mode == "model_config": + all_model_configs = list(MODEL_REGISTRY.values()) + + def _probe_factory(model_cfg, probe_bt): + def _probe(): + probe_input = SingleBenchmarkRunInput( + x=probe_bt, + kernel_provider="huggingface", + extra_benchmark_config={ + "hidden_size": model_cfg.hidden_size, + "dtype": model_cfg.dtype, + "eps": 1e-6, + }, + ) + x, layer = _setup_poly_norm(probe_input) + return layer(x) + + return _probe + + sweep = compute_model_config_sweep_config(all_model_configs, probe_fn_factory=_probe_factory, bt=args.bt) + + model_configs_info = { + cfg.name: { + "hidden_size": cfg.hidden_size, + "dtype": cfg.dtype, + } + for cfg in sweep.model_configs + } + + common_configs = { + "kernel_name": "poly_norm", + "x_name": "model_config", + "x_label": "model configuration", + "x_values": [cfg.name for cfg in sweep.model_configs], + "kernel_providers": ["liger", "huggingface"], + "extra_benchmark_configs": [ + { + "model_configs": model_configs_info, + "BT": sweep.bt, + "eps": 1e-6, + } + ], + "overwrite": args.overwrite, + } + + run_benchmarks( + bench_test_fn=bench_speed_poly_norm_model_config, + kernel_operation_modes=["full", "forward", "backward"], + metric_name="speed", + metric_unit="ms", + **common_configs, + ) + run_benchmarks( + bench_test_fn=bench_memory_poly_norm_model_config, + kernel_operation_modes=["full", "forward", "backward"], + metric_name="memory", + metric_unit="MB", + **common_configs, + ) + else: + model = get_benchmark_model_config(args.model) + probe_bt = 1024 + + def _probe(): + probe_input = SingleBenchmarkRunInput( + x=probe_bt, + kernel_provider="huggingface", + extra_benchmark_config={ + "hidden_size": model.hidden_size, + "dtype": model.dtype, + "eps": 1e-6, + }, + ) + x, layer = _setup_poly_norm(probe_input) + return layer(x) + + peak_bytes = estimate_kernel_peak_memory(probe_fn=_probe) + kernel_bpt = peak_bytes // probe_bt + + config = compute_seq_len_sweep_config(model, kernel_bytes_per_token=kernel_bpt) + + common_configs = { + "kernel_name": "poly_norm", + "x_name": "BT", + "x_label": "B * T", + "x_values": [2**i for i in range(10, int(math.log2(config.batch_size * config.seq_len)) + 1)], + "kernel_providers": ["liger", "huggingface"], + "extra_benchmark_configs": [ + { + "hidden_size": model.hidden_size, + "dtype": model.dtype, + "eps": 1e-6, + } + ], + "overwrite": args.overwrite, + } + + run_benchmarks( + bench_test_fn=bench_speed_poly_norm, + kernel_operation_modes=["full", "forward", "backward"], + metric_name="speed", + metric_unit="ms", + **common_configs, + ) + run_benchmarks( + bench_test_fn=bench_memory_poly_norm, + kernel_operation_modes=["full", "forward", "backward"], + metric_name="memory", + metric_unit="MB", + **common_configs, + ) diff --git a/benchmark/scripts/benchmark_qwen2vl_mrope.py b/benchmark/scripts/benchmark_qwen2vl_mrope.py index ec1c53b89..ef92bbad2 100644 --- a/benchmark/scripts/benchmark_qwen2vl_mrope.py +++ b/benchmark/scripts/benchmark_qwen2vl_mrope.py @@ -1,6 +1,15 @@ +import math +import os +import sys + import torch import triton +from benchmark_model_configs import MODEL_REGISTRY +from benchmark_model_configs import compute_model_config_sweep_config +from benchmark_model_configs import compute_seq_len_sweep_config +from benchmark_model_configs import estimate_kernel_peak_memory +from benchmark_model_configs import get_benchmark_model_config from transformers.models.qwen2_vl.configuration_qwen2_vl import Qwen2VLTextConfig from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLRotaryEmbedding from transformers.models.qwen2_vl.modeling_qwen2_vl import apply_multimodal_rotary_pos_emb @@ -16,29 +25,22 @@ device = infer_device() +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))) -def bench_speed_qwen2vl_mrope( - input: SingleBenchmarkRunInput, -) -> SingleBenchmarkRunOutput: - provider = input.kernel_provider - mode = input.kernel_operation_mode - - extra_benchmark_config = input.extra_benchmark_config - num_q_heads = extra_benchmark_config["num_q_heads"] - num_kv_heads = extra_benchmark_config["num_kv_heads"] - dtype = extra_benchmark_config["dtype"] - # x can be either hidden_size or seq_len - hidden_size = extra_benchmark_config["hidden_size"] if "hidden_size" in extra_benchmark_config else input.x - seq_len = extra_benchmark_config["seq_len"] if "seq_len" in extra_benchmark_config else input.x +def _setup_qwen2vl_mrope(input: SingleBenchmarkRunInput): + """Create input tensors and Qwen2VL M-RoPE embedding from benchmark config.""" + cfg = input.extra_benchmark_config + num_q_heads = cfg["num_q_heads"] + num_kv_heads = cfg["num_kv_heads"] + dtype = cfg["dtype"] + hidden_size = cfg.get("hidden_size", input.x) + seq_len = cfg.get("seq_len", input.x) head_dim = hidden_size // num_q_heads mrope_section_hw = head_dim * 3 // 16 - mrope_section = [ - head_dim // 2 - 2 * mrope_section_hw, - mrope_section_hw, - mrope_section_hw, - ] + mrope_section = [head_dim // 2 - 2 * mrope_section_hw, mrope_section_hw, mrope_section_hw] + config = Qwen2VLTextConfig( hidden_size=hidden_size, num_attention_heads=num_q_heads, @@ -59,30 +61,28 @@ def bench_speed_qwen2vl_mrope( requires_grad=True, dtype=dtype, ).transpose(1, 2) - dq, dk = ( - torch.randn_like(q, device=device, dtype=dtype), - torch.randn_like(k, device=device, dtype=dtype), - ) + dq, dk = torch.randn_like(q, device=device, dtype=dtype), torch.randn_like(k, device=device, dtype=dtype) pos_ids = torch.arange(seq_len * 3, device=device, dtype=torch.long).view(3, 1, -1) cos, sin = rotary_emb(k, pos_ids) - def fwd(): - if provider == "liger": - return liger_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section) - elif provider == "huggingface": - return apply_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section) - else: - raise ValueError(f"Invalid provider: {provider} for M-RoPE embedding") + if input.kernel_provider == "liger": + fwd_fn = lambda: liger_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section) + elif input.kernel_provider == "huggingface": + fwd_fn = lambda: apply_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section) + else: + raise ValueError(f"Invalid provider: {input.kernel_provider} for M-RoPE embedding") + + return q, k, dq, dk, fwd_fn + + +def bench_speed_qwen2vl_mrope(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + q, k, dq, dk, fwd_fn = _setup_qwen2vl_mrope(input) + mode = input.kernel_operation_mode if mode == "forward": - ms_50, ms_20, ms_80 = triton.testing.do_bench( - fwd, - grad_to_none=[q, k], - rep=400, - quantiles=QUANTILES, - ) + ms_50, ms_20, ms_80 = triton.testing.do_bench(fwd_fn, grad_to_none=[q, k], rep=400, quantiles=QUANTILES) elif mode == "backward": - q_out, k_out = fwd() + q_out, k_out = fwd_fn() ms_50, ms_20, ms_80 = triton.testing.do_bench( lambda: torch.autograd.grad((q_out, k_out), (q, k), (dq, dk), allow_unused=True, retain_graph=True), grad_to_none=[q, k], @@ -92,150 +92,192 @@ def fwd(): elif mode == "full": def full(): - q_out, k_out = fwd() + q_out, k_out = fwd_fn() torch.autograd.grad((q_out, k_out), (q, k), (dq, dk), allow_unused=True) + ms_50, ms_20, ms_80 = triton.testing.do_bench(full, grad_to_none=[q, k], rep=400, quantiles=QUANTILES) + else: + raise ValueError(f"Unsupported mode: {mode}") + return SingleBenchmarkRunOutput(y_20=ms_20, y_50=ms_50, y_80=ms_80) + + +def bench_memory_qwen2vl_mrope(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + q, k, dq, dk, fwd_fn = _setup_qwen2vl_mrope(input) + + def full(): + q_out, k_out = fwd_fn() + torch.autograd.grad((q_out, k_out), (q, k), (dq, dk), allow_unused=True, retain_graph=True) + + mem_50, mem_20, mem_80 = _test_memory(full, quantiles=QUANTILES) + return SingleBenchmarkRunOutput(y_20=mem_20, y_50=mem_50, y_80=mem_80) + + +def _resolve_model_config_qwen2vl_mrope(input: SingleBenchmarkRunInput): + cfg = input.extra_benchmark_config + model_info = cfg["model_configs"][input.x] + return _setup_qwen2vl_mrope( + SingleBenchmarkRunInput( + x=input.x, + kernel_provider=input.kernel_provider, + extra_benchmark_config={ + "hidden_size": model_info["hidden_size"], + "num_q_heads": model_info["num_q_heads"], + "num_kv_heads": model_info["num_kv_heads"], + "dtype": model_info["dtype"], + "seq_len": cfg["seq_len"], + }, + ) + ) + + +def bench_speed_qwen2vl_mrope_model_config(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + q, k, dq, dk, fwd_fn = _resolve_model_config_qwen2vl_mrope(input) + mode = input.kernel_operation_mode + + if mode == "forward": + ms_50, ms_20, ms_80 = triton.testing.do_bench(fwd_fn, grad_to_none=[q, k], rep=400, quantiles=QUANTILES) + elif mode == "backward": + q_out, k_out = fwd_fn() ms_50, ms_20, ms_80 = triton.testing.do_bench( - full, + lambda: torch.autograd.grad((q_out, k_out), (q, k), (dq, dk), allow_unused=True, retain_graph=True), grad_to_none=[q, k], rep=400, quantiles=QUANTILES, ) - return SingleBenchmarkRunOutput( - y_20=ms_20, - y_50=ms_50, - y_80=ms_80, - ) - + elif mode == "full": -def bench_memory_qwen2vl_mrope( - input: SingleBenchmarkRunInput, -) -> SingleBenchmarkRunOutput: - provider = input.kernel_provider + def full(): + q_out, k_out = fwd_fn() + torch.autograd.grad((q_out, k_out), (q, k), (dq, dk), allow_unused=True) - extra_benchmark_config = input.extra_benchmark_config - num_q_heads = extra_benchmark_config["num_q_heads"] - num_kv_heads = extra_benchmark_config["num_kv_heads"] - dtype = extra_benchmark_config["dtype"] + ms_50, ms_20, ms_80 = triton.testing.do_bench(full, grad_to_none=[q, k], rep=400, quantiles=QUANTILES) + else: + raise ValueError(f"Unsupported mode: {mode}") + return SingleBenchmarkRunOutput(y_20=ms_20, y_50=ms_50, y_80=ms_80) - # x can be either hidden_size or seq_len - hidden_size = extra_benchmark_config["hidden_size"] if "hidden_size" in extra_benchmark_config else input.x - seq_len = extra_benchmark_config["seq_len"] if "seq_len" in extra_benchmark_config else input.x - head_dim = hidden_size // num_q_heads - - mrope_section_hw = head_dim * 3 // 16 - mrope_section = [ - head_dim // 2 - 2 * mrope_section_hw, - mrope_section_hw, - mrope_section_hw, - ] - config = Qwen2VLTextConfig( - hidden_size=hidden_size, - num_attention_heads=num_q_heads, - num_key_value_heads=num_kv_heads, - rope_theta=1000000.0, - mrope_section=mrope_section, - ) - rotary_emb = Qwen2VLRotaryEmbedding(config, device=device) - q = torch.randn( - (1, seq_len, num_q_heads, head_dim), - device=device, - requires_grad=True, - dtype=dtype, - ).transpose(1, 2) - k = torch.randn( - (1, seq_len, num_kv_heads, head_dim), - device=device, - requires_grad=True, - dtype=dtype, - ).transpose(1, 2) - dq, dk = ( - torch.randn_like(q, device=device, dtype=dtype), - torch.randn_like(k, device=device, dtype=dtype), - ) - pos_ids = torch.arange(seq_len * 3, device=device, dtype=torch.long).view(3, 1, -1) - cos, sin = rotary_emb(k, pos_ids) +def bench_memory_qwen2vl_mrope_model_config(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + q, k, dq, dk, fwd_fn = _resolve_model_config_qwen2vl_mrope(input) def full(): - if provider == "liger": - q_out, k_out = liger_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section) - else: - q_out, k_out = apply_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section) + q_out, k_out = fwd_fn() torch.autograd.grad((q_out, k_out), (q, k), (dq, dk), allow_unused=True, retain_graph=True) - mem_50, mem_20, mem_80 = _test_memory( - full, - quantiles=QUANTILES, - ) - return SingleBenchmarkRunOutput( - y_20=mem_20, - y_50=mem_50, - y_80=mem_80, - ) + mem_50, mem_20, mem_80 = _test_memory(full, quantiles=QUANTILES) + return SingleBenchmarkRunOutput(y_20=mem_20, y_50=mem_50, y_80=mem_80) if __name__ == "__main__": args = parse_benchmark_script_args() - common_configs_varying_hidden_size = { - "kernel_name": "qwen2vl_mrope", - "x_name": "H", - "x_label": "hidden size", - "x_values": [32 * (2**i) for i in range(4, 10, 2)], - "kernel_providers": ["liger", "huggingface"], - "extra_benchmark_configs": [ - { - "dtype": torch.bfloat16, - "seq_len": 2048, - "num_q_heads": 32, - "num_kv_heads": 8, - } - ], - "overwrite": args.overwrite, - } - run_benchmarks( - bench_test_fn=bench_speed_qwen2vl_mrope, - kernel_operation_modes=["forward", "backward", "full"], - metric_name="speed", - metric_unit="ms", - **common_configs_varying_hidden_size, - ) - run_benchmarks( - bench_test_fn=bench_memory_qwen2vl_mrope, - kernel_operation_modes=["full"], - metric_name="memory", - metric_unit="MB", - **common_configs_varying_hidden_size, - ) + if args.sweep_mode == "model_config": + all_model_configs = list(MODEL_REGISTRY.values()) + seq_len = 2048 + + def _probe_factory(model_cfg, probe_bt): + def _probe(): + probe_input = SingleBenchmarkRunInput( + x=0, + kernel_provider="huggingface", + extra_benchmark_config={ + "hidden_size": model_cfg.hidden_size, + "num_q_heads": model_cfg.num_attention_heads, + "num_kv_heads": model_cfg.num_key_value_heads, + "dtype": model_cfg.dtype, + "seq_len": seq_len, + }, + ) + _, _, _, _, fwd_fn = _setup_qwen2vl_mrope(probe_input) + return fwd_fn()[0] + + return _probe - common_configs_varying_seq_len = { - "kernel_name": "qwen2vl_mrope", - "x_name": "T", - "x_label": "sequence length", - "x_values": [2**i for i in range(10, 15)], - "kernel_providers": ["liger", "huggingface"], - "extra_benchmark_configs": [ - { - "dtype": torch.bfloat16, - "hidden_size": 8192, - "num_q_heads": 32, - "num_kv_heads": 8, + sweep = compute_model_config_sweep_config(all_model_configs, probe_fn_factory=_probe_factory, bt=args.bt) + model_configs_info = { + cfg.name: { + "hidden_size": cfg.hidden_size, + "num_q_heads": cfg.num_attention_heads, + "num_kv_heads": cfg.num_key_value_heads, + "dtype": cfg.dtype, } - ], - "overwrite": args.overwrite, - } - run_benchmarks( - bench_test_fn=bench_speed_qwen2vl_mrope, - kernel_operation_modes=["forward", "backward", "full"], - metric_name="speed", - metric_unit="ms", - **common_configs_varying_seq_len, - ) - run_benchmarks( - bench_test_fn=bench_memory_qwen2vl_mrope, - kernel_operation_modes=["full"], - metric_name="memory", - metric_unit="MB", - **common_configs_varying_seq_len, - ) + for cfg in sweep.model_configs + } + + common_configs = { + "kernel_name": "qwen2vl_mrope", + "x_name": "model_config", + "x_label": "model configuration", + "x_values": [cfg.name for cfg in sweep.model_configs], + "kernel_providers": ["liger", "huggingface"], + "extra_benchmark_configs": [{"model_configs": model_configs_info, "seq_len": seq_len}], + "overwrite": args.overwrite, + } + + run_benchmarks( + bench_test_fn=bench_speed_qwen2vl_mrope_model_config, + kernel_operation_modes=["forward", "backward", "full"], + metric_name="speed", + metric_unit="ms", + **common_configs, + ) + run_benchmarks( + bench_test_fn=bench_memory_qwen2vl_mrope_model_config, + kernel_operation_modes=["full"], + metric_name="memory", + metric_unit="MB", + **common_configs, + ) + else: + model = get_benchmark_model_config(args.model) + probe_seq_len = 2048 + + def _probe(): + probe_input = SingleBenchmarkRunInput( + x=0, + kernel_provider="huggingface", + extra_benchmark_config={ + "hidden_size": model.hidden_size, + "num_q_heads": model.num_attention_heads, + "num_kv_heads": model.num_key_value_heads, + "dtype": model.dtype, + "seq_len": probe_seq_len, + }, + ) + _, _, _, _, fwd_fn = _setup_qwen2vl_mrope(probe_input) + return fwd_fn()[0] + + peak_bytes = estimate_kernel_peak_memory(probe_fn=_probe) + kernel_bpt = peak_bytes // probe_seq_len + config = compute_seq_len_sweep_config(model, kernel_bytes_per_token=kernel_bpt) + + common_configs = { + "kernel_name": "qwen2vl_mrope", + "x_name": "T", + "x_label": "sequence length", + "x_values": [2**i for i in range(10, int(math.log2(max(1024, config.seq_len))) + 1)], + "kernel_providers": ["liger", "huggingface"], + "extra_benchmark_configs": [ + { + "hidden_size": model.hidden_size, + "num_q_heads": model.num_attention_heads, + "num_kv_heads": model.num_key_value_heads, + "dtype": model.dtype, + } + ], + "overwrite": args.overwrite, + } + + run_benchmarks( + bench_test_fn=bench_speed_qwen2vl_mrope, + kernel_operation_modes=["forward", "backward", "full"], + metric_name="speed", + metric_unit="ms", + **common_configs, + ) + run_benchmarks( + bench_test_fn=bench_memory_qwen2vl_mrope, + kernel_operation_modes=["full"], + metric_name="memory", + metric_unit="MB", + **common_configs, + ) diff --git a/benchmark/scripts/benchmark_relu_squared.py b/benchmark/scripts/benchmark_relu_squared.py index 401aea3f9..bb1f0ed12 100644 --- a/benchmark/scripts/benchmark_relu_squared.py +++ b/benchmark/scripts/benchmark_relu_squared.py @@ -1,6 +1,15 @@ +import math +import os +import sys + import torch import triton +from benchmark_model_configs import MODEL_REGISTRY +from benchmark_model_configs import compute_model_config_sweep_config +from benchmark_model_configs import compute_seq_len_sweep_config +from benchmark_model_configs import estimate_kernel_peak_memory +from benchmark_model_configs import get_benchmark_model_config from utils import QUANTILES from utils import SingleBenchmarkRunInput from utils import SingleBenchmarkRunOutput @@ -13,6 +22,8 @@ device = infer_device() +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))) + class TorchReLUSquared(torch.nn.Module): def forward(self, x): @@ -20,32 +31,35 @@ def forward(self, x): return torch.square(relu_applied) -def bench_speed_relu_squared(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: - N = input.x - provider = input.kernel_provider - mode = input.kernel_operation_mode - extra_benchmark_config = input.extra_benchmark_config - M = extra_benchmark_config["M"] - dtype = extra_benchmark_config["dtype"] +def _setup_relu_squared(input: SingleBenchmarkRunInput): + """Create input tensors and relu_squared module from benchmark config.""" + cfg = input.extra_benchmark_config + hidden_size = cfg["hidden_size"] + M = cfg.get("M", input.x) + dtype = cfg["dtype"] - x_shape = (M, N) - liger_relu_squared = LigerReLUSquared().to(device) - torch_relu_squared = TorchReLUSquared().to(device) - - x = torch.randn(x_shape, dtype=dtype, device=device) + x = torch.randn(M, hidden_size, dtype=dtype, device=device, requires_grad=True) dy = torch.randn_like(x) - x.requires_grad_(True) - def y_fwd(): - if provider == "liger": - return liger_relu_squared(x) - if provider == "torch": - return torch_relu_squared(x) + if input.kernel_provider == "liger": + relu_sq = LigerReLUSquared().to(device) + elif input.kernel_provider == "torch": + relu_sq = TorchReLUSquared().to(device) + else: + raise ValueError(f"Invalid provider: {input.kernel_provider} for relu_squared") + + fwd_fn = lambda: relu_sq(x) + return x, dy, fwd_fn + + +def bench_speed_relu_squared(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + x, dy, fwd_fn = _setup_relu_squared(input) + mode = input.kernel_operation_mode if mode == "forward": - ms_50, ms_20, ms_80 = triton.testing.do_bench(y_fwd, quantiles=QUANTILES, grad_to_none=[x], rep=500) + ms_50, ms_20, ms_80 = triton.testing.do_bench(fwd_fn, quantiles=QUANTILES, grad_to_none=[x], rep=500) elif mode == "backward": - y = y_fwd() + y = fwd_fn() ms_50, ms_20, ms_80 = triton.testing.do_bench( lambda: y.backward(dy, retain_graph=True), quantiles=QUANTILES, @@ -55,91 +69,173 @@ def y_fwd(): elif mode == "full": def full(): - y = y_fwd() + y = fwd_fn() y.backward(dy, retain_graph=True) ms_50, ms_20, ms_80 = triton.testing.do_bench(full, quantiles=QUANTILES, grad_to_none=[x], rep=500) - - if any(val is None for val in (ms_20, ms_50, ms_80)): - raise RuntimeError(f"Benchmark speed result is None: ms_20={ms_20}, ms_50={ms_50}, ms_80={ms_80}") - - return SingleBenchmarkRunOutput( - y_20=ms_20, - y_50=ms_50, - y_80=ms_80, - ) + else: + raise ValueError(f"Unsupported mode: {mode}") + return SingleBenchmarkRunOutput(y_20=ms_20, y_50=ms_50, y_80=ms_80) def bench_memory_relu_squared(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: - shape = input.x - provider = input.kernel_provider - mode = input.kernel_operation_mode - extra_benchmark_config = input.extra_benchmark_config - dtype = extra_benchmark_config.get("dtype", torch.float32) + x, dy, fwd_fn = _setup_relu_squared(input) - torch_relu_squared = TorchReLUSquared() - liger_relu_squared = LigerReLUSquared().to(device) + def full(): + y = fwd_fn() + y.backward(torch.ones_like(y), retain_graph=True) - x = torch.randn(shape, device=device, dtype=dtype, requires_grad=True) + mem_50, mem_20, mem_80 = _test_memory(full, quantiles=QUANTILES) + return SingleBenchmarkRunOutput(y_20=mem_20, y_50=mem_50, y_80=mem_80) + + +def _resolve_model_config_relu_squared(input: SingleBenchmarkRunInput): + cfg = input.extra_benchmark_config + model_info = cfg["model_configs"][input.x] + return _setup_relu_squared( + SingleBenchmarkRunInput( + x=input.x, + kernel_provider=input.kernel_provider, + extra_benchmark_config={ + "hidden_size": model_info["hidden_size"], + "dtype": model_info["dtype"], + "M": cfg["M"], + }, + ) + ) - def fwd(): - if provider == "liger": - return liger_relu_squared(x) - elif provider == "torch": - return torch_relu_squared(x) - else: - raise ValueError(f"Invalid provider: {provider} for relu_squared") - def full(): - y = fwd() - y.backward(torch.ones_like(y), retain_graph=True) +def bench_speed_relu_squared_model_config(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + x, dy, fwd_fn = _resolve_model_config_relu_squared(input) + mode = input.kernel_operation_mode if mode == "forward": - mem_50, mem_20, mem_80 = _test_memory(fwd, quantiles=QUANTILES) + ms_50, ms_20, ms_80 = triton.testing.do_bench(fwd_fn, quantiles=QUANTILES, grad_to_none=[x], rep=500) elif mode == "backward": - do = torch.ones_like(x) - y = fwd() - mem_50, mem_20, mem_80 = _test_memory(lambda: y.backward(do, retain_graph=True), quantiles=QUANTILES) + y = fwd_fn() + ms_50, ms_20, ms_80 = triton.testing.do_bench( + lambda: y.backward(dy, retain_graph=True), + quantiles=QUANTILES, + grad_to_none=[x], + rep=500, + ) + elif mode == "full": + + def full(): + y = fwd_fn() + y.backward(dy, retain_graph=True) + + ms_50, ms_20, ms_80 = triton.testing.do_bench(full, quantiles=QUANTILES, grad_to_none=[x], rep=500) else: - mem_50, mem_20, mem_80 = _test_memory(full, quantiles=QUANTILES) + raise ValueError(f"Unsupported mode: {mode}") + return SingleBenchmarkRunOutput(y_20=ms_20, y_50=ms_50, y_80=ms_80) - if any(val is None for val in (mem_20, mem_50, mem_80)): - raise RuntimeError(f"Benchmark memory result is None: mem_20={mem_20}, mem_50={mem_50}, mem_80={mem_80}") - return SingleBenchmarkRunOutput( - y_20=mem_20, - y_50=mem_50, - y_80=mem_80, - ) +def bench_memory_relu_squared_model_config(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + x, dy, fwd_fn = _resolve_model_config_relu_squared(input) + + def full(): + y = fwd_fn() + y.backward(torch.ones_like(y), retain_graph=True) + + mem_50, mem_20, mem_80 = _test_memory(full, quantiles=QUANTILES) + return SingleBenchmarkRunOutput(y_20=mem_20, y_50=mem_50, y_80=mem_80) if __name__ == "__main__": args = parse_benchmark_script_args() - common_configs = dict( - kernel_name="relu_squared", - x_name="N", - x_label="hidden size", - x_values=[128, 256, 512, 1024, 2048, 4096, 8192, 16384], - kernel_providers=["liger", "torch"], - extra_benchmark_configs=[ - {"M": 4096, "dtype": torch.bfloat16}, - ], - ) - - run_benchmarks( - bench_test_fn=bench_speed_relu_squared, - kernel_operation_modes=["forward", "full", "backward"], - metric_name="speed", - metric_unit="ms", - overwrite=args.overwrite, - **common_configs, - ) - run_benchmarks( - bench_test_fn=bench_memory_relu_squared, - kernel_operation_modes=["full"], - metric_name="memory", - metric_unit="MB", - overwrite=args.overwrite, - **common_configs, - ) + if args.sweep_mode == "model_config": + all_model_configs = list(MODEL_REGISTRY.values()) + M = 2048 + + def _probe_factory(model_cfg, probe_bt): + def _probe(): + probe_input = SingleBenchmarkRunInput( + x=0, + kernel_provider="torch", + extra_benchmark_config={ + "hidden_size": model_cfg.hidden_size, + "dtype": model_cfg.dtype, + "M": M, + }, + ) + _, _, fwd_fn = _setup_relu_squared(probe_input) + return fwd_fn() + + return _probe + + sweep = compute_model_config_sweep_config(all_model_configs, probe_fn_factory=_probe_factory, bt=args.bt) + model_configs_info = { + cfg.name: {"hidden_size": cfg.hidden_size, "dtype": cfg.dtype} for cfg in sweep.model_configs + } + + common_configs = { + "kernel_name": "relu_squared", + "x_name": "model_config", + "x_label": "model configuration", + "x_values": [cfg.name for cfg in sweep.model_configs], + "kernel_providers": ["liger", "torch"], + "extra_benchmark_configs": [{"model_configs": model_configs_info, "M": M}], + "overwrite": args.overwrite, + } + + run_benchmarks( + bench_test_fn=bench_speed_relu_squared_model_config, + kernel_operation_modes=["forward", "backward", "full"], + metric_name="speed", + metric_unit="ms", + **common_configs, + ) + run_benchmarks( + bench_test_fn=bench_memory_relu_squared_model_config, + kernel_operation_modes=["full"], + metric_name="memory", + metric_unit="MB", + **common_configs, + ) + else: + model = get_benchmark_model_config(args.model) + probe_bt = 2048 + + def _probe(): + probe_input = SingleBenchmarkRunInput( + x=0, + kernel_provider="torch", + extra_benchmark_config={ + "hidden_size": model.hidden_size, + "dtype": model.dtype, + "M": probe_bt, + }, + ) + _, _, fwd_fn = _setup_relu_squared(probe_input) + return fwd_fn() + + peak_bytes = estimate_kernel_peak_memory(probe_fn=_probe) + kernel_bpt = peak_bytes // probe_bt + config = compute_seq_len_sweep_config(model, kernel_bytes_per_token=kernel_bpt) + + common_configs = { + "kernel_name": "relu_squared", + "x_name": "BT", + "x_label": "B x T", + "x_values": [2**i for i in range(10, int(math.log2(max(1024, config.batch_size * config.seq_len))) + 1)], + "kernel_providers": ["liger", "torch"], + "extra_benchmark_configs": [{"hidden_size": model.hidden_size, "dtype": model.dtype}], + "overwrite": args.overwrite, + } + + run_benchmarks( + bench_test_fn=bench_speed_relu_squared, + kernel_operation_modes=["forward", "backward", "full"], + metric_name="speed", + metric_unit="ms", + **common_configs, + ) + run_benchmarks( + bench_test_fn=bench_memory_relu_squared, + kernel_operation_modes=["full"], + metric_name="memory", + metric_unit="MB", + **common_configs, + ) diff --git a/benchmark/scripts/benchmark_rms_norm.py b/benchmark/scripts/benchmark_rms_norm.py index 6bcd56a83..a0bccad34 100644 --- a/benchmark/scripts/benchmark_rms_norm.py +++ b/benchmark/scripts/benchmark_rms_norm.py @@ -1,13 +1,19 @@ +import math + import torch import torch.nn as nn -import triton -from utils import QUANTILES +from benchmark_model_configs import MODEL_REGISTRY +from benchmark_model_configs import compute_model_config_sweep_config +from benchmark_model_configs import compute_seq_len_sweep_config +from benchmark_model_configs import estimate_kernel_peak_memory +from benchmark_model_configs import get_benchmark_model_config from utils import SingleBenchmarkRunInput from utils import SingleBenchmarkRunOutput -from utils import _test_memory from utils import parse_benchmark_script_args from utils import run_benchmarks +from utils import run_memory_benchmark +from utils import run_speed_benchmark from liger_kernel.transformers.rms_norm import LigerRMSNorm from liger_kernel.utils import infer_device @@ -32,131 +38,175 @@ def forward(self, hidden_states): return self.weight * hidden_states.to(input_dtype) -def bench_speed_rms_norm(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: - N = input.x - provider = input.kernel_provider - mode = input.kernel_operation_mode - - extra_benchmark_config = input.extra_benchmark_config - M = extra_benchmark_config["M"] - eps = extra_benchmark_config["eps"] - dtype = extra_benchmark_config["dtype"] - - x_shape = (M, N) - - triton_rms = LigerRMSNorm(hidden_size=N, eps=eps).to(device) - llama_rms = LlamaRMSNorm(hidden_size=N, eps=eps).to(device) - - x = torch.randn(x_shape, dtype=dtype, device=device) - dy = torch.randn_like(x) - x.requires_grad_(True) - - # utility functions - - def y_fwd(): - if provider == "liger": - return triton_rms(x) +def _setup_rms_norm(input: SingleBenchmarkRunInput): + """Create input tensor and RMSNorm layer from benchmark config.""" + cfg = input.extra_benchmark_config + hidden_size = cfg["hidden_size"] + eps = cfg["eps"] + x = torch.randn( + input.x, + hidden_size, + device=device, + dtype=cfg["dtype"], + requires_grad=True, + ) + if input.kernel_provider == "liger": + layer = LigerRMSNorm(hidden_size=hidden_size, eps=eps).to(device) + elif input.kernel_provider == "huggingface": + layer = LlamaRMSNorm(hidden_size=hidden_size, eps=eps).to(device) + else: + raise ValueError(f"Invalid provider: {input.kernel_provider} for RMSNorm") + return x, layer - if provider == "huggingface": - return llama_rms(x) - if mode == "forward": - ms_50, ms_20, ms_80 = triton.testing.do_bench( - y_fwd, - grad_to_none=[x], - rep=500, - quantiles=QUANTILES, - ) - elif mode == "backward": - y = y_fwd() - ms_50, ms_20, ms_80 = triton.testing.do_bench( - lambda: y.backward(dy, retain_graph=True), - grad_to_none=[x], - rep=500, - quantiles=QUANTILES, - ) - elif mode == "full": +def bench_speed_rms_norm(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + x, layer = _setup_rms_norm(input) + return run_speed_benchmark(lambda: layer(x), input.kernel_operation_mode, [x]) - def full(): - y = y_fwd() - y.backward(dy, retain_graph=True) - ms_50, ms_20, ms_80 = triton.testing.do_bench( - full, - grad_to_none=[x], - rep=500, - quantiles=QUANTILES, +def bench_memory_rms_norm(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + x, layer = _setup_rms_norm(input) + return run_memory_benchmark(lambda: layer(x), input.kernel_operation_mode) + + +def _resolve_model_config_rms_norm(input: SingleBenchmarkRunInput): + """Resolve model-config-sweep input into standard setup args.""" + cfg = input.extra_benchmark_config + model_info = cfg["model_configs"][input.x] + return _setup_rms_norm( + SingleBenchmarkRunInput( + x=cfg["BT"], + kernel_provider=input.kernel_provider, + extra_benchmark_config={ + "hidden_size": model_info["hidden_size"], + "dtype": model_info["dtype"], + "eps": cfg["eps"], + }, ) - - return SingleBenchmarkRunOutput( - y_20=ms_20, - y_50=ms_50, - y_80=ms_80, ) -def bench_memory_rms_norm(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: - N = input.x - provider = input.kernel_provider - - extra_benchmark_config = input.extra_benchmark_config - M = extra_benchmark_config["M"] - eps = extra_benchmark_config["eps"] - dtype = extra_benchmark_config["dtype"] - - x_shape = (M, N) - - triton_rms = LigerRMSNorm(hidden_size=N, eps=eps).to(device) - llama_rms = LlamaRMSNorm(hidden_size=N, eps=eps).to(device) - - x = torch.randn(x_shape, dtype=dtype, device=device) - dy = torch.randn_like(x) - x.requires_grad_(True) - - # utility functions - def y_fwd(): - if provider == "liger": - return triton_rms(x) - if provider == "huggingface": - return llama_rms(x) - - def full(): - y = y_fwd() - y.backward(dy, retain_graph=True) +def bench_speed_rms_norm_model_config(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + x, layer = _resolve_model_config_rms_norm(input) + return run_speed_benchmark(lambda: layer(x), input.kernel_operation_mode, [x]) - mem_50, mem_20, mem_80 = _test_memory(full, quantiles=QUANTILES) - return SingleBenchmarkRunOutput( - y_20=mem_20, - y_50=mem_50, - y_80=mem_80, - ) +def bench_memory_rms_norm_model_config(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + x, layer = _resolve_model_config_rms_norm(input) + return run_memory_benchmark(lambda: layer(x), input.kernel_operation_mode) if __name__ == "__main__": args = parse_benchmark_script_args() - common_configs = { - "kernel_name": "rms_norm", - "x_name": "H", - "x_label": "hidden size", - "x_values": [2**i for i in range(10, 16)], - "kernel_providers": ["liger", "huggingface"], - "extra_benchmark_configs": [{"M": 2048, "dtype": torch.bfloat16, "eps": 1e-6}], - "overwrite": args.overwrite, - } - - run_benchmarks( - bench_test_fn=bench_speed_rms_norm, - kernel_operation_modes=["forward", "full", "backward"], - metric_name="speed", - metric_unit="ms", - **common_configs, - ) - run_benchmarks( - bench_test_fn=bench_memory_rms_norm, - kernel_operation_modes=["full"], - metric_name="memory", - metric_unit="MB", - **common_configs, - ) + if args.sweep_mode == "model_config": + all_model_configs = list(MODEL_REGISTRY.values()) + + def _probe_factory(model_cfg, probe_bt): + def _probe(): + probe_input = SingleBenchmarkRunInput( + x=probe_bt, + kernel_provider="huggingface", + extra_benchmark_config={ + "hidden_size": model_cfg.hidden_size, + "dtype": model_cfg.dtype, + "eps": 1e-6, + }, + ) + x, layer = _setup_rms_norm(probe_input) + return layer(x) + + return _probe + + sweep = compute_model_config_sweep_config(all_model_configs, probe_fn_factory=_probe_factory, bt=args.bt) + + model_configs_info = { + cfg.name: { + "hidden_size": cfg.hidden_size, + "dtype": cfg.dtype, + } + for cfg in sweep.model_configs + } + + common_configs = { + "kernel_name": "rms_norm", + "x_name": "model_config", + "x_label": "model configuration", + "x_values": [cfg.name for cfg in sweep.model_configs], + "kernel_providers": ["liger", "huggingface"], + "extra_benchmark_configs": [ + { + "model_configs": model_configs_info, + "BT": sweep.bt, + "eps": 1e-6, + } + ], + "overwrite": args.overwrite, + } + + run_benchmarks( + bench_test_fn=bench_speed_rms_norm_model_config, + kernel_operation_modes=["full", "forward", "backward"], + metric_name="speed", + metric_unit="ms", + **common_configs, + ) + run_benchmarks( + bench_test_fn=bench_memory_rms_norm_model_config, + kernel_operation_modes=["full", "forward", "backward"], + metric_name="memory", + metric_unit="MB", + **common_configs, + ) + else: + model = get_benchmark_model_config(args.model) + probe_bt = 1024 + + def _probe(): + probe_input = SingleBenchmarkRunInput( + x=probe_bt, + kernel_provider="huggingface", + extra_benchmark_config={ + "hidden_size": model.hidden_size, + "dtype": model.dtype, + "eps": 1e-6, + }, + ) + x, layer = _setup_rms_norm(probe_input) + return layer(x) + + peak_bytes = estimate_kernel_peak_memory(probe_fn=_probe) + kernel_bpt = peak_bytes // probe_bt + + config = compute_seq_len_sweep_config(model, kernel_bytes_per_token=kernel_bpt) + + common_configs = { + "kernel_name": "rms_norm", + "x_name": "BT", + "x_label": "B * T", + "x_values": [2**i for i in range(10, int(math.log2(config.batch_size * config.seq_len)) + 1)], + "kernel_providers": ["liger", "huggingface"], + "extra_benchmark_configs": [ + { + "hidden_size": model.hidden_size, + "dtype": model.dtype, + "eps": 1e-6, + } + ], + "overwrite": args.overwrite, + } + + run_benchmarks( + bench_test_fn=bench_speed_rms_norm, + kernel_operation_modes=["full", "forward", "backward"], + metric_name="speed", + metric_unit="ms", + **common_configs, + ) + run_benchmarks( + bench_test_fn=bench_memory_rms_norm, + kernel_operation_modes=["full", "forward", "backward"], + metric_name="memory", + metric_unit="MB", + **common_configs, + ) diff --git a/benchmark/scripts/benchmark_rope.py b/benchmark/scripts/benchmark_rope.py index 1951c3c23..d6c52a3c9 100644 --- a/benchmark/scripts/benchmark_rope.py +++ b/benchmark/scripts/benchmark_rope.py @@ -1,6 +1,15 @@ +import math +import os +import sys + import torch import triton +from benchmark_model_configs import MODEL_REGISTRY +from benchmark_model_configs import compute_model_config_sweep_config +from benchmark_model_configs import compute_seq_len_sweep_config +from benchmark_model_configs import estimate_kernel_peak_memory +from benchmark_model_configs import get_benchmark_model_config from transformers.models.llama.configuration_llama import LlamaConfig from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding from transformers.models.llama.modeling_llama import apply_rotary_pos_emb @@ -17,19 +26,17 @@ device = infer_device() +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))) -def bench_speed_rope(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: - provider = input.kernel_provider - mode = input.kernel_operation_mode - - extra_benchmark_config = input.extra_benchmark_config - num_q_heads = extra_benchmark_config["num_q_heads"] - num_kv_heads = extra_benchmark_config["num_kv_heads"] - dtype = extra_benchmark_config["dtype"] - # x can be either hidden_size or seq_len - hidden_size = extra_benchmark_config["hidden_size"] if "hidden_size" in extra_benchmark_config else input.x - seq_len = extra_benchmark_config["seq_len"] if "seq_len" in extra_benchmark_config else input.x +def _setup_rope(input: SingleBenchmarkRunInput): + """Create input tensors and RoPE embedding from benchmark config.""" + cfg = input.extra_benchmark_config + num_q_heads = cfg["num_q_heads"] + num_kv_heads = cfg["num_kv_heads"] + dtype = cfg["dtype"] + hidden_size = cfg.get("hidden_size", input.x) + seq_len = cfg.get("seq_len", input.x) head_dim = hidden_size // num_q_heads rotary_emb = transformers_version_dispatch( @@ -51,30 +58,28 @@ def bench_speed_rope(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput requires_grad=True, dtype=dtype, ).transpose(1, 2) - dq, dk = ( - torch.randn_like(q, device=device, dtype=dtype), - torch.randn_like(k, device=device), - ) + dq, dk = torch.randn_like(q, device=device, dtype=dtype), torch.randn_like(k, device=device) pos_ids = torch.arange(seq_len, device=device, dtype=torch.long).unsqueeze(0) cos, sin = rotary_emb(k, pos_ids) - def fwd(): - if provider == "liger": - return liger_rotary_pos_emb(q, k, cos, sin, pos_ids) - elif provider == "huggingface": - return apply_rotary_pos_emb(q, k, cos, sin) - else: - raise ValueError(f"Invalid provider: {provider} for RoPE embedding") + if input.kernel_provider == "liger": + fwd_fn = lambda: liger_rotary_pos_emb(q, k, cos, sin, pos_ids) + elif input.kernel_provider == "huggingface": + fwd_fn = lambda: apply_rotary_pos_emb(q, k, cos, sin, pos_ids) + else: + raise ValueError(f"Invalid provider: {input.kernel_provider} for RoPE embedding") + + return q, k, dq, dk, fwd_fn + + +def bench_speed_rope(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + q, k, dq, dk, fwd_fn = _setup_rope(input) + mode = input.kernel_operation_mode if mode == "forward": - ms_50, ms_20, ms_80 = triton.testing.do_bench( - fwd, - grad_to_none=[q, k], - rep=400, - quantiles=QUANTILES, - ) + ms_50, ms_20, ms_80 = triton.testing.do_bench(fwd_fn, grad_to_none=[q, k], rep=400, quantiles=QUANTILES) elif mode == "backward": - q_out, k_out = fwd() + q_out, k_out = fwd_fn() ms_50, ms_20, ms_80 = triton.testing.do_bench( lambda: torch.autograd.grad((q_out, k_out), (q, k), (dq, dk), allow_unused=True, retain_graph=True), grad_to_none=[q, k], @@ -84,140 +89,192 @@ def fwd(): elif mode == "full": def full(): - q_out, k_out = fwd() + q_out, k_out = fwd_fn() torch.autograd.grad((q_out, k_out), (q, k), (dq, dk), allow_unused=True) + ms_50, ms_20, ms_80 = triton.testing.do_bench(full, grad_to_none=[q, k], rep=400, quantiles=QUANTILES) + else: + raise ValueError(f"Unsupported mode: {mode}") + return SingleBenchmarkRunOutput(y_20=ms_20, y_50=ms_50, y_80=ms_80) + + +def bench_memory_rope(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + q, k, dq, dk, fwd_fn = _setup_rope(input) + + def full(): + q_out, k_out = fwd_fn() + torch.autograd.grad((q_out, k_out), (q, k), (dq, dk), allow_unused=True, retain_graph=True) + + mem_50, mem_20, mem_80 = _test_memory(full, quantiles=QUANTILES) + return SingleBenchmarkRunOutput(y_20=mem_20, y_50=mem_50, y_80=mem_80) + + +def _resolve_model_config_rope(input: SingleBenchmarkRunInput): + cfg = input.extra_benchmark_config + model_info = cfg["model_configs"][input.x] + return _setup_rope( + SingleBenchmarkRunInput( + x=input.x, + kernel_provider=input.kernel_provider, + extra_benchmark_config={ + "hidden_size": model_info["hidden_size"], + "num_q_heads": model_info["num_q_heads"], + "num_kv_heads": model_info["num_kv_heads"], + "dtype": model_info["dtype"], + "seq_len": cfg["seq_len"], + }, + ) + ) + + +def bench_speed_rope_model_config(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + q, k, dq, dk, fwd_fn = _resolve_model_config_rope(input) + mode = input.kernel_operation_mode + + if mode == "forward": + ms_50, ms_20, ms_80 = triton.testing.do_bench(fwd_fn, grad_to_none=[q, k], rep=400, quantiles=QUANTILES) + elif mode == "backward": + q_out, k_out = fwd_fn() ms_50, ms_20, ms_80 = triton.testing.do_bench( - full, + lambda: torch.autograd.grad((q_out, k_out), (q, k), (dq, dk), allow_unused=True, retain_graph=True), grad_to_none=[q, k], rep=400, quantiles=QUANTILES, ) - return SingleBenchmarkRunOutput( - y_20=ms_20, - y_50=ms_50, - y_80=ms_80, - ) + elif mode == "full": + def full(): + q_out, k_out = fwd_fn() + torch.autograd.grad((q_out, k_out), (q, k), (dq, dk), allow_unused=True) -def bench_memory_rope(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: - provider = input.kernel_provider + ms_50, ms_20, ms_80 = triton.testing.do_bench(full, grad_to_none=[q, k], rep=400, quantiles=QUANTILES) + else: + raise ValueError(f"Unsupported mode: {mode}") + return SingleBenchmarkRunOutput(y_20=ms_20, y_50=ms_50, y_80=ms_80) - extra_benchmark_config = input.extra_benchmark_config - num_q_heads = extra_benchmark_config["num_q_heads"] - num_kv_heads = extra_benchmark_config["num_kv_heads"] - dtype = extra_benchmark_config["dtype"] - # x can be either hidden_size or seq_len - hidden_size = extra_benchmark_config["hidden_size"] if "hidden_size" in extra_benchmark_config else input.x - seq_len = extra_benchmark_config["seq_len"] if "seq_len" in extra_benchmark_config else input.x - - head_dim = hidden_size // num_q_heads - rotary_emb = transformers_version_dispatch( - "4.48.0", - LlamaRotaryEmbedding, - LlamaRotaryEmbedding, - before_kwargs={"dim": head_dim, "device": device}, - after_kwargs={"config": LlamaConfig(num_kv_heads=num_kv_heads, head_dim=head_dim), "device": device}, - ) - q = torch.randn( - (1, seq_len, num_q_heads, head_dim), - device=device, - requires_grad=True, - dtype=dtype, - ).transpose(1, 2) - k = torch.randn( - (1, seq_len, num_kv_heads, head_dim), - device=device, - requires_grad=True, - dtype=dtype, - ).transpose(1, 2) - dq, dk = ( - torch.randn_like(q, device=device, dtype=dtype), - torch.randn_like(k, device=device), - ) - pos_ids = torch.arange(seq_len, device=device, dtype=torch.long).unsqueeze(0) - cos, sin = rotary_emb(k, pos_ids) +def bench_memory_rope_model_config(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + q, k, dq, dk, fwd_fn = _resolve_model_config_rope(input) def full(): - if provider == "liger": - q_out, k_out = liger_rotary_pos_emb(q, k, cos, sin, pos_ids) - else: - q_out, k_out = apply_rotary_pos_emb(q, k, cos, sin) + q_out, k_out = fwd_fn() torch.autograd.grad((q_out, k_out), (q, k), (dq, dk), allow_unused=True, retain_graph=True) - mem_50, mem_20, mem_80 = _test_memory( - full, - quantiles=QUANTILES, - ) - return SingleBenchmarkRunOutput( - y_20=mem_20, - y_50=mem_50, - y_80=mem_80, - ) + mem_50, mem_20, mem_80 = _test_memory(full, quantiles=QUANTILES) + return SingleBenchmarkRunOutput(y_20=mem_20, y_50=mem_50, y_80=mem_80) if __name__ == "__main__": args = parse_benchmark_script_args() - common_configs_varying_hidden_size = { - "kernel_name": "rope", - "x_name": "H", - "x_label": "hidden size", - "x_values": [32 * (2**i) for i in range(4, 10, 2)], - "kernel_providers": ["liger", "huggingface"], - "extra_benchmark_configs": [ - { - "dtype": torch.bfloat16, - "seq_len": 2048, - "num_q_heads": 32, - "num_kv_heads": 8, - } - ], - "overwrite": args.overwrite, - } - run_benchmarks( - bench_test_fn=bench_speed_rope, - kernel_operation_modes=["forward", "backward", "full"], - metric_name="speed", - metric_unit="ms", - **common_configs_varying_hidden_size, - ) - run_benchmarks( - bench_test_fn=bench_memory_rope, - kernel_operation_modes=["full"], - metric_name="memory", - metric_unit="MB", - **common_configs_varying_hidden_size, - ) + if args.sweep_mode == "model_config": + all_model_configs = list(MODEL_REGISTRY.values()) + seq_len = 2048 + + def _probe_factory(model_cfg, probe_bt): + def _probe(): + probe_input = SingleBenchmarkRunInput( + x=0, + kernel_provider="huggingface", + extra_benchmark_config={ + "hidden_size": model_cfg.hidden_size, + "num_q_heads": model_cfg.num_attention_heads, + "num_kv_heads": model_cfg.num_key_value_heads, + "dtype": model_cfg.dtype, + "seq_len": seq_len, + }, + ) + _, _, _, _, fwd_fn = _setup_rope(probe_input) + return fwd_fn()[0] # return q_out for memory estimation + + return _probe - common_configs_varying_seq_len = { - "kernel_name": "rope", - "x_name": "T", - "x_label": "sequence length", - "x_values": [2**i for i in range(10, 15)], - "kernel_providers": ["liger", "huggingface"], - "extra_benchmark_configs": [ - { - "dtype": torch.bfloat16, - "hidden_size": 8192, - "num_q_heads": 32, - "num_kv_heads": 8, + sweep = compute_model_config_sweep_config(all_model_configs, probe_fn_factory=_probe_factory, bt=args.bt) + model_configs_info = { + cfg.name: { + "hidden_size": cfg.hidden_size, + "num_q_heads": cfg.num_attention_heads, + "num_kv_heads": cfg.num_key_value_heads, + "dtype": cfg.dtype, } - ], - "overwrite": args.overwrite, - } - run_benchmarks( - bench_test_fn=bench_speed_rope, - kernel_operation_modes=["forward", "backward", "full"], - metric_name="speed", - metric_unit="ms", - **common_configs_varying_seq_len, - ) - run_benchmarks( - bench_test_fn=bench_memory_rope, - kernel_operation_modes=["full"], - metric_name="memory", - metric_unit="MB", - **common_configs_varying_seq_len, - ) + for cfg in sweep.model_configs + } + + common_configs = { + "kernel_name": "rope", + "x_name": "model_config", + "x_label": "model configuration", + "x_values": [cfg.name for cfg in sweep.model_configs], + "kernel_providers": ["liger", "huggingface"], + "extra_benchmark_configs": [{"model_configs": model_configs_info, "seq_len": seq_len}], + "overwrite": args.overwrite, + } + + run_benchmarks( + bench_test_fn=bench_speed_rope_model_config, + kernel_operation_modes=["forward", "backward", "full"], + metric_name="speed", + metric_unit="ms", + **common_configs, + ) + run_benchmarks( + bench_test_fn=bench_memory_rope_model_config, + kernel_operation_modes=["full"], + metric_name="memory", + metric_unit="MB", + **common_configs, + ) + else: + model = get_benchmark_model_config(args.model) + probe_seq_len = 2048 + + def _probe(): + probe_input = SingleBenchmarkRunInput( + x=0, + kernel_provider="huggingface", + extra_benchmark_config={ + "hidden_size": model.hidden_size, + "num_q_heads": model.num_attention_heads, + "num_kv_heads": model.num_key_value_heads, + "dtype": model.dtype, + "seq_len": probe_seq_len, + }, + ) + _, _, _, _, fwd_fn = _setup_rope(probe_input) + return fwd_fn()[0] + + peak_bytes = estimate_kernel_peak_memory(probe_fn=_probe) + kernel_bpt = peak_bytes // probe_seq_len + config = compute_seq_len_sweep_config(model, kernel_bytes_per_token=kernel_bpt) + + common_configs = { + "kernel_name": "rope", + "x_name": "T", + "x_label": "sequence length", + "x_values": [2**i for i in range(10, int(math.log2(max(1024, config.seq_len))) + 1)], + "kernel_providers": ["liger", "huggingface"], + "extra_benchmark_configs": [ + { + "hidden_size": model.hidden_size, + "num_q_heads": model.num_attention_heads, + "num_kv_heads": model.num_key_value_heads, + "dtype": model.dtype, + } + ], + "overwrite": args.overwrite, + } + + run_benchmarks( + bench_test_fn=bench_speed_rope, + kernel_operation_modes=["forward", "backward", "full"], + metric_name="speed", + metric_unit="ms", + **common_configs, + ) + run_benchmarks( + bench_test_fn=bench_memory_rope, + kernel_operation_modes=["full"], + metric_name="memory", + metric_unit="MB", + **common_configs, + ) diff --git a/benchmark/scripts/benchmark_simpo_loss.py b/benchmark/scripts/benchmark_simpo_loss.py index 148b8e3e4..b02c3e3bf 100644 --- a/benchmark/scripts/benchmark_simpo_loss.py +++ b/benchmark/scripts/benchmark_simpo_loss.py @@ -1,9 +1,15 @@ +import math import os import sys import torch import triton +from benchmark_model_configs import MODEL_REGISTRY +from benchmark_model_configs import compute_model_config_sweep_config +from benchmark_model_configs import compute_seq_len_sweep_config +from benchmark_model_configs import estimate_kernel_peak_memory +from benchmark_model_configs import get_benchmark_model_config from utils import QUANTILES from utils import SingleBenchmarkRunInput from utils import SingleBenchmarkRunOutput @@ -18,150 +24,220 @@ sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))) -############################################################################# -# Test the memory consumption of the linear fused cross entropy loss -############################################################################# - - -def bench_memory_fused_linear_simpo_loss( - input: SingleBenchmarkRunInput, -) -> SingleBenchmarkRunOutput: +def _setup_simpo_loss(input: SingleBenchmarkRunInput): + """Create input tensors and SimPO loss from benchmark config.""" from test.chunked_loss.test_simpo_loss import LigerLMHeadSimPO from test.chunked_loss.test_simpo_loss import TorchLMHeadCPO + cfg = input.extra_benchmark_config + H = cfg["hidden_size"] + V = cfg["vocab_size"] + dtype = cfg["dtype"] B = input.x - T = input.extra_benchmark_config["T"] - H = input.extra_benchmark_config["H"] - V = input.extra_benchmark_config["V"] - dtype = input.extra_benchmark_config["dtype"] - provider = input.kernel_provider - - # Instantiate once and retrieve the first output only - torch_lm_head_simpo = TorchLMHeadCPO(H=H, V=V, dtype=dtype).to(device) - liger_lm_head_simpo = LigerLMHeadSimPO(H=H, V=V, dtype=dtype).to(device) - torch_fwd = lambda x, target: torch_lm_head_simpo(x, target)[0] - liger_fwd = lambda x, target: liger_lm_head_simpo(x, target)[0] + T = cfg["T"] _input = torch.randn(B, T, H, requires_grad=True, dtype=dtype, device=device) target = torch.randint(V, (B, T), dtype=torch.long, device=device) - def fwd(): - if provider == "liger": - return liger_fwd(_input, target) - elif provider == "huggingface": - return torch_fwd(_input, target) + if input.kernel_provider == "liger": + loss_module = LigerLMHeadSimPO(H=H, V=V, dtype=dtype).to(device) + elif input.kernel_provider == "huggingface": + loss_module = TorchLMHeadCPO(H=H, V=V, dtype=dtype).to(device) + else: + raise ValueError(f"Invalid provider: {input.kernel_provider} for SimPOLoss") - def full(): - y = fwd() - y.backward() + fwd_fn = lambda: loss_module(_input, target)[0] + return _input, fwd_fn - mem_50, mem_20, mem_80 = _test_memory(full, _iter=10, quantiles=QUANTILES) - return SingleBenchmarkRunOutput( - y_20=mem_20, - y_50=mem_50, - y_80=mem_80, - ) +def bench_speed_simpo_loss(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + _input, fwd_fn = _setup_simpo_loss(input) + mode = input.kernel_operation_mode -# ############################################################################# -# # Test the speed of the fused linear cross entropy loss -# ############################################################################# + if mode == "forward": + ms_50, ms_20, ms_80 = triton.testing.do_bench(fwd_fn, rep=100, quantiles=QUANTILES) + elif mode == "backward": + y = fwd_fn() + ms_50, ms_20, ms_80 = triton.testing.do_bench( + lambda: y.backward(retain_graph=True), grad_to_none=[_input], rep=100, quantiles=QUANTILES + ) + elif mode == "full": + def full(): + y = fwd_fn() + y.backward() -def bench_speed_fused_linear_simpo_loss( - input: SingleBenchmarkRunInput, -) -> SingleBenchmarkRunOutput: - from test.chunked_loss.test_simpo_loss import LigerLMHeadSimPO - from test.chunked_loss.test_simpo_loss import TorchLMHeadCPO + ms_50, ms_20, ms_80 = triton.testing.do_bench(full, rep=100, quantiles=QUANTILES) + else: + raise ValueError(f"Unsupported mode: {mode}") + return SingleBenchmarkRunOutput(y_20=ms_20, y_50=ms_50, y_80=ms_80) - B = input.x - T = input.extra_benchmark_config["T"] - H = input.extra_benchmark_config["H"] - V = input.extra_benchmark_config["V"] - dtype = input.extra_benchmark_config["dtype"] - provider = input.kernel_provider - mode = input.kernel_operation_mode - # Instantiate once and retrieve the first output only - torch_lm_head_simpo = TorchLMHeadCPO(H=H, V=V, dtype=dtype).to(device) - liger_lm_head_simpo = LigerLMHeadSimPO(H=H, V=V, dtype=dtype).to(device) - torch_fwd = lambda x, target: torch_lm_head_simpo(x, target)[0] - liger_fwd = lambda x, target: liger_lm_head_simpo(x, target)[0] +def bench_memory_simpo_loss(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + _input, fwd_fn = _setup_simpo_loss(input) - _input = torch.randn(B, T, H, requires_grad=True, dtype=dtype, device=device) - target = torch.randint(V, (B, T), dtype=torch.long, device=device) + def full(): + y = fwd_fn() + y.backward() - def fwd(): - if provider == "liger": - return liger_fwd(_input, target) - elif provider == "huggingface": - return torch_fwd(_input, target) + mem_50, mem_20, mem_80 = _test_memory(full, _iter=10, quantiles=QUANTILES) + return SingleBenchmarkRunOutput(y_20=mem_20, y_50=mem_50, y_80=mem_80) + + +def _resolve_model_config_simpo_loss(input: SingleBenchmarkRunInput): + cfg = input.extra_benchmark_config + model_info = cfg["model_configs"][input.x] + return _setup_simpo_loss( + SingleBenchmarkRunInput( + x=cfg["B"], + kernel_provider=input.kernel_provider, + extra_benchmark_config={ + "hidden_size": model_info["hidden_size"], + "vocab_size": model_info["vocab_size"], + "dtype": model_info["dtype"], + "T": cfg["T"], + }, + ) + ) + + +def bench_speed_simpo_loss_model_config(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + _input, fwd_fn = _resolve_model_config_simpo_loss(input) + mode = input.kernel_operation_mode if mode == "forward": - ms_50, ms_20, ms_80 = triton.testing.do_bench( - fwd, - rep=100, - quantiles=QUANTILES, - ) + ms_50, ms_20, ms_80 = triton.testing.do_bench(fwd_fn, rep=100, quantiles=QUANTILES) elif mode == "backward": - y = fwd() - + y = fwd_fn() ms_50, ms_20, ms_80 = triton.testing.do_bench( - lambda: y.backward(retain_graph=True), - grad_to_none=[_input], - rep=100, - quantiles=QUANTILES, + lambda: y.backward(retain_graph=True), grad_to_none=[_input], rep=100, quantiles=QUANTILES ) elif mode == "full": def full(): - y = fwd() + y = fwd_fn() y.backward() - ms_50, ms_20, ms_80 = triton.testing.do_bench( - full, - rep=100, - quantiles=QUANTILES, - ) - return SingleBenchmarkRunOutput( - y_20=ms_20, - y_50=ms_50, - y_80=ms_80, - ) + ms_50, ms_20, ms_80 = triton.testing.do_bench(full, rep=100, quantiles=QUANTILES) + else: + raise ValueError(f"Unsupported mode: {mode}") + return SingleBenchmarkRunOutput(y_20=ms_20, y_50=ms_50, y_80=ms_80) + + +def bench_memory_simpo_loss_model_config(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + _input, fwd_fn = _resolve_model_config_simpo_loss(input) + + def full(): + y = fwd_fn() + y.backward() + + mem_50, mem_20, mem_80 = _test_memory(full, _iter=10, quantiles=QUANTILES) + return SingleBenchmarkRunOutput(y_20=mem_20, y_50=mem_50, y_80=mem_80) if __name__ == "__main__": args = parse_benchmark_script_args() - common_configs = { - "kernel_name": "fused_linear_simpo_loss", - "x_name": "B", - "x_label": "B", - "x_values": [2**i for i in range(1, 5)], - "kernel_providers": ["liger", "huggingface"], - "extra_benchmark_configs": [ - { - "T": 1024, - "H": 4096, - "V": 128256, - "mode": "forward", - "dtype": torch.bfloat16, - } - ], - "overwrite": args.overwrite, - } - - run_benchmarks( - bench_test_fn=bench_speed_fused_linear_simpo_loss, - kernel_operation_modes=["forward", "full", "backward"], - metric_name="speed", - metric_unit="ms", - **common_configs, - ) - run_benchmarks( - bench_test_fn=bench_memory_fused_linear_simpo_loss, - kernel_operation_modes=["full"], - metric_name="memory", - metric_unit="MB", - **common_configs, - ) + if args.sweep_mode == "model_config": + all_model_configs = list(MODEL_REGISTRY.values()) + T = 1024 + + def _probe_factory(model_cfg, probe_bt): + def _probe(): + B = max(1, probe_bt // T) + probe_input = SingleBenchmarkRunInput( + x=B, + kernel_provider="huggingface", + extra_benchmark_config={ + "hidden_size": model_cfg.hidden_size, + "vocab_size": model_cfg.vocab_size, + "dtype": model_cfg.dtype, + "T": T, + }, + ) + _, fwd_fn = _setup_simpo_loss(probe_input) + return fwd_fn() + + return _probe + + sweep = compute_model_config_sweep_config(all_model_configs, probe_fn_factory=_probe_factory, bt=args.bt) + model_configs_info = { + cfg.name: {"hidden_size": cfg.hidden_size, "vocab_size": cfg.vocab_size, "dtype": cfg.dtype} + for cfg in sweep.model_configs + } + B = max(1, sweep.bt // T) + + common_configs = { + "kernel_name": "fused_linear_simpo_loss", + "x_name": "model_config", + "x_label": "model configuration", + "x_values": [cfg.name for cfg in sweep.model_configs], + "kernel_providers": ["liger", "huggingface"], + "extra_benchmark_configs": [{"model_configs": model_configs_info, "B": B, "T": T}], + "overwrite": args.overwrite, + } + + run_benchmarks( + bench_test_fn=bench_speed_simpo_loss_model_config, + kernel_operation_modes=["forward", "full", "backward"], + metric_name="speed", + metric_unit="ms", + **common_configs, + ) + run_benchmarks( + bench_test_fn=bench_memory_simpo_loss_model_config, + kernel_operation_modes=["full"], + metric_name="memory", + metric_unit="MB", + **common_configs, + ) + else: + model = get_benchmark_model_config(args.model) + T = 1024 + probe_bt = 1024 + + def _probe(): + B = max(1, probe_bt // T) + probe_input = SingleBenchmarkRunInput( + x=B, + kernel_provider="huggingface", + extra_benchmark_config={ + "hidden_size": model.hidden_size, + "vocab_size": model.vocab_size, + "dtype": model.dtype, + "T": T, + }, + ) + _, fwd_fn = _setup_simpo_loss(probe_input) + return fwd_fn() + + peak_bytes = estimate_kernel_peak_memory(probe_fn=_probe) + kernel_bpt = peak_bytes // probe_bt + config = compute_seq_len_sweep_config(model, kernel_bytes_per_token=kernel_bpt) + + common_configs = { + "kernel_name": "fused_linear_simpo_loss", + "x_name": "B", + "x_label": "Batch Size (B)", + "x_values": [2**i for i in range(1, int(math.log2(max(2, config.batch_size * config.seq_len // T))) + 1)], + "kernel_providers": ["liger", "huggingface"], + "extra_benchmark_configs": [ + {"hidden_size": model.hidden_size, "vocab_size": model.vocab_size, "dtype": model.dtype, "T": T} + ], + "overwrite": args.overwrite, + } + + run_benchmarks( + bench_test_fn=bench_speed_simpo_loss, + kernel_operation_modes=["forward", "full", "backward"], + metric_name="speed", + metric_unit="ms", + **common_configs, + ) + run_benchmarks( + bench_test_fn=bench_memory_simpo_loss, + kernel_operation_modes=["full"], + metric_name="memory", + metric_unit="MB", + **common_configs, + ) diff --git a/benchmark/scripts/benchmark_softmax.py b/benchmark/scripts/benchmark_softmax.py index 10e994c8c..31d7f31e1 100644 --- a/benchmark/scripts/benchmark_softmax.py +++ b/benchmark/scripts/benchmark_softmax.py @@ -1,6 +1,15 @@ +import math +import os +import sys + import torch import triton +from benchmark_model_configs import MODEL_REGISTRY +from benchmark_model_configs import compute_model_config_sweep_config +from benchmark_model_configs import compute_seq_len_sweep_config +from benchmark_model_configs import estimate_kernel_peak_memory +from benchmark_model_configs import get_benchmark_model_config from utils import QUANTILES from utils import SingleBenchmarkRunInput from utils import SingleBenchmarkRunOutput @@ -13,33 +22,38 @@ device = infer_device() +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))) -def bench_speed_softmax(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: - N = input.x - provider = input.kernel_provider - mode = input.kernel_operation_mode - extra_benchmark_config = input.extra_benchmark_config - M = extra_benchmark_config["M"] - dtype = extra_benchmark_config["dtype"] - x_shape = (M, N) - liger_softmax = LigerSoftmax().to(device).to(dtype) - torch_softmax = torch.nn.Softmax(dim=-1).to(device).to(dtype) +def _setup_softmax(input: SingleBenchmarkRunInput): + """Create input tensors and softmax module from benchmark config.""" + cfg = input.extra_benchmark_config + hidden_size = cfg.get("hidden_size", input.x) + M = cfg.get("M", input.x) + dtype = cfg["dtype"] - x = torch.randn(x_shape, dtype=dtype, device=device) + x = torch.randn(M, hidden_size, dtype=dtype, device=device, requires_grad=True) dy = torch.randn_like(x) - x.requires_grad_(True) - def y_fwd(): - if provider == "liger": - return liger_softmax(x) - if provider == "torch": - return torch_softmax(x) + if input.kernel_provider == "liger": + softmax = LigerSoftmax().to(device).to(dtype) + elif input.kernel_provider == "torch": + softmax = torch.nn.Softmax(dim=-1).to(device).to(dtype) + else: + raise ValueError(f"Invalid provider: {input.kernel_provider} for softmax") + + fwd_fn = lambda: softmax(x) + return x, dy, fwd_fn + + +def bench_speed_softmax(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + x, dy, fwd_fn = _setup_softmax(input) + mode = input.kernel_operation_mode if mode == "forward": - ms_50, ms_20, ms_80 = triton.testing.do_bench(y_fwd, quantiles=QUANTILES, grad_to_none=[x], rep=500) + ms_50, ms_20, ms_80 = triton.testing.do_bench(fwd_fn, quantiles=QUANTILES, grad_to_none=[x], rep=500) elif mode == "backward": - y = y_fwd() + y = fwd_fn() ms_50, ms_20, ms_80 = triton.testing.do_bench( lambda: y.backward(dy, retain_graph=True), quantiles=QUANTILES, @@ -49,92 +63,173 @@ def y_fwd(): elif mode == "full": def full(): - y = y_fwd() + y = fwd_fn() y.backward(dy, retain_graph=True) ms_50, ms_20, ms_80 = triton.testing.do_bench(full, quantiles=QUANTILES, grad_to_none=[x], rep=500) - - if any(val is None for val in (ms_20, ms_50, ms_80)): - raise RuntimeError(f"Benchmark speed result is None: ms_20={ms_20}, ms_50={ms_50}, ms_80={ms_80}") - - return SingleBenchmarkRunOutput( - y_20=ms_20, - y_50=ms_50, - y_80=ms_80, - ) + else: + raise ValueError(f"Unsupported mode: {mode}") + return SingleBenchmarkRunOutput(y_20=ms_20, y_50=ms_50, y_80=ms_80) def bench_memory_softmax(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: - shape = input.x - provider = input.kernel_provider - mode = input.kernel_operation_mode - extra_benchmark_config = input.extra_benchmark_config - dtype = extra_benchmark_config.get("dtype", torch.float32) + x, dy, fwd_fn = _setup_softmax(input) - torch_softmax = torch.nn.Softmax(dim=-1) - liger_softmax = LigerSoftmax().to(device).to(dtype) + def full(): + y = fwd_fn() + y.backward(torch.ones_like(y), retain_graph=True) - x = torch.randn(shape, device=device, dtype=dtype, requires_grad=True) + mem_50, mem_20, mem_80 = _test_memory(full, quantiles=QUANTILES) + return SingleBenchmarkRunOutput(y_20=mem_20, y_50=mem_50, y_80=mem_80) + + +def _resolve_model_config_softmax(input: SingleBenchmarkRunInput): + cfg = input.extra_benchmark_config + model_info = cfg["model_configs"][input.x] + return _setup_softmax( + SingleBenchmarkRunInput( + x=input.x, + kernel_provider=input.kernel_provider, + extra_benchmark_config={ + "hidden_size": model_info["hidden_size"], + "dtype": model_info["dtype"], + "M": cfg["M"], + }, + ) + ) - def fwd(): - if provider == "liger": - return liger_softmax(x) - elif provider == "torch": - return torch_softmax(x) - else: - raise ValueError(f"Invalid provider: {provider} for softmax") - def full(): - y = fwd() - y.backward(torch.ones_like(y), retain_graph=True) +def bench_speed_softmax_model_config(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + x, dy, fwd_fn = _resolve_model_config_softmax(input) + mode = input.kernel_operation_mode if mode == "forward": - mem_50, mem_20, mem_80 = _test_memory(fwd, quantiles=QUANTILES) + ms_50, ms_20, ms_80 = triton.testing.do_bench(fwd_fn, quantiles=QUANTILES, grad_to_none=[x], rep=500) elif mode == "backward": - do = torch.ones_like(x) - y = fwd() - mem_50, mem_20, mem_80 = _test_memory(lambda: y.backward(do, retain_graph=True), quantiles=QUANTILES) + y = fwd_fn() + ms_50, ms_20, ms_80 = triton.testing.do_bench( + lambda: y.backward(dy, retain_graph=True), + quantiles=QUANTILES, + grad_to_none=[x], + rep=500, + ) + elif mode == "full": + + def full(): + y = fwd_fn() + y.backward(dy, retain_graph=True) + + ms_50, ms_20, ms_80 = triton.testing.do_bench(full, quantiles=QUANTILES, grad_to_none=[x], rep=500) else: - mem_50, mem_20, mem_80 = _test_memory(full, quantiles=QUANTILES) + raise ValueError(f"Unsupported mode: {mode}") + return SingleBenchmarkRunOutput(y_20=ms_20, y_50=ms_50, y_80=ms_80) - if any(val is None for val in (mem_20, mem_50, mem_80)): - raise RuntimeError(f"Benchmark memory result is None: mem_20={mem_20}, mem_50={mem_50}, mem_80={mem_80}") - return SingleBenchmarkRunOutput( - y_20=mem_20, - y_50=mem_50, - y_80=mem_80, - ) +def bench_memory_softmax_model_config(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + x, dy, fwd_fn = _resolve_model_config_softmax(input) + + def full(): + y = fwd_fn() + y.backward(torch.ones_like(y), retain_graph=True) + + mem_50, mem_20, mem_80 = _test_memory(full, quantiles=QUANTILES) + return SingleBenchmarkRunOutput(y_20=mem_20, y_50=mem_50, y_80=mem_80) if __name__ == "__main__": args = parse_benchmark_script_args() - common_configs = dict( - kernel_name="softmax", - x_name="N", - x_label="hidden size", - x_values=[128, 256, 512, 1024, 2048, 4096], - kernel_providers=["liger", "torch"], - extra_benchmark_configs=[ - {"M": 2048, "dtype": torch.float32}, - {"M": 2048, "dtype": torch.bfloat16}, - ], - ) - - run_benchmarks( - bench_test_fn=bench_speed_softmax, - kernel_operation_modes=["forward", "full", "backward"], - metric_name="speed", - metric_unit="ms", - overwrite=args.overwrite, - **common_configs, - ) - run_benchmarks( - bench_test_fn=bench_memory_softmax, - kernel_operation_modes=["full"], - metric_name="memory", - metric_unit="MB", - overwrite=args.overwrite, - **common_configs, - ) + if args.sweep_mode == "model_config": + all_model_configs = list(MODEL_REGISTRY.values()) + M = 2048 + + def _probe_factory(model_cfg, probe_bt): + def _probe(): + probe_input = SingleBenchmarkRunInput( + x=0, + kernel_provider="torch", + extra_benchmark_config={ + "hidden_size": model_cfg.hidden_size, + "dtype": model_cfg.dtype, + "M": M, + }, + ) + _, _, fwd_fn = _setup_softmax(probe_input) + return fwd_fn() + + return _probe + + sweep = compute_model_config_sweep_config(all_model_configs, probe_fn_factory=_probe_factory, bt=args.bt) + model_configs_info = { + cfg.name: {"hidden_size": cfg.hidden_size, "dtype": cfg.dtype} for cfg in sweep.model_configs + } + + common_configs = { + "kernel_name": "softmax", + "x_name": "model_config", + "x_label": "model configuration", + "x_values": [cfg.name for cfg in sweep.model_configs], + "kernel_providers": ["liger", "torch"], + "extra_benchmark_configs": [{"model_configs": model_configs_info, "M": M}], + "overwrite": args.overwrite, + } + + run_benchmarks( + bench_test_fn=bench_speed_softmax_model_config, + kernel_operation_modes=["forward", "backward", "full"], + metric_name="speed", + metric_unit="ms", + **common_configs, + ) + run_benchmarks( + bench_test_fn=bench_memory_softmax_model_config, + kernel_operation_modes=["full"], + metric_name="memory", + metric_unit="MB", + **common_configs, + ) + else: + model = get_benchmark_model_config(args.model) + probe_bt = 2048 + + def _probe(): + probe_input = SingleBenchmarkRunInput( + x=0, + kernel_provider="torch", + extra_benchmark_config={ + "hidden_size": model.hidden_size, + "dtype": model.dtype, + "M": probe_bt, + }, + ) + _, _, fwd_fn = _setup_softmax(probe_input) + return fwd_fn() + + peak_bytes = estimate_kernel_peak_memory(probe_fn=_probe) + kernel_bpt = peak_bytes // probe_bt + config = compute_seq_len_sweep_config(model, kernel_bytes_per_token=kernel_bpt) + + common_configs = { + "kernel_name": "softmax", + "x_name": "BT", + "x_label": "B x T", + "x_values": [2**i for i in range(10, int(math.log2(max(1024, config.batch_size * config.seq_len))) + 1)], + "kernel_providers": ["liger", "torch"], + "extra_benchmark_configs": [{"hidden_size": model.hidden_size, "dtype": model.dtype}], + "overwrite": args.overwrite, + } + + run_benchmarks( + bench_test_fn=bench_speed_softmax, + kernel_operation_modes=["forward", "backward", "full"], + metric_name="speed", + metric_unit="ms", + **common_configs, + ) + run_benchmarks( + bench_test_fn=bench_memory_softmax, + kernel_operation_modes=["full"], + metric_name="memory", + metric_unit="MB", + **common_configs, + ) diff --git a/benchmark/scripts/benchmark_sparse_multi_token_attention.py b/benchmark/scripts/benchmark_sparse_multi_token_attention.py index 98f47d713..16a9c4ab5 100644 --- a/benchmark/scripts/benchmark_sparse_multi_token_attention.py +++ b/benchmark/scripts/benchmark_sparse_multi_token_attention.py @@ -1,6 +1,15 @@ +import math +import os +import sys + import torch import triton +from benchmark_model_configs import MODEL_REGISTRY +from benchmark_model_configs import compute_model_config_sweep_config +from benchmark_model_configs import compute_seq_len_sweep_config +from benchmark_model_configs import estimate_kernel_peak_memory +from benchmark_model_configs import get_benchmark_model_config from utils import QUANTILES from utils import SingleBenchmarkRunInput from utils import SingleBenchmarkRunOutput @@ -13,6 +22,8 @@ device = infer_device() +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))) + class TorchSparseMultiTokenAttention(torch.nn.Module): def __init__(self, C_in, C_out, K, groups, bias, dtype, device): @@ -37,9 +48,7 @@ def forward(self, scores): z = s_inf z_sorted, _ = torch.sort(z, dim=dim, descending=True) - cum_sum = torch.cumsum(z_sorted, dim=dim) - k_indices = torch.arange(1, L + 1, device=z.device, dtype=z.dtype).view(1, 1, 1, L) is_positive = z_sorted > -1e8 @@ -47,7 +56,6 @@ def forward(self, scores): k_sparsemax = torch.sum(condition, dim=dim, keepdim=True) k_sparsemax_safe = torch.max(k_sparsemax, torch.ones_like(k_sparsemax)) - cum_sum_k = torch.gather(cum_sum, dim=dim, index=k_sparsemax_safe.long() - 1) tau = (cum_sum_k - 1) / k_sparsemax_safe.to(z.dtype) @@ -64,21 +72,17 @@ def forward(self, scores): return out_c.masked_fill(~mask, zero).to(scores.dtype) -def bench_speed_sparse_multi_token_attention(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: - L = input.x - provider = input.kernel_provider - mode = input.kernel_operation_mode - - extra_benchmark_config = input.extra_benchmark_config - B = extra_benchmark_config["B"] - C_in = extra_benchmark_config["C_in"] - C_out = extra_benchmark_config["C_out"] - K = extra_benchmark_config["K"] - groups = extra_benchmark_config["groups"] - bias = extra_benchmark_config["bias"] - dtype = extra_benchmark_config["dtype"] - - x_shape = (B, C_in, L, L) +def _setup_sparse_multi_token_attention(input: SingleBenchmarkRunInput): + """Create input tensors and sparse multi-token attention from benchmark config.""" + cfg = input.extra_benchmark_config + C_in = cfg["C_in"] + C_out = cfg["C_out"] + K = cfg["K"] + groups = cfg["groups"] + bias = cfg["bias"] + dtype = cfg["dtype"] + B = cfg.get("B", 2) + L = cfg.get("L", input.x) liger_attn = ( LigerMultiTokenAttention( @@ -97,7 +101,13 @@ def bench_speed_sparse_multi_token_attention(input: SingleBenchmarkRunInput) -> ) torch_attn = TorchSparseMultiTokenAttention( - C_in=C_in, C_out=C_out, K=K, groups=groups, bias=bias, dtype=dtype, device=device + C_in=C_in, + C_out=C_out, + K=K, + groups=groups, + bias=bias, + dtype=dtype, + device=device, ) with torch.no_grad(): @@ -108,26 +118,31 @@ def bench_speed_sparse_multi_token_attention(input: SingleBenchmarkRunInput) -> if bias: torch_attn.bias.copy_(liger_attn.bias) - x = torch.randn(x_shape, dtype=dtype, device=device) + x = torch.randn(B, C_in, L, L, dtype=dtype, device=device, requires_grad=True) dy = torch.randn_like(x) - x.requires_grad_(True) - - def fwd(): - if provider == "liger": - return liger_attn(x) - elif provider == "torch": - return torch_attn(x) - - print(f"Starting Warmup for input size: {x_shape}") - _ = fwd() - if mode in ("backward", "full"): - y = _ - y.backward(dy, retain_graph=True) - print("Done Warmup") + + if input.kernel_provider == "liger": + fwd_fn = lambda: liger_attn(x) + elif input.kernel_provider == "torch": + fwd_fn = lambda: torch_attn(x) + else: + raise ValueError(f"Invalid provider: {input.kernel_provider} for sparse multi-token attention") + + # Warmup + _ = fwd_fn() + _.backward(dy, retain_graph=True) + + return x, dy, fwd_fn + + +def bench_speed_sparse_multi_token_attention(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + x, dy, fwd_fn = _setup_sparse_multi_token_attention(input) + mode = input.kernel_operation_mode if mode == "forward": - ms_50, ms_20, ms_80 = triton.testing.do_bench(fwd, grad_to_none=[x], rep=100, quantiles=QUANTILES) + ms_50, ms_20, ms_80 = triton.testing.do_bench(fwd_fn, grad_to_none=[x], rep=100, quantiles=QUANTILES) elif mode == "backward": + y = fwd_fn() ms_50, ms_20, ms_80 = triton.testing.do_bench( lambda: y.backward(dy, retain_graph=True), grad_to_none=[x], @@ -137,118 +152,201 @@ def fwd(): elif mode == "full": def full(): - y = fwd() + y = fwd_fn() y.backward(dy, retain_graph=True) ms_50, ms_20, ms_80 = triton.testing.do_bench(full, grad_to_none=[x], rep=100, quantiles=QUANTILES) - - return SingleBenchmarkRunOutput( - y_20=ms_20, - y_50=ms_50, - y_80=ms_80, - ) + else: + raise ValueError(f"Unsupported mode: {mode}") + return SingleBenchmarkRunOutput(y_20=ms_20, y_50=ms_50, y_80=ms_80) def bench_memory_sparse_multi_token_attention(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: - L = input.x - provider = input.kernel_provider - - extra_benchmark_config = input.extra_benchmark_config - B = extra_benchmark_config["B"] - C_in = extra_benchmark_config["C_in"] - C_out = extra_benchmark_config["C_out"] - K = extra_benchmark_config["K"] - groups = extra_benchmark_config["groups"] - bias = extra_benchmark_config["bias"] - dtype = extra_benchmark_config["dtype"] + x, dy, fwd_fn = _setup_sparse_multi_token_attention(input) - x_shape = (B, C_in, L, L) + def full(): + y = fwd_fn() + y.backward(dy, retain_graph=True) - liger_attn = ( - LigerMultiTokenAttention( - in_channels=C_in, - out_channels=C_out, - kernel_size=K, - stride=1, - padding=K // 2, - dilation=1, - groups=groups, - bias=bias, - sparse=True, + mem_50, mem_20, mem_80 = _test_memory(full, quantiles=QUANTILES) + return SingleBenchmarkRunOutput(y_20=mem_20, y_50=mem_50, y_80=mem_80) + + +def _resolve_model_config_sparse_multi_token_attention(input: SingleBenchmarkRunInput): + cfg = input.extra_benchmark_config + model_info = cfg["model_configs"][input.x] + return _setup_sparse_multi_token_attention( + SingleBenchmarkRunInput( + x=input.x, + kernel_provider=input.kernel_provider, + extra_benchmark_config={ + "C_in": cfg["C_in"], + "C_out": cfg["C_out"], + "K": cfg["K"], + "groups": cfg["groups"], + "bias": cfg["bias"], + "dtype": model_info["dtype"], + "B": cfg["B"], + "L": cfg["L"], + }, ) - .to(device) - .to(dtype) ) - torch_attn = TorchSparseMultiTokenAttention( - C_in=C_in, C_out=C_out, K=K, groups=groups, bias=bias, dtype=dtype, device=device - ) - with torch.no_grad(): - torch.nn.init.kaiming_uniform_(liger_attn.weight, a=5**0.5) - if bias: - torch.nn.init.zeros_(liger_attn.bias) - torch_attn.weight.copy_(liger_attn.weight) - if bias: - torch_attn.bias.copy_(liger_attn.bias) +def bench_speed_sparse_multi_token_attention_model_config(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + x, dy, fwd_fn = _resolve_model_config_sparse_multi_token_attention(input) + mode = input.kernel_operation_mode + + if mode == "forward": + ms_50, ms_20, ms_80 = triton.testing.do_bench(fwd_fn, grad_to_none=[x], rep=100, quantiles=QUANTILES) + elif mode == "backward": + y = fwd_fn() + ms_50, ms_20, ms_80 = triton.testing.do_bench( + lambda: y.backward(dy, retain_graph=True), + grad_to_none=[x], + rep=100, + quantiles=QUANTILES, + ) + elif mode == "full": + + def full(): + y = fwd_fn() + y.backward(dy, retain_graph=True) + + ms_50, ms_20, ms_80 = triton.testing.do_bench(full, grad_to_none=[x], rep=100, quantiles=QUANTILES) + else: + raise ValueError(f"Unsupported mode: {mode}") + return SingleBenchmarkRunOutput(y_20=ms_20, y_50=ms_50, y_80=ms_80) - x = torch.randn(x_shape, dtype=dtype, device=device) - dy = torch.randn_like(x) - x.requires_grad_(True) - def fwd(): - if provider == "liger": - return liger_attn(x) - elif provider == "torch": - return torch_attn(x) +def bench_memory_sparse_multi_token_attention_model_config(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + x, dy, fwd_fn = _resolve_model_config_sparse_multi_token_attention(input) def full(): - y = fwd() + y = fwd_fn() y.backward(dy, retain_graph=True) mem_50, mem_20, mem_80 = _test_memory(full, quantiles=QUANTILES) - - return SingleBenchmarkRunOutput( - y_20=mem_20, - y_50=mem_50, - y_80=mem_80, - ) + return SingleBenchmarkRunOutput(y_20=mem_20, y_50=mem_50, y_80=mem_80) if __name__ == "__main__": args = parse_benchmark_script_args() - common_configs = { - "kernel_name": "sparse_multi_token_attention", - "x_name": "L", - "x_label": "sequence length", - "x_values": [2**i for i in range(5, 10)], - "kernel_providers": ["liger", "torch"], - "extra_benchmark_configs": [ - { - "B": 2, - "C_in": 4, - "C_out": 4, - "K": 3, - "groups": 1, - "bias": True, - "dtype": torch.float32, - } - ], - "overwrite": args.overwrite, - } - - run_benchmarks( - bench_test_fn=bench_speed_sparse_multi_token_attention, - kernel_operation_modes=["forward", "full", "backward"], - metric_name="speed", - metric_unit="ms", - **common_configs, - ) - run_benchmarks( - bench_test_fn=bench_memory_sparse_multi_token_attention, - kernel_operation_modes=["full"], - metric_name="memory", - metric_unit="MB", - **common_configs, - ) + if args.sweep_mode == "model_config": + all_model_configs = list(MODEL_REGISTRY.values()) + L = 256 + B = 2 + + def _probe_factory(model_cfg, probe_bt): + def _probe(): + probe_input = SingleBenchmarkRunInput( + x=0, + kernel_provider="torch", + extra_benchmark_config={ + "C_in": 4, + "C_out": 4, + "K": 3, + "groups": 1, + "bias": True, + "dtype": model_cfg.dtype, + "B": B, + "L": L, + }, + ) + _, _, fwd_fn = _setup_sparse_multi_token_attention(probe_input) + return fwd_fn() + + return _probe + + sweep = compute_model_config_sweep_config(all_model_configs, probe_fn_factory=_probe_factory, bt=args.bt) + model_configs_info = {cfg.name: {"dtype": cfg.dtype} for cfg in sweep.model_configs} + + common_configs = { + "kernel_name": "sparse_multi_token_attention", + "x_name": "model_config", + "x_label": "model configuration", + "x_values": [cfg.name for cfg in sweep.model_configs], + "kernel_providers": ["liger", "torch"], + "extra_benchmark_configs": [ + { + "model_configs": model_configs_info, + "C_in": 4, + "C_out": 4, + "K": 3, + "groups": 1, + "bias": True, + "B": B, + "L": L, + } + ], + "overwrite": args.overwrite, + } + + run_benchmarks( + bench_test_fn=bench_speed_sparse_multi_token_attention_model_config, + kernel_operation_modes=["forward", "backward", "full"], + metric_name="speed", + metric_unit="ms", + **common_configs, + ) + run_benchmarks( + bench_test_fn=bench_memory_sparse_multi_token_attention_model_config, + kernel_operation_modes=["full"], + metric_name="memory", + metric_unit="MB", + **common_configs, + ) + else: + model = get_benchmark_model_config(args.model) + B = 2 + probe_L = 256 + + def _probe(): + probe_input = SingleBenchmarkRunInput( + x=0, + kernel_provider="torch", + extra_benchmark_config={ + "C_in": 4, + "C_out": 4, + "K": 3, + "groups": 1, + "bias": True, + "dtype": model.dtype, + "B": B, + "L": probe_L, + }, + ) + _, _, fwd_fn = _setup_sparse_multi_token_attention(probe_input) + return fwd_fn() + + peak_bytes = estimate_kernel_peak_memory(probe_fn=_probe) + kernel_bpt = peak_bytes // probe_L + config = compute_seq_len_sweep_config(model, kernel_bytes_per_token=kernel_bpt) + + common_configs = { + "kernel_name": "sparse_multi_token_attention", + "x_name": "L", + "x_label": "sequence length", + "x_values": [2**i for i in range(5, int(math.log2(max(32, config.seq_len))) + 1)], + "kernel_providers": ["liger", "torch"], + "extra_benchmark_configs": [ + {"C_in": 4, "C_out": 4, "K": 3, "groups": 1, "bias": True, "dtype": model.dtype, "B": B} + ], + "overwrite": args.overwrite, + } + + run_benchmarks( + bench_test_fn=bench_speed_sparse_multi_token_attention, + kernel_operation_modes=["forward", "backward", "full"], + metric_name="speed", + metric_unit="ms", + **common_configs, + ) + run_benchmarks( + bench_test_fn=bench_memory_sparse_multi_token_attention, + kernel_operation_modes=["full"], + metric_name="memory", + metric_unit="MB", + **common_configs, + ) diff --git a/benchmark/scripts/benchmark_sparsemax.py b/benchmark/scripts/benchmark_sparsemax.py index 919f4c66d..fcc6884d9 100644 --- a/benchmark/scripts/benchmark_sparsemax.py +++ b/benchmark/scripts/benchmark_sparsemax.py @@ -1,6 +1,15 @@ +import math +import os +import sys + import torch import triton +from benchmark_model_configs import MODEL_REGISTRY +from benchmark_model_configs import compute_model_config_sweep_config +from benchmark_model_configs import compute_seq_len_sweep_config +from benchmark_model_configs import estimate_kernel_peak_memory +from benchmark_model_configs import get_benchmark_model_config from utils import QUANTILES from utils import SingleBenchmarkRunInput from utils import SingleBenchmarkRunOutput @@ -13,6 +22,8 @@ device = infer_device() +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))) + def torch_sparsemax(input_tensor: torch.Tensor, dim: int = -1) -> torch.Tensor: input_dims = input_tensor.dim() @@ -42,42 +53,36 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return torch_sparsemax(x, dim=self.dim) -def bench_speed_sparsemax(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: - V = input.x - provider = input.kernel_provider - mode = input.kernel_operation_mode +def _setup_sparsemax(input: SingleBenchmarkRunInput): + """Create input tensors and sparsemax module from benchmark config.""" + cfg = input.extra_benchmark_config + hidden_size = cfg.get("hidden_size", input.x) + M = cfg.get("M", input.x) + dtype = cfg["dtype"] + dim = cfg.get("dim", -1) - extra_benchmark_config = input.extra_benchmark_config - B = extra_benchmark_config["B"] - T = extra_benchmark_config["T"] - dim = extra_benchmark_config["dim"] - dtype = extra_benchmark_config["dtype"] + x = torch.randn(M, hidden_size, dtype=dtype, device=device, requires_grad=True) + dy = torch.randn_like(x) - x_shape = (B * T, V) + if input.kernel_provider == "liger": + sparsemax_module = LigerSparsemax(dim=dim).to(device) + elif input.kernel_provider == "torch": + sparsemax_module = TorchSparsemax(dim=dim).to(device) + else: + raise ValueError(f"Invalid provider: {input.kernel_provider} for sparsemax") - torch_sparsemax_module = TorchSparsemax(dim=dim).to(device) - liger_sparsemax_module = LigerSparsemax(dim=dim).to(device) + fwd_fn = lambda: sparsemax_module(x) + return x, dy, fwd_fn - x = torch.randn(x_shape, dtype=dtype, device=device) - dy = torch.randn_like(x) - x.requires_grad_(True) - # utility functions - def y_fwd(): - if provider == "liger": - return liger_sparsemax_module(x) - elif provider == "torch": - return torch_sparsemax_module(x) +def bench_speed_sparsemax(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + x, dy, fwd_fn = _setup_sparsemax(input) + mode = input.kernel_operation_mode if mode == "forward": - ms_50, ms_20, ms_80 = triton.testing.do_bench( - y_fwd, - grad_to_none=[x], - rep=500, - quantiles=QUANTILES, - ) + ms_50, ms_20, ms_80 = triton.testing.do_bench(fwd_fn, grad_to_none=[x], rep=500, quantiles=QUANTILES) elif mode == "backward": - y = y_fwd() + y = fwd_fn() ms_50, ms_20, ms_80 = triton.testing.do_bench( lambda: y.backward(dy, retain_graph=True), grad_to_none=[x], @@ -87,86 +92,176 @@ def y_fwd(): elif mode == "full": def full(): - y = y_fwd() + y = fwd_fn() y.backward(dy, retain_graph=True) - ms_50, ms_20, ms_80 = triton.testing.do_bench( - full, - grad_to_none=[x], - rep=500, - quantiles=QUANTILES, - ) + ms_50, ms_20, ms_80 = triton.testing.do_bench(full, grad_to_none=[x], rep=500, quantiles=QUANTILES) + else: + raise ValueError(f"Unsupported mode: {mode}") + return SingleBenchmarkRunOutput(y_20=ms_20, y_50=ms_50, y_80=ms_80) + + +def bench_memory_sparsemax(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + x, dy, fwd_fn = _setup_sparsemax(input) + + def full(): + y = fwd_fn() + y.backward(dy, retain_graph=True) - return SingleBenchmarkRunOutput( - y_20=ms_20, - y_50=ms_50, - y_80=ms_80, + mem_50, mem_20, mem_80 = _test_memory(full, quantiles=QUANTILES) + return SingleBenchmarkRunOutput(y_20=mem_20, y_50=mem_50, y_80=mem_80) + + +def _resolve_model_config_sparsemax(input: SingleBenchmarkRunInput): + cfg = input.extra_benchmark_config + model_info = cfg["model_configs"][input.x] + return _setup_sparsemax( + SingleBenchmarkRunInput( + x=input.x, + kernel_provider=input.kernel_provider, + extra_benchmark_config={ + "hidden_size": model_info["hidden_size"], + "dtype": model_info["dtype"], + "M": cfg["M"], + "dim": cfg.get("dim", -1), + }, + ) ) -def bench_memory_sparsemax(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: - V = input.x - provider = input.kernel_provider +def bench_speed_sparsemax_model_config(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + x, dy, fwd_fn = _resolve_model_config_sparsemax(input) + mode = input.kernel_operation_mode - extra_benchmark_config = input.extra_benchmark_config - B = extra_benchmark_config["B"] - T = extra_benchmark_config["T"] - dim = extra_benchmark_config["dim"] - dtype = extra_benchmark_config["dtype"] + if mode == "forward": + ms_50, ms_20, ms_80 = triton.testing.do_bench(fwd_fn, grad_to_none=[x], rep=500, quantiles=QUANTILES) + elif mode == "backward": + y = fwd_fn() + ms_50, ms_20, ms_80 = triton.testing.do_bench( + lambda: y.backward(dy, retain_graph=True), + grad_to_none=[x], + rep=500, + quantiles=QUANTILES, + ) + elif mode == "full": - x_shape = (B * T, V) + def full(): + y = fwd_fn() + y.backward(dy, retain_graph=True) - torch_sparsemax_module = TorchSparsemax(dim=dim).to(device) - liger_sparsemax_module = LigerSparsemax(dim=dim).to(device) + ms_50, ms_20, ms_80 = triton.testing.do_bench(full, grad_to_none=[x], rep=500, quantiles=QUANTILES) + else: + raise ValueError(f"Unsupported mode: {mode}") + return SingleBenchmarkRunOutput(y_20=ms_20, y_50=ms_50, y_80=ms_80) - x = torch.randn(x_shape, dtype=dtype, device=device) - dy = torch.randn_like(x) - x.requires_grad_(True) - # utility functions - def y_fwd(): - if provider == "liger": - return liger_sparsemax_module(x) - elif provider == "torch": - return torch_sparsemax_module(x) +def bench_memory_sparsemax_model_config(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + x, dy, fwd_fn = _resolve_model_config_sparsemax(input) def full(): - y = y_fwd() + y = fwd_fn() y.backward(dy, retain_graph=True) mem_50, mem_20, mem_80 = _test_memory(full, quantiles=QUANTILES) - - return SingleBenchmarkRunOutput( - y_20=mem_20, - y_50=mem_50, - y_80=mem_80, - ) + return SingleBenchmarkRunOutput(y_20=mem_20, y_50=mem_50, y_80=mem_80) if __name__ == "__main__": args = parse_benchmark_script_args() - common_configs = { - "kernel_name": "sparsemax", - "x_name": "V", - "x_label": "feature size", - "x_values": [2**i for i in range(10, 16)], - "kernel_providers": ["liger", "torch"], - "extra_benchmark_configs": [{"B": 4, "T": 512, "dim": -1, "dtype": torch.float32}], - "overwrite": args.overwrite, - } - - run_benchmarks( - bench_test_fn=bench_speed_sparsemax, - kernel_operation_modes=["forward", "full", "backward"], - metric_name="speed", - metric_unit="ms", - **common_configs, - ) - run_benchmarks( - bench_test_fn=bench_memory_sparsemax, - kernel_operation_modes=["full"], - metric_name="memory", - metric_unit="MB", - **common_configs, - ) + if args.sweep_mode == "model_config": + all_model_configs = list(MODEL_REGISTRY.values()) + M = 2048 + + def _probe_factory(model_cfg, probe_bt): + def _probe(): + probe_input = SingleBenchmarkRunInput( + x=0, + kernel_provider="torch", + extra_benchmark_config={ + "hidden_size": model_cfg.hidden_size, + "dtype": model_cfg.dtype, + "M": M, + "dim": -1, + }, + ) + _, _, fwd_fn = _setup_sparsemax(probe_input) + return fwd_fn() + + return _probe + + sweep = compute_model_config_sweep_config(all_model_configs, probe_fn_factory=_probe_factory, bt=args.bt) + model_configs_info = { + cfg.name: {"hidden_size": cfg.hidden_size, "dtype": cfg.dtype} for cfg in sweep.model_configs + } + + common_configs = { + "kernel_name": "sparsemax", + "x_name": "model_config", + "x_label": "model configuration", + "x_values": [cfg.name for cfg in sweep.model_configs], + "kernel_providers": ["liger", "torch"], + "extra_benchmark_configs": [{"model_configs": model_configs_info, "M": M, "dim": -1}], + "overwrite": args.overwrite, + } + + run_benchmarks( + bench_test_fn=bench_speed_sparsemax_model_config, + kernel_operation_modes=["forward", "backward", "full"], + metric_name="speed", + metric_unit="ms", + **common_configs, + ) + run_benchmarks( + bench_test_fn=bench_memory_sparsemax_model_config, + kernel_operation_modes=["full"], + metric_name="memory", + metric_unit="MB", + **common_configs, + ) + else: + model = get_benchmark_model_config(args.model) + probe_bt = 2048 + + def _probe(): + probe_input = SingleBenchmarkRunInput( + x=0, + kernel_provider="torch", + extra_benchmark_config={ + "hidden_size": model.hidden_size, + "dtype": model.dtype, + "M": probe_bt, + "dim": -1, + }, + ) + _, _, fwd_fn = _setup_sparsemax(probe_input) + return fwd_fn() + + peak_bytes = estimate_kernel_peak_memory(probe_fn=_probe) + kernel_bpt = peak_bytes // probe_bt + config = compute_seq_len_sweep_config(model, kernel_bytes_per_token=kernel_bpt) + + common_configs = { + "kernel_name": "sparsemax", + "x_name": "BT", + "x_label": "B x T", + "x_values": [2**i for i in range(10, int(math.log2(max(1024, config.batch_size * config.seq_len))) + 1)], + "kernel_providers": ["liger", "torch"], + "extra_benchmark_configs": [{"hidden_size": model.hidden_size, "dtype": model.dtype, "dim": -1}], + "overwrite": args.overwrite, + } + + run_benchmarks( + bench_test_fn=bench_speed_sparsemax, + kernel_operation_modes=["forward", "backward", "full"], + metric_name="speed", + metric_unit="ms", + **common_configs, + ) + run_benchmarks( + bench_test_fn=bench_memory_sparsemax, + kernel_operation_modes=["full"], + metric_name="memory", + metric_unit="MB", + **common_configs, + ) diff --git a/benchmark/scripts/benchmark_tiled_mlp.py b/benchmark/scripts/benchmark_tiled_mlp.py index 1eaf21dac..7c64f1715 100644 --- a/benchmark/scripts/benchmark_tiled_mlp.py +++ b/benchmark/scripts/benchmark_tiled_mlp.py @@ -1,9 +1,16 @@ import math +import os +import sys import torch import torch.nn as nn import triton +from benchmark_model_configs import MODEL_REGISTRY +from benchmark_model_configs import compute_model_config_sweep_config +from benchmark_model_configs import compute_seq_len_sweep_config +from benchmark_model_configs import estimate_kernel_peak_memory +from benchmark_model_configs import get_benchmark_model_config from transformers.models.llama.configuration_llama import LlamaConfig from transformers.models.llama.modeling_llama import LlamaMLP from utils import QUANTILES @@ -21,18 +28,12 @@ device = infer_device() +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))) + # DeepSpeed TiledMLP implementation # Based on: https://github.com/deepspeedai/DeepSpeed/blob/v0.18.2/deepspeed/runtime/sequence_parallel/ulysses_sp.py#L838 class DeepSpeedTiledMLP(torch.autograd.Function): - """ - DeepSpeed's TiledMLP implementation for fair comparison. - This is the actual DeepSpeed algorithm that performs tiled MLP computation - to massively reduce memory usage with very long sequence lengths. - - This module re-computes forward in the backward, so forward occurs twice per iteration. - """ - @staticmethod def forward(ctx, fn, self, x, shards, compute_params) -> torch.Tensor: ctx.fn = fn @@ -41,12 +42,10 @@ def forward(ctx, fn, self, x, shards, compute_params) -> torch.Tensor: ctx.compute_params = [p for p in compute_params if p.requires_grad] if compute_params else [] ctx.save_for_backward(x) - # x.shape could be [bs, seqlen, hidden_size] or [seqlen, hidden_size] (moe experts) x_shards = list(torch.chunk(x, chunks=shards, dim=-2)) with torch.no_grad(): output_shards = [fn(self, x_shard) for x_shard in x_shards] output_unsharded = torch.cat(output_shards, dim=-2) - return output_unsharded @staticmethod @@ -59,14 +58,11 @@ def backward(ctx, *grads): x_requires_grad = x.requires_grad x = x.detach() - # detach() unsets x.requires_grad, so restore it x.requires_grad_(x_requires_grad) - # x.shape could be [bs, seqlen, hidden_size] or [seqlen, hidden_size] (moe experts) hidden_size = x.shape[-1] x_shape_orig = x.shape - # flatten bs+seqlen to avoid having stride issues when narrowing into seqlen w/ bs>1 x = x.view(-1, hidden_size) incoming_grad = grads[0].view(-1, hidden_size) x_grad = torch.zeros_like(x) @@ -74,22 +70,18 @@ def backward(ctx, *grads): x_shards = list(torch.chunk(x, chunks=shards, dim=0)) for i, x_shard in enumerate(x_shards): - # Tell deepspeed not to add a new grad to its ipg bucket until the last shard is run - # XXX: DDP, FSDP will need something similar to make it work if compute_params: if i + 1 < shards: for param in compute_params: if hasattr(param, "ds_grad_is_ready"): param.ds_grad_is_ready = False else: - # last shard, can add the grad for param in compute_params: if hasattr(param, "ds_grad_is_ready"): param.ds_grad_is_ready = True x_shard.requires_grad_(x_requires_grad) - # if seqlen is not exactly divisible by shards the last step will be shorter than shard_step shard_step = x_shards[i].shape[0] shard_offset = i * x_shards[0].shape[0] @@ -99,30 +91,20 @@ def backward(ctx, *grads): output = fn(self, x_shard) torch.autograd.backward(output, incoming_grad_shard) - # unflatten x_grad = x_grad.view(x_shape_orig) - return (None, None, x_grad, None, None) -# DeepSpeed TiledMLP wrapper to match our interface class DeepSpeedTiledMLPWrapper(nn.Module): - """ - Wrapper for DeepSpeed's TiledMLP to match the interface used in benchmarks. - Uses the DeepSpeed TiledMLP algorithm for memory-efficient MLP computation. - """ - def __init__(self, config, num_shards=None): super().__init__() self.config = config self.hidden_size = config.hidden_size self.intermediate_size = config.intermediate_size self.num_shards = num_shards - self.mlp = LlamaMLP(config=config) def forward(self, x): - # Calculate num_shards if not provided num_shards = self.num_shards if num_shards is None: hidden_size = x.shape[-1] @@ -130,38 +112,29 @@ def forward(self, x): num_shards = math.ceil(seqlen / hidden_size) num_shards = max(1, num_shards) - # Collect compute parameters for DeepSpeed ZeRO compatibility compute_params = [ self.mlp.down_proj.weight, self.mlp.gate_proj.weight, self.mlp.up_proj.weight, ] - # Define the MLP forward function for DeepSpeed TiledMLP def mlp_forward(mlp_module, x_input): return mlp_module.down_proj(mlp_module.act_fn(mlp_module.gate_proj(x_input)) * mlp_module.up_proj(x_input)) - # Use DeepSpeed's TiledMLP implementation - return DeepSpeedTiledMLP.apply( - mlp_forward, - self.mlp, - x, - num_shards, - compute_params, - ) + return DeepSpeedTiledMLP.apply(mlp_forward, self.mlp, x, num_shards, compute_params) -def bench_speed_tiled_mlp(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: - seq_len = input.x - bsz = input.extra_benchmark_config["bsz"] - hidden_size = input.extra_benchmark_config["hidden_size"] - intermediate_size = input.extra_benchmark_config["intermediate_size"] - hidden_act = input.extra_benchmark_config["hidden_act"] - dtype = input.extra_benchmark_config["dtype"] - num_shards = input.extra_benchmark_config.get("num_shards", None) - activation_type = input.extra_benchmark_config["activation_type"] - provider = input.kernel_provider - mode = input.kernel_operation_mode +def _setup_tiled_mlp(input: SingleBenchmarkRunInput): + """Create input tensors and tiled MLP from benchmark config.""" + cfg = input.extra_benchmark_config + hidden_size = cfg["hidden_size"] + intermediate_size = cfg["intermediate_size"] + hidden_act = cfg["hidden_act"] + dtype = cfg["dtype"] + activation_type = cfg["activation_type"] + num_shards = cfg.get("num_shards", None) + bsz = cfg.get("bsz", 2) + seq_len = cfg.get("seq_len", input.x) llama_config = LlamaConfig( hidden_size=hidden_size, @@ -169,229 +142,293 @@ def bench_speed_tiled_mlp(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunO hidden_act=hidden_act, ) - x_shape = (bsz, seq_len, hidden_size) - - # initialize input - x = torch.randn(*x_shape, device=device, dtype=dtype, requires_grad=True) + x = torch.randn(bsz, seq_len, hidden_size, device=device, dtype=dtype, requires_grad=True) if activation_type == "geglu": - if provider == "huggingface": + if input.kernel_provider == "huggingface": layer = LlamaMLP(config=llama_config).to(device).to(dtype) - elif provider == "liger": + elif input.kernel_provider == "liger": layer = LigerGEGLUMLP(config=llama_config).to(device).to(dtype) - elif provider == "liger_tiled": + elif input.kernel_provider == "liger_tiled": layer = LigerTiledGEGLUMLP(config=llama_config, num_shards=num_shards).to(device).to(dtype) - elif provider == "deepspeed_tiled": + elif input.kernel_provider == "deepspeed_tiled": layer = DeepSpeedTiledMLPWrapper(config=llama_config, num_shards=num_shards).to(device).to(dtype) else: - raise ValueError(f"Invalid provider: {provider} for GEGLU") + raise ValueError(f"Invalid provider: {input.kernel_provider} for GEGLU") elif activation_type == "swiglu": - if provider == "huggingface": + if input.kernel_provider == "huggingface": layer = LlamaMLP(config=llama_config).to(device).to(dtype) - elif provider == "liger": + elif input.kernel_provider == "liger": layer = LigerSwiGLUMLP(config=llama_config).to(device).to(dtype) - elif provider == "liger_tiled": + elif input.kernel_provider == "liger_tiled": layer = LigerTiledSwiGLUMLP(config=llama_config, num_shards=num_shards).to(device).to(dtype) - elif provider == "deepspeed_tiled": + elif input.kernel_provider == "deepspeed_tiled": layer = DeepSpeedTiledMLPWrapper(config=llama_config, num_shards=num_shards).to(device).to(dtype) else: - raise ValueError(f"Invalid provider: {provider} for SwiGLU") + raise ValueError(f"Invalid provider: {input.kernel_provider} for SwiGLU") else: raise ValueError(f"Invalid activation_type: {activation_type}") - def fwd(): - return layer(x) + fwd_fn = lambda: layer(x) + return x, fwd_fn + + +def bench_speed_tiled_mlp(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + x, fwd_fn = _setup_tiled_mlp(input) + mode = input.kernel_operation_mode if mode == "forward": - ms_50, ms_20, ms_80 = triton.testing.do_bench( - fwd, - grad_to_none=[x], - rep=10, - quantiles=QUANTILES, - ) + ms_50, ms_20, ms_80 = triton.testing.do_bench(fwd_fn, grad_to_none=[x], rep=10, quantiles=QUANTILES) elif mode == "backward": do = torch.randn_like(x) - y = fwd() + y = fwd_fn() ms_50, ms_20, ms_80 = triton.testing.do_bench( lambda: y.backward(do, retain_graph=True), grad_to_none=[x], rep=10, quantiles=QUANTILES, ) + elif mode == "full": + + def full(): + y = fwd_fn() + y.backward(torch.randn_like(y), retain_graph=True) + + ms_50, ms_20, ms_80 = triton.testing.do_bench(full, grad_to_none=[x], rep=10, quantiles=QUANTILES) + else: + raise ValueError(f"Unsupported mode: {mode}") + return SingleBenchmarkRunOutput(y_20=ms_20, y_50=ms_50, y_80=ms_80) + + +def bench_memory_tiled_mlp(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + x, fwd_fn = _setup_tiled_mlp(input) + mode = input.kernel_operation_mode + + if mode == "forward": + mem_50, mem_20, mem_80 = _test_memory(fwd_fn, quantiles=QUANTILES) + elif mode == "backward": + do = torch.randn_like(x) + y = fwd_fn() + mem_50, mem_20, mem_80 = _test_memory(lambda: y.backward(do, retain_graph=True), quantiles=QUANTILES) else: def full(): - y = fwd() + y = fwd_fn() y.backward(torch.randn_like(y), retain_graph=True) - ms_50, ms_20, ms_80 = triton.testing.do_bench( - full, - grad_to_none=[x], - rep=10, - quantiles=QUANTILES, + mem_50, mem_20, mem_80 = _test_memory(full, quantiles=QUANTILES) + return SingleBenchmarkRunOutput(y_20=mem_20, y_50=mem_50, y_80=mem_80) + + +def _resolve_model_config_tiled_mlp(input: SingleBenchmarkRunInput): + cfg = input.extra_benchmark_config + model_info = cfg["model_configs"][input.x] + return _setup_tiled_mlp( + SingleBenchmarkRunInput( + x=input.x, + kernel_provider=input.kernel_provider, + extra_benchmark_config={ + "hidden_size": model_info["hidden_size"], + "intermediate_size": model_info["intermediate_size"], + "hidden_act": model_info["hidden_act"], + "dtype": model_info["dtype"], + "activation_type": cfg["activation_type"], + "num_shards": cfg.get("num_shards", None), + "bsz": cfg["bsz"], + "seq_len": cfg["seq_len"], + }, ) - - return SingleBenchmarkRunOutput( - y_20=ms_20, - y_50=ms_50, - y_80=ms_80, ) -def bench_memory_tiled_mlp(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: - seq_len = input.x - bsz = input.extra_benchmark_config["bsz"] - hidden_size = input.extra_benchmark_config["hidden_size"] - intermediate_size = input.extra_benchmark_config["intermediate_size"] - hidden_act = input.extra_benchmark_config["hidden_act"] - dtype = input.extra_benchmark_config["dtype"] - num_shards = input.extra_benchmark_config.get("num_shards", None) - activation_type = input.extra_benchmark_config["activation_type"] - provider = input.kernel_provider +def bench_speed_tiled_mlp_model_config(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + x, fwd_fn = _resolve_model_config_tiled_mlp(input) mode = input.kernel_operation_mode - llama_config = LlamaConfig( - hidden_size=hidden_size, - intermediate_size=intermediate_size, - hidden_act=hidden_act, - ) + if mode == "forward": + ms_50, ms_20, ms_80 = triton.testing.do_bench(fwd_fn, grad_to_none=[x], rep=10, quantiles=QUANTILES) + elif mode == "backward": + do = torch.randn_like(x) + y = fwd_fn() + ms_50, ms_20, ms_80 = triton.testing.do_bench( + lambda: y.backward(do, retain_graph=True), + grad_to_none=[x], + rep=10, + quantiles=QUANTILES, + ) + elif mode == "full": - x_shape = (bsz, seq_len, hidden_size) - # initialize input - x = torch.randn(*x_shape, device=device, dtype=dtype, requires_grad=True) + def full(): + y = fwd_fn() + y.backward(torch.randn_like(y), retain_graph=True) - if activation_type == "geglu": - if provider == "huggingface": - layer = LlamaMLP(config=llama_config).to(device).to(dtype) - elif provider == "liger": - layer = LigerGEGLUMLP(config=llama_config).to(device).to(dtype) - elif provider == "liger_tiled": - layer = LigerTiledGEGLUMLP(config=llama_config, num_shards=num_shards).to(device).to(dtype) - elif provider == "deepspeed_tiled": - layer = DeepSpeedTiledMLPWrapper(config=llama_config, num_shards=num_shards).to(device).to(dtype) - else: - raise ValueError(f"Invalid provider: {provider} for GEGLU") - elif activation_type == "swiglu": - if provider == "huggingface": - layer = LlamaMLP(config=llama_config).to(device).to(dtype) - elif provider == "liger": - layer = LigerSwiGLUMLP(config=llama_config).to(device).to(dtype) - elif provider == "liger_tiled": - layer = LigerTiledSwiGLUMLP(config=llama_config, num_shards=num_shards).to(device).to(dtype) - elif provider == "deepspeed_tiled": - layer = DeepSpeedTiledMLPWrapper(config=llama_config, num_shards=num_shards).to(device).to(dtype) - else: - raise ValueError(f"Invalid provider: {provider} for SwiGLU") + ms_50, ms_20, ms_80 = triton.testing.do_bench(full, grad_to_none=[x], rep=10, quantiles=QUANTILES) else: - raise ValueError(f"Invalid activation_type: {activation_type}") + raise ValueError(f"Unsupported mode: {mode}") + return SingleBenchmarkRunOutput(y_20=ms_20, y_50=ms_50, y_80=ms_80) - def fwd(): - return layer(x) - def full(): - y = fwd() - y.backward(torch.randn_like(y), retain_graph=True) +def bench_memory_tiled_mlp_model_config(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + x, fwd_fn = _resolve_model_config_tiled_mlp(input) + mode = input.kernel_operation_mode if mode == "forward": - mem_50, mem_20, mem_80 = _test_memory( - fwd, - quantiles=QUANTILES, - ) + mem_50, mem_20, mem_80 = _test_memory(fwd_fn, quantiles=QUANTILES) elif mode == "backward": do = torch.randn_like(x) - y = fwd() - mem_50, mem_20, mem_80 = _test_memory( - lambda: y.backward(do, retain_graph=True), - quantiles=QUANTILES, + y = fwd_fn() + mem_50, mem_20, mem_80 = _test_memory(lambda: y.backward(do, retain_graph=True), quantiles=QUANTILES) + else: + + def full(): + y = fwd_fn() + y.backward(torch.randn_like(y), retain_graph=True) + + mem_50, mem_20, mem_80 = _test_memory(full, quantiles=QUANTILES) + return SingleBenchmarkRunOutput(y_20=mem_20, y_50=mem_50, y_80=mem_80) + + +def _run_tiled_mlp_benchmarks(args, activation_type, hidden_act, kernel_name): + """Run D1 or D2 benchmarks for a given activation type.""" + kernel_providers = ["huggingface", "liger", "liger_tiled", "deepspeed_tiled"] + + if args.sweep_mode == "model_config": + all_model_configs = list(MODEL_REGISTRY.values()) + bsz = 2 + seq_len = 2048 + + def _probe_factory(model_cfg, probe_bt): + def _probe(): + probe_input = SingleBenchmarkRunInput( + x=0, + kernel_provider="huggingface", + extra_benchmark_config={ + "hidden_size": model_cfg.hidden_size, + "intermediate_size": model_cfg.intermediate_size, + "hidden_act": hidden_act, + "dtype": model_cfg.dtype, + "activation_type": activation_type, + "num_shards": 4, + "bsz": bsz, + "seq_len": seq_len, + }, + ) + _, fwd_fn = _setup_tiled_mlp(probe_input) + return fwd_fn() + + return _probe + + sweep = compute_model_config_sweep_config(all_model_configs, probe_fn_factory=_probe_factory, bt=args.bt) + model_configs_info = { + cfg.name: { + "hidden_size": cfg.hidden_size, + "intermediate_size": cfg.intermediate_size, + "hidden_act": cfg.hidden_act, + "dtype": cfg.dtype, + } + for cfg in sweep.model_configs + } + + common_configs = { + "kernel_name": kernel_name, + "x_name": "model_config", + "x_label": "model configuration", + "x_values": [cfg.name for cfg in sweep.model_configs], + "kernel_providers": kernel_providers, + "extra_benchmark_configs": [ + { + "model_configs": model_configs_info, + "activation_type": activation_type, + "num_shards": 4, + "bsz": bsz, + "seq_len": seq_len, + } + ], + "overwrite": args.overwrite, + } + + run_benchmarks( + bench_test_fn=bench_speed_tiled_mlp_model_config, + kernel_operation_modes=["forward", "backward", "full"], + metric_name="speed", + metric_unit="ms", + **common_configs, + ) + run_benchmarks( + bench_test_fn=bench_memory_tiled_mlp_model_config, + kernel_operation_modes=["full", "forward", "backward"], + metric_name="memory", + metric_unit="MB", + **common_configs, ) else: - mem_50, mem_20, mem_80 = _test_memory( - full, - quantiles=QUANTILES, + model = get_benchmark_model_config(args.model) + bsz = 2 + probe_seq_len = 2048 + + def _probe(): + probe_input = SingleBenchmarkRunInput( + x=0, + kernel_provider="huggingface", + extra_benchmark_config={ + "hidden_size": model.hidden_size, + "intermediate_size": model.intermediate_size, + "hidden_act": hidden_act, + "dtype": model.dtype, + "activation_type": activation_type, + "num_shards": 4, + "bsz": bsz, + "seq_len": probe_seq_len, + }, + ) + _, fwd_fn = _setup_tiled_mlp(probe_input) + return fwd_fn() + + peak_bytes = estimate_kernel_peak_memory(probe_fn=_probe) + kernel_bpt = peak_bytes // probe_seq_len + config = compute_seq_len_sweep_config(model, kernel_bytes_per_token=kernel_bpt) + + common_configs = { + "kernel_name": kernel_name, + "x_name": "T", + "x_label": "sequence length", + "x_values": [2**i for i in range(10, int(math.log2(max(1024, config.seq_len))) + 1)], + "kernel_providers": kernel_providers, + "extra_benchmark_configs": [ + { + "hidden_size": model.hidden_size, + "intermediate_size": model.intermediate_size, + "hidden_act": hidden_act, + "dtype": model.dtype, + "activation_type": activation_type, + "num_shards": 4, + "bsz": bsz, + } + ], + "overwrite": args.overwrite, + } + + run_benchmarks( + bench_test_fn=bench_speed_tiled_mlp, + kernel_operation_modes=["forward", "backward", "full"], + metric_name="speed", + metric_unit="ms", + **common_configs, + ) + run_benchmarks( + bench_test_fn=bench_memory_tiled_mlp, + kernel_operation_modes=["full", "forward", "backward"], + metric_name="memory", + metric_unit="MB", + **common_configs, ) - - return SingleBenchmarkRunOutput( - y_20=mem_20, - y_50=mem_50, - y_80=mem_80, - ) if __name__ == "__main__": args = parse_benchmark_script_args() # Benchmark GEGLU variants - kernel_providers_geglu = ["huggingface", "liger", "liger_tiled", "deepspeed_tiled"] - - common_configs_geglu = { - "kernel_name": "tiled_geglu", - "x_name": "T", - "x_label": "sequence length", - "x_values": [2**i for i in range(10, 15)], # 1024 to 16384 - "kernel_providers": kernel_providers_geglu, - "extra_benchmark_configs": [ - { - "bsz": 2, - "hidden_size": 2048, - "intermediate_size": 4096, - "hidden_act": "gelu_pytorch_tanh", - "activation_type": "geglu", - "num_shards": 4, - "dtype": torch.bfloat16, - } - ], - "overwrite": args.overwrite, - } - - run_benchmarks( - bench_test_fn=bench_speed_tiled_mlp, - kernel_operation_modes=["full", "forward", "backward"], - metric_name="speed", - metric_unit="ms", - **common_configs_geglu, - ) - run_benchmarks( - bench_test_fn=bench_memory_tiled_mlp, - kernel_operation_modes=["full", "forward", "backward"], - metric_name="memory", - metric_unit="MB", - **common_configs_geglu, - ) + _run_tiled_mlp_benchmarks(args, activation_type="geglu", hidden_act="gelu_pytorch_tanh", kernel_name="tiled_geglu") # Benchmark SwiGLU variants - kernel_providers_swiglu = ["huggingface", "liger", "liger_tiled", "deepspeed_tiled"] - - common_configs_swiglu = { - "kernel_name": "tiled_swiglu", - "x_name": "T", - "x_label": "sequence length", - "x_values": [2**i for i in range(10, 15)], # 1024 to 16384 - "kernel_providers": kernel_providers_swiglu, - "extra_benchmark_configs": [ - { - "bsz": 2, - "hidden_size": 2048, - "intermediate_size": 4096, - "hidden_act": "silu", - "activation_type": "swiglu", - "num_shards": 4, - "dtype": torch.bfloat16, - } - ], - "overwrite": args.overwrite, - } - - run_benchmarks( - bench_test_fn=bench_speed_tiled_mlp, - kernel_operation_modes=["full", "forward", "backward"], - metric_name="speed", - metric_unit="ms", - **common_configs_swiglu, - ) - run_benchmarks( - bench_test_fn=bench_memory_tiled_mlp, - kernel_operation_modes=["full", "forward", "backward"], - metric_name="memory", - metric_unit="MB", - **common_configs_swiglu, - ) + _run_tiled_mlp_benchmarks(args, activation_type="swiglu", hidden_act="silu", kernel_name="tiled_swiglu") diff --git a/benchmark/scripts/benchmark_tvd.py b/benchmark/scripts/benchmark_tvd.py index ef76380a2..7bdba1540 100644 --- a/benchmark/scripts/benchmark_tvd.py +++ b/benchmark/scripts/benchmark_tvd.py @@ -1,6 +1,13 @@ +import math + import torch import triton +from benchmark_model_configs import MODEL_REGISTRY +from benchmark_model_configs import compute_model_config_sweep_config +from benchmark_model_configs import compute_seq_len_sweep_config +from benchmark_model_configs import estimate_kernel_peak_memory +from benchmark_model_configs import get_benchmark_model_config from utils import QUANTILES from utils import SingleBenchmarkRunInput from utils import SingleBenchmarkRunOutput @@ -9,7 +16,6 @@ from utils import run_benchmarks from liger_kernel.transformers.tvd import LigerTVDLoss -from liger_kernel.utils import get_total_gpu_memory from liger_kernel.utils import infer_device device = infer_device() @@ -34,112 +40,225 @@ def forward(self, p, q): raise ValueError("Invalid reduction type.") -S, E = 12, 18 +def _setup_tvd(input: SingleBenchmarkRunInput): + """Create input tensors and TVD loss from benchmark config.""" + cfg = input.extra_benchmark_config + V = cfg["vocab_size"] + BT = input.x + reduction = "batchmean" + _input = torch.randn(BT, V, requires_grad=True, device=device).softmax(dim=-1) + target = torch.randn(BT, V, device=device).softmax(dim=-1) + + if input.kernel_provider == "liger": + loss_fn = LigerTVDLoss(reduction=reduction) + elif input.kernel_provider == "torch": + loss_fn = TorchTVDLoss(reduction=reduction) + else: + raise ValueError(f"Invalid provider: {input.kernel_provider} for TVD") + return _input, target, loss_fn -def bench_speed_tvd(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: - reduction = "batchmean" - V = input.x - B, T = input.extra_benchmark_config["B"], input.extra_benchmark_config["T"] - torch_tvd = TorchTVDLoss(reduction=reduction) - liger_tvd = LigerTVDLoss(reduction=reduction) - _input = torch.randn(B * T, V, requires_grad=True, device=device).softmax(dim=-1) - target = torch.randn(B * T, V, device=device).softmax(dim=-1) +def bench_speed_tvd(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + _input, target, loss_fn = _setup_tvd(input) + mode = input.kernel_operation_mode def fwd(): - if input.kernel_provider == "liger": - return liger_tvd(_input, target) - else: - return torch_tvd(_input, target) + return loss_fn(_input, target) - if input.kernel_operation_mode == "forward": + if mode == "forward": ms_50, ms_20, ms_80 = triton.testing.do_bench(fwd, quantiles=QUANTILES, rep=100) - elif input.kernel_operation_mode == "backward": + elif mode == "backward": y = fwd() - ms_50, ms_20, ms_80 = triton.testing.do_bench( lambda: y.backward(retain_graph=True), quantiles=QUANTILES, grad_to_none=[_input], rep=100, ) - elif input.kernel_operation_mode == "full": + elif mode == "full": def full(): y = fwd() y.backward(retain_graph=True) ms_50, ms_20, ms_80 = triton.testing.do_bench(full, quantiles=QUANTILES, rep=100) - return SingleBenchmarkRunOutput( - y_20=ms_20, - y_50=ms_50, - y_80=ms_80, - ) + else: + raise ValueError(f"Unsupported mode: {mode}") + + return SingleBenchmarkRunOutput(y_20=ms_20, y_50=ms_50, y_80=ms_80) def bench_memory_tvd(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: - reduction = "batchmean" - torch_tvd = TorchTVDLoss(reduction=reduction) - liger_tvd = LigerTVDLoss(reduction=reduction) + _input, target, loss_fn = _setup_tvd(input) - V = input.x - B, T = input.extra_benchmark_config["B"], input.extra_benchmark_config["T"] + def full(): + y = loss_fn(_input, target) + y.backward(retain_graph=True) - _input = torch.randn(B * T, V, requires_grad=True, device=device).softmax(dim=-1) - target = torch.randn(B * T, V, device=device).softmax(dim=-1) + mem_50, mem_20, mem_80 = _test_memory(full, quantiles=QUANTILES) + return SingleBenchmarkRunOutput(y_20=mem_20, y_50=mem_50, y_80=mem_80) + + +def _resolve_model_config_tvd(input: SingleBenchmarkRunInput): + """Resolve model-config-sweep input into standard setup args.""" + cfg = input.extra_benchmark_config + model_info = cfg["model_configs"][input.x] + return _setup_tvd( + SingleBenchmarkRunInput( + x=cfg["BT"], + kernel_provider=input.kernel_provider, + extra_benchmark_config={ + "vocab_size": model_info["vocab_size"], + }, + ) + ) + + +def bench_speed_tvd_model_config(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + _input, target, loss_fn = _resolve_model_config_tvd(input) + mode = input.kernel_operation_mode def fwd(): - if input.kernel_provider == "liger": - return liger_tvd(_input, target) - else: - return torch_tvd(_input, target) + return loss_fn(_input, target) - def full(): + if mode == "forward": + ms_50, ms_20, ms_80 = triton.testing.do_bench(fwd, quantiles=QUANTILES, rep=100) + elif mode == "backward": y = fwd() + ms_50, ms_20, ms_80 = triton.testing.do_bench( + lambda: y.backward(retain_graph=True), + quantiles=QUANTILES, + grad_to_none=[_input], + rep=100, + ) + elif mode == "full": + + def full(): + y = fwd() + y.backward(retain_graph=True) + + ms_50, ms_20, ms_80 = triton.testing.do_bench(full, quantiles=QUANTILES, rep=100) + else: + raise ValueError(f"Unsupported mode: {mode}") + + return SingleBenchmarkRunOutput(y_20=ms_20, y_50=ms_50, y_80=ms_80) + + +def bench_memory_tvd_model_config(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + _input, target, loss_fn = _resolve_model_config_tvd(input) + + def full(): + y = loss_fn(_input, target) y.backward(retain_graph=True) mem_50, mem_20, mem_80 = _test_memory(full, quantiles=QUANTILES) - - return SingleBenchmarkRunOutput( - y_20=mem_20, - y_50=mem_50, - y_80=mem_80, - ) + return SingleBenchmarkRunOutput(y_20=mem_20, y_50=mem_50, y_80=mem_80) if __name__ == "__main__": args = parse_benchmark_script_args() - gpu_memory_gbs = get_total_gpu_memory() - # We know that the full test will require 66GBs for vocab size 2^17 - if gpu_memory_gbs >= 66: - x_max = 17 - elif gpu_memory_gbs >= 32: - x_max = 16 + + if args.sweep_mode == "model_config": + all_model_configs = list(MODEL_REGISTRY.values()) + + def _probe_factory(model_cfg, probe_bt): + def _probe(): + probe_input = SingleBenchmarkRunInput( + x=probe_bt, + kernel_provider="torch", + extra_benchmark_config={ + "vocab_size": model_cfg.vocab_size, + }, + ) + _input, target, loss_fn = _setup_tvd(probe_input) + return loss_fn(_input, target) + + return _probe + + sweep = compute_model_config_sweep_config(all_model_configs, probe_fn_factory=_probe_factory, bt=args.bt) + + model_configs_info = { + cfg.name: { + "vocab_size": cfg.vocab_size, + } + for cfg in sweep.model_configs + } + + common_configs = { + "kernel_name": "tvd", + "x_name": "model_config", + "x_label": "model configuration", + "x_values": [cfg.name for cfg in sweep.model_configs], + "kernel_providers": ["liger", "torch"], + "extra_benchmark_configs": [ + { + "model_configs": model_configs_info, + "BT": sweep.bt, + } + ], + "overwrite": args.overwrite, + } + + run_benchmarks( + bench_test_fn=bench_speed_tvd_model_config, + kernel_operation_modes=["forward", "full", "backward"], + metric_name="speed", + metric_unit="ms", + **common_configs, + ) + run_benchmarks( + bench_test_fn=bench_memory_tvd_model_config, + kernel_operation_modes=["full"], + metric_name="memory", + metric_unit="MB", + **common_configs, + ) else: - x_max = 15 - common_args = { - "kernel_name": "tvd", - "x_name": "V", - "x_label": "vocab size", - "x_values": [2**i for i in range(12, x_max + 1)], - "kernel_providers": ["liger", "torch"], - "extra_benchmark_configs": [{"B": 8, "T": 2048}], - "overwrite": args.overwrite, - } - - run_benchmarks( - bench_test_fn=bench_memory_tvd, - kernel_operation_modes=["full"], - metric_name="memory", - metric_unit="MB", - **common_args, - ) + model = get_benchmark_model_config(args.model) + probe_bt = 1024 - run_benchmarks( - bench_test_fn=bench_speed_tvd, - kernel_operation_modes=["forward", "full", "backward"], - metric_name="speed", - metric_unit="ms", - **common_args, - ) + def _probe(): + probe_input = SingleBenchmarkRunInput( + x=probe_bt, + kernel_provider="torch", + extra_benchmark_config={ + "vocab_size": model.vocab_size, + }, + ) + _input, target, loss_fn = _setup_tvd(probe_input) + return loss_fn(_input, target) + + peak_bytes = estimate_kernel_peak_memory(probe_fn=_probe) + kernel_bpt = peak_bytes // probe_bt + + config = compute_seq_len_sweep_config(model, kernel_bytes_per_token=kernel_bpt) + + common_configs = { + "kernel_name": "tvd", + "x_name": "BT", + "x_label": "B * T", + "x_values": [2**i for i in range(10, int(math.log2(config.batch_size * config.seq_len)) + 1)], + "kernel_providers": ["liger", "torch"], + "extra_benchmark_configs": [ + { + "vocab_size": model.vocab_size, + } + ], + "overwrite": args.overwrite, + } + + run_benchmarks( + bench_test_fn=bench_speed_tvd, + kernel_operation_modes=["forward", "full", "backward"], + metric_name="speed", + metric_unit="ms", + **common_configs, + ) + run_benchmarks( + bench_test_fn=bench_memory_tvd, + kernel_operation_modes=["full"], + metric_name="memory", + metric_unit="MB", + **common_configs, + )