Skip to content

Commit ddef61e

Browse files
BV-Venkykaghatim
andcommitted
feat(models): add shared strict schema transformation utility
Add ensure_strict_json_schema utility that recursively injects additionalProperties: false on all object types in tool input schemas. Handles $defs, definitions, anyOf, allOf, oneOf, items, and $ref resolution with deep copy safety. Includes require_all_properties flag for OpenAI support. From PR strands-agents#1862 Co-authored-by: Hatim Kagalwala <kaghatim@amazon.com>
1 parent b340dc4 commit ddef61e

2 files changed

Lines changed: 446 additions & 0 deletions

File tree

Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
"""Strict JSON schema transformation for tool definitions.
2+
3+
When model providers require `strict: true` on tool definitions, they also require
4+
`"additionalProperties": false` on every `object` type in the input schema. This module
5+
provides a utility to recursively apply that constraint.
6+
7+
Modeled after OpenAI's `_ensure_strict_json_schema`:
8+
https://github.com/openai/openai-python/blob/main/src/openai/lib/_pydantic.py
9+
"""
10+
11+
import copy
12+
import logging
13+
from typing import Any
14+
15+
logger = logging.getLogger(__name__)
16+
17+
18+
def ensure_strict_json_schema(
19+
schema: dict[str, Any],
20+
*,
21+
require_all_properties: bool = False,
22+
) -> dict[str, Any]:
23+
"""Ensure a JSON schema conforms to strict tool use requirements.
24+
25+
Creates a deep copy of the schema and recursively:
26+
1. Adds ``"additionalProperties": false`` to all ``object`` types that do not already define it
27+
2. Optionally adds all properties to the ``required`` array (needed for OpenAI)
28+
3. Handles ``$defs``, ``definitions``, ``anyOf``, ``allOf``, ``items``, and ``$ref``
29+
30+
Args:
31+
schema: The JSON schema to process. A deep copy is made internally so the original is not mutated.
32+
require_all_properties: If True, set ``required`` to include all property keys. OpenAI strict mode
33+
requires this; Bedrock and Anthropic do not.
34+
35+
Returns:
36+
A new schema dict with strict-mode constraints applied.
37+
"""
38+
schema_copy = copy.deepcopy(schema)
39+
_apply_strict(schema_copy, root=schema_copy, require_all_properties=require_all_properties)
40+
return schema_copy
41+
42+
43+
def _apply_strict(
44+
schema: dict[str, Any],
45+
*,
46+
root: dict[str, Any],
47+
require_all_properties: bool,
48+
) -> None:
49+
"""Recursively apply strict-mode constraints to a JSON schema in place.
50+
51+
Args:
52+
schema: The schema node to process (modified in place).
53+
root: The root schema, used for resolving ``$ref`` pointers.
54+
require_all_properties: If True, add all properties to ``required``.
55+
"""
56+
# Process $defs / definitions blocks
57+
for defs_key in ("$defs", "definitions"):
58+
defs = schema.get(defs_key)
59+
if isinstance(defs, dict):
60+
for def_schema in defs.values():
61+
if isinstance(def_schema, dict):
62+
_apply_strict(def_schema, root=root, require_all_properties=require_all_properties)
63+
64+
# Add additionalProperties: false to object types that lack it
65+
if schema.get("type") == "object" and "additionalProperties" not in schema:
66+
schema["additionalProperties"] = False
67+
68+
# Process properties and optionally enforce required
69+
properties = schema.get("properties")
70+
if isinstance(properties, dict):
71+
if require_all_properties:
72+
schema["required"] = list(properties.keys())
73+
74+
for prop_schema in properties.values():
75+
if isinstance(prop_schema, dict):
76+
_apply_strict(prop_schema, root=root, require_all_properties=require_all_properties)
77+
78+
# Process array items
79+
items = schema.get("items")
80+
if isinstance(items, dict):
81+
_apply_strict(items, root=root, require_all_properties=require_all_properties)
82+
83+
# Process anyOf variants
84+
any_of = schema.get("anyOf")
85+
if isinstance(any_of, list):
86+
for variant in any_of:
87+
if isinstance(variant, dict):
88+
_apply_strict(variant, root=root, require_all_properties=require_all_properties)
89+
90+
# Process allOf variants
91+
all_of = schema.get("allOf")
92+
if isinstance(all_of, list):
93+
for entry in all_of:
94+
if isinstance(entry, dict):
95+
_apply_strict(entry, root=root, require_all_properties=require_all_properties)
96+
97+
# Process oneOf variants
98+
one_of = schema.get("oneOf")
99+
if isinstance(one_of, list):
100+
for variant in one_of:
101+
if isinstance(variant, dict):
102+
_apply_strict(variant, root=root, require_all_properties=require_all_properties)
103+
104+
# Resolve $ref combined with other keys by inlining the referenced schema
105+
ref = schema.get("$ref")
106+
if isinstance(ref, str) and len(schema) > 1:
107+
resolved = _resolve_ref(root, ref)
108+
if isinstance(resolved, dict):
109+
# Inline the resolved schema, giving priority to existing keys
110+
merged = {**copy.deepcopy(resolved), **schema}
111+
merged.pop("$ref", None)
112+
schema.clear()
113+
schema.update(merged)
114+
# Re-apply strict to the inlined schema
115+
_apply_strict(schema, root=root, require_all_properties=require_all_properties)
116+
117+
118+
def _resolve_ref(root: dict[str, Any], ref: str) -> dict[str, Any] | None:
119+
"""Resolve a JSON Schema ``$ref`` pointer against the root schema.
120+
121+
Args:
122+
root: The root schema containing definitions.
123+
ref: A JSON pointer string (e.g., ``#/$defs/MyModel``).
124+
125+
Returns:
126+
The resolved schema dict, or None if resolution fails.
127+
"""
128+
if not ref.startswith("#/"):
129+
logger.warning("ref=<%s> | unexpected $ref format, skipping resolution", ref)
130+
return None
131+
132+
path = ref[2:].split("/")
133+
current: Any = root
134+
for key in path:
135+
if not isinstance(current, dict) or key not in current:
136+
logger.warning("ref=<%s> | failed to resolve $ref path", ref)
137+
return None
138+
current = current[key]
139+
140+
if not isinstance(current, dict):
141+
logger.warning("ref=<%s> | resolved to non-dict value", ref)
142+
return None
143+
144+
return current

0 commit comments

Comments
 (0)