Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
eaf8121
fix: some params were lost (provider and api_key)
ea-rus Mar 30, 2026
4f6f768
fix max_tokens is not supported
ea-rus Mar 30, 2026
b03d081
fix readable error if embedding model is not in default config and is…
ea-rus Mar 30, 2026
aec266c
ruff
ea-rus Mar 30, 2026
4a7afc4
snowflake provider
ea-rus Apr 2, 2026
0c8cb44
remove litellm
ea-rus Apr 2, 2026
6a2da23
ollama fixes
ea-rus Apr 2, 2026
9017cae
fix bedrock
ea-rus Apr 2, 2026
53c4d34
Merge branch 'azure-fixes' into litellm-replace
ea-rus Apr 2, 2026
51c95c9
dockstrings
ea-rus Apr 3, 2026
7d72417
unit tests
ea-rus Apr 3, 2026
3823774
check reqs
ea-rus Apr 3, 2026
bef4fce
check reqs
ea-rus Apr 3, 2026
6f93970
fix listwise rerank
ea-rus Apr 3, 2026
84321c4
Merge branch 'releases/26.1.0' into litellm-replace
ea-rus Apr 6, 2026
bcfcf71
remove pragma
ea-rus Apr 6, 2026
90f7342
fix
ea-rus Apr 6, 2026
07e3812
fix
ea-rus Apr 8, 2026
add4946
Merge branch 'releases/26.1.0' into litellm-replace
ea-rus Apr 9, 2026
eb3be42
Merge branch 'releases/26.1.0' into litellm-replace
ea-rus Apr 9, 2026
5808da3
replace dependency aioboto3 -> aiobotocore
ea-rus Apr 9, 2026
b084a50
Merge branch 'releases/26.1.0' into litellm-replace
ea-rus Apr 10, 2026
df7b468
Merge branch 'releases/26.1.0' into litellm-replace
ea-rus Apr 13, 2026
6290b1e
don't save params
ea-rus Apr 13, 2026
ca82005
fix ollama
ea-rus Apr 13, 2026
c101439
Merge remote-tracking branch 'origin/litellm-replace' into litellm-re…
ea-rus Apr 13, 2026
800f0ed
fix init
ea-rus Apr 14, 2026
e3ca04a
ruff
ea-rus Apr 14, 2026
5f276ad
fix to keep error message
ea-rus Apr 14, 2026
daa7cc1
Merge branch 'releases/26.1.0' into litellm-replace
ea-rus Apr 15, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 29 additions & 28 deletions mindsdb/integrations/utilities/rag/rerankers/base_reranker.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,11 @@
DEFAULT_VALID_CLASS_TOKENS,
RerankerMode,
)
from mindsdb.integrations.libs.base import BaseMLEngine

from mindsdb.interfaces.knowledge_base.providers.bedrock import AsyncBedrockClient
from mindsdb.interfaces.knowledge_base.providers.gemini import GeminiClient
from mindsdb.interfaces.knowledge_base.providers.snowflake import SnowflakeClient


log = logging.getLogger(__name__)

Expand All @@ -46,10 +50,10 @@ class BaseLLMReranker(BaseModel):
base_url: Optional[str] = None
api_version: Optional[str] = None
num_docs_to_keep: Optional[int] = None # How many of the top documents to keep after reranking & compressing.
method: str = "multi-class" # Scoring method: 'multi-class' or 'binary'
method: str = "no-logprobs" # Scoring method: 'multi-class' or 'no-logprobs'
mode: RerankerMode = RerankerMode.POINTWISE
_api_key_var: str = "OPENAI_API_KEY"
client: Optional[AsyncOpenAI | BaseMLEngine] = None
client: Optional[AsyncOpenAI | AsyncBedrockClient | GeminiClient | SnowflakeClient] = None
_semaphore: Optional[asyncio.Semaphore] = None
max_concurrent_requests: int = 20
max_retries: int = 4
Expand All @@ -74,6 +78,9 @@ def __init__(self, **kwargs):

def _init_client(self):
if self.client is None:
if self.provider == "google":
self.provider = "gemini"

