Skip to content

Commit c73ff20

Browse files
issue/189: add inference server support to InfiniLM (#190)
2 parents de3e6b9 + 97870d3 commit c73ff20

File tree

9 files changed

+1981
-1
lines changed

9 files changed

+1981
-1
lines changed

README.md

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,28 @@ python scripts/test_ppl.py --model-path MODEL_PATH [--ndev NDEV] [--max-batch MA
8888
python examples/jiuge.py --nvidia --model_path=/models/9G7B_MHA/ --backend=cpp --tp=4 --batch_size=16
8989
```
9090

91+
92+
- 推理服务测试
93+
- 启动推理服务
94+
```bash
95+
python python/infinilm/server/inference_server.py [--cpu | --nvidia | --metax | --moore | --iluvatar | --cambricon] --model_path=<path/to/model_dir> --max_tokens=MAX_TOKENS --max_batch_size=MAX_BATCH --tp=NDEV --temperature=TEMP --top_p=TOP_P --top_k=TOP_K --host=HOST --port=PORT
96+
```
97+
98+
- 单卡示例:
99+
```bash
100+
CUDA_VISIBLE_DEVICES=0 python python/infinilm/server/inference_server.py --nvidia --model_path=/models/9G7B_MHA/ --max_tokens=100 --max_batch_size=32 --tp=1 --temperature=1.0 --top_p=0.8 --top_k=1
101+
```
102+
103+
- 多卡分布式示例:
104+
```bash
105+
CUDA_VISIBLE_DEVICES=0,1,2,3 python python/infinilm/server/inference_server.py --nvidia --model_path=/models/9G7B_MHA/ --max_tokens=100 --max_batch_size=32 --tp=4 --temperature=1.0 --top_p=0.8 --top_k=1
106+
```
107+
108+
- 测试推理服务性能:
109+
```bash
110+
python scripts/test_perf.py --verbose
111+
```
112+
91113
- 运行推理基准测试(C-Eval/MMLU)
92114

93115
```bash

python/infinilm/__init__.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,25 @@
11
from .models import AutoLlamaModel
22
from . import distributed
33
from . import cache
4+
from . import llm
45

5-
__all__ = ["AutoLlamaModel", "distributed", "cache"]
6+
from .llm import (
7+
LLM,
8+
AsyncLLMEngine,
9+
SamplingParams,
10+
RequestOutput,
11+
TokenOutput,
12+
)
13+
14+
__all__ = [
15+
"AutoLlamaModel",
16+
"distributed",
17+
"cache",
18+
"llm",
19+
# LLM classes
20+
"LLM",
21+
"AsyncLLMEngine",
22+
"SamplingParams",
23+
"RequestOutput",
24+
"TokenOutput",
25+
]

python/infinilm/llm/__init__.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
"""
2+
InfiniLM Engine - High-performance llm inference engine with batch generation and streaming support.
3+
"""
4+
5+
from infinilm.llm.sampling_params import SamplingParams
6+
from infinilm.llm.request import (
7+
RequestStatus,
8+
FinishReason,
9+
RequestOutput,
10+
CompletionOutput,
11+
TokenOutput,
12+
InferenceRequest,
13+
)
14+
from infinilm.llm.llm import (
15+
LLM,
16+
LLMEngine,
17+
AsyncLLMEngine,
18+
EngineConfig,
19+
)
20+
from infinilm.llm.scheduler import Scheduler, SchedulerOutput
21+
from infinilm.llm.cache_manager import BlockManager, Block
22+
23+
__all__ = [
24+
# Main classes
25+
"LLM",
26+
"AsyncLLMEngine",
27+
"LLMEngine",
28+
"EngineConfig",
29+
# Parameters
30+
"SamplingParams",
31+
# Request and Output
32+
"InferenceRequest",
33+
"RequestOutput",
34+
"CompletionOutput",
35+
"TokenOutput",
36+
"RequestStatus",
37+
"FinishReason",
38+
# Internal (for advanced use)
39+
"Scheduler",
40+
"SchedulerOutput",
41+
"BlockManager",
42+
"Block",
43+
]
Lines changed: 268 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,268 @@
1+
"""
2+
KV Cache Manager - Paged Attention block-based cache allocation and management.
3+
"""
4+
5+
from collections import deque
6+
from typing import List, Dict, Set
7+
import xxhash
8+
import numpy as np
9+
10+
11+
class Block:
12+
"""KV Cache Block with reference counting and hash-based reuse support."""
13+
14+
def __init__(self, block_id: int):
15+
self.block_id = block_id
16+
self.ref_count = 0
17+
self.hash = -1
18+
self.token_ids: List[int] = []
19+
20+
def update(self, hash_value: int, token_ids: List[int]) -> None:
21+
self.hash = hash_value
22+
self.token_ids = token_ids.copy()
23+
24+
def reset(self) -> None:
25+
self.ref_count = 1
26+
self.hash = -1
27+
self.token_ids = []
28+
29+
def free(self) -> None:
30+
self.ref_count = 0
31+
self.hash = -1
32+
self.token_ids = []
33+
34+
def __repr__(self) -> str:
35+
return f"Block(id={self.block_id}, ref={self.ref_count}, hash={self.hash})"
36+
37+
38+
class BlockManager:
39+
"""Manages Paged KV Cache allocation with prefix caching support.
40+
41+
Features:
42+
- Block allocation/deallocation with reference counting
43+
- Hash-based prefix caching for token sequence reuse
44+
- Slot mapping generation for physical-to-logical position mapping
45+
"""
46+
47+
def __init__(self, num_blocks: int, block_size: int):
48+
assert (
49+
num_blocks > 0 and block_size > 0
50+
), "num_blocks and block_size must be positive"
51+
self.num_blocks = num_blocks
52+
self.block_size = block_size
53+
54+
self.blocks: List[Block] = [Block(i) for i in range(num_blocks)]
55+
self.hash_to_block_id: Dict[int, int] = {}
56+
self.free_block_ids: deque = deque(range(num_blocks))
57+
self.used_block_ids: Set[int] = set()
58+
self.req_block_ids: Set[int] = set()
59+
60+
def reset_req_blocks(self) -> None:
61+
"""Move blocks from prefill stage to used blocks and update hash mappings."""
62+
for block_id in self.req_block_ids:
63+
self.used_block_ids.add(block_id)
64+
block = self.blocks[block_id]
65+
prefix_hash = block.hash
66+
self.hash_to_block_id[prefix_hash] = block_id
67+
self.req_block_ids.clear()
68+
69+
@classmethod
70+
def compute_hash(cls, token_ids: List[int], prefix_hash: int = -1) -> int:
71+
"""Compute hash for token sequence with optional prefix chaining."""
72+
h = xxhash.xxh64()
73+
if prefix_hash != -1:
74+
h.update(prefix_hash.to_bytes(8, "little"))
75+
h.update(np.array(token_ids, dtype=np.int32).tobytes())
76+
return h.intdigest()
77+
78+
def _allocate_partial_block(self, block_id: int) -> Block:
79+
"""Allocate an incomplete block and add to used blocks."""
80+
assert block_id in self.free_block_ids, f"Block {block_id} not in free list"
81+
block = self.blocks[block_id]
82+
assert block.ref_count == 0, f"Block {block_id} ref_count not zero"
83+
84+
block.reset()
85+
self.free_block_ids.remove(block_id)
86+
self.used_block_ids.add(block_id)
87+
return block
88+
89+
def _allocate_full_block(self, block_id: int) -> Block:
90+
"""Allocate a complete block and add to request blocks."""
91+
assert block_id in self.free_block_ids, f"Block {block_id} not in free list"
92+
block = self.blocks[block_id]
93+
assert block.ref_count == 0, f"Block {block_id} ref_count not zero"
94+
95+
block.reset()
96+
self.free_block_ids.remove(block_id)
97+
self.req_block_ids.add(block_id)
98+
return block
99+
100+
def _deallocate_block(self, block_id: int):
101+
"""Deallocate a block and return it to free list."""
102+
block = self.blocks[block_id]
103+
assert (
104+
block.ref_count == 0
105+
), f"Block {block_id} ref_count not zero, cannot deallocate"
106+
107+
if block.hash != -1 and self.hash_to_block_id.get(block.hash) == block_id:
108+
del self.hash_to_block_id[block.hash]
109+
110+
block.free()
111+
self.used_block_ids.remove(block_id)
112+
self.free_block_ids.append(block_id)
113+
114+
def can_allocate(self, num_required_blocks: int) -> bool:
115+
return len(self.free_block_ids) >= num_required_blocks
116+
117+
def allocate_blocks(
118+
self, token_ids: List[int], block_table: List[int] = None
119+
) -> tuple[List[int], List[int], int]:
120+
"""Allocate cache blocks for new request with prefix caching support.
121+
122+
Args:
123+
token_ids: Input token sequence
124+
block_table: Existing block_table (for decode phase)
125+
126+
Returns:
127+
Tuple of (block_table, slot_mapping, num_cached_tokens)
128+
"""
129+
if block_table is None:
130+
block_table = []
131+
132+
num_tokens = len(token_ids)
133+
num_blocks = (num_tokens + self.block_size - 1) // self.block_size
134+
slot_mapping = []
135+
num_cached_tokens = 0
136+
prefix_hash = -1
137+
cache_miss = False
138+
139+
for block_idx in range(num_blocks):
140+
start_idx = block_idx * self.block_size
141+
end_idx = min(start_idx + self.block_size, num_tokens)
142+
block_tokens = token_ids[start_idx:end_idx]
143+
144+
# Only full blocks can be hashed for reuse
145+
if len(block_tokens) == self.block_size:
146+
prefix_hash = self.compute_hash(block_tokens, prefix_hash)
147+
148+
# Try to reuse existing block
149+
if not cache_miss:
150+
cached_block_id = self.hash_to_block_id.get(prefix_hash, -1)
151+
if (
152+
cached_block_id != -1
153+
and self.blocks[cached_block_id].token_ids == block_tokens
154+
):
155+
# Check if all tokens are cached
156+
if num_cached_tokens + self.block_size == len(token_ids):
157+
cache_miss = True
158+
else:
159+
# Reuse successful
160+
block = self.blocks[cached_block_id]
161+
block.ref_count += 1
162+
block_table.append(cached_block_id)
163+
num_cached_tokens += self.block_size
164+
continue
165+
else:
166+
cache_miss = True
167+
else:
168+
prefix_hash = -1
169+
170+
# Cannot reuse, allocate new block
171+
if not self.free_block_ids:
172+
raise RuntimeError("No available cache blocks")
173+
174+
new_block_id = self.free_block_ids[0]
175+
if prefix_hash != -1:
176+
block = self._allocate_full_block(new_block_id)
177+
block.update(prefix_hash, block_tokens)
178+
else:
179+
block = self._allocate_partial_block(new_block_id)
180+
block_table.append(new_block_id)
181+
182+
# Generate slot_mapping
183+
for i in range(len(block_tokens)):
184+
slot_mapping.append(new_block_id * self.block_size + i)
185+
186+
return block_table, slot_mapping, num_cached_tokens
187+
188+
def append_slot(
189+
self, block_table: List[int], num_tokens: int, total_token_ids: List[int] = None
190+
) -> tuple[List[int], int]:
191+
"""Append slot for decode phase (generate one new token).
192+
193+
Args:
194+
block_table: Current block_table
195+
num_tokens: Current total token count (including newly generated token)
196+
total_token_ids: All token sequence (for updating block hash)
197+
198+
Returns:
199+
Tuple of (block_table, slot_id)
200+
"""
201+
assert len(block_table) > 0, "block_table cannot be empty"
202+
assert num_tokens > 0, "num_tokens must be greater than 0"
203+
204+
if num_tokens % self.block_size == 1:
205+
# Previous block is full, update its hash for future prefix caching
206+
last_block_id = block_table[-1]
207+
last_block = self.blocks[last_block_id]
208+
209+
# Only update if block's token_ids is empty (avoid duplicate updates)
210+
if len(last_block.token_ids) == 0:
211+
block_start_idx = num_tokens - self.block_size - 1
212+
block_end_idx = num_tokens - 1
213+
block_tokens = total_token_ids[block_start_idx:block_end_idx]
214+
215+
# Compute prefix_hash using previous block's hash if available
216+
if len(block_table) > 1:
217+
prev_block = self.blocks[block_table[-2]]
218+
prefix_hash = prev_block.hash
219+
else:
220+
prefix_hash = -1
221+
222+
current_hash = self.compute_hash(block_tokens, prefix_hash)
223+
last_block.update(current_hash, block_tokens)
224+
self.hash_to_block_id[current_hash] = last_block_id
225+
226+
# Need new block
227+
if not self.free_block_ids:
228+
if not self.try_free_blocks(1):
229+
raise RuntimeError("No available cache blocks")
230+
new_block_id = self.free_block_ids[0]
231+
self._allocate_partial_block(new_block_id)
232+
block_table.append(new_block_id)
233+
234+
# Calculate slot
235+
last_block_id = block_table[-1]
236+
offset = (num_tokens - 1) % self.block_size
237+
slot_id = last_block_id * self.block_size + offset
238+
239+
return block_table, slot_id
240+
241+
def free_blocks(self, block_table: List[int]):
242+
"""Decrease reference count for all blocks. Blocks with ref_count=0 are not
243+
immediately freed to allow reuse."""
244+
for block_id in reversed(block_table):
245+
block = self.blocks[block_id]
246+
block.ref_count -= 1
247+
248+
def try_free_blocks(self, num_required: int) -> bool:
249+
"""Try to free blocks with ref_count=0."""
250+
to_free = [
251+
bid for bid in self.used_block_ids if self.blocks[bid].ref_count == 0
252+
]
253+
254+
for block_id in to_free:
255+
self._deallocate_block(block_id)
256+
if self.can_allocate(num_required):
257+
return True
258+
259+
return self.can_allocate(num_required)
260+
261+
def get_num_free_blocks(self) -> int:
262+
return len(self.free_block_ids)
263+
264+
def __repr__(self):
265+
return (
266+
f"BlockManager(blocks={self.num_blocks}, block_size={self.block_size}, "
267+
f"free={len(self.free_block_ids)}, used={len(self.used_block_ids)})"
268+
)

0 commit comments

Comments
 (0)