Skip to content

Commit b76390a

Browse files
committed
deprecate target_role in guardian_check helper
Signed-off-by: Paul S. Schweigert <paul@paulschweigert.com>
1 parent ca421e6 commit b76390a

2 files changed

Lines changed: 166 additions & 5 deletions

File tree

mellea/stdlib/components/intrinsic/guardian.py

Lines changed: 57 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,21 @@
99
resolved kwargs through.
1010
"""
1111

12+
import warnings
13+
1214
from ....backends.adapters import AdapterMixin
15+
from ....core.utils import MelleaLogger
1316
from ...context import ChatContext
1417
from ._util import call_intrinsic
1518

19+
_UNSET: object = object()
20+
"""Sentinel distinguishing 'caller omitted scoring_schema' from 'caller passed
21+
the default value explicitly'. Used only to detect conflicts with the
22+
deprecated ``target_role`` kwarg."""
23+
24+
_TARGET_ROLE_TO_SCHEMA = {"user": "user_prompt", "assistant": "assistant_response"}
25+
"""Mapping used by the deprecated ``target_role`` path of :func:`guardian_check`."""
26+
1627

1728
def policy_guardrails(
1829
context: ChatContext, backend: AdapterMixin, policy_text: str
@@ -153,7 +164,8 @@ def guardian_check(
153164
context: ChatContext,
154165
backend: AdapterMixin,
155166
criteria: str,
156-
scoring_schema: str = "assistant_response",
167+
scoring_schema: str | object = _UNSET,
168+
target_role: str | None = None,
157169
) -> float:
158170
"""Check whether text meets specified safety/quality criteria.
159171
@@ -170,15 +182,55 @@ def guardian_check(
170182
scoring_schema: Sentence that tells the judge which span to
171183
evaluate and how to decide. Can be a key from
172184
:data:`SCORING_SCHEMA_BANK` (e.g. ``"user_prompt"``) or a
173-
custom string. Must still resolve to a yes/no verdict —
174-
the adapter's ``response_format`` constrains output to
175-
``"yes"``/``"no"``.
185+
custom string. Defaults to ``"assistant_response"``. Must
186+
still resolve to a yes/no verdict — the adapter's
187+
``response_format`` constrains output to ``"yes"``/``"no"``.
188+
target_role: Deprecated. Role whose last message is being
189+
evaluated (``"user"`` or ``"assistant"``). Prefer
190+
``scoring_schema`` with a key from
191+
:data:`SCORING_SCHEMA_BANK`. Passing both
192+
``scoring_schema`` and ``target_role`` raises
193+
:class:`TypeError`.
176194
177195
Returns:
178196
Risk score as a float between 0.0 (no risk) and 1.0 (risk detected).
179197
"""
198+
if target_role is not None:
199+
warnings.warn(
200+
"`target_role` is deprecated; use `scoring_schema` instead "
201+
"(e.g. scoring_schema='user_prompt'). Will be removed in a "
202+
"future release.",
203+
DeprecationWarning,
204+
stacklevel=2,
205+
)
206+
if scoring_schema is not _UNSET:
207+
raise TypeError("Pass either `scoring_schema` or `target_role`, not both.")
208+
if target_role not in _TARGET_ROLE_TO_SCHEMA:
209+
raise ValueError(
210+
f"target_role must be 'user' or 'assistant', got {target_role!r}"
211+
)
212+
resolved_schema = _TARGET_ROLE_TO_SCHEMA[target_role]
213+
elif scoring_schema is _UNSET:
214+
resolved_schema = "assistant_response"
215+
else:
216+
assert isinstance(scoring_schema, str)
217+
if scoring_schema in _TARGET_ROLE_TO_SCHEMA:
218+
# Looks like an old-style target_role value passed positionally.
219+
suggested = _TARGET_ROLE_TO_SCHEMA[scoring_schema]
220+
MelleaLogger.get_logger().warning(
221+
"guardian_check(scoring_schema=%r) looks like an old-style "
222+
"target_role value. It will be used as a literal "
223+
"scoring-schema sentence, which is probably not what you "
224+
"want. Did you mean scoring_schema=%r? (target_role is "
225+
"deprecated; prefer SCORING_SCHEMA_BANK keys like "
226+
"'user_prompt' or 'assistant_response'.)",
227+
scoring_schema,
228+
suggested,
229+
)
230+
resolved_schema = scoring_schema
231+
180232
criteria_text = CRITERIA_BANK.get(criteria, criteria)
181-
scoring_schema_text = SCORING_SCHEMA_BANK.get(scoring_schema, scoring_schema)
233+
scoring_schema_text = SCORING_SCHEMA_BANK.get(resolved_schema, resolved_schema)
182234
result_json = call_intrinsic(
183235
"guardian-core",
184236
context,
Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
"""Unit tests for the deprecated ``target_role`` path of ``guardian_check``.
2+
3+
Exercises the sentinel/mapping logic without touching a model. We monkeypatch
4+
``call_intrinsic`` and assert on (a) the ``kwargs["scoring_schema"]`` that
5+
reaches the adapter boundary and (b) the warnings/errors the caller sees.
6+
"""
7+
8+
import warnings
9+
10+
import pytest
11+
12+
from mellea.stdlib.components.intrinsic import guardian
13+
from mellea.stdlib.context import ChatContext
14+
15+
16+
@pytest.fixture
17+
def capture_kwargs(monkeypatch):
18+
"""Replace call_intrinsic with a spy that returns a stub yes=1.0 result."""
19+
captured: dict = {}
20+
21+
def fake_call_intrinsic(name, context, backend, /, kwargs=None, model_options=None):
22+
captured["name"] = name
23+
captured["kwargs"] = kwargs
24+
return {"guardian": {"score": 1.0}}
25+
26+
monkeypatch.setattr(guardian, "call_intrinsic", fake_call_intrinsic)
27+
return captured
28+
29+
30+
def test_default_scoring_schema_resolves_to_assistant_response(capture_kwargs):
31+
guardian.guardian_check(ChatContext(), object(), criteria="harm")
32+
assert (
33+
capture_kwargs["kwargs"]["scoring_schema"]
34+
== guardian.SCORING_SCHEMA_BANK["assistant_response"]
35+
)
36+
37+
38+
def test_target_role_user_maps_to_user_prompt_with_deprecation_warning(capture_kwargs):
39+
with pytest.warns(DeprecationWarning, match="target_role"):
40+
guardian.guardian_check(
41+
ChatContext(), object(), criteria="harm", target_role="user"
42+
)
43+
assert (
44+
capture_kwargs["kwargs"]["scoring_schema"]
45+
== guardian.SCORING_SCHEMA_BANK["user_prompt"]
46+
)
47+
48+
49+
def test_target_role_assistant_maps_to_assistant_response_with_warning(capture_kwargs):
50+
with pytest.warns(DeprecationWarning, match="target_role"):
51+
guardian.guardian_check(
52+
ChatContext(), object(), criteria="harm", target_role="assistant"
53+
)
54+
assert (
55+
capture_kwargs["kwargs"]["scoring_schema"]
56+
== guardian.SCORING_SCHEMA_BANK["assistant_response"]
57+
)
58+
59+
60+
def test_target_role_invalid_value_raises_value_error(capture_kwargs):
61+
with warnings.catch_warnings():
62+
warnings.simplefilter("ignore", DeprecationWarning)
63+
with pytest.raises(ValueError, match="target_role must be"):
64+
guardian.guardian_check(
65+
ChatContext(), object(), criteria="harm", target_role="system"
66+
)
67+
68+
69+
def test_passing_both_scoring_schema_and_target_role_raises_type_error(capture_kwargs):
70+
with warnings.catch_warnings():
71+
warnings.simplefilter("ignore", DeprecationWarning)
72+
with pytest.raises(TypeError, match="not both"):
73+
guardian.guardian_check(
74+
ChatContext(),
75+
object(),
76+
criteria="harm",
77+
scoring_schema="user_prompt",
78+
target_role="user",
79+
)
80+
81+
82+
def test_positional_user_logs_warning_and_sends_literal(capture_kwargs, caplog):
83+
"""Positional 'user' is NOT auto-remapped — it's sent as a literal schema
84+
sentence, with a logger warning pointing the caller at the fix.
85+
"""
86+
with caplog.at_level("WARNING"):
87+
guardian.guardian_check(ChatContext(), object(), "harm", "user")
88+
# The literal "user" flows to the adapter unchanged.
89+
assert capture_kwargs["kwargs"]["scoring_schema"] == "user"
90+
# The warning text nudges the caller toward the bank key.
91+
assert any("user_prompt" in rec.message for rec in caplog.records)
92+
93+
94+
def test_scoring_schema_bank_key_resolves_to_full_sentence(capture_kwargs):
95+
guardian.guardian_check(
96+
ChatContext(), object(), criteria="harm", scoring_schema="tool_call"
97+
)
98+
assert (
99+
capture_kwargs["kwargs"]["scoring_schema"]
100+
== guardian.SCORING_SCHEMA_BANK["tool_call"]
101+
)
102+
103+
104+
def test_custom_scoring_schema_passes_through(capture_kwargs):
105+
custom = "If the previous turn mentions cats, return 'yes'; otherwise, return 'no'."
106+
guardian.guardian_check(
107+
ChatContext(), object(), criteria="harm", scoring_schema=custom
108+
)
109+
assert capture_kwargs["kwargs"]["scoring_schema"] == custom

0 commit comments

Comments
 (0)