Skip to content

Commit 0e07fc6

Browse files
author
zhangchi47
committed
feat: support Anthropic-compatible endpoints for benchmark LLMs
1 parent 45fa380 commit 0e07fc6

6 files changed

Lines changed: 212 additions & 10 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ version = "0.1.0"
44
description = "Open Memory Benchmark"
55
requires-python = ">=3.11"
66
dependencies = [
7+
"anthropic>=0.84.0",
78
"datasets>=2.0",
89
"typer>=0.12",
910
"rich>=13",

src/memory_bench/cli.py

Lines changed: 59 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
load_dotenv(dotenv_path=Path(__file__).parents[2] / ".env", override=True)
1313

1414
from .dataset import REGISTRY as DATASET_REGISTRY, get_dataset
15-
from .llm import REGISTRY as LLM_REGISTRY, get_llm, get_answer_llm
15+
from .llm import REGISTRY as LLM_REGISTRY, get_answer_llm
1616
from .memory import REGISTRY as MEMORY_REGISTRY, get_memory_provider
1717
from .modes import REGISTRY as MODE_REGISTRY, get_mode
1818
from .runner import EvalRunner
@@ -22,12 +22,63 @@
2222
console = Console()
2323

2424

25-
def _resolve_gemini_key() -> None:
26-
key = os.environ.get("GEMINI_API_KEY") or os.environ.get("GOOGLE_API_KEY")
27-
if not key:
28-
typer.echo("Error: GEMINI_API_KEY environment variable is not set.", err=True)
25+
def _ensure_provider_env(provider: str, role: str) -> None:
26+
if provider not in LLM_REGISTRY:
27+
typer.echo(
28+
f"Error: unknown {role.lower()} LLM provider '{provider}'. Available: {', '.join(LLM_REGISTRY)}.",
29+
err=True,
30+
)
2931
raise typer.Exit(1)
30-
os.environ["GOOGLE_API_KEY"] = key
32+
33+
if provider == "anthropic":
34+
if not os.environ.get("ANTHROPIC_API_KEY"):
35+
typer.echo(f"Error: {role} LLM provider '{provider}' requires ANTHROPIC_API_KEY.", err=True)
36+
raise typer.Exit(1)
37+
return
38+
39+
if provider == "gemini":
40+
key = os.environ.get("GEMINI_API_KEY") or os.environ.get("GOOGLE_API_KEY")
41+
if not key:
42+
typer.echo(f"Error: {role} LLM provider '{provider}' requires GEMINI_API_KEY.", err=True)
43+
raise typer.Exit(1)
44+
os.environ["GOOGLE_API_KEY"] = key
45+
return
46+
47+
if provider == "groq":
48+
if not os.environ.get("GROQ_API_KEY"):
49+
typer.echo(f"Error: {role} LLM provider '{provider}' requires GROQ_API_KEY.", err=True)
50+
raise typer.Exit(1)
51+
return
52+
53+
if provider == "openai":
54+
if not os.environ.get("OPENAI_API_KEY"):
55+
typer.echo(f"Error: {role} LLM provider '{provider}' requires OPENAI_API_KEY.", err=True)
56+
raise typer.Exit(1)
57+
return
58+
59+
60+
def _validate_run_env(memory: str, mode: str, answer_provider: str | None = None) -> None:
61+
if answer_provider is not None:
62+
os.environ["OMB_ANSWER_LLM"] = answer_provider
63+
64+
answer_provider = os.environ.get("OMB_ANSWER_LLM", "groq")
65+
judge_provider = os.environ.get("OMB_JUDGE_LLM", "gemini")
66+
_ensure_provider_env(answer_provider, "Answer")
67+
_ensure_provider_env(judge_provider, "Judge")
68+
69+
if mode == "agentic-rag" and answer_provider != "gemini":
70+
typer.echo(
71+
f"Error: response mode 'agentic-rag' requires a tool-capable LLM provider; '{answer_provider}' is not supported.",
72+
err=True,
73+
)
74+
raise typer.Exit(1)
75+
76+
if memory == "hindsight":
77+
key = os.environ.get("GEMINI_API_KEY") or os.environ.get("GOOGLE_API_KEY")
78+
if not key:
79+
typer.echo("Error: memory provider 'hindsight' requires GEMINI_API_KEY for embedded extraction.", err=True)
80+
raise typer.Exit(1)
81+
os.environ["GOOGLE_API_KEY"] = key
3182

3283

3384
@app.command()
@@ -36,7 +87,7 @@ def run(
3687
dataset: str = typer.Option("tempo", "--dataset", help=f"Dataset. Available: {', '.join(DATASET_REGISTRY)}"),
3788
memory: str = typer.Option("bm25", "--memory", "-m", help=f"Memory provider. Available: {', '.join(MEMORY_REGISTRY)}"),
3889
mode: str = typer.Option("rag", "--mode", help=f"Response mode. Available: {', '.join(MODE_REGISTRY)}"),
39-
llm: str = typer.Option("gemini", "--llm", help=f"LLM for answer generation. Available: {', '.join(LLM_REGISTRY)}"),
90+
llm: str | None = typer.Option(None, "--llm", help=f"LLM provider for answer generation. Overrides OMB_ANSWER_LLM. Available: {', '.join(LLM_REGISTRY)}"),
4091
category: str = typer.Option(None, "--category", "-c", help="Category filter(s), comma-separated (e.g. 'a,b,c'). With --query-limit, runs N queries per category."),
4192
query_limit: int = typer.Option(None, "--query-limit", "-q", help="Max queries to evaluate. When combined with multiple --category values, applies per category."),
4293
query_id: str = typer.Option(None, "--query-id", help="Run a single specific query by ID"),
@@ -53,7 +104,7 @@ def run(
53104
description: str = typer.Option(None, "--description", "-d", help="Optional description for this run (stored in the result JSON)"),
54105
) -> None:
55106
"""Run an evaluation on a single split (optionally filtered to a category)."""
56-
_resolve_gemini_key()
107+
_validate_run_env(memory, mode, llm)
57108

58109
ds = get_dataset(dataset)
59110

src/memory_bench/llm/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
import os
22

3+
from .anthropic import AnthropicLLM
34
from .base import LLM, Schema
45
from .gemini import GeminiLLM
56
from .groq import GroqLLM
67
from .openai import OpenAILLM
78

89
REGISTRY: dict[str, type[LLM]] = {
10+
"anthropic": AnthropicLLM,
911
"gemini": GeminiLLM,
1012
"groq": GroqLLM,
1113
"openai": OpenAILLM,

src/memory_bench/llm/anthropic.py

Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
import json
2+
import os
3+
import re
4+
import time
5+
6+
from .base import LLM, Schema
7+
8+
_MAX_RETRIES = 6
9+
_RETRY_BASE_DELAY = 5
10+
11+
12+
def _parse_json_payload(text: str) -> dict:
13+
text = text.strip()
14+
15+
try:
16+
return json.loads(text)
17+
except json.JSONDecodeError:
18+
pass
19+
20+
fenced = re.search(r"```(?:json)?\s*(\{.*\})\s*```", text, flags=re.DOTALL | re.IGNORECASE)
21+
if fenced:
22+
return json.loads(fenced.group(1))
23+
24+
start = text.find("{")
25+
end = text.rfind("}")
26+
if start != -1 and end != -1 and end > start:
27+
return json.loads(text[start : end + 1])
28+
29+
raise json.JSONDecodeError("Could not find JSON object in model response", text, 0)
30+
31+
32+
def _coerce_text_payload(text: str, schema: Schema) -> dict | None:
33+
text = text.strip()
34+
if not text:
35+
return None
36+
37+
result: dict = {}
38+
for field in schema.required:
39+
spec = schema.properties.get(field, {})
40+
field_type = spec.get("type", "string")
41+
lowered = text.lower()
42+
43+
if field_type == "string":
44+
result[field] = text
45+
continue
46+
47+
if field_type == "boolean":
48+
if re.search(r"\btrue\b", lowered):
49+
result[field] = True
50+
continue
51+
if re.search(r"\bfalse\b", lowered):
52+
result[field] = False
53+
continue
54+
if re.search(r"\b(correct|yes)\b", lowered) and not re.search(r"\b(incorrect|wrong|no)\b", lowered):
55+
result[field] = True
56+
continue
57+
if re.search(r"\b(incorrect|wrong|no)\b", lowered):
58+
result[field] = False
59+
continue
60+
return None
61+
62+
return None
63+
64+
return result
65+
66+
67+
class AnthropicLLM(LLM):
68+
def __init__(self, model: str | None = None):
69+
from anthropic import Anthropic
70+
71+
api_key = os.environ.get("ANTHROPIC_API_KEY")
72+
if not api_key:
73+
raise RuntimeError("Anthropic provider requires ANTHROPIC_API_KEY")
74+
75+
base_url = os.environ.get("ANTHROPIC_BASE_URL")
76+
self._client = Anthropic(
77+
api_key=api_key,
78+
base_url=base_url or None,
79+
max_retries=0,
80+
)
81+
self._model = (
82+
model
83+
or os.environ.get("ANTHROPIC_MODEL")
84+
or "claude-sonnet-4-5"
85+
)
86+
87+
@property
88+
def model_id(self) -> str:
89+
return f"anthropic:{self._model}"
90+
91+
def generate(self, prompt: str, schema: Schema) -> dict:
92+
from anthropic import APIConnectionError, APIStatusError, RateLimitError
93+
94+
schema_json = {
95+
"type": "object",
96+
"properties": schema.properties,
97+
"required": schema.required,
98+
"additionalProperties": False,
99+
}
100+
system_prompt = (
101+
"Return only a valid JSON object matching this schema. "
102+
"Do not wrap JSON in markdown fences.\n\n"
103+
f"{json.dumps(schema_json, ensure_ascii=False)}"
104+
)
105+
106+
delay = _RETRY_BASE_DELAY
107+
last_exc = None
108+
109+
for attempt in range(_MAX_RETRIES):
110+
try:
111+
response = self._client.messages.create(
112+
model=self._model,
113+
max_tokens=4096,
114+
temperature=0.0,
115+
system=system_prompt,
116+
messages=[{"role": "user", "content": prompt}],
117+
)
118+
text = "".join(block.text for block in response.content if getattr(block, "type", None) == "text")
119+
try:
120+
return _parse_json_payload(text)
121+
except json.JSONDecodeError:
122+
coerced = _coerce_text_payload(text, schema)
123+
if coerced is not None:
124+
return coerced
125+
raise
126+
except (RateLimitError, APIConnectionError) as e:
127+
last_exc = e
128+
except APIStatusError as e:
129+
last_exc = e
130+
if e.status_code not in (429, 500, 502, 503, 504):
131+
raise
132+
except Exception as e:
133+
last_exc = e
134+
msg = str(e)
135+
if "429" not in msg and "rate" not in msg.lower():
136+
raise
137+
138+
if attempt < _MAX_RETRIES - 1:
139+
time.sleep(delay)
140+
delay *= 2
141+
142+
raise RuntimeError(f"Anthropic request failed after {_MAX_RETRIES} retries: {last_exc}")

src/memory_bench/modes/agentic_rag.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from .base import ResponseMode
55
from .rag import RAGMode, _OPEN_SCHEMA, _MCQ_SCHEMA
66
from ..dataset.base import _DEFAULT_OPEN_PROMPT as _OPEN_PROMPT, _DEFAULT_MCQ_PROMPT as _MCQ_PROMPT
7-
from ..llm.base import ToolDef
7+
from ..llm.base import LLM, ToolDef
88
from ..llm.gemini import GeminiLLM
99
from ..memory.base import MemoryProvider
1010
from ..models import AnswerResult
@@ -23,8 +23,10 @@ class AgenticRAGMode(ResponseMode):
2323
name = "agentic-rag"
2424
description = "The LLM acts as an agent with a recall tool and can make multiple retrieval calls with different queries before finalising its answer."
2525

26-
def __init__(self, llm: GeminiLLM | None = None, k: int = 10):
26+
def __init__(self, llm: LLM | None = None, k: int = 10):
2727
self._llm = llm or GeminiLLM()
28+
if type(self._llm).tool_loop is LLM.tool_loop:
29+
raise ValueError(f"{self._llm.model_id} does not support agentic-rag tool calling")
2830
self._rag = RAGMode(llm=self._llm, k=k)
2931
self.k = k
3032

uv.lock

Lines changed: 4 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)