Skip to content

Commit 96d221a

Browse files
authored
feat(core): add PartialValidationResult with tri-state semantics (#924)
* feat(core): add PartialValidationResult with tri-state semantics Adds PartialValidationResult to mellea/core/requirement.py — a tri-state result type for per-chunk streaming validation. Follows the same private-fields-with-property-accessors pattern as ValidationResult. success: Literal["pass", "fail", "unknown"]. __bool__ returns True for "pass", False for "fail" and "unknown". "pass" is informational only in Phase 1; orchestrators call validate() at stream end for all non-"fail" results. Closes #898 Part of #891 Assisted-by: Claude Code Signed-off-by: Nigel Jones <jonesn@uk.ibm.com> * fix(core): add explicit Literal type annotation to PartialValidationResult._success Pyright infers self._success as str without the annotation because it doesn't narrow Literal types through bare instance attribute assignment. The explicit annotation makes the property return type verifiable. Assisted-by: Claude Code Signed-off-by: Nigel Jones <jonesn@uk.ibm.com> * fix(core): address PR #924 review feedback on PartialValidationResult - Update mellea/core/__init__.py module docstring to mention PartialValidationResult - Update mellea/core/requirement.py module docstring to mention PartialValidationResult - Document ValidationResult/PartialValidationResult API asymmetry in class docstring - Add runtime validation of success parameter (fail fast on invalid values) - Expand as_bool() docstring with streaming-context warning for "unknown" → False - Add __repr__ for opaque debug output in streaming pipelines - Replace test_as_bool_matches_bool loop with parametrised test_as_bool_correctness - Add test_invalid_success_raises and test_repr_shows_state Assisted-by: Claude Code Signed-off-by: Nigel Jones <jonesn@uk.ibm.com> * fix(core): add ValidationResult.__repr__ and thunk/context tests - Add __repr__ to ValidationResult for parity with PartialValidationResult - Add test_thunk_field and test_context_field to close the test coverage gap for keyword-only constructor arguments (previously only default-None was verified) Assisted-by: Claude Code Signed-off-by: Nigel Jones <jonesn@uk.ibm.com> --------- Signed-off-by: Nigel Jones <jonesn@uk.ibm.com>
1 parent 50b106c commit 96d221a

3 files changed

Lines changed: 180 additions & 2 deletions

File tree

mellea/core/__init__.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
other layer of mellea is built: the ``Backend``, ``Formatter``, and
55
``SamplingStrategy`` protocols; the ``Component``, ``CBlock``, ``Context``, and
66
``ModelOutputThunk`` data types that flow through the inference pipeline; and
7-
``Requirement`` / ``ValidationResult`` for constrained generation. Start here when
7+
``Requirement`` / ``ValidationResult`` / ``PartialValidationResult`` for constrained generation. Start here when
88
building a new backend, formatter, or sampling strategy, or when you need the type
99
definitions shared across the library.
1010
"""
@@ -29,7 +29,12 @@
2929
blockify,
3030
)
3131
from .formatter import Formatter
32-
from .requirement import Requirement, ValidationResult, default_output_to_bool
32+
from .requirement import (
33+
PartialValidationResult,
34+
Requirement,
35+
ValidationResult,
36+
default_output_to_bool,
37+
)
3338
from .sampling import SamplingResult, SamplingStrategy
3439
from .utils import MelleaLogger, clear_log_context, log_context, set_log_context
3540

@@ -66,6 +71,7 @@ def __getattr__(name: str) -> object:
6671
"MelleaLogger",
6772
"ModelOutputThunk",
6873
"ModelToolCall",
74+
"PartialValidationResult",
6975
"Requirement",
7076
"S",
7177
"SamplingResult",

mellea/core/requirement.py

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,16 @@
44
inspects a ``Context`` (and optionally a backend) to determine whether a model output
55
meets a constraint. ``ValidationResult`` carries the pass/fail verdict along with an
66
optional reason, score, and the ``ModelOutputThunk`` produced during validation.
7+
``PartialValidationResult`` provides a tri-state variant (``"pass"``, ``"fail"``,
8+
``"unknown"``) for per-chunk streaming validation.
79
Helper factories such as ``default_output_to_bool`` make it easy to build requirements
810
without boilerplate.
911
"""
1012

1113
import re
1214
from collections.abc import Callable
1315
from copy import copy
16+
from typing import Literal
1417

1518
from .backend import Backend, BaseModelSubclass
1619
from .base import CBlock, Component, Context, ModelOutputThunk, TemplateRepresentation
@@ -76,6 +79,97 @@ def __bool__(self) -> bool:
7679
"""Return a boolean value based on the result."""
7780
return self.as_bool()
7881

82+
def __repr__(self) -> str:
83+
"""Return a developer-readable representation of the validation result."""
84+
return f"ValidationResult({self._result!r}, reason={self._reason!r}, score={self._score!r})"
85+
86+
87+
class PartialValidationResult:
88+
"""Tri-state result from per-chunk streaming validation.
89+
90+
Unlike :class:`ValidationResult`, which stores its verdict as a private
91+
``_result: bool``, this class exposes ``success`` as a public property.
92+
The asymmetry is intentional: the tri-state value cannot be recovered from
93+
a ``bool``, so a public property is the only way to distinguish ``"fail"``
94+
from ``"unknown"`` after construction.
95+
96+
Args:
97+
success: Validation state — ``"pass"`` (constraint satisfied so far),
98+
``"fail"`` (constraint violated, stop streaming), or
99+
``"unknown"`` (insufficient data yet, continue streaming).
100+
reason: Optional human-readable explanation.
101+
score: Optional numeric confidence score.
102+
thunk: Optional ModelOutputThunk from the validation call.
103+
context: Optional context associated with the validation call.
104+
105+
"""
106+
107+
def __init__(
108+
self,
109+
success: Literal["pass", "fail", "unknown"],
110+
*,
111+
reason: str | None = None,
112+
score: float | None = None,
113+
thunk: ModelOutputThunk | None = None,
114+
context: Context | None = None,
115+
):
116+
"""Initialize PartialValidationResult with a tri-state success value."""
117+
if success not in ("pass", "fail", "unknown"):
118+
raise ValueError(
119+
f"success must be 'pass', 'fail', or 'unknown', got {success!r}"
120+
)
121+
self._success: Literal["pass", "fail", "unknown"] = success
122+
self._reason = reason
123+
self._score = score
124+
self._thunk = thunk
125+
self._context = context
126+
127+
@property
128+
def success(self) -> Literal["pass", "fail", "unknown"]:
129+
"""The tri-state validation result."""
130+
return self._success
131+
132+
@property
133+
def reason(self) -> str | None:
134+
"""Reason for the validation result."""
135+
return self._reason
136+
137+
@property
138+
def score(self) -> float | None:
139+
"""An optional score for the validation result."""
140+
return self._score
141+
142+
@property
143+
def thunk(self) -> ModelOutputThunk | None:
144+
"""The ModelOutputThunk associated with the validation call, if any."""
145+
return self._thunk
146+
147+
@property
148+
def context(self) -> Context | None:
149+
"""The context associated with the validation call, if any."""
150+
return self._context
151+
152+
def as_bool(self) -> bool:
153+
"""Return True for ``"pass"``, False for ``"fail"`` or ``"unknown"``.
154+
155+
``"unknown"`` maps to ``False`` intentionally. In streaming contexts,
156+
check ``pvr.success == "unknown"`` before treating ``False`` as a definitive
157+
failure — ``"unknown"`` means insufficient data so far, not a constraint
158+
violation.
159+
160+
Returns:
161+
bool: ``True`` if the partial result is ``"pass"``, ``False`` otherwise.
162+
"""
163+
return self._success == "pass"
164+
165+
def __bool__(self) -> bool:
166+
"""Return a boolean value based on the success state."""
167+
return self.as_bool()
168+
169+
def __repr__(self) -> str:
170+
"""Return a developer-readable representation showing the tri-state value."""
171+
return f"PartialValidationResult({self._success!r}, reason={self._reason!r}, score={self._score!r})"
172+
79173

80174
def default_output_to_bool(x: CBlock | str) -> bool:
81175
"""Convert a model output string to a boolean by checking for a "yes" answer.
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
"""Unit tests for PartialValidationResult tri-state semantics."""
2+
3+
import pytest
4+
5+
from mellea.core import PartialValidationResult
6+
7+
8+
def test_pass_state():
9+
pvr = PartialValidationResult("pass")
10+
assert pvr.success == "pass"
11+
assert pvr.as_bool() is True
12+
assert bool(pvr) is True
13+
14+
15+
def test_fail_state():
16+
pvr = PartialValidationResult("fail")
17+
assert pvr.success == "fail"
18+
assert pvr.as_bool() is False
19+
assert bool(pvr) is False
20+
21+
22+
def test_unknown_state():
23+
pvr = PartialValidationResult("unknown")
24+
assert pvr.success == "unknown"
25+
assert pvr.as_bool() is False
26+
assert bool(pvr) is False
27+
28+
29+
def test_default_optional_fields_are_none():
30+
pvr = PartialValidationResult("unknown")
31+
assert pvr.reason is None
32+
assert pvr.score is None
33+
assert pvr.thunk is None
34+
assert pvr.context is None
35+
36+
37+
def test_reason_field():
38+
pvr = PartialValidationResult("fail", reason="Too short")
39+
assert pvr.reason == "Too short"
40+
41+
42+
def test_score_field():
43+
pvr = PartialValidationResult("pass", score=0.95)
44+
assert pvr.score == 0.95
45+
46+
47+
@pytest.mark.parametrize(
48+
("state", "expected"), [("pass", True), ("fail", False), ("unknown", False)]
49+
)
50+
def test_as_bool_correctness(state: str, expected: bool) -> None:
51+
pvr = PartialValidationResult(state) # type: ignore[arg-type]
52+
assert pvr.as_bool() is expected
53+
assert bool(pvr) is expected
54+
55+
56+
def test_invalid_success_raises() -> None:
57+
with pytest.raises(ValueError, match="success must be"):
58+
PartialValidationResult("maybe") # type: ignore[arg-type]
59+
60+
61+
def test_repr_shows_state() -> None:
62+
pvr = PartialValidationResult("fail", reason="too short", score=0.1)
63+
r = repr(pvr)
64+
assert "'fail'" in r
65+
assert "too short" in r
66+
assert "0.1" in r
67+
68+
69+
def test_thunk_field() -> None:
70+
sentinel = object()
71+
pvr = PartialValidationResult("pass", thunk=sentinel) # type: ignore[arg-type]
72+
assert pvr.thunk is sentinel
73+
74+
75+
def test_context_field() -> None:
76+
sentinel = object()
77+
pvr = PartialValidationResult("pass", context=sentinel) # type: ignore[arg-type]
78+
assert pvr.context is sentinel

0 commit comments

Comments
 (0)