Skip to content

Commit 0ed243e

Browse files
committed
feat: adds ability to use inverted judges
1 parent e8c6692 commit 0ed243e

4 files changed

Lines changed: 278 additions & 7 deletions

File tree

packages/optimization/src/ldai_optimizer/client.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,17 @@ def _compute_validation_count(pool_size: int) -> int:
142142
}
143143

144144

145+
def _judge_passed(score: float, threshold: float, is_inverted: bool) -> bool:
146+
"""Return True when a judge score meets its threshold.
147+
148+
For standard judges (higher is better) the score must reach the threshold
149+
from below: ``score >= threshold``. For inverted judges (lower is better,
150+
e.g. toxicity) the score must stay at or below the threshold:
151+
``score <= threshold``.
152+
"""
153+
return score <= threshold if is_inverted else score >= threshold
154+
155+
145156
class OptimizationClient:
146157
_options: OptimizationOptions
147158
_ldClient: LDAIClient
@@ -470,13 +481,14 @@ async def _call_judges(
470481
if optimization_judge.threshold is not None
471482
else 1.0
472483
)
473-
passed = result.score >= threshold
484+
passed = _judge_passed(result.score, threshold, optimization_judge.is_inverted)
474485
logger.debug(
475-
"[Iteration %d] -> Judge '%s' scored %.3f (threshold=%.3f) -> %s%s",
486+
"[Iteration %d] -> Judge '%s' scored %.3f (threshold=%.3f, inverted=%s) -> %s%s",
476487
iteration,
477488
judge_key,
478489
result.score,
479490
threshold,
491+
optimization_judge.is_inverted,
480492
"PASSED" if passed else "FAILED",
481493
f" | {result.rationale}" if result.rationale else "",
482494
)
@@ -1492,9 +1504,13 @@ def _build_options_from_config(
14921504
)
14931505

14941506
for judge in config["judges"]:
1495-
judges[judge["key"]] = OptimizationJudge(
1507+
judge_key = judge["key"]
1508+
ai_config = api_client.get_ai_config(options.project_key, judge_key)
1509+
is_inverted = bool(ai_config.get("isInverted", False)) if ai_config else False
1510+
judges[judge_key] = OptimizationJudge(
14961511
threshold=float(judge.get("threshold", 0.95)),
1497-
judge_key=judge["key"],
1512+
judge_key=judge_key,
1513+
is_inverted=is_inverted,
14981514
)
14991515

15001516
raw_ground_truth: List[str] = config.get("groundTruthResponses") or []
@@ -1852,7 +1868,7 @@ def _evaluate_response(self, optimize_context: OptimizationContext) -> bool:
18521868
if optimization_judge.threshold is not None
18531869
else 1.0
18541870
)
1855-
if result.score < threshold:
1871+
if not _judge_passed(result.score, threshold, optimization_judge.is_inverted):
18561872
return False
18571873

18581874
return True

packages/optimization/src/ldai_optimizer/dataclasses.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,7 @@ class OptimizationJudge:
196196
threshold: float
197197
judge_key: Optional[str] = None
198198
acceptance_statement: Optional[str] = None
199+
is_inverted: bool = False
199200

200201

201202
@dataclass

