Skip to content

Commit 45e5df4

Browse files
committed
refactor: also rename factory function rewrite_max_tokens in openai/openai-compatible
1 parent a769de2 commit 45e5df4

10 files changed

Lines changed: 142 additions & 31 deletions

File tree

apisix/plugins/ai-providers/aimlapi.lua

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
-- limitations under the License.
1616
--
1717

18-
local function rewrite_request_body(body, override, force)
18+
local function rewrite_chat_request_body(body, override, force)
1919
if override.max_tokens then
2020
if force or body.max_tokens == nil then
2121
body.max_tokens = override.max_tokens
@@ -30,7 +30,7 @@ return require("apisix.plugins.ai-providers.base").new(
3030
capabilities = {
3131
["openai-chat"] = {
3232
path = "/chat/completions",
33-
rewrite_request_body = rewrite_request_body,
33+
rewrite_request_body = rewrite_chat_request_body,
3434
},
3535
},
3636
}

apisix/plugins/ai-providers/anthropic.lua

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,16 @@
1515
-- limitations under the License.
1616
--
1717

18-
local function rewrite_request_body(body, override, force)
18+
local function rewrite_chat_request_body(body, override, force)
19+
if override.max_tokens then
20+
if force or body.max_tokens == nil then
21+
body.max_tokens = override.max_tokens
22+
end
23+
end
24+
end
25+
26+
27+
local function rewrite_messages_request_body(body, override, force)
1928
if override.max_tokens then
2029
if force or body.max_tokens == nil then
2130
body.max_tokens = override.max_tokens
@@ -30,11 +39,11 @@ return require("apisix.plugins.ai-providers.base").new(
3039
capabilities = {
3140
["openai-chat"] = {
3241
path = "/v1/chat/completions",
33-
rewrite_request_body = rewrite_request_body,
42+
rewrite_request_body = rewrite_chat_request_body,
3443
},
3544
["anthropic-messages"] = {
3645
path = "/v1/messages",
37-
rewrite_request_body = rewrite_request_body,
46+
rewrite_request_body = rewrite_messages_request_body,
3847
},
3948
},
4049
}

apisix/plugins/ai-providers/azure-openai.lua

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
-- limitations under the License.
1616
--
1717

18-
local function rewrite_request_body(body, override, force)
18+
local function rewrite_chat_request_body(body, override, force)
1919
if override.max_tokens then
2020
if force or body.max_tokens == nil then
2121
body.max_tokens = override.max_tokens
@@ -30,7 +30,7 @@ return require("apisix.plugins.ai-providers.base").new(
3030
capabilities = {
3131
["openai-chat"] = {
3232
path = "/completions",
33-
rewrite_request_body = rewrite_request_body,
33+
rewrite_request_body = rewrite_chat_request_body,
3434
},
3535
},
3636
}

apisix/plugins/ai-providers/deepseek.lua

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
-- limitations under the License.
1616
--
1717

18-
local function rewrite_request_body(body, override, force)
18+
local function rewrite_chat_request_body(body, override, force)
1919
if override.max_tokens then
2020
if force or body.max_tokens == nil then
2121
body.max_tokens = override.max_tokens
@@ -30,7 +30,7 @@ return require("apisix.plugins.ai-providers.base").new(
3030
capabilities = {
3131
["openai-chat"] = {
3232
path = "/chat/completions",
33-
rewrite_request_body = rewrite_request_body,
33+
rewrite_request_body = rewrite_chat_request_body,
3434
},
3535
},
3636
}

apisix/plugins/ai-providers/gemini.lua

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
-- limitations under the License.
1616
--
1717

18-
local function rewrite_request_body(body, override, force)
18+
local function rewrite_chat_request_body(body, override, force)
1919
if override.max_tokens then
2020
if force or body.max_completion_tokens == nil then
2121
body.max_completion_tokens = override.max_tokens
@@ -30,7 +30,7 @@ return require("apisix.plugins.ai-providers.base").new(
3030
capabilities = {
3131
["openai-chat"] = {
3232
path = "/v1beta/openai/chat/completions",
33-
rewrite_request_body = rewrite_request_body,
33+
rewrite_request_body = rewrite_chat_request_body,
3434
},
3535
},
3636
}

