Skip to content

Commit b6cc1e6

Browse files
committed
review comments
Signed-off-by: Akihiko Kuroda <akihikokuroda2020@gmail.com>
1 parent 6733711 commit b6cc1e6

6 files changed

Lines changed: 127 additions & 85 deletions

File tree

docs/examples/intrinsics/factuality_correction.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,9 @@
8686
)
8787
# NOTE: This example can also be run with the OpenAIBackend using a GraniteSwitch model. See docs/examples/granite-switch/.
8888

89-
ctx = ctx.add(Message("user", user_text)).add(Message("assistant", response_text))
89+
ctx = ctx.add(Message("user", user_text))
9090

91-
result = guardian.factuality_correction(ctx, backend, documents=[document])
91+
result = guardian.factuality_correction(
92+
response_text, ctx, backend, documents=[document]
93+
)
9294
print(f"Result of factuality correction: {result}") # corrected response string

docs/examples/intrinsics/factuality_detection.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,9 @@
2929
)
3030
# NOTE: This example can also be run with the OpenAIBackend using a GraniteSwitch model. See docs/examples/granite-switch/.
3131

32-
ctx = ctx.add(Message("user", user_text)).add(Message("assistant", response_text))
32+
ctx = ctx.add(Message("user", user_text))
3333

34-
result = guardian.factuality_detection(ctx, backend, documents=[document])
34+
result = guardian.factuality_detection(
35+
response_text, ctx, backend, documents=[document]
36+
)
3537
print(f"Result of factuality detection: {result}") # string "yes" or "no"

mellea/stdlib/components/intrinsic/_util.py

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from ....stdlib import functional as mfuncs
1414
from ...components import Document
1515
from ...context import ChatContext
16+
from ..chat import Message
1617
from .intrinsic import Intrinsic
1718

1819

