|
6 | 6 | from functools import partial |
7 | 7 | from typing import Optional |
8 | 8 |
|
| 9 | +import anthropic |
9 | 10 | import openai |
10 | 11 | from huggingface_hub import InferenceClient |
11 | 12 | from openai import AzureOpenAI, OpenAI |
@@ -471,3 +472,77 @@ def __init__( |
471 | 472 | client_args={"base_url": "http://0.0.0.0:8000/v1"}, |
472 | 473 | pricing_func=None, |
473 | 474 | ) |
| 475 | + |
| 476 | + |
| 477 | +class AnthropicChatModel(AbstractChatModel): |
| 478 | + def __init__( |
| 479 | + self, |
| 480 | + model_name, |
| 481 | + api_key=None, |
| 482 | + temperature=0.5, |
| 483 | + max_tokens=100, |
| 484 | + max_retry=4, |
| 485 | + log_probs=False, |
| 486 | + ): |
| 487 | + self.model_name = model_name |
| 488 | + self.temperature = temperature |
| 489 | + self.max_tokens = max_tokens |
| 490 | + self.max_retry = max_retry |
| 491 | + self.log_probs = log_probs |
| 492 | + |
| 493 | + api_key = api_key or os.getenv("ANTHROPIC_API_KEY") |
| 494 | + self.client = anthropic.Anthropic(api_key=api_key) |
| 495 | + |
| 496 | + def __call__(self, messages: list[dict], n_samples: int = 1, temperature: float = None) -> dict: |
| 497 | + # Convert OpenAI format to Anthropic format |
| 498 | + system_message = None |
| 499 | + anthropic_messages = [] |
| 500 | + |
| 501 | + for msg in messages: |
| 502 | + if msg["role"] == "system": |
| 503 | + system_message = msg["content"] |
| 504 | + else: |
| 505 | + anthropic_messages.append({"role": msg["role"], "content": msg["content"]}) |
| 506 | + |
| 507 | + temperature = temperature if temperature is not None else self.temperature |
| 508 | + |
| 509 | + for attempt in range(self.max_retry): |
| 510 | + try: |
| 511 | + kwargs = { |
| 512 | + "model": self.model_name, |
| 513 | + "messages": anthropic_messages, |
| 514 | + "max_tokens": self.max_tokens, |
| 515 | + "temperature": temperature, |
| 516 | + } |
| 517 | + |
| 518 | + if system_message: |
| 519 | + kwargs["system"] = system_message |
| 520 | + |
| 521 | + response = self.client.messages.create(**kwargs) |
| 522 | + |
| 523 | + # Track usage if available |
| 524 | + if hasattr(tracking.TRACKER, "instance"): |
| 525 | + tracking.TRACKER.instance( |
| 526 | + response.usage.input_tokens, |
| 527 | + response.usage.output_tokens, |
| 528 | + 0, # cost calculation would need pricing info |
| 529 | + ) |
| 530 | + |
| 531 | + return AIMessage(response.content[0].text) |
| 532 | + |
| 533 | + except Exception as e: |
| 534 | + if attempt == self.max_retry - 1: |
| 535 | + raise e |
| 536 | + logging.warning(f"Anthropic API error (attempt {attempt + 1}): {e}") |
| 537 | + time.sleep(60) # Simple retry delay |
| 538 | + |
| 539 | + |
| 540 | +@dataclass |
| 541 | +class AnthropicModelArgs(BaseModelArgs): |
| 542 | + def make_model(self): |
| 543 | + return AnthropicChatModel( |
| 544 | + model_name=self.model_name, |
| 545 | + temperature=self.temperature, |
| 546 | + max_tokens=self.max_new_tokens, |
| 547 | + log_probs=self.log_probs, |
| 548 | + ) |
0 commit comments