Skip to content

Commit b22b397

Browse files
authored
Update models.py
1 parent 32e211b commit b22b397

1 file changed

Lines changed: 50 additions & 37 deletions

File tree

sweagent/agent/models.py

Lines changed: 50 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -676,9 +676,7 @@ def _sleep(self) -> None:
676676
with GLOBAL_STATS_LOCK:
677677
GLOBAL_STATS.last_query_timestamp = time.time()
678678

679-
def _single_query(
680-
self, messages: list[dict[str, str]], n: int | None = None, temperature: float | None = None
681-
) -> list[dict]:
679+
def _single_query(self, messages: list[dict[str, str]], n: int | None = None, temperature: float | None = None) -> list[dict]:
682680
self._sleep()
683681
# Workaround for litellm bug https://github.com/SWE-agent/SWE-agent/issues/1109
684682
messages_no_cache_control = copy.deepcopy(messages)
@@ -687,16 +685,9 @@ def _single_query(
687685
del message["cache_control"]
688686
if "thinking_blocks" in message:
689687
del message["thinking_blocks"]
690-
input_tokens: int = litellm.utils.token_counter(
691-
messages=messages_no_cache_control,
692-
model=self.custom_tokenizer["identifier"] if self.custom_tokenizer is not None else self.config.name,
693-
custom_tokenizer=self.custom_tokenizer,
694-
)
688+
input_tokens: int = litellm.utils.token_counter(messages=messages_no_cache_control, model=self.custom_tokenizer["identifier"] if self.custom_tokenizer is not None else self.config.name, custom_tokenizer=self.custom_tokenizer)
695689
if self.model_max_input_tokens is None:
696-
msg = (
697-
f"No max input tokens found for model {self.config.name!r}. "
698-
"If you are using a local model, you can set `max_input_token` in the model config to override this."
699-
)
690+
msg = (f"No max input tokens found for model {self.config.name!r}. If you are using a local model, you can set `max_input_token` in the model config to override this.")
700691
self.logger.warning(msg)
701692
elif input_tokens > self.model_max_input_tokens > 0:
702693
msg = f"Input tokens {input_tokens} exceed max tokens {self.model_max_input_tokens}"
@@ -707,13 +698,18 @@ def _single_query(
707698
extra_args["api_base"] = self.config.api_base
708699
if self.tools.use_function_calling:
709700
extra_args["tools"] = self.tools.tools
710-
# We need to always set max_tokens for anthropic models
701+
self.logger.info(f"api_base:{extra_args['api_base']}")
711702
completion_kwargs = self.config.completion_kwargs
703+
completion_kwargs['extra_headers'] = {'anthropic-beta': 'output-128k-2025-02-19'}
704+
model_name = self.config.name
705+
if "deepseek" or "glm" in self.config.name.lower():
706+
completion_kwargs.pop('reasoning_effort', None)
707+
self.logger.debug(f"Using official DeepSeek format: {model_name}")
712708
if self.lm_provider == "anthropic":
713709
completion_kwargs["max_tokens"] = self.model_max_output_tokens
714710
try:
715-
response: litellm.types.utils.ModelResponse = litellm.completion( # type: ignore
716-
model=self.config.name,
711+
response: litellm.types.utils.ModelResponse = litellm.completion(
712+
model=model_name,
717713
messages=messages,
718714
temperature=self.config.temperature if temperature is None else temperature,
719715
top_p=self.config.top_p,
@@ -732,43 +728,60 @@ def _single_query(
732728
if "is longer than the model's context length" in str(e):
733729
raise ContextWindowExceededError from e
734730
raise
735-
self.logger.debug(f"Response: {response}")
736731
try:
737732
cost = litellm.cost_calculator.completion_cost(response, model=self.config.name)
738733
except Exception as e:
739-
self.logger.debug(f"Error calculating cost: {e}, setting cost to 0.")
734+
self.logger.debug(f"Error calculating cost: {e}, attempting fallback cost calculation.")
735+
fallback_cost = 0
736+
if "deepseek" in self.config.name.lower():
737+
try:
738+
fallback_models = ["deepseek-chat", "deepseek/deepseek-chat"]
739+
for fallback_model in fallback_models:
740+
try:
741+
fallback_cost = litellm.cost_calculator.completion_cost(response, model=fallback_model)
742+
self.logger.info(f"Using fallback model '{fallback_model}' for cost calculation: ${fallback_cost:.6f}")
743+
break
744+
except Exception:
745+
continue
746+
if fallback_cost == 0:
747+
input_tokens = litellm.utils.token_counter(messages=messages_no_cache_control, model=self.config.name)
748+
output_tokens = sum(litellm.utils.token_counter(text=choice.message.content or "", model=self.config.name) for choice in response.choices)
749+
fallback_cost = (input_tokens * 1.0 + output_tokens * 3.0) / 1000000
750+
self.logger.info(f"Using estimated DeepSeek pricing for cost calculation: ${fallback_cost:.6f}")
751+
except Exception as fallback_error:
752+
self.logger.debug(f"Fallback cost calculation also failed: {fallback_error}")
753+
fallback_cost = 0
740754
if self.config.per_instance_cost_limit > 0 or self.config.total_cost_limit > 0:
741-
msg = (
742-
f"Error calculating cost: {e} for your model {self.config.name}. If this is ok "
743-
"(local models, etc.), please make sure you set `per_instance_cost_limit` and "
744-
"`total_cost_limit` to 0 to disable this safety check."
745-
)
746-
self.logger.error(msg)
747-
raise ModelConfigurationError(msg)
748-
cost = 0
755+
if fallback_cost == 0:
756+
msg = (
757+
f"Error calculating cost: {e} for your model {self.config.name}. "
758+
f"Fallback cost calculation also failed. If this is ok (local models, etc.), "
759+
f"please set `per_instance_cost_limit` and `total_cost_limit` to 0 to disable this safety check."
760+
)
761+
self.logger.error(msg)
762+
raise ModelConfigurationError(msg)
763+
else:
764+
self.logger.warning(f"Using fallback cost calculation due to: {e}")
765+
cost = fallback_cost
766+
else:
767+
cost = fallback_cost
768+
749769
choices: litellm.types.utils.Choices = response.choices # type: ignore
750770
n_choices = n if n is not None else 1
751771
outputs = []
752772
output_tokens = 0
753773
for i in range(n_choices):
754774
output = choices[i].message.content or ""
755-
output_tokens += litellm.utils.token_counter(
756-
text=output,
757-
model=self.custom_tokenizer["identifier"] if self.custom_tokenizer is not None else self.config.name,
758-
custom_tokenizer=self.custom_tokenizer,
759-
)
775+
output_tokens += litellm.utils.token_counter(text=output, model=self.custom_tokenizer["identifier"] if self.custom_tokenizer is not None else self.config.name, custom_tokenizer=self.custom_tokenizer)
760776
output_dict = {"message": output}
761777
if self.tools.use_function_calling:
762-
if response.choices[i].message.tool_calls: # type: ignore
763-
tool_calls = [call.to_dict() for call in response.choices[i].message.tool_calls] # type: ignore
778+
if response.choices[i].message.tool_calls:
779+
tool_calls = [call.to_dict() for call in response.choices[i].message.tool_calls]
764780
else:
765781
tool_calls = []
766782
output_dict["tool_calls"] = tool_calls
767-
if (
768-
hasattr(response.choices[i].message, "thinking_blocks") # type: ignore
769-
and response.choices[i].message.thinking_blocks # type: ignore
770-
):
771-
output_dict["thinking_blocks"] = response.choices[i].message.thinking_blocks # type: ignore
783+
if (hasattr(response.choices[i].message, "thinking_blocks") and response.choices[i].message.thinking_blocks):
784+
output_dict["thinking_blocks"] = response.choices[i].message.thinking_blocks
772785
outputs.append(output_dict)
773786
self._update_stats(input_tokens=input_tokens, output_tokens=output_tokens, cost=cost)
774787
return outputs

0 commit comments

Comments
 (0)