Skip to content

Commit 7618323

Browse files
authored
Refactor event_filter_regex handling in Ironwood attention benchmark (#77)
1 parent ecd4c82 commit 7618323

1 file changed

Lines changed: 4 additions & 11 deletions

File tree

Ironwood/src/benchmark_attention.py

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -114,11 +114,7 @@ def tokamax_splash_attention_benchmark(
114114
trace_dir: str = None,
115115
) -> Dict[str, Any]:
116116
"""Benchmarks the Tokamax Splash attention kernel."""
117-
118-
if tune_pallas_only:
119-
event_filter_regex = _pallas_call_hlo_pattern(mode, q_heads != kv_heads)
120-
else:
121-
event_filter_regex = None
117+
event_filter_regex = _pallas_call_hlo_pattern(mode, q_heads != kv_heads)
122118

123119
hyperparams_override = {}
124120
if mode == "bwd":
@@ -146,9 +142,7 @@ def tokamax_splash_attention_benchmark(
146142
mask = mask_lib.FullMask(_shape=(q_seq_len, kv_seq_len))
147143
if causal:
148144
# Pick offset for causal masks for a "representative" slice of the causal
149-
offset = (
150-
0 if q.shape[-2] == v.shape[-2] else (v.shape[-2] // 2 - q.shape[-2] // 2)
151-
)
145+
offset = v.shape[-2] - q.shape[-2]
152146
mask = mask_lib.CausalMask(shape=(q_seq_len, kv_seq_len), offset=offset)
153147

154148
def attention_fn(
@@ -250,7 +244,7 @@ def attention_fn(
250244
tuned_splash = tune_jax.tune(
251245
splash_fn,
252246
hyperparams=hyperparams,
253-
event_filter_regex=event_filter_regex,
247+
event_filter_regex=event_filter_regex if tune_pallas_only else None,
254248
sample_num=num_samples,
255249
)
256250

@@ -268,8 +262,7 @@ def attention_fn(
268262
task="tokamax_splash_attentionatt",
269263
trace_dir=trace_dir,
270264
event_name_str_list=[
271-
"splash_mqa_fwd_no_residuals.1",
272-
"splash_mqa_dkv_no_residuals.1",
265+
f"{event_filter_regex}_no_residuals.1",
273266
]
274267
)
275268
return {"time_ms_list": time_ms_list, "output": output}

0 commit comments

Comments
 (0)