Skip to content

Commit d3e4c3e

Browse files
BV-Venkykaghatim
authored andcommitted
feat(models): add shared strict schema transformation utility
Add ensure_strict_json_schema utility that recursively injects additionalProperties: false on all object types in tool input schemas. Handles $defs, definitions, anyOf, allOf, oneOf, items, and $ref resolution. Includes require_all_properties flag for OpenAI support. From PR strands-agents#1862
1 parent b340dc4 commit d3e4c3e

2 files changed

Lines changed: 302 additions & 0 deletions

File tree

Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
"""Strict JSON schema transformation for tool definitions.
2+
3+
When model providers require `strict: true` on tool definitions, they also require
4+
`"additionalProperties": false` on every `object` type in the input schema. This module
5+
provides a utility to recursively apply that constraint.
6+
7+
Modeled after OpenAI's `_ensure_strict_json_schema`:
8+
https://github.com/openai/openai-python/blob/main/src/openai/lib/_pydantic.py
9+
"""
10+
11+
import copy
12+
import logging
13+
from typing import Any
14+
15+
logger = logging.getLogger(__name__)
16+
17+
18+
def ensure_strict_json_schema(
19+
schema: dict[str, Any],
20+
*,
21+
require_all_properties: bool = False,
22+
) -> dict[str, Any]:
23+
"""Ensure a JSON schema conforms to strict tool use requirements.
24+
25+
Creates a deep copy of the schema and recursively:
26+
1. Adds ``"additionalProperties": false`` to all ``object`` types that do not already define it
27+
2. Optionally adds all properties to the ``required`` array (needed for OpenAI)
28+
3. Handles ``$defs``, ``definitions``, ``anyOf``, ``allOf``, ``items``, and ``$ref``
29+
30+
Args:
31+
schema: The JSON schema to process. A deep copy is made internally so the original is not mutated.
32+
require_all_properties: If True, set ``required`` to include all property keys. OpenAI strict mode
33+
requires this; Bedrock and Anthropic do not.
34+
35+
Returns:
36+
A new schema dict with strict-mode constraints applied.
37+
"""
38+
schema_copy = copy.deepcopy(schema)
39+
_apply_strict(schema_copy, root=schema_copy, require_all_properties=require_all_properties)
40+
return schema_copy
41+
42+
43+
def _apply_strict(
44+
schema: dict[str, Any],
45+
*,
46+
root: dict[str, Any],
47+
require_all_properties: bool,
48+
) -> None:
49+
"""Recursively apply strict-mode constraints to a JSON schema in place.
50+
51+
Args:
52+
schema: The schema node to process (modified in place).
53+
root: The root schema, used for resolving ``$ref`` pointers.
54+
require_all_properties: If True, add all properties to ``required``.
55+
"""
56+
# Process $defs / definitions blocks
57+
for defs_key in ("$defs", "definitions"):
58+
defs = schema.get(defs_key)
59+
if isinstance(defs, dict):
60+
for def_schema in defs.values():
61+
if isinstance(def_schema, dict):
62+
_apply_strict(def_schema, root=root, require_all_properties=require_all_properties)
63+
64+
# Add additionalProperties: false to object types that lack it
65+
if schema.get("type") == "object" and "additionalProperties" not in schema:
66+
schema["additionalProperties"] = False
67+
68+
# Process properties and optionally enforce required
69+
properties = schema.get("properties")
70+
if isinstance(properties, dict):
71+
if require_all_properties:
72+
schema["required"] = list(properties.keys())
73+
74+
for prop_schema in properties.values():
75+
if isinstance(prop_schema, dict):
76+
_apply_strict(prop_schema, root=root, require_all_properties=require_all_properties)
77+
78+
# Process array items
79+
items = schema.get("items")
80+
if isinstance(items, dict):
81+
_apply_strict(items, root=root, require_all_properties=require_all_properties)
82+
83+
# Process anyOf variants
84+
any_of = schema.get("anyOf")
85+
if isinstance(any_of, list):
86+
for variant in any_of:
87+
if isinstance(variant, dict):
88+
_apply_strict(variant, root=root, require_all_properties=require_all_properties)
89+
90+
# Process allOf variants
91+
all_of = schema.get("allOf")
92+
if isinstance(all_of, list):
93+
for entry in all_of:
94+
if isinstance(entry, dict):
95+
_apply_strict(entry, root=root, require_all_properties=require_all_properties)
96+
97+
# Process oneOf variants
98+
one_of = schema.get("oneOf")
99+
if isinstance(one_of, list):
100+
for variant in one_of:
101+
if isinstance(variant, dict):
102+
_apply_strict(variant, root=root, require_all_properties=require_all_properties)
103+
104+
# Resolve $ref combined with other keys by inlining the referenced schema
105+
ref = schema.get("$ref")
106+
if isinstance(ref, str) and len(schema) > 1:
107+
resolved = _resolve_ref(root, ref)
108+
if isinstance(resolved, dict):
109+
# Inline the resolved schema, giving priority to existing keys
110+
merged = {**resolved, **schema}
111+
merged.pop("$ref", None)
112+
schema.clear()
113+
schema.update(merged)
114+
# Re-apply strict to the inlined schema
115+
_apply_strict(schema, root=root, require_all_properties=require_all_properties)
116+
117+
118+
def _resolve_ref(root: dict[str, Any], ref: str) -> dict[str, Any] | None:
119+
"""Resolve a JSON Schema ``$ref`` pointer against the root schema.
120+
121+
Args:
122+
root: The root schema containing definitions.
123+
ref: A JSON pointer string (e.g., ``#/$defs/MyModel``).
124+
125+
Returns:
126+
The resolved schema dict, or None if resolution fails.
127+
"""
128+
if not ref.startswith("#/"):
129+
logger.warning("ref=<%s> | unexpected $ref format, skipping resolution", ref)
130+
return None
131+
132+
path = ref[2:].split("/")
133+
current: Any = root
134+
for key in path:
135+
if not isinstance(current, dict) or key not in current:
136+
logger.warning("ref=<%s> | failed to resolve $ref path", ref)
137+
return None
138+
current = current[key]
139+
140+
if not isinstance(current, dict):
141+
logger.warning("ref=<%s> | resolved to non-dict value", ref)
142+
return None
143+
144+
return current
Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,158 @@
1+
from strands.models._strict_schema import ensure_strict_json_schema
2+
3+
4+
def test_ensure_strict_json_schema_basic():
5+
schema = {
6+
"type": "object",
7+
"properties": {"x": {"type": "string"}},
8+
}
9+
strict_schema = ensure_strict_json_schema(schema)
10+
11+
assert strict_schema["additionalProperties"] is False
12+
assert strict_schema["properties"]["x"] == {"type": "string"}
13+
# Original should be untouched
14+
assert "additionalProperties" not in schema
15+
16+
17+
def test_ensure_strict_json_schema_nested():
18+
schema = {
19+
"type": "object",
20+
"properties": {
21+
"outer": {
22+
"type": "object",
23+
"properties": {"inner": {"type": "integer"}},
24+
}
25+
},
26+
}
27+
strict_schema = ensure_strict_json_schema(schema)
28+
29+
assert strict_schema["additionalProperties"] is False
30+
assert strict_schema["properties"]["outer"]["additionalProperties"] is False
31+
assert strict_schema["properties"]["outer"]["properties"]["inner"] == {"type": "integer"}
32+
33+
34+
def test_ensure_strict_json_schema_with_defs():
35+
schema = {
36+
"type": "object",
37+
"properties": {"item": {"$ref": "#/$defs/MyItem"}},
38+
"$defs": {
39+
"MyItem": {
40+
"type": "object",
41+
"properties": {"name": {"type": "string"}},
42+
}
43+
},
44+
}
45+
strict_schema = ensure_strict_json_schema(schema)
46+
47+
assert strict_schema["additionalProperties"] is False
48+
assert strict_schema["$defs"]["MyItem"]["additionalProperties"] is False
49+
50+
51+
def test_ensure_strict_json_schema_with_ref_inline():
52+
# When a $ref is combined with other keys, the reference should be inlined.
53+
# Note: ensure_strict_json_schema resolves refs from the root schema.
54+
schema = {
55+
"type": "object",
56+
"properties": {
57+
"item": {
58+
"$ref": "#/$defs/MyItem",
59+
"description": "An item",
60+
}
61+
},
62+
"$defs": {
63+
"MyItem": {
64+
"type": "object",
65+
"properties": {"name": {"type": "string"}},
66+
}
67+
},
68+
}
69+
strict_schema = ensure_strict_json_schema(schema)
70+
71+
assert strict_schema["additionalProperties"] is False
72+
# The reference should have been inlined, retaining 'description', and gaining additionalProperties
73+
item_prop = strict_schema["properties"]["item"]
74+
assert "$ref" not in item_prop
75+
assert item_prop["type"] == "object"
76+
assert item_prop["description"] == "An item"
77+
assert item_prop["additionalProperties"] is False
78+
assert item_prop["properties"]["name"] == {"type": "string"}
79+
80+
81+
def test_ensure_strict_json_schema_arrays_and_unions():
82+
schema = {
83+
"type": "object",
84+
"properties": {
85+
"items": {
86+
"type": "array",
87+
"items": {"type": "object", "properties": {"a": {"type": "string"}}},
88+
},
89+
"union": {
90+
"anyOf": [
91+
{"type": "object", "properties": {"b": {"type": "string"}}},
92+
{"type": "object", "properties": {"c": {"type": "string"}}},
93+
]
94+
},
95+
"intersection": {
96+
"allOf": [
97+
{"type": "object", "properties": {"d": {"type": "string"}}},
98+
]
99+
},
100+
},
101+
}
102+
strict_schema = ensure_strict_json_schema(schema)
103+
104+
assert strict_schema["additionalProperties"] is False
105+
assert strict_schema["properties"]["items"]["items"]["additionalProperties"] is False
106+
assert strict_schema["properties"]["union"]["anyOf"][0]["additionalProperties"] is False
107+
assert strict_schema["properties"]["union"]["anyOf"][1]["additionalProperties"] is False
108+
assert strict_schema["properties"]["intersection"]["allOf"][0]["additionalProperties"] is False
109+
110+
111+
def test_ensure_strict_json_schema_require_all_properties():
112+
schema = {
113+
"type": "object",
114+
"properties": {
115+
"required_field": {"type": "string"},
116+
"optional_field": {"type": "string"},
117+
},
118+
"required": ["required_field"],
119+
}
120+
121+
# Test without require_all_properties
122+
strict_schema = ensure_strict_json_schema(schema)
123+
assert strict_schema["required"] == ["required_field"]
124+
125+
# Test with require_all_properties
126+
strict_req = ensure_strict_json_schema(schema, require_all_properties=True)
127+
# The order of keys is typically preserved from the dict iteration
128+
assert set(strict_req["required"]) == {"required_field", "optional_field"}
129+
130+
131+
def test_ensure_strict_json_schema_preserves_additional_properties_true():
132+
schema = {
133+
"type": "object",
134+
"properties": {"x": {"type": "string"}},
135+
"additionalProperties": True,
136+
}
137+
strict_schema = ensure_strict_json_schema(schema)
138+
139+
assert strict_schema["additionalProperties"] is True
140+
141+
142+
def test_ensure_strict_json_schema_oneOf():
143+
schema = {
144+
"type": "object",
145+
"properties": {
146+
"value": {
147+
"oneOf": [
148+
{"type": "object", "properties": {"a": {"type": "string"}}},
149+
{"type": "object", "properties": {"b": {"type": "integer"}}},
150+
]
151+
}
152+
},
153+
}
154+
strict_schema = ensure_strict_json_schema(schema)
155+
156+
assert strict_schema["additionalProperties"] is False
157+
assert strict_schema["properties"]["value"]["oneOf"][0]["additionalProperties"] is False
158+
assert strict_schema["properties"]["value"]["oneOf"][1]["additionalProperties"] is False

0 commit comments

Comments
 (0)