Skip to content

Commit 26173bc

Browse files
committed
Add AnthropicChatModel and corresponding test cases for API integration
1 parent 88e949f commit 26173bc

2 files changed

Lines changed: 100 additions & 0 deletions

File tree

src/agentlab/llm/chat_api.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from functools import partial
77
from typing import Optional
88

9+
import anthropic
910
import openai
1011
from huggingface_hub import InferenceClient
1112
from openai import AzureOpenAI, OpenAI
@@ -471,3 +472,77 @@ def __init__(
471472
client_args={"base_url": "http://0.0.0.0:8000/v1"},
472473
pricing_func=None,
473474
)
475+
476+
477+
class AnthropicChatModel(AbstractChatModel):
478+
def __init__(
479+
self,
480+
model_name,
481+
api_key=None,
482+
temperature=0.5,
483+
max_tokens=100,
484+
max_retry=4,
485+
log_probs=False,
486+
):
487+
self.model_name = model_name
488+
self.temperature = temperature
489+
self.max_tokens = max_tokens
490+
self.max_retry = max_retry
491+
self.log_probs = log_probs
492+
493+
api_key = api_key or os.getenv("ANTHROPIC_API_KEY")
494+
self.client = anthropic.Anthropic(api_key=api_key)
495+
496+
def __call__(self, messages: list[dict], n_samples: int = 1, temperature: float = None) -> dict:
497+
# Convert OpenAI format to Anthropic format
498+
system_message = None
499+
anthropic_messages = []
500+
501+
for msg in messages:
502+
if msg["role"] == "system":
503+
system_message = msg["content"]
504+
else:
505+
anthropic_messages.append({"role": msg["role"], "content": msg["content"]})
506+
507+
temperature = temperature if temperature is not None else self.temperature
508+
509+
for attempt in range(self.max_retry):
510+
try:
511+
kwargs = {
512+
"model": self.model_name,
513+
"messages": anthropic_messages,
514+
"max_tokens": self.max_tokens,
515+
"temperature": temperature,
516+
}
517+
518+
if system_message:
519+
kwargs["system"] = system_message
520+
521+
response = self.client.messages.create(**kwargs)
522+
523+
# Track usage if available
524+
if hasattr(tracking.TRACKER, "instance"):
525+
tracking.TRACKER.instance(
526+
response.usage.input_tokens,
527+
response.usage.output_tokens,
528+
0, # cost calculation would need pricing info
529+
)
530+
531+
return AIMessage(response.content[0].text)
532+
533+
except Exception as e:
534+
if attempt == self.max_retry - 1:
535+
raise e
536+
logging.warning(f"Anthropic API error (attempt {attempt + 1}): {e}")
537+
time.sleep(60) # Simple retry delay
538+
539+
540+
@dataclass
541+
class AnthropicModelArgs(BaseModelArgs):
542+
def make_model(self):
543+
return AnthropicChatModel(
544+
model_name=self.model_name,
545+
temperature=self.temperature,
546+
max_tokens=self.max_new_tokens,
547+
log_probs=self.log_probs,
548+
)

tests/llm/test_chat_api.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import pytest
44

55
from agentlab.llm.chat_api import (
6+
AnthropicModelArgs,
67
AzureModelArgs,
78
OpenAIModelArgs,
89
make_system_message,
@@ -59,3 +60,27 @@ def test_api_model_args_openai():
5960
answer = model(messages)
6061

6162
assert "5" in answer.get("content")
63+
64+
65+
@pytest.mark.pricy
66+
@pytest.mark.skipif(skip_tests, reason="Skipping on remote as Anthropic is pricy")
67+
def test_api_model_args_anthropic():
68+
model_args = AnthropicModelArgs(
69+
model_name="claude-3-haiku-20240307",
70+
max_total_tokens=8192,
71+
max_input_tokens=8192 - 512,
72+
max_new_tokens=512,
73+
temperature=1e-1,
74+
)
75+
model = model_args.make_model()
76+
77+
messages = [
78+
make_system_message("You are an helpful virtual assistant"),
79+
make_user_message("Give the third prime number. Just the number, no explanation."),
80+
]
81+
answer = model(messages)
82+
assert "5" in answer.get("content")
83+
84+
85+
if __name__ == "__main__":
86+
test_api_model_args_anthropic()

0 commit comments

Comments
 (0)