Skip to content

Commit ef7a21b

Browse files
authored
add sglang cancellation (#1019)
1 parent 8dbe72b commit ef7a21b

3 files changed

Lines changed: 714 additions & 13 deletions

File tree

clarifai/cli/templates/toolkits/sglang/1/model.py

Lines changed: 82 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,16 @@
22
import sys
33
from typing import Iterator, List
44

5+
import httpx
6+
from clarifai_protocol import get_item_id
57
from openai import OpenAI
68

79
from clarifai.runners.models.model_builder import ModelBuilder
810
from clarifai.runners.models.openai_class import OpenAIModelClass
11+
from clarifai.runners.models.sglang_openai_class import (
12+
SGLangCancellationHandler,
13+
SGLangOpenAIModelClass,
14+
)
915
from clarifai.runners.utils.data_utils import Param
1016
from clarifai.runners.utils.openai_convertor import build_openai_messages
1117
from clarifai.utils.logging import logger
@@ -62,7 +68,7 @@ def sglang_openai_server(checkpoints, **kwargs):
6268
return server
6369

6470

65-
class SGLangModel(OpenAIModelClass):
71+
class SGLangModel(SGLangOpenAIModelClass):
6672
client = True
6773
model = True
6874

@@ -90,17 +96,21 @@ def load_model(self):
9096
checkpoints = builder.download_checkpoints(stage=stage)
9197

9298
self.server = sglang_openai_server(checkpoints, **server_args)
99+
self.base_url = f"http://{self.server.host}:{self.server.port}"
93100
self.client = OpenAI(
94101
api_key="notset",
95-
base_url=f"http://{self.server.host}:{self.server.port}/v1",
102+
base_url=f"{self.base_url}/v1",
96103
)
97104
self.model = self.client.models.list().data[0].id
105+
self.cancellation_handler = SGLangCancellationHandler(self.base_url)
98106

99107
@OpenAIModelClass.method
100108
def predict(
101109
self,
102110
prompt: str = "",
103111
chat_history: List[dict] = None,
112+
tools: List[dict] = None,
113+
tool_choice: str = None,
104114
max_tokens: int = Param(
105115
default=512,
106116
description="The maximum number of tokens to generate.",
@@ -115,21 +125,34 @@ def predict(
115125
),
116126
) -> str:
117127
"""Return a single completion."""
128+
if tools is not None and tool_choice is None:
129+
tool_choice = "auto"
130+
118131
messages = build_openai_messages(prompt=prompt, messages=chat_history)
119132
response = self.client.chat.completions.create(
120133
model=self.model,
121134
messages=messages,
135+
tools=tools,
136+
tool_choice=tool_choice,
122137
max_completion_tokens=max_tokens,
123138
temperature=temperature,
124139
top_p=top_p,
125140
)
141+
142+
if response.choices[0] and response.choices[0].message.tool_calls:
143+
import json
144+
145+
tool_calls = response.choices[0].message.tool_calls
146+
return json.dumps([tc.to_dict() for tc in tool_calls], indent=2)
126147
return response.choices[0].message.content
127148

128149
@OpenAIModelClass.method
129150
def generate(
130151
self,
131152
prompt: str = "",
132153
chat_history: List[dict] = None,
154+
tools: List[dict] = None,
155+
tool_choice: str = None,
133156
max_tokens: int = Param(
134157
default=512,
135158
description="The maximum number of tokens to generate.",
@@ -144,15 +167,61 @@ def generate(
144167
),
145168
) -> Iterator[str]:
146169
"""Stream a completion response."""
170+
if tools is not None and tool_choice is None:
171+
tool_choice = "auto"
172+
173+
item_id = None
174+
cancel_event = None
175+
try:
176+
item_id = get_item_id()
177+
except Exception:
178+
pass
179+
147180
messages = build_openai_messages(prompt=prompt, messages=chat_history)
148-
for chunk in self.client.chat.completions.create(
149-
model=self.model,
150-
messages=messages,
151-
max_completion_tokens=max_tokens,
152-
temperature=temperature,
153-
top_p=top_p,
154-
stream=True,
155-
):
156-
if chunk.choices:
157-
text = chunk.choices[0].delta.content if chunk.choices[0].delta.content else ''
158-
yield text
181+
try:
182+
response = self.client.chat.completions.create(
183+
model=self.model,
184+
messages=messages,
185+
tools=tools,
186+
tool_choice=tool_choice,
187+
max_completion_tokens=max_tokens,
188+
temperature=temperature,
189+
top_p=top_p,
190+
stream=True,
191+
stream_options={"include_usage": True},
192+
)
193+
194+
if item_id and self.cancellation_handler:
195+
cancel_event = self.cancellation_handler.register_request(
196+
item_id, response=response.response
197+
)
198+
199+
rid_registered = False
200+
for chunk in response:
201+
if item_id and self.cancellation_handler and not rid_registered:
202+
rid = getattr(chunk, 'id', None)
203+
if rid:
204+
self.cancellation_handler.register_rid(item_id, rid)
205+
rid_registered = True
206+
if cancel_event and cancel_event.is_set():
207+
return
208+
if chunk.choices:
209+
if chunk.choices[0].delta.tool_calls:
210+
import json
211+
212+
tool_calls_json = [
213+
tc.to_dict() for tc in chunk.choices[0].delta.tool_calls
214+
]
215+
yield json.dumps(tool_calls_json, indent=2)
216+
else:
217+
text = (
218+
chunk.choices[0].delta.content
219+
if chunk.choices[0].delta.content
220+
else ''
221+
)
222+
yield text
223+
except httpx.ReadError:
224+
pass
225+
finally:
226+
if item_id and self.cancellation_handler:
227+
self.cancellation_handler.unregister_request(item_id)
Lines changed: 214 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,214 @@
1+
import threading
2+
from typing import Iterator
3+
4+
import httpx
5+
import requests
6+
from clarifai_protocol import get_item_id, register_item_abort_callback
7+
8+
from clarifai.runners.models.openai_class import OpenAIModelClass
9+
from clarifai.utils.logging import logger
10+
11+
12+
class SGLangCancellationHandler:
13+
# Important: in addition to closing the httpx response (which kills the TCP
14+
# connection and lets sglang detect the disconnect), we also POST the captured
15+
# request id (rid) to sglang's /abort_request endpoint. This frees the KV cache
16+
# immediately instead of waiting for the engine to notice the disconnect.
17+
def __init__(self, base_url: str):
18+
self._cancel_events = {}
19+
self._responses = {}
20+
self._rids = {}
21+
self._early_aborts = set()
22+
self._lock = threading.Lock()
23+
self._base_url = base_url
24+
register_item_abort_callback(self._handle_abort)
25+
26+
def _call_abort_request(self, rid: str) -> None:
27+
try:
28+
resp = requests.post(f"{self._base_url}/abort_request", json={"rid": rid}, timeout=2)
29+
logger.info(
30+
f"[SGLangCancellationHandler] /abort_request rid={rid} "
31+
f"status={resp.status_code} body={resp.text}"
32+
)
33+
except Exception as e:
34+
logger.warning(f"[SGLangCancellationHandler] /abort_request failed: {e}")
35+
36+
def _handle_abort(self, item_id: str) -> None:
37+
rid = None
38+
with self._lock:
39+
event = self._cancel_events.get(item_id)
40+
response = self._responses.get(item_id)
41+
rid = self._rids.get(item_id)
42+
if event:
43+
event.set()
44+
if response:
45+
try:
46+
response.close()
47+
except Exception:
48+
pass
49+
else:
50+
self._early_aborts.add(item_id)
51+
# Call outside the lock to avoid holding it during network I/O.
52+
if rid:
53+
self._call_abort_request(rid)
54+
55+
def register_request(self, item_id: str, response=None) -> threading.Event:
56+
cancel_event = threading.Event()
57+
with self._lock:
58+
self._cancel_events[item_id] = cancel_event
59+
if response is not None:
60+
self._responses[item_id] = response
61+
if item_id in self._early_aborts:
62+
cancel_event.set()
63+
self._early_aborts.discard(item_id)
64+
if response is not None:
65+
try:
66+
response.close()
67+
except Exception:
68+
pass
69+
return cancel_event
70+
71+
def register_rid(self, item_id: str, rid: str) -> None:
72+
"""Register the sglang request id once captured from the first chunk.
73+
If the request was already cancelled before the rid was known, issue
74+
/abort_request now so the engine frees the KV cache immediately.
75+
"""
76+
should_abort = False
77+
with self._lock:
78+
if item_id in self._cancel_events:
79+
self._rids[item_id] = rid
80+
if self._cancel_events[item_id].is_set():
81+
should_abort = True
82+
if should_abort:
83+
self._call_abort_request(rid)
84+
85+
def unregister_request(self, item_id: str) -> None:
86+
with self._lock:
87+
self._cancel_events.pop(item_id, None)
88+
self._responses.pop(item_id, None)
89+
self._rids.pop(item_id, None)
90+
self._early_aborts.discard(item_id)
91+
92+
93+
class SGLangOpenAIModelClass(OpenAIModelClass):
94+
"""SGLang-backed OpenAI model with /health probes and cancellation support.
95+
96+
Subclasses must set client, model, server, base_url and cancellation_handler in
97+
load_model(), for example:
98+
99+
def load_model(self):
100+
self.server = sglang_openai_server(checkpoints, **server_args)
101+
self.base_url = f"http://{self.server.host}:{self.server.port}"
102+
self.client = OpenAI(base_url=f"{self.base_url}/v1", api_key="x")
103+
self.model = self.client.models.list().data[0].id
104+
self.cancellation_handler = SGLangCancellationHandler(self.base_url)
105+
106+
For cancellation in generate() or custom streaming methods, follow this pattern:
107+
108+
def generate(self, prompt, ...) -> Iterator[str]:
109+
item_id = None
110+
cancel_event = None
111+
try:
112+
item_id = get_item_id()
113+
except Exception:
114+
pass
115+
try:
116+
response = self.client.chat.completions.create(..., stream=True)
117+
if item_id:
118+
cancel_event = self.cancellation_handler.register_request(
119+
item_id, response=response.response
120+
)
121+
rid_registered = False
122+
for chunk in response:
123+
if item_id and not rid_registered:
124+
rid = getattr(chunk, 'id', None)
125+
if rid:
126+
self.cancellation_handler.register_rid(item_id, rid)
127+
rid_registered = True
128+
if cancel_event and cancel_event.is_set():
129+
return
130+
yield ...
131+
except httpx.ReadError:
132+
pass
133+
finally:
134+
if item_id:
135+
self.cancellation_handler.unregister_request(item_id)
136+
"""
137+
138+
server = None
139+
base_url = None
140+
cancellation_handler = None
141+
142+
def _health_url(self) -> str:
143+
if self.base_url:
144+
return f"{self.base_url}/health"
145+
return f"http://{self.server.host}:{self.server.port}/health"
146+
147+
def handle_liveness_probe(self) -> bool:
148+
if self.server is None:
149+
return super().handle_liveness_probe()
150+
try:
151+
resp = httpx.get(self._health_url(), timeout=5.0)
152+
return resp.status_code == 200
153+
except Exception:
154+
return False
155+
156+
def handle_readiness_probe(self) -> bool:
157+
if self.server is None:
158+
return super().handle_readiness_probe()
159+
try:
160+
resp = httpx.get(self._health_url(), timeout=10.0)
161+
return resp.status_code == 200
162+
except Exception:
163+
return False
164+
165+
@OpenAIModelClass.method
166+
def openai_stream_transport(self, msg: str) -> Iterator[str]:
167+
from pydantic_core import from_json
168+
169+
try:
170+
item_id = get_item_id()
171+
except Exception:
172+
item_id = None
173+
174+
cancel_event = None
175+
try:
176+
request_data = from_json(msg)
177+
request_data = self._update_old_fields(request_data)
178+
endpoint = request_data.pop("openai_endpoint", self.DEFAULT_ENDPOINT)
179+
if endpoint not in [self.ENDPOINT_CHAT_COMPLETIONS, self.ENDPOINT_RESPONSES]:
180+
raise ValueError(
181+
f"Only {self.ENDPOINT_CHAT_COMPLETIONS} and {self.ENDPOINT_RESPONSES} endpoints are supported for streaming."
182+
)
183+
184+
if endpoint == self.ENDPOINT_RESPONSES:
185+
response_args = {**request_data}
186+
response_args.update({"model": self.model})
187+
response = self.client.responses.create(**response_args)
188+
else:
189+
completion_args = self._create_completion_args(request_data)
190+
response = self.client.chat.completions.create(**completion_args)
191+
192+
if item_id and self.cancellation_handler:
193+
cancel_event = self.cancellation_handler.register_request(
194+
item_id, response=response.response
195+
)
196+
197+
rid_registered = False
198+
for chunk in response:
199+
if item_id and self.cancellation_handler and not rid_registered:
200+
rid = getattr(chunk, 'id', None) or getattr(
201+
getattr(chunk, 'response', None), 'id', None
202+
)
203+
if rid:
204+
self.cancellation_handler.register_rid(item_id, rid)
205+
rid_registered = True
206+
if cancel_event and cancel_event.is_set():
207+
return
208+
self._set_usage(chunk)
209+
yield chunk.model_dump_json()
210+
except httpx.ReadError:
211+
pass
212+
finally:
213+
if item_id and self.cancellation_handler:
214+
self.cancellation_handler.unregister_request(item_id)

0 commit comments

Comments
 (0)