Skip to content

Commit fdaf82b

Browse files
authored
feat: implement OpenAPI discriminator support for oneOf/anyOf (#4)
1 parent 04a4a1f commit fdaf82b

6 files changed

Lines changed: 540 additions & 4 deletions

File tree

lib/resty/openapi_validator/body.lua

Lines changed: 108 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,105 @@ local function find_body_schema_for_content_type(route, content_type)
222222
end
223223

224224

225+
-- Check if a schema (or its allOf sub-schemas) declares the discriminator
226+
-- property with an enum that contains the given value.
227+
local function branch_matches_discriminator_enum(branch, prop_name, value)
228+
local function check_props(s)
229+
local p = s.properties and s.properties[prop_name]
230+
if p and p.enum then
231+
for _, v in ipairs(p.enum) do
232+
if v == value then
233+
return true
234+
end
235+
end
236+
end
237+
return false
238+
end
239+
240+
if check_props(branch) then
241+
return true
242+
end
243+
244+
if branch.allOf then
245+
for _, sub in ipairs(branch.allOf) do
246+
if check_props(sub) then
247+
return true
248+
end
249+
end
250+
end
251+
252+
return false
253+
end
254+
255+
256+
-- Find the branch whose _ref matches the mapping target.
257+
local function find_branch_by_mapping(branches, mapping, value)
258+
local target_ref = mapping[value]
259+
if not target_ref then
260+
return nil
261+
end
262+
263+
for _, branch in ipairs(branches) do
264+
if branch._ref == target_ref then
265+
return branch
266+
end
267+
if branch.allOf then
268+
for _, sub in ipairs(branch.allOf) do
269+
if sub._ref == target_ref then
270+
return branch
271+
end
272+
end
273+
end
274+
end
275+
276+
return nil
277+
end
278+
279+
280+
-- Resolve an OpenAPI discriminator to select the correct oneOf/anyOf branch.
281+
-- Returns (selected_schema, nil) on success, or (nil, error_string) on failure.
282+
-- Returns (nil, nil) when the schema has no discriminator.
283+
local function resolve_discriminator(schema, body_data)
284+
local disc = schema.discriminator
285+
if not disc or not disc.propertyName then
286+
return nil, nil
287+
end
288+
289+
local prop_name = disc.propertyName
290+
local branches = schema.oneOf or schema.anyOf
291+
if not branches then
292+
return nil, nil
293+
end
294+
295+
if type(body_data) ~= "table" then
296+
return nil, "discriminator property '" .. prop_name .. "' is missing"
297+
end
298+
299+
local value = body_data[prop_name]
300+
if value == nil then
301+
return nil, "discriminator property '" .. prop_name .. "' is missing"
302+
end
303+
304+
-- try mapping-based lookup first (uses _ref annotations from ref resolver)
305+
if disc.mapping then
306+
local branch = find_branch_by_mapping(branches, disc.mapping, value)
307+
if branch then
308+
return branch, nil
309+
end
310+
end
311+
312+
-- fallback: match by enum on the discriminator property
313+
for _, branch in ipairs(branches) do
314+
if branch_matches_discriminator_enum(branch, prop_name, value) then
315+
return branch, nil
316+
end
317+
end
318+
319+
return nil, "discriminator value '" .. tostring(value)
320+
.. "' does not match any schema"
321+
end
322+
323+
225324
-- Check for readOnly properties present in the request body data.
226325
local function check_readonly_properties(data, schema, errs)
227326
if type(data) ~= "table" or type(schema) ~= "table" then
@@ -297,7 +396,15 @@ function _M.validate(route, body_str, content_type, opts)
297396
check_readonly_properties(body_data, schema, errs)
298397
end
299398

300-
local validator = get_validator(schema)
399+
local effective_schema = schema
400+
local disc_schema, disc_err = resolve_discriminator(schema, body_data)
401+
if disc_err then
402+
tab_insert(errs, errors.new("body", nil, disc_err))
403+
elseif disc_schema then
404+
effective_schema = disc_schema
405+
end
406+
407+
local validator = get_validator(effective_schema)
301408
if validator then
302409
local ok, err = validator(body_data)
303410
if not ok then

lib/resty/openapi_validator/refs.lua

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,7 @@ function _M.resolve(spec)
122122
return nil, err
123123
end
124124

125+
resolved._ref = ref
125126
registry[pointer] = resolved
126127
resolving[pointer] = nil
127128

t/conformance/test_issue201.lua

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
#!/usr/bin/env resty
2+
--- Conformance test ported from kin-openapi issue201_test.go
3+
-- Tests duplicate path templates with different methods and overlapping path segments.
4+
dofile("t/lib/test_bootstrap.lua")
5+
6+
local T = require("test_helper")
7+
local cjson = require("cjson.safe")
8+
local ov = require("resty.openapi_validator")
9+
10+
local spec = cjson.encode({
11+
openapi = "3.0.0",
12+
info = { title = "Sample API", version = "1.0.0" },
13+
paths = {
14+
["/users/{id}"] = {
15+
get = {
16+
parameters = {
17+
{ name = "id", ["in"] = "path", required = true,
18+
schema = { type = "string" } },
19+
},
20+
responses = { ["200"] = { description = "OK" } },
21+
},
22+
post = {
23+
parameters = {
24+
{ name = "id", ["in"] = "path", required = true,
25+
schema = { type = "string" } },
26+
},
27+
requestBody = {
28+
required = true,
29+
content = {
30+
["application/json"] = {
31+
schema = {
32+
type = "object",
33+
required = { "name" },
34+
properties = {
35+
name = { type = "string" },
36+
},
37+
},
38+
},
39+
},
40+
},
41+
responses = { ["200"] = { description = "OK" } },
42+
},
43+
},
44+
},
45+
})
46+
47+
local v = ov.compile(spec)
48+
assert(v, "compile failed")
49+
50+
T.describe("issue201: GET /users/123 (valid)", function()
51+
local ok, err = v:validate_request({
52+
method = "GET",
53+
path = "/users/123",
54+
})
55+
T.ok(ok, "should pass: " .. tostring(err))
56+
end)
57+
58+
T.describe("issue201: POST /users/123 with valid body (valid)", function()
59+
local ok, err = v:validate_request({
60+
method = "POST",
61+
path = "/users/123",
62+
body = '{"name": "alice"}',
63+
content_type = "application/json",
64+
headers = { ["content-type"] = "application/json" },
65+
})
66+
T.ok(ok, "should pass: " .. tostring(err))
67+
end)
68+
69+
T.describe("issue201: POST /users/123 missing required body field (fail)", function()
70+
local ok, err = v:validate_request({
71+
method = "POST",
72+
path = "/users/123",
73+
body = '{}',
74+
content_type = "application/json",
75+
headers = { ["content-type"] = "application/json" },
76+
})
77+
T.ok(not ok, "should fail - missing required field 'name'")
78+
end)
79+
80+
T.describe("issue201: GET /users/abc string id (valid)", function()
81+
local ok, err = v:validate_request({
82+
method = "GET",
83+
path = "/users/abc",
84+
})
85+
T.ok(ok, "should pass: " .. tostring(err))
86+
end)
87+
88+
T.done()

t/conformance/test_issue639.lua

Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
#!/usr/bin/env resty
2+
--- Conformance test ported from kin-openapi issue639_test.go
3+
-- Tests request body decode edge cases: empty objects, optional bodies,
4+
-- additional properties, and nested object validation.
5+
dofile("t/lib/test_bootstrap.lua")
6+
7+
local T = require("test_helper")
8+
local cjson = require("cjson.safe")
9+
local ov = require("resty.openapi_validator")
10+
11+
local spec = cjson.encode({
12+
openapi = "3.0.0",
13+
info = { title = "Sample API", version = "1.0.0" },
14+
paths = {
15+
["/items"] = {
16+
post = {
17+
requestBody = {
18+
required = false,
19+
content = {
20+
["application/json"] = {
21+
schema = {
22+
type = "object",
23+
properties = {
24+
name = { type = "string" },
25+
count = { type = "integer" },
26+
metadata = {
27+
type = "object",
28+
properties = {
29+
tags = {
30+
type = "array",
31+
items = { type = "string" },
32+
},
33+
nested = {
34+
type = "object",
35+
properties = {
36+
level = { type = "integer" },
37+
},
38+
},
39+
},
40+
},
41+
},
42+
},
43+
},
44+
},
45+
},
46+
responses = { ["200"] = { description = "OK" } },
47+
},
48+
},
49+
["/strict"] = {
50+
post = {
51+
requestBody = {
52+
required = true,
53+
content = {
54+
["application/json"] = {
55+
schema = {
56+
type = "object",
57+
required = { "id" },
58+
properties = {
59+
id = { type = "integer" },
60+
label = { type = "string" },
61+
},
62+
},
63+
},
64+
},
65+
},
66+
responses = { ["200"] = { description = "OK" } },
67+
},
68+
},
69+
},
70+
})
71+
72+
local v = ov.compile(spec)
73+
assert(v, "compile failed")
74+
75+
T.describe("issue639: empty object with only optional properties (valid)", function()
76+
local ok, err = v:validate_request({
77+
method = "POST",
78+
path = "/items",
79+
body = "{}",
80+
content_type = "application/json",
81+
headers = { ["content-type"] = "application/json" },
82+
})
83+
T.ok(ok, "should pass: " .. tostring(err))
84+
end)
85+
86+
T.describe("issue639: no body when body is not required (valid)", function()
87+
local ok, err = v:validate_request({
88+
method = "POST",
89+
path = "/items",
90+
})
91+
T.ok(ok, "should pass: " .. tostring(err))
92+
end)
93+
94+
T.describe("issue639: object with all fields (valid)", function()
95+
local ok, err = v:validate_request({
96+
method = "POST",
97+
path = "/items",
98+
body = '{"name": "widget", "count": 5}',
99+
content_type = "application/json",
100+
headers = { ["content-type"] = "application/json" },
101+
})
102+
T.ok(ok, "should pass: " .. tostring(err))
103+
end)
104+
105+
T.describe("issue639: object with additional properties (valid - no restriction)", function()
106+
local ok, err = v:validate_request({
107+
method = "POST",
108+
path = "/items",
109+
body = '{"name": "widget", "extra_field": "hello", "another": 42}',
110+
content_type = "application/json",
111+
headers = { ["content-type"] = "application/json" },
112+
})
113+
T.ok(ok, "should pass: " .. tostring(err))
114+
end)
115+
116+
T.describe("issue639: deeply nested valid object (valid)", function()
117+
local ok, err = v:validate_request({
118+
method = "POST",
119+
path = "/items",
120+
body = '{"name": "widget", "metadata": {"tags": ["a", "b"], "nested": {"level": 3}}}',
121+
content_type = "application/json",
122+
headers = { ["content-type"] = "application/json" },
123+
})
124+
T.ok(ok, "should pass: " .. tostring(err))
125+
end)
126+
127+
T.describe("issue639: missing required field in /strict (fail)", function()
128+
local ok, err = v:validate_request({
129+
method = "POST",
130+
path = "/strict",
131+
body = '{"label": "test"}',
132+
content_type = "application/json",
133+
headers = { ["content-type"] = "application/json" },
134+
})
135+
T.ok(not ok, "should fail - missing required field 'id'")
136+
end)
137+
138+
T.describe("issue639: valid required field in /strict (valid)", function()
139+
local ok, err = v:validate_request({
140+
method = "POST",
141+
path = "/strict",
142+
body = '{"id": 1, "label": "test"}',
143+
content_type = "application/json",
144+
headers = { ["content-type"] = "application/json" },
145+
})
146+
T.ok(ok, "should pass: " .. tostring(err))
147+
end)
148+
149+
T.done()

0 commit comments

Comments
 (0)