apisix/plugins/ai-providers/openai-compatible.lua

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,19 @@
1515
-- limitations under the License.
1616
--
1717

18-
local function rewrite_max_tokens(field_name)
19-
return function(body, override, force)
20-
if override.max_tokens then
21-
if force or body[field_name] == nil then
22-
body[field_name] = override.max_tokens
23-
end
18+
local function rewrite_chat_request_body(body, override, force)
19+
if override.max_tokens then
20+
if force or body.max_tokens == nil then
21+
body.max_tokens = override.max_tokens
22+
end
23+
end
24+
end
25+
26+
27+
local function rewrite_responses_request_body(body, override, force)
28+
if override.max_tokens then
29+
if force or body.max_output_tokens == nil then
30+
body.max_output_tokens = override.max_tokens
2431
end
2532
end
2633
end
@@ -29,11 +36,11 @@ return require("apisix.plugins.ai-providers.base").new({
2936
capabilities = {
3037
["openai-chat"] = {
3138
path = "/v1/chat/completions",
32-
rewrite_request_body = rewrite_max_tokens("max_tokens"),
39+
rewrite_request_body = rewrite_chat_request_body,
3340
},
3441
["openai-responses"] = {
3542
path = "/v1/responses",
36-
rewrite_request_body = rewrite_max_tokens("max_output_tokens"),
43+
rewrite_request_body = rewrite_responses_request_body,
3744
},
3845
["openai-embeddings"] = { path = "/v1/embeddings" },
3946
},

apisix/plugins/ai-providers/openai.lua

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,20 @@
1515
-- limitations under the License.
1616
--
1717

18-
local function rewrite_max_tokens(field_name)
19-
return function(body, override, force)
20-
if override.max_tokens then
21-
if force or body[field_name] == nil then
22-
body[field_name] = override.max_tokens
23-
end
18+
local function rewrite_chat_request_body(body, override, force)
19+
if override.max_tokens then
20+
if force or (body.max_completion_tokens == nil and body.max_tokens == nil) then
21+
body.max_completion_tokens = override.max_tokens
22+
body.max_tokens = nil
23+
end
24+
end
25+
end
26+
27+
28+
local function rewrite_responses_request_body(body, override, force)
29+
if override.max_tokens then
30+
if force or body.max_output_tokens == nil then
31+
body.max_output_tokens = override.max_tokens
2432
end
2533
end
2634
end
@@ -32,11 +40,11 @@ return require("apisix.plugins.ai-providers.base").new(
3240
capabilities = {
3341
["openai-chat"] = {
3442
path = "/v1/chat/completions",
35-
rewrite_request_body = rewrite_max_tokens("max_completion_tokens"),
43+
rewrite_request_body = rewrite_chat_request_body,
3644
},
3745
["openai-responses"] = {
3846
path = "/v1/responses",
39-
rewrite_request_body = rewrite_max_tokens("max_output_tokens"),
47+
rewrite_request_body = rewrite_responses_request_body,
4048
},
4149
["openai-embeddings"] = { path = "/v1/embeddings" },
4250
},

apisix/plugins/ai-providers/openrouter.lua

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
-- limitations under the License.
1616
--
1717

18-
local function rewrite_request_body(body, override, force)
18+
local function rewrite_chat_request_body(body, override, force)
1919
if override.max_tokens then
2020
if force or body.max_tokens == nil then
2121
body.max_tokens = override.max_tokens
@@ -30,7 +30,7 @@ return require("apisix.plugins.ai-providers.base").new(
3030
capabilities = {
3131
["openai-chat"] = {
3232
path = "/api/v1/chat/completions",
33-
rewrite_request_body = rewrite_request_body,
33+
rewrite_request_body = rewrite_chat_request_body,
3434
},
3535
},
3636
}

apisix/plugins/ai-providers/vertex-ai.lua

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ local function get_node(instance_conf)
5454
end
5555

5656

57-
local function rewrite_request_body(body, override, force)
57+
local function rewrite_chat_request_body(body, override, force)
5858
if override.max_tokens then
5959
if force or body.max_completion_tokens == nil then
6060
body.max_completion_tokens = override.max_tokens
@@ -74,7 +74,7 @@ return require("apisix.plugins.ai-providers.base").new({
7474
return get_chat_completions_path(conf.project_id, conf.region)
7575
end
7676
end,
77-
rewrite_request_body = rewrite_request_body,
77+
rewrite_request_body = rewrite_chat_request_body,
7878
},
7979
["vertex-predict"] = {
8080
host = function(conf)

t/plugin/ai-proxy-request-body-override.t

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -496,3 +496,90 @@ max_tokens=555
496496
}
497497
--- response_body
498498
max_tokens=555
499+
500+
501+
502+
=== TEST 10: openai chat - deprecated max_tokens in body is respected in default mode and cleared in force mode
503+
--- config
504+
location /t {
505+
content_by_lua_block {
506+
local t = require("lib.test_admin").test
507+
508+
-- Route with default mode (no force)
509+
local code = t('/apisix/admin/routes/1',
510+
ngx.HTTP_PUT,
511+
[[{
512+
"uri": "/chat",
513+
"plugins": {
514+
"ai-proxy": {
515+
"provider": "openai",
516+
"model": { "name": "gpt-4" },
517+
"auth": { "header": { "Authorization": "Bearer t" } },
518+
"override": {
519+
"endpoint": "http://localhost:6732",
520+
"request_body": {
521+
"max_tokens": 999
522+
}
523+
},
524+
"ssl_verify": false
525+
}
526+
}
527+
}]]
528+
)
529+
if code >= 300 then ngx.status = code; return end
530+
531+
local http = require("resty.http").new()
532+
local cjson = require("cjson.safe")
533+
534+
-- Client sends deprecated max_tokens=200; default mode should NOT override
535+
local res = assert(http:request_uri("http://127.0.0.1:" .. ngx.var.server_port .. "/chat", {
536+
method = "POST",
537+
body = '{"messages":[{"role":"user","content":"hi"}],"max_tokens":200}',
538+
headers = { ["Content-Type"] = "application/json" },
539+
}))
540+
local body = cjson.decode(res.body)
541+
local echoed = cjson.decode(body.choices[1].message.content)
542+
ngx.say("default: max_tokens=", echoed.max_tokens,
543+
" max_completion_tokens=", echoed.max_completion_tokens)
544+
545+
-- Switch to force mode
546+
code = t('/apisix/admin/routes/1',
547+
ngx.HTTP_PUT,
548+
[[{
549+
"uri": "/chat",
550+
"plugins": {
551+
"ai-proxy": {
552+
"provider": "openai",
553+
"model": { "name": "gpt-4" },
554+
"auth": { "header": { "Authorization": "Bearer t" } },
555+
"override": {
556+
"endpoint": "http://localhost:6732",
557+
"request_body": {
558+
"max_tokens": 999
559+
},
560+
"request_body_force_override": true
561+
},
562+
"ssl_verify": false
563+
}
564+
}
565+
}]]
566+
)
567+
if code >= 300 then ngx.status = code; return end
568+
569+
ngx.sleep(0.5)
570+
571+
-- Client sends deprecated max_tokens=200; force mode should clear it and set max_completion_tokens
572+
res = assert(http:request_uri("http://127.0.0.1:" .. ngx.var.server_port .. "/chat", {
573+
method = "POST",
574+
body = '{"messages":[{"role":"user","content":"hi"}],"max_tokens":200}',
575+
headers = { ["Content-Type"] = "application/json" },
576+
}))
577+
body = cjson.decode(res.body)
578+
echoed = cjson.decode(body.choices[1].message.content)
579+
ngx.say("force: max_tokens=", echoed.max_tokens,
580+
" max_completion_tokens=", echoed.max_completion_tokens)
581+
}
582+
}
583+
--- response_body
584+
default: max_tokens=200 max_completion_tokens=nil
585+
force: max_tokens=nil max_completion_tokens=999

0 commit comments

Comments
 (0)