Skip to content

Commit 7912a1d

Browse files
authored
feat: add stream_validate() hook to Requirement (#900) (#925)
* feat: add stream_validate() hook to Requirement (#900) Add an async `stream_validate(chunk, backend, ctx)` method to the base `Requirement` class. The default implementation returns `PartialValidationResult("unknown")`; subclasses override to inspect the accumulated chunk and return `"pass"` or `"fail"` early. Per the Phase 1 design: `"pass"` is informational and does not short-circuit the final `validate()` call. The method must not mutate `self` — state isolation is the orchestrator's responsibility. Signed-off-by: Nigel Jones <jonesn@uk.ibm.com> Assisted-by: Claude Code Signed-off-by: Nigel Jones <jonesn@uk.ibm.com> * fix(core): address PR #925 review feedback on stream_validate - Remove "In Phase 1" temporal qualifier from docstring — reworded to timeless statement about orchestrator responsibility - Add type annotations (str, Backend, Context) to test subclass overrides - Add idempotency test: multiple calls on the same Requirement instance leave state unchanged Assisted-by: Claude Code Signed-off-by: Nigel Jones <jonesn@uk.ibm.com> * fix(core): make stream_validate backend/ctx keyword-only Prevents positional confusion and makes future parameter additions to the signature non-breaking for existing subclass overrides. Assisted-by: Claude Code * fix(core): fix stream_validate docstring and add missing stateful tests The docstring incorrectly stated that implementations must not mutate self. Issue #900 spec explicitly allows stateful accumulation and requires the shallow-copy caveat to be documented. Fix the docstring to match the spec. Add two tests required by the issue acceptance criteria: - test_stateful_subclass_accumulates_state: verifies a subclass can accumulate state (bullet counter) across stream_validate calls - test_stateful_subclass_clone_isolation: verifies copy() gives an independent clone, confirming the orchestrator clone pattern Assisted-by: Claude Code Signed-off-by: Nigel Jones <jonesn@uk.ibm.com> * test(core): make BulletCounter genuinely stateful via delta extraction The previous implementation overwrote _bullet_count from the full accumulated chunk on each call — equivalent to a pure function with no real dependency on prior state. Use _seen_len to extract only the new portion of each accumulated chunk, accumulating the count additively. This genuinely requires prior-call state to know where to slice, making the test name "accumulates_state" accurate. Assisted-by: Claude Code Signed-off-by: Nigel Jones <jonesn@uk.ibm.com> * test(core): fix missing type: ignore on backend=None in multi-line call In multi-line calls, # type: ignore only suppresses errors on its own line. The backend=None argument was uncovered; add the ignore there too. Assisted-by: Claude Code Signed-off-by: Nigel Jones <jonesn@uk.ibm.com> * test(core): fix import path and clone test orchestrator pattern Use the public API for imports: Backend and Context both appear in mellea.core.__all__, so import from mellea.core rather than the internal submodules. Rewrite test_stateful_subclass_clone_isolation to simulate the correct orchestrator pattern: the original requirement is never called directly; each attempt clones from the fresh original, giving _calls == 0 at the start of every attempt. The previous test cloned mid-stream, which tested shallow-copy isolation but demonstrated the wrong usage pattern. Assisted-by: Claude Code Signed-off-by: Nigel Jones <jonesn@uk.ibm.com> * fix(core): stream_validate receives a single chunk, not accumulated text Restores the chunk-at-a-time semantics set out in the #891 epic and #900 spec: stream_validate is called once per complete chunk produced by the chunking strategy, and receives that single chunk. Requirements that need history accumulate it on self. Commit 315a98c inadvertently flipped this: the BulletCounter test was rewritten to recover deltas from accumulated text via self._seen_len, and the docstring was updated to match ("The accumulated model output so far"). Neither change reflected a design decision — it was drift during a test fix, and buries a confusing workaround in what should be a straightforward stateful override. Changes: - requirement.py: rewrite chunk Args description to name the chunking-strategy-produced delta, clarify that ctx does not contain the generated output during streaming, and note the MOT single-consumer constraint - test_stream_validate.py: rewrite BulletCounter to accumulate its own running count (no self._seen_len); calls pass delta chunks ("\n- one\n- two") rather than re-sending accumulated text The corresponding orchestrator fix in stream_with_chunking() -- pass the chunk, iterate per chunk -- is in the stacked Wave 3 branch. Assisted-by: Claude Code --------- Signed-off-by: Nigel Jones <jonesn@uk.ibm.com>
1 parent c1622f5 commit 7912a1d

2 files changed

Lines changed: 209 additions & 0 deletions

File tree

mellea/core/requirement.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -283,6 +283,51 @@ async def validate(
283283
context=val_ctx,
284284
)
285285

286+
async def stream_validate(
287+
self, chunk: str, *, backend: Backend, ctx: Context
288+
) -> PartialValidationResult:
289+
"""Hook for per-chunk streaming validation.
290+
291+
The default implementation returns ``PartialValidationResult("unknown")``
292+
— meaning insufficient data to decide yet. Subclasses override this method
293+
to inspect the current chunk and return ``"pass"`` or ``"fail"`` early.
294+
295+
Implementations may accumulate state on ``self`` across calls within a
296+
single attempt. The orchestrator clones the requirement (``copy(req)``)
297+
before each attempt, so state does not bleed across retries.
298+
299+
Shallow-copy caveat: mutable container fields (e.g. ``self._buffer = []``)
300+
are shared by reference under ``copy()``. Reassign rather than mutate in
301+
place (``self._buffer = self._buffer + [chunk]``, not
302+
``self._buffer.append(chunk)``), or override ``__copy__`` for proper
303+
isolation.
304+
305+
Implementations must not call ``mot.astream()`` or otherwise read the
306+
underlying stream; the orchestrator is the single consumer of the MOT
307+
stream (see ``ModelOutputThunk.astream``). Requirements that need access
308+
to the text seen so far should accumulate it themselves from the
309+
``chunk`` values they receive.
310+
311+
Args:
312+
chunk: A single complete semantic chunk produced by the chunking
313+
strategy (e.g. one sentence for ``SentenceChunker``). This is
314+
the delta since the previous ``stream_validate`` call for this
315+
attempt, not the accumulated output. Requirements that need
316+
earlier context should retain it on ``self`` across calls.
317+
backend: The inference backend, available for backend-assisted checks.
318+
ctx: The current generation context. During streaming the MOT is
319+
not yet computed, so ``ctx`` does not contain the generated
320+
output; use ``chunk`` (and any state accumulated on ``self``)
321+
instead.
322+
323+
Returns:
324+
PartialValidationResult: ``"unknown"`` by default. Subclasses may return
325+
``"pass"`` (constraint satisfied so far) or ``"fail"`` (constraint violated,
326+
streaming should be aborted). ``"pass"`` does not short-circuit the final
327+
``validate()`` call; the orchestrator decides whether to skip it.
328+
"""
329+
return PartialValidationResult("unknown")
330+
286331
def parts(self) -> list[Component | CBlock]:
287332
"""Returns all of the constituent parts of a Requirement.
288333

test/core/test_stream_validate.py

Lines changed: 164 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,164 @@
1+
"""Unit tests for Requirement.stream_validate() hook."""
2+
3+
import inspect
4+
from copy import copy
5+
6+
import pytest
7+
8+
from mellea.core import Backend, Context, PartialValidationResult, Requirement
9+
10+
11+
@pytest.mark.asyncio
12+
async def test_default_returns_unknown():
13+
req = Requirement(description="some requirement")
14+
result = await req.stream_validate("some chunk", backend=None, ctx=None) # type: ignore[arg-type]
15+
assert result.success == "unknown"
16+
17+
18+
@pytest.mark.asyncio
19+
async def test_default_returns_partial_validation_result_instance():
20+
req = Requirement()
21+
result = await req.stream_validate("chunk", backend=None, ctx=None) # type: ignore[arg-type]
22+
assert isinstance(result, PartialValidationResult)
23+
24+
25+
def test_stream_validate_is_coroutine():
26+
req = Requirement()
27+
assert inspect.iscoroutinefunction(req.stream_validate)
28+
29+
30+
@pytest.mark.asyncio
31+
async def test_subclass_can_return_pass():
32+
class PassRequirement(Requirement):
33+
async def stream_validate(
34+
self, chunk: str, *, backend: Backend, ctx: Context
35+
) -> PartialValidationResult:
36+
return PartialValidationResult("pass")
37+
38+
req = PassRequirement(description="always passes")
39+
result = await req.stream_validate("any chunk", backend=None, ctx=None) # type: ignore[arg-type]
40+
assert result.success == "pass"
41+
42+
43+
@pytest.mark.asyncio
44+
async def test_subclass_can_return_fail():
45+
class FailRequirement(Requirement):
46+
async def stream_validate(
47+
self, chunk: str, *, backend: Backend, ctx: Context
48+
) -> PartialValidationResult:
49+
if "bad" in chunk:
50+
return PartialValidationResult("fail", reason="bad word detected")
51+
return PartialValidationResult("unknown")
52+
53+
req = FailRequirement(description="no bad words")
54+
result = await req.stream_validate("this is bad content", backend=None, ctx=None) # type: ignore[arg-type]
55+
assert result.success == "fail"
56+
assert result.reason == "bad word detected"
57+
58+
result_unknown = await req.stream_validate("good content", backend=None, ctx=None) # type: ignore[arg-type]
59+
assert result_unknown.success == "unknown"
60+
61+
62+
@pytest.mark.asyncio
63+
async def test_does_not_mutate_requirement():
64+
req = Requirement(description="original description")
65+
original_description = req.description
66+
original_output = req._output
67+
original_validation_fn = req.validation_fn
68+
69+
await req.stream_validate("some chunk", backend=None, ctx=None) # type: ignore[arg-type]
70+
71+
assert req.description == original_description
72+
assert req._output == original_output
73+
assert req.validation_fn == original_validation_fn
74+
75+
76+
@pytest.mark.asyncio
77+
async def test_stream_validate_idempotent():
78+
req = Requirement(description="repeated calls")
79+
result1 = await req.stream_validate("chunk one", backend=None, ctx=None) # type: ignore[arg-type]
80+
result2 = await req.stream_validate("chunk two", backend=None, ctx=None) # type: ignore[arg-type]
81+
assert result1.success == "unknown"
82+
assert result2.success == "unknown"
83+
assert req._output is None
84+
85+
86+
@pytest.mark.asyncio
87+
async def test_stateful_subclass_accumulates_state():
88+
"""Stateful subclass correctly accumulates state across stream_validate calls.
89+
90+
Each call receives a single chunk (the delta produced by the chunking
91+
strategy). Requirements maintain their own running state across calls
92+
rather than re-scanning accumulated text.
93+
"""
94+
95+
class BulletCounter(Requirement):
96+
def __init__(self) -> None:
97+
super().__init__(description="no more than 3 bullets")
98+
self._bullet_count = 0
99+
100+
async def stream_validate(
101+
self, chunk: str, *, backend: Backend, ctx: Context
102+
) -> PartialValidationResult:
103+
self._bullet_count += chunk.count("\n-")
104+
if self._bullet_count > 3:
105+
return PartialValidationResult(
106+
"fail", reason=f"{self._bullet_count} bullets exceeds limit"
107+
)
108+
return PartialValidationResult("unknown")
109+
110+
req = BulletCounter()
111+
assert req._bullet_count == 0
112+
113+
await req.stream_validate("intro text", backend=None, ctx=None) # type: ignore[arg-type]
114+
assert req._bullet_count == 0
115+
116+
await req.stream_validate("\n- one\n- two", backend=None, ctx=None) # type: ignore[arg-type]
117+
assert req._bullet_count == 2
118+
119+
result = await req.stream_validate(
120+
"\n- three\n- four",
121+
backend=None, # type: ignore[arg-type]
122+
ctx=None, # type: ignore[arg-type]
123+
)
124+
assert req._bullet_count == 4
125+
assert result.success == "fail"
126+
assert result.reason is not None and "4" in result.reason
127+
128+
129+
@pytest.mark.asyncio
130+
async def test_stateful_subclass_clone_isolation():
131+
"""Orchestrator clone pattern: copy() before each attempt gives a fresh independent clone.
132+
133+
The orchestrator holds the original requirement and never calls stream_validate on it
134+
directly. Before each attempt it clones the original; each clone starts from the
135+
original's (zero) state and advances independently.
136+
"""
137+
138+
class CallCounter(Requirement):
139+
def __init__(self) -> None:
140+
super().__init__(description="call counter")
141+
self._calls = 0
142+
143+
async def stream_validate(
144+
self, chunk: str, *, backend: Backend, ctx: Context
145+
) -> PartialValidationResult:
146+
self._calls += 1
147+
return PartialValidationResult("unknown")
148+
149+
req = CallCounter() # original — never used directly by the orchestrator
150+
151+
# Attempt 1
152+
attempt1 = copy(req)
153+
assert attempt1._calls == 0
154+
await attempt1.stream_validate("a", backend=None, ctx=None) # type: ignore[arg-type]
155+
await attempt1.stream_validate("b", backend=None, ctx=None) # type: ignore[arg-type]
156+
assert attempt1._calls == 2
157+
158+
# Attempt 2 (retry) — fresh clone from the same original
159+
attempt2 = copy(req)
160+
assert attempt2._calls == 0 # starts clean, not carrying attempt1's state
161+
await attempt2.stream_validate("c", backend=None, ctx=None) # type: ignore[arg-type]
162+
assert attempt2._calls == 1
163+
164+
assert req._calls == 0 # original never mutated

0 commit comments

Comments
 (0)