packages/optimization/src/ldai_optimizer/prompts.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -285,7 +285,10 @@ def variation_prompt_feedback(
285285
if optimization_judge:
286286
score = result.score
287287
if optimization_judge.threshold is not None:
288-
passed = score >= optimization_judge.threshold
288+
if optimization_judge.is_inverted:
289+
passed = score <= optimization_judge.threshold
290+
else:
291+
passed = score >= optimization_judge.threshold
289292
status = "PASSED" if passed else "FAILED"
290293
feedback_line = (
291294
f"- {judge_key}: Score {score:.3f}"

packages/optimization/tests/test_client.py

Lines changed: 252 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from ldai.tracker import TokenUsage
1111
from ldclient import Context
1212

13-
from ldai_optimizer.client import OptimizationClient, _compute_validation_count, _find_model_config
13+
from ldai_optimizer.client import OptimizationClient, _compute_validation_count, _find_model_config, _judge_passed
1414
from ldai_optimizer.dataclasses import (
1515
AIJudgeCallConfig,
1616
GroundTruthOptimizationOptions,
@@ -28,6 +28,7 @@
2828
_acceptance_criteria_implies_duration_optimization,
2929
build_new_variation_prompt,
3030
variation_prompt_acceptance_criteria,
31+
variation_prompt_feedback,
3132
variation_prompt_improvement_instructions,
3233
variation_prompt_overfit_warning,
3334
variation_prompt_preamble,
@@ -1847,6 +1848,8 @@ def _make_mock_api_client() -> MagicMock:
18471848
mock.post_agent_optimization_result = MagicMock(return_value="result-uuid-789")
18481849
mock.patch_agent_optimization_result = MagicMock()
18491850
mock.get_model_configs = MagicMock(return_value=[])
1851+
# Default: AI Configs do not have isInverted set
1852+
mock.get_ai_config = MagicMock(return_value={})
18501853
return mock
18511854

18521855

@@ -4404,3 +4407,251 @@ async def test_optimization_key_in_post_url_uses_string_key_not_uuid(self):
44044407
assert opt_key_arg == "my-optimization", (
44054408
f"Expected string key 'my-optimization', got '{opt_key_arg}'"
44064409
)
4410+
4411+
4412+
# ---------------------------------------------------------------------------
4413+
# _judge_passed helper
4414+
# ---------------------------------------------------------------------------
4415+
4416+
4417+
class TestJudgePassed:
4418+
def test_standard_judge_passes_at_or_above_threshold(self):
4419+
assert _judge_passed(0.8, 0.8, is_inverted=False) is True
4420+
assert _judge_passed(1.0, 0.8, is_inverted=False) is True
4421+
4422+
def test_standard_judge_fails_below_threshold(self):
4423+
assert _judge_passed(0.5, 0.8, is_inverted=False) is False
4424+
4425+
def test_inverted_judge_passes_at_or_below_threshold(self):
4426+
assert _judge_passed(0.1, 0.3, is_inverted=True) is True
4427+
assert _judge_passed(0.3, 0.3, is_inverted=True) is True
4428+
4429+
def test_inverted_judge_fails_above_threshold(self):
4430+
assert _judge_passed(0.8, 0.3, is_inverted=True) is False
4431+
4432+
4433+
# ---------------------------------------------------------------------------
4434+
# _evaluate_response with inverted judges
4435+
# ---------------------------------------------------------------------------
4436+
4437+
4438+
class TestEvaluateResponseInvertedJudges:
4439+
def setup_method(self):
4440+
self.client = _make_client()
4441+
4442+
def _ctx_with_scores(self, scores: Dict[str, JudgeResult]) -> OptimizationContext:
4443+
return OptimizationContext(
4444+
scores=scores,
4445+
completion_response="Some response.",
4446+
current_instructions="Do X.",
4447+
current_parameters={},
4448+
current_variables={},
4449+
iteration=1,
4450+
)
4451+
4452+
def test_inverted_judge_passes_when_score_below_threshold(self):
4453+
self.client._options = _make_options(
4454+
judges={"toxicity": OptimizationJudge(threshold=0.3, acceptance_statement="Low toxicity.", is_inverted=True)}
4455+
)
4456+
ctx = self._ctx_with_scores({"toxicity": JudgeResult(score=0.1)})
4457+
assert self.client._evaluate_response(ctx) is True
4458+
4459+
def test_inverted_judge_passes_at_exact_threshold(self):
4460+
self.client._options = _make_options(
4461+
judges={"toxicity": OptimizationJudge(threshold=0.3, acceptance_statement="Low toxicity.", is_inverted=True)}
4462+
)
4463+
ctx = self._ctx_with_scores({"toxicity": JudgeResult(score=0.3)})
4464+
assert self.client._evaluate_response(ctx) is True
4465+
4466+
def test_inverted_judge_fails_when_score_above_threshold(self):
4467+
self.client._options = _make_options(
4468+
judges={"toxicity": OptimizationJudge(threshold=0.3, acceptance_statement="Low toxicity.", is_inverted=True)}
4469+
)
4470+
ctx = self._ctx_with_scores({"toxicity": JudgeResult(score=0.8)})
4471+
assert self.client._evaluate_response(ctx) is False
4472+
4473+
def test_mixed_judges_all_must_pass(self):
4474+
"""A standard judge and an inverted judge must both pass for overall pass."""
4475+
self.client._options = _make_options(
4476+
judges={
4477+
"relevance": OptimizationJudge(threshold=0.8, acceptance_statement="Relevant.", is_inverted=False),
4478+
"toxicity": OptimizationJudge(threshold=0.3, acceptance_statement="Low toxicity.", is_inverted=True),
4479+
}
4480+
)
4481+
# Both pass: relevance high, toxicity low
4482+
ctx = self._ctx_with_scores({
4483+
"relevance": JudgeResult(score=0.9),
4484+
"toxicity": JudgeResult(score=0.1),
4485+
})
4486+
assert self.client._evaluate_response(ctx) is True
4487+
4488+
def test_mixed_judges_fails_when_inverted_judge_too_high(self):
4489+
self.client._options = _make_options(
4490+
judges={
4491+
"relevance": OptimizationJudge(threshold=0.8, acceptance_statement="Relevant.", is_inverted=False),
4492+
"toxicity": OptimizationJudge(threshold=0.3, acceptance_statement="Low toxicity.", is_inverted=True),
4493+
}
4494+
)
4495+
# Relevance passes but toxicity fails (score too high)
4496+
ctx = self._ctx_with_scores({
4497+
"relevance": JudgeResult(score=0.9),
4498+
"toxicity": JudgeResult(score=0.8),
4499+
})
4500+
assert self.client._evaluate_response(ctx) is False
4501+
4502+
def test_mixed_judges_fails_when_standard_judge_too_low(self):
4503+
self.client._options = _make_options(
4504+
judges={
4505+
"relevance": OptimizationJudge(threshold=0.8, acceptance_statement="Relevant.", is_inverted=False),
4506+
"toxicity": OptimizationJudge(threshold=0.3, acceptance_statement="Low toxicity.", is_inverted=True),
4507+
}
4508+
)
4509+
# Toxicity passes but relevance fails (score too low)
4510+
ctx = self._ctx_with_scores({
4511+
"relevance": JudgeResult(score=0.5),
4512+
"toxicity": JudgeResult(score=0.1),
4513+
})
4514+
assert self.client._evaluate_response(ctx) is False
4515+
4516+
4517+
# ---------------------------------------------------------------------------
4518+
# _build_options_from_config reads isInverted via get_ai_config REST call
4519+
# ---------------------------------------------------------------------------
4520+
4521+
4522+
class TestBuildOptionsFromConfigIsInverted:
4523+
def setup_method(self):
4524+
self.client = _make_client()
4525+
self.client._agent_key = "my-agent"
4526+
self.client._initialize_class_members_from_config(_make_agent_config())
4527+
self.client._options = _make_options()
4528+
self.api_client = _make_mock_api_client()
4529+
4530+
def _build(self, config=None, options=None) -> OptimizationOptions:
4531+
return self.client._build_options_from_config(
4532+
config or dict(_API_CONFIG),
4533+
options or _make_from_config_options(),
4534+
self.api_client,
4535+
optimization_key="opt-key-123",
4536+
run_id="run-uuid-456",
4537+
model_configs=[],
4538+
)
4539+
4540+
def test_is_inverted_true_when_ai_config_returns_isInverted(self):
4541+
"""is_inverted is set from the AI Config REST API response for each judge."""
4542+
self.api_client.get_ai_config.return_value = {"isInverted": True}
4543+
config = dict(_API_CONFIG, acceptanceStatements=[], judges=[
4544+
{"key": "toxicity", "threshold": 0.3},
4545+
])
4546+
result = self._build(config=config)
4547+
assert result.judges["toxicity"].is_inverted is True
4548+
4549+
def test_is_inverted_false_when_ai_config_has_no_isInverted(self):
4550+
self.api_client.get_ai_config.return_value = {}
4551+
config = dict(_API_CONFIG, acceptanceStatements=[], judges=[
4552+
{"key": "relevance", "threshold": 0.8},
4553+
])
4554+
result = self._build(config=config)
4555+
assert result.judges["relevance"].is_inverted is False
4556+
4557+
def test_is_inverted_false_when_ai_config_has_isInverted_false(self):
4558+
self.api_client.get_ai_config.return_value = {"isInverted": False}
4559+
config = dict(_API_CONFIG, acceptanceStatements=[], judges=[
4560+
{"key": "relevance", "threshold": 0.8},
4561+
])
4562+
result = self._build(config=config)
4563+
assert result.judges["relevance"].is_inverted is False
4564+
4565+
def test_get_ai_config_called_once_per_judge(self):
4566+
config = dict(_API_CONFIG, acceptanceStatements=[], judges=[
4567+
{"key": "toxicity", "threshold": 0.3},
4568+
{"key": "relevance", "threshold": 0.8},
4569+
])
4570+
self._build(config=config)
4571+
assert self.api_client.get_ai_config.call_count == 2
4572+
4573+
def test_acceptance_statements_skip_get_ai_config(self):
4574+
"""Acceptance statement judges are not backed by AI Configs."""
4575+
config = dict(_API_CONFIG, judges=[], acceptanceStatements=[
4576+
{"statement": "Be accurate.", "threshold": 0.9},
4577+
])
4578+
self._build(config=config)
4579+
self.api_client.get_ai_config.assert_not_called()
4580+
4581+
def test_raises_when_get_ai_config_fails(self):
4582+
"""A failing get_ai_config call propagates — the build should not silently ignore it."""
4583+
self.api_client.get_ai_config.side_effect = Exception("API error")
4584+
config = dict(_API_CONFIG, acceptanceStatements=[], judges=[
4585+
{"key": "toxicity", "threshold": 0.3},
4586+
])
4587+
with pytest.raises(Exception, match="API error"):
4588+
self._build(config=config)
4589+
4590+
def test_per_judge_isInverted_mixed(self):
4591+
"""Different judges can have different isInverted values."""
4592+
def _get_ai_config_side_effect(project_key, config_key):
4593+
return {"isInverted": True} if config_key == "toxicity" else {"isInverted": False}
4594+
4595+
self.api_client.get_ai_config.side_effect = _get_ai_config_side_effect
4596+
config = dict(_API_CONFIG, acceptanceStatements=[], judges=[
4597+
{"key": "toxicity", "threshold": 0.3},
4598+
{"key": "relevance", "threshold": 0.8},
4599+
])
4600+
result = self._build(config=config)
4601+
assert result.judges["toxicity"].is_inverted is True
4602+
assert result.judges["relevance"].is_inverted is False
4603+
4604+
4605+
# ---------------------------------------------------------------------------
4606+
# variation_prompt_feedback with inverted judges
4607+
# ---------------------------------------------------------------------------
4608+
4609+
4610+
class TestVariationPromptFeedbackInvertedJudges:
4611+
def _make_ctx(self, scores: Dict[str, JudgeResult], iteration: int = 1) -> OptimizationContext:
4612+
return OptimizationContext(
4613+
scores=scores,
4614+
completion_response="Some response.",
4615+
current_instructions="Do X.",
4616+
current_parameters={},
4617+
current_variables={},
4618+
iteration=iteration,
4619+
)
4620+
4621+
def test_inverted_judge_shows_passed_when_score_below_threshold(self):
4622+
ctx = self._make_ctx({"toxicity": JudgeResult(score=0.1, rationale="Very clean.")})
4623+
judges = {"toxicity": OptimizationJudge(threshold=0.3, acceptance_statement="Low toxicity.", is_inverted=True)}
4624+
result = variation_prompt_feedback([ctx], judges)
4625+
assert "PASSED" in result
4626+
4627+
def test_inverted_judge_shows_failed_when_score_above_threshold(self):
4628+
ctx = self._make_ctx({"toxicity": JudgeResult(score=0.8, rationale="Very toxic.")})
4629+
judges = {"toxicity": OptimizationJudge(threshold=0.3, acceptance_statement="Low toxicity.", is_inverted=True)}
4630+
result = variation_prompt_feedback([ctx], judges)
4631+
assert "FAILED" in result
4632+
4633+
def test_standard_judge_shows_passed_when_score_above_threshold(self):
4634+
ctx = self._make_ctx({"relevance": JudgeResult(score=0.9)})
4635+
judges = {"relevance": OptimizationJudge(threshold=0.8, acceptance_statement="Relevant.", is_inverted=False)}
4636+
result = variation_prompt_feedback([ctx], judges)
4637+
assert "PASSED" in result
4638+
4639+
def test_standard_judge_shows_failed_when_score_below_threshold(self):
4640+
ctx = self._make_ctx({"relevance": JudgeResult(score=0.5)})
4641+
judges = {"relevance": OptimizationJudge(threshold=0.8, acceptance_statement="Relevant.", is_inverted=False)}
4642+
result = variation_prompt_feedback([ctx], judges)
4643+
assert "FAILED" in result
4644+
4645+
def test_mixed_judges_feedback_reflects_correct_pass_fail(self):
4646+
ctx = self._make_ctx({
4647+
"relevance": JudgeResult(score=0.9),
4648+
"toxicity": JudgeResult(score=0.05),
4649+
})
4650+
judges = {
4651+
"relevance": OptimizationJudge(threshold=0.8, acceptance_statement="Relevant.", is_inverted=False),
4652+
"toxicity": OptimizationJudge(threshold=0.3, acceptance_statement="Low toxicity.", is_inverted=True),
4653+
}
4654+
result = variation_prompt_feedback([ctx], judges)
4655+
# Both should be PASSED — relevance high enough, toxicity low enough
4656+
assert result.count("PASSED") == 2
4657+
assert "FAILED" not in result

0 commit comments

Comments
 (0)