Skip to content

Commit bf6954f

Browse files
lexgeniusclaude
andcommitted
fix: bypass LiteLLM for Ollama embeddings to resolve 400 Bad Request
LiteLLM's Ollama embedding handler sends a malformed request to Ollama's /api/embed endpoint, causing a 400 Bad Request error on Ollama 0.18.x. - Add `_ollama_embed()` to `LiteLLMEmbeddingWrapper` that calls Ollama's `/api/embed` directly via httpx, stripping the "ollama/" prefix from the model name (the root cause of the malformed request) - Route `embed_query` and `embed_documents` through this helper when provider == "ollama", bypassing LiteLLM entirely - Wrap `search_similarity_threshold` in try/except so an embedding failure returns [] instead of crashing the agent Fixes #1425 Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent b5e110a commit bf6954f

2 files changed

Lines changed: 208 additions & 60 deletions

File tree

models.py

Lines changed: 175 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,9 @@
2626
from helpers.rate_limiter import RateLimiter
2727
from helpers.tokens import approximate_tokens
2828
from helpers import dirty_json
29-
from helpers.extension import extensible # extensible: allows plugins to intercept get_api_key()
29+
from helpers.extension import (
30+
extensible,
31+
) # extensible: allows plugins to intercept get_api_key()
3032

3133
from langchain_core.language_models.chat_models import SimpleChatModel
3234
from langchain_core.outputs.chat_generation import ChatGenerationChunk
@@ -59,6 +61,7 @@ def turn_off_logging():
5961
load_dotenv()
6062
turn_off_logging()
6163

64+
6265
class ModelType(Enum):
6366
CHAT = "Chat"
6467
EMBEDDING = "Embedding"
@@ -89,12 +92,15 @@ def build_kwargs(self):
8992

9093
class ChatChunk(TypedDict):
9194
"""Simplified response chunk for chat models."""
95+
9296
response_delta: str
9397
reasoning_delta: str
9498

99+
95100
class ChatGenerationResult:
96101
"""Chat generation result object"""
97-
def __init__(self, chunk: ChatChunk|None = None):
102+
103+
def __init__(self, chunk: ChatChunk | None = None):
98104
self.reasoning = ""
99105
self.response = ""
100106
self.thinking = False
@@ -111,7 +117,10 @@ def add_chunk(self, chunk: ChatChunk) -> ChatChunk:
111117

112118
# if native reasoning detection works, there's no need to worry about thinking tags
113119
if self.native_reasoning:
114-
processed_chunk = ChatChunk(response_delta=chunk["response_delta"], reasoning_delta=chunk["reasoning_delta"])
120+
processed_chunk = ChatChunk(
121+
response_delta=chunk["response_delta"],
122+
reasoning_delta=chunk["reasoning_delta"],
123+
)
115124
else:
116125
# if the model outputs thinking tags, we ned to parse them manually as reasoning
117126
processed_chunk = self._process_thinking_chunk(chunk)
@@ -131,7 +140,7 @@ def _process_thinking_tags(self, response: str, reasoning: str) -> ChatChunk:
131140
close_pos = response.find(self.thinking_tag)
132141
if close_pos != -1:
133142
reasoning += response[:close_pos]
134-
response = response[close_pos + len(self.thinking_tag):]
143+
response = response[close_pos + len(self.thinking_tag) :]
135144
self.thinking = False
136145
self.thinking_tag = ""
137146
else:
@@ -144,14 +153,14 @@ def _process_thinking_tags(self, response: str, reasoning: str) -> ChatChunk:
144153
else:
145154
for opening_tag, closing_tag in self.thinking_pairs:
146155
if response.startswith(opening_tag):
147-
response = response[len(opening_tag):]
156+
response = response[len(opening_tag) :]
148157
self.thinking = True
149158
self.thinking_tag = closing_tag
150159

