@@ -50,6 +50,7 @@ class EngineConfig:
5050 temperature: Default sampling temperature.
5151 top_p: Default top-p sampling parameter.
5252 top_k: Default top-k sampling parameter.
53+ enable_graph: Whether to enable graph compiling.
5354 """
5455
5556 model_path : str
@@ -63,6 +64,7 @@ class EngineConfig:
6364 temperature : float = 1.0
6465 top_p : float = 0.8
6566 top_k : int = 1
67+ enable_graph : bool = False
6668
6769
6870class LLMEngine :
@@ -74,11 +76,18 @@ def __init__(self, config: EngineConfig):
7476 # Initialize device and dtype
7577 self ._init_device ()
7678
79+ # Initialize KV cache
80+ cache_config = PagedKVCacheConfig (
81+ num_blocks = config .num_blocks , block_size = config .block_size
82+ )
83+
7784 # Initialize model engine
7885 self .model_engine = InferEngine (
7986 model_path = config .model_path ,
8087 device = self .device ,
8188 distributed_config = DistConfig (config .tensor_parallel_size ),
89+ cache_config = cache_config ,
90+ enable_graph_compiling = config .enable_graph ,
8291 )
8392
8493 # Load model weights
@@ -92,12 +101,6 @@ def __init__(self, config: EngineConfig):
92101 )
93102 self ._fix_tokenizer_decoder ()
94103
95- # Initialize KV cache
96- cache_config = PagedKVCacheConfig (
97- num_blocks = config .num_blocks , block_size = config .block_size
98- )
99- self .model_engine .reset_cache (cache_config )
100-
101104 # Initialize scheduler
102105 self .scheduler = Scheduler (
103106 max_batch_size = config .max_batch_size ,
@@ -113,6 +116,7 @@ def __init__(self, config: EngineConfig):
113116 logger .info (
114117 f"LLMEngine initialized with model at { config .model_path } "
115118 f"on device { config .device } "
119+ f"enable_graph={ config .enable_graph } "
116120 )
117121
118122 def _init_device (self ):
@@ -308,6 +312,7 @@ def __init__(
308312 temperature : float = 1.0 ,
309313 top_p : float = 0.8 ,
310314 top_k : int = 1 ,
315+ enable_graph : bool = False ,
311316 ):
312317 """Initialize LLM.
313318
@@ -323,6 +328,7 @@ def __init__(
323328 temperature: Default sampling temperature.
324329 top_p: Default top-p sampling parameter.
325330 top_k: Default top-k sampling parameter.
331+ enable_graph: Whether to enable graph compiling.
326332 """
327333 config = EngineConfig (
328334 model_path = model_path ,
@@ -336,6 +342,7 @@ def __init__(
336342 temperature = temperature ,
337343 top_p = top_p ,
338344 top_k = top_k ,
345+ enable_graph = enable_graph ,
339346 )
340347 self .engine = LLMEngine (config )
341348 self .config = config
@@ -452,6 +459,7 @@ def __init__(
452459 temperature : float = 1.0 ,
453460 top_p : float = 0.8 ,
454461 top_k : int = 1 ,
462+ enable_graph : bool = False ,
455463 ):
456464 """Initialize AsyncLLMEngine.
457465
@@ -467,6 +475,7 @@ def __init__(
467475 temperature: Default sampling temperature.
468476 top_p: Default top-p sampling parameter.
469477 top_k: Default top-k sampling parameter.
478+ enable_graph: Whether to enable graph compiling.
470479 """
471480 config = EngineConfig (
472481 model_path = model_path ,
@@ -480,6 +489,7 @@ def __init__(
480489 temperature = temperature ,
481490 top_p = top_p ,
482491 top_k = top_k ,
492+ enable_graph = enable_graph ,
483493 )
484494 self .engine = LLMEngine (config )
485495 self .config = config
0 commit comments