Skip to content

Commit 2ca3bd0

Browse files
lowdy1noemotiovon
andauthored
[Benchmark] Add ModelConfig Sweep Support and Pre-Probe to Remaining Benchmarks (#1195)
## Summary * Add `model_config` sweep support to training and Rope op benchmarks, enabling evaluation across different model architectures at a fixed sequence length * Introduce a pre-sweep probe step to determine safe configurations and prevent OOM during benchmarking benchmark_attn_res.py benchmark_embedding.py benchmark_fused_neighborhood_attention.py (retains all kernel configs for `model_config` sweep instead of sweeping model architectures) benchmark_mhc.py benchmark_mhc_lm.py benchmark_multi_token_attention.py benchmark_relu_squared.py benchmark_softmax.py benchmark_sparsemax.py benchmark_sparse_multi_token_attention.py benchmark_tiled_mlp.py - Hardware Type: <BLANK> - [ ] run `make test` to ensure correctness - [x] run `make checkstyle` to ensure code style - [ ] run `make test-convergence` to ensure convergence --------- Co-authored-by: noemotiovon <757486878@qq.com>
1 parent ef85f24 commit 2ca3bd0

11 files changed

Lines changed: 2267 additions & 1040 deletions

benchmark/scripts/benchmark_attn_res.py

Lines changed: 149 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212

1313
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")))
1414

15+
from benchmark_model_configs import MODEL_REGISTRY
16+
from benchmark_model_configs import compute_model_config_sweep_config
1517
from benchmark_model_configs import compute_seq_len_sweep_config
1618
from benchmark_model_configs import estimate_kernel_peak_memory
1719
from benchmark_model_configs import get_benchmark_model_config
@@ -69,61 +71,154 @@ def bench_memory_attn_res(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunO
6971
return run_memory_benchmark(fn, input.kernel_operation_mode)
7072

7173

72-
if __name__ == "__main__":
73-
args = parse_benchmark_script_args()
74-
75-
model = get_benchmark_model_config(args.model)
76-
probe_seq_len = 1024
77-
78-
def _probe():
79-
probe_input = SingleBenchmarkRunInput(
80-
x=probe_seq_len,
81-
kernel_provider="pytorch",
74+
def _resolve_model_config_attn_res(input: SingleBenchmarkRunInput):
75+
"""Resolve model-config-sweep input into standard setup args."""
76+
cfg = input.extra_benchmark_config
77+
model_info = cfg["model_configs"][input.x]
78+
return _setup_attn_res(
79+
SingleBenchmarkRunInput(
80+
x=cfg["seq_len"],
81+
kernel_provider=input.kernel_provider,
8282
extra_benchmark_config={
83-
"N": 8,
84-
"bsz": 1,
85-
"hidden_size": model.hidden_size,
86-
"dtype": model.dtype,
87-
"eps": 1e-6,
83+
"N": cfg["N"],
84+
"bsz": cfg["bsz"],
85+
"hidden_size": model_info["hidden_size"],
86+
"dtype": model_info["dtype"],
87+
"eps": cfg.get("eps", 1e-6),
8888
},
8989
)
90-
V, fn = _setup_attn_res(probe_input)
91-
return fn()
92-
93-
peak_bytes = estimate_kernel_peak_memory(probe_fn=_probe)
94-
kernel_bpt = peak_bytes // probe_seq_len
95-
96-
config = compute_seq_len_sweep_config(model, kernel_bytes_per_token=kernel_bpt)
97-
98-
common_configs = {
99-
"kernel_name": "attn_res",
100-
"x_name": "T",
101-
"x_label": "sequence length",
102-
"x_values": [2**i for i in range(10, int(math.log2(config.seq_len)) + 1)],
103-
"kernel_providers": ["liger", "pytorch"],
104-
"extra_benchmark_configs": [
105-
{
106-
"N": 8,
107-
"bsz": config.batch_size,
108-
"hidden_size": model.hidden_size,
109-
"dtype": model.dtype,
110-
"eps": 1e-6,
111-
}
112-
],
113-
"overwrite": args.overwrite,
114-
}
115-
116-
run_benchmarks(
117-
bench_test_fn=bench_speed_attn_res,
118-
kernel_operation_modes=["full", "forward", "backward"],
119-
metric_name="speed",
120-
metric_unit="ms",
121-
**common_configs,
122-
)
123-
run_benchmarks(
124-
bench_test_fn=bench_memory_attn_res,
125-
kernel_operation_modes=["full", "forward", "backward"],
126-
metric_name="memory",
127-
metric_unit="MB",
128-
**common_configs,
12990
)
91+
92+
93+
def bench_speed_attn_res_model_config(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
94+
V, fn = _resolve_model_config_attn_res(input)
95+
return run_speed_benchmark(fn, input.kernel_operation_mode, [V])
96+
97+
98+
def bench_memory_attn_res_model_config(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
99+
V, fn = _resolve_model_config_attn_res(input)
100+
return run_memory_benchmark(fn, input.kernel_operation_mode)
101+
102+
103+
if __name__ == "__main__":
104+
args = parse_benchmark_script_args()
105+
106+
if args.sweep_mode == "model_config":
107+
all_model_configs = list(MODEL_REGISTRY.values())
108+
109+
def _probe_factory(model_cfg, probe_seq_len):
110+
def _probe():
111+
probe_input = SingleBenchmarkRunInput(
112+
x=probe_seq_len,
113+
kernel_provider="pytorch",
114+
extra_benchmark_config={
115+
"N": 8,
116+
"bsz": 1,
117+
"hidden_size": model_cfg.hidden_size,
118+
"dtype": model_cfg.dtype,
119+
"eps": 1e-6,
120+
},
121+
)
122+
V, fn = _setup_attn_res(probe_input)
123+
return fn()
124+
125+
return _probe
126+
127+
sweep = compute_model_config_sweep_config(all_model_configs, probe_fn_factory=_probe_factory, bt=args.bt)
128+
129+
model_configs_info = {
130+
cfg.name: {
131+
"hidden_size": cfg.hidden_size,
132+
"dtype": cfg.dtype,
133+
}
134+
for cfg in sweep.model_configs
135+
}
136+
137+
common_configs = {
138+
"kernel_name": "attn_res",
139+
"x_name": "model_config",
140+
"x_label": "model configuration",
141+
"x_values": [cfg.name for cfg in sweep.model_configs],
142+
"kernel_providers": ["liger", "pytorch"],
143+
"extra_benchmark_configs": [
144+
{
145+
"model_configs": model_configs_info,
146+
"N": 8,
147+
"bsz": sweep.batch_size,
148+
"seq_len": sweep.seq_len,
149+
"eps": 1e-6,
150+
}
151+
],
152+
"overwrite": args.overwrite,
153+
}
154+
155+
run_benchmarks(
156+
bench_test_fn=bench_speed_attn_res_model_config,
157+
kernel_operation_modes=["full", "forward", "backward"],
158+
metric_name="speed",
159+
metric_unit="ms",
160+
**common_configs,
161+
)
162+
run_benchmarks(
163+
bench_test_fn=bench_memory_attn_res_model_config,
164+
kernel_operation_modes=["full", "forward", "backward"],
165+
metric_name="memory",
166+
metric_unit="MB",
167+
**common_configs,
168+
)
169+
else:
170+
model = get_benchmark_model_config(args.model)
171+
probe_seq_len = 1024
172+
173+
def _probe():
174+
probe_input = SingleBenchmarkRunInput(
175+
x=probe_seq_len,
176+
kernel_provider="pytorch",
177+
extra_benchmark_config={
178+
"N": 8,
179+
"bsz": 1,
180+
"hidden_size": model.hidden_size,
181+
"dtype": model.dtype,
182+
"eps": 1e-6,
183+
},
184+
)
185+
V, fn = _setup_attn_res(probe_input)
186+
return fn()
187+
188+
peak_bytes = estimate_kernel_peak_memory(probe_fn=_probe)
189+
kernel_bpt = peak_bytes // probe_seq_len
190+
191+
config = compute_seq_len_sweep_config(model, kernel_bytes_per_token=kernel_bpt)
192+
193+
common_configs = {
194+
"kernel_name": "attn_res",
195+
"x_name": "T",
196+
"x_label": "sequence length",
197+
"x_values": [2**i for i in range(10, int(math.log2(config.seq_len)) + 1)],
198+
"kernel_providers": ["liger", "pytorch"],
199+
"extra_benchmark_configs": [
200+
{
201+
"N": 8,
202+
"bsz": config.batch_size,
203+
"hidden_size": model.hidden_size,
204+
"dtype": model.dtype,
205+
"eps": 1e-6,
206+
}
207+
],
208+
"overwrite": args.overwrite,
209+
}
210+
211+
run_benchmarks(
212+
bench_test_fn=bench_speed_attn_res,
213+
kernel_operation_modes=["full", "forward", "backward"],
214+
metric_name="speed",
215+
metric_unit="ms",
216+
**common_configs,
217+
)
218+
run_benchmarks(
219+
bench_test_fn=bench_memory_attn_res,
220+
kernel_operation_modes=["full", "forward", "backward"],
221+
metric_name="memory",
222+
metric_unit="MB",
223+
**common_configs,
224+
)

0 commit comments

Comments
 (0)