Skip to content

Commit 3b4acd0

Browse files
committed
add pricing info for AnthropicChatModel
1 parent cbe1e7a commit 3b4acd0

1 file changed

Lines changed: 60 additions & 33 deletions

File tree

src/agentlab/llm/chat_api.py

Lines changed: 60 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -331,48 +331,47 @@ def __call__(self, messages: list[dict], n_samples: int = 1, temperature: float
331331
res_think["log_probs"] = completion.choices[0].logprobs
332332
return res_think, res_action
333333
else:
334-
return [
335-
self._build_think_action_pair(choice)
336-
for choice in completion.choices
337-
]
334+
return [self._build_think_action_pair(choice) for choice in completion.choices]
338335

339-
def _extract_thinking_content_from_response(self, response, wrap_tag="think") -> tuple[str, str]:
336+
def _extract_thinking_content_from_response(
337+
self, response, wrap_tag="think"
338+
) -> tuple[str, str]:
340339
"""Extract reasoning and action content from an API response.
341-
340+
342341
Logic:
343-
1. If reasoning_content exists: use it as think, use content as action
342+
1. If reasoning_content exists: use it as think, use content as action
344343
(remove BEGIN/END FINAL RESPONSE tokens if present, add action tags)
345344
2. If reasoning_content is empty: search content for last BEGIN/END FINAL RESPONSE block,
346345
use everything before as think, use content inside tags as action
347-
346+
348347
Args:
349348
response: The API response object.
350349
wrap_tag: Tag name to wrap reasoning content (default: "think").
351-
350+
352351
Returns:
353352
tuple: (think_wrapped, action_wrapped)
354353
"""
355354
message = response.choices[0].message
356-
msg_dict = message.to_dict() if hasattr(message, 'to_dict') else dict(message)
357-
355+
msg_dict = message.to_dict() if hasattr(message, "to_dict") else dict(message)
356+
358357
reasoning = msg_dict.get("reasoning_content") or msg_dict.get("reasoning") or ""
359358
content = msg_dict.get("content", "") or msg_dict.get("text", "") or ""
360-
359+
361360
# Case 1: Explicit reasoning field from API
362361
if reasoning:
363362
think_wrapped = f"<{wrap_tag}>{reasoning}</{wrap_tag}>"
364363
# Remove BEGIN/END FINAL RESPONSE tokens from content if present
365364
action_text = self._remove_final_response_tokens(content)
366365
action_wrapped = f"<action>{action_text}</action>"
367366
return think_wrapped, action_wrapped
368-
367+
369368
# Case 2: No reasoning field - parse content for BEGIN/END FINAL RESPONSE
370369
if "[BEGIN FINAL RESPONSE]" in content and "[END FINAL RESPONSE]" in content:
371370
think_text, action_text = self._parse_apriel_format(content)
372371
think_wrapped = f"<{wrap_tag}>{think_text}</{wrap_tag}>" if think_text else ""
373372
action_wrapped = f"<action>{action_text}</action>" if action_text else ""
374373
return think_wrapped, action_wrapped
375-
374+
376375
# Case 3: No special format - return content as action
377376
return "", f"<action>{content}</action>" if content else ""
378377

@@ -383,7 +382,7 @@ def _remove_final_response_tokens(self, content: str) -> str:
383382

384383
def _extract_last_action_from_tags(self, content: str) -> str:
385384
"""Extract content from the LAST [BEGIN FINAL RESPONSE]...[END FINAL RESPONSE] block."""
386-
pattern = r'\[BEGIN FINAL RESPONSE\](.*?)\[END FINAL RESPONSE\]'
385+
pattern = r"\[BEGIN FINAL RESPONSE\](.*?)\[END FINAL RESPONSE\]"
387386
matches = re.findall(pattern, content, re.DOTALL)
388387
return matches[-1].strip() if matches else ""
389388

@@ -392,20 +391,18 @@ def _parse_apriel_format(self, content: str) -> tuple[str, str]:
392391
last_begin = content.rfind("[BEGIN FINAL RESPONSE]")
393392
if last_begin == -1:
394393
return "", content
395-
394+
396395
reasoning = content[:last_begin].strip()
397396
if reasoning.startswith("Here are my reasoning steps:"):
398-
reasoning = reasoning[len("Here are my reasoning steps:"):].strip()
399-
397+
reasoning = reasoning[len("Here are my reasoning steps:") :].strip()
398+
400399
action = self._extract_last_action_from_tags(content)
401400
return reasoning, action
402401

403402
def _build_think_action_pair(self, choice) -> tuple[AIMessage, AIMessage]:
404403
"""Build (think, action) pair from a single choice."""
405404
# Create minimal response-like object for the extraction method
406-
mock_response = type('MockResponse', (), {
407-
'choices': [choice]
408-
})()
405+
mock_response = type("MockResponse", (), {"choices": [choice]})()
409406
think, action = self._extract_thinking_content_from_response(mock_response)
410407
return AIMessage(think or ""), AIMessage(action or "")
411408

@@ -575,12 +572,9 @@ def __init__(
575572
max_retry=4,
576573
min_retry_wait_time=60,
577574
):
578-
base_url = base_url or os.getenv(
579-
"APRIEL_API_URL",
580-
""
581-
)
575+
base_url = base_url or os.getenv("APRIEL_API_URL", "")
582576
api_key = api_key or os.getenv("APRIEL_API_KEY")
583-
577+
584578
super().__init__(
585579
model_name=model_name,
586580
api_key=api_key,
@@ -597,7 +591,7 @@ def __init__(
597591
@dataclass
598592
class AprielModelArgs(BaseModelArgs):
599593
"""Serializable args for Apriel models."""
600-
594+
601595
base_url: str = None
602596
api_key: str = None
603597

@@ -619,6 +613,7 @@ def __init__(
619613
temperature=0.5,
620614
max_tokens=100,
621615
max_retry=4,
616+
pricing_func=None,
622617
):
623618
self.model_name = model_name
624619
self.temperature = temperature
@@ -628,6 +623,22 @@ def __init__(
628623
api_key = api_key or os.getenv("ANTHROPIC_API_KEY")
629624
self.client = anthropic.Anthropic(api_key=api_key)
630625

626+
# Get pricing information
627+
if pricing_func:
628+
pricings = pricing_func()
629+
try:
630+
self.input_cost = float(pricings[model_name]["prompt"])
631+
self.output_cost = float(pricings[model_name]["completion"])
632+
except KeyError:
633+
logging.warning(
634+
f"Model {model_name} not found in the pricing information, prices are set to 0. Maybe try upgrading langchain_community."
635+
)
636+
self.input_cost = 0.0
637+
self.output_cost = 0.0
638+
else:
639+
self.input_cost = 0.0
640+
self.output_cost = 0.0
641+
631642
def __call__(self, messages: list[dict], n_samples: int = 1, temperature: float = None) -> dict:
632643
# Convert OpenAI format to Anthropic format
633644
system_message = None
@@ -655,13 +666,28 @@ def __call__(self, messages: list[dict], n_samples: int = 1, temperature: float
655666

656667
response = self.client.messages.create(**kwargs)
657668

669+
usage = getattr(response, "usage", {})
670+
new_input_tokens = getattr(usage, "input_tokens", 0)
671+
output_tokens = getattr(usage, "output_tokens", 0)
672+
cache_read_tokens = getattr(usage, "cache_input_tokens", 0)
673+
cache_write_tokens = getattr(usage, "cache_creation_input_tokens", 0)
674+
cache_read_cost = (
675+
self.input_cost * tracking.ANTHROPIC_CACHE_PRICING_FACTOR["cache_read_tokens"]
676+
)
677+
cache_write_cost = (
678+
self.input_cost * tracking.ANTHROPIC_CACHE_PRICING_FACTOR["cache_write_tokens"]
679+
)
680+
cost = (
681+
new_input_tokens * self.input_cost
682+
+ output_tokens * self.output_cost
683+
+ cache_read_tokens * cache_read_cost
684+
+ cache_write_tokens * cache_write_cost
685+
)
658686
# Track usage if available
659-
if hasattr(tracking.TRACKER, "instance"):
660-
tracking.TRACKER.instance(
661-
response.usage.input_tokens,
662-
response.usage.output_tokens,
663-
0, # cost calculation would need pricing info
664-
)
687+
if hasattr(tracking.TRACKER, "instance") and isinstance(
688+
tracking.TRACKER.instance, tracking.LLMTracker
689+
):
690+
tracking.TRACKER.instance(new_input_tokens, output_tokens, cost)
665691

666692
return AIMessage(response.content[0].text)
667693

@@ -679,6 +705,7 @@ def make_model(self):
679705
model_name=self.model_name,
680706
temperature=self.temperature,
681707
max_tokens=self.max_new_tokens,
708+
pricing_func=partial(tracking.get_pricing_litellm, model_name=self.model_name),
682709
)
683710

684711

0 commit comments

Comments
 (0)