Skip to content

Commit a7d2c78

Browse files
author
wangpengcheng
committed
issue/340 - set default max_num_batched_tokens
1 parent 7d7e39d commit a7d2c78

3 files changed

Lines changed: 21 additions & 13 deletions

File tree

csrc/global_state/infinilm_config.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ struct InfinilmConfig {
2323

2424
if (const char *max_num_batched_tokens_env = getenv("INFINILM_MAX_NUM_BATCHED_TOKENS")) {
2525
max_num_batched_tokens = std::stoi(max_num_batched_tokens_env);
26-
ASSERT(max_num_batched_tokens >= 1024 && max_num_batched_tokens < max_position_embeddings);
26+
ASSERT(max_num_batched_tokens >= 1024 && max_num_batched_tokens <= max_position_embeddings);
2727
}
2828
}
2929

python/infinilm/llm/llm.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
- AsyncLLM class for asynchronous streaming (server use)
77
"""
88

9+
import os
910
import asyncio
1011
import time
1112
import uuid
@@ -66,10 +67,20 @@ def __init__(self, config: EngineConfig):
6667
f"KV Connector created: {config.kv_transfer_config.kv_connector} "
6768
f"(role={config.kv_transfer_config.kv_role})"
6869
)
70+
71+
max_position_embeddings = self.model_runner.model_engine.hf_config[
72+
"max_position_embeddings"
73+
]
74+
max_num_batched_tokens = int(
75+
os.getenv("INFINILM_MAX_NUM_BATCHED_TOKENS", max_position_embeddings)
76+
)
77+
assert 1024 <= max_num_batched_tokens <= max_position_embeddings
78+
6979
self.scheduler = Scheduler(
7080
max_batch_size=config.max_batch_size,
7181
num_blocks=config.num_blocks,
7282
block_size=config.block_size,
83+
max_num_batched_tokens=max_num_batched_tokens,
7384
connector=connector,
7485
)
7586
logger.info(f"Using Paged KV Cache with num_blocks={config.num_blocks}")
@@ -685,13 +696,13 @@ def add_request(
685696
elif prompt is not None:
686697
prompt_token_ids = self.engine.tokenize(prompt)
687698
else:
688-
assert (
689-
messages is not None
690-
), "Either messages or prompt/prompt_token_ids must be provided"
699+
assert messages is not None, (
700+
"Either messages or prompt/prompt_token_ids must be provided"
701+
)
691702

692-
assert (
693-
apply_chat_template
694-
), "apply_chat_template needs to be true for multi-role conversation"
703+
assert apply_chat_template, (
704+
"apply_chat_template needs to be true for multi-role conversation"
705+
)
695706

696707
prompt = self.engine.apply_chat_template(
697708
messages, add_generation_prompt=add_generation_prompt

python/infinilm/llm/scheduler.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ def __init__(
4141
max_batch_size: int = 16,
4242
num_blocks: int = 512,
4343
block_size: int = 256,
44+
max_num_batched_tokens: int = 1024,
4445
connector=None,
4546
):
4647
self.waiting_queue = janus.Queue()
@@ -56,13 +57,9 @@ def __init__(
5657
self.cache_manager = BlockManager(num_blocks=num_blocks, block_size=block_size)
5758
self.block_size = block_size
5859

59-
self.connector = connector
60+
self.max_num_batched_tokens = max_num_batched_tokens
6061

61-
assert "INFINILM_MAX_NUM_BATCHED_TOKENS" in os.environ
62-
self.max_num_batched_tokens = int(
63-
os.getenv("INFINILM_MAX_NUM_BATCHED_TOKENS", 65535)
64-
)
65-
assert self.max_num_batched_tokens > 1024
62+
self.connector = connector
6663

6764
def add_request(self, request: InferenceRequest):
6865
if request is not None:

0 commit comments

Comments
 (0)