2626from helpers .rate_limiter import RateLimiter
2727from helpers .tokens import approximate_tokens
2828from helpers import dirty_json
29- from helpers .extension import extensible # extensible: allows plugins to intercept get_api_key()
29+ from helpers .extension import (
30+ extensible ,
31+ ) # extensible: allows plugins to intercept get_api_key()
3032
3133from langchain_core .language_models .chat_models import SimpleChatModel
3234from langchain_core .outputs .chat_generation import ChatGenerationChunk
@@ -59,6 +61,7 @@ def turn_off_logging():
5961load_dotenv ()
6062turn_off_logging ()
6163
64+
6265class ModelType (Enum ):
6366 CHAT = "Chat"
6467 EMBEDDING = "Embedding"
@@ -89,12 +92,15 @@ def build_kwargs(self):
8992
9093class ChatChunk (TypedDict ):
9194 """Simplified response chunk for chat models."""
95+
9296 response_delta : str
9397 reasoning_delta : str
9498
99+
95100class ChatGenerationResult :
96101 """Chat generation result object"""
97- def __init__ (self , chunk : ChatChunk | None = None ):
102+
103+ def __init__ (self , chunk : ChatChunk | None = None ):
98104 self .reasoning = ""
99105 self .response = ""
100106 self .thinking = False
@@ -111,7 +117,10 @@ def add_chunk(self, chunk: ChatChunk) -> ChatChunk:
111117
112118 # if native reasoning detection works, there's no need to worry about thinking tags
113119 if self .native_reasoning :
114- processed_chunk = ChatChunk (response_delta = chunk ["response_delta" ], reasoning_delta = chunk ["reasoning_delta" ])
120+ processed_chunk = ChatChunk (
121+ response_delta = chunk ["response_delta" ],
122+ reasoning_delta = chunk ["reasoning_delta" ],
123+ )
115124 else :
116125 # if the model outputs thinking tags, we ned to parse them manually as reasoning
117126 processed_chunk = self ._process_thinking_chunk (chunk )
@@ -131,7 +140,7 @@ def _process_thinking_tags(self, response: str, reasoning: str) -> ChatChunk:
131140 close_pos = response .find (self .thinking_tag )
132141 if close_pos != - 1 :
133142 reasoning += response [:close_pos ]
134- response = response [close_pos + len (self .thinking_tag ):]
143+ response = response [close_pos + len (self .thinking_tag ) :]
135144 self .thinking = False
136145 self .thinking_tag = ""
137146 else :
@@ -144,14 +153,14 @@ def _process_thinking_tags(self, response: str, reasoning: str) -> ChatChunk:
144153 else :
145154 for opening_tag , closing_tag in self .thinking_pairs :
146155 if response .startswith (opening_tag ):
147- response = response [len (opening_tag ):]
156+ response = response [len (opening_tag ) :]
148157 self .thinking = True
149158 self .thinking_tag = closing_tag
150159
151160 close_pos = response .find (closing_tag )
152161 if close_pos != - 1 :
153162 reasoning += response [:close_pos ]
154- response = response [close_pos + len (closing_tag ):]
163+ response = response [close_pos + len (closing_tag ) :]
155164 self .thinking = False
156165 self .thinking_tag = ""
157166 else :
@@ -162,7 +171,9 @@ def _process_thinking_tags(self, response: str, reasoning: str) -> ChatChunk:
162171 reasoning += response
163172 response = ""
164173 break
165- elif len (response ) < len (opening_tag ) and self ._is_partial_opening_tag (response , opening_tag ):
174+ elif len (response ) < len (opening_tag ) and self ._is_partial_opening_tag (
175+ response , opening_tag
176+ ):
166177 self .unprocessed = response
167178 response = ""
168179 break
@@ -318,7 +329,9 @@ def __init__(
318329 def _llm_type (self ) -> str :
319330 return "litellm-chat"
320331
321- def _convert_messages (self , messages : List [BaseMessage ], explicit_caching : bool = False ) -> List [dict ]:
332+ def _convert_messages (
333+ self , messages : List [BaseMessage ], explicit_caching : bool = False
334+ ) -> List [dict ]:
322335 result = []
323336 # Map LangChain message types to LiteLLM roles
324337 role_mapping = {
@@ -365,7 +378,9 @@ def _convert_messages(self, messages: List[BaseMessage], explicit_caching: bool
365378
366379 # fix messages with empty content, this breaks some LLMs
367380 content = message_dict .get ("content" )
368- has_content = bool (content ) if not isinstance (content , list ) else len (content ) > 0
381+ has_content = (
382+ bool (content ) if not isinstance (content , list ) else len (content ) > 0
383+ )
369384 if not has_content :
370385 message_dict ["content" ] = "empty"
371386
@@ -429,8 +444,8 @@ def _stream(
429444 ** {** self .kwargs , ** kwargs },
430445 ):
431446 # parse chunk
432- parsed = _parse_chunk (chunk ) # chunk parsing
433- output = result .add_chunk (parsed ) # chunk processing
447+ parsed = _parse_chunk (chunk ) # chunk parsing
448+ output = result .add_chunk (parsed ) # chunk processing
434449
435450 # Only yield chunks with non-None content
436451 if output ["response_delta" ]:
@@ -461,8 +476,8 @@ async def _astream(
461476 )
462477 async for chunk in response : # type: ignore
463478 # parse chunk
464- parsed = _parse_chunk (chunk ) # chunk parsing
465- output = result .add_chunk (parsed ) # chunk processing
479+ parsed = _parse_chunk (chunk ) # chunk parsing
480+ output = result .add_chunk (parsed ) # chunk processing
466481
467482 # Only yield chunks with non-None content
468483 if output ["response_delta" ]:
@@ -507,7 +522,11 @@ async def unified_call(
507522 call_kwargs : dict [str , Any ] = {** self .kwargs , ** kwargs }
508523 max_retries : int = int (call_kwargs .pop ("a0_retry_attempts" , 2 ))
509524 retry_delay_s : float = float (call_kwargs .pop ("a0_retry_delay_seconds" , 1.5 ))
510- stream = reasoning_callback is not None or response_callback is not None or tokens_callback is not None
525+ stream = (
526+ reasoning_callback is not None
527+ or response_callback is not None
528+ or tokens_callback is not None
529+ )
511530
512531 # results
513532 result = ChatGenerationResult ()
@@ -537,15 +556,21 @@ async def unified_call(
537556 # collect reasoning delta and call callbacks
538557 if output ["reasoning_delta" ]:
539558 if reasoning_callback :
540- await reasoning_callback (output ["reasoning_delta" ], result .reasoning )
559+ await reasoning_callback (
560+ output ["reasoning_delta" ], result .reasoning
561+ )
541562 if tokens_callback :
542563 await tokens_callback (
543564 output ["reasoning_delta" ],
544565 approximate_tokens (output ["reasoning_delta" ]),
545566 )
546567 # Add output tokens to rate limiter if configured
547568 if limiter :
548- limiter .add (output = approximate_tokens (output ["reasoning_delta" ]))
569+ limiter .add (
570+ output = approximate_tokens (
571+ output ["reasoning_delta" ]
572+ )
573+ )
549574 # collect response delta and call callbacks
550575 if output ["response_delta" ]:
551576 if response_callback :
@@ -559,7 +584,11 @@ async def unified_call(
559584 )
560585 # Add output tokens to rate limiter if configured
561586 if limiter :
562- limiter .add (output = approximate_tokens (output ["response_delta" ]))
587+ limiter .add (
588+ output = approximate_tokens (
589+ output ["response_delta" ]
590+ )
591+ )
563592 if stop_response is not None :
564593 result .response = stop_response
565594 break
@@ -573,27 +602,48 @@ async def unified_call(
573602 output = result .add_chunk (parsed )
574603 if limiter :
575604 if output ["response_delta" ]:
576- limiter .add (output = approximate_tokens (output ["response_delta" ]))
605+ limiter .add (
606+ output = approximate_tokens (output ["response_delta" ])
607+ )
577608 if output ["reasoning_delta" ]:
578- limiter .add (output = approximate_tokens (output ["reasoning_delta" ]))
609+ limiter .add (
610+ output = approximate_tokens (output ["reasoning_delta" ])
611+ )
579612
580613 # Successful completion of stream
581614 return result .response , result .reasoning
582615
583616 except Exception as e :
584617 import asyncio
585618
586- # Retry only if no chunks received and error is transient
587- if got_any_chunk or not _is_transient_litellm_error (e ) or attempt >= max_retries :
619+ if got_any_chunk or not _is_transient_litellm_error (e ):
620+ raise
621+
622+ is_rate_limit = getattr (e , "status_code" , None ) == 429 or isinstance (
623+ e , litellm .RateLimitError
624+ )
625+ effective_max_retries = (
626+ max (max_retries , 5 ) if is_rate_limit else max_retries
627+ )
628+ if attempt >= effective_max_retries :
588629 raise
630+
589631 attempt += 1
590- await asyncio .sleep (retry_delay_s )
632+ if is_rate_limit :
633+ delay = min (10.0 * (2 ** (attempt - 1 )), 60.0 )
634+ else :
635+ delay = retry_delay_s
636+ await asyncio .sleep (delay )
591637
592638
593639class LiteLLMEmbeddingWrapper (Embeddings ):
594640 model_name : str
595641 kwargs : dict = {}
596642 a0_model_conf : Optional [ModelConfig ] = None
643+ _provider : str = ""
644+ _api_base : str = ""
645+
646+ model_config = ConfigDict (arbitrary_types_allowed = True )
597647
598648 def __init__ (
599649 self ,
@@ -603,14 +653,86 @@ def __init__(
603653 ** kwargs : Any ,
604654 ):
605655 self .model_name = f"{ provider } /{ model } " if provider != "openai" else model
656+ self ._provider = provider
657+ self ._api_base = kwargs .pop ("api_base" , "" ) or ""
606658 self .kwargs = kwargs
607659 self .a0_model_conf = model_config
608660
661+ def _is_ollama (self ) -> bool :
662+ return self ._provider == "ollama"
663+
664+ def _ollama_embed (self , texts : List [str ]) -> List [List [float ]]:
665+ """Bypass LiteLLM for Ollama — its handler sends a malformed body
666+ (ollama/ prefix in model name + unsupported kwargs) causing 400."""
667+ import httpx
668+ import time
669+
670+ # Sanitize: Ollama rejects null/None entries with HTTP 400 "invalid input type".
671+ # Convert None → empty string and ensure all items are str so JSON serialisation
672+ # never produces a null element in the input array.
673+ safe_texts = [
674+ t if isinstance (t , str ) else ("" if t is None else str (t )) for t in texts
675+ ]
676+ if safe_texts != texts :
677+ logging .warning (
678+ "Ollama embed %s: %d input(s) contained non-str values and were sanitised. "
679+ "Original types: %s" ,
680+ self .model_name ,
681+ sum (1 for t in texts if not isinstance (t , str )),
682+ [type (t ).__name__ for t in texts if not isinstance (t , str )],
683+ )
684+ texts = safe_texts
685+
686+ model = self .model_name .removeprefix ("ollama/" )
687+ api_base = self ._api_base or os .environ .get (
688+ "OLLAMA_API_BASE" ,
689+ os .environ .get ("OLLAMA_HOST" , "http://localhost:11434" ),
690+ )
691+ api_base = api_base .rstrip ("/" )
692+ if api_base .endswith ("/api/embed" ) or api_base .endswith ("/api/embeddings" ):
693+ api_base = api_base .rsplit ("/api/" , 1 )[0 ]
694+
695+ url = f"{ api_base } /api/embed"
696+ payload = {"model" : model , "input" : texts }
697+
698+ last_exc : Exception = RuntimeError ("no attempts made" )
699+ for attempt in range (3 ):
700+ if attempt :
701+ time .sleep (2.0 * attempt )
702+ try :
703+ resp = httpx .post (url , json = payload , timeout = 120.0 )
704+ if resp .status_code != 200 :
705+ logging .warning (
706+ "Ollama embed %s attempt %d: HTTP %d — %s | texts[:100]=%r" ,
707+ model ,
708+ attempt + 1 ,
709+ resp .status_code ,
710+ resp .text [:300 ],
711+ [t [:100 ] if isinstance (t , str ) else t for t in texts ],
712+ )
713+ resp .raise_for_status ()
714+ return resp .json ()["embeddings" ]
715+ except httpx .HTTPStatusError as e :
716+ last_exc = e
717+ # 400 = bad request payload — retrying won't help, raise immediately
718+ if e .response .status_code == 400 :
719+ raise
720+ # 429 / 503 = transient — retry with backoff
721+ if e .response .status_code not in (503 , 429 ):
722+ raise
723+ except (httpx .ConnectError , httpx .TimeoutException ) as e :
724+ last_exc = e
725+ raise last_exc
726+
609727 def embed_documents (self , texts : List [str ]) -> List [List [float ]]:
610728 # Apply rate limiting if configured
611729 apply_rate_limiter_sync (self .a0_model_conf , " " .join (texts ))
612730
613- resp = embedding (model = self .model_name , input = texts , ** self .kwargs )
731+ if self ._is_ollama ():
732+ return self ._ollama_embed (texts )
733+
734+ embed_kwargs = {"encoding_format" : "float" , ** self .kwargs }
735+ resp = embedding (model = self .model_name , input = texts , ** embed_kwargs )
614736 return [
615737 item .get ("embedding" ) if isinstance (item , dict ) else item .embedding # type: ignore
616738 for item in resp .data # type: ignore
@@ -620,7 +742,11 @@ def embed_query(self, text: str) -> List[float]:
620742 # Apply rate limiting if configured
621743 apply_rate_limiter_sync (self .a0_model_conf , text )
622744
623- resp = embedding (model = self .model_name , input = [text ], ** self .kwargs )
745+ if self ._is_ollama ():
746+ return self ._ollama_embed ([text ])[0 ]
747+
748+ embed_kwargs = {"encoding_format" : "float" , ** self .kwargs }
749+ resp = embedding (model = self .model_name , input = [text ], ** embed_kwargs )
624750 item = resp .data [0 ] # type: ignore
625751 return item .get ("embedding" ) if isinstance (item , dict ) else item .embedding # type: ignore
626752
@@ -739,28 +865,35 @@ def _parse_chunk(chunk: Any) -> ChatChunk:
739865 "model_extra" , {}
740866 ).get ("message" , {})
741867 response_delta = (
742- delta .get ("content" , "" )
743- if isinstance (delta , dict )
744- else getattr (delta , "content" , "" )
745- ) or (
746- message .get ("content" , "" )
747- if isinstance (message , dict )
748- else getattr (message , "content" , "" )
749- ) or ""
868+ (
869+ delta .get ("content" , "" )
870+ if isinstance (delta , dict )
871+ else getattr (delta , "content" , "" )
872+ )
873+ or (
874+ message .get ("content" , "" )
875+ if isinstance (message , dict )
876+ else getattr (message , "content" , "" )
877+ )
878+ or ""
879+ )
750880 reasoning_delta = (
751- delta .get ("reasoning_content" , "" )
752- if isinstance (delta , dict )
753- else getattr (delta , "reasoning_content" , "" )
754- ) or (
755- message .get ("reasoning_content" , "" )
756- if isinstance (message , dict )
757- else getattr (message , "reasoning_content" , "" )
758- ) or ""
881+ (
882+ delta .get ("reasoning_content" , "" )
883+ if isinstance (delta , dict )
884+ else getattr (delta , "reasoning_content" , "" )
885+ )
886+ or (
887+ message .get ("reasoning_content" , "" )
888+ if isinstance (message , dict )
889+ else getattr (message , "reasoning_content" , "" )
890+ )
891+ or ""
892+ )
759893
760894 return ChatChunk (reasoning_delta = reasoning_delta , response_delta = response_delta )
761895
762896
763-
764897def _adjust_call_args (provider_name : str , model_name : str , kwargs : dict ):
765898
766899 # remap other to openai for litellm
@@ -827,6 +960,7 @@ def get_chat_model(
827960 LiteLLMChatWrapper , name , provider_name , model_config , ** kwargs
828961 )
829962
963+
830964def get_embedding_model (
831965 provider : str , name : str , model_config : Optional [ModelConfig ] = None , ** kwargs : Any
832966) -> LiteLLMEmbeddingWrapper | LocalSentenceTransformerWrapper :
0 commit comments