Skip to content

Commit ee262bc

Browse files
wooway777PanZezhong1725
authored andcommitted
issue/204 - support graph in server scripts
1 parent 71fe805 commit ee262bc

2 files changed

Lines changed: 32 additions & 7 deletions

File tree

python/infinilm/llm/llm.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ class EngineConfig:
5050
temperature: Default sampling temperature.
5151
top_p: Default top-p sampling parameter.
5252
top_k: Default top-k sampling parameter.
53+
enable_graph: Whether to enable graph compiling.
5354
"""
5455

5556
model_path: str
@@ -63,6 +64,7 @@ class EngineConfig:
6364
temperature: float = 1.0
6465
top_p: float = 0.8
6566
top_k: int = 1
67+
enable_graph: bool = False
6668

6769

6870
class LLMEngine:
@@ -74,11 +76,18 @@ def __init__(self, config: EngineConfig):
7476
# Initialize device and dtype
7577
self._init_device()
7678

79+
# Initialize KV cache
80+
cache_config = PagedKVCacheConfig(
81+
num_blocks=config.num_blocks, block_size=config.block_size
82+
)
83+
7784
# Initialize model engine
7885
self.model_engine = InferEngine(
7986
model_path=config.model_path,
8087
device=self.device,
8188
distributed_config=DistConfig(config.tensor_parallel_size),
89+
cache_config=cache_config,
90+
enable_graph_compiling=config.enable_graph,
8291
)
8392

8493
# Load model weights
@@ -92,12 +101,6 @@ def __init__(self, config: EngineConfig):
92101
)
93102
self._fix_tokenizer_decoder()
94103

95-
# Initialize KV cache
96-
cache_config = PagedKVCacheConfig(
97-
num_blocks=config.num_blocks, block_size=config.block_size
98-
)
99-
self.model_engine.reset_cache(cache_config)
100-
101104
# Initialize scheduler
102105
self.scheduler = Scheduler(
103106
max_batch_size=config.max_batch_size,
@@ -113,6 +116,7 @@ def __init__(self, config: EngineConfig):
113116
logger.info(
114117
f"LLMEngine initialized with model at {config.model_path} "
115118
f"on device {config.device}"
119+
f"enable_graph={config.enable_graph}"
116120
)
117121

118122
def _init_device(self):
@@ -308,6 +312,7 @@ def __init__(
308312
temperature: float = 1.0,
309313
top_p: float = 0.8,
310314
top_k: int = 1,
315+
enable_graph: bool = False,
311316
):
312317
"""Initialize LLM.
313318
@@ -323,6 +328,7 @@ def __init__(
323328
temperature: Default sampling temperature.
324329
top_p: Default top-p sampling parameter.
325330
top_k: Default top-k sampling parameter.
331+
enable_graph: Whether to enable graph compiling.
326332
"""
327333
config = EngineConfig(
328334
model_path=model_path,
@@ -336,6 +342,7 @@ def __init__(
336342
temperature=temperature,
337343
top_p=top_p,
338344
top_k=top_k,
345+
enable_graph=enable_graph,
339346
)
340347
self.engine = LLMEngine(config)
341348
self.config = config
@@ -452,6 +459,7 @@ def __init__(
452459
temperature: float = 1.0,
453460
top_p: float = 0.8,
454461
top_k: int = 1,
462+
enable_graph: bool = False,
455463
):
456464
"""Initialize AsyncLLMEngine.
457465
@@ -467,6 +475,7 @@ def __init__(
467475
temperature: Default sampling temperature.
468476
top_p: Default top-p sampling parameter.
469477
top_k: Default top-k sampling parameter.
478+
enable_graph: Whether to enable graph compiling.
470479
"""
471480
config = EngineConfig(
472481
model_path=model_path,
@@ -480,6 +489,7 @@ def __init__(
480489
temperature=temperature,
481490
top_p=top_p,
482491
top_k=top_k,
492+
enable_graph=enable_graph,
483493
)
484494
self.engine = LLMEngine(config)
485495
self.config = config

python/infinilm/server/inference_server.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,9 @@
2222
DEFAULT_REQUEST_TIMEOUT = 1000.0
2323

2424

25-
def chunk_json(id_, content=None, role=None, finish_reason=None):
25+
def chunk_json(
26+
id_, content=None, role=None, finish_reason=None, model: str = "unknown"
27+
):
2628
"""Generate JSON chunk for streaming response."""
2729
delta = {}
2830
if content:
@@ -65,6 +67,7 @@ def __init__(
6567
top_k: int = 1,
6668
host: str = "0.0.0.0",
6769
port: int = 8000,
70+
enable_graph: bool = False,
6871
):
6972
"""Initialize inference server.
7073
@@ -82,6 +85,7 @@ def __init__(
8285
top_k: Default top-k sampling parameter.
8386
host: Server host address.
8487
port: Server port number.
88+
enable_graph: Whether to enable graph compiling.
8589
"""
8690
self.model_path = model_path
8791
self.device = device
@@ -96,6 +100,7 @@ def __init__(
96100
self.top_k = top_k
97101
self.host = host
98102
self.port = port
103+
self.enable_graph = enable_graph
99104

100105
self.engine: AsyncLLMEngine = None
101106

@@ -123,9 +128,11 @@ async def lifespan(app: FastAPI):
123128
temperature=self.temperature,
124129
top_p=self.top_p,
125130
top_k=self.top_k,
131+
enable_graph=self.enable_graph,
126132
)
127133
self.engine.start()
128134
logger.info(f"Engine initialized with model at {self.model_path}")
135+
logger.info(f" enable_graph: {self.enable_graph}")
129136
yield
130137
self.engine.stop()
131138

@@ -407,6 +414,11 @@ def parse_args():
407414
parser.add_argument("--moore", action="store_true", help="Use Moore device")
408415
parser.add_argument("--iluvatar", action="store_true", help="Use Iluvatar device")
409416
parser.add_argument("--cambricon", action="store_true", help="Use Cambricon device")
417+
parser.add_argument(
418+
"--enable-graph",
419+
action="store_true",
420+
help="Enable graph compiling",
421+
)
410422
parser.add_argument(
411423
"--log_level",
412424
type=str,
@@ -442,6 +454,8 @@ def main():
442454
"\n"
443455
"Example: python infinilm.server.inference_server --nvidia --model_path=/data/shared/models/9G7B_MHA/ "
444456
"--max_tokens=100 --max_batch_size=32 --tp=1 --temperature=1.0 --top_p=0.8 --top_k=1"
457+
"\n"
458+
"Optional: --enable-paged-attn --enable-graph"
445459
)
446460
sys.exit(1)
447461

@@ -459,6 +473,7 @@ def main():
459473
top_k=args.top_k,
460474
host=args.host,
461475
port=args.port,
476+
enable_graph=args.enable_graph,
462477
)
463478
server.start()
464479

0 commit comments

Comments
 (0)