Skip to content

Commit 03bcf23

Browse files
fix(llm): address CoPilot review on PR #42
Addresses the 8 CoPilot review threads on the structured-output PR: - strict_mode_supported now requires additionalProperties to be EXPLICITLY false (not just missing-or-false). Missing implies the JSON Schema default of permitting extras, which OpenAI's strict mode rejects. Pydantic's .model_json_schema() omits the key by default, so the class-input path would have 400ed against OpenAI even with conformance fixtures passing. - _normalize_response_schema now raises ProviderInvalidRequest when the class form is not a BaseModel subclass, instead of letting AttributeError leak from model_json_schema. - validate_response_schema now runs jsonschema.Draft202012Validator .check_schema() at the boundary, wrapping SchemaError as ProviderInvalidRequest. Malformed schemas now fail at the API boundary instead of escaping at decode time. - _derive_schema_name now regex-checks the title against OpenAI's name constraint (^[a-zA-Z0-9_-]{1,64}$) and falls back to the hashed name when the title doesn't match. Sanitizing-in-place would silently mutate user intent; the hash is a more honest fallback. - Two comments claiming Message instances are immutable Pydantic models were updated. The models are not configured with frozen=True; the safety actually comes from the helpers not modifying them in place. - match_wire_body now fails on extra keys in actual. The previous permissive default defeated the point of expected_wire_request being a literal compare; partial assertions continue to live in the sibling expected_wire_request_checks block. - _iter_calls now propagates expected_wire_request, expected_wire_request_checks, response_schema, and retry_middleware from sibling-of-call into the call dict. Only expected was being copied before. Cases-form fixtures with case-level wire expectations were silently running without those assertions. The _iter_calls fix surfaced two pre-existing gaps in the harness's handling of cases-shape fixtures, fixed inline: - The harness was never wiring config from the call spec into provider.complete(); fixture 005's runtime_config_passthrough case was effectively a no-op. - OpenAIProvider was using json.dumps default formatting for tool_call.function.arguments (with spaces after colons), which doesn't match the canonical compact form OpenAI emits or the spec's fixture 005 expectations. Switched to compact form. New unit tests cover the missing-additionalProperties strict-mode case, the non-BaseModel class rejection, the malformed JSON Schema rejection, and the title-falls-back hash cases.
1 parent 1d1e2df commit 03bcf23

5 files changed

Lines changed: 157 additions & 30 deletions

File tree

src/openarmature/llm/provider.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
from collections.abc import Sequence
3939
from typing import Any, Protocol, cast
4040

41+
import jsonschema
4142
from pydantic import BaseModel
4243

4344
from .errors import ProviderInvalidRequest
@@ -184,8 +185,9 @@ def validate_response_schema(schema: object) -> None:
184185
"""Pre-send validation for a JSON Schema passed as the
185186
``response_schema`` argument to ``complete()``.
186187
187-
Raises :class:`ProviderInvalidRequest` if the schema is not a dict
188-
or does not declare a top-level object type.
188+
Raises :class:`ProviderInvalidRequest` if the schema is not a dict,
189+
does not declare a top-level object type, or is not a valid JSON
190+
Schema document.
189191
"""
190192
if not isinstance(schema, dict):
191193
raise ProviderInvalidRequest(f"response_schema: MUST be a dict (got {type(schema).__name__})")
@@ -195,12 +197,23 @@ def validate_response_schema(schema: object) -> None:
195197
raise ProviderInvalidRequest(
196198
f"response_schema: top-level type MUST be 'object' (got {schema_type!r})"
197199
)
200+
# Full JSON Schema validity check at the boundary so a malformed
201+
# schema raises ProviderInvalidRequest here instead of escaping as
202+
# jsonschema.SchemaError at decode time. ValidationError covers
203+
# instance-against-schema failures and is handled separately on the
204+
# parse path.
205+
try:
206+
jsonschema.Draft202012Validator.check_schema(schema_dict)
207+
except jsonschema.SchemaError as exc:
208+
raise ProviderInvalidRequest(f"response_schema: not a valid JSON Schema: {exc.message}") from exc
198209

199210

