11import os
22import sys
33import logging
4+ import asyncio
45sys .path .insert (0 , '/sgl-workspace/sglang/python' )
56import torch
67import torch .nn .functional as F
78import json
89import uuid
910import time
10- from fastapi import FastAPI , HTTPException
11+ from fastapi import FastAPI , HTTPException , Request
1112from pydantic import BaseModel
1213from typing import List , Optional
1314import uvicorn
3031PAGE_SIZE = int (os .getenv ('PAGE_SIZE' , '1' ))
3132CHUNKED_PREFILL_SIZE = int (os .getenv ('CHUNKED_PREFILL_SIZE' , '131072' ))
3233PORT = int (os .getenv ('PORT' , '8000' ))
34+ WARMUP_ENABLED = os .getenv ('WARMUP_ENABLED' , 'true' ).strip ().lower () in {'1' , 'true' , 'yes' , 'y' , 'on' }
35+ WARMUP_DELAY_SECONDS = int (os .getenv ('WARMUP_DELAY_SECONDS' , '60' ))
36+ WARMUP_PROMPT = os .getenv ('WARMUP_PROMPT' , 'hello' )
3337
3438logger .info (f'MODEL_PATH: { MODEL_PATH } ' )
3539logger .info (f'CONTEXT_LENGTH: { CONTEXT_LENGTH } ' )
3842logger .info (f'PAGE_SIZE: { PAGE_SIZE } ' )
3943logger .info (f'CHUNKED_PREFILL_SIZE: { CHUNKED_PREFILL_SIZE } ' )
4044logger .info (f'PORT: { PORT } ' )
45+ logger .info (f'WARMUP_ENABLED: { WARMUP_ENABLED } ' )
46+ logger .info (f'WARMUP_DELAY_SECONDS: { WARMUP_DELAY_SECONDS } ' )
4147
4248logger .info ('Loading tokenizer...' )
4349tokenizer = AutoTokenizer .from_pretrained (MODEL_PATH , trust_remote_code = True )
4450logger .info ('Tokenizer loaded' )
4551
46- im_start_id = tokenizer .convert_tokens_to_ids ('<|im_start|>' )
47- user_id = tokenizer .convert_tokens_to_ids ('user' )
48- im_end_id = tokenizer .convert_tokens_to_ids ('<|im_end|>' )
49-
5052risk_level_map = {0 : 'Safe' , 1 : 'Unsafe' , 2 : 'Controversial' }
5153query_category_map = {0 : 'Violent' , 1 : 'Sexual Content' , 2 : 'Self-Harm' , 3 : 'Political' , 4 : 'PII' , 5 : 'Copyright' , 6 : 'Illegal Acts' , 7 : 'Unethical' , 8 : 'Jailbreak' }
5254response_category_map = {0 : 'Violent' , 1 : 'Sexual Content' , 2 : 'Self-Harm' , 3 : 'Political' , 4 : 'PII' , 5 : 'Copyright' , 6 : 'Illegal Acts' , 7 : 'Unethical' }
5557
5658app = FastAPI (title = 'Qwen3Guard-Stream API' )
5759
60+
61+ async def delayed_warmup ():
62+ if not WARMUP_ENABLED :
63+ return
64+ if engine is None :
65+ logger .warning ('Warmup skipped: engine is not initialized' )
66+ return
67+
68+ logger .info (f'Warmup scheduled, will run after { WARMUP_DELAY_SECONDS } s' )
69+ await asyncio .sleep (WARMUP_DELAY_SECONDS )
70+
71+ warmup_trace_id = f'warmup-{ uuid .uuid4 ().hex [:8 ]} '
72+ warmup_rid = f'warmup-{ uuid .uuid4 ().hex } '
73+ try :
74+ start = time .time ()
75+ with torch .inference_mode ():
76+ outputs = await engine .async_generate (
77+ WARMUP_PROMPT ,
78+ sampling_params = {'max_new_tokens' : 1 },
79+ rid = warmup_rid ,
80+ resumable = False
81+ )
82+ infer_time = time .time () - start
83+ token_stats = get_token_stats (outputs , prompt = WARMUP_PROMPT , infer_cost = infer_time )
84+ logger .info (
85+ f'[TraceID: { warmup_trace_id } ] Warmup done'
86+ f' | PromptTokens: { token_stats ["prompt_tokens" ]} '
87+ f' | CompletionTokens: { token_stats ["completion_tokens" ]} '
88+ f' | Infer: { infer_time * 1000 :.1f} ms'
89+ f' | TPM(total): { token_stats ["total_tpm" ]:.2f} '
90+ )
91+ except Exception :
92+ logger .exception (f'[TraceID: { warmup_trace_id } ] Warmup failed' )
93+
5894class ChatMessage (BaseModel ):
5995 role : str
6096 content : str
@@ -64,6 +100,57 @@ class ChatCompletionRequest(BaseModel):
64100 messages : List [ChatMessage ]
65101 stream : Optional [bool ] = False
66102
103+
104+ def parse_bool_header (value : Optional [str ]) -> Optional [bool ]:
105+ if not isinstance (value , str ):
106+ return None
107+ normalized = value .strip ().lower ()
108+ if normalized in {'1' , 'true' , 'yes' , 'y' , 'on' }:
109+ return True
110+ if normalized in {'0' , 'false' , 'no' , 'n' , 'off' }:
111+ return False
112+ return None
113+
114+
115+ def get_token_stats (result , prompt : str , infer_cost : float ):
116+ prompt_tokens = 0
117+ completion_tokens = 1 # Guard classification typically returns one token.
118+
119+ meta = None
120+ if isinstance (result , dict ):
121+ meta = result .get ('meta_info' )
122+ elif hasattr (result , 'meta_info' ):
123+ meta = getattr (result , 'meta_info' )
124+
125+ if isinstance (meta , dict ):
126+ prompt_tokens = int (meta .get ('prompt_tokens' , 0 ) or 0 )
127+ completion_tokens = int (meta .get ('completion_tokens' , 1 ) or 1 )
128+ elif meta is not None :
129+ prompt_tokens = int (getattr (meta , 'prompt_tokens' , 0 ) or 0 )
130+ completion_tokens = int (getattr (meta , 'completion_tokens' , 1 ) or 1 )
131+
132+ # Fallback to an estimation for logging if engine meta does not provide token counts.
133+ if prompt_tokens <= 0 and prompt :
134+ prompt_tokens = max (len (prompt ) // 2 , 1 )
135+
136+ total_tokens = prompt_tokens + completion_tokens
137+ generation_tps = completion_tokens / infer_cost if infer_cost > 0 else 0.0
138+ total_tps = total_tokens / infer_cost if infer_cost > 0 else 0.0
139+
140+ # Tokens per minute estimation.
141+ generation_tpm = generation_tps * 60
142+ total_tpm = total_tps * 60
143+
144+ return {
145+ 'prompt_tokens' : prompt_tokens ,
146+ 'completion_tokens' : completion_tokens ,
147+ 'total_tokens' : total_tokens ,
148+ 'generation_tps' : generation_tps ,
149+ 'total_tps' : total_tps ,
150+ 'generation_tpm' : generation_tpm ,
151+ 'total_tpm' : total_tpm
152+ }
153+
67154def process_result (result , type_ = 'query' ):
68155 if type_ == 'query' :
69156 risk_logits = torch .tensor (result ['query_risk_level_logits' ]).view (- 1 , 3 )
@@ -85,52 +172,95 @@ def health():
85172def list_models ():
86173 return {'object' : 'list' , 'data' : [{'id' : os .getenv ('REPO_ID' , 'qwen3-guard' ), 'object' : 'model' , 'owned_by' : 'qwen' }]}
87174
175+
176+ @app .on_event ('startup' )
177+ async def startup_event ():
178+ asyncio .create_task (delayed_warmup ())
179+
180+
88181@app .post ('/v1/chat/completions' )
89- async def chat_completions (request : ChatCompletionRequest ):
90- request_id = uuid .uuid4 ().hex [:8 ]
182+ async def chat_completions (raw_request : Request , request : ChatCompletionRequest ):
91183 start_time = time .time ()
184+ trace_id = (
185+ raw_request .headers .get ('x-request-id' )
186+ or raw_request .headers .get ('request-id' )
187+ or uuid .uuid4 ().hex
188+ )
92189
93190 if not request .messages :
94191 raise HTTPException (status_code = 400 , detail = 'messages cannot be empty' )
95192
96193 last_msg = request .messages [- 1 ]
97- logger .info (f'[{ request_id } ] >>> Request: role={ last_msg .role } , content={ last_msg .content [:100 ]} ...' if len (last_msg .content ) > 100 else f'[{ request_id } ] >>> Request: role={ last_msg .role } , content={ last_msg .content } ' )
98-
99- conversation = [{'role' : m .role , 'content' : m .content } for m in request .messages ]
100- prompt_text = tokenizer .apply_chat_template (conversation , tokenize = False , add_generation_prompt = True )
101- input_ids = tokenizer (prompt_text , return_tensors = 'pt' ).input_ids [0 ].tolist ()
194+ prompt = last_msg .content
195+ sampling_params = {'max_new_tokens' : 1 }
102196
103- logger .info (f'[{ request_id } ] Input tokens: { len (input_ids )} ' )
197+ rid = (
198+ raw_request .headers .get ('x-session-id' )
199+ or raw_request .headers .get ('session-id' )
200+ )
201+ resumable = parse_bool_header (
202+ raw_request .headers .get ('x-resumable' )
203+ or raw_request .headers .get ('resumable' )
204+ )
205+ if not rid :
206+ rid = uuid .uuid4 ().hex
207+ resumable = False
208+ if resumable is None :
209+ resumable = False
104210
105- last_start = next ((i for i in range (len (input_ids )- 1 , - 1 , - 1 ) if input_ids [i :i + 2 ] == [im_start_id , user_id ]), None )
106- user_end_index = next ((i for i in range (last_start + 2 , len (input_ids )) if input_ids [i ] == im_end_id ), None ) if last_start else None
211+ content_for_log = prompt [:100 ] + '...' if len (prompt ) > 100 else prompt
212+ logger .info (
213+ f'[TraceID: { trace_id } ] >>> Request: role={ last_msg .role } , '
214+ f'content={ content_for_log } , rid={ rid } , resumable={ resumable } '
215+ )
107216
108- rid = uuid .uuid4 ().hex
109- last_role = request .messages [- 1 ].role
217+ try :
218+ infer_start = time .time ()
219+ with torch .inference_mode ():
220+ outputs = await engine .async_generate (
221+ prompt ,
222+ sampling_params = sampling_params ,
223+ rid = rid ,
224+ resumable = resumable
225+ )
226+ infer_time = time .time () - infer_start
110227
111- infer_start = time .time ()
112- if last_role == 'user' :
113- type_ = 'query'
114- query_prompt = input_ids [:user_end_index + 1 ] if user_end_index else input_ids
115- outputs = await engine .async_generate (input_ids = query_prompt , sampling_params = {'max_new_tokens' : 1 }, rid = rid , resumable = False )
116- else :
117- type_ = 'response'
118- outputs = await engine .async_generate (input_ids = input_ids , sampling_params = {'max_new_tokens' : 1 }, rid = rid , resumable = False )
119- infer_time = time .time () - infer_start
228+ type_ = 'query' if last_msg .role == 'user' else 'response'
229+ eval_result = process_result (outputs , type_ = type_ )
230+ token_stats = get_token_stats (outputs , prompt = prompt , infer_cost = infer_time )
231+ total_time = time .time () - start_time
120232
121- eval_result = process_result (outputs , type_ = type_ )
122- total_time = time .time () - start_time
233+ logger .info (
234+ f'[TraceID: { trace_id } ] <<< Response: { json .dumps (eval_result , ensure_ascii = False )} '
235+ f' | Mode: { type_ } '
236+ f' | Infer: { infer_time * 1000 :.1f} ms'
237+ f' | Total: { total_time * 1000 :.1f} ms'
238+ f' | PromptTokens: { token_stats ["prompt_tokens" ]} '
239+ f' | CompletionTokens: { token_stats ["completion_tokens" ]} '
240+ f' | TPS(gen): { token_stats ["generation_tps" ]:.2f} '
241+ f' | TPS(total): { token_stats ["total_tps" ]:.2f} '
242+ f' | TPM(gen): { token_stats ["generation_tpm" ]:.2f} '
243+ f' | TPM(total): { token_stats ["total_tpm" ]:.2f} '
244+ )
123245
124- logger .info (f'[{ request_id } ] <<< Response: { json .dumps (eval_result , ensure_ascii = False )} | Infer: { infer_time * 1000 :.1f} ms | Total: { total_time * 1000 :.1f} ms' )
125-
126- return {
127- 'id' : f'chatcmpl-{ rid } ' ,
128- 'object' : 'chat.completion' ,
129- 'created' : int (time .time ()),
130- 'model' : request .model ,
131- 'choices' : [{'index' : 0 , 'message' : {'role' : 'assistant' , 'content' : json .dumps (eval_result , ensure_ascii = False )}, 'finish_reason' : 'stop' }],
132- 'usage' : {'prompt_tokens' : len (input_ids ), 'completion_tokens' : 1 , 'total_tokens' : len (input_ids ) + 1 }
133- }
246+ return {
247+ 'id' : f'chatcmpl-{ rid } ' ,
248+ 'object' : 'chat.completion' ,
249+ 'created' : int (time .time ()),
250+ 'model' : request .model ,
251+ 'choices' : [{'index' : 0 , 'message' : {'role' : 'assistant' , 'content' : json .dumps (eval_result , ensure_ascii = False )}, 'finish_reason' : 'stop' }],
252+ 'usage' : {
253+ 'prompt_tokens' : token_stats ['prompt_tokens' ],
254+ 'completion_tokens' : token_stats ['completion_tokens' ],
255+ 'total_tokens' : token_stats ['total_tokens' ]
256+ }
257+ }
258+ except torch .cuda .OutOfMemoryError :
259+ logger .error (f'[TraceID: { trace_id } ] CUDA OOM' )
260+ raise HTTPException (status_code = 500 , detail = 'GPU out of memory' )
261+ except Exception as e :
262+ logger .exception (f'[TraceID: { trace_id } ] Inference failed' )
263+ raise HTTPException (status_code = 500 , detail = str (e ))
134264
135265if __name__ == '__main__' :
136266 logger .info ('Loading SGLang engine...' )
0 commit comments