Skip to content

Commit f5c3592

Browse files
authored
Add missing judge_tools.py (#347)
* missed a file in intial mlflow judge evaluation. updated README.md to be a bit more clear * rate limit workaround adjustments * linting fixes * linting fixes
1 parent 1d74320 commit f5c3592

File tree

7 files changed

+383
-24
lines changed

7 files changed

+383
-24
lines changed

.test/README.md

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -67,14 +67,14 @@ Agent evaluation runs a real Claude Code instance via the [Claude Agent SDK](htt
6767
}
6868
```
6969

70-
| Field | Purpose |
71-
|-------|---------|
72-
| `ANTHROPIC_MODEL` | Default model the agent uses |
73-
| `ANTHROPIC_BASE_URL` | Claude API endpoint (Databricks AI Gateway or direct) |
74-
| `ANTHROPIC_AUTH_TOKEN` | Auth token — supports `${VAR:-default}` interpolation |
75-
| `ANTHROPIC_CUSTOM_HEADERS` | Extra headers (e.g., coding agent mode for Databricks) |
76-
| `DATABRICKS_CONFIG_PROFILE` | Databricks CLI profile for MCP tools |
77-
| `DATABRICKS_API_KEY` | Databricks token for MCP tool calls |
70+
| Field | Purpose |
71+
|-------|-------------------------------------------------------------------------|
72+
| `ANTHROPIC_MODEL` | Default model the agent uses. Currently points to Databricks by default |
73+
| `ANTHROPIC_BASE_URL` | Claude API endpoint (Databricks AI Gateway or direct) |
74+
| `ANTHROPIC_AUTH_TOKEN` | Auth token — supports `${VAR:-default}` interpolation |
75+
| `ANTHROPIC_CUSTOM_HEADERS` | Extra headers (e.g., coding agent mode for Databricks) |
76+
| `DATABRICKS_CONFIG_PROFILE` | Databricks CLI profile for MCP tools |
77+
| `DATABRICKS_API_KEY` | Databricks token for MCP tool calls |
7878

7979
The `${VAR:-default}` syntax lets you reference environment variables with fallbacks. The agent runs with `bypassPermissions` mode so it doesn't prompt for tool approval.
8080

.test/src/skill_test/agent/executor.py

Lines changed: 46 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,25 @@
3333
_mlflow_env_lock = threading.Lock()
3434
_mlflow_env_configured = False
3535

36+
# Serialize process_transcript calls across parallel agents to avoid
37+
# burst HTTP load on the MLflow tracking server when multiple agents
38+
# finish concurrently (e.g. --parallel-agents 3).
39+
# Lazy per-loop factory: asyncio.Semaphore binds to the running loop at
40+
# creation time. When _run_in_fresh_loop creates a new loop the module-level
41+
# semaphore would crash with "attached to a different loop". Instead we
42+
# cache one semaphore per event-loop id.
43+
_transcript_semaphores: dict[int, asyncio.Semaphore] = {}
44+
_transcript_semaphore_lock = threading.Lock()
45+
46+
47+
def _get_transcript_semaphore() -> asyncio.Semaphore:
48+
"""Return a Semaphore(1) bound to the current running event loop."""
49+
loop_id = id(asyncio.get_running_loop())
50+
with _transcript_semaphore_lock:
51+
if loop_id not in _transcript_semaphores:
52+
_transcript_semaphores[loop_id] = asyncio.Semaphore(1)
53+
return _transcript_semaphores[loop_id]
54+
3655

3756
@dataclass
3857
class AgentEvent:
@@ -391,20 +410,34 @@ async def mlflow_stop_hook(input_data, tool_use_id, context):
391410
# Run process_transcript synchronously — it does HTTP I/O per span
392411
# so can take 20-40s for large sessions. Use a generous timeout to
393412
# prevent hangs from rate limits or network issues.
413+
# Serialize across parallel agents to avoid burst HTTP load on the
414+
# MLflow tracking server when multiple agents finish concurrently.
394415
loop = asyncio.get_running_loop()
395-
try:
396-
trace = await asyncio.wait_for(
397-
loop.run_in_executor(None, process_transcript, transcript_path, session_id),
398-
timeout=120.0,
399-
)
400-
except asyncio.TimeoutError:
401-
print(
402-
f" [MLflow] ERROR: process_transcript timed out after 120s "
403-
f"(session={session_id}). This may indicate rate limiting or "
404-
f"network issues. Continuing without trace."
405-
)
406-
result_holder["trace"] = None
407-
return {"continue": True}
416+
max_retries = 3
417+
trace = None
418+
async with _get_transcript_semaphore():
419+
for attempt in range(max_retries):
420+
try:
421+
trace = await asyncio.wait_for(
422+
loop.run_in_executor(None, process_transcript, transcript_path, session_id),
423+
timeout=300.0,
424+
)
425+
break
426+
except asyncio.TimeoutError:
427+
if attempt < max_retries - 1:
428+
wait = 2 ** (attempt + 1) # 2s, 4s
429+
print(
430+
f" [MLflow] process_transcript attempt {attempt + 1} timed out, "
431+
f"retrying in {wait}s..."
432+
)
433+
await asyncio.sleep(wait)
434+
else:
435+
print(
436+
f" [MLflow] ERROR: process_transcript timed out after {max_retries} "
437+
f"attempts (session={session_id}). Continuing without trace."
438+
)
439+
result_holder["trace"] = None
440+
return {"continue": True}
408441

409442
result_holder["trace"] = trace
410443

.test/src/skill_test/optimize/agent_evaluator.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,9 @@ def __init__(
157157
self._mlflow_experiment = mlflow_experiment
158158
self._skill_name = skill_name
159159

160+
# Cache WITH-skill evaluation results keyed on (prompt_hash, candidate_hash)
161+
self._with_skill_cache: dict[str, tuple[float, dict]] = {}
162+
160163
# Caches for WITHOUT-skill runs (keyed by prompt hash)
161164
self._baseline_response_cache: dict[str, str] = {}
162165
self._baseline_trace_cache: dict[str, dict] = {}
@@ -258,6 +261,12 @@ def _evaluate(
258261
skill_md = candidate.get("skill_md", "")
259262
prompt = example.get("input", "")
260263

264+
# Check candidate-level cache
265+
candidate_hash = hashlib.sha256(json.dumps(candidate, sort_keys=True).encode()).hexdigest()[:16]
266+
cache_key = f"{_prompt_hash(prompt)}:{candidate_hash}"
267+
if cache_key in self._with_skill_cache:
268+
return self._with_skill_cache[cache_key]
269+
261270
# Decode expectations
262271
expectations: dict[str, Any] = {}
263272
expectations_json = example.get("additional_context", {}).get("expectations", "")
@@ -629,6 +638,9 @@ def _judge_with_fallback(
629638
f"guideline_adherence={guideline_adherence_score:.2f}"
630639
)
631640

641+
# Store in candidate-level cache
642+
self._with_skill_cache[cache_key] = (final_score, side_info)
643+
632644
return final_score, side_info
633645

634646

Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
"""Custom judge tools for adaptive evaluation criteria loading.
2+
3+
Implements ``ReadSkillTool`` and ``ReadSkillReferenceTool`` from the MLflow
4+
#21255 design spec, registered in MLflow's global ``JudgeToolRegistry`` so
5+
they are available to any trace-based judge.
6+
7+
Key difference from spec: tools accept ``trace: Trace`` (required by the
8+
``JudgeTool`` interface) but use the internal ``EvalCriteriaSet`` for skill
9+
lookup. When the native ``make_judge(skills=[...])`` API lands, replace
10+
this module with MLflow's built-in skill tools which route via type
11+
annotation.
12+
13+
Registry invocation flow::
14+
15+
registry.invoke(tool_call, trace)
16+
→ json.loads(tool_call.function.arguments)
17+
→ tool.invoke(trace, **parsed_args)
18+
"""
19+
20+
from __future__ import annotations
21+
22+
import logging
23+
import os
24+
from typing import Any
25+
26+
from mlflow.entities.trace import Trace
27+
from mlflow.genai.judges.tools.base import JudgeTool
28+
from mlflow.genai.judges.tools.registry import register_judge_tool
29+
from mlflow.types.llm import FunctionToolDefinition, ParamProperty, ToolDefinition, ToolParamsSchema
30+
31+
from .eval_criteria import EvalCriteriaSet
32+
33+
logger = logging.getLogger(__name__)
34+
35+
36+
class ReadEvalCriteriaTool(JudgeTool):
37+
"""Read the full body of an evaluation-criteria skill.
38+
39+
The judge calls this tool when a criteria's description matches the
40+
trace it is evaluating.
41+
"""
42+
43+
def __init__(self, criteria_set: EvalCriteriaSet):
44+
self._criteria_set = criteria_set
45+
46+
@property
47+
def name(self) -> str:
48+
return "read_eval_criteria"
49+
50+
def get_definition(self) -> ToolDefinition:
51+
available = self._criteria_set.names
52+
return ToolDefinition(
53+
function=FunctionToolDefinition(
54+
name="read_eval_criteria",
55+
description=(
56+
"Read the full content of an evaluation criteria skill to get domain-specific "
57+
"rubrics, scoring rules, and reference material. Use this when a criteria's "
58+
f"description matches the trace content. Available criteria: {available}"
59+
),
60+
parameters=ToolParamsSchema(
61+
properties={
62+
"skill_name": ParamProperty(
63+
type="string", description="Name of the evaluation criteria to read"
64+
),
65+
},
66+
),
67+
),
68+
)
69+
70+
def invoke(self, trace: Trace, skill_name: str) -> str:
71+
skill = self._criteria_set.get_skill(skill_name)
72+
if not skill:
73+
available = self._criteria_set.names
74+
return f"Error: No criteria named '{skill_name}'. Available: {available}"
75+
return skill.body
76+
77+
78+
class ReadEvalReferenceTool(JudgeTool):
79+
"""Read a reference document from a criteria's ``references/`` directory.
80+
81+
Used for detailed rubrics, edge cases, and scoring examples.
82+
"""
83+
84+
def __init__(self, criteria_set: EvalCriteriaSet):
85+
self._criteria_set = criteria_set
86+
87+
@property
88+
def name(self) -> str:
89+
return "read_eval_reference"
90+
91+
def get_definition(self) -> ToolDefinition:
92+
return ToolDefinition(
93+
function=FunctionToolDefinition(
94+
name="read_eval_reference",
95+
description=(
96+
"Read a reference document from an evaluation criteria skill for detailed "
97+
"rubrics, edge cases, or scoring examples."
98+
),
99+
parameters=ToolParamsSchema(
100+
properties={
101+
"skill_name": ParamProperty(type="string", description="Name of the evaluation criteria"),
102+
"file_path": ParamProperty(
103+
type="string",
104+
description="Relative path within the skill (e.g., 'references/RUBRIC.md')",
105+
),
106+
},
107+
),
108+
),
109+
)
110+
111+
def invoke(self, trace: Trace, skill_name: str, file_path: str) -> str:
112+
skill = self._criteria_set.get_skill(skill_name)
113+
if not skill:
114+
available = self._criteria_set.names
115+
return f"Error: No criteria named '{skill_name}'. Available: {available}"
116+
normalized = os.path.normpath(file_path)
117+
if normalized.startswith("..") or os.path.isabs(normalized):
118+
return f"Error: Invalid file path '{file_path}'. Must be relative."
119+
if normalized not in skill.references:
120+
return f"Error: File '{file_path}' not found in '{skill_name}'"
121+
return skill.references[normalized]
122+
123+
124+
_registered = False
125+
126+
127+
def register_eval_tools(criteria_set: EvalCriteriaSet) -> None:
128+
"""Register eval-criteria tools in MLflow's global ``JudgeToolRegistry``.
129+
130+
Safe to call multiple times — tools are registered only once per process.
131+
"""
132+
global _registered
133+
if _registered:
134+
return
135+
if not criteria_set.skills:
136+
logger.debug("No eval criteria loaded; skipping tool registration")
137+
return
138+
register_judge_tool(ReadEvalCriteriaTool(criteria_set))
139+
register_judge_tool(ReadEvalReferenceTool(criteria_set))
140+
_registered = True
141+
logger.info("Registered eval criteria judge tools (%d criteria available)", len(criteria_set.skills))

0 commit comments

Comments
 (0)