Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 84 additions & 37 deletions benchmark/benchmarks_visualizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,50 +319,97 @@ def plot_data(df: pd.DataFrame, config: VisualizationsConfig):

plt.figure(figsize=(10, 6))
sns.set(style="whitegrid")
try:
ax = sns.lineplot(
data=df,
x="x_value",
y="y_value_50",
hue="kernel_provider",
marker="o",
palette="tab10",
errorbar=("ci", None),
)
except Exception:
ax = sns.lineplot(

use_bar_chart = config.sweep_mode == "model_config"

if use_bar_chart:
# Grouped bar chart for model_config sweep
ax = sns.barplot(
data=df,
x="x_value",
y="y_value_50",
hue="kernel_provider",
marker="o",
palette="tab10",
errorbar=None,
edgecolor="black",
linewidth=0.5,
)

# For numeric x axes, show tick labels only at actual data points
if is_numeric_x:
tick_values = sorted(df["x_value"].unique())
ax.set_xticks(tick_values)
ax.set_xticklabels([str(int(v)) if v == int(v) else str(v) for v in tick_values])

# Seaborn can't plot pre-computed error bars, so we need to do it manually
lines = ax.get_lines()
colors = [line.get_color() for line in lines]

for (_, group_data), color in zip(df.groupby("kernel_provider"), colors):
y_error_lower = group_data["y_value_50"] - group_data["y_value_20"]
y_error_upper = group_data["y_value_80"] - group_data["y_value_50"]
y_error = [y_error_lower, y_error_upper]

plt.errorbar(
group_data["x_value"],
group_data["y_value_50"],
yerr=y_error,
fmt="o",
color=color,
capsize=5,
)
# Add error bars on each bar using pre-computed percentiles
providers = df.sort_values("kernel_provider")["kernel_provider"].unique()
x_values = df["x_value"].unique()
n_providers = len(providers)
bar_width = 0.8 / n_providers # seaborn default total width is 0.8

for i, provider in enumerate(providers):
group_data = df[df["kernel_provider"] == provider]
for j, x_val in enumerate(x_values):
row = group_data[group_data["x_value"] == x_val]
if row.empty:
continue
y_val = row["y_value_50"].values[0]
y_err_lower = y_val - row["y_value_20"].values[0]
y_err_upper = row["y_value_80"].values[0] - y_val
bar_x = j + (i - (n_providers - 1) / 2) * bar_width
ax.errorbar(
bar_x,
y_val,
yerr=[[y_err_lower], [y_err_upper]],
fmt="none",
color="black",
capsize=3,
linewidth=1,
)

# Rotate x labels if they are long model config names
if not is_numeric_x:
plt.xticks(rotation=30, ha="right")
else:
# Line chart for token_length sweep
try:
ax = sns.lineplot(
data=df,
x="x_value",
y="y_value_50",
hue="kernel_provider",
marker="o",
palette="tab10",
errorbar=("ci", None),
)
except Exception:
ax = sns.lineplot(
data=df,
x="x_value",
y="y_value_50",
hue="kernel_provider",
marker="o",
palette="tab10",
errorbar=None,
)

# For numeric x axes, show tick labels only at actual data points
if is_numeric_x:
tick_values = sorted(df["x_value"].unique())
ax.set_xticks(tick_values)
ax.set_xticklabels([str(int(v)) if v == int(v) else str(v) for v in tick_values])

# Seaborn can't plot pre-computed error bars, so we need to do it manually
lines = ax.get_lines()
colors = [line.get_color() for line in lines]

for (_, group_data), color in zip(df.groupby("kernel_provider"), colors):
y_error_lower = group_data["y_value_50"] - group_data["y_value_20"]
y_error_upper = group_data["y_value_80"] - group_data["y_value_50"]
y_error = [y_error_lower, y_error_upper]

plt.errorbar(
group_data["x_value"],
group_data["y_value_50"],
yerr=y_error,
fmt="o",
color=color,
capsize=5,
)

plt.legend(title="Kernel Provider")
plt.xlabel(xlabel)
plt.ylabel(ylabel)
Expand Down
203 changes: 149 additions & 54 deletions benchmark/scripts/benchmark_attn_res.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@

sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")))

from benchmark_model_configs import MODEL_REGISTRY
from benchmark_model_configs import compute_model_config_sweep_config
from benchmark_model_configs import compute_seq_len_sweep_config
from benchmark_model_configs import estimate_kernel_peak_memory
from benchmark_model_configs import get_benchmark_model_config
Expand Down Expand Up @@ -69,61 +71,154 @@ def bench_memory_attn_res(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunO
return run_memory_benchmark(fn, input.kernel_operation_mode)


if __name__ == "__main__":
args = parse_benchmark_script_args()

model = get_benchmark_model_config(args.model)
probe_seq_len = 1024

def _probe():
probe_input = SingleBenchmarkRunInput(
x=probe_seq_len,
kernel_provider="pytorch",
def _resolve_model_config_attn_res(input: SingleBenchmarkRunInput):
"""Resolve model-config-sweep input into standard setup args."""
cfg = input.extra_benchmark_config
model_info = cfg["model_configs"][input.x]
return _setup_attn_res(
SingleBenchmarkRunInput(
x=cfg["seq_len"],
kernel_provider=input.kernel_provider,
extra_benchmark_config={
"N": 8,
"bsz": 1,
"hidden_size": model.hidden_size,
"dtype": model.dtype,
"eps": 1e-6,
"N": cfg["N"],
"bsz": cfg["bsz"],
"hidden_size": model_info["hidden_size"],
"dtype": model_info["dtype"],
"eps": cfg.get("eps", 1e-6),
},
)
V, fn = _setup_attn_res(probe_input)
return fn()

peak_bytes = estimate_kernel_peak_memory(probe_fn=_probe)
kernel_bpt = peak_bytes // probe_seq_len

config = compute_seq_len_sweep_config(model, kernel_bytes_per_token=kernel_bpt)

common_configs = {
"kernel_name": "attn_res",
"x_name": "T",
"x_label": "sequence length",
"x_values": [2**i for i in range(10, int(math.log2(config.seq_len)) + 1)],
"kernel_providers": ["liger", "pytorch"],
"extra_benchmark_configs": [
{
"N": 8,
"bsz": config.batch_size,
"hidden_size": model.hidden_size,
"dtype": model.dtype,
"eps": 1e-6,
}
],
"overwrite": args.overwrite,
}

run_benchmarks(
bench_test_fn=bench_speed_attn_res,
kernel_operation_modes=["full", "forward", "backward"],
metric_name="speed",
metric_unit="ms",
**common_configs,
)
run_benchmarks(
bench_test_fn=bench_memory_attn_res,
kernel_operation_modes=["full", "forward", "backward"],
metric_name="memory",
metric_unit="MB",
**common_configs,
)


def bench_speed_attn_res_model_config(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
V, fn = _resolve_model_config_attn_res(input)
return run_speed_benchmark(fn, input.kernel_operation_mode, [V])


def bench_memory_attn_res_model_config(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
V, fn = _resolve_model_config_attn_res(input)
return run_memory_benchmark(fn, input.kernel_operation_mode)


if __name__ == "__main__":
args = parse_benchmark_script_args()

if args.sweep_mode == "model_config":
all_model_configs = list(MODEL_REGISTRY.values())

def _probe_factory(model_cfg, probe_seq_len):
def _probe():
probe_input = SingleBenchmarkRunInput(
x=probe_seq_len,
kernel_provider="pytorch",
extra_benchmark_config={
"N": 8,
"bsz": 1,
"hidden_size": model_cfg.hidden_size,
"dtype": model_cfg.dtype,
"eps": 1e-6,
},
)
V, fn = _setup_attn_res(probe_input)
return fn()

return _probe

sweep = compute_model_config_sweep_config(all_model_configs, probe_fn_factory=_probe_factory, bt=args.bt)

model_configs_info = {
cfg.name: {
"hidden_size": cfg.hidden_size,
"dtype": cfg.dtype,
}
for cfg in sweep.model_configs
}

common_configs = {
"kernel_name": "attn_res",
"x_name": "model_config",
"x_label": "model configuration",
"x_values": [cfg.name for cfg in sweep.model_configs],
"kernel_providers": ["liger", "pytorch"],
"extra_benchmark_configs": [
{
"model_configs": model_configs_info,
"N": 8,
"bsz": sweep.batch_size,
"seq_len": sweep.seq_len,
"eps": 1e-6,
}
],
"overwrite": args.overwrite,
}

run_benchmarks(
bench_test_fn=bench_speed_attn_res_model_config,
kernel_operation_modes=["full", "forward", "backward"],
metric_name="speed",
metric_unit="ms",
**common_configs,
)
run_benchmarks(
bench_test_fn=bench_memory_attn_res_model_config,
kernel_operation_modes=["full", "forward", "backward"],
metric_name="memory",
metric_unit="MB",
**common_configs,
)
else:
model = get_benchmark_model_config(args.model)
probe_seq_len = 1024

def _probe():
probe_input = SingleBenchmarkRunInput(
x=probe_seq_len,
kernel_provider="pytorch",
extra_benchmark_config={
"N": 8,
"bsz": 1,
"hidden_size": model.hidden_size,
"dtype": model.dtype,
"eps": 1e-6,
},
)
V, fn = _setup_attn_res(probe_input)
return fn()

peak_bytes = estimate_kernel_peak_memory(probe_fn=_probe)
kernel_bpt = peak_bytes // probe_seq_len

config = compute_seq_len_sweep_config(model, kernel_bytes_per_token=kernel_bpt)

common_configs = {
"kernel_name": "attn_res",
"x_name": "T",
"x_label": "sequence length",
"x_values": [2**i for i in range(10, int(math.log2(config.seq_len)) + 1)],
"kernel_providers": ["liger", "pytorch"],
"extra_benchmark_configs": [
{
"N": 8,
"bsz": config.batch_size,
"hidden_size": model.hidden_size,
"dtype": model.dtype,
"eps": 1e-6,
}
],
"overwrite": args.overwrite,
}

run_benchmarks(
bench_test_fn=bench_speed_attn_res,
kernel_operation_modes=["full", "forward", "backward"],
metric_name="speed",
metric_unit="ms",
**common_configs,
)
run_benchmarks(
bench_test_fn=bench_memory_attn_res,
kernel_operation_modes=["full", "forward", "backward"],
metric_name="memory",
metric_unit="MB",
**common_configs,
)
Loading