Skip to content

Commit a7e327e

Browse files
authored
feat(ai-proxy): abort upstream read on client disconnect during streaming (#13254)
1 parent b15341d commit a7e327e

5 files changed

Lines changed: 271 additions & 10 deletions

File tree

apisix/plugin.lua

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1397,13 +1397,23 @@ function _M.run_global_rules(api_ctx, global_rules, conf_version, phase_name)
13971397
end
13981398
end
13991399

1400-
function _M.lua_response_filter(api_ctx, headers, body)
1400+
-- @param wait boolean When true, use synchronous flush (ngx.flush(true)) so callers
1401+
-- can detect client disconnection. Defaults to false (async flush).
1402+
-- @return boolean, string|nil Always returns (ok, err). On success returns true.
1403+
-- On flush failure or print failure returns false, err.
1404+
function _M.lua_response_filter(api_ctx, headers, body, wait)
14011405
local plugins = api_ctx.plugins
14021406
if not plugins or #plugins == 0 then
14031407
-- if there is no any plugin, just print the original body to downstream
1404-
ngx_print(body)
1405-
ngx_flush()
1406-
return
1408+
local ok, err = ngx_print(body)
1409+
if not ok then
1410+
return false, err
1411+
end
1412+
ok, err = ngx_flush(wait == true)
1413+
if not ok then
1414+
return false, err
1415+
end
1416+
return true
14071417
end
14081418
for i = 1, #plugins, 2 do
14091419
local phase_func = plugins[i]["lua_body_filter"]
@@ -1430,8 +1440,15 @@ function _M.lua_response_filter(api_ctx, headers, body)
14301440

14311441
::CONTINUE::
14321442
end
1433-
ngx_print(body)
1434-
ngx_flush()
1443+
local ok, err = ngx_print(body)
1444+
if not ok then
1445+
return false, err
1446+
end
1447+
ok, err = ngx_flush(wait == true)
1448+
if not ok then
1449+
return false, err
1450+
end
1451+
return true
14351452
end
14361453

14371454

apisix/plugins/ai-providers/base.lua

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -333,6 +333,16 @@ function _M.parse_streaming_response(self, ctx, res, target_proto, converter, co
333333
-- uncommitted and causing nginx to fall through to the balancer phase.
334334
local output_sent = false
335335

336+
local function abort_on_disconnect(flush_err)
337+
core.log.info("client disconnected during AI streaming, ",
338+
"aborting upstream read: ", flush_err)
339+
if res._httpc then
340+
res._httpc:close()
341+
res._httpc = nil
342+
end
343+
ctx.var.llm_request_done = true
344+
end
345+
336346
-- Runaway-upstream safeguards. Both are opt-in; unset means no cap.
337347
local max_duration_ms = conf and conf.max_stream_duration_ms
338348
local max_bytes = conf and conf.max_response_bytes
@@ -424,15 +434,24 @@ function _M.parse_streaming_response(self, ctx, res, target_proto, converter, co
424434
::CONTINUE::
425435
end
426436

427-
-- Output: converter events or passthrough raw chunk
437+
-- Output: converter events or passthrough raw chunk.
438+
-- Pass wait=true for synchronous flush so we can detect client disconnection.
428439
if converter then
429440
for _, c in ipairs(converted_chunks) do
430-
plugin.lua_response_filter(ctx, res.headers, c)
441+
local ok, flush_err = plugin.lua_response_filter(ctx, res.headers, c, true)
431442
output_sent = true
443+
if not ok then
444+
abort_on_disconnect(flush_err)
445+
return
446+
end
432447
end
433448
else
434-
plugin.lua_response_filter(ctx, res.headers, chunk)
449+
local ok, flush_err = plugin.lua_response_filter(ctx, res.headers, chunk, true)
435450
output_sent = true
451+
if not ok then
452+
abort_on_disconnect(flush_err)
453+
return
454+
end
436455
end
437456

438457
-- Enforce runaway-upstream safeguards after processing the chunk.

t/cli/test_dns.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,7 @@ curl -v -k -i -m 20 -o /dev/null -s -X PUT http://127.0.0.1:9180/apisix/admin/st
158158
}
159159
}'
160160

161+
sleep 1 # wait for the stream route to propagate from etcd to stream workers
161162
curl http://127.0.0.1:9100 || true
162163
make stop
163164
sleep 0.1 # wait for logs output

t/plugin/ai-proxy-anthropic.t

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -442,7 +442,7 @@ Content-Type: application/json
442442
test-type: null-details
443443
--- error_code: 200
444444
--- response_body_like eval
445-
qr/"input_tokens":10.*"output_tokens":5/
445+
qr/(?s)(?=.*"input_tokens":10)(?=.*"output_tokens":5)/
446446
--- no_error_log
447447
[error]
448448
Lines changed: 224 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,224 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one or more
3+
# contributor license agreements. See the NOTICE file distributed with
4+
# this work for additional information regarding copyright ownership.
5+
# The ASF licenses this file to You under the Apache License, Version 2.0
6+
# (the "License"); you may not use this file except in compliance with
7+
# the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
18+
use t::APISIX 'no_plan';
19+
20+
log_level("info");
21+
repeat_each(1);
22+
no_long_string();
23+
no_root_location();
24+
25+
26+
add_block_preprocessor(sub {
27+
my ($block) = @_;
28+
29+
if (!defined $block->request) {
30+
$block->set_value("request", "GET /t");
31+
}
32+
33+
# Mock upstream: slow SSE server that streams chunks until the connection
34+
# is closed, tracking the final chunk count in the "test" shared dict.
35+
my $http_config = $block->http_config // <<_EOC_;
36+
server {
37+
server_name slow_openai_sse;
38+
listen 7750;
39+
40+
default_type 'text/event-stream';
41+
42+
location /v1/chat/completions {
43+
content_by_lua_block {
44+
ngx.header["Content-Type"] = "text/event-stream"
45+
local dict = ngx.shared["test"]
46+
dict:set("upstream_chunks", 0)
47+
-- Stream up to 2000 chunks with 30ms sleep between each.
48+
-- The proxy should abort well before this completes when
49+
-- the client disconnects.
50+
for i = 1, 2000 do
51+
local ok, err = ngx.print(
52+
'data: {"id":"chatcmpl-1","object":'
53+
.. '"chat.completion.chunk","choices":[{"delta":'
54+
.. '{"content":"tok"},"index":0,'
55+
.. '"finish_reason":null}],"usage":null}\\n\\n')
56+
if not ok then
57+
return
58+
end
59+
local flush_ok = ngx.flush(true)
60+
if not flush_ok then
61+
return
62+
end
63+
dict:set("upstream_chunks", i)
64+
ngx.sleep(0.03)
65+
end
66+
}
67+
}
68+
69+
# Probe endpoint to read the current chunk count.
70+
location /chunks {
71+
content_by_lua_block {
72+
local dict = ngx.shared["test"]
73+
ngx.say(dict:get("upstream_chunks") or 0)
74+
}
75+
}
76+
}
77+
_EOC_
78+
79+
$block->set_value("http_config", $http_config);
80+
});
81+
82+
83+
run_tests();
84+
85+
__DATA__
86+
87+
=== TEST 1: set route for client disconnect test
88+
--- config
89+
location /t {
90+
content_by_lua_block {
91+
local t = require("lib.test_admin").test
92+
local code, body = t('/apisix/admin/routes/1',
93+
ngx.HTTP_PUT,
94+
[[{
95+
"uri": "/anything",
96+
"plugins": {
97+
"ai-proxy": {
98+
"provider": "openai",
99+
"auth": {
100+
"header": {
101+
"Authorization": "Bearer token"
102+
}
103+
},
104+
"options": {
105+
"model": "gpt-4",
106+
"stream": true
107+
},
108+
"override": {
109+
"endpoint": "http://localhost:7750"
110+
},
111+
"ssl_verify": false
112+
}
113+
}
114+
}]]
115+
)
116+
117+
if code >= 300 then
118+
ngx.status = code
119+
end
120+
ngx.say(body)
121+
}
122+
}
123+
--- response_body
124+
passed
125+
126+
127+
128+
=== TEST 2: client disconnect aborts upstream read early
129+
--- config
130+
location /t {
131+
content_by_lua_block {
132+
local http = require("resty.http")
133+
local httpc = http.new()
134+
135+
local ok, err = httpc:connect({
136+
scheme = "http",
137+
host = "localhost",
138+
port = ngx.var.server_port,
139+
})
140+
if not ok then
141+
ngx.status = 500
142+
ngx.say("connect failed: ", err)
143+
return
144+
end
145+
146+
local res, err = httpc:request({
147+
method = "POST",
148+
headers = { ["Content-Type"] = "application/json" },
149+
path = "/anything",
150+
body = [[{"messages": [{"role": "user", "content": "hi"}]}]],
151+
})
152+
if not res then
153+
ngx.status = 500
154+
ngx.say("request failed: ", err)
155+
return
156+
end
157+
158+
-- Read exactly 3 chunks then close the connection abruptly.
159+
for i = 1, 3 do
160+
local chunk, rerr = res.body_reader()
161+
if rerr or not chunk then
162+
ngx.status = 500
163+
ngx.say("unexpected end of stream at chunk ", i, ": ", rerr)
164+
return
165+
end
166+
end
167+
httpc:close()
168+
169+
-- Allow time for the proxy to detect the disconnect and stop
170+
-- feeding the upstream connection, then capture the chunk count.
171+
-- 1s window: unfixed path produces ~33 chunks (1000ms / 30ms per
172+
-- chunk); fixed path stops within a few chunks of the disconnect.
173+
ngx.sleep(1.0)
174+
175+
-- Read chunk count from the mock upstream's probe endpoint.
176+
local probe = http.new()
177+
ok, err = probe:connect({ scheme = "http", host = "localhost", port = 7750 })
178+
if not ok then
179+
ngx.status = 500
180+
ngx.say("probe connect failed: ", err)
181+
return
182+
end
183+
local probe_res, probe_err = probe:request({
184+
method = "GET",
185+
path = "/chunks",
186+
headers = { Host = "localhost" },
187+
})
188+
if not probe_res then
189+
ngx.status = 500
190+
ngx.say("probe request failed: ", probe_err)
191+
return
192+
end
193+
local count_str = probe_res:read_body()
194+
probe:close()
195+
196+
if probe_res.status ~= 200 then
197+
ngx.status = 500
198+
ngx.say("probe status unexpected: ", probe_res.status)
199+
return
200+
end
201+
202+
local count = tonumber(count_str)
203+
if not count then
204+
ngx.status = 500
205+
ngx.say("invalid probe response: ", count_str or "nil")
206+
return
207+
end
208+
209+
-- With the fix the upstream stops shortly after client disconnect
210+
-- (well under 15 chunks). Without the fix it reaches ~33 chunks in
211+
-- the 1s observation window, so this threshold reliably catches the
212+
-- regression while leaving ample headroom for timing variation.
213+
if count > 15 then
214+
ngx.status = 500
215+
ngx.say("upstream was not aborted promptly, chunks: ", count)
216+
return
217+
end
218+
ngx.say("ok, upstream aborted after ~", count, " chunks")
219+
}
220+
}
221+
--- response_body_like
222+
^ok, upstream aborted after ~\d+ chunks$
223+
--- error_log
224+
client disconnected during AI streaming

0 commit comments

Comments
 (0)