if self.provider == "azure_openai":
azure_api_key = self.api_key or os.getenv("AZURE_OPENAI_API_KEY")
azure_api_endpoint = self.base_url or os.environ.get("AZURE_OPENAI_ENDPOINT")
Expand All @@ -85,11 +92,21 @@ def _init_client(self):
timeout=self.request_timeout,
max_retries=2,
)
self.method = "multi-class"
elif self.provider == "bedrock":
kwargs = self.model_extra.copy()
self.client = AsyncBedrockClient(**kwargs)
elif self.provider == "gemini":
self.client = GeminiClient(api_key=self.api_key)
elif self.provider == "snowflake":
kwargs = self.model_extra.copy()
self.client = SnowflakeClient(api_key=self.api_key, **kwargs)
elif self.provider in ("openai", "ollama"):
if self.provider == "ollama":
self.method = "no-logprobs"
if self.api_key is None:
self.api_key = "n/a"
else:
self.method = "multi-class"

api_key_var: str = "OPENAI_API_KEY"
openai_api_key = self.api_key or os.getenv(api_key_var)
Expand All @@ -101,31 +118,17 @@ def _init_client(self):
api_key=openai_api_key, base_url=base_url, timeout=self.request_timeout, max_retries=2
)
else:
# try to use litellm
from mindsdb.api.executor.controllers.session_controller import SessionController

session = SessionController()
module = session.integration_controller.get_handler_module("litellm")

if module is None or module.Handler is None:
raise ValueError(f'Unable to use "{self.provider}" provider. Litellm handler is not installed')

self.client = module.Handler
self.method = "no-logprobs"
raise NotImplementedError(f'Provider "{self.provider}" is not supported')

async def _call_llm(self, messages):
async def _call_llm(self, messages) -> str:
if self.provider in ("azure_openai", "openai", "ollama"):
return await self.client.chat.completions.create(
response = await self.client.chat.completions.create(
model=self.model,
messages=messages,
)
return response.choices[0].message.content
else:
kwargs = self.model_extra.copy()

if self.api_key is not None:
kwargs["api_key"] = self.api_key

return await self.client.acompletion(self.provider, model=self.model, messages=messages, args=kwargs)
return await self.client.acompletion(model_name=self.model, messages=messages)

async def _rank(self, query_document_pairs: List[Tuple[str, str]], rerank_callback=None) -> List[Tuple[str, float]]:
ranked_results = []
Expand Down Expand Up @@ -236,12 +239,10 @@ async def search_relevancy_no_logprob(self, query: str, document: str) -> Any:
f"Search query: {query}"
)

