@@ -331,48 +331,47 @@ def __call__(self, messages: list[dict], n_samples: int = 1, temperature: float
331331 res_think ["log_probs" ] = completion .choices [0 ].logprobs
332332 return res_think , res_action
333333 else :
334- return [
335- self ._build_think_action_pair (choice )
336- for choice in completion .choices
337- ]
334+ return [self ._build_think_action_pair (choice ) for choice in completion .choices ]
338335
339- def _extract_thinking_content_from_response (self , response , wrap_tag = "think" ) -> tuple [str , str ]:
336+ def _extract_thinking_content_from_response (
337+ self , response , wrap_tag = "think"
338+ ) -> tuple [str , str ]:
340339 """Extract reasoning and action content from an API response.
341-
340+
342341 Logic:
343- 1. If reasoning_content exists: use it as think, use content as action
342+ 1. If reasoning_content exists: use it as think, use content as action
344343 (remove BEGIN/END FINAL RESPONSE tokens if present, add action tags)
345344 2. If reasoning_content is empty: search content for last BEGIN/END FINAL RESPONSE block,
346345 use everything before as think, use content inside tags as action
347-
346+
348347 Args:
349348 response: The API response object.
350349 wrap_tag: Tag name to wrap reasoning content (default: "think").
351-
350+
352351 Returns:
353352 tuple: (think_wrapped, action_wrapped)
354353 """
355354 message = response .choices [0 ].message
356- msg_dict = message .to_dict () if hasattr (message , ' to_dict' ) else dict (message )
357-
355+ msg_dict = message .to_dict () if hasattr (message , " to_dict" ) else dict (message )
356+
358357 reasoning = msg_dict .get ("reasoning_content" ) or msg_dict .get ("reasoning" ) or ""
359358 content = msg_dict .get ("content" , "" ) or msg_dict .get ("text" , "" ) or ""
360-
359+
361360 # Case 1: Explicit reasoning field from API
362361 if reasoning :
363362 think_wrapped = f"<{ wrap_tag } >{ reasoning } </{ wrap_tag } >"
364363 # Remove BEGIN/END FINAL RESPONSE tokens from content if present
365364 action_text = self ._remove_final_response_tokens (content )
366365 action_wrapped = f"<action>{ action_text } </action>"
367366 return think_wrapped , action_wrapped
368-
367+
369368 # Case 2: No reasoning field - parse content for BEGIN/END FINAL RESPONSE
370369 if "[BEGIN FINAL RESPONSE]" in content and "[END FINAL RESPONSE]" in content :
371370 think_text , action_text = self ._parse_apriel_format (content )
372371 think_wrapped = f"<{ wrap_tag } >{ think_text } </{ wrap_tag } >" if think_text else ""
373372 action_wrapped = f"<action>{ action_text } </action>" if action_text else ""
374373 return think_wrapped , action_wrapped
375-
374+
376375 # Case 3: No special format - return content as action
377376 return "" , f"<action>{ content } </action>" if content else ""
378377
@@ -383,7 +382,7 @@ def _remove_final_response_tokens(self, content: str) -> str:
383382
384383 def _extract_last_action_from_tags (self , content : str ) -> str :
385384 """Extract content from the LAST [BEGIN FINAL RESPONSE]...[END FINAL RESPONSE] block."""
386- pattern = r' \[BEGIN FINAL RESPONSE\](.*?)\[END FINAL RESPONSE\]'
385+ pattern = r" \[BEGIN FINAL RESPONSE\](.*?)\[END FINAL RESPONSE\]"
387386 matches = re .findall (pattern , content , re .DOTALL )
388387 return matches [- 1 ].strip () if matches else ""
389388
@@ -392,20 +391,18 @@ def _parse_apriel_format(self, content: str) -> tuple[str, str]:
392391 last_begin = content .rfind ("[BEGIN FINAL RESPONSE]" )
393392 if last_begin == - 1 :
394393 return "" , content
395-
394+
396395 reasoning = content [:last_begin ].strip ()
397396 if reasoning .startswith ("Here are my reasoning steps:" ):
398- reasoning = reasoning [len ("Here are my reasoning steps:" ):].strip ()
399-
397+ reasoning = reasoning [len ("Here are my reasoning steps:" ) :].strip ()
398+
400399 action = self ._extract_last_action_from_tags (content )
401400 return reasoning , action
402401
403402 def _build_think_action_pair (self , choice ) -> tuple [AIMessage , AIMessage ]:
404403 """Build (think, action) pair from a single choice."""
405404 # Create minimal response-like object for the extraction method
406- mock_response = type ('MockResponse' , (), {
407- 'choices' : [choice ]
408- })()
405+ mock_response = type ("MockResponse" , (), {"choices" : [choice ]})()
409406 think , action = self ._extract_thinking_content_from_response (mock_response )
410407 return AIMessage (think or "" ), AIMessage (action or "" )
411408
@@ -575,12 +572,9 @@ def __init__(
575572 max_retry = 4 ,
576573 min_retry_wait_time = 60 ,
577574 ):
578- base_url = base_url or os .getenv (
579- "APRIEL_API_URL" ,
580- ""
581- )
575+ base_url = base_url or os .getenv ("APRIEL_API_URL" , "" )
582576 api_key = api_key or os .getenv ("APRIEL_API_KEY" )
583-
577+
584578 super ().__init__ (
585579 model_name = model_name ,
586580 api_key = api_key ,
@@ -597,7 +591,7 @@ def __init__(
597591@dataclass
598592class AprielModelArgs (BaseModelArgs ):
599593 """Serializable args for Apriel models."""
600-
594+
601595 base_url : str = None
602596 api_key : str = None
603597
@@ -619,6 +613,7 @@ def __init__(
619613 temperature = 0.5 ,
620614 max_tokens = 100 ,
621615 max_retry = 4 ,
616+ pricing_func = None ,
622617 ):
623618 self .model_name = model_name
624619 self .temperature = temperature
@@ -628,6 +623,22 @@ def __init__(
628623 api_key = api_key or os .getenv ("ANTHROPIC_API_KEY" )
629624 self .client = anthropic .Anthropic (api_key = api_key )
630625
626+ # Get pricing information
627+ if pricing_func :
628+ pricings = pricing_func ()
629+ try :
630+ self .input_cost = float (pricings [model_name ]["prompt" ])
631+ self .output_cost = float (pricings [model_name ]["completion" ])
632+ except KeyError :
633+ logging .warning (
634+ f"Model { model_name } not found in the pricing information, prices are set to 0. Maybe try upgrading langchain_community."
635+ )
636+ self .input_cost = 0.0
637+ self .output_cost = 0.0
638+ else :
639+ self .input_cost = 0.0
640+ self .output_cost = 0.0
641+
631642 def __call__ (self , messages : list [dict ], n_samples : int = 1 , temperature : float = None ) -> dict :
632643 # Convert OpenAI format to Anthropic format
633644 system_message = None
@@ -655,13 +666,28 @@ def __call__(self, messages: list[dict], n_samples: int = 1, temperature: float
655666
656667 response = self .client .messages .create (** kwargs )
657668
669+ usage = getattr (response , "usage" , {})
670+ new_input_tokens = getattr (usage , "input_tokens" , 0 )
671+ output_tokens = getattr (usage , "output_tokens" , 0 )
672+ cache_read_tokens = getattr (usage , "cache_input_tokens" , 0 )
673+ cache_write_tokens = getattr (usage , "cache_creation_input_tokens" , 0 )
674+ cache_read_cost = (
675+ self .input_cost * tracking .ANTHROPIC_CACHE_PRICING_FACTOR ["cache_read_tokens" ]
676+ )
677+ cache_write_cost = (
678+ self .input_cost * tracking .ANTHROPIC_CACHE_PRICING_FACTOR ["cache_write_tokens" ]
679+ )
680+ cost = (
681+ new_input_tokens * self .input_cost
682+ + output_tokens * self .output_cost
683+ + cache_read_tokens * cache_read_cost
684+ + cache_write_tokens * cache_write_cost
685+ )
658686 # Track usage if available
659- if hasattr (tracking .TRACKER , "instance" ):
660- tracking .TRACKER .instance (
661- response .usage .input_tokens ,
662- response .usage .output_tokens ,
663- 0 , # cost calculation would need pricing info
664- )
687+ if hasattr (tracking .TRACKER , "instance" ) and isinstance (
688+ tracking .TRACKER .instance , tracking .LLMTracker
689+ ):
690+ tracking .TRACKER .instance (new_input_tokens , output_tokens , cost )
665691
666692 return AIMessage (response .content [0 ].text )
667693
@@ -679,6 +705,7 @@ def make_model(self):
679705 model_name = self .model_name ,
680706 temperature = self .temperature ,
681707 max_tokens = self .max_new_tokens ,
708+ pricing_func = partial (tracking .get_pricing_litellm , model_name = self .model_name ),
682709 )
683710
684711
0 commit comments