Skip to content

Commit 3b8e1cb

Browse files
authored
Merge pull request #260 from InfiniTensor/issue/259
issue/259 - add attn backend option to inference server
2 parents dfec9d8 + 91cd299 commit 3b8e1cb

2 files changed

Lines changed: 22 additions & 1 deletion

File tree

python/infinilm/llm/llm.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ class EngineConfig:
5555
top_p: Default top-p sampling parameter.
5656
top_k: Default top-k sampling parameter.
5757
enable_graph: Whether to enable graph compiling.
58+
attn_backend: Attention backend to use ('default', 'flash-attn').
5859
"""
5960

6061
model_path: str
@@ -71,6 +72,7 @@ class EngineConfig:
7172
top_p: float = 0.8
7273
top_k: int = 1
7374
enable_graph: bool = False
75+
attn_backend: str = "default"
7476

7577

7678
class LLMEngine:
@@ -88,6 +90,7 @@ def __init__(self, config: EngineConfig):
8890
device=self.device,
8991
distributed_config=DistConfig(config.tensor_parallel_size),
9092
enable_graph_compiling=config.enable_graph,
93+
attention_backend=config.attn_backend,
9194
)
9295

9396
# Load model weights
@@ -383,6 +386,7 @@ def __init__(
383386
top_p: float = 0.8,
384387
top_k: int = 1,
385388
enable_graph: bool = False,
389+
attn_backend: str = "default",
386390
):
387391
"""Initialize LLM.
388392
@@ -401,6 +405,7 @@ def __init__(
401405
top_p: Default top-p sampling parameter.
402406
top_k: Default top-k sampling parameter.
403407
enable_graph: Whether to enable graph compiling.
408+
attn_backend: Attention backend to use ('default', 'flash-attn').
404409
"""
405410
config = EngineConfig(
406411
model_path=model_path,
@@ -417,6 +422,7 @@ def __init__(
417422
top_p=top_p,
418423
top_k=top_k,
419424
enable_graph=enable_graph,
425+
attn_backend=attn_backend,
420426
)
421427
self.engine = LLMEngine(config)
422428
self.config = config
@@ -536,6 +542,7 @@ def __init__(
536542
top_p: float = 0.8,
537543
top_k: int = 1,
538544
enable_graph: bool = False,
545+
attn_backend: str = "default",
539546
):
540547
"""Initialize AsyncLLMEngine.
541548
@@ -554,6 +561,7 @@ def __init__(
554561
top_p: Default top-p sampling parameter.
555562
top_k: Default top-k sampling parameter.
556563
enable_graph: Whether to enable graph compiling.
564+
attn_backend: Attention backend to use ('default', 'flash-attn').
557565
"""
558566
config = EngineConfig(
559567
model_path=model_path,
@@ -570,6 +578,7 @@ def __init__(
570578
top_p=top_p,
571579
top_k=top_k,
572580
enable_graph=enable_graph,
581+
attn_backend=attn_backend,
573582
)
574583
self.engine = LLMEngine(config)
575584
self.config = config

python/infinilm/server/inference_server.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@ def __init__(
108108
host: str = "0.0.0.0",
109109
port: int = 8000,
110110
enable_graph: bool = False,
111+
attn_backend: str = "default",
111112
):
112113
"""Initialize inference server.
113114
@@ -128,6 +129,7 @@ def __init__(
128129
host: Server host address.
129130
port: Server port number.
130131
enable_graph: Whether to enable graph compiling.
132+
attn_backend: Attention backend to use ('default', 'flash-attn').
131133
"""
132134
self.model_path = model_path
133135
# vLLM-like served model id: directory name of model_path
@@ -147,6 +149,7 @@ def __init__(
147149
self.host = host
148150
self.port = port
149151
self.enable_graph = enable_graph
152+
self.attn_backend = attn_backend
150153

151154
self.engine: AsyncLLMEngine = None
152155

@@ -177,6 +180,7 @@ async def lifespan(app: FastAPI):
177180
top_p=self.top_p,
178181
top_k=self.top_k,
179182
enable_graph=self.enable_graph,
183+
attn_backend=self.attn_backend,
180184
)
181185
self.engine.start()
182186
logger.info(f"Engine initialized with model at {self.model_path}")
@@ -613,6 +617,13 @@ def parse_args():
613617
action="store_true",
614618
help="Enable graph compiling",
615619
)
620+
parser.add_argument(
621+
"--attn",
622+
type=str,
623+
default="default",
624+
choices=["default", "flash-attn"],
625+
help="Attention backend to use: 'default' or 'flash-attn'",
626+
)
616627
parser.add_argument(
617628
"--log_level",
618629
type=str,
@@ -655,7 +666,7 @@ def main():
655666
"Example: python infinilm.server.inference_server --nvidia --model_path=/data/shared/models/9G7B_MHA/ "
656667
"--max_tokens=100 --max_batch_size=32 --tp=1 --temperature=1.0 --top_p=0.8 --top_k=1"
657668
"\n"
658-
"Optional: --enable-paged-attn --enable-graph"
669+
"Optional: --enable-paged-attn --enable-graph --attn=default"
659670
)
660671
sys.exit(1)
661672

@@ -676,6 +687,7 @@ def main():
676687
host=args.host,
677688
port=args.port,
678689
enable_graph=args.enable_graph,
690+
attn_backend=args.attn,
679691
)
680692
server.start()
681693

0 commit comments

Comments
 (0)