From eaf8121a4137929d0548910105720e7d201a868d Mon Sep 17 00:00:00 2001 From: andrew Date: Mon, 30 Mar 2026 14:15:40 +0300 Subject: [PATCH 01/22] fix: some params were lost (provider and api_key) --- mindsdb/interfaces/functions/controller.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/mindsdb/interfaces/functions/controller.py b/mindsdb/interfaces/functions/controller.py index 6503e6af402..03585d407bb 100644 --- a/mindsdb/interfaces/functions/controller.py +++ b/mindsdb/interfaces/functions/controller.py @@ -139,11 +139,6 @@ def llm_call_function(self, node): try: from mindsdb.interfaces.knowledge_base.llm_client import LLMClient - - llm_config = get_llm_config(chat_model_params["provider"], chat_model_params) - chat_model_params = llm_config.model_dump(by_alias=True) - chat_model_params = {k: v for k, v in chat_model_params.items() if v is not None} - llm = LLMClient(chat_model_params, session=self.session) except Exception as e: raise RuntimeError(f"Unable to use LLM function, check ENV variables: {e}") from e From 4f6f768120eaea7ed81c9eeb55989b49cd1e1efb Mon Sep 17 00:00:00 2001 From: andrew Date: Mon, 30 Mar 2026 14:16:11 +0300 Subject: [PATCH 02/22] fix max_tokens is not supported --- mindsdb/integrations/utilities/rag/rerankers/base_reranker.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mindsdb/integrations/utilities/rag/rerankers/base_reranker.py b/mindsdb/integrations/utilities/rag/rerankers/base_reranker.py index b97b18898fc..496f645327c 100644 --- a/mindsdb/integrations/utilities/rag/rerankers/base_reranker.py +++ b/mindsdb/integrations/utilities/rag/rerankers/base_reranker.py @@ -207,7 +207,7 @@ async def search_relevancy(self, query: str, document: str) -> Any: temperature=self.temperature, n=1, logprobs=True, - max_tokens=1, + max_completion_tokens=1, ) # Extract response and logprobs @@ -355,7 +355,7 @@ async def search_relevancy_score(self, query: str, document: str) -> Any: n=self.n, logprobs=self.logprobs, top_logprobs=self.top_logprobs, - max_tokens=self.max_tokens, + max_completion_tokens=self.max_tokens, ) # Extract response and logprobs From b03d081aa3fed85d4fc9ab56ce420b4bc0f91436 Mon Sep 17 00:00:00 2001 From: andrew Date: Mon, 30 Mar 2026 14:17:09 +0300 Subject: [PATCH 03/22] fix readable error if embedding model is not in default config and isn't specified --- mindsdb/interfaces/knowledge_base/controller.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/mindsdb/interfaces/knowledge_base/controller.py b/mindsdb/interfaces/knowledge_base/controller.py index ad7cdb7fbad..26ab06b85c1 100644 --- a/mindsdb/interfaces/knowledge_base/controller.py +++ b/mindsdb/interfaces/knowledge_base/controller.py @@ -180,7 +180,7 @@ def rotate_provider_api_key(params): :param params: input params, can be modified by this function :return: a new api key if it is refreshed """ - provider = params.get("provider").lower() + provider = params.get("provider", "").lower() if provider == "snowflake": if "snowflake_account_id" in params: @@ -1209,6 +1209,9 @@ def add( raise EntityExistsError("Knowledge base already exists", name) embedding_params = get_model_params(params.get("embedding_model", {}), "default_embedding_model") + if not bool(embedding_params): + raise ValueError("No embedding model parameters provided") + params["embedding_model"] = embedding_params rotate_provider_api_key(embedding_params) From aec266c29c1ab9560b7591a5baddd142f136ac88 Mon Sep 17 00:00:00 2001 From: andrew Date: Mon, 30 Mar 2026 14:20:14 +0300 Subject: [PATCH 04/22] ruff --- mindsdb/interfaces/functions/controller.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mindsdb/interfaces/functions/controller.py b/mindsdb/interfaces/functions/controller.py index 03585d407bb..86abe4022a2 100644 --- a/mindsdb/interfaces/functions/controller.py +++ b/mindsdb/interfaces/functions/controller.py @@ -4,7 +4,6 @@ from duckdb.typing import BIGINT, DOUBLE, VARCHAR, BLOB, BOOLEAN from mindsdb.interfaces.storage.model_fs import HandlerStorage -from mindsdb.integrations.libs.llm.utils import get_llm_config from mindsdb.utilities.config import config @@ -139,6 +138,7 @@ def llm_call_function(self, node): try: from mindsdb.interfaces.knowledge_base.llm_client import LLMClient + llm = LLMClient(chat_model_params, session=self.session) except Exception as e: raise RuntimeError(f"Unable to use LLM function, check ENV variables: {e}") from e From 4a7afc494481225b8aa6443f70035dd88f9a66b0 Mon Sep 17 00:00:00 2001 From: andrew Date: Thu, 2 Apr 2026 17:31:29 +0300 Subject: [PATCH 05/22] snowflake provider --- .../knowledge_base/providers/snowflake.py | 113 ++++++++++++++++++ 1 file changed, 113 insertions(+) create mode 100644 mindsdb/interfaces/knowledge_base/providers/snowflake.py diff --git a/mindsdb/interfaces/knowledge_base/providers/snowflake.py b/mindsdb/interfaces/knowledge_base/providers/snowflake.py new file mode 100644 index 00000000000..2a1f2a03990 --- /dev/null +++ b/mindsdb/interfaces/knowledge_base/providers/snowflake.py @@ -0,0 +1,113 @@ +from typing import List, Optional + +import requests +import httpx + + +def _raise_for_status(response): + # show response text in error + if 400 <= response.status_code < 600: + if hasattr(response, "reason"): + reason = response.reason + elif hasattr(response, "reason_phrase"): + reason = response.reason_phrase + else: + reason = "Error" + raise requests.HTTPError(f"{reason}: {response.text}", response=response) + + +class SnowflakeClient: + def __init__(self, account_id: str = None, api_key: str = None): + if account_id is None: + raise ValueError("account_id must be provided") + if api_key is None: + raise ValueError("api_key must be provided") + + self.account_id = account_id.lower() + self.api_key = api_key + + self.auth_type = "KEYPAIR_JWT" + if self.api_key.startswith("pat/"): + self.api_key = self.api_key[4:] + self.auth_type = "PROGRAMMATIC_ACCESS_TOKEN" + + def _get_base_url(self): + return f"https://{self.account_id}.snowflakecomputing.com/api/v2" + + def _get_headers(self): + return { + "Content-Type": "application/json", + "Accept": "application/json", + "Authorization": "Bearer " + self.api_key, + "X-Snowflake-Authorization-Token-Type": self.auth_type, + } + + def embeddings(self, model_name: str, messages: List[str]): + url = f"{self._get_base_url()}/cortex/inference:embed" + + payload = {"text": messages, "model": model_name} + + response = requests.post(url, json=payload, headers=self._get_headers()) + _raise_for_status(response) + + embeddings = [] + for item in response.json()["data"]: + embeddings.append(item["embedding"][0]) + return embeddings + + def completion( + self, + model_name: str, + messages: List[dict], + temperature: Optional[float] = None, + max_tokens: Optional[int] = None, + top_p: Optional[float] = None, + ): + url = f"{self._get_base_url()}/cortex/inference:complete" + + payload = { + "model": model_name, + "stream": False, + "messages": messages, + } + + if temperature: + payload["temperature"] = temperature + if max_tokens: + payload["max_tokens"] = max_tokens + if top_p: + payload["top_p"] = top_p + + response = requests.post(url, json=payload, headers=self._get_headers()) + _raise_for_status(response) + data = response.json() + return data["choices"][0]["message"]["content"] + + async def acompletion( + self, + model_name: str, + messages: List[dict], + temperature: Optional[float] = None, + max_tokens: Optional[int] = None, + top_p: Optional[float] = None, + ): + url = f"{self._get_base_url()}/cortex/inference:complete" + + payload = { + "model": model_name, + "stream": False, + "messages": messages, + } + + if temperature: + payload["temperature"] = temperature + if max_tokens: + payload["max_tokens"] = max_tokens + if top_p: + payload["top_p"] = top_p + + async with httpx.AsyncClient() as client: + response = await client.post(url, json=payload, headers=self._get_headers()) + _raise_for_status(response) + data = response.json() + return data["choices"][0]["message"]["content"] From 0c8cb44c8494fb65194bc7a45b8853126d573a97 Mon Sep 17 00:00:00 2001 From: andrew Date: Thu, 2 Apr 2026 18:20:05 +0300 Subject: [PATCH 06/22] remove litellm --- .../utilities/rag/rerankers/base_reranker.py | 57 ++++---- .../interfaces/knowledge_base/llm_client.py | 58 +++----- .../knowledge_base/providers/bedrock.py | 137 ++++++++++++++++++ .../knowledge_base/providers/gemini.py | 73 ++++++++++ requirements/requirements.txt | 5 + 5 files changed, 268 insertions(+), 62 deletions(-) create mode 100644 mindsdb/interfaces/knowledge_base/providers/bedrock.py create mode 100644 mindsdb/interfaces/knowledge_base/providers/gemini.py diff --git a/mindsdb/integrations/utilities/rag/rerankers/base_reranker.py b/mindsdb/integrations/utilities/rag/rerankers/base_reranker.py index b97b18898fc..014b44c9ad5 100644 --- a/mindsdb/integrations/utilities/rag/rerankers/base_reranker.py +++ b/mindsdb/integrations/utilities/rag/rerankers/base_reranker.py @@ -23,7 +23,11 @@ DEFAULT_VALID_CLASS_TOKENS, RerankerMode, ) -from mindsdb.integrations.libs.base import BaseMLEngine + +from mindsdb.interfaces.knowledge_base.providers.bedrock import AsyncBedrockClient +from mindsdb.interfaces.knowledge_base.providers.gemini import GeminiClient +from mindsdb.interfaces.knowledge_base.providers.snowflake import SnowflakeClient + log = logging.getLogger(__name__) @@ -47,10 +51,10 @@ class BaseLLMReranker(BaseModel, ABC): base_url: Optional[str] = None api_version: Optional[str] = None num_docs_to_keep: Optional[int] = None # How many of the top documents to keep after reranking & compressing. - method: str = "multi-class" # Scoring method: 'multi-class' or 'binary' + method: str = "no-logprobs" # Scoring method: 'multi-class' or 'no-logprobs' mode: RerankerMode = RerankerMode.POINTWISE _api_key_var: str = "OPENAI_API_KEY" - client: Optional[AsyncOpenAI | BaseMLEngine] = None + client: Optional[AsyncOpenAI | AsyncBedrockClient | GeminiClient | SnowflakeClient] = None _semaphore: Optional[asyncio.Semaphore] = None max_concurrent_requests: int = 20 max_retries: int = 4 @@ -75,6 +79,9 @@ def __init__(self, **kwargs): def _init_client(self): if self.client is None: + if self.provider == "google": + self.provider = "gemini" + if self.provider == "azure_openai": azure_api_key = self.api_key or os.getenv("AZURE_OPENAI_API_KEY") azure_api_endpoint = self.base_url or os.environ.get("AZURE_OPENAI_ENDPOINT") @@ -86,11 +93,21 @@ def _init_client(self): timeout=self.request_timeout, max_retries=2, ) + self.method = "multi-class" + elif self.provider == "bedrock": + kwargs = self.model_extra.copy() + self.client = AsyncBedrockClient(**kwargs) + elif self.provider == "gemini": + self.client = GeminiClient(api_key=self.api_key) + elif self.provider == "snowflake": + kwargs = self.model_extra.copy() + self.client = SnowflakeClient(api_key=self.api_key, **kwargs) elif self.provider in ("openai", "ollama"): if self.provider == "ollama": - self.method = "no-logprobs" if self.api_key is None: self.api_key = "n/a" + else: + self.method = "multi-class" api_key_var: str = "OPENAI_API_KEY" openai_api_key = self.api_key or os.getenv(api_key_var) @@ -102,31 +119,17 @@ def _init_client(self): api_key=openai_api_key, base_url=base_url, timeout=self.request_timeout, max_retries=2 ) else: - # try to use litellm - from mindsdb.api.executor.controllers.session_controller import SessionController - - session = SessionController() - module = session.integration_controller.get_handler_module("litellm") - - if module is None or module.Handler is None: - raise ValueError(f'Unable to use "{self.provider}" provider. Litellm handler is not installed') - - self.client = module.Handler - self.method = "no-logprobs" + raise NotImplementedError(f'Provider "{self.provider}" is not supported') - async def _call_llm(self, messages): + async def _call_llm(self, messages) -> str: if self.provider in ("azure_openai", "openai", "ollama"): - return await self.client.chat.completions.create( + response = await self.client.chat.completions.create( model=self.model, messages=messages, ) + return response.choices[0].message.content else: - kwargs = self.model_extra.copy() - - if self.api_key is not None: - kwargs["api_key"] = self.api_key - - return await self.client.acompletion(self.provider, model=self.model, messages=messages, args=kwargs) + return await self.client.acompletion(model_name=self.model, messages=messages) async def _rank(self, query_document_pairs: List[Tuple[str, str]], rerank_callback=None) -> List[Tuple[str, float]]: ranked_results = [] @@ -237,12 +240,10 @@ async def search_relevancy_no_logprob(self, query: str, document: str) -> Any: f"Search query: {query}" ) - response = await self._call_llm( + answer = await self._call_llm( messages=[{"role": "system", "content": prompt}, {"role": "user", "content": document}], ) - answer = response.choices[0].message.content - try: value = re.findall(r"[\d]+", answer)[0] score = float(value) / 100 @@ -484,8 +485,8 @@ async def _rank_single_batch( for attempt in range(self.max_retries): try: - response = await self._call_llm(messages) - content = response.choices[0].message.content + content = await self._call_llm(messages) + scores = self._extract_scores(content, len(documents)) return list(zip(documents, scores)) except Exception as exc: diff --git a/mindsdb/interfaces/knowledge_base/llm_client.py b/mindsdb/interfaces/knowledge_base/llm_client.py index ab044811b94..4dd30261d13 100644 --- a/mindsdb/interfaces/knowledge_base/llm_client.py +++ b/mindsdb/interfaces/knowledge_base/llm_client.py @@ -7,6 +7,10 @@ from mindsdb.integrations.utilities.handler_utils import get_api_key +from mindsdb.interfaces.knowledge_base.providers.bedrock import BedrockClient +from mindsdb.interfaces.knowledge_base.providers.gemini import GeminiClient +from mindsdb.interfaces.knowledge_base.providers.snowflake import SnowflakeClient + def retry_with_exponential_backoff(func): def decorator(*args, **kwargs): @@ -60,22 +64,23 @@ def wrapper(self, messages, *args, **kwargs): class LLMClient: """ Class for accession to LLM. - It chooses openai client or litellm handler depending on the config + It chooses openai provider client depending on the config """ def __init__(self, params: dict = None, session=None): self._session = session - self.params = params + self.params = params.copy() - self.provider = params.get("provider", "openai") + self.provider = self.params.pop("provider", "openai") + self.model_name = self.params.pop("model_name") + if self.provider == "google": + self.provider = "gemini" if "api_key" not in params: api_key = get_api_key(self.provider, params, strict=False) if api_key is not None: params["api_key"] = api_key - self.engine = "openai" - if self.provider == "azure_openai": azure_api_key = params.get("api_key") or os.getenv("AZURE_OPENAI_API_KEY") azure_api_endpoint = params.get("base_url") or os.environ.get("AZURE_OPENAI_ENDPOINT") @@ -97,56 +102,41 @@ def __init__(self, params: dict = None, session=None): if kwargs.get("api_key") is None: kwargs["api_key"] = "n/a" self.client = OpenAI(**kwargs) + elif self.provider == "bedrock": + self.client = BedrockClient(**self.params) + elif self.provider == "gemini": + self.client = GeminiClient(**self.params) + elif self.provider == "snowflake": + self.client = SnowflakeClient(**self.params) else: - # try to use litellm - if self._session is None: - from mindsdb.api.executor.controllers.session_controller import SessionController - - self._session = SessionController() - module = self._session.integration_controller.get_handler_module("litellm") - - if module is None or module.Handler is None: - raise ValueError(f'Unable to use "{self.provider}" provider. Litellm handler is not installed') - - self.client = module.Handler - self.engine = "litellm" + raise NotImplementedError(f'Provider "{self.provider}" is not supported') @run_in_batches(1000) @retry_with_exponential_backoff def embeddings(self, messages: List[str]): - params = self.params - if self.engine == "openai": + if self.provider in ("openai", "azure_openai"): response = self.client.embeddings.create( - model=params["model_name"], + model=self.model_name, input=messages, ) return [item.embedding for item in response.data] else: - kwargs = params.copy() - model = kwargs.pop("model_name") - kwargs.pop("provider", None) - - return self.client.embeddings(self.provider, model=model, messages=messages, args=kwargs) + return self.client.embeddings(self.model_name, messages) @run_in_batches(100) def completion(self, messages: List[dict], json_output: bool = False) -> List[str]: """ Call LLM completion and get response """ - params = self.params - params["json_output"] = json_output - if self.engine == "openai": + + if self.provider in ("openai", "azure_openai"): response = self.client.chat.completions.create( - model=params["model_name"], + model=self.model_name, messages=messages, ) return [item.message.content for item in response.choices] else: - kwargs = params.copy() - model = kwargs.pop("model_name") - kwargs.pop("provider", None) - response = self.client.completion(self.provider, model=model, messages=messages, args=kwargs) - return [item.message.content for item in response.choices] + return self.client.completion(self.model_name, messages) async def abatch(self, messages_list: List[List[dict]], json_output: bool = False) -> List[List[str]]: """ diff --git a/mindsdb/interfaces/knowledge_base/providers/bedrock.py b/mindsdb/interfaces/knowledge_base/providers/bedrock.py new file mode 100644 index 00000000000..26a134b916e --- /dev/null +++ b/mindsdb/interfaces/knowledge_base/providers/bedrock.py @@ -0,0 +1,137 @@ +import json +from typing import List, Optional + + +def prepare_conversation(messages): + conversation = [] + for message in messages: + content = message["content"] + role = message["role"] + if role == "system": + role = "assistant" + if role != "user": + if len(conversation) == 0: + # the first message has to be user message + content = message["role"] + ":\n" + content + role = "user" + + conversation.append( + { + "role": role, + "content": [{"text": content}], + } + ) + return conversation + + +class AsyncBedrockClient: + def __init__( + self, + aws_access_key_id: Optional[str] = None, + aws_secret_access_key: Optional[str] = None, + region_name: Optional[str] = None, + aws_session_token: Optional[str] = None, + ): + try: + import aioboto3 # type: ignore + except ImportError as exc: # pragma: no cover - environment specific + raise ImportError( + "aioboto3 is required for the Bedrock reranker client. Install it with `pip install aioboto3`." + ) from exc + + self.aws_access_key_id = aws_access_key_id + self.aws_secret_access_key = aws_secret_access_key + self.aws_session_token = aws_session_token + self.region_name = region_name + + self.session = aioboto3.Session() + self._client = None + + async def acompletion( + self, + model_name: str, + messages: List[dict], + temperature: Optional[float] = None, + max_tokens: Optional[int] = None, + top_p: Optional[float] = None, + ): + inferenceConfig = {} + if temperature: + inferenceConfig["temperature"] = temperature + if max_tokens: + inferenceConfig["max_tokens"] = max_tokens + if top_p: + inferenceConfig["top_p"] = top_p + + # convert messages + conversation = prepare_conversation(messages) + + async with self.session.client( + "bedrock-runtime", + aws_access_key_id=self.aws_access_key_id, + aws_secret_access_key=self.aws_secret_access_key, + aws_session_token=self.aws_session_token, + region_name=self.region_name, + ) as client: + response = await client.converse(modelId=model_name, messages=conversation, inferenceConfig=inferenceConfig) + + return response["output"]["message"]["content"][0]["text"] + + +class BedrockClient: + def __init__( + self, + aws_access_key_id: Optional[str] = None, + aws_secret_access_key: Optional[str] = None, + region_name: Optional[str] = None, + aws_session_token: Optional[str] = None, + ): + try: + import boto3 # type: ignore + except ImportError as exc: # pragma: no cover - environment specific + raise ImportError("boto3 is required for the Bedrock client. Install it with `pip install boto3`.") from exc + + self.client = boto3.client( + "bedrock-runtime", + aws_access_key_id=aws_access_key_id, + aws_secret_access_key=aws_secret_access_key, + aws_session_token=aws_session_token, + region_name=region_name, + ) + + def embeddings(self, model_name: str, messages: List[str]): + embeddings = [] + for message in messages: + native_request = {"inputText": message} + request = json.dumps(native_request) + + response = self.client.invoke_model(modelId=model_name, body=request) + model_response = json.loads(response["body"].read()) + + # Extract and print the generated embedding and the input text token count. + embeddings.append(model_response["embedding"]) + + return embeddings + + def completion( + self, + model_name: str, + messages: List[dict], + temperature: Optional[float] = None, + max_tokens: Optional[int] = None, + top_p: Optional[float] = None, + ): + inferenceConfig = {} + if temperature: + inferenceConfig["temperature"] = temperature + if max_tokens: + inferenceConfig["max_tokens"] = max_tokens + if top_p: + inferenceConfig["top_p"] = top_p + + # convert messages + conversation = prepare_conversation(messages) + + response = self.client.converse(modelId=model_name, messages=conversation, inferenceConfig=inferenceConfig) + + return response["output"]["message"]["content"][0]["text"] diff --git a/mindsdb/interfaces/knowledge_base/providers/gemini.py b/mindsdb/interfaces/knowledge_base/providers/gemini.py new file mode 100644 index 00000000000..8b37c3c819a --- /dev/null +++ b/mindsdb/interfaces/knowledge_base/providers/gemini.py @@ -0,0 +1,73 @@ +from typing import List, Optional + + +class GeminiClient: + def __init__(self, api_key: str): + try: + from google import genai + from google.genai import types + except ImportError as exc: # pragma: no cover - environment specific + raise ImportError("google.genai is required. Install it with `pip install google-genai`.") from exc + + self.client = genai.Client(api_key=api_key) + self.types = types + + def embeddings(self, model_name: str, messages: List[str]): + result = self.client.models.embed_content(model=model_name, contents=messages) + + return [item.values for item in result.embeddings] + + def _prepare_messages(self, messages): + contents = [] + for message in messages: + role = message["role"] + # system role is not supported + if role != "user": + role = "model" + + contents.append(self.types.Content(role=role, parts=[self.types.Part(text=message["content"])])) + return contents + + def completion( + self, + model_name: str, + messages: List[dict], + temperature: Optional[float] = None, + max_tokens: Optional[int] = None, + top_p: Optional[float] = None, + ): + config = {} + if temperature: + config["temperature"] = temperature + if max_tokens: + config["max_output_tokens"] = max_tokens + if top_p: + config["top_p"] = top_p + + contents = self._prepare_messages(messages) + + result = self.client.models.generate_content(model=model_name, contents=contents, config=config) + + return result.text + + async def acompletion( + self, + model_name: str, + messages: List[dict], + temperature: Optional[float] = None, + max_tokens: Optional[int] = None, + top_p: Optional[float] = None, + ): + config = {} + if temperature: + config["temperature"] = temperature + if max_tokens: + config["max_output_tokens"] = max_tokens + if top_p: + config["top_p"] = top_p + + contents = self._prepare_messages(messages) + + result = await self.client.aio.models.generate_content(model=model_name, contents=contents, config=config) + + return result.text diff --git a/requirements/requirements.txt b/requirements/requirements.txt index 55c81279c0f..c10c1300f86 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -51,3 +51,8 @@ pydantic-ai>=0.0.14 # Required for Pydantic AI agents bs4 # for rag HTMLDocumentLoader urllib3>=2.6.3 # not directly required, pinned by Snyk to avoid a vulnerability + +# kb providers +aioboto3==15.5.0 +types-aioboto3[bedrock-runtime] +google-genai==1.70.0 From 6a2da23be62dbc7d9a4a3021ac6aad0ed7e09bce Mon Sep 17 00:00:00 2001 From: andrew Date: Thu, 2 Apr 2026 18:46:01 +0300 Subject: [PATCH 07/22] ollama fixes --- mindsdb/interfaces/functions/controller.py | 1 + mindsdb/interfaces/knowledge_base/llm_client.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/mindsdb/interfaces/functions/controller.py b/mindsdb/interfaces/functions/controller.py index 86abe4022a2..7f63fa8b2de 100644 --- a/mindsdb/interfaces/functions/controller.py +++ b/mindsdb/interfaces/functions/controller.py @@ -139,6 +139,7 @@ def llm_call_function(self, node): try: from mindsdb.interfaces.knowledge_base.llm_client import LLMClient + chat_model_params.pop("api_keys", None) llm = LLMClient(chat_model_params, session=self.session) except Exception as e: raise RuntimeError(f"Unable to use LLM function, check ENV variables: {e}") from e diff --git a/mindsdb/interfaces/knowledge_base/llm_client.py b/mindsdb/interfaces/knowledge_base/llm_client.py index ab044811b94..365e84c10e5 100644 --- a/mindsdb/interfaces/knowledge_base/llm_client.py +++ b/mindsdb/interfaces/knowledge_base/llm_client.py @@ -134,7 +134,6 @@ def completion(self, messages: List[dict], json_output: bool = False) -> List[st Call LLM completion and get response """ params = self.params - params["json_output"] = json_output if self.engine == "openai": response = self.client.chat.completions.create( model=params["model_name"], @@ -143,6 +142,7 @@ def completion(self, messages: List[dict], json_output: bool = False) -> List[st return [item.message.content for item in response.choices] else: kwargs = params.copy() + params["json_output"] = json_output model = kwargs.pop("model_name") kwargs.pop("provider", None) response = self.client.completion(self.provider, model=model, messages=messages, args=kwargs) From 9017cae3aaddeca228e355475083058be71bd6ea Mon Sep 17 00:00:00 2001 From: andrew Date: Thu, 2 Apr 2026 18:52:46 +0300 Subject: [PATCH 08/22] fix bedrock --- mindsdb/interfaces/knowledge_base/llm_client.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/mindsdb/interfaces/knowledge_base/llm_client.py b/mindsdb/interfaces/knowledge_base/llm_client.py index 4dd30261d13..aa292ff1671 100644 --- a/mindsdb/interfaces/knowledge_base/llm_client.py +++ b/mindsdb/interfaces/knowledge_base/llm_client.py @@ -103,6 +103,8 @@ def __init__(self, params: dict = None, session=None): kwargs["api_key"] = "n/a" self.client = OpenAI(**kwargs) elif self.provider == "bedrock": + if 'aws_region' in self.params: + self.params['region_name'] = self.params.pop('aws_region') self.client = BedrockClient(**self.params) elif self.provider == "gemini": self.client = GeminiClient(**self.params) From 51c95c9bf9c2efe2da055596ac42a82495163118 Mon Sep 17 00:00:00 2001 From: andrew Date: Fri, 3 Apr 2026 13:49:00 +0300 Subject: [PATCH 09/22] dockstrings --- .../knowledge_base/providers/bedrock.py | 42 +++++++++++-------- .../knowledge_base/providers/gemini.py | 16 ++++--- .../knowledge_base/providers/snowflake.py | 23 ++++++---- 3 files changed, 50 insertions(+), 31 deletions(-) diff --git a/mindsdb/interfaces/knowledge_base/providers/bedrock.py b/mindsdb/interfaces/knowledge_base/providers/bedrock.py index 26a134b916e..81ddb6800a0 100644 --- a/mindsdb/interfaces/knowledge_base/providers/bedrock.py +++ b/mindsdb/interfaces/knowledge_base/providers/bedrock.py @@ -1,8 +1,9 @@ import json -from typing import List, Optional +from typing import Dict, List, Optional -def prepare_conversation(messages): +def prepare_conversation(messages: List[dict]) -> List[dict]: + """Convert chat messages to Bedrock `converse` message payload format.""" conversation = [] for message in messages: content = message["content"] @@ -25,6 +26,8 @@ def prepare_conversation(messages): class AsyncBedrockClient: + """Async Bedrock runtime client wrapper""" + def __init__( self, aws_access_key_id: Optional[str] = None, @@ -54,16 +57,16 @@ async def acompletion( temperature: Optional[float] = None, max_tokens: Optional[int] = None, top_p: Optional[float] = None, - ): - inferenceConfig = {} + ) -> str: + """Generate a chat completion asynchronously via Bedrock.""" + inference_config = {} if temperature: - inferenceConfig["temperature"] = temperature + inference_config["temperature"] = temperature if max_tokens: - inferenceConfig["max_tokens"] = max_tokens + inference_config["max_tokens"] = max_tokens if top_p: - inferenceConfig["top_p"] = top_p + inference_config["top_p"] = top_p - # convert messages conversation = prepare_conversation(messages) async with self.session.client( @@ -73,12 +76,16 @@ async def acompletion( aws_session_token=self.aws_session_token, region_name=self.region_name, ) as client: - response = await client.converse(modelId=model_name, messages=conversation, inferenceConfig=inferenceConfig) + response = await client.converse( + modelId=model_name, messages=conversation, inferenceConfig=inference_config + ) return response["output"]["message"]["content"][0]["text"] class BedrockClient: + """Synchronous Bedrock runtime client wrapper""" + def __init__( self, aws_access_key_id: Optional[str] = None, @@ -99,7 +106,8 @@ def __init__( region_name=region_name, ) - def embeddings(self, model_name: str, messages: List[str]): + def embeddings(self, model_name: str, messages: List[str]) -> List[List[float]]: + """Request embedding vectors for each text in `messages`.""" embeddings = [] for message in messages: native_request = {"inputText": message} @@ -120,18 +128,18 @@ def completion( temperature: Optional[float] = None, max_tokens: Optional[int] = None, top_p: Optional[float] = None, - ): - inferenceConfig = {} + ) -> str: + """Generate a chat completion synchronously via Bedrock.""" + inference_config: Dict[str, float | int] = {} if temperature: - inferenceConfig["temperature"] = temperature + inference_config["temperature"] = temperature if max_tokens: - inferenceConfig["max_tokens"] = max_tokens + inference_config["max_tokens"] = max_tokens if top_p: - inferenceConfig["top_p"] = top_p + inference_config["top_p"] = top_p - # convert messages conversation = prepare_conversation(messages) - response = self.client.converse(modelId=model_name, messages=conversation, inferenceConfig=inferenceConfig) + response = self.client.converse(modelId=model_name, messages=conversation, inferenceConfig=inference_config) return response["output"]["message"]["content"][0]["text"] diff --git a/mindsdb/interfaces/knowledge_base/providers/gemini.py b/mindsdb/interfaces/knowledge_base/providers/gemini.py index 8b37c3c819a..18a055a7773 100644 --- a/mindsdb/interfaces/knowledge_base/providers/gemini.py +++ b/mindsdb/interfaces/knowledge_base/providers/gemini.py @@ -1,7 +1,9 @@ -from typing import List, Optional +from typing import Any, List, Optional class GeminiClient: + """Wrapper around google-genai SDK""" + def __init__(self, api_key: str): try: from google import genai @@ -12,12 +14,14 @@ def __init__(self, api_key: str): self.client = genai.Client(api_key=api_key) self.types = types - def embeddings(self, model_name: str, messages: List[str]): + def embeddings(self, model_name: str, messages: List[str]) -> List[List[float]]: + """Generate embedding vectors for each text in `messages`.""" result = self.client.models.embed_content(model=model_name, contents=messages) return [item.values for item in result.embeddings] - def _prepare_messages(self, messages): + def _prepare_messages(self, messages: List[dict]) -> List[Any]: + """Convert chat messages into google-genai content payloads.""" contents = [] for message in messages: role = message["role"] @@ -35,7 +39,8 @@ def completion( temperature: Optional[float] = None, max_tokens: Optional[int] = None, top_p: Optional[float] = None, - ): + ) -> str: + """Produce a chat response""" config = {} if temperature: config["temperature"] = temperature @@ -57,7 +62,8 @@ async def acompletion( temperature: Optional[float] = None, max_tokens: Optional[int] = None, top_p: Optional[float] = None, - ): + ) -> str: + """Async variant of `completion` using the SDK aio client.""" config = {} if temperature: config["temperature"] = temperature diff --git a/mindsdb/interfaces/knowledge_base/providers/snowflake.py b/mindsdb/interfaces/knowledge_base/providers/snowflake.py index 2a1f2a03990..7bcc3fc6382 100644 --- a/mindsdb/interfaces/knowledge_base/providers/snowflake.py +++ b/mindsdb/interfaces/knowledge_base/providers/snowflake.py @@ -1,11 +1,11 @@ -from typing import List, Optional +from typing import Dict, List, Optional, Union import requests import httpx -def _raise_for_status(response): - # show response text in error +def _raise_for_status(response: Union[requests.Response, httpx.Response]) -> None: + """Raise an informative HTTPError when Snowflake responds with an error.""" if 400 <= response.status_code < 600: if hasattr(response, "reason"): reason = response.reason @@ -17,7 +17,9 @@ def _raise_for_status(response): class SnowflakeClient: - def __init__(self, account_id: str = None, api_key: str = None): + """Wrapper over Snowflake Cortex REST endpoints.""" + + def __init__(self, account_id: Optional[str] = None, api_key: Optional[str] = None): if account_id is None: raise ValueError("account_id must be provided") if api_key is None: @@ -31,10 +33,10 @@ def __init__(self, account_id: str = None, api_key: str = None): self.api_key = self.api_key[4:] self.auth_type = "PROGRAMMATIC_ACCESS_TOKEN" - def _get_base_url(self): + def _get_base_url(self) -> str: return f"https://{self.account_id}.snowflakecomputing.com/api/v2" - def _get_headers(self): + def _get_headers(self) -> Dict[str, str]: return { "Content-Type": "application/json", "Accept": "application/json", @@ -42,7 +44,8 @@ def _get_headers(self): "X-Snowflake-Authorization-Token-Type": self.auth_type, } - def embeddings(self, model_name: str, messages: List[str]): + def embeddings(self, model_name: str, messages: List[str]) -> List[List[float]]: + """Request embedding vectors for the provided `messages`.""" url = f"{self._get_base_url()}/cortex/inference:embed" payload = {"text": messages, "model": model_name} @@ -62,7 +65,8 @@ def completion( temperature: Optional[float] = None, max_tokens: Optional[int] = None, top_p: Optional[float] = None, - ): + ) -> str: + """Generate a chat completion with the Cortex complete endpoint.""" url = f"{self._get_base_url()}/cortex/inference:complete" payload = { @@ -90,7 +94,8 @@ async def acompletion( temperature: Optional[float] = None, max_tokens: Optional[int] = None, top_p: Optional[float] = None, - ): + ) -> str: + """Async variant of `completion` using httpx.""" url = f"{self._get_base_url()}/cortex/inference:complete" payload = { From 7d724172c06729b4a0527976630712be1ac42e28 Mon Sep 17 00:00:00 2001 From: andrew Date: Fri, 3 Apr 2026 14:11:32 +0300 Subject: [PATCH 10/22] unit tests --- .../interfaces/knowledge_base/llm_client.py | 4 +- .../knowledge_base/providers/bedrock.py | 8 ++-- tests/unit/executor/test_knowledge_base.py | 41 +++++++++++++++++++ 3 files changed, 47 insertions(+), 6 deletions(-) diff --git a/mindsdb/interfaces/knowledge_base/llm_client.py b/mindsdb/interfaces/knowledge_base/llm_client.py index aa292ff1671..475e56cd6da 100644 --- a/mindsdb/interfaces/knowledge_base/llm_client.py +++ b/mindsdb/interfaces/knowledge_base/llm_client.py @@ -103,8 +103,8 @@ def __init__(self, params: dict = None, session=None): kwargs["api_key"] = "n/a" self.client = OpenAI(**kwargs) elif self.provider == "bedrock": - if 'aws_region' in self.params: - self.params['region_name'] = self.params.pop('aws_region') + if "aws_region" in self.params: + self.params["aws_region_name"] = self.params.pop("aws_region") self.client = BedrockClient(**self.params) elif self.provider == "gemini": self.client = GeminiClient(**self.params) diff --git a/mindsdb/interfaces/knowledge_base/providers/bedrock.py b/mindsdb/interfaces/knowledge_base/providers/bedrock.py index 81ddb6800a0..a8ce768775d 100644 --- a/mindsdb/interfaces/knowledge_base/providers/bedrock.py +++ b/mindsdb/interfaces/knowledge_base/providers/bedrock.py @@ -32,7 +32,7 @@ def __init__( self, aws_access_key_id: Optional[str] = None, aws_secret_access_key: Optional[str] = None, - region_name: Optional[str] = None, + aws_region_name: Optional[str] = None, aws_session_token: Optional[str] = None, ): try: @@ -45,7 +45,7 @@ def __init__( self.aws_access_key_id = aws_access_key_id self.aws_secret_access_key = aws_secret_access_key self.aws_session_token = aws_session_token - self.region_name = region_name + self.region_name = aws_region_name self.session = aioboto3.Session() self._client = None @@ -90,7 +90,7 @@ def __init__( self, aws_access_key_id: Optional[str] = None, aws_secret_access_key: Optional[str] = None, - region_name: Optional[str] = None, + aws_region_name: Optional[str] = None, aws_session_token: Optional[str] = None, ): try: @@ -103,7 +103,7 @@ def __init__( aws_access_key_id=aws_access_key_id, aws_secret_access_key=aws_secret_access_key, aws_session_token=aws_session_token, - region_name=region_name, + region_name=aws_region_name, ) def embeddings(self, model_name: str, messages: List[str]) -> List[List[float]]: diff --git a/tests/unit/executor/test_knowledge_base.py b/tests/unit/executor/test_knowledge_base.py index 991166e45ab..ff7d1be72e6 100644 --- a/tests/unit/executor/test_knowledge_base.py +++ b/tests/unit/executor/test_knowledge_base.py @@ -1352,6 +1352,47 @@ def test_create_index(self, mock_embedding): ret = self.run_sql("select * from kb_ral where content='white'") assert "white" in ret["chunk_content"].iloc[0] + def test_providers(self): + with patch("mindsdb.interfaces.knowledge_base.llm_client.BedrockClient.embeddings") as embed: + with patch( + "mindsdb.integrations.utilities.rag.rerankers.base_reranker.AsyncBedrockClient.acompletion" + ) as rerank: + embed.return_value = [[1, 1, 1]] + rerank.return_value = "100" + self._create_kb( + "kb_test", + embedding_model={ + "provider": "bedrock", + "model_name": "amazon.titan", + "aws_access_key_id": "-", + "aws_region_name": "us-east-2", + "aws_secret_access_key": "-", + }, + reranking_model={ + "provider": "bedrock", + "model_name": "llama3", + "aws_access_key_id": "-", + "aws_region_name": "us-east-2", + "aws_secret_access_key": "-", + }, + ) + assert embed.call_args_list[0][0][0] == "amazon.titan" + assert rerank.call_args_list[0][1]["model_name"] == "llama3" + + with patch("mindsdb.interfaces.knowledge_base.llm_client.SnowflakeClient.embeddings") as embed: + embed.return_value = [[1, 1, 1]] + self._create_kb( + "kb_test", + embedding_model={"provider": "snowflake", "model_name": "arctic", "account_id": "ABC", "api_key": "-"}, + ) + assert embed.call_args_list[0][0][0] == "arctic" + with patch("mindsdb.interfaces.knowledge_base.llm_client.GeminiClient.embeddings") as embed: + embed.return_value = [[1, 1, 1]] + self._create_kb( + "kb_test", embedding_model={"provider": "gemini", "model_name": "gemini-embedding", "api_key": "-"} + ) + assert embed.call_args_list[0][0][0] == "gemini-embedding" + class TestKBAutoBatch(BaseTestKB): @patch("mindsdb.interfaces.knowledge_base.controller.LLMClient") From 3823774cee191d4b68c8d7f7f8696d211329b692 Mon Sep 17 00:00:00 2001 From: andrew Date: Fri, 3 Apr 2026 14:29:10 +0300 Subject: [PATCH 11/22] check reqs --- tests/scripts/check_requirements.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/scripts/check_requirements.py b/tests/scripts/check_requirements.py index 8aa0f0f8bf8..24314fe868e 100644 --- a/tests/scripts/check_requirements.py +++ b/tests/scripts/check_requirements.py @@ -110,6 +110,8 @@ def get_requirements_with_DEP002(path): "numba", # required in a few files for the hierarchicalforecast. Otherwise, uv may install an old version. "urllib3", # pinned by Snyk to avoid a vulnerability "faiss-cpu", + "types-aioboto3", # aioboto3 is imported + "google-genai", # google.genai is imported ], } From bef4fce186897302107532d99af50fbf5089a710 Mon Sep 17 00:00:00 2001 From: andrew Date: Fri, 3 Apr 2026 14:31:48 +0300 Subject: [PATCH 12/22] check reqs --- tests/scripts/check_requirements.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/scripts/check_requirements.py b/tests/scripts/check_requirements.py index 24314fe868e..5c1e8b2a429 100644 --- a/tests/scripts/check_requirements.py +++ b/tests/scripts/check_requirements.py @@ -110,8 +110,6 @@ def get_requirements_with_DEP002(path): "numba", # required in a few files for the hierarchicalforecast. Otherwise, uv may install an old version. "urllib3", # pinned by Snyk to avoid a vulnerability "faiss-cpu", - "types-aioboto3", # aioboto3 is imported - "google-genai", # google.genai is imported ], } @@ -203,6 +201,7 @@ def get_requirements_with_DEP002(path): "google-analytics-admin": ["google"], "google-auth": ["google"], "google-cloud-storage": ["google"], + "google-genai": ["google"], "google-auth-oauthlib": ["google_auth_oauthlib"], "google-api-python-client": ["googleapiclient"], "ibm-cos-sdk": ["ibm_boto3", "ibm_botocore"], @@ -260,6 +259,7 @@ def get_requirements_with_DEP002(path): "python-dotenv": ["dotenv"], "pyjwt": ["jwt"], "sklearn": ["scikit-learn"], + "types-aioboto3": ["aioboto3"], } # We use this to exit with a non-zero status code if any check fails From 6f93970985a8c8eaa0a434b4fda3612605859bf3 Mon Sep 17 00:00:00 2001 From: andrew Date: Fri, 3 Apr 2026 14:52:46 +0300 Subject: [PATCH 13/22] fix listwise rerank --- tests/unit/executor/test_knowledge_base.py | 21 +++------------------ 1 file changed, 3 insertions(+), 18 deletions(-) diff --git a/tests/unit/executor/test_knowledge_base.py b/tests/unit/executor/test_knowledge_base.py index ff7d1be72e6..7646caca9ae 100644 --- a/tests/unit/executor/test_knowledge_base.py +++ b/tests/unit/executor/test_knowledge_base.py @@ -325,13 +325,8 @@ class _Choice: def __init__(self, content): self.message = _Msg(content) - class _Resp: - def __init__(self, content): - self.choices = [_Choice(content)] - async def _fake_call_llm(messages): - content = '{"ranking": [{"doc_index": 2, "score": 0.9}, {"doc_index": 1, "score": 0.6}, {"doc_index": 3, "score": 0.1}]}' - return _Resp(content) + return '{"ranking": [{"doc_index": 2, "score": 0.9}, {"doc_index": 1, "score": 0.6}, {"doc_index": 3, "score": 0.1}]}' # Bind the async method to this reranker instance reranker._call_llm = _fake_call_llm # type: ignore @@ -356,16 +351,11 @@ class _Choice: def __init__(self, content): self.message = _Msg(content) - class _Resp: - def __init__(self, content): - self.choices = [_Choice(content)] - async def _fake_call_llm(messages): # Returns code-fenced JSON, includes only two entries, one without score - content = """```json + return """```json {"ranking": [1, {"doc_index": 3, "score": 0.8}]} ```""" - return _Resp(content) reranker._call_llm = _fake_call_llm # type: ignore @@ -389,14 +379,9 @@ class _Choice: def __init__(self, content): self.message = _Msg(content) - class _Resp: - def __init__(self, content): - self.choices = [_Choice(content)] - async def _fake_call_llm(messages): # Invalid JSON forces fallback - content = "not-json" - return _Resp(content) + return "not-json" reranker._call_llm = _fake_call_llm # type: ignore From bcfcf71c58b88e29c20a1102c9353f81d8046872 Mon Sep 17 00:00:00 2001 From: andrew Date: Mon, 6 Apr 2026 14:50:59 +0300 Subject: [PATCH 14/22] remove pragma --- mindsdb/interfaces/knowledge_base/providers/bedrock.py | 8 ++++---- mindsdb/interfaces/knowledge_base/providers/gemini.py | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/mindsdb/interfaces/knowledge_base/providers/bedrock.py b/mindsdb/interfaces/knowledge_base/providers/bedrock.py index a8ce768775d..be6c8fe2fb8 100644 --- a/mindsdb/interfaces/knowledge_base/providers/bedrock.py +++ b/mindsdb/interfaces/knowledge_base/providers/bedrock.py @@ -36,8 +36,8 @@ def __init__( aws_session_token: Optional[str] = None, ): try: - import aioboto3 # type: ignore - except ImportError as exc: # pragma: no cover - environment specific + import aioboto3 + except ImportError as exc: raise ImportError( "aioboto3 is required for the Bedrock reranker client. Install it with `pip install aioboto3`." ) from exc @@ -94,8 +94,8 @@ def __init__( aws_session_token: Optional[str] = None, ): try: - import boto3 # type: ignore - except ImportError as exc: # pragma: no cover - environment specific + import boto3 + except ImportError as exc: raise ImportError("boto3 is required for the Bedrock client. Install it with `pip install boto3`.") from exc self.client = boto3.client( diff --git a/mindsdb/interfaces/knowledge_base/providers/gemini.py b/mindsdb/interfaces/knowledge_base/providers/gemini.py index 18a055a7773..33e2bc314d0 100644 --- a/mindsdb/interfaces/knowledge_base/providers/gemini.py +++ b/mindsdb/interfaces/knowledge_base/providers/gemini.py @@ -8,7 +8,7 @@ def __init__(self, api_key: str): try: from google import genai from google.genai import types - except ImportError as exc: # pragma: no cover - environment specific + except ImportError as exc: raise ImportError("google.genai is required. Install it with `pip install google-genai`.") from exc self.client = genai.Client(api_key=api_key) From 90f7342b42553c936f46e725795153fcd282c089 Mon Sep 17 00:00:00 2001 From: andrew Date: Mon, 6 Apr 2026 15:14:30 +0300 Subject: [PATCH 15/22] fix --- mindsdb/interfaces/knowledge_base/llm_client.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mindsdb/interfaces/knowledge_base/llm_client.py b/mindsdb/interfaces/knowledge_base/llm_client.py index 475e56cd6da..6caf350f5e7 100644 --- a/mindsdb/interfaces/knowledge_base/llm_client.py +++ b/mindsdb/interfaces/knowledge_base/llm_client.py @@ -64,7 +64,7 @@ def wrapper(self, messages, *args, **kwargs): class LLMClient: """ Class for accession to LLM. - It chooses openai provider client depending on the config + It chooses provider client depending on the config """ def __init__(self, params: dict = None, session=None): From 07e38126879e11c995f77a678a6c9fd4b361abe4 Mon Sep 17 00:00:00 2001 From: andrew Date: Wed, 8 Apr 2026 19:27:45 +0300 Subject: [PATCH 16/22] fix --- mindsdb/interfaces/knowledge_base/llm_client.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mindsdb/interfaces/knowledge_base/llm_client.py b/mindsdb/interfaces/knowledge_base/llm_client.py index 6caf350f5e7..e5f33726f76 100644 --- a/mindsdb/interfaces/knowledge_base/llm_client.py +++ b/mindsdb/interfaces/knowledge_base/llm_client.py @@ -138,7 +138,7 @@ def completion(self, messages: List[dict], json_output: bool = False) -> List[st ) return [item.message.content for item in response.choices] else: - return self.client.completion(self.model_name, messages) + return [self.client.completion(self.model_name, messages)] async def abatch(self, messages_list: List[List[dict]], json_output: bool = False) -> List[List[str]]: """ From 5808da3d67560b5445b9d31ebaf5ada6cac65bed Mon Sep 17 00:00:00 2001 From: andrew Date: Thu, 9 Apr 2026 20:10:04 +0300 Subject: [PATCH 17/22] replace dependency aioboto3 -> aiobotocore --- .../knowledge_base/providers/bedrock.py | 36 ++++++++++--------- requirements/requirements.txt | 3 +- 2 files changed, 20 insertions(+), 19 deletions(-) diff --git a/mindsdb/interfaces/knowledge_base/providers/bedrock.py b/mindsdb/interfaces/knowledge_base/providers/bedrock.py index be6c8fe2fb8..2f50c5c4c87 100644 --- a/mindsdb/interfaces/knowledge_base/providers/bedrock.py +++ b/mindsdb/interfaces/knowledge_base/providers/bedrock.py @@ -36,10 +36,10 @@ def __init__( aws_session_token: Optional[str] = None, ): try: - import aioboto3 + from aiobotocore.session import get_session except ImportError as exc: raise ImportError( - "aioboto3 is required for the Bedrock reranker client. Install it with `pip install aioboto3`." + "aiobotocore is required for the Bedrock reranker client. Install it with `pip install aiobotocore`." ) from exc self.aws_access_key_id = aws_access_key_id @@ -47,8 +47,7 @@ def __init__( self.aws_session_token = aws_session_token self.region_name = aws_region_name - self.session = aioboto3.Session() - self._client = None + self._session = get_session() async def acompletion( self, @@ -60,22 +59,25 @@ async def acompletion( ) -> str: """Generate a chat completion asynchronously via Bedrock.""" inference_config = {} - if temperature: + if temperature is not None: inference_config["temperature"] = temperature - if max_tokens: + if max_tokens is not None: inference_config["max_tokens"] = max_tokens - if top_p: + if top_p is not None: inference_config["top_p"] = top_p conversation = prepare_conversation(messages) - async with self.session.client( - "bedrock-runtime", - aws_access_key_id=self.aws_access_key_id, - aws_secret_access_key=self.aws_secret_access_key, - aws_session_token=self.aws_session_token, - region_name=self.region_name, - ) as client: + # Create client with credentials + client_kwargs = { + "service_name": "bedrock-runtime", + "region_name": self.region_name, + "aws_access_key_id": self.aws_access_key_id, + "aws_secret_access_key": self.aws_secret_access_key, + "aws_session_token": self.aws_session_token, + } + + async with self._session.create_client(**client_kwargs) as client: response = await client.converse( modelId=model_name, messages=conversation, inferenceConfig=inference_config ) @@ -131,11 +133,11 @@ def completion( ) -> str: """Generate a chat completion synchronously via Bedrock.""" inference_config: Dict[str, float | int] = {} - if temperature: + if temperature is not None: inference_config["temperature"] = temperature - if max_tokens: + if max_tokens is not None: inference_config["max_tokens"] = max_tokens - if top_p: + if top_p is not None: inference_config["top_p"] = top_p conversation = prepare_conversation(messages) diff --git a/requirements/requirements.txt b/requirements/requirements.txt index 2c4292f2740..ab5217a0416 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -53,6 +53,5 @@ bs4 # for rag HTMLDocumentLoader urllib3>=2.6.3 # not directly required, pinned by Snyk to avoid a vulnerability # kb providers -aioboto3==15.5.0 -types-aioboto3[bedrock-runtime] +aiobotocore==3.4.0 google-genai==1.70.0 From 6290b1ed7cc3efb5e5247f4af800118d374449cd Mon Sep 17 00:00:00 2001 From: andrew Date: Mon, 13 Apr 2026 13:04:42 +0300 Subject: [PATCH 18/22] don't save params --- .../interfaces/knowledge_base/llm_client.py | 25 ++++++++----------- 1 file changed, 11 insertions(+), 14 deletions(-) diff --git a/mindsdb/interfaces/knowledge_base/llm_client.py b/mindsdb/interfaces/knowledge_base/llm_client.py index e5f33726f76..d043800c6fd 100644 --- a/mindsdb/interfaces/knowledge_base/llm_client.py +++ b/mindsdb/interfaces/knowledge_base/llm_client.py @@ -69,10 +69,10 @@ class LLMClient: def __init__(self, params: dict = None, session=None): self._session = session - self.params = params.copy() + params = params.copy() - self.provider = self.params.pop("provider", "openai") - self.model_name = self.params.pop("model_name") + self.provider = params.pop("provider", "openai") + self.model_name = params.pop("model_name") if self.provider == "google": self.provider = "gemini" @@ -96,20 +96,17 @@ def __init__(self, params: dict = None, session=None): kwargs["base_url"] = base_url self.client = OpenAI(**kwargs) elif self.provider == "ollama": - kwargs = params.copy() - kwargs.pop("model_name") - kwargs.pop("provider", None) - if kwargs.get("api_key") is None: - kwargs["api_key"] = "n/a" - self.client = OpenAI(**kwargs) + if params.get("api_key") is None: + params["api_key"] = "n/a" + self.client = OpenAI(**params) elif self.provider == "bedrock": - if "aws_region" in self.params: - self.params["aws_region_name"] = self.params.pop("aws_region") - self.client = BedrockClient(**self.params) + if "aws_region" in params: + params["aws_region_name"] = params.pop("aws_region") + self.client = BedrockClient(**params) elif self.provider == "gemini": - self.client = GeminiClient(**self.params) + self.client = GeminiClient(**params) elif self.provider == "snowflake": - self.client = SnowflakeClient(**self.params) + self.client = SnowflakeClient(**params) else: raise NotImplementedError(f'Provider "{self.provider}" is not supported') From ca820050933f1d2dc503ba13365a64811e0c3535 Mon Sep 17 00:00:00 2001 From: andrew Date: Mon, 13 Apr 2026 13:14:39 +0300 Subject: [PATCH 19/22] fix ollama --- mindsdb/interfaces/knowledge_base/llm_client.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mindsdb/interfaces/knowledge_base/llm_client.py b/mindsdb/interfaces/knowledge_base/llm_client.py index d043800c6fd..61591b62c17 100644 --- a/mindsdb/interfaces/knowledge_base/llm_client.py +++ b/mindsdb/interfaces/knowledge_base/llm_client.py @@ -113,7 +113,7 @@ def __init__(self, params: dict = None, session=None): @run_in_batches(1000) @retry_with_exponential_backoff def embeddings(self, messages: List[str]): - if self.provider in ("openai", "azure_openai"): + if self.provider in ("openai", "azure_openai", "ollama"): response = self.client.embeddings.create( model=self.model_name, input=messages, @@ -128,7 +128,7 @@ def completion(self, messages: List[dict], json_output: bool = False) -> List[st Call LLM completion and get response """ - if self.provider in ("openai", "azure_openai"): + if self.provider in ("openai", "azure_openai", "ollama"): response = self.client.chat.completions.create( model=self.model_name, messages=messages, From 800f0eda376cd66e31c8700e9e8079f1bccbfa9e Mon Sep 17 00:00:00 2001 From: andrew Date: Tue, 14 Apr 2026 14:27:18 +0300 Subject: [PATCH 20/22] fix init --- mindsdb/interfaces/knowledge_base/providers/__init__.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 mindsdb/interfaces/knowledge_base/providers/__init__.py diff --git a/mindsdb/interfaces/knowledge_base/providers/__init__.py b/mindsdb/interfaces/knowledge_base/providers/__init__.py new file mode 100644 index 00000000000..e69de29bb2d From e3ca04a0f4a8d8a1a1af4b0546ff331c943324e7 Mon Sep 17 00:00:00 2001 From: andrew Date: Tue, 14 Apr 2026 14:28:01 +0300 Subject: [PATCH 21/22] ruff --- mindsdb/interfaces/knowledge_base/providers/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/mindsdb/interfaces/knowledge_base/providers/__init__.py b/mindsdb/interfaces/knowledge_base/providers/__init__.py index e69de29bb2d..8b137891791 100644 --- a/mindsdb/interfaces/knowledge_base/providers/__init__.py +++ b/mindsdb/interfaces/knowledge_base/providers/__init__.py @@ -0,0 +1 @@ + From 5f276ad8e794314ad9d5368b4f3c6fcb751bc3f5 Mon Sep 17 00:00:00 2001 From: andrew Date: Tue, 14 Apr 2026 17:27:11 +0300 Subject: [PATCH 22/22] fix to keep error message --- mindsdb/interfaces/knowledge_base/controller.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mindsdb/interfaces/knowledge_base/controller.py b/mindsdb/interfaces/knowledge_base/controller.py index 0ba95236119..dd7d7397ca8 100644 --- a/mindsdb/interfaces/knowledge_base/controller.py +++ b/mindsdb/interfaces/knowledge_base/controller.py @@ -1319,9 +1319,9 @@ def _check_embedding_model(self, project_name, params: dict = None, kb_name="") f"Wrong embedding provider: {params['provider']}. Available providers: {', '.join(avail_providers)}" ) - llm_client = LLMClient(params, session=self.session) - try: + llm_client = LLMClient(params, session=self.session) + resp = llm_client.embeddings(["test"]) return {"dimension": len(resp[0])} except Exception as e: