Skip to content

Commit 9aa6159

Browse files
CementZhangRader
authored andcommitted
feat: optimize sglang guard stream chat completion trace and session headers
1 parent 8fd5d67 commit 9aa6159

1 file changed

Lines changed: 168 additions & 38 deletions

File tree

docker/inference/sglang-guard-stream/start_engine.py

Lines changed: 168 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
import os
22
import sys
33
import logging
4+
import asyncio
45
sys.path.insert(0, '/sgl-workspace/sglang/python')
56
import torch
67
import torch.nn.functional as F
78
import json
89
import uuid
910
import time
10-
from fastapi import FastAPI, HTTPException
11+
from fastapi import FastAPI, HTTPException, Request
1112
from pydantic import BaseModel
1213
from typing import List, Optional
1314
import uvicorn
@@ -30,6 +31,9 @@
3031
PAGE_SIZE = int(os.getenv('PAGE_SIZE', '1'))
3132
CHUNKED_PREFILL_SIZE = int(os.getenv('CHUNKED_PREFILL_SIZE', '131072'))
3233
PORT = int(os.getenv('PORT', '8000'))
34+
WARMUP_ENABLED = os.getenv('WARMUP_ENABLED', 'true').strip().lower() in {'1', 'true', 'yes', 'y', 'on'}
35+
WARMUP_DELAY_SECONDS = int(os.getenv('WARMUP_DELAY_SECONDS', '60'))
36+
WARMUP_PROMPT = os.getenv('WARMUP_PROMPT', 'hello')
3337

3438
logger.info(f'MODEL_PATH: {MODEL_PATH}')
3539
logger.info(f'CONTEXT_LENGTH: {CONTEXT_LENGTH}')
@@ -38,15 +42,13 @@
3842
logger.info(f'PAGE_SIZE: {PAGE_SIZE}')
3943
logger.info(f'CHUNKED_PREFILL_SIZE: {CHUNKED_PREFILL_SIZE}')
4044
logger.info(f'PORT: {PORT}')
45+
logger.info(f'WARMUP_ENABLED: {WARMUP_ENABLED}')
46+
logger.info(f'WARMUP_DELAY_SECONDS: {WARMUP_DELAY_SECONDS}')
4147

4248
logger.info('Loading tokenizer...')
4349
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True)
4450
logger.info('Tokenizer loaded')
4551

46-
im_start_id = tokenizer.convert_tokens_to_ids('<|im_start|>')
47-
user_id = tokenizer.convert_tokens_to_ids('user')
48-
im_end_id = tokenizer.convert_tokens_to_ids('<|im_end|>')
49-
5052
risk_level_map = {0: 'Safe', 1: 'Unsafe', 2: 'Controversial'}
5153
query_category_map = {0: 'Violent', 1: 'Sexual Content', 2: 'Self-Harm', 3: 'Political', 4: 'PII', 5: 'Copyright', 6: 'Illegal Acts', 7: 'Unethical', 8: 'Jailbreak'}
5254
response_category_map = {0: 'Violent', 1: 'Sexual Content', 2: 'Self-Harm', 3: 'Political', 4: 'PII', 5: 'Copyright', 6: 'Illegal Acts', 7: 'Unethical'}
@@ -55,6 +57,40 @@
5557

5658
app = FastAPI(title='Qwen3Guard-Stream API')
5759

60+
61+
async def delayed_warmup():
62+
if not WARMUP_ENABLED:
63+
return
64+
if engine is None:
65+
logger.warning('Warmup skipped: engine is not initialized')
66+
return
67+
68+
logger.info(f'Warmup scheduled, will run after {WARMUP_DELAY_SECONDS}s')
69+
await asyncio.sleep(WARMUP_DELAY_SECONDS)
70+
71+
warmup_trace_id = f'warmup-{uuid.uuid4().hex[:8]}'
72+
warmup_rid = f'warmup-{uuid.uuid4().hex}'
73+
try:
74+
start = time.time()
75+
with torch.inference_mode():
76+
outputs = await engine.async_generate(
77+
WARMUP_PROMPT,
78+
sampling_params={'max_new_tokens': 1},
79+
rid=warmup_rid,
80+
resumable=False
81+
)
82+
infer_time = time.time() - start
83+
token_stats = get_token_stats(outputs, prompt=WARMUP_PROMPT, infer_cost=infer_time)
84+
logger.info(
85+
f'[TraceID: {warmup_trace_id}] Warmup done'
86+
f' | PromptTokens: {token_stats["prompt_tokens"]}'
87+
f' | CompletionTokens: {token_stats["completion_tokens"]}'
88+
f' | Infer: {infer_time*1000:.1f}ms'
89+
f' | TPM(total): {token_stats["total_tpm"]:.2f}'
90+
)
91+
except Exception:
92+
logger.exception(f'[TraceID: {warmup_trace_id}] Warmup failed')
93+
5894
class ChatMessage(BaseModel):
5995
role: str
6096
content: str
@@ -64,6 +100,57 @@ class ChatCompletionRequest(BaseModel):
64100
messages: List[ChatMessage]
65101
stream: Optional[bool] = False
66102

103+
104+
def parse_bool_header(value: Optional[str]) -> Optional[bool]:
105+
if not isinstance(value, str):
106+
return None
107+
normalized = value.strip().lower()
108+
if normalized in {'1', 'true', 'yes', 'y', 'on'}:
109+
return True
110+
if normalized in {'0', 'false', 'no', 'n', 'off'}:
111+
return False
112+
return None
113+
114+
115+
def get_token_stats(result, prompt: str, infer_cost: float):
116+
prompt_tokens = 0
117+
completion_tokens = 1 # Guard classification typically returns one token.
118+
119+
meta = None
120+
if isinstance(result, dict):
121+
meta = result.get('meta_info')
122+
elif hasattr(result, 'meta_info'):
123+
meta = getattr(result, 'meta_info')
124+
125+
if isinstance(meta, dict):
126+
prompt_tokens = int(meta.get('prompt_tokens', 0) or 0)
127+
completion_tokens = int(meta.get('completion_tokens', 1) or 1)
128+
elif meta is not None:
129+
prompt_tokens = int(getattr(meta, 'prompt_tokens', 0) or 0)
130+
completion_tokens = int(getattr(meta, 'completion_tokens', 1) or 1)
131+
132+
# Fallback to an estimation for logging if engine meta does not provide token counts.
133+
if prompt_tokens <= 0 and prompt:
134+
prompt_tokens = max(len(prompt) // 2, 1)
135+
136+
total_tokens = prompt_tokens + completion_tokens
137+
generation_tps = completion_tokens / infer_cost if infer_cost > 0 else 0.0
138+
total_tps = total_tokens / infer_cost if infer_cost > 0 else 0.0
139+
140+
# Tokens per minute estimation.
141+
generation_tpm = generation_tps * 60
142+
total_tpm = total_tps * 60
143+
144+
return {
145+
'prompt_tokens': prompt_tokens,
146+
'completion_tokens': completion_tokens,
147+
'total_tokens': total_tokens,
148+
'generation_tps': generation_tps,
149+
'total_tps': total_tps,
150+
'generation_tpm': generation_tpm,
151+
'total_tpm': total_tpm
152+
}
153+
67154
def process_result(result, type_='query'):
68155
if type_ == 'query':
69156
risk_logits = torch.tensor(result['query_risk_level_logits']).view(-1, 3)
@@ -85,52 +172,95 @@ def health():
85172
def list_models():
86173
return {'object': 'list', 'data': [{'id': os.getenv('REPO_ID', 'qwen3-guard'), 'object': 'model', 'owned_by': 'qwen'}]}
87174

175+
176+
@app.on_event('startup')
177+
async def startup_event():
178+
asyncio.create_task(delayed_warmup())
179+
180+
88181
@app.post('/v1/chat/completions')
89-
async def chat_completions(request: ChatCompletionRequest):
90-
request_id = uuid.uuid4().hex[:8]
182+
async def chat_completions(raw_request: Request, request: ChatCompletionRequest):
91183
start_time = time.time()
184+
trace_id = (
185+
raw_request.headers.get('x-request-id')
186+
or raw_request.headers.get('request-id')
187+
or uuid.uuid4().hex
188+
)
92189

93190
if not request.messages:
94191
raise HTTPException(status_code=400, detail='messages cannot be empty')
95192

96193
last_msg = request.messages[-1]
97-
logger.info(f'[{request_id}] >>> Request: role={last_msg.role}, content={last_msg.content[:100]}...' if len(last_msg.content) > 100 else f'[{request_id}] >>> Request: role={last_msg.role}, content={last_msg.content}')
98-
99-
conversation = [{'role': m.role, 'content': m.content} for m in request.messages]
100-
prompt_text = tokenizer.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True)
101-
input_ids = tokenizer(prompt_text, return_tensors='pt').input_ids[0].tolist()
194+
prompt = last_msg.content
195+
sampling_params = {'max_new_tokens': 1}
102196

103-
logger.info(f'[{request_id}] Input tokens: {len(input_ids)}')
197+
rid = (
198+
raw_request.headers.get('x-session-id')
199+
or raw_request.headers.get('session-id')
200+
)
201+
resumable = parse_bool_header(
202+
raw_request.headers.get('x-resumable')
203+
or raw_request.headers.get('resumable')
204+
)
205+
if not rid:
206+
rid = uuid.uuid4().hex
207+
resumable = False
208+
if resumable is None:
209+
resumable = False
104210

105-
last_start = next((i for i in range(len(input_ids)-1, -1, -1) if input_ids[i:i+2] == [im_start_id, user_id]), None)
106-
user_end_index = next((i for i in range(last_start+2, len(input_ids)) if input_ids[i] == im_end_id), None) if last_start else None
211+
content_for_log = prompt[:100] + '...' if len(prompt) > 100 else prompt
212+
logger.info(
213+
f'[TraceID: {trace_id}] >>> Request: role={last_msg.role}, '
214+
f'content={content_for_log}, rid={rid}, resumable={resumable}'
215+
)
107216

108-
rid = uuid.uuid4().hex
109-
last_role = request.messages[-1].role
217+
try:
218+
infer_start = time.time()
219+
with torch.inference_mode():
220+
outputs = await engine.async_generate(
221+
prompt,
222+
sampling_params=sampling_params,
223+
rid=rid,
224+
resumable=resumable
225+
)
226+
infer_time = time.time() - infer_start
110227

111-
infer_start = time.time()
112-
if last_role == 'user':
113-
type_ = 'query'
114-
query_prompt = input_ids[:user_end_index+1] if user_end_index else input_ids
115-
outputs = await engine.async_generate(input_ids=query_prompt, sampling_params={'max_new_tokens': 1}, rid=rid, resumable=False)
116-
else:
117-
type_ = 'response'
118-
outputs = await engine.async_generate(input_ids=input_ids, sampling_params={'max_new_tokens': 1}, rid=rid, resumable=False)
119-
infer_time = time.time() - infer_start
228+
type_ = 'query' if last_msg.role == 'user' else 'response'
229+
eval_result = process_result(outputs, type_=type_)
230+
token_stats = get_token_stats(outputs, prompt=prompt, infer_cost=infer_time)
231+
total_time = time.time() - start_time
120232

121-
eval_result = process_result(outputs, type_=type_)
122-
total_time = time.time() - start_time
233+
logger.info(
234+
f'[TraceID: {trace_id}] <<< Response: {json.dumps(eval_result, ensure_ascii=False)}'
235+
f' | Mode: {type_}'
236+
f' | Infer: {infer_time*1000:.1f}ms'
237+
f' | Total: {total_time*1000:.1f}ms'
238+
f' | PromptTokens: {token_stats["prompt_tokens"]}'
239+
f' | CompletionTokens: {token_stats["completion_tokens"]}'
240+
f' | TPS(gen): {token_stats["generation_tps"]:.2f}'
241+
f' | TPS(total): {token_stats["total_tps"]:.2f}'
242+
f' | TPM(gen): {token_stats["generation_tpm"]:.2f}'
243+
f' | TPM(total): {token_stats["total_tpm"]:.2f}'
244+
)
123245

124-
logger.info(f'[{request_id}] <<< Response: {json.dumps(eval_result, ensure_ascii=False)} | Infer: {infer_time*1000:.1f}ms | Total: {total_time*1000:.1f}ms')
125-
126-
return {
127-
'id': f'chatcmpl-{rid}',
128-
'object': 'chat.completion',
129-
'created': int(time.time()),
130-
'model': request.model,
131-
'choices': [{'index': 0, 'message': {'role': 'assistant', 'content': json.dumps(eval_result, ensure_ascii=False)}, 'finish_reason': 'stop'}],
132-
'usage': {'prompt_tokens': len(input_ids), 'completion_tokens': 1, 'total_tokens': len(input_ids) + 1}
133-
}
246+
return {
247+
'id': f'chatcmpl-{rid}',
248+
'object': 'chat.completion',
249+
'created': int(time.time()),
250+
'model': request.model,
251+
'choices': [{'index': 0, 'message': {'role': 'assistant', 'content': json.dumps(eval_result, ensure_ascii=False)}, 'finish_reason': 'stop'}],
252+
'usage': {
253+
'prompt_tokens': token_stats['prompt_tokens'],
254+
'completion_tokens': token_stats['completion_tokens'],
255+
'total_tokens': token_stats['total_tokens']
256+
}
257+
}
258+
except torch.cuda.OutOfMemoryError:
259+
logger.error(f'[TraceID: {trace_id}] CUDA OOM')
260+
raise HTTPException(status_code=500, detail='GPU out of memory')
261+
except Exception as e:
262+
logger.exception(f'[TraceID: {trace_id}] Inference failed')
263+
raise HTTPException(status_code=500, detail=str(e))
134264

135265
if __name__ == '__main__':
136266
logger.info('Loading SGLang engine...')

0 commit comments

Comments
 (0)