Skip to content

Commit a37f229

Browse files
committed
feat: (wip) update protect support
1 parent 09653e3 commit a37f229

2 files changed

Lines changed: 122 additions & 10 deletions

File tree

evaluators/contrib/galileo/src/agent_control_evaluator_galileo/luna2/config.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,20 +5,24 @@
55
from agent_control_evaluators import EvaluatorConfig
66
from pydantic import Field, model_validator
77

8-
# Supported Luna-2 metrics
8+
# Supported Luna-2 metrics — names must match the Galileo Protect API exactly.
9+
# Use the Protect API / golden-demo code as the source of truth:
10+
# output PII → "pii" input PII → "input_pii"
11+
# output tox → "toxicity" input tox → "input_toxicity"
912
Luna2Metric = Literal[
1013
"input_toxicity",
11-
"output_toxicity",
14+
"toxicity", # output toxicity (API name, not "output_toxicity")
1215
"input_sexism",
1316
"output_sexism",
1417
"prompt_injection",
15-
"pii_detection",
18+
"pii", # output PII (API name, replaces "pii_detection")
19+
"input_pii", # input PII
1620
"hallucination",
1721
"tone",
1822
]
1923

20-
# Supported operators
21-
Luna2Operator = Literal["gt", "lt", "gte", "lte", "eq", "contains", "any"]
24+
# Supported operators — "not_empty" matches categorical PII/injection results
25+
Luna2Operator = Literal["gt", "lt", "gte", "lte", "eq", "contains", "any", "not_empty"]
2226

2327

2428
class Luna2EvaluatorConfig(EvaluatorConfig):
@@ -113,7 +117,8 @@ def validate_stage_config(self) -> "Luna2EvaluatorConfig":
113117
raise ValueError("'metric' is required for local stage")
114118
if not self.operator:
115119
raise ValueError("'operator' is required for local stage")
116-
if self.target_value is None:
120+
# not_empty / not_null operators don't need a comparison value
121+
if self.target_value is None and self.operator not in ("not_empty", "not_null"):
117122
raise ValueError("'target_value' is required for local stage")
118123
elif self.stage_type == "central":
119124
if not self.stage_name:

evaluators/contrib/galileo/src/agent_control_evaluator_galileo/luna2/evaluator.py

Lines changed: 111 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,20 @@ def _get_numeric_target_value(self) -> float | int | str | None:
172172
async def _evaluate_local_stage(self, data: Any) -> EvaluatorResult:
173173
"""Evaluate using a local stage (runtime rulesets).
174174
175+
We use PASSTHROUGH action so Protect computes the metric and returns
176+
metric_results without making a block decision itself — agent-control
177+
owns that decision via the control's action.decision field.
178+
179+
Numeric operators (gt, lt, gte, lte, eq): Protect evaluates the rule
180+
server-side and returns status="triggered" when the condition is met,
181+
so _parse_response picks it up directly.
182+
183+
Categorical operators (not_empty, any): the Protect local-stage rule
184+
engine does not support these operators and always returns
185+
status="not_triggered", even when the metric value is non-empty.
186+
_parse_response falls back to _evaluate_metric_results which evaluates
187+
the condition client-side from the raw metric_results dict.
188+
175189
Args:
176190
data: The data to evaluate.
177191
@@ -187,7 +201,8 @@ async def _evaluate_local_stage(self, data: Any) -> EvaluatorResult:
187201
target_value=self._get_numeric_target_value() or 0,
188202
)
189203

190-
# Create proper Ruleset with PassthroughAction
204+
# PASSTHROUGH: Protect scores the content and returns metric_results,
205+
# but does not block — agent-control's deny action handles that.
191206
ruleset = Ruleset(
192207
rules=[rule],
193208
action=PassthroughAction(type="PASSTHROUGH"),
@@ -204,6 +219,7 @@ async def _evaluate_local_stage(self, data: Any) -> EvaluatorResult:
204219
payload=payload,
205220
prioritized_rulesets=[ruleset],
206221
project_name=self.config.galileo_project,
222+
stage_name=self.config.stage_name,
207223
timeout=self.get_timeout_seconds(),
208224
metadata=self.config.metadata or {},
209225
)
@@ -279,10 +295,20 @@ def _prepare_payload(self, data: Any) -> Payload:
279295
is_output_metric = "output" in metric
280296

281297
if is_output_metric:
282-
return Payload(input="", output=data_str)
298+
payload = Payload(input="", output=data_str)
283299
else:
284300
# Default to input for central stages or input metrics
285-
return Payload(input=data_str, output="")
301+
payload = Payload(input=data_str, output="")
302+
303+
logger.debug(
304+
"[Luna2] _prepare_payload: metric=%r payload_field_config=%r "
305+
"→ input=%d chars, output=%d chars",
306+
self.config.metric,
307+
self.config.payload_field,
308+
len(payload.input),
309+
len(payload.output),
310+
)
311+
return payload
286312

287313
def _parse_response(self, response: ProtectResponse | None) -> EvaluatorResult:
288314
"""Parse Galileo Protect response into EvaluatorResult.
@@ -304,16 +330,34 @@ def _parse_response(self, response: ProtectResponse | None) -> EvaluatorResult:
304330
status = response.status.lower() if response.status else "unknown"
305331
triggered = status == "triggered"
306332

