Skip to content

Commit 8ed334c

Browse files
fix(llm): second CoPilot review pass on PR #42
Addresses 19 review threads from the second CoPilot pass; about half were duplicates of the same underlying issue: - examples/00-hello-world/main.py + README hello-world: api_key now uses `os.environ.get("LLM_API_KEY") or None` so an exported-but- empty env var falls through to no-auth (matters for local servers that reject an empty bearer header). - Both examples now close the OpenAIProvider in the finally block alongside graph.drain(). Long-running consumers that copy the snippet had been leaking the underlying httpx.AsyncClient. - errors.py header dropped the hard-coded "seven canonical categories" count after StructuredOutputInvalid landed. - strict_mode_supported docstring and the surrounding spec-anchor comment block both updated to match the implementation: additionalProperties must be EXPLICITLY false (an omitted key counts as non-strict, since JSON Schema's default permits extras). - _resolve_ref now handles ref == "#" as the document root before rejecting external refs. Root-recursive schemas that use the bare JSON-Pointer-root form now resolve correctly. Unit test added. - _strict_mode_check tightened to return False on unrecognized shapes (empty {}, const-only, enum-only, unknown keywords) instead of falling through to True. Primitive types (string/integer/ number/boolean/null) classified as terminal-strict-compatible. Two unit tests added. - _build_request_body now explicitly strips response_format from the body when the provider is in fallback mode. RuntimeConfig is extra="allow", so a caller could have piped response_format through the extras loop past the include_response_format gate. - provider.py module docstring's summary signature line updated to match the Protocol's response_schema parameter. - validate_response_schema's spec-anchor comment updated to reflect that JSON Schema validity is now checked at the boundary via Draft202012Validator.check_schema(), not delegated to parse time. - test_pydantic_class_wire_body_matches_dict_form: widened the assertion from response_format-only to full body equality, so any regression in the class-input wire mapping (not just response_format) gets caught. - test_inspect_property_native_default and test_inspect_property_fallback_when_forced converted to async with try/finally + aclose() to match the rest of the file's provider-lifecycle pattern.
1 parent 58f6c2f commit 8ed334c

6 files changed

Lines changed: 115 additions & 24 deletions

File tree

README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ class PipelineState(State):
9494
provider = OpenAIProvider(
9595
base_url=os.environ.get("LLM_BASE_URL", "https://api.openai.com"), # host root; impl adds /v1
9696
model=os.environ.get("LLM_MODEL", "gpt-4o-mini"),
97-
api_key=os.environ.get("LLM_API_KEY"),
97+
api_key=os.environ.get("LLM_API_KEY") or None, # empty → no-auth
9898
)
9999

100100

@@ -168,6 +168,7 @@ async def main() -> None:
168168
print(f"summary: {final.summary}")
169169
finally:
170170
await graph.drain()
171+
await provider.aclose()
171172

172173

173174
asyncio.run(main())

examples/00-hello-world/main.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,10 @@ class PipelineState(State):
8181
_provider = OpenAIProvider(
8282
base_url=os.environ.get("LLM_BASE_URL", "https://api.openai.com"),
8383
model=os.environ.get("LLM_MODEL", "gpt-4o-mini"),
84-
api_key=os.environ.get("LLM_API_KEY"),
84+
# ``or None`` so an exported-but-empty LLM_API_KEY falls through to
85+
# no-auth (matters for local servers like vLLM that reject an empty
86+
# bearer header).
87+
api_key=os.environ.get("LLM_API_KEY") or None,
8588
)
8689

8790

@@ -197,6 +200,7 @@ async def main() -> None:
197200
print(f"metadata: {final.metadata}")
198201
finally:
199202
await graph.drain()
203+
await _provider.aclose()
200204

201205

202206
if __name__ == "__main__":

src/openarmature/llm/errors.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
1-
# Spec: realizes llm-provider §7 (seven canonical error categories).
1+
# Spec: realizes llm-provider §7 (canonical error categories).
22

33
"""Errors raised by an llm-provider implementation.
44
5-
A provider call (``ready()`` or ``complete()``) MAY raise one of
6-
seven canonical category errors. Each error class carries a
7-
``category`` class attribute matching the canonical string identifier
8-
so callers can dispatch on the category without matching exception
9-
types directly.
5+
A provider call (``ready()`` or ``complete()``) MAY raise one of the
6+
canonical category errors documented below. Each error class carries
7+
a ``category`` class attribute matching the canonical string
8+
identifier so callers can dispatch on the category without matching
9+
exception types directly.
1010
1111
This module is also the single source of truth for the canonical
12-
category strings :data:`TRANSIENT_CATEGORIES` lives here, and
12+
category strings; :data:`TRANSIENT_CATEGORIES` lives here, and
1313
``openarmature.graph.middleware.retry``'s default classifier imports
1414
it.
1515
"""

src/openarmature/llm/provider.py

Lines changed: 45 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,11 @@
1414
A successful return implies the next ``complete()`` would not
1515
raise errors that surface mismatched configuration or unloaded
1616
state.
17-
- ``async complete(messages, tools=None, config=None) -> Response``
18-
— performs a single completion. Stateless, reentrant, MUST NOT
19-
mutate its inputs.
17+
- ``async complete(messages, tools=None, config=None, response_schema=None) -> Response``
18+
performs a single completion. Stateless, reentrant, MUST NOT mutate
19+
its inputs. When ``response_schema`` is supplied (a JSON Schema
20+
dict or Pydantic class), the implementation constrains the model's
21+
output and populates ``Response.parsed``.
2022
2123
This module also exports :func:`validate_message_list`: a list-level
2224
invariant check that complements per-message Pydantic validation. A
@@ -177,9 +179,12 @@ def validate_tools(tools: Sequence[Tool] | None) -> None:
177179

178180
# Spec llm-provider §5 requires the response_schema argument to
179181
# complete() to be a valid JSON Schema with a top-level type "object".
180-
# The pre-send check here is the structural minimum; deeper validity
181-
# (recursive JSON Schema correctness, vendor extensions) is delegated
182-
# to the runtime validator at parse time.
182+
# The boundary check here validates BOTH constraints: structural
183+
# (must be a dict with top-level type: "object") AND full JSON Schema
184+
# validity via Draft202012Validator.check_schema(). The runtime
185+
# validator on the parse path only handles instance-against-schema
186+
# failures; malformed schemas fail here rather than escaping at decode
187+
# time as jsonschema.SchemaError.
183188
def validate_response_schema(schema: object) -> None:
184189
"""Pre-send validation for a JSON Schema passed as the
185190
``response_schema`` argument to ``complete()``.
@@ -224,9 +229,11 @@ def strict_mode_supported(schema: dict[str, Any]) -> bool:
224229
by native-decoding LLM wire paths.
225230
226231
Returns True iff for every nested (sub)schema in the tree
227-
``additionalProperties`` is not ``true`` and every key in
228-
``properties`` appears in ``required``. False on any violation, on
229-
an unresolvable ``$ref``, or on an unknown shape.
232+
``additionalProperties`` is explicitly ``false`` (an omitted key
233+
counts as non-strict, since JSON Schema's default is to permit
234+
extras) and every key in ``properties`` appears in ``required``.
235+
False on any violation, on an unresolvable ``$ref``, or on an
236+
unknown shape.
230237
231238
Args:
232239
schema: The root JSON Schema dict.
@@ -238,6 +245,13 @@ def strict_mode_supported(schema: dict[str, Any]) -> bool:
238245
return _strict_mode_check(schema, root=schema, visited=set())
239246

240247

248+
# JSON Schema primitive types: terminal-strict-compatible because they
249+
# carry no nested structure to verify. Object/array types have their
250+
# own branch checks; anything else (const, enum, unknown keywords,
251+
# empty {}) is conservatively non-strict.
252+
_PRIMITIVE_TYPES = frozenset({"string", "integer", "number", "boolean", "null"})
253+
254+
241255
def _strict_mode_check(
242256
schema: Any,
243257
*,
@@ -311,7 +325,23 @@ def _strict_mode_check(
311325
if not _strict_mode_check(item, root=root, visited=visited):
312326
return False
313327

314-
return True
328+
# Determine whether the schema declared a shape we know how to
329+
# verify. Object/array branches above already returned False on
330+
# any internal violation; reaching here means all internal checks
331+
# passed. Combinators with all branches passing are likewise OK.
332+
# Primitive types are terminal. Anything else (empty schema,
333+
# `const`/`enum`-only, unknown keywords) is conservatively
334+
# non-strict — the walker can't statically verify it.
335+
has_combinator = any(k in schema_dict for k in ("anyOf", "oneOf", "allOf"))
336+
if is_object_type or is_array_type or has_combinator:
337+
return True
338+
if isinstance(schema_type, str) and schema_type in _PRIMITIVE_TYPES:
339+
return True
340+
if isinstance(schema_type, list) and all(
341+
isinstance(t, str) and t in _PRIMITIVE_TYPES for t in cast("list[Any]", schema_type)
342+
):
343+
return True
344+
return False
315345

316346

317347
# Internal-only $ref resolver. Handles JSON Pointer fragments rooted
@@ -320,6 +350,11 @@ def _strict_mode_check(
320350
# None. JSON Pointer escape rules (`~0` for `~`, `~1` for `/`) are
321351
# unescaped per RFC 6901.
322352
def _resolve_ref(ref: str, root: dict[str, Any]) -> dict[str, Any] | None:
353+
# Bare "#" is the JSON Pointer for the document root; "#/" prefixes
354+
# an internal path. Anything else (external URIs, relative refs we
355+
# can't resolve without a base) we treat as unresolvable.
356+
if ref == "#":
357+
return root
323358
if not ref.startswith("#/"):
324359
return None
325360
parts = ref[2:].split("/")

src/openarmature/llm/providers/openai.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -360,6 +360,13 @@ def _build_request_body(
360360
"strict": strict_mode_supported(schema_dict),
361361
},
362362
}
363+
elif not include_response_format:
364+
# On the fallback path the §8.5.1 contract is "response_format
365+
# MUST NOT be on the wire." RuntimeConfig is extra="allow" so
366+
# a caller could pass response_format through via the extras
367+
# loop above; strip it here so the fallback contract holds
368+
# regardless of caller-supplied extras.
369+
body.pop("response_format", None)
363370
return body
364371

365372
# ------------------------------------------------------------------

tests/unit/test_structured_output.py

Lines changed: 49 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,44 @@ def test_strict_mode_unresolvable_ref_fails() -> None:
178178
assert strict_mode_supported(schema) is False
179179

180180

181+
def test_strict_mode_empty_property_schema_fails() -> None:
182+
# A property schema of {} (matches anything) cannot be statically
183+
# verified as strict-compatible. The walker should return False
184+
# rather than fall through to True.
185+
schema: dict[str, Any] = {
186+
"type": "object",
187+
"properties": {"x": {}},
188+
"required": ["x"],
189+
"additionalProperties": False,
190+
}
191+
assert strict_mode_supported(schema) is False
192+
193+
194+
def test_strict_mode_primitive_property_passes() -> None:
195+
# Primitive types (string, integer, number, boolean, null) carry no
196+
# nested structure to verify, so they are terminal-strict-compatible.
197+
schema: dict[str, Any] = {
198+
"type": "object",
199+
"properties": {"name": {"type": "string"}, "age": {"type": "integer"}},
200+
"required": ["name", "age"],
201+
"additionalProperties": False,
202+
}
203+
assert strict_mode_supported(schema) is True
204+
205+
206+
def test_strict_mode_resolves_bare_root_ref() -> None:
207+
# JSON Pointer "#" is a valid reference to the document root
208+
# (RFC 6901). A schema using $ref: "#" for self-recursion should
209+
# resolve through and inherit the root's strict-mode status.
210+
schema: dict[str, Any] = {
211+
"type": "object",
212+
"properties": {"value": {"type": "string"}, "self": {"$ref": "#"}},
213+
"required": ["value", "self"],
214+
"additionalProperties": False,
215+
}
216+
assert strict_mode_supported(schema) is True
217+
218+
181219
def test_strict_mode_handles_ref_cycle() -> None:
182220
# Self-referential schema: each entry has a "children" key pointing
183221
# back to the same definition. Without cycle protection this would
@@ -445,28 +483,34 @@ def handler_dict(request: httpx.Request) -> httpx.Response:
445483

446484
body_class = json.loads(captured_class[0].content)
447485
body_dict = json.loads(captured_dict[0].content)
448-
assert body_class["response_format"] == body_dict["response_format"]
486+
assert body_class == body_dict
449487

450488

451489
# ---------------------------------------------------------------------------
452490
# uses_prompt_augmentation_fallback inspect property
453491
# ---------------------------------------------------------------------------
454492

455493

456-
def test_inspect_property_native_default() -> None:
494+
async def test_inspect_property_native_default() -> None:
457495
provider = OpenAIProvider(
458496
base_url="http://mock-llm.test",
459497
model="test-model",
460498
api_key="test-key",
461499
)
462-
assert provider.uses_prompt_augmentation_fallback is False
500+
try:
501+
assert provider.uses_prompt_augmentation_fallback is False
502+
finally:
503+
await provider.aclose()
463504

464505

465-
def test_inspect_property_fallback_when_forced() -> None:
506+
async def test_inspect_property_fallback_when_forced() -> None:
466507
provider = OpenAIProvider(
467508
base_url="http://mock-llm.test",
468509
model="test-model",
469510
api_key="test-key",
470511
force_prompt_augmentation_fallback=True,
471512
)
472-
assert provider.uses_prompt_augmentation_fallback is True
513+
try:
514+
assert provider.uses_prompt_augmentation_fallback is True
515+
finally:
516+
await provider.aclose()

0 commit comments

Comments
 (0)