Skip to content

Commit 99eebe3

Browse files
authored
[feature]Add LLM connection base components and OpenAI connector (#636)
This PR introduces the foundational building blocks for integrating large language models (LLMs) into the system, with a focus on modularity, type safety, and support for both streaming and non-streaming interactions. Highlights: Core Data Models: Added LLMRequest, LLMResponse, and LLMStreamChunk as immutable dataclasses to standardize input/output formats across LLM integrations. Protocol Abstraction: Defined LLMConnection as a Python protocol to enforce a consistent interface (sync/async chat methods) for any LLM backend. OpenAI Integration: Implemented OpenAIConn, a fully featured connector that supports synchronous/asynchronous invocation and both streaming and non-streaming modes via the OpenAI-compatible API. Token Management: Integrated HuggingFaceTokenizer to enable accurate token counting and utility functions like random text generation based on tokenizer vocabulary. Streaming Support: Built an SSE (Server-Sent Events) parser that converts OpenAI’s streaming response format into the standardized LLMStreamChunk structure. HTTP Layer: Leveraged httpx (sync and async clients) for robust and efficient HTTP communication with LLM endpoints. Security & Utilities: Added helper functions for secure token encoding/decoding and request ID validation to ensure integrity and traceability. This lays the groundwork for pluggable LLM backends—future connectors (e.g., Anthropic, Ollama, or custom endpoints) can implement the LLMConnection protocol and reuse the shared infrastructure. token_utils test <img width="1769" height="1107" alt="image" src="https://github.com/user-attachments/assets/2411cbb3-bfbb-440f-818f-604715e8091a" /> The calculated token is consistent with the engine's return <img width="1761" height="1036" alt="image" src="https://github.com/user-attachments/assets/4cd61a2e-f2ec-410c-80b9-99a6ffe59681" /> <img width="1232" height="463" alt="image" src="https://github.com/user-attachments/assets/f0b7639c-1e54-490a-8575-e7cdea417f3f" /> Streaming token calculation method <img width="1544" height="983" alt="image" src="https://github.com/user-attachments/assets/08e7e799-2683-4697-99fc-d1779b29b7f4" />
1 parent 5c48af8 commit 99eebe3

5 files changed

Lines changed: 519 additions & 1 deletion

File tree

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
from __future__ import annotations
2+
3+
from dataclasses import dataclass
4+
from typing import (
5+
AsyncIterator,
6+
Iterator,
7+
Optional,
8+
Protocol,
9+
Sequence,
10+
runtime_checkable,
11+
)
12+
13+
14+
@dataclass
15+
class LLMRequest:
16+
"""
17+
Either `messages` or `num_tokens` should be provided (not both).
18+
19+
- `messages`: A standard list of message dictionaries (e.g., with roles like 'user', 'assistant').
20+
- `num_tokens`: If provided, the system will generate random messages totaling approximately this many tokens.
21+
- `max_tokens`: Maximum number of output tokens to generate.
22+
- `ignore_eos`: If True, generation will continue past the end-of-sequence token (only respected by vLLM).
23+
- `temperature`: Sampling temperature (default: 0.0 for deterministic output).
24+
- `top_p`: Nucleus sampling parameter (default: 1.0 for full distribution).
25+
- `timeout`: Request timeout in seconds (default: 600.0).
26+
"""
27+
28+
messages: Sequence[dict] = ()
29+
num_tokens: Optional[int] = None
30+
ignore_eos: bool = False
31+
max_tokens: Optional[int] = None
32+
temperature: float = 0.0
33+
top_p: float = 1.0
34+
timeout: float = 600.0
35+
36+
37+
@dataclass
38+
class LLMResponse:
39+
"""Represents a complete response from an LLM."""
40+
41+
text: str
42+
finish_reason: Optional[str]
43+
total_tokens: int
44+
45+
46+
@dataclass
47+
class LLMStreamChunk:
48+
"""Represents a single streaming chunk during LLM generation."""
49+
50+
text: str
51+
num_tokens: int
52+
is_finished: bool
53+
finish_reason: Optional[str]
54+
55+
56+
@runtime_checkable
57+
class LLMConnection(Protocol):
58+
"""
59+
Minimal contract for LLM clients.
60+
61+
Any connector that implements these four methods satisfies the LLMConnection protocol,
62+
without needing to inherit from a base class.
63+
64+
- `chat` and `achat`: Perform single-turn synchronous and asynchronous inference, respectively.
65+
- `stream_chat` and `astream_chat`: Yield structured `LLMStreamChunk` objects during generation.
66+
"""
67+
68+
def chat(self, request: LLMRequest, **kwargs) -> LLMResponse: ...
69+
70+
"""Synchronous single-turn chat completion."""
71+
72+
def stream_chat(
73+
self, request: LLMRequest, **kwargs
74+
) -> Iterator[LLMStreamChunk]: ...
75+
76+
"""Synchronous streaming chat that yields structured generation chunks."""
77+
78+
async def achat(self, request: LLMRequest, **kwargs) -> LLMResponse: ...
79+
80+
"""Asynchronous single-turn chat completion."""
81+
82+
async def astream_chat(
83+
self, request: LLMRequest, **kwargs
84+
) -> AsyncIterator[LLMStreamChunk]: ...
85+
86+
"""Asynchronous streaming chat that yields structured generation chunks."""
87+
88+
# TODO: Consider adding a unified calling interface.

test/common/llm_connection/__init__.py

Whitespace-only changes.
Lines changed: 304 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,304 @@
1+
from __future__ import annotations
2+
3+
import json
4+
import logging
5+
from dataclasses import dataclass, field
6+
from typing import Any, AsyncIterator, Dict, Iterator, Optional
7+
8+
import httpx
9+
10+
# from pathlib import Path
11+
# import sys
12+
# PRJ_ROOT = Path(__file__).resolve().parent.parent.parent
13+
# sys.path.insert(0, str(PRJ_ROOT))
14+
from common.llm_connection.LLMBase import (
15+
LLMConnection,
16+
LLMRequest,
17+
LLMResponse,
18+
LLMStreamChunk,
19+
)
20+
from common.llm_connection.token_counter import HuggingFaceTokenizer
21+
22+
logger = logging.getLogger(__name__)
23+
24+
25+
def _to_chunk(line: str) -> Optional[LLMStreamChunk]:
26+
"""
27+
Parse a single SSE line from OpenAI-compatible streaming response.
28+
29+
Notes:
30+
- Token count here is only an estimate based on delta text.
31+
- Designed for performance testing, not billing-accurate accounting.
32+
"""
33+
if not line:
34+
return None
35+
36+
line = line.strip()
37+
if not line.startswith("data:"):
38+
return None
39+
40+
raw = line[len("data:") :].strip()
41+
if raw == "[DONE]":
42+
return LLMStreamChunk(
43+
text="",
44+
num_tokens=0,
45+
is_finished=True,
46+
finish_reason="stop",
47+
)
48+
49+
try:
50+
ev = json.loads(raw)
51+
choice = ev["choices"][0]
52+
delta = choice.get("delta", {})
53+
text = delta.get("content", "")
54+
finish_reason = choice.get("finish_reason")
55+
56+
return LLMStreamChunk(
57+
text=text,
58+
num_tokens=0, # filled later (estimate)
59+
is_finished=finish_reason is not None,
60+
finish_reason=finish_reason,
61+
)
62+
except (json.JSONDecodeError, KeyError, IndexError) as e:
63+
logger.error("Failed to parse SSE line: %r (%s)", line, e)
64+
return None
65+
66+
67+
@dataclass
68+
class OpenAIConn(LLMConnection):
69+
"""
70+
OpenAI-compatible LLM connection, intended for:
71+
- performance benchmarking
72+
- streaming latency measurement
73+
- basic accuracy testing
74+
75+
Assumes /v1/chat/completions API.
76+
"""
77+
78+
base_url: str
79+
tokenizer: HuggingFaceTokenizer = field(repr=False)
80+
api_key: str = ""
81+
model: str = "default"
82+
83+
timeout: float = 3600.0 # connect + read + write + pool
84+
85+
_client: httpx.Client = field(init=False, repr=False)
86+
_aclient: httpx.AsyncClient = field(init=False, repr=False)
87+
88+
def __post_init__(self):
89+
self.base_url = self.base_url.rstrip("/")
90+
if not self.base_url.endswith("/v1"):
91+
self.base_url += "/v1"
92+
93+
headers = {"Content-Type": "application/json"}
94+
if self.api_key:
95+
headers["Authorization"] = f"Bearer {self.api_key}"
96+
97+
limits = httpx.Limits(
98+
max_keepalive_connections=None,
99+
max_connections=None,
100+
keepalive_expiry=None,
101+
)
102+
103+
self._client = httpx.Client(
104+
base_url=self.base_url,
105+
headers=headers,
106+
timeout=self.timeout,
107+
limits=limits,
108+
)
109+
self._aclient = httpx.AsyncClient(
110+
base_url=self.base_url,
111+
headers=headers,
112+
timeout=self.timeout,
113+
limits=limits,
114+
)
115+
116+
# ---------------- internal helpers ----------------
117+
118+
def _make_body(self, req: LLMRequest) -> Dict[str, Any]:
119+
if req.messages:
120+
messages = [
121+
{"role": m["role"], "content": m["content"]} for m in req.messages
122+
]
123+
elif req.num_tokens:
124+
logger.warning(
125+
"LLMRequest has no messages, using synthetic tokens for warmup"
126+
)
127+
messages = [
128+
{
129+
"role": "user",
130+
"content": self.tokenizer.get_some_tokens(req.num_tokens or 256),
131+
}
132+
]
133+
else:
134+
raise ValueError("Either 'messages' or 'num_tokens' must be provided")
135+
136+
body: Dict[str, Any] = {
137+
"model": self.model,
138+
"messages": messages,
139+
"temperature": req.temperature,
140+
"top_p": req.top_p,
141+
"stream": False,
142+
}
143+
144+
if req.max_tokens is not None:
145+
body["max_tokens"] = req.max_tokens
146+
if req.ignore_eos is not None:
147+
body["ignore_eos"] = req.ignore_eos
148+
149+
return body
150+
151+
def _wrap_exception(self, e: Exception) -> Exception:
152+
if isinstance(e, httpx.HTTPStatusError):
153+
try:
154+
detail = e.response.json()
155+
msg = detail.get("error", {}).get("message", e.response.text)
156+
except Exception:
157+
msg = e.response.text
158+
raise RuntimeError(f"LLM API Error {e.response.status_code}: {msg}") from e
159+
160+
if isinstance(e, httpx.RequestError):
161+
raise RuntimeError(
162+
f"LLM Network Error: {type(e).__name__} at {e.request.url}"
163+
) from e
164+
165+
raise e
166+
167+
# ---------------- sync ----------------
168+
169+
def chat(self, req: LLMRequest, **kwargs) -> LLMResponse:
170+
body = self._make_body(req)
171+
try:
172+
r = self._client.post("/chat/completions", json=body)
173+
r.raise_for_status()
174+
data = r.json()
175+
176+
# print(data)
177+
178+
choice = data["choices"][0]
179+
text = choice["message"]["content"]
180+
181+
return LLMResponse(
182+
text=text,
183+
finish_reason=choice.get("finish_reason"),
184+
total_tokens=self.tokenizer.count_tokens(text),
185+
)
186+
except Exception as e:
187+
raise self._wrap_exception(e)
188+
189+
def stream_chat(self, req: LLMRequest, **kwargs) -> Iterator[LLMStreamChunk]:
190+
body = self._make_body(req)
191+
body["stream"] = True
192+
193+
try:
194+
with self._client.stream(
195+
"POST",
196+
"/chat/completions",
197+
json=body,
198+
) as resp:
199+
resp.raise_for_status()
200+
res = ""
201+
for line in resp.iter_lines():
202+
chunk = _to_chunk(line)
203+
if not chunk:
204+
continue
205+
if not chunk.is_finished:
206+
# estimated token count for throughput only
207+
chunk.num_tokens = self.tokenizer.count_tokens(chunk.text)
208+
yield chunk
209+
except Exception as e:
210+
raise self._wrap_exception(e)
211+
212+
# ---------------- async ----------------
213+
214+
async def achat(self, req: LLMRequest, **kwargs) -> LLMResponse:
215+
body = self._make_body(req)
216+
try:
217+
r = await self._aclient.post("/chat/completions", json=body)
218+
r.raise_for_status()
219+
data = r.json()
220+
221+
choice = data["choices"][0]
222+
text = choice["message"]["content"]
223+
224+
return LLMResponse(
225+
text=text,
226+
finish_reason=choice.get("finish_reason"),
227+
total_tokens=self.tokenizer.count_tokens(text),
228+
)
229+
except Exception as e:
230+
raise self._wrap_exception(e)
231+
232+
async def astream_chat(
233+
self,
234+
req: LLMRequest,
235+
**kwargs,
236+
) -> AsyncIterator[LLMStreamChunk]:
237+
body = self._make_body(req)
238+
body["stream"] = True
239+
240+
try:
241+
async with self._aclient.stream(
242+
"POST",
243+
"/chat/completions",
244+
json=body,
245+
) as resp:
246+
resp.raise_for_status()
247+
async for line in resp.aiter_lines():
248+
chunk = _to_chunk(line)
249+
if not chunk:
250+
continue
251+
if not chunk.is_finished:
252+
chunk.num_tokens = self.tokenizer.count_tokens(chunk.text)
253+
yield chunk
254+
except Exception as e:
255+
raise self._wrap_exception(e)
256+
257+
# ---------------- lifecycle ----------------
258+
259+
def close(self):
260+
self._client.close()
261+
262+
async def aclose(self):
263+
await self._aclient.aclose()
264+
265+
def __enter__(self):
266+
return self
267+
268+
def __exit__(self, exc_type, exc, tb):
269+
self.close()
270+
271+
272+
# Example usage (for local testing)
273+
if __name__ == "__main__":
274+
tok = HuggingFaceTokenizer("D:/Models/Qwen3-32B")
275+
import os
276+
277+
conn = OpenAIConn(
278+
base_url="https://api.siliconflow.cn/v1",
279+
api_key=os.getenv("SILICON_API_KEY") or "",
280+
model="Qwen/Qwen2-7B-Instruct",
281+
tokenizer=tok,
282+
)
283+
284+
# 1. Synchronous non-streaming test
285+
print("Test synchronous non-streaming")
286+
req = LLMRequest(
287+
messages=[{"role": "user", "content": "Hello, introduce yourself"}],
288+
max_tokens=1024,
289+
)
290+
print(conn.chat(req)) #
291+
292+
# 2. Synchronous streaming test
293+
print("Test synchronous streaming")
294+
req = LLMRequest(num_tokens=10)
295+
res = ""
296+
stream_num = 0
297+
for c in conn.stream_chat(req):
298+
# print(c.text, end="", flush=True)
299+
# print(c.num_tokens, c.is_finished, c.finish_reason)
300+
if c.text:
301+
res += c.text
302+
stream_num += c.num_tokens
303+
print(f"Non Stream: {tok.count_tokens(res)}") # Recommended for use
304+
print(f"Stream: {stream_num}") # result is consistent with that Non Stream

0 commit comments

Comments
 (0)