|
6 | 6 | - AsyncLLM class for asynchronous streaming (server use) |
7 | 7 | """ |
8 | 8 |
|
| 9 | +import os |
9 | 10 | import asyncio |
10 | 11 | import time |
11 | 12 | import uuid |
@@ -66,10 +67,20 @@ def __init__(self, config: EngineConfig): |
66 | 67 | f"KV Connector created: {config.kv_transfer_config.kv_connector} " |
67 | 68 | f"(role={config.kv_transfer_config.kv_role})" |
68 | 69 | ) |
| 70 | + |
| 71 | + max_position_embeddings = self.model_runner.model_engine.hf_config[ |
| 72 | + "max_position_embeddings" |
| 73 | + ] |
| 74 | + max_num_batched_tokens = int( |
| 75 | + os.getenv("INFINILM_MAX_NUM_BATCHED_TOKENS", max_position_embeddings) |
| 76 | + ) |
| 77 | + assert 1024 <= max_num_batched_tokens <= max_position_embeddings |
| 78 | + |
69 | 79 | self.scheduler = Scheduler( |
70 | 80 | max_batch_size=config.max_batch_size, |
71 | 81 | num_blocks=config.num_blocks, |
72 | 82 | block_size=config.block_size, |
| 83 | + max_num_batched_tokens=max_num_batched_tokens, |
73 | 84 | connector=connector, |
74 | 85 | ) |
75 | 86 | logger.info(f"Using Paged KV Cache with num_blocks={config.num_blocks}") |
@@ -685,13 +696,13 @@ def add_request( |
685 | 696 | elif prompt is not None: |
686 | 697 | prompt_token_ids = self.engine.tokenize(prompt) |
687 | 698 | else: |
688 | | - assert ( |
689 | | - messages is not None |
690 | | - ), "Either messages or prompt/prompt_token_ids must be provided" |
| 699 | + assert messages is not None, ( |
| 700 | + "Either messages or prompt/prompt_token_ids must be provided" |
| 701 | + ) |
691 | 702 |
|
692 | | - assert ( |
693 | | - apply_chat_template |
694 | | - ), "apply_chat_template needs to be true for multi-role conversation" |
| 703 | + assert apply_chat_template, ( |
| 704 | + "apply_chat_template needs to be true for multi-role conversation" |
| 705 | + ) |
695 | 706 |
|
696 | 707 | prompt = self.engine.apply_chat_template( |
697 | 708 | messages, add_generation_prompt=add_generation_prompt |
|
0 commit comments