Skip to content

Commit f38eb57

Browse files
author
pengcheng888
committed
issue/340 - enhance validation with some checks
1 parent 87a4da1 commit f38eb57

5 files changed

Lines changed: 24 additions & 20 deletions

File tree

python/infinilm/config/kv_transfer.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -41,14 +41,16 @@ def __post_init__(self) -> None:
4141
if self.engine_id is None:
4242
self.engine_id = f"{self.kv_role}_" + str(uuid.uuid4())
4343

44-
if not self.kv_connector_extra_config:
45-
self.kv_connector_extra_config = dict(self.kv_connector_extra_config or {})
46-
self.kv_connector_extra_config.setdefault("mooncake_protocol", "rdma")
47-
48-
assert all(
49-
key in ["mooncake_protocol", "num_workers"]
50-
for key in self.kv_connector_extra_config.keys()
51-
)
44+
self.kv_connector_extra_config = dict(self.kv_connector_extra_config or {})
45+
self.kv_connector_extra_config.setdefault("mooncake_protocol", "rdma")
46+
47+
allowed_extra_config_keys = frozenset({"mooncake_protocol", "num_workers"})
48+
unknown_keys = set(self.kv_connector_extra_config.keys()) - allowed_extra_config_keys
49+
if unknown_keys:
50+
raise ValueError(
51+
f"Unsupported kv_connector_extra_config keys: {sorted(unknown_keys)}. "
52+
f"Supported keys are {sorted(allowed_extra_config_keys)}"
53+
)
5254

5355
mooncake_protocol = self.kv_connector_extra_config["mooncake_protocol"]
5456
if mooncake_protocol not in ["tcp", "rdma"]:

python/infinilm/kv_connector/mooncake/mooncake_connector_worker.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
try:
22
from mooncake.engine import TransferEngine
33
except ImportError as e:
4-
raise ImportError("Please install mooncake") from e
4+
raise ImportError("Please pip install mooncake-transfer-engine") from e
55

66
import asyncio
77
import logging
@@ -933,6 +933,7 @@ async def receive_kv_from_single_worker(
933933

934934
except zmq.ContextTerminated:
935935
logger.debug("ZMQ context terminated, exiting Mooncake receiver thread.")
936+
# TODO: handle this error
936937
except Exception as e:
937938
logger.error("MooncakeXferMetadata transfer failed for %s: %s", req_ids, e)
938939
return

python/infinilm/llm/cache_manager.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -62,9 +62,9 @@ def compute_hash(
6262
return h.intdigest()
6363

6464
def __init__(self, num_blocks: int, block_size: int):
65-
assert (
66-
num_blocks > 0 and block_size > 0
67-
), "num_blocks and block_size must be positive"
65+
assert num_blocks > 0 and block_size > 0, (
66+
"num_blocks and block_size must be positive"
67+
)
6868
self.num_blocks = num_blocks
6969
self.block_size = block_size
7070

@@ -105,9 +105,9 @@ def _allocate_full_block(self) -> Block:
105105
def _deallocate_block(self, block_id: int):
106106
"""Deallocate a block and return it to free list."""
107107
block = self.blocks[block_id]
108-
assert (
109-
block.ref_count == 0
110-
), f"Block {block_id} ref_count not zero, cannot deallocate"
108+
assert block.ref_count == 0, (
109+
f"Block {block_id} ref_count not zero, cannot deallocate"
110+
)
111111

112112
if block.hash != -1 and self.hash_to_block_id.get(block.hash) == block_id:
113113
del self.hash_to_block_id[block.hash]
@@ -396,6 +396,7 @@ def free_blocks(self, block_table: List[int]):
396396
immediately freed to allow reuse."""
397397
for block_id in reversed(block_table):
398398
block = self.blocks[block_id]
399+
assert block.ref_count > 0, "block ref_count must be greater than 0"
399400
block.ref_count -= 1
400401

401402
def try_free_blocks(self, num_required: int) -> bool:
@@ -425,9 +426,9 @@ def update_blocks_hash(self, block_table: List[int], num_local_cached_tokens: in
425426
num_local_cached_tokens: Number of locally cached tokens (must be a multiple of
426427
block_size).
427428
"""
428-
assert (
429-
num_local_cached_tokens % self.block_size == 0
430-
), "num_local_cached_tokens must be multiple of block_size"
429+
assert num_local_cached_tokens % self.block_size == 0, (
430+
"num_local_cached_tokens must be multiple of block_size"
431+
)
431432
for idx in range(num_local_cached_tokens // self.block_size, len(block_table)):
432433
block_id = block_table[idx]
433434
block = self.blocks[block_id]

python/infinilm/llm/model_runner/model_runner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ class ModelRunner:
4545
def __init__(self, config: EngineConfig):
4646
self.config = config
4747
self.kv_transfer_config = config.kv_transfer_config
48-
print(f"kv_transfer_config: {self.kv_transfer_config}")
48+
logger.info(f"kv_transfer_config: {self.kv_transfer_config}")
4949

5050
self._init_device()
5151

python/infinilm/llm/sampling_params.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ class SamplingParams:
1313
temperature: float = 1.0
1414
top_p: float = 0.8
1515
top_k: int = 1
16-
max_tokens: Optional[int] = None
16+
max_tokens: int = 512
1717
stop: Optional[List[str]] = None
1818
stop_token_ids: Optional[List[int]] = (
1919
None # Placeholder for future usage, not currently handled

0 commit comments

Comments
 (0)