@@ -252,16 +252,12 @@ def get_attn_backend_cls(
252252 "FLASHMLA, FLASH_ATTN_MLA, or TRITON_MLA. Alternatively, set "
253253 "APHRODITE_MLA_DISABLE=1 to disable MLA for this model."
254254 )
255- if not use_v1 :
256- raise RuntimeError (
257- "MLA attention backends require the V1 engine. Set APHRODITE_USE_V1=1 to enable them."
258- )
259255
260256 from aphrodite .attention .ops .flashmla import is_flashmla_dense_supported
261257 from aphrodite .attention .utils .fa_utils import flash_attn_supports_mla
262258
263259 if use_sparse :
264- logger .info_once ("Using Sparse MLA backend on V1 engine ." , scope = "global" )
260+ logger .info_once ("Using Sparse MLA backend." , scope = "global" )
265261 return "aphrodite.v1.attention.backends.mla.flashmla_sparse.FlashMLASparseBackend"
266262
267263 use_cutlassmla = selected_backend == _Backend .CUTLASS_MLA or (
@@ -281,13 +277,13 @@ def get_attn_backend_cls(
281277 use_triton = selected_backend == _Backend .TRITON_MLA or (selected_backend is None )
282278
283279 if use_cutlassmla :
284- logger .info_once ("Using Cutlass MLA backend on V1 engine ." , scope = "local" )
280+ logger .info_once ("Using Cutlass MLA backend." , scope = "local" )
285281 return "aphrodite.v1.attention.backends.mla.cutlass_mla.CutlassMLABackend"
286282 if use_flashinfermla :
287283 from aphrodite .v1 .attention .backends .utils import set_kv_cache_layout
288284
289285 set_kv_cache_layout ("HND" )
290- logger .info_once ("Using FlashInfer MLA backend on V1 engine ." , scope = "global" )
286+ logger .info_once ("Using FlashInfer MLA backend." , scope = "global" )
291287 return "aphrodite.v1.attention.backends.mla.flashinfer_mla.FlashInferMLABackend"
292288 if use_flashmla :
293289 if block_size % 64 != 0 :
@@ -296,106 +292,96 @@ def get_attn_backend_cls(
296292 block_size ,
297293 )
298294 else :
299- logger .info_once ("Using FlashMLA backend on V1 engine ." , scope = "global" )
295+ logger .info_once ("Using FlashMLA backend." , scope = "global" )
300296 return "aphrodite.v1.attention.backends.mla.flashmla.FlashMLABackend"
301297 if use_flashattn :
302- logger .info_once ("Using FlashAttention MLA backend on V1 engine ." , scope = "global" )
298+ logger .info_once ("Using FlashAttention MLA backend." , scope = "global" )
303299 return "aphrodite.v1.attention.backends.mla.flashattn_mla.FlashAttnMLABackend"
304300 if use_triton :
305- logger .info_once ("Using Triton MLA backend on V1 engine ." , scope = "global" )
301+ logger .info_once ("Using Triton MLA backend." , scope = "global" )
306302 return "aphrodite.v1.attention.backends.mla.triton_mla.TritonMLABackend"
307- if use_v1 :
308- FLASHINFER_V1 = "aphrodite.v1.attention.backends.flashinfer.FlashInferBackend" # noqa: E501
309- FLEX_ATTENTION_V1 = "aphrodite.v1.attention.backends.flex_attention.FlexAttentionBackend" # noqa: E501
310- TRITON_ATTN = "aphrodite.v1.attention.backends.triton_attn.TritonAttentionBackend" # noqa: E501
311- FLASH_ATTN_V1 = "aphrodite.v1.attention.backends.flash_attn.FlashAttentionBackend" # noqa: E501
312- TREE_ATTN_V1 = "aphrodite.v1.attention.backends.tree_attn.TreeAttentionBackend" # noqa: E501
313- XFORMERS_V1 = "aphrodite.v1.attention.backends.xformers.XFormersAttentionBackend" # noqa: E501
314303
315- use_fp8_kv_cache = kv_cache_dtype is not None and kv_cache_dtype .startswith ("fp8" )
304+ FLASHINFER_V1 = "aphrodite.v1.attention.backends.flashinfer.FlashInferBackend" # noqa: E501
305+ FLEX_ATTENTION_V1 = "aphrodite.v1.attention.backends.flex_attention.FlexAttentionBackend" # noqa: E501
306+ TRITON_ATTN = "aphrodite.v1.attention.backends.triton_attn.TritonAttentionBackend" # noqa: E501
307+ FLASH_ATTN_V1 = "aphrodite.v1.attention.backends.flash_attn.FlashAttentionBackend" # noqa: E501
308+ TREE_ATTN_V1 = "aphrodite.v1.attention.backends.tree_attn.TreeAttentionBackend" # noqa: E501
309+ XFORMERS_V1 = "aphrodite.v1.attention.backends.xformers.XFormersAttentionBackend" # noqa: E501
316310
317- if selected_backend == _Backend .FLASHINFER :
318- logger .info_once ("Using FlashInfer backend on V1 engine." , scope = "global" )
319- if cls .has_device_capability (100 ):
320- from aphrodite .v1 .attention .backends .utils import set_kv_cache_layout
311+ use_fp8_kv_cache = kv_cache_dtype is not None and kv_cache_dtype .startswith ("fp8" )
321312
322- set_kv_cache_layout ("HND" )
323- return FLASHINFER_V1
324- elif selected_backend == _Backend .FLEX_ATTENTION :
325- logger .info_once ("Using FlexAttention backend on V1 engine." , scope = "global" )
326- return FLEX_ATTENTION_V1
327- elif selected_backend == _Backend .TRITON_ATTN :
328- logger .info_once ("Using Triton backend on V1 engine." , scope = "global" )
329- return TRITON_ATTN
330- elif selected_backend == _Backend .FLASH_ATTN :
331- logger .info_once ("Using Flash Attention backend on V1 engine." , scope = "global" )
332- return FLASH_ATTN_V1
333- elif selected_backend == _Backend .TREE_ATTN :
334- logger .info_once ("Using Tree Attention backend on V1 engine." , scope = "global" )
335- return TREE_ATTN_V1
336- elif selected_backend == _Backend .XFORMERS :
337- logger .info_once ("Using XFormers backend on V1 engine." , scope = "global" )
338- return XFORMERS_V1
339-
340- from aphrodite .attention .selector import is_attn_backend_supported
341-
342- # Default backends for V1 engine
343- # Prefer FlashInfer for Blackwell GPUs if installed
344- if cls .is_device_capability (100 ):
345- if is_default_backend_supported := is_attn_backend_supported (FLASHINFER_V1 , head_size , dtype ):
346- from aphrodite .v1 .attention .backends .utils import set_kv_cache_layout
313+ if selected_backend == _Backend .FLASHINFER :
314+ logger .info_once ("Using FlashInfer backend." )
315+ if cls .has_device_capability (100 ):
316+ from aphrodite .v1 .attention .backends .utils import set_kv_cache_layout
347317
348- logger .info_once (
349- "Using FlashInfer backend with HND KV cache layout on "
350- "V1 engine by default for Blackwell (SM 10.0) GPUs." ,
351- scope = "global" ,
352- )
353- set_kv_cache_layout ("HND" )
318+ set_kv_cache_layout ("HND" )
319+ return FLASHINFER_V1
320+ elif selected_backend == _Backend .FLEX_ATTENTION :
321+ logger .info_once ("Using FlexAttention backend." )
322+ return FLEX_ATTENTION_V1
323+ elif selected_backend == _Backend .TRITON_ATTN :
324+ logger .info_once ("Using Triton backend." )
325+ return TRITON_ATTN
326+ elif selected_backend == _Backend .FLASH_ATTN :
327+ logger .info_once ("Using Flash Attention backend." )
328+ return FLASH_ATTN_V1
329+ elif selected_backend == _Backend .TREE_ATTN :
330+ logger .info_once ("Using Tree Attention backend." )
331+ return TREE_ATTN_V1
332+ elif selected_backend == _Backend .XFORMERS :
333+ logger .info_once ("Using XFormers backend." )
334+ return XFORMERS_V1
335+
336+ from aphrodite .attention .selector import is_attn_backend_supported
337+
338+ # Default backends for V1 engine
339+ # Prefer FlashInfer for Blackwell GPUs if installed
340+ if cls .is_device_capability (100 ):
341+ if is_default_backend_supported := is_attn_backend_supported (FLASHINFER_V1 , head_size , dtype ):
342+ from aphrodite .v1 .attention .backends .utils import set_kv_cache_layout
354343
355- return FLASHINFER_V1
344+ logger .info_once (
345+ "Using FlashInfer backend with HND KV cache layout on "
346+ "V1 engine by default for Blackwell (SM 10.0) GPUs." ,
347+ scope = "global" ,
348+ )
349+ set_kv_cache_layout ("HND" )
356350
357- if not is_default_backend_supported .can_import :
358- logger .warning_once (
359- "FlashInfer failed to import for V1 engine on "
360- "Blackwell (SM 10.0) GPUs; it is recommended to "
361- "install FlashInfer for better performance." ,
362- scope = "global" ,
363- )
351+ return FLASHINFER_V1
364352
365- # FlashAttention is the default for SM 8.0+ GPUs
366- if cls .has_device_capability (80 ):
367- if (has_sink or use_fp8_kv_cache ) and not cls .is_device_capability (90 ):
368- logger .info_once ("Using Triton backend on V1 engine." , scope = "global" )
369- return TRITON_ATTN
370- elif is_default_backend_supported := is_attn_backend_supported (
371- FLASH_ATTN_V1 , head_size , dtype , allow_import_error = False
372- ):
373- logger .info_once ("Using Flash Attention backend on V1 engine." , scope = "global" )
374- return FLASH_ATTN_V1
375-
376- # FlexAttention is the default for older GPUs
377- else :
378- logger .info_once ("Using FlexAttention backend on V1 engine." , scope = "global" )
379- return FLEX_ATTENTION_V1
353+ if not is_default_backend_supported .can_import :
354+ logger .warning_once (
355+ "FlashInfer failed to import for V1 engine on "
356+ "Blackwell (SM 10.0) GPUs; it is recommended to "
357+ "install FlashInfer for better performance." ,
358+ scope = "global" ,
359+ )
380360
381- assert not is_default_backend_supported
361+ # FlashAttention is the default for SM 8.0+ GPUs
362+ if cls .has_device_capability (80 ):
363+ if (has_sink or use_fp8_kv_cache ) and not cls .is_device_capability (90 ):
364+ logger .info_once ("Using Triton backend." , scope = "global" )
365+ return TRITON_ATTN
366+ elif is_default_backend_supported := is_attn_backend_supported (
367+ FLASH_ATTN_V1 , head_size , dtype , allow_import_error = False
368+ ):
369+ logger .info_once ("Using Flash Attention backend." , scope = "global" )
370+ return FLASH_ATTN_V1
382371
383- use_flex_attention_reason = {}
384- if not is_default_backend_supported .head_size :
385- use_flex_attention_reason ["head_size" ] = head_size
386- if not is_default_backend_supported .dtype :
387- use_flex_attention_reason ["dtype" ] = dtype
372+ assert not is_default_backend_supported
388373
389- logger .info_once (
390- "Using FlexAttention backend for %s on V1 engine." ,
391- ", " .join (f"{ k } ={ v } " for k , v in use_flex_attention_reason .items ()),
392- scope = "global" ,
393- )
394- return FLEX_ATTENTION_V1
374+ use_flex_attention_reason = {}
375+ if not is_default_backend_supported .head_size :
376+ use_flex_attention_reason ["head_size" ] = head_size
377+ if not is_default_backend_supported .dtype :
378+ use_flex_attention_reason ["dtype" ] = dtype
395379
396- raise RuntimeError (
397- "V0 attention backends have been removed. Set APHRODITE_USE_V1=1 to select a supported backend."
380+ logger .info_once (
381+ "Using FlexAttention backend for %s." ,
382+ ", " .join (f"{ k } ={ v } " for k , v in use_flex_attention_reason .items ()),
398383 )
384+ return FLEX_ATTENTION_V1
399385
400386 @classmethod
401387 def get_punica_wrapper (cls ) -> str :
0 commit comments