@@ -114,11 +114,7 @@ def tokamax_splash_attention_benchmark(
114114 trace_dir : str = None ,
115115) -> Dict [str , Any ]:
116116 """Benchmarks the Tokamax Splash attention kernel."""
117-
118- if tune_pallas_only :
119- event_filter_regex = _pallas_call_hlo_pattern (mode , q_heads != kv_heads )
120- else :
121- event_filter_regex = None
117+ event_filter_regex = _pallas_call_hlo_pattern (mode , q_heads != kv_heads )
122118
123119 hyperparams_override = {}
124120 if mode == "bwd" :
@@ -146,9 +142,7 @@ def tokamax_splash_attention_benchmark(
146142 mask = mask_lib .FullMask (_shape = (q_seq_len , kv_seq_len ))
147143 if causal :
148144 # Pick offset for causal masks for a "representative" slice of the causal
149- offset = (
150- 0 if q .shape [- 2 ] == v .shape [- 2 ] else (v .shape [- 2 ] // 2 - q .shape [- 2 ] // 2 )
151- )
145+ offset = v .shape [- 2 ] - q .shape [- 2 ]
152146 mask = mask_lib .CausalMask (shape = (q_seq_len , kv_seq_len ), offset = offset )
153147
154148 def attention_fn (
@@ -250,7 +244,7 @@ def attention_fn(
250244 tuned_splash = tune_jax .tune (
251245 splash_fn ,
252246 hyperparams = hyperparams ,
253- event_filter_regex = event_filter_regex ,
247+ event_filter_regex = event_filter_regex if tune_pallas_only else None ,
254248 sample_num = num_samples ,
255249 )
256250
@@ -268,8 +262,7 @@ def attention_fn(
268262 task = "tokamax_splash_attentionatt" ,
269263 trace_dir = trace_dir ,
270264 event_name_str_list = [
271- "splash_mqa_fwd_no_residuals.1" ,
272- "splash_mqa_dkv_no_residuals.1" ,
265+ f"{ event_filter_regex } _no_residuals.1" ,
273266 ]
274267 )
275268 return {"time_ms_list" : time_ms_list , "output" : output }
0 commit comments