151160
close_pos = response.find(closing_tag)
152161
if close_pos != -1:
153162
reasoning += response[:close_pos]
154-
response = response[close_pos + len(closing_tag):]
163+
response = response[close_pos + len(closing_tag) :]
155164
self.thinking = False
156165
self.thinking_tag = ""
157166
else:
@@ -162,7 +171,9 @@ def _process_thinking_tags(self, response: str, reasoning: str) -> ChatChunk:
162171
reasoning += response
163172
response = ""
164173
break
165-
elif len(response) < len(opening_tag) and self._is_partial_opening_tag(response, opening_tag):
174+
elif len(response) < len(opening_tag) and self._is_partial_opening_tag(
175+
response, opening_tag
176+
):
166177
self.unprocessed = response
167178
response = ""
168179
break
@@ -318,7 +329,9 @@ def __init__(
318329
def _llm_type(self) -> str:
319330
return "litellm-chat"
320331

321-
def _convert_messages(self, messages: List[BaseMessage], explicit_caching: bool = False) -> List[dict]:
332+
def _convert_messages(
333+
self, messages: List[BaseMessage], explicit_caching: bool = False
334+
) -> List[dict]:
322335
result = []
323336
# Map LangChain message types to LiteLLM roles
324337
role_mapping = {
@@ -365,7 +378,9 @@ def _convert_messages(self, messages: List[BaseMessage], explicit_caching: bool
365378

366379
# fix messages with empty content, this breaks some LLMs
367380
content = message_dict.get("content")
368-
has_content = bool(content) if not isinstance(content, list) else len(content) > 0
381+
has_content = (
382+
bool(content) if not isinstance(content, list) else len(content) > 0
383+
)
369384
if not has_content:
370385
message_dict["content"] = "empty"
371386

@@ -429,8 +444,8 @@ def _stream(
429444
**{**self.kwargs, **kwargs},
430445
):
431446
# parse chunk
432-
parsed = _parse_chunk(chunk) # chunk parsing
433-
output = result.add_chunk(parsed) # chunk processing
447+
parsed = _parse_chunk(chunk) # chunk parsing
448+
output = result.add_chunk(parsed) # chunk processing
434449

435450
# Only yield chunks with non-None content
436451
if output["response_delta"]:
@@ -461,8 +476,8 @@ async def _astream(
461476
)
462477
async for chunk in response: # type: ignore
463478
# parse chunk
464-
parsed = _parse_chunk(chunk) # chunk parsing
465-
output = result.add_chunk(parsed) # chunk processing
479+
parsed = _parse_chunk(chunk) # chunk parsing
480+
output = result.add_chunk(parsed) # chunk processing
466481

467482
# Only yield chunks with non-None content
468483
if output["response_delta"]:
@@ -507,7 +522,11 @@ async def unified_call(
507522
call_kwargs: dict[str, Any] = {**self.kwargs, **kwargs}
508523
max_retries: int = int(call_kwargs.pop("a0_retry_attempts", 2))
509524
retry_delay_s: float = float(call_kwargs.pop("a0_retry_delay_seconds", 1.5))
510-
stream = reasoning_callback is not None or response_callback is not None or tokens_callback is not None
525+
stream = (
526+
reasoning_callback is not None
527+
or response_callback is not None
528+
or tokens_callback is not None
529+
)
511530

512531
# results
513532
result = ChatGenerationResult()
@@ -537,15 +556,21 @@ async def unified_call(
537556
# collect reasoning delta and call callbacks
538557
if output["reasoning_delta"]:
539558
if reasoning_callback:
540-
await reasoning_callback(output["reasoning_delta"], result.reasoning)
559+
await reasoning_callback(
560+
output["reasoning_delta"], result.reasoning
561+
)
541562
if tokens_callback:
542563
await tokens_callback(
543564
output["reasoning_delta"],
544565
approximate_tokens(output["reasoning_delta"]),
545566
)
546567
# Add output tokens to rate limiter if configured
547568
if limiter:
548-
limiter.add(output=approximate_tokens(output["reasoning_delta"]))
569+
limiter.add(
570+
output=approximate_tokens(
571+
output["reasoning_delta"]
572+
)
573+
)
549574
# collect response delta and call callbacks
550575
if output["response_delta"]:
551576
if response_callback:
@@ -559,7 +584,11 @@ async def unified_call(
559584
)
560585
# Add output tokens to rate limiter if configured
561586
if limiter:
562-
limiter.add(output=approximate_tokens(output["response_delta"]))
587+
limiter.add(
588+
output=approximate_tokens(
589+
output["response_delta"]
590+
)
591+
)
563592
if stop_response is not None:
564593
result.response = stop_response
565594
break
@@ -573,27 +602,48 @@ async def unified_call(
573602
output = result.add_chunk(parsed)
574603
if limiter:
575604
if output["response_delta"]:
576-
limiter.add(output=approximate_tokens(output["response_delta"]))
605+
limiter.add(
606+
output=approximate_tokens(output["response_delta"])
607+
)
577608
if output["reasoning_delta"]:
578-
limiter.add(output=approximate_tokens(output["reasoning_delta"]))
609+
limiter.add(
610+
output=approximate_tokens(output["reasoning_delta"])
611+
)
579612

580613
# Successful completion of stream
581614
return result.response, result.reasoning
582615

583616
except Exception as e:
584617
import asyncio
585618

586-
# Retry only if no chunks received and error is transient
587-
if got_any_chunk or not _is_transient_litellm_error(e) or attempt >= max_retries:
619+
if got_any_chunk or not _is_transient_litellm_error(e):
620+
raise
621+
622+
is_rate_limit = getattr(e, "status_code", None) == 429 or isinstance(
623+
e, litellm.RateLimitError
624+
)
625+
effective_max_retries = (
626+
max(max_retries, 5) if is_rate_limit else max_retries
627+
)
628+
if attempt >= effective_max_retries:
588629
raise
630+
589631
attempt += 1
590-
await asyncio.sleep(retry_delay_s)
632+
if is_rate_limit:
633+
delay = min(10.0 * (2 ** (attempt - 1)), 60.0)
634+
else:
635+
delay = retry_delay_s
636+
await asyncio.sleep(delay)
591637

592638

593639
class LiteLLMEmbeddingWrapper(Embeddings):
594640
model_name: str
595641
kwargs: dict = {}
596642
a0_model_conf: Optional[ModelConfig] = None
643+
_provider: str = ""
644+
_api_base: str = ""
645+
646+
model_config = ConfigDict(arbitrary_types_allowed=True)
597647

598648
def __init__(
599649
self,
@@ -603,14 +653,86 @@ def __init__(
603653
**kwargs: Any,
604654
):
605655
self.model_name = f"{provider}/{model}" if provider != "openai" else model
656+
self._provider = provider
657+
self._api_base = kwargs.pop("api_base", "") or ""
606658
self.kwargs = kwargs
607659
self.a0_model_conf = model_config
608660

661+
def _is_ollama(self) -> bool:
662+
return self._provider == "ollama"
663+
664+
def _ollama_embed(self, texts: List[str]) -> List[List[float]]:
665+
"""Bypass LiteLLM for Ollama — its handler sends a malformed body
666+
(ollama/ prefix in model name + unsupported kwargs) causing 400."""
667+
import httpx
668+
import time
669+
670+
# Sanitize: Ollama rejects null/None entries with HTTP 400 "invalid input type".
671+
# Convert None → empty string and ensure all items are str so JSON serialisation
672+
# never produces a null element in the input array.
673+
safe_texts = [
674+
t if isinstance(t, str) else ("" if t is None else str(t)) for t in texts
675+
]
676+
if safe_texts != texts:
677+
logging.warning(
678+
"Ollama embed %s: %d input(s) contained non-str values and were sanitised. "
679+
"Original types: %s",
680+
self.model_name,
681+
sum(1 for t in texts if not isinstance(t, str)),
682+
[type(t).__name__ for t in texts if not isinstance(t, str)],
683+
)
684+
texts = safe_texts
685+
686+
model = self.model_name.removeprefix("ollama/")
687+
api_base = self._api_base or os.environ.get(
688+
"OLLAMA_API_BASE",
689+
os.environ.get("OLLAMA_HOST", "http://localhost:11434"),
690+
)
691+
api_base = api_base.rstrip("/")
692+
if api_base.endswith("/api/embed") or api_base.endswith("/api/embeddings"):
693+
api_base = api_base.rsplit("/api/", 1)[0]
694+
695+
url = f"{api_base}/api/embed"
696+
payload = {"model": model, "input": texts}
697+
698+
last_exc: Exception = RuntimeError("no attempts made")
699+
for attempt in range(3):
700+
if attempt:
701+
time.sleep(2.0 * attempt)
702+
try:
703+
resp = httpx.post(url, json=payload, timeout=120.0)
704+
if resp.status_code != 200:
705+
logging.warning(
706+
"Ollama embed %s attempt %d: HTTP %d — %s | texts[:100]=%r",
707+
model,
708+
attempt + 1,
709+
resp.status_code,
710+
resp.text[:300],
711+
[t[:100] if isinstance(t, str) else t for t in texts],
712+
)
713+
resp.raise_for_status()
714+
return resp.json()["embeddings"]
715+
except httpx.HTTPStatusError as e:
716+
last_exc = e
717+
# 400 = bad request payload — retrying won't help, raise immediately
718+
if e.response.status_code == 400:
719+
raise
720+
# 429 / 503 = transient — retry with backoff
721+
if e.response.status_code not in (503, 429):
722+
raise
723+
except (httpx.ConnectError, httpx.TimeoutException) as e:
724+
last_exc = e
725+
raise last_exc
726+
609727
def embed_documents(self, texts: List[str]) -> List[List[float]]:
610728
# Apply rate limiting if configured
611729
apply_rate_limiter_sync(self.a0_model_conf, " ".join(texts))
612730

613-
resp = embedding(model=self.model_name, input=texts, **self.kwargs)
731+
if self._is_ollama():
732+
return self._ollama_embed(texts)
733+
734+
embed_kwargs = {"encoding_format": "float", **self.kwargs}
735+
resp = embedding(model=self.model_name, input=texts, **embed_kwargs)
614736
return [
615737
item.get("embedding") if isinstance(item, dict) else item.embedding # type: ignore
616738
for item in resp.data # type: ignore
@@ -620,7 +742,11 @@ def embed_query(self, text: str) -> List[float]:
620742
# Apply rate limiting if configured
621743
apply_rate_limiter_sync(self.a0_model_conf, text)
622744

623-
resp = embedding(model=self.model_name, input=[text], **self.kwargs)
745+
if self._is_ollama():
746+
return self._ollama_embed([text])[0]
747+
748+
embed_kwargs = {"encoding_format": "float", **self.kwargs}
749+
resp = embedding(model=self.model_name, input=[text], **embed_kwargs)
624750
item = resp.data[0] # type: ignore
625751
return item.get("embedding") if isinstance(item, dict) else item.embedding # type: ignore
626752

@@ -739,28 +865,35 @@ def _parse_chunk(chunk: Any) -> ChatChunk:
739865
"model_extra", {}
740866
).get("message", {})
741867
response_delta = (
742-
delta.get("content", "")
743-
if isinstance(delta, dict)
744-
else getattr(delta, "content", "")
745-
) or (
746-
message.get("content", "")
747-
if isinstance(message, dict)
748-
else getattr(message, "content", "")
749-
) or ""
868+
(
869+
delta.get("content", "")
870+
if isinstance(delta, dict)
871+
else getattr(delta, "content", "")
872+
)
873+
or (
874+
message.get("content", "")
875+
if isinstance(message, dict)
876+
else getattr(message, "content", "")
877+
)
878+
or ""
879+
)
750880
reasoning_delta = (
751-
delta.get("reasoning_content", "")
752-
if isinstance(delta, dict)
753-
else getattr(delta, "reasoning_content", "")
754-
) or (
755-
message.get("reasoning_content", "")
756-
if isinstance(message, dict)
757-
else getattr(message, "reasoning_content", "")
758-
) or ""
881+
(
882+
delta.get("reasoning_content", "")
883+
if isinstance(delta, dict)
884+
else getattr(delta, "reasoning_content", "")
885+
)
886+
or (
887+
message.get("reasoning_content", "")
888+
if isinstance(message, dict)
889+
else getattr(message, "reasoning_content", "")
890+
)
891+
or ""
892+
)
759893

760894
return ChatChunk(reasoning_delta=reasoning_delta, response_delta=response_delta)
761895

762896

763-
764897
def _adjust_call_args(provider_name: str, model_name: str, kwargs: dict):
765898

766899
# remap other to openai for litellm
@@ -827,6 +960,7 @@ def get_chat_model(
827960
LiteLLMChatWrapper, name, provider_name, model_config, **kwargs
828961
)
829962

963+
830964
def get_embedding_model(
831965
provider: str, name: str, model_config: Optional[ModelConfig] = None, **kwargs: Any
832966
) -> LiteLLMEmbeddingWrapper | LocalSentenceTransformerWrapper:

0 commit comments

Comments
 (0)