Skip to content

Commit 341b38e

Browse files
feat: support Anthropic-compatible endpoints for benchmark LLMs
1 parent 45fa380 commit 341b38e

6 files changed

Lines changed: 253 additions & 10 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ version = "0.1.0"
44
description = "Open Memory Benchmark"
55
requires-python = ">=3.11"
66
dependencies = [
7+
"anthropic>=0.84.0",
78
"datasets>=2.0",
89
"typer>=0.12",
910
"rich>=13",

src/memory_bench/cli.py

Lines changed: 59 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
load_dotenv(dotenv_path=Path(__file__).parents[2] / ".env", override=True)
1313

1414
from .dataset import REGISTRY as DATASET_REGISTRY, get_dataset
15-
from .llm import REGISTRY as LLM_REGISTRY, get_llm, get_answer_llm
15+
from .llm import REGISTRY as LLM_REGISTRY, get_answer_llm
1616
from .memory import REGISTRY as MEMORY_REGISTRY, get_memory_provider
1717
from .modes import REGISTRY as MODE_REGISTRY, get_mode
1818
from .runner import EvalRunner
@@ -22,12 +22,63 @@
2222
console = Console()
2323

2424

25-
def _resolve_gemini_key() -> None:
26-
key = os.environ.get("GEMINI_API_KEY") or os.environ.get("GOOGLE_API_KEY")
27-
if not key:
28-
typer.echo("Error: GEMINI_API_KEY environment variable is not set.", err=True)
25+
def _ensure_provider_env(provider: str, role: str) -> None:
26+
if provider not in LLM_REGISTRY:
27+
typer.echo(
28+
f"Error: unknown {role.lower()} LLM provider '{provider}'. Available: {', '.join(LLM_REGISTRY)}.",
29+
err=True,
30+
)
2931
raise typer.Exit(1)
30-
os.environ["GOOGLE_API_KEY"] = key
32+
33+
if provider == "anthropic":
34+
if not os.environ.get("ANTHROPIC_API_KEY"):
35+
typer.echo(f"Error: {role} LLM provider '{provider}' requires ANTHROPIC_API_KEY.", err=True)
36+
raise typer.Exit(1)
37+
return
38+
39+
if provider == "gemini":
40+
key = os.environ.get("GEMINI_API_KEY") or os.environ.get("GOOGLE_API_KEY")
41+
if not key:
42+
typer.echo(f"Error: {role} LLM provider '{provider}' requires GEMINI_API_KEY.", err=True)
43+
raise typer.Exit(1)
44+
os.environ["GOOGLE_API_KEY"] = key
45+
return
46+
47+
if provider == "groq":
48+
if not os.environ.get("GROQ_API_KEY"):
49+
typer.echo(f"Error: {role} LLM provider '{provider}' requires GROQ_API_KEY.", err=True)
50+
raise typer.Exit(1)
51+
return
52+
53+
if provider == "openai":
54+
if not os.environ.get("OPENAI_API_KEY"):
55+
typer.echo(f"Error: {role} LLM provider '{provider}' requires OPENAI_API_KEY.", err=True)
56+
raise typer.Exit(1)
57+
return
58+
59+
60+
def _validate_run_env(memory: str, mode: str, answer_provider: str | None = None) -> None:
61+
if answer_provider is not None:
62+
os.environ["OMB_ANSWER_LLM"] = answer_provider
63+
64+
answer_provider = os.environ.get("OMB_ANSWER_LLM", "groq")
65+
judge_provider = os.environ.get("OMB_JUDGE_LLM", "gemini")
66+
_ensure_provider_env(answer_provider, "Answer")
67+
_ensure_provider_env(judge_provider, "Judge")
68+
69+
if mode == "agentic-rag" and answer_provider != "gemini":
70+
typer.echo(
71+
f"Error: response mode 'agentic-rag' requires a tool-capable LLM provider; '{answer_provider}' is not supported.",
72+
err=True,
73+
)
74+
raise typer.Exit(1)
75+
76+
if memory == "hindsight":
77+
key = os.environ.get("GEMINI_API_KEY") or os.environ.get("GOOGLE_API_KEY")
78+
if not key:
79+
typer.echo("Error: memory provider 'hindsight' requires GEMINI_API_KEY for embedded extraction.", err=True)
80+
raise typer.Exit(1)
81+
os.environ["GOOGLE_API_KEY"] = key
3182

3283

3384
@app.command()
@@ -36,7 +87,7 @@ def run(
3687
dataset: str = typer.Option("tempo", "--dataset", help=f"Dataset. Available: {', '.join(DATASET_REGISTRY)}"),
3788
memory: str = typer.Option("bm25", "--memory", "-m", help=f"Memory provider. Available: {', '.join(MEMORY_REGISTRY)}"),
3889
mode: str = typer.Option("rag", "--mode", help=f"Response mode. Available: {', '.join(MODE_REGISTRY)}"),
39-
llm: str = typer.Option("gemini", "--llm", help=f"LLM for answer generation. Available: {', '.join(LLM_REGISTRY)}"),
90+
llm: str | None = typer.Option(None, "--llm", help=f"LLM provider for answer generation. Overrides OMB_ANSWER_LLM. Available: {', '.join(LLM_REGISTRY)}"),
4091
category: str = typer.Option(None, "--category", "-c", help="Category filter(s), comma-separated (e.g. 'a,b,c'). With --query-limit, runs N queries per category."),
4192
query_limit: int = typer.Option(None, "--query-limit", "-q", help="Max queries to evaluate. When combined with multiple --category values, applies per category."),
4293
query_id: str = typer.Option(None, "--query-id", help="Run a single specific query by ID"),
@@ -53,7 +104,7 @@ def run(
53104
description: str = typer.Option(None, "--description", "-d", help="Optional description for this run (stored in the result JSON)"),
54105
) -> None:
55106
"""Run an evaluation on a single split (optionally filtered to a category)."""
56-
_resolve_gemini_key()
107+
_validate_run_env(memory, mode, llm)
57108

58109
ds = get_dataset(dataset)
59110

src/memory_bench/llm/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
import os
22

3+
from .anthropic import AnthropicLLM
34
from .base import LLM, Schema
45
from .gemini import GeminiLLM
56
from .groq import GroqLLM
67
from .openai import OpenAILLM
78

89
REGISTRY: dict[str, type[LLM]] = {
10+
"anthropic": AnthropicLLM,
911
"gemini": GeminiLLM,
1012
"groq": GroqLLM,
1113
"openai": OpenAILLM,

src/memory_bench/llm/anthropic.py

Lines changed: 183 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,183 @@
1+
import json
2+
import os
3+
import re
4+
import time
5+
6+
from .base import LLM, Schema
7+
8+
_MAX_RETRIES = 6
9+
_RETRY_BASE_DELAY = 5
10+
11+
12+
class _StructuredOutputError(ValueError):
13+
"""Raised when the model response does not match the requested schema."""
14+
15+
16+
def _parse_json_payload(text: str) -> dict:
17+
text = text.strip()
18+
19+
try:
20+
payload = json.loads(text)
21+
except json.JSONDecodeError:
22+
pass
23+
else:
24+
if not isinstance(payload, dict):
25+
raise _StructuredOutputError("Model response must be a JSON object")
26+
return payload
27+
28+
fenced = re.search(r"```(?:json)?\s*(\{.*\})\s*```", text, flags=re.DOTALL | re.IGNORECASE)
29+
if fenced:
30+
payload = json.loads(fenced.group(1))
31+
if not isinstance(payload, dict):
32+
raise _StructuredOutputError("Model response must be a JSON object")
33+
return payload
34+
35+
start = text.find("{")
36+
end = text.rfind("}")
37+
if start != -1 and end != -1 and end > start:
38+
payload = json.loads(text[start : end + 1])
39+
if not isinstance(payload, dict):
40+
raise _StructuredOutputError("Model response must be a JSON object")
41+
return payload
42+
43+
raise json.JSONDecodeError("Could not find JSON object in model response", text, 0)
44+
45+
46+
def _coerce_text_payload(text: str, schema: Schema) -> dict | None:
47+
text = text.strip()
48+
if not text:
49+
return None
50+
if len(schema.required) != 1:
51+
return None
52+
53+
field = schema.required[0]
54+
spec = schema.properties.get(field, {})
55+
field_type = spec.get("type", "string")
56+
57+
if field_type == "string":
58+
return {field: text}
59+
60+
if field_type == "boolean":
61+
lowered = text.lower()
62+
if lowered == "true":
63+
return {field: True}
64+
if lowered == "false":
65+
return {field: False}
66+
67+
return None
68+
69+
70+
def _validate_schema_payload(payload: dict, schema: Schema) -> dict:
71+
extra = sorted(set(payload) - set(schema.properties))
72+
if extra:
73+
raise _StructuredOutputError(f"Model response included unsupported field(s): {', '.join(extra)}")
74+
75+
missing = [field for field in schema.required if field not in payload]
76+
if missing:
77+
raise _StructuredOutputError(f"Model response omitted required field(s): {', '.join(missing)}")
78+
79+
for field, value in payload.items():
80+
spec = schema.properties.get(field, {})
81+
expected_type = spec.get("type", "string")
82+
if expected_type == "string":
83+
valid = isinstance(value, str)
84+
elif expected_type == "boolean":
85+
valid = isinstance(value, bool)
86+
elif expected_type == "integer":
87+
valid = isinstance(value, int) and not isinstance(value, bool)
88+
elif expected_type == "number":
89+
valid = isinstance(value, (int, float)) and not isinstance(value, bool)
90+
elif expected_type == "array":
91+
valid = isinstance(value, list)
92+
elif expected_type == "object":
93+
valid = isinstance(value, dict)
94+
else:
95+
valid = True
96+
97+
if not valid:
98+
raise _StructuredOutputError(
99+
f"Model response field '{field}' must be {expected_type}, got {type(value).__name__}"
100+
)
101+
102+
return payload
103+
104+
105+
class AnthropicLLM(LLM):
106+
def __init__(self, model: str | None = None):
107+
from anthropic import Anthropic
108+
109+
api_key = os.environ.get("ANTHROPIC_API_KEY")
110+
if not api_key:
111+
raise RuntimeError("Anthropic provider requires ANTHROPIC_API_KEY")
112+
113+
base_url = os.environ.get("ANTHROPIC_BASE_URL")
114+
self._client = Anthropic(
115+
api_key=api_key,
116+
base_url=base_url or None,
117+
max_retries=0,
118+
)
119+
self._model = (
120+
model
121+
or os.environ.get("ANTHROPIC_MODEL")
122+
or "claude-sonnet-4-5"
123+
)
124+
125+
@property
126+
def model_id(self) -> str:
127+
return f"anthropic:{self._model}"
128+
129+
def generate(self, prompt: str, schema: Schema) -> dict:
130+
from anthropic import APIConnectionError, APIStatusError, RateLimitError
131+
132+
schema_json = {
133+
"type": "object",
134+
"properties": schema.properties,
135+
"required": schema.required,
136+
"additionalProperties": False,
137+
}
138+
system_prompt = (
139+
"Return only a valid JSON object matching this schema. "
140+
"Do not wrap JSON in markdown fences.\n\n"
141+
f"{json.dumps(schema_json, ensure_ascii=False)}"
142+
)
143+
144+
delay = _RETRY_BASE_DELAY
145+
last_exc = None
146+
147+
for attempt in range(_MAX_RETRIES):
148+
try:
149+
response = self._client.messages.create(
150+
model=self._model,
151+
max_tokens=4096,
152+
temperature=0.0,
153+
system=system_prompt,
154+
messages=[{"role": "user", "content": prompt}],
155+
)
156+
text = "".join(block.text for block in response.content if getattr(block, "type", None) == "text")
157+
try:
158+
payload = _parse_json_payload(text)
159+
except json.JSONDecodeError:
160+
coerced = _coerce_text_payload(text, schema)
161+
if coerced is None:
162+
raise _StructuredOutputError("Model response was not valid JSON") from None
163+
payload = coerced
164+
return _validate_schema_payload(payload, schema)
165+
except (RateLimitError, APIConnectionError) as e:
166+
last_exc = e
167+
except APIStatusError as e:
168+
last_exc = e
169+
if e.status_code not in (429, 500, 502, 503, 504):
170+
raise
171+
except _StructuredOutputError as e:
172+
last_exc = e
173+
except Exception as e:
174+
last_exc = e
175+
msg = str(e)
176+
if "429" not in msg and "rate" not in msg.lower():
177+
raise
178+
179+
if attempt < _MAX_RETRIES - 1:
180+
time.sleep(delay)
181+
delay *= 2
182+
183+
raise RuntimeError(f"Anthropic request failed after {_MAX_RETRIES} retries: {last_exc}")

src/memory_bench/modes/agentic_rag.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from .base import ResponseMode
55
from .rag import RAGMode, _OPEN_SCHEMA, _MCQ_SCHEMA
66
from ..dataset.base import _DEFAULT_OPEN_PROMPT as _OPEN_PROMPT, _DEFAULT_MCQ_PROMPT as _MCQ_PROMPT
7-
from ..llm.base import ToolDef
7+
from ..llm.base import LLM, ToolDef
88
from ..llm.gemini import GeminiLLM
99
from ..memory.base import MemoryProvider
1010
from ..models import AnswerResult
@@ -23,8 +23,10 @@ class AgenticRAGMode(ResponseMode):
2323
name = "agentic-rag"
2424
description = "The LLM acts as an agent with a recall tool and can make multiple retrieval calls with different queries before finalising its answer."
2525

26-
def __init__(self, llm: GeminiLLM | None = None, k: int = 10):
26+
def __init__(self, llm: LLM | None = None, k: int = 10):
2727
self._llm = llm or GeminiLLM()
28+
if type(self._llm).tool_loop is LLM.tool_loop:
29+
raise ValueError(f"{self._llm.model_id} does not support agentic-rag tool calling")
2830
self._rag = RAGMode(llm=self._llm, k=k)
2931
self.k = k
3032

uv.lock

Lines changed: 4 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)