response = await self._call_llm(
answer = await self._call_llm(
messages=[{"role": "system", "content": prompt}, {"role": "user", "content": document}],
)

answer = response.choices[0].message.content

try:
value = re.findall(r"[\d]+", answer)[0]
score = float(value) / 100
Expand Down Expand Up @@ -483,8 +484,8 @@ async def _rank_single_batch(

for attempt in range(self.max_retries):
try:
response = await self._call_llm(messages)
content = response.choices[0].message.content
content = await self._call_llm(messages)

scores = self._extract_scores(content, len(documents))
return list(zip(documents, scores))
except Exception as exc:
Expand Down
4 changes: 2 additions & 2 deletions mindsdb/interfaces/knowledge_base/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -1319,9 +1319,9 @@ def _check_embedding_model(self, project_name, params: dict = None, kb_name="")
f"Wrong embedding provider: {params['provider']}. Available providers: {', '.join(avail_providers)}"
)

llm_client = LLMClient(params, session=self.session)

try:
llm_client = LLMClient(params, session=self.session)

resp = llm_client.embeddings(["test"])
return {"dimension": len(resp[0])}
except Exception as e:
Expand Down
69 changes: 29 additions & 40 deletions mindsdb/interfaces/knowledge_base/llm_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@

from mindsdb.integrations.utilities.handler_utils import get_api_key

from mindsdb.interfaces.knowledge_base.providers.bedrock import BedrockClient
from mindsdb.interfaces.knowledge_base.providers.gemini import GeminiClient
from mindsdb.interfaces.knowledge_base.providers.snowflake import SnowflakeClient


def retry_with_exponential_backoff(func):
def decorator(*args, **kwargs):
Expand Down Expand Up @@ -60,22 +64,23 @@ def wrapper(self, messages, *args, **kwargs):
class LLMClient:
"""
Class for accession to LLM.
It chooses openai client or litellm handler depending on the config
It chooses provider client depending on the config
"""

def __init__(self, params: dict = None, session=None):
self._session = session
self.params = params
params = params.copy()

self.provider = params.get("provider", "openai")
self.provider = params.pop("provider", "openai")
self.model_name = params.pop("model_name")
if self.provider == "google":
self.provider = "gemini"

if "api_key" not in params:
api_key = get_api_key(self.provider, params, strict=False)
if api_key is not None:
params["api_key"] = api_key

self.engine = "openai"

if self.provider == "azure_openai":
azure_api_key = params.get("api_key") or os.getenv("AZURE_OPENAI_API_KEY")
azure_api_endpoint = params.get("base_url") or os.environ.get("AZURE_OPENAI_ENDPOINT")
Expand All @@ -91,62 +96,46 @@ def __init__(self, params: dict = None, session=None):
kwargs["base_url"] = base_url
self.client = OpenAI(**kwargs)
elif self.provider == "ollama":
kwargs = params.copy()
kwargs.pop("model_name")
kwargs.pop("provider", None)
if kwargs.get("api_key") is None:
kwargs["api_key"] = "n/a"
self.client = OpenAI(**kwargs)
if params.get("api_key") is None:
params["api_key"] = "n/a"
self.client = OpenAI(**params)
elif self.provider == "bedrock":
if "aws_region" in params:
params["aws_region_name"] = params.pop("aws_region")
self.client = BedrockClient(**params)
elif self.provider == "gemini":
self.client = GeminiClient(**params)
elif self.provider == "snowflake":
self.client = SnowflakeClient(**params)
else:
# try to use litellm
if self._session is None:
from mindsdb.api.executor.controllers.session_controller import SessionController

self._session = SessionController()
module = self._session.integration_controller.get_handler_module("litellm")

if module is None or module.Handler is None:
raise ValueError(f'Unable to use "{self.provider}" provider. Litellm handler is not installed')

self.client = module.Handler
self.engine = "litellm"
raise NotImplementedError(f'Provider "{self.provider}" is not supported')

@run_in_batches(1000)
@retry_with_exponential_backoff
def embeddings(self, messages: List[str]):
params = self.params
if self.engine == "openai":
if self.provider in ("openai", "azure_openai", "ollama"):
response = self.client.embeddings.create(
model=params["model_name"],
model=self.model_name,
input=messages,
)
return [item.embedding for item in response.data]
else:
kwargs = params.copy()
model = kwargs.pop("model_name")
kwargs.pop("provider", None)

return self.client.embeddings(self.provider, model=model, messages=messages, args=kwargs)
return self.client.embeddings(self.model_name, messages)

@run_in_batches(100)
def completion(self, messages: List[dict], json_output: bool = False) -> List[str]:
"""
Call LLM completion and get response
"""
params = self.params
if self.engine == "openai":

if self.provider in ("openai", "azure_openai", "ollama"):
response = self.client.chat.completions.create(
model=params["model_name"],
model=self.model_name,
messages=messages,
)
return [item.message.content for item in response.choices]
else:
kwargs = params.copy()
params["json_output"] = json_output
model = kwargs.pop("model_name")
kwargs.pop("provider", None)
response = self.client.completion(self.provider, model=model, messages=messages, args=kwargs)
return [item.message.content for item in response.choices]
return [self.client.completion(self.model_name, messages)]

async def abatch(self, messages_list: List[List[dict]], json_output: bool = False) -> List[List[str]]:
"""
Expand Down
1 change: 1 addition & 0 deletions mindsdb/interfaces/knowledge_base/providers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@

Loading
Loading