@@ -55,6 +55,7 @@ class EngineConfig:
5555 top_p: Default top-p sampling parameter.
5656 top_k: Default top-k sampling parameter.
5757 enable_graph: Whether to enable graph compiling.
58+ attn_backend: Attention backend to use ('default', 'flash-attn').
5859 """
5960
6061 model_path : str
@@ -71,6 +72,7 @@ class EngineConfig:
7172 top_p : float = 0.8
7273 top_k : int = 1
7374 enable_graph : bool = False
75+ attn_backend : str = "default"
7476
7577
7678class LLMEngine :
@@ -88,6 +90,7 @@ def __init__(self, config: EngineConfig):
8890 device = self .device ,
8991 distributed_config = DistConfig (config .tensor_parallel_size ),
9092 enable_graph_compiling = config .enable_graph ,
93+ attention_backend = config .attn_backend ,
9194 )
9295
9396 # Load model weights
@@ -383,6 +386,7 @@ def __init__(
383386 top_p : float = 0.8 ,
384387 top_k : int = 1 ,
385388 enable_graph : bool = False ,
389+ attn_backend : str = "default" ,
386390 ):
387391 """Initialize LLM.
388392
@@ -401,6 +405,7 @@ def __init__(
401405 top_p: Default top-p sampling parameter.
402406 top_k: Default top-k sampling parameter.
403407 enable_graph: Whether to enable graph compiling.
408+ attn_backend: Attention backend to use ('default', 'flash-attn').
404409 """
405410 config = EngineConfig (
406411 model_path = model_path ,
@@ -417,6 +422,7 @@ def __init__(
417422 top_p = top_p ,
418423 top_k = top_k ,
419424 enable_graph = enable_graph ,
425+ attn_backend = attn_backend ,
420426 )
421427 self .engine = LLMEngine (config )
422428 self .config = config
@@ -536,6 +542,7 @@ def __init__(
536542 top_p : float = 0.8 ,
537543 top_k : int = 1 ,
538544 enable_graph : bool = False ,
545+ attn_backend : str = "default" ,
539546 ):
540547 """Initialize AsyncLLMEngine.
541548
@@ -554,6 +561,7 @@ def __init__(
554561 top_p: Default top-p sampling parameter.
555562 top_k: Default top-k sampling parameter.
556563 enable_graph: Whether to enable graph compiling.
564+ attn_backend: Attention backend to use ('default', 'flash-attn').
557565 """
558566 config = EngineConfig (
559567 model_path = model_path ,
@@ -570,6 +578,7 @@ def __init__(
570578 top_p = top_p ,
571579 top_k = top_k ,
572580 enable_graph = enable_graph ,
581+ attn_backend = attn_backend ,
573582 )
574583 self .engine = LLMEngine (config )
575584 self .config = config
0 commit comments