Skip to content

Commit 5137f64

Browse files
committed
final quickstart example server + evaluator
1 parent e3260c6 commit 5137f64

12 files changed

Lines changed: 1932 additions & 0 deletions

File tree

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
FROM python:3.11-slim
2+
3+
# Set working directory
4+
WORKDIR /app
5+
6+
# Install system dependencies for SVG rendering
7+
RUN apt-get update && apt-get install -y \
8+
chromium \
9+
&& rm -rf /var/lib/apt/lists/*
10+
11+
# Copy requirements first for better Docker caching
12+
COPY requirements.txt .
13+
14+
# Install Python dependencies
15+
RUN pip install --no-cache-dir -r requirements.txt
16+
17+
# Copy evaluation code
18+
COPY . .
19+
20+
# Set environment for better logging
21+
ENV PYTHONUNBUFFERED=1

eval_protocol/quickstart/svg_agent/evaluator/svgbench_dataset.jsonl

Lines changed: 50 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 276 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,276 @@
1+
"""
2+
SVGBench evaluation test for EvalProtocol.io using RemoteRolloutProcessor.
3+
4+
This test evaluates LLM ability to generate SVG code that meets specific visual requirements.
5+
The remote server handles:
6+
1. SVG code generation from text prompts (model calls)
7+
8+
The local test handles:
9+
2. SVG to PNG rendering using Selenium
10+
3. LLM judge evaluation of requirement fulfillment
11+
4. Scoring based on fulfilled requirements ratio
12+
"""
13+
14+
import base64
15+
import json
16+
import logging
17+
import os
18+
import tempfile
19+
import traceback
20+
from pathlib import Path
21+
from typing import Any, Dict, List
22+
import asyncio
23+
24+
import litellm
25+
from pydantic import BaseModel
26+
27+
from eval_protocol.models import EvaluateResult, EvaluationRow, InputMetadata, Message, MetricResult
28+
from eval_protocol.pytest import evaluation_test
29+
from eval_protocol.pytest.remote_rollout_processor import RemoteRolloutProcessor
30+
31+
from utils import extract_svg_code, render_svg_to_png
32+
33+
logger = logging.getLogger(__name__)
34+
35+
36+
class SVGBenchResponse(BaseModel):
37+
reasoning: str
38+
number_of_fulfilled_requirements: int
39+
40+
41+
class IntentMatchingResponse(BaseModel):
42+
"""Response structure for intent matching evaluation."""
43+
44+
intent_reasoning: str
45+
intent_matching_score: float # 0-1: Does the content match the intended purpose?
46+
47+
48+
def svgbench_to_evaluation_row(data: List[Dict[str, Any]]) -> List[EvaluationRow]:
49+
"""
50+
Convert SVGBench dataset entries to EvaluationRow objects.
51+
52+
Args:
53+
data: List of dictionaries containing prompt and requirements
54+
55+
Returns:
56+
List of EvaluationRow objects
57+
"""
58+
rows = []
59+
60+
for i, row in enumerate(data):
61+
# Format requirements as numbered list
62+
requirements = "\n".join([f"{i + 1}. {req}" for i, req in enumerate(row["requirements"])])
63+
64+
# Create the generation prompt following SVGBench format
65+
prompt = f"""{row["prompt"]} Wrap the SVG code in an SVG code block following the example below.
66+
67+
Example:
68+
```svg
69+
<svg viewBox="0 0 100 100" width="100" height="100">
70+
<circle cx="50" cy="50" r="40" fill="red" />
71+
</svg>
72+
```
73+
74+
Requirements:
75+
{requirements}"""
76+
77+
eval_row = EvaluationRow(
78+
messages=[Message(role="user", content=prompt)],
79+
input_metadata=InputMetadata(
80+
row_id=f"row_{i}",
81+
dataset_info={
82+
"original_prompt": row["prompt"],
83+
"requirements": row["requirements"],
84+
"total_requirements": len(row["requirements"]),
85+
"formatted_prompt": prompt,
86+
},
87+
),
88+
)
89+
90+
rows.append(eval_row)
91+
92+
return rows
93+
94+
95+
async def evaluate_with_llm_judge(image_path: str, requirements: List[str]) -> Dict[str, Any]:
96+
"""
97+
Use LLM judge to evaluate how many requirements are fulfilled.
98+
Uses GPT-4.1 for vision capabilities to match project's model preferences. (note original repo uses Gemini 2.5 flashs)
99+
100+
Args:
101+
image_path: Path to rendered PNG image
102+
requirements: List of requirements to evaluate
103+
104+
Returns:
105+
Dictionary with evaluation results
106+
"""
107+
# Format requirements for evaluation (exactly as in original)
108+
requirements_text = "\n".join([f"{i + 1}. {req}" for i, req in enumerate(requirements)])
109+
110+
# Create evaluation prompt with JSON response format
111+
evaluate_prompt = f"""Examine the generated image. How many of the following {len(requirements)} requirements were fulfilled?
112+
113+
Be strict about the requirements and respond ONLY with a JSON object in this exact format:
114+
{{"reasoning": <reasoning_text>,
115+
"number_of_fulfilled_requirements": <count>}}
116+
117+
Where <count> is a number between 0 and {len(requirements)}.
118+
119+
Requirements:
120+
{requirements_text}"""
121+
122+
# Read and encode image
123+
with open(image_path, "rb") as f:
124+
image_data = base64.b64encode(f.read()).decode("utf-8")
125+
126+
# Prepare messages with image
127+
messages = [
128+
{
129+
"role": "user",
130+
"content": [
131+
{"type": "text", "text": evaluate_prompt},
132+
{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{image_data}"}},
133+
],
134+
}
135+
]
136+
137+
# Use GPT-4.1 for vision capabilities to match project's OpenAI model preference
138+
response = await litellm.acompletion(
139+
model="gpt-4.1",
140+
messages=messages,
141+
temperature=0.0,
142+
response_format={
143+
"type": "json_schema",
144+
"json_schema": {"name": "SVGBenchResponse", "schema": SVGBenchResponse.model_json_schema()},
145+
},
146+
)
147+
148+
# Parse response
149+
response_content = response.choices[0].message.content # pyright: ignore[reportAttributeAccessIssue]
150+
151+
# Handle empty response
152+
if not response_content or response_content.strip() == "":
153+
raise ValueError("Empty response from LLM judge")
154+
155+
result = json.loads(response_content)
156+
157+
# Validate the result
158+
if "number_of_fulfilled_requirements" in result:
159+
return result
160+
else:
161+
raise ValueError("Missing required field in response")
162+
163+
164+
@evaluation_test(
165+
input_dataset=[str(Path(__file__).parent / "svgbench_dataset.jsonl")],
166+
dataset_adapter=svgbench_to_evaluation_row,
167+
completion_params=[
168+
{
169+
"temperature": 0.8,
170+
"model": "fireworks_ai/accounts/fireworks/models/gpt-oss-120b",
171+
"extra_body": {"reasoning_effort": "medium"},
172+
},
173+
],
174+
rollout_processor=RemoteRolloutProcessor(
175+
remote_base_url="https://vercel-svg-server-ts.vercel.app",
176+
),
177+
passed_threshold=0.5,
178+
max_dataset_rows=8,
179+
num_runs=1,
180+
mode="pointwise",
181+
)
182+
async def test_svg_generation_evaluation(row: EvaluationRow) -> EvaluationRow:
183+
"""
184+
SVG generation evaluation.
185+
186+
This evaluation asks: How many of the requirements were fulfilled?
187+
"""
188+
assert row.input_metadata.dataset_info is not None
189+
190+
# Extract dataset info
191+
requirements = row.input_metadata.dataset_info["requirements"]
192+
total_requirements = row.input_metadata.dataset_info["total_requirements"]
193+
original_prompt = row.input_metadata.dataset_info["original_prompt"]
194+
row_id = row.input_metadata.row_id
195+
196+
# Check if we should save debug files
197+
save_debug_files = os.environ.get("SVGBENCH_SAVE_DEBUG_FILES", "false").lower() == "true"
198+
199+
# Get model response
200+
if not row.messages or len(row.messages) < 2:
201+
row.evaluation_result = EvaluateResult(score=0.0, reason="No model response found", is_score_valid=False)
202+
return row
203+
204+
model_response = row.messages[-1].content
205+
assert isinstance(model_response, str)
206+
207+
# Extract SVG code
208+
try:
209+
svg_code = extract_svg_code(model_response)
210+
if not svg_code:
211+
raise ValueError("No valid SVG code found in response")
212+
except Exception as e:
213+
logger.error(f"Error extracting SVG code for question {row_id}: {e}")
214+
row.evaluation_result = EvaluateResult(score=0.0, reason=f"SVG extraction failed: {str(e)}")
215+
return row
216+
217+
# Setup file paths
218+
if save_debug_files:
219+
model = row.input_metadata.completion_params["model"]
220+
safe_model_name = model.replace("/", "_").replace(":", "_")
221+
debug_dir = "svgbench_debug_intent_matching"
222+
os.makedirs(debug_dir, exist_ok=True)
223+
png_path = os.path.join(debug_dir, f"question_{row_id}_{safe_model_name}.png")
224+
svg_path = os.path.join(debug_dir, f"question_{row_id}_{safe_model_name}.svg")
225+
with open(svg_path, "w") as f:
226+
f.write(svg_code)
227+
else:
228+
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as f:
229+
png_path = f.name
230+
231+
try:
232+
# Render SVG to PNG
233+
try:
234+
svg_render_success = await asyncio.to_thread(render_svg_to_png, svg_code, png_path)
235+
if not svg_render_success:
236+
row.evaluation_result = EvaluateResult(
237+
score=0.0,
238+
reason="Failed to render SVG to PNG - render_svg_to_png returned False",
239+
is_score_valid=False,
240+
)
241+
return row
242+
except Exception as e:
243+
# Capture full stack trace for debugging
244+
full_traceback = traceback.format_exc()
245+
error_reason = f"Failed to render SVG to PNG - Exception occurred:\n\nError: {str(e)}\n\nFull Stack Trace:\n{full_traceback}"
246+
row.evaluation_result = EvaluateResult(score=0.0, reason=error_reason, is_score_valid=False)
247+
return row
248+
249+
# Run LLM judge evaluation
250+
judge_result = await evaluate_with_llm_judge(png_path, requirements)
251+
252+
# Calculate score
253+
fulfilled_count = judge_result.get("number_of_fulfilled_requirements", 0)
254+
fulfilled_count = max(0, min(fulfilled_count, total_requirements)) # Clamp to valid range
255+
score = fulfilled_count / total_requirements
256+
257+
row.evaluation_result = EvaluateResult(
258+
score=score,
259+
reason=judge_result.get("reasoning", ""),
260+
)
261+
262+
return row
263+
264+
except Exception as e:
265+
logger.error(f"LLM judge evaluation failed for question {row_id}: {e}")
266+
row.evaluation_result = EvaluateResult(score=0.0, reason=f"Evaluation error: {str(e)}", is_score_valid=False)
267+
return row
268+
269+
finally:
270+
# Clean up temporary PNG file (only if not saving debug files)
271+
if not save_debug_files:
272+
try:
273+
if os.path.exists(png_path):
274+
os.unlink(png_path)
275+
except Exception:
276+
pass

0 commit comments

Comments
 (0)