200211
# Strict mode (OpenAI's response_format strict:true and the analogous
201212
# native-decoding paths in Anthropic / Gemini) requires the schema to
202213
# satisfy two rules at every nested level:
203-
# 1. additionalProperties is NOT true (false or absent).
214+
# 1. additionalProperties is EXPLICITLY false. OpenAI rejects schemas
215+
# where the key is absent, since absence means JSON Schema's
216+
# default of permitting extras.
204217
# 2. every key in `properties` is listed in `required`.
205218
# strict_mode_supported() walks the schema tree (object properties,
206219
# array items, anyOf/oneOf/allOf branches, $ref targets with cycle
@@ -272,7 +285,7 @@ def _strict_mode_check(
272285
)
273286

274287
if is_object_type:
275-
if schema_dict.get("additionalProperties") is True:
288+
if schema_dict.get("additionalProperties") is not False:
276289
return False
277290
properties = schema_dict.get("properties")
278291
if properties is not None and not isinstance(properties, dict):

src/openarmature/llm/providers/openai.py

Lines changed: 33 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242

4343
import hashlib
4444
import json
45+
import re
4546
import uuid
4647
from collections.abc import Sequence
4748
from typing import Any, Literal, cast
@@ -237,9 +238,8 @@ async def complete(
237238
# On the fallback path, the wire-side messages list is an
238239
# augmented COPY of the caller's messages — original messages
239240
# MUST NOT be mutated. _augment_messages_with_schema_directive
240-
# builds a fresh list; the original instances are reused
241-
# (immutable Pydantic models) so the caller's sequence is
242-
# untouched.
241+
# builds a fresh list and does not modify the reused Message
242+
# instances in place; the caller's sequence is untouched.
243243
wire_messages: Sequence[Message] = messages
244244
if schema_dict is not None and self._force_prompt_augmentation_fallback:
245245
wire_messages = _augment_messages_with_schema_directive(messages, schema_dict)
@@ -461,24 +461,38 @@ def _normalize_response_schema(
461461
if response_schema is None:
462462
return None, None
463463
if isinstance(response_schema, type):
464-
# Per the Protocol signature, the only class form accepted is
465-
# a BaseModel subclass; non-BaseModel classes will AttributeError
466-
# on model_json_schema below.
464+
# Defensive runtime check: the Protocol signature accepts
465+
# type[BaseModel], but Python doesn't enforce that at the call
466+
# boundary. Reject non-BaseModel classes with a canonical error
467+
# instead of letting AttributeError leak from model_json_schema.
468+
if not issubclass(response_schema, BaseModel): # pyright: ignore[reportUnnecessaryIsInstance]
469+
raise ProviderInvalidRequest(
470+
f"response_schema: class form MUST be a Pydantic BaseModel subclass "
471+
f"(got {response_schema.__name__})"
472+
)
467473
schema_dict = response_schema.model_json_schema()
468474
validate_response_schema(schema_dict)
469475
return schema_dict, response_schema
470476
validate_response_schema(response_schema)
471477
return response_schema, None
472478

473479

480+
# OpenAI's response_format.json_schema.name field is restricted to
481+
# letters, digits, underscores, and dashes with a max length of 64
482+
# characters. A JSON Schema title can be any string ("Person Record",
483+
# "User's Profile", etc.), so verbatim use risks a 400 on the wire.
484+
_OPENAI_SCHEMA_NAME_RE = re.compile(r"^[a-zA-Z0-9_-]{1,64}$")
485+
486+
474487
# Derive a stable identifier for the JSON Schema for OpenAI's
475488
# response_format.json_schema.name field. Uses the schema's `title`
476-
# when present (and a valid identifier-shaped string); otherwise
477-
# derives a deterministic short hash so the same schema always
478-
# produces the same name across calls.
489+
# when it satisfies the provider's name constraints; otherwise derives
490+
# a deterministic short hash so the same schema always produces the
491+
# same name across calls. Sanitizing-in-place would silently mutate
492+
# user intent; the hash is a more honest fallback.
479493
def _derive_schema_name(schema: dict[str, Any]) -> str:
480494
title = schema.get("title")
481-
if isinstance(title, str) and title:
495+
if isinstance(title, str) and _OPENAI_SCHEMA_NAME_RE.match(title):
482496
return title
483497
canonical = json.dumps(schema, sort_keys=True).encode("utf-8")
484498
return f"oa_schema_{hashlib.sha256(canonical).hexdigest()[:16]}"
@@ -546,9 +560,11 @@ def _parse_and_validate(
546560
# Construct a fresh message list with a schema directive added. The
547561
# directive is appended to the existing system message's content when
548562
# present, or prepended as a new system message otherwise. The caller's
549-
# original list is never mutated; Message instances are reused because
550-
# they are immutable Pydantic models. The serialized schema appears
551-
# verbatim in the directive so callers that need to verify the directive
563+
# original list is never mutated; Message instances are reused, and
564+
# this helper does not modify them in place (the message models are
565+
# not frozen Pydantic models, so the safety is structural, not
566+
# enforced by the type). The serialized schema appears verbatim in
567+
# the directive so callers that need to verify the directive
552568
# references the schema (conformance harnesses, observability spans)
553569
# can substring-match the canonical JSON form.
554570
def _augment_messages_with_schema_directive(
@@ -585,7 +601,10 @@ def _message_to_wire(msg: Message) -> dict[str, Any]:
585601
"type": "function",
586602
"function": {
587603
"name": tc.name,
588-
"arguments": json.dumps(tc.arguments or {}),
604+
# Canonical compact form (no inter-token spaces). Matches
605+
# the spec's wire-mapping fixture (005, cases shape) and
606+
# the form OpenAI itself emits.
607+
"arguments": json.dumps(tc.arguments or {}, separators=(",", ":")),
589608
},
590609
}
591610
for tc in msg.tool_calls

tests/conformance/harness/wire.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,9 +42,10 @@ def match_wire_body(
4242
) -> None:
4343
"""Recursive deep-equal between an actual wire-body value and an
4444
expected shape. Strings equal to ``"*"`` in the expected value match
45-
any non-empty string in the actual value. Keys present in
46-
``expected`` MUST be present in ``actual`` and equal; keys present
47-
in ``actual`` but absent from ``expected`` are allowed.
45+
any non-empty string in the actual value. ``expected_wire_request``
46+
is a literal compare: keys present in ``actual`` but absent from
47+
``expected`` are NOT allowed. Partial assertions belong in the
48+
sibling ``expected_wire_request_checks`` block.
4849
4950
Raises :class:`AssertionError` with a JSON-pointer-style path on
5051
mismatch.
@@ -61,6 +62,9 @@ def match_wire_body(
6162
raise AssertionError(f"wire mismatch at {path}: expected object, got {type(actual).__name__}")
6263
expected_map = cast("Mapping[str, Any]", expected)
6364
actual_map = cast("Mapping[str, Any]", actual)
65+
extra = set(actual_map) - set(expected_map)
66+
if extra:
67+
raise AssertionError(f"wire mismatch at {path}: unexpected extra keys in actual: {sorted(extra)}")
6468
for key, exp_v in expected_map.items():
6569
if key not in actual_map:
6670
raise AssertionError(f"wire mismatch at {path}: missing key {key!r}")

tests/conformance/test_llm_provider.py

Lines changed: 28 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
ProviderInvalidRequest,
3838
ProviderRateLimit,
3939
Response,
40+
RuntimeConfig,
4041
SystemMessage,
4142
Tool,
4243
ToolCall,
@@ -410,23 +411,40 @@ async def _run_one_case(spec: Mapping[str, Any]) -> None:
410411
await provider.aclose()
411412

412413

414+
# Keys that may live as siblings to a ``call:`` block in a cases-shape
415+
# fixture but are conceptually call-level metadata. ``_iter_calls``
416+
# copies these from the case into the yielded call so the test runner
417+
# sees them in one place.
418+
_CASE_LEVEL_CALL_KEYS = (
419+
"expected",
420+
"expected_wire_request",
421+
"expected_wire_request_checks",
422+
"response_schema",
423+
"retry_middleware",
424+
)
425+
426+
413427
def _iter_calls(spec: Mapping[str, Any]) -> Iterator[Mapping[str, Any]]:
414-
"""Yield each call dict with its ``expected`` block attached.
428+
"""Yield each call dict with its case-level metadata attached.
415429
416430
Two shapes the fixtures use:
417431
- ``calls: [{operation, messages, expected, ...}]`` — call and
418432
expected are siblings inside each call entry.
419433
- ``call: {operation, messages, ...}`` + sibling ``expected: ...``
420-
— the case-shape, where expected lives alongside the call.
421-
Both are normalised here to a flat dict where ``expected`` is on
422-
the call.
434+
(and possibly ``expected_wire_request:``, ``response_schema:``,
435+
``retry_middleware:``) — the case-shape, where call-level
436+
metadata lives alongside the call. All sibling keys in
437+
``_CASE_LEVEL_CALL_KEYS`` are folded into the call dict here so
438+
the runner reads them from one place. The nested ``call`` block
439+
takes precedence when both are present.
423440
"""
424441
if "calls" in spec:
425442
yield from cast("list[Mapping[str, Any]]", spec["calls"])
426443
elif "call" in spec:
427444
call = dict(cast("Mapping[str, Any]", spec["call"]))
428-
if "expected" in spec and "expected" not in call:
429-
call["expected"] = spec["expected"]
445+
for key in _CASE_LEVEL_CALL_KEYS:
446+
if key in spec and key not in call:
447+
call[key] = spec[key]
430448
yield call
431449
else:
432450
raise AssertionError("fixture has neither `calls` nor `call` block")
@@ -441,6 +459,8 @@ async def _run_one_call(
441459
expected = cast("Mapping[str, Any]", call_spec.get("expected") or {})
442460
response_schema = call_spec.get("response_schema")
443461
retry_mw_cfg = cast("Mapping[str, Any] | None", call_spec.get("retry_middleware"))
462+
config_block = call_spec.get("config")
463+
config = RuntimeConfig(**cast("Mapping[str, Any]", config_block)) if config_block else None
444464

445465
if operation == "complete":
446466
# Per spec §3 "Validation timing" — complete() validates at
@@ -461,7 +481,7 @@ async def _run_one_call(
461481
except ValidationError as ve:
462482
raise ProviderInvalidRequest(str(ve)) from ve
463483
await _maybe_with_retry(
464-
lambda: provider.complete(messages, tools, response_schema=response_schema),
484+
lambda: provider.complete(messages, tools, config, response_schema=response_schema),
465485
retry_mw_cfg,
466486
)
467487
_assert_raises_matches(excinfo, expected["raises"])
@@ -476,7 +496,7 @@ async def _run_one_call(
476496
messages_snapshot = [m.model_dump(mode="json") for m in messages]
477497
tools = _build_tools(cast("list[Mapping[str, Any]] | None", call_spec.get("tools")))
478498
response = await _maybe_with_retry(
479-
lambda: provider.complete(messages, tools, response_schema=response_schema),
499+
lambda: provider.complete(messages, tools, config, response_schema=response_schema),
480500
retry_mw_cfg,
481501
)
482502
_assert_response_matches(response, cast("Mapping[str, Any]", expected.get("response") or {}))

tests/unit/test_structured_output.py

Lines changed: 72 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,21 @@ def test_validate_response_schema_rejects_missing_type() -> None:
5757
validate_response_schema({"properties": {"x": {"type": "integer"}}})
5858

5959

60+
def test_validate_response_schema_rejects_malformed_schema() -> None:
61+
# `"type": "foobar"` is not a valid JSON Schema type keyword; the
62+
# boundary check should catch this and raise ProviderInvalidRequest
63+
# rather than letting jsonschema.SchemaError leak at parse time.
64+
with pytest.raises(ProviderInvalidRequest, match="not a valid JSON Schema"):
65+
validate_response_schema(
66+
{
67+
"type": "object",
68+
"properties": {"x": {"type": "foobar"}},
69+
"required": ["x"],
70+
"additionalProperties": False,
71+
}
72+
)
73+
74+
6075
# ---------------------------------------------------------------------------
6176
# strict_mode_supported
6277
# ---------------------------------------------------------------------------
@@ -92,6 +107,18 @@ def test_strict_mode_additional_properties_true_fails() -> None:
92107
assert strict_mode_supported(schema) is False
93108

94109

110+
def test_strict_mode_missing_additional_properties_fails() -> None:
111+
# OpenAI strict mode requires additionalProperties: false to be
112+
# EXPLICITLY set; absence (the default for Pydantic-derived schemas)
113+
# is not strict-compatible.
114+
schema = {
115+
"type": "object",
116+
"properties": {"a": {"type": "string"}},
117+
"required": ["a"],
118+
}
119+
assert strict_mode_supported(schema) is False
120+
121+
95122
def test_strict_mode_recurses_into_nested_object() -> None:
96123
schema: dict[str, Any] = {
97124
"type": "object",
@@ -132,10 +159,12 @@ def test_strict_mode_resolves_internal_ref() -> None:
132159
"type": "object",
133160
"properties": {"a": {"type": "string"}},
134161
"required": ["a"],
162+
"additionalProperties": False,
135163
}
136164
},
137165
"properties": {"inner": {"$ref": "#/$defs/Inner"}},
138166
"required": ["inner"],
167+
"additionalProperties": False,
139168
}
140169
assert strict_mode_supported(schema) is True
141170

@@ -153,7 +182,7 @@ def test_strict_mode_handles_ref_cycle() -> None:
153182
# Self-referential schema: each entry has a "children" key pointing
154183
# back to the same definition. Without cycle protection this would
155184
# recurse forever.
156-
schema = {
185+
schema: dict[str, Any] = {
157186
"type": "object",
158187
"$defs": {
159188
"Node": {
@@ -163,10 +192,12 @@ def test_strict_mode_handles_ref_cycle() -> None:
163192
"children": {"$ref": "#/$defs/Node"},
164193
},
165194
"required": ["value", "children"],
195+
"additionalProperties": False,
166196
}
167197
},
168198
"properties": {"root": {"$ref": "#/$defs/Node"}},
169199
"required": ["root"],
200+
"additionalProperties": False,
170201
}
171202
assert strict_mode_supported(schema) is True
172203

@@ -198,6 +229,28 @@ def test_derive_schema_name_ignores_empty_title() -> None:
198229
assert _derive_schema_name(schema).startswith("oa_schema_")
199230

200231

232+
def test_derive_schema_name_falls_back_on_title_with_spaces() -> None:
233+
# OpenAI's name field rejects spaces; the hash fallback fires.
234+
schema = {
235+
"type": "object",
236+
"title": "Person Record",
237+
"properties": {"x": {"type": "string"}},
238+
"required": ["x"],
239+
}
240+
assert _derive_schema_name(schema).startswith("oa_schema_")
241+
242+
243+
def test_derive_schema_name_falls_back_on_title_too_long() -> None:
244+
# OpenAI's name field has a 64-char cap; longer titles fall back.
245+
schema = {
246+
"type": "object",
247+
"title": "A" * 65,
248+
"properties": {"x": {"type": "string"}},
249+
"required": ["x"],
250+
}
251+
assert _derive_schema_name(schema).startswith("oa_schema_")
252+
253+
201254
# ---------------------------------------------------------------------------
202255
# _augment_messages_with_schema_directive
203256
# ---------------------------------------------------------------------------
@@ -273,6 +326,24 @@ def handler(request: httpx.Request) -> httpx.Response:
273326
return httpx.MockTransport(handler)
274327

275328

329+
async def test_non_basemodel_class_raises_provider_invalid_request() -> None:
330+
transport = _mock_chat_completion_response('{"x":1}')
331+
provider = OpenAIProvider(
332+
base_url="http://mock-llm.test",
333+
model="test-model",
334+
api_key="test-key",
335+
transport=transport,
336+
)
337+
try:
338+
with pytest.raises(ProviderInvalidRequest, match="BaseModel subclass"):
339+
await provider.complete(
340+
[UserMessage(content="x")],
341+
response_schema=str, # type: ignore[arg-type]
342+
)
343+
finally:
344+
await provider.aclose()
345+
346+
276347
async def test_pydantic_class_returns_validated_instance() -> None:
277348
transport = _mock_chat_completion_response('{"name":"Alice","age":30}')
278349
provider = OpenAIProvider(

0 commit comments

Comments
 (0)