diff --git a/fms_mo/aiu_addons/fp8/fp8_attn.py b/fms_mo/aiu_addons/fp8/fp8_attn.py index e4e4224..5a091ae 100644 --- a/fms_mo/aiu_addons/fp8/fp8_attn.py +++ b/fms_mo/aiu_addons/fp8/fp8_attn.py @@ -29,6 +29,7 @@ # Third Party from fms.modules.attention import ( AttentionKwargs, + _sdpa_compute_op, _sdpa_update_attn_kwargs, register_attention_op, ) @@ -340,7 +341,7 @@ def __spyre_scaled_paged_validate_attn_kwargs_op( register_attention_op( "spyre_paged_attn_fp8", _spyre_scaled_paged_store_op, - compute_op=_math_fp8_compute_op, + compute_op=_sdpa_compute_op, is_prefill_op=lambda **attn_kwargs: attn_kwargs.get("block_table", None) is None, compute_decode_op=_spyre_scaled_paged_compute_op,