@@ -74,20 +75,37 @@ def _resolve_response(
7475
"""Return ``(response_text, context_to_use)``.
7576
7677
When *response* is not ``None``, returns it with *context* unchanged.
77-
When ``None``, extracts from the last turn's ``output.value`` and rewinds
78-
*context* to before that output.
78+
When ``None``, extracts from the last turn's ``output.value`` (generated) or
79+
``model_input.content`` (manually-added Message), then rewinds *context*
80+
to before that turn.
7981
"""
8082
if response is not None:
8183
return response, context
8284
turn = context.last_turn()
83-
if turn is None or turn.output is None:
85+
if turn is None:
8486
raise ValueError("response is None and context has no last turn with output")
85-
if turn.output.value is None:
86-
raise ValueError("response is None and last turn output has no value")
87+
88+
# Try generated output first
89+
if turn.output is not None:
90+
if turn.output.value is None:
91+
raise ValueError("response is None and last turn output has no value")
92+
response_text = turn.output.value
93+
# Fall back to manually-added assistant Message
94+
elif (
95+
turn.model_input is not None
96+
and isinstance(turn.model_input, Message)
97+
and turn.model_input.role == "assistant"
98+
):
99+
response_text = turn.model_input.content
100+
else:
101+
raise ValueError(
102+
"response is None and context has no last turn with output or assistant message"
103+
)
104+
87105
rewound = context.previous_node
88106
if rewound is None:
89107
raise ValueError("Cannot rewind context past the root node")
90-
return turn.output.value, rewound # type: ignore[return-value]
108+
return response_text, rewound # type: ignore[return-value]
91109

92110

93111
def call_intrinsic(

mellea/stdlib/components/intrinsic/guardian.py

Lines changed: 79 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,14 @@
77
"""
88

99
import collections.abc
10-
from typing import cast
1110

1211
from ....backends import model_ids
1312
from ....backends.adapters import AdapterMixin
1413
from ...components import Document
1514
from ...context import ChatContext
1615
from ..chat import Message
1716
from ..docs.document import _coerce_to_documents
18-
from ._util import call_intrinsic
17+
from ._util import _resolve_response, call_intrinsic
1918

2019

2120
def policy_guardrails(
@@ -153,6 +152,8 @@ def guardian_check(
153152
backend: AdapterMixin,
154153
criteria: str,
155154
target_role: str = "assistant",
155+
*,
156+
documents: collections.abc.Iterable[str | Document] | None = None,
156157
model_options: dict | None = None,
157158
) -> float:
158159
"""Check whether text meets specified safety/quality criteria.
@@ -168,12 +169,20 @@ def guardian_check(
168169
criteria string.
169170
target_role: Role whose last message is being evaluated
170171
(``"user"`` or ``"assistant"``).
172+
documents: Optional document snippets to attach to the target message.
173+
Primarily used for the ``"groundedness"`` criterion, to provide
174+
reference context for grounding checks. Each element may be a
175+
``Document`` or a plain string (automatically wrapped in ``Document``).
176+
Keyword-only.
171177
model_options: Optional model options to pass to the backend (e.g.,
172178
temperature, max_tokens). Defaults to ``{ModelOption.TEMPERATURE: 0.0}``.
173179
174180
Returns:
175181
Risk score as a float between 0.0 (no risk) and 1.0 (risk detected).
176182
"""
183+
if documents is not None and target_role == "assistant":
184+
context = _reattach_documents(context, documents)
185+
177186
criteria_text = CRITERIA_BANK.get(criteria, criteria)
178187

179188
scoring = (
@@ -209,61 +218,57 @@ def _reattach_documents(
209218
New context with documents attached to the last assistant message.
210219
211220
Raises:
212-
ValueError: If context cannot be rewound or assistant content cannot be extracted.
221+
ValueError: If context cannot be rewound or content cannot be extracted.
213222
"""
214-
last_turn = context.last_turn()
215-
if last_turn is None:
216-
raise ValueError("Cannot reattach documents: context has no last turn")
217-
218-
# Extract assistant content, preferring generated output over input
219-
if last_turn.output is not None and last_turn.output.value is not None:
220-
assistant_content = last_turn.output.value
221-
elif last_turn.output is not None and last_turn.output.value is None:
222-
# Uncomputed thunk — avoid silent fallthrough to model_input
223-
raise ValueError(
224-
"Cannot reattach documents: last turn output is uncomputed (thunk with no value)"
225-
)
226-
elif last_turn.model_input is not None and isinstance(
227-
last_turn.model_input, Message
223+
turn = context.last_turn()
224+
if turn is None:
225+
raise ValueError("Cannot extract response from empty context")
226+
227+
# Try to get response from output first (generated), then from message content (manual)
228+
response_text = None
229+
rewound = context.previous_node
230+
231+
if turn.output is not None and turn.output.value is not None:
232+
# Response is from a generated output
233+
response_text = turn.output.value
234+
elif (
235+
turn.model_input is not None
236+
and isinstance(turn.model_input, Message)
237+
and turn.model_input.role == "assistant"
228238
):
229-
assistant_content = last_turn.model_input.content
239+
# Response is from a manually added assistant Message
240+
response_text = turn.model_input.content
230241
else:
231242
raise ValueError(
232-
"Cannot reattach documents: cannot extract assistant content from last turn"
243+
"Cannot extract response: turn has neither output nor assistant message"
233244
)
234245

235-
# Rewind and re-add with documents
236-
rewound = context.previous_node
237246
if rewound is None:
238247
raise ValueError("Cannot rewind context past the root node")
239248

240-
return cast(
241-
ChatContext,
242-
rewound.add(
243-
Message(
244-
"assistant",
245-
assistant_content,
246-
documents=_coerce_to_documents(documents),
247-
)
248-
),
249+
return rewound.add( # type: ignore[return-value]
250+
Message("assistant", response_text, documents=_coerce_to_documents(documents))
249251
)
250252

251253

252254
def factuality_detection(
255+
response: str | None,
253256
context: ChatContext,
254257
backend: AdapterMixin,
255258
*,
256259
documents: collections.abc.Iterable[str | Document] | None = None,
257260
model_options: dict | None = None,
258261
) -> str:
259-
"""Determine if the last response is factually incorrect.
262+
"""Determine if a response is factually incorrect.
260263
261-
Intrinsic function that evaluates the factuality of the
262-
assistant's response to a user's question. The context should end with
263-
a user question followed by an assistant answer.
264+
Intrinsic function that evaluates the factuality of an assistant's response
265+
to a user's question. The context should typically end with a user question
266+
followed by an assistant answer.
264267
265268
Args:
266-
context: Chat context containing user question and assistant answer.
269+
response: The assistant's response text to evaluate. When ``None``, the
270+
response is extracted from the last assistant output in ``context``.
271+
context: Chat context containing user question and conversation history.
267272
backend: Backend instance that supports LoRA/aLoRA adapters.
268273
documents: Document snippets that provide factual context for evaluation.
269274
Each element may be a ``Document`` or a plain string (automatically
@@ -283,8 +288,24 @@ def factuality_detection(
283288
### Scoring Schema: If the last assistant's text meets the criteria, return 'yes'; otherwise, return 'no'.
284289
"""
285290

291+
if response is None:
292+
response, context = _resolve_response(None, context)
293+
286294
if documents is not None:
287-
context = _reattach_documents(context, documents)
295+
if response is not None:
296+
# Response was explicitly provided, add it with documents
297+
context = context.add(
298+
Message(
299+
"assistant", response, documents=_coerce_to_documents(documents)
300+
)
301+
)
302+
else:
303+
# Response came from context output, reattach documents
304+
context = _reattach_documents(context, documents)
305+
else:
306+
# No documents provided, add response to context if it was explicitly provided
307+
if response is not None:
308+
context = context.add(Message("assistant", response))
288309

289310
context = context.add(Message("user", detector_message))
290311
result_json = call_intrinsic(
@@ -294,19 +315,22 @@ def factuality_detection(
294315

295316

296317
def factuality_correction(
318+
response: str | None,
297319
context: ChatContext,
298320
backend: AdapterMixin,
299321
*,
300322
documents: collections.abc.Iterable[str | Document] | None = None,
301323
model_options: dict | None = None,
302324
) -> str:
303-
"""Corrects the last response so that it is factually correct.
325+
"""Correct a response to be factually accurate.
304326
305327
Intrinsic function that corrects the assistant's response to a user's
306328
question relative to the given contextual information.
307329
308330
Args:
309-
context: Chat context containing user question and assistant answer.
331+
response: The assistant's response text to correct. When ``None``, the
332+
response is extracted from the last assistant output in ``context``.
333+
context: Chat context containing user question and conversation history.
310334
backend: Backend instance that supports LoRA/aLoRA adapters.
311335
documents: Document snippets that provide factual context for correction.
312336
Each element may be a ``Document`` or a plain string (automatically
@@ -326,8 +350,24 @@ def factuality_correction(
326350
### Scoring Schema: If the last assistant's text meets the criteria, return a corrected version of the assistant's message based on the given context; otherwise, return 'none'.
327351
"""
328352

353+
if response is None:
354+
response, context = _resolve_response(None, context)
355+
329356
if documents is not None:
330-
context = _reattach_documents(context, documents)
357+
if response is not None:
358+
# Response was explicitly provided, add it with documents
359+
context = context.add(
360+
Message(
361+
"assistant", response, documents=_coerce_to_documents(documents)
362+
)
363+
)
364+
else:
365+
# Response came from context output, reattach documents
366+
context = _reattach_documents(context, documents)
367+
else:
368+
# No documents provided, add response to context if it was explicitly provided
369+
if response is not None:
370+
context = context.add(Message("assistant", response))
331371

332372
context = context.add(Message("user", corrector_message))
333373
result_json = call_intrinsic(

test/backends/test_openai_intrinsics.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -559,7 +559,9 @@ def test_call_intrinsic_factuality_detection(call_intrinsic_backend):
559559
for m in messages:
560560
context = context.add(Message(m["role"], m["content"]))
561561

562-
result = guardian.factuality_detection(docs, context, call_intrinsic_backend)
562+
result = guardian.factuality_detection(
563+
None, context, call_intrinsic_backend, documents=docs
564+
)
563565
assert result in ("yes", "no")
564566

565567

@@ -580,5 +582,7 @@ def test_call_intrinsic_factuality_correction(call_intrinsic_backend):
580582
for m in messages:
581583
context = context.add(Message(m["role"], m["content"]))
582584

583-
result = guardian.factuality_correction(docs, context, call_intrinsic_backend)
585+
result = guardian.factuality_correction(
586+
None, context, call_intrinsic_backend, documents=docs
587+
)
584588
assert isinstance(result, str)

test/stdlib/components/intrinsic/test_guardian.py

Lines changed: 10 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,7 @@ def test_factuality_detection(backend):
176176
context, documents = _read_guardian_input("factuality_detection.json")
177177

178178
# Test with documents passed as argument
179-
result = guardian.factuality_detection(context, backend, documents=documents)
179+
result = guardian.factuality_detection(None, context, backend, documents=documents)
180180
assert result == "yes" or result == "no"
181181

182182

@@ -185,23 +185,11 @@ def test_factuality_detection_from_context(backend):
185185
"""Verify factuality detection works when documents are in the last message."""
186186
context, documents = _read_guardian_input("factuality_detection.json")
187187

188-
# Extract assistant content using the same logic as _reattach_documents
189-
last_turn = context.last_turn()
190-
assert last_turn is not None
191-
if last_turn.output is not None and last_turn.output.value is not None:
192-
content = last_turn.output.value
193-
else:
194-
assert isinstance(last_turn.model_input, Message)
195-
content = last_turn.model_input.content
196-
197-
rewound = context.previous_node
198-
assert rewound is not None
199-
context_with_docs: ChatContext = rewound.add( # type: ignore[assignment]
200-
Message("assistant", content, documents=documents)
201-
)
188+
# Extract response and rewind, then add back with documents
189+
context_with_docs = guardian._reattach_documents(context, documents)
202190

203-
# Test with documents=None (should extract from context)
204-
result = guardian.factuality_detection(context_with_docs, backend)
191+
# Test with response=None (should extract from context)
192+
result = guardian.factuality_detection(None, context_with_docs, backend)
205193
assert result == "yes" or result == "no"
206194

207195

@@ -211,7 +199,7 @@ def test_factuality_correction(backend):
211199
context, documents = _read_guardian_input("factuality_correction.json")
212200

213201
# Test with documents passed as argument
214-
result = guardian.factuality_correction(context, backend, documents=documents)
202+
result = guardian.factuality_correction(None, context, backend, documents=documents)
215203
assert isinstance(result, str)
216204

217205

@@ -220,23 +208,11 @@ def test_factuality_correction_from_context(backend):
220208
"""Verify factuality correction works when documents are in the last message."""
221209
context, documents = _read_guardian_input("factuality_correction.json")
222210

223-
# Extract assistant content using the same logic as _reattach_documents
224-
last_turn = context.last_turn()
225-
assert last_turn is not None
226-
if last_turn.output is not None and last_turn.output.value is not None:
227-
content = last_turn.output.value
228-
else:
229-
assert isinstance(last_turn.model_input, Message)
230-
content = last_turn.model_input.content
231-
232-
rewound = context.previous_node
233-
assert rewound is not None
234-
context_with_docs: ChatContext = rewound.add( # type: ignore[assignment]
235-
Message("assistant", content, documents=documents)
236-
)
211+
# Extract response and rewind, then add back with documents
212+
context_with_docs = guardian._reattach_documents(context, documents)
237213

238-
# Test with documents=None (should extract from context)
239-
result = guardian.factuality_correction(context_with_docs, backend)
214+
# Test with response=None (should extract from context)
215+
result = guardian.factuality_correction(None, context_with_docs, backend)
240216
assert isinstance(result, str)
241217

242218

0 commit comments

Comments
 (0)