333+
# Numeric operators (gt/lt/etc.) are evaluated server-side by Protect and
334+
# return status="triggered" correctly even with PASSTHROUGH.
335+
# Categorical operators (not_empty, any) are NOT supported by Protect's
336+
# local-stage rule engine — it always returns status="not_triggered" for
337+
# them regardless of the metric value. Fall back to client-side evaluation
338+
# from metric_results for those cases.
339+
if not triggered and response.metric_results:
340+
triggered = self._evaluate_metric_results(response.metric_results)
341+
342+
logger.info(
343+
"[Luna2] response: status=%r triggered=%s metric_results=%s",
344+
status,
345+
triggered,
346+
response.metric_results,
347+
)
348+
307349
# Extract trace metadata
308350
trace_id = response.trace_metadata.id if response.trace_metadata else None
309351
execution_time = response.trace_metadata.execution_time if response.trace_metadata else None
310352
received_at = response.trace_metadata.received_at if response.trace_metadata else None
311353
response_at = response.trace_metadata.response_at if response.trace_metadata else None
312354

355+
message = self._build_message(triggered, status, response.metric_results)
356+
313357
return EvaluatorResult(
314358
matched=triggered,
315359
confidence=1.0 if triggered else 0.0,
316-
message=response.text or f"Luna-2 check: {status}",
360+
message=message,
317361
metadata={
318362
"status": status,
319363
"metric": self.config.metric or "unknown",
@@ -324,6 +368,69 @@ def _parse_response(self, response: ProtectResponse | None) -> EvaluatorResult:
324368
},
325369
)
326370

371+
def _build_message(self, triggered: bool, status: str, metric_results: dict) -> str:
372+
"""Build a human-readable message from the evaluation result."""
373+
metric = self.config.metric or "unknown"
374+
375+
if not triggered:
376+
return f"Luna-2 {metric} check passed"
377+
378+
result = (metric_results or {}).get(metric, {})
379+
value = result.get("value")
380+
381+
if isinstance(value, list) and value:
382+
categories = ", ".join(str(v).replace("_", " ") for v in value)
383+
return f"PII detected: {categories}"
384+
if isinstance(value, (int, float)):
385+
return f"{metric} score {value:.2f} exceeded threshold"
386+
387+
return f"Luna-2 {metric} check triggered"
388+
389+
def _evaluate_metric_results(self, metric_results: dict) -> bool:
390+
"""Evaluate the configured operator/target against raw metric_results.
391+
392+
Used when the Protect API returns PASSTHROUGH (no server-side trigger)
393+
but we still need to decide whether the rule condition is met.
394+
395+
Args:
396+
metric_results: The metric_results dict from the Protect API response.
397+
398+
Returns:
399+
True if the rule condition is satisfied.
400+
"""
401+
metric = self.config.metric
402+
if not metric or metric not in metric_results:
403+
return False
404+
405+
result = metric_results[metric]
406+
if result.get("status") != "SUCCESS":
407+
return False
408+
409+
value = result.get("value")
410+
operator = self.config.operator
411+
target = self.config.target_value
412+
413+
if operator in ("not_empty", "not_null"):
414+
return bool(value)
415+
if operator in ("empty", "is_null"):
416+
return not bool(value)
417+
if operator == "any" and isinstance(value, list):
418+
return target in value if target is not None else bool(value)
419+
if operator == "contains":
420+
return target in value if (value and target is not None) else False
421+
if isinstance(value, (int, float)) and target is not None:
422+
try:
423+
t = float(target)
424+
if operator == "gt": return value > t
425+
if operator == "gte": return value >= t
426+
if operator == "lt": return value < t
427+
if operator == "lte": return value <= t
428+
if operator == "eq": return value == t
429+
except (TypeError, ValueError):
430+
pass
431+
432+
return False
433+
327434
def _handle_error(self, error: Exception) -> EvaluatorResult:
328435
"""Handle errors from Luna-2 evaluation.
329436

0 commit comments

Comments
 (0)