diff --git a/mindsdb/integrations/utilities/rag/rerankers/base_reranker.py b/mindsdb/integrations/utilities/rag/rerankers/base_reranker.py index c88385f7e2e..1af44388872 100644 --- a/mindsdb/integrations/utilities/rag/rerankers/base_reranker.py +++ b/mindsdb/integrations/utilities/rag/rerankers/base_reranker.py @@ -22,7 +22,11 @@ DEFAULT_VALID_CLASS_TOKENS, RerankerMode, ) -from mindsdb.integrations.libs.base import BaseMLEngine + +from mindsdb.interfaces.knowledge_base.providers.bedrock import AsyncBedrockClient +from mindsdb.interfaces.knowledge_base.providers.gemini import GeminiClient +from mindsdb.interfaces.knowledge_base.providers.snowflake import SnowflakeClient + log = logging.getLogger(__name__) @@ -46,10 +50,10 @@ class BaseLLMReranker(BaseModel): base_url: Optional[str] = None api_version: Optional[str] = None num_docs_to_keep: Optional[int] = None # How many of the top documents to keep after reranking & compressing. - method: str = "multi-class" # Scoring method: 'multi-class' or 'binary' + method: str = "no-logprobs" # Scoring method: 'multi-class' or 'no-logprobs' mode: RerankerMode = RerankerMode.POINTWISE _api_key_var: str = "OPENAI_API_KEY" - client: Optional[AsyncOpenAI | BaseMLEngine] = None + client: Optional[AsyncOpenAI | AsyncBedrockClient | GeminiClient | SnowflakeClient] = None _semaphore: Optional[asyncio.Semaphore] = None max_concurrent_requests: int = 20 max_retries: int = 4 @@ -74,6 +78,9 @@ def __init__(self, **kwargs): def _init_client(self): if self.client is None: + if self.provider == "google": + self.provider = "gemini" + if self.provider == "azure_openai": azure_api_key = self.api_key or os.getenv("AZURE_OPENAI_API_KEY") azure_api_endpoint = self.base_url or os.environ.get("AZURE_OPENAI_ENDPOINT") @@ -85,11 +92,21 @@ def _init_client(self): timeout=self.request_timeout, max_retries=2, ) + self.method = "multi-class" + elif self.provider == "bedrock": + kwargs = self.model_extra.copy() + self.client = AsyncBedrockClient(**kwargs) + elif self.provider == "gemini": + self.client = GeminiClient(api_key=self.api_key) + elif self.provider == "snowflake": + kwargs = self.model_extra.copy() + self.client = SnowflakeClient(api_key=self.api_key, **kwargs) elif self.provider in ("openai", "ollama"): if self.provider == "ollama": - self.method = "no-logprobs" if self.api_key is None: self.api_key = "n/a" + else: + self.method = "multi-class" api_key_var: str = "OPENAI_API_KEY" openai_api_key = self.api_key or os.getenv(api_key_var) @@ -101,31 +118,17 @@ def _init_client(self): api_key=openai_api_key, base_url=base_url, timeout=self.request_timeout, max_retries=2 ) else: - # try to use litellm - from mindsdb.api.executor.controllers.session_controller import SessionController - - session = SessionController() - module = session.integration_controller.get_handler_module("litellm") - - if module is None or module.Handler is None: - raise ValueError(f'Unable to use "{self.provider}" provider. Litellm handler is not installed') - - self.client = module.Handler - self.method = "no-logprobs" + raise NotImplementedError(f'Provider "{self.provider}" is not supported') - async def _call_llm(self, messages): + async def _call_llm(self, messages) -> str: if self.provider in ("azure_openai", "openai", "ollama"): - return await self.client.chat.completions.create( + response = await self.client.chat.completions.create( model=self.model, messages=messages, ) + return response.choices[0].message.content else: - kwargs = self.model_extra.copy() - - if self.api_key is not None: - kwargs["api_key"] = self.api_key - - return await self.client.acompletion(self.provider, model=self.model, messages=messages, args=kwargs) + return await self.client.acompletion(model_name=self.model, messages=messages) async def _rank(self, query_document_pairs: List[Tuple[str, str]], rerank_callback=None) -> List[Tuple[str, float]]: ranked_results = [] @@ -236,12 +239,10 @@ async def search_relevancy_no_logprob(self, query: str, document: str) -> Any: f"Search query: {query}" ) - response = await self._call_llm( + answer = await self._call_llm( messages=[{"role": "system", "content": prompt}, {"role": "user", "content": document}], ) - answer = response.choices[0].message.content - try: value = re.findall(r"[\d]+", answer)[0] score = float(value) / 100 @@ -483,8 +484,8 @@ async def _rank_single_batch( for attempt in range(self.max_retries): try: - response = await self._call_llm(messages) - content = response.choices[0].message.content + content = await self._call_llm(messages) + scores = self._extract_scores(content, len(documents)) return list(zip(documents, scores)) except Exception as exc: diff --git a/mindsdb/interfaces/knowledge_base/controller.py b/mindsdb/interfaces/knowledge_base/controller.py index 0ba95236119..dd7d7397ca8 100644 --- a/mindsdb/interfaces/knowledge_base/controller.py +++ b/mindsdb/interfaces/knowledge_base/controller.py @@ -1319,9 +1319,9 @@ def _check_embedding_model(self, project_name, params: dict = None, kb_name="") f"Wrong embedding provider: {params['provider']}. Available providers: {', '.join(avail_providers)}" ) - llm_client = LLMClient(params, session=self.session) - try: + llm_client = LLMClient(params, session=self.session) + resp = llm_client.embeddings(["test"]) return {"dimension": len(resp[0])} except Exception as e: diff --git a/mindsdb/interfaces/knowledge_base/llm_client.py b/mindsdb/interfaces/knowledge_base/llm_client.py index 365e84c10e5..61591b62c17 100644 --- a/mindsdb/interfaces/knowledge_base/llm_client.py +++ b/mindsdb/interfaces/knowledge_base/llm_client.py @@ -7,6 +7,10 @@ from mindsdb.integrations.utilities.handler_utils import get_api_key +from mindsdb.interfaces.knowledge_base.providers.bedrock import BedrockClient +from mindsdb.interfaces.knowledge_base.providers.gemini import GeminiClient +from mindsdb.interfaces.knowledge_base.providers.snowflake import SnowflakeClient + def retry_with_exponential_backoff(func): def decorator(*args, **kwargs): @@ -60,22 +64,23 @@ def wrapper(self, messages, *args, **kwargs): class LLMClient: """ Class for accession to LLM. - It chooses openai client or litellm handler depending on the config + It chooses provider client depending on the config """ def __init__(self, params: dict = None, session=None): self._session = session - self.params = params + params = params.copy() - self.provider = params.get("provider", "openai") + self.provider = params.pop("provider", "openai") + self.model_name = params.pop("model_name") + if self.provider == "google": + self.provider = "gemini" if "api_key" not in params: api_key = get_api_key(self.provider, params, strict=False) if api_key is not None: params["api_key"] = api_key - self.engine = "openai" - if self.provider == "azure_openai": azure_api_key = params.get("api_key") or os.getenv("AZURE_OPENAI_API_KEY") azure_api_endpoint = params.get("base_url") or os.environ.get("AZURE_OPENAI_ENDPOINT") @@ -91,62 +96,46 @@ def __init__(self, params: dict = None, session=None): kwargs["base_url"] = base_url self.client = OpenAI(**kwargs) elif self.provider == "ollama": - kwargs = params.copy() - kwargs.pop("model_name") - kwargs.pop("provider", None) - if kwargs.get("api_key") is None: - kwargs["api_key"] = "n/a" - self.client = OpenAI(**kwargs) + if params.get("api_key") is None: + params["api_key"] = "n/a" + self.client = OpenAI(**params) + elif self.provider == "bedrock": + if "aws_region" in params: + params["aws_region_name"] = params.pop("aws_region") + self.client = BedrockClient(**params) + elif self.provider == "gemini": + self.client = GeminiClient(**params) + elif self.provider == "snowflake": + self.client = SnowflakeClient(**params) else: - # try to use litellm - if self._session is None: - from mindsdb.api.executor.controllers.session_controller import SessionController - - self._session = SessionController() - module = self._session.integration_controller.get_handler_module("litellm") - - if module is None or module.Handler is None: - raise ValueError(f'Unable to use "{self.provider}" provider. Litellm handler is not installed') - - self.client = module.Handler - self.engine = "litellm" + raise NotImplementedError(f'Provider "{self.provider}" is not supported') @run_in_batches(1000) @retry_with_exponential_backoff def embeddings(self, messages: List[str]): - params = self.params - if self.engine == "openai": + if self.provider in ("openai", "azure_openai", "ollama"): response = self.client.embeddings.create( - model=params["model_name"], + model=self.model_name, input=messages, ) return [item.embedding for item in response.data] else: - kwargs = params.copy() - model = kwargs.pop("model_name") - kwargs.pop("provider", None) - - return self.client.embeddings(self.provider, model=model, messages=messages, args=kwargs) + return self.client.embeddings(self.model_name, messages) @run_in_batches(100) def completion(self, messages: List[dict], json_output: bool = False) -> List[str]: """ Call LLM completion and get response """ - params = self.params - if self.engine == "openai": + + if self.provider in ("openai", "azure_openai", "ollama"): response = self.client.chat.completions.create( - model=params["model_name"], + model=self.model_name, messages=messages, ) return [item.message.content for item in response.choices] else: - kwargs = params.copy() - params["json_output"] = json_output - model = kwargs.pop("model_name") - kwargs.pop("provider", None) - response = self.client.completion(self.provider, model=model, messages=messages, args=kwargs) - return [item.message.content for item in response.choices] + return [self.client.completion(self.model_name, messages)] async def abatch(self, messages_list: List[List[dict]], json_output: bool = False) -> List[List[str]]: """ diff --git a/mindsdb/interfaces/knowledge_base/providers/__init__.py b/mindsdb/interfaces/knowledge_base/providers/__init__.py new file mode 100644 index 00000000000..8b137891791 --- /dev/null +++ b/mindsdb/interfaces/knowledge_base/providers/__init__.py @@ -0,0 +1 @@ + diff --git a/mindsdb/interfaces/knowledge_base/providers/bedrock.py b/mindsdb/interfaces/knowledge_base/providers/bedrock.py new file mode 100644 index 00000000000..2f50c5c4c87 --- /dev/null +++ b/mindsdb/interfaces/knowledge_base/providers/bedrock.py @@ -0,0 +1,147 @@ +import json +from typing import Dict, List, Optional + + +def prepare_conversation(messages: List[dict]) -> List[dict]: + """Convert chat messages to Bedrock `converse` message payload format.""" + conversation = [] + for message in messages: + content = message["content"] + role = message["role"] + if role == "system": + role = "assistant" + if role != "user": + if len(conversation) == 0: + # the first message has to be user message + content = message["role"] + ":\n" + content + role = "user" + + conversation.append( + { + "role": role, + "content": [{"text": content}], + } + ) + return conversation + + +class AsyncBedrockClient: + """Async Bedrock runtime client wrapper""" + + def __init__( + self, + aws_access_key_id: Optional[str] = None, + aws_secret_access_key: Optional[str] = None, + aws_region_name: Optional[str] = None, + aws_session_token: Optional[str] = None, + ): + try: + from aiobotocore.session import get_session + except ImportError as exc: + raise ImportError( + "aiobotocore is required for the Bedrock reranker client. Install it with `pip install aiobotocore`." + ) from exc + + self.aws_access_key_id = aws_access_key_id + self.aws_secret_access_key = aws_secret_access_key + self.aws_session_token = aws_session_token + self.region_name = aws_region_name + + self._session = get_session() + + async def acompletion( + self, + model_name: str, + messages: List[dict], + temperature: Optional[float] = None, + max_tokens: Optional[int] = None, + top_p: Optional[float] = None, + ) -> str: + """Generate a chat completion asynchronously via Bedrock.""" + inference_config = {} + if temperature is not None: + inference_config["temperature"] = temperature + if max_tokens is not None: + inference_config["max_tokens"] = max_tokens + if top_p is not None: + inference_config["top_p"] = top_p + + conversation = prepare_conversation(messages) + + # Create client with credentials + client_kwargs = { + "service_name": "bedrock-runtime", + "region_name": self.region_name, + "aws_access_key_id": self.aws_access_key_id, + "aws_secret_access_key": self.aws_secret_access_key, + "aws_session_token": self.aws_session_token, + } + + async with self._session.create_client(**client_kwargs) as client: + response = await client.converse( + modelId=model_name, messages=conversation, inferenceConfig=inference_config + ) + + return response["output"]["message"]["content"][0]["text"] + + +class BedrockClient: + """Synchronous Bedrock runtime client wrapper""" + + def __init__( + self, + aws_access_key_id: Optional[str] = None, + aws_secret_access_key: Optional[str] = None, + aws_region_name: Optional[str] = None, + aws_session_token: Optional[str] = None, + ): + try: + import boto3 + except ImportError as exc: + raise ImportError("boto3 is required for the Bedrock client. Install it with `pip install boto3`.") from exc + + self.client = boto3.client( + "bedrock-runtime", + aws_access_key_id=aws_access_key_id, + aws_secret_access_key=aws_secret_access_key, + aws_session_token=aws_session_token, + region_name=aws_region_name, + ) + + def embeddings(self, model_name: str, messages: List[str]) -> List[List[float]]: + """Request embedding vectors for each text in `messages`.""" + embeddings = [] + for message in messages: + native_request = {"inputText": message} + request = json.dumps(native_request) + + response = self.client.invoke_model(modelId=model_name, body=request) + model_response = json.loads(response["body"].read()) + + # Extract and print the generated embedding and the input text token count. + embeddings.append(model_response["embedding"]) + + return embeddings + + def completion( + self, + model_name: str, + messages: List[dict], + temperature: Optional[float] = None, + max_tokens: Optional[int] = None, + top_p: Optional[float] = None, + ) -> str: + """Generate a chat completion synchronously via Bedrock.""" + inference_config: Dict[str, float | int] = {} + if temperature is not None: + inference_config["temperature"] = temperature + if max_tokens is not None: + inference_config["max_tokens"] = max_tokens + if top_p is not None: + inference_config["top_p"] = top_p + + conversation = prepare_conversation(messages) + + response = self.client.converse(modelId=model_name, messages=conversation, inferenceConfig=inference_config) + + return response["output"]["message"]["content"][0]["text"] diff --git a/mindsdb/interfaces/knowledge_base/providers/gemini.py b/mindsdb/interfaces/knowledge_base/providers/gemini.py new file mode 100644 index 00000000000..33e2bc314d0 --- /dev/null +++ b/mindsdb/interfaces/knowledge_base/providers/gemini.py @@ -0,0 +1,79 @@ +from typing import Any, List, Optional + + +class GeminiClient: + """Wrapper around google-genai SDK""" + + def __init__(self, api_key: str): + try: + from google import genai + from google.genai import types + except ImportError as exc: + raise ImportError("google.genai is required. Install it with `pip install google-genai`.") from exc + + self.client = genai.Client(api_key=api_key) + self.types = types + + def embeddings(self, model_name: str, messages: List[str]) -> List[List[float]]: + """Generate embedding vectors for each text in `messages`.""" + result = self.client.models.embed_content(model=model_name, contents=messages) + + return [item.values for item in result.embeddings] + + def _prepare_messages(self, messages: List[dict]) -> List[Any]: + """Convert chat messages into google-genai content payloads.""" + contents = [] + for message in messages: + role = message["role"] + # system role is not supported + if role != "user": + role = "model" + + contents.append(self.types.Content(role=role, parts=[self.types.Part(text=message["content"])])) + return contents + + def completion( + self, + model_name: str, + messages: List[dict], + temperature: Optional[float] = None, + max_tokens: Optional[int] = None, + top_p: Optional[float] = None, + ) -> str: + """Produce a chat response""" + config = {} + if temperature: + config["temperature"] = temperature + if max_tokens: + config["max_output_tokens"] = max_tokens + if top_p: + config["top_p"] = top_p + + contents = self._prepare_messages(messages) + + result = self.client.models.generate_content(model=model_name, contents=contents, config=config) + + return result.text + + async def acompletion( + self, + model_name: str, + messages: List[dict], + temperature: Optional[float] = None, + max_tokens: Optional[int] = None, + top_p: Optional[float] = None, + ) -> str: + """Async variant of `completion` using the SDK aio client.""" + config = {} + if temperature: + config["temperature"] = temperature + if max_tokens: + config["max_output_tokens"] = max_tokens + if top_p: + config["top_p"] = top_p + + contents = self._prepare_messages(messages) + + result = await self.client.aio.models.generate_content(model=model_name, contents=contents, config=config) + + return result.text diff --git a/mindsdb/interfaces/knowledge_base/providers/snowflake.py b/mindsdb/interfaces/knowledge_base/providers/snowflake.py new file mode 100644 index 00000000000..7bcc3fc6382 --- /dev/null +++ b/mindsdb/interfaces/knowledge_base/providers/snowflake.py @@ -0,0 +1,118 @@ +from typing import Dict, List, Optional, Union + +import requests +import httpx + + +def _raise_for_status(response: Union[requests.Response, httpx.Response]) -> None: + """Raise an informative HTTPError when Snowflake responds with an error.""" + if 400 <= response.status_code < 600: + if hasattr(response, "reason"): + reason = response.reason + elif hasattr(response, "reason_phrase"): + reason = response.reason_phrase + else: + reason = "Error" + raise requests.HTTPError(f"{reason}: {response.text}", response=response) + + +class SnowflakeClient: + """Wrapper over Snowflake Cortex REST endpoints.""" + + def __init__(self, account_id: Optional[str] = None, api_key: Optional[str] = None): + if account_id is None: + raise ValueError("account_id must be provided") + if api_key is None: + raise ValueError("api_key must be provided") + + self.account_id = account_id.lower() + self.api_key = api_key + + self.auth_type = "KEYPAIR_JWT" + if self.api_key.startswith("pat/"): + self.api_key = self.api_key[4:] + self.auth_type = "PROGRAMMATIC_ACCESS_TOKEN" + + def _get_base_url(self) -> str: + return f"https://{self.account_id}.snowflakecomputing.com/api/v2" + + def _get_headers(self) -> Dict[str, str]: + return { + "Content-Type": "application/json", + "Accept": "application/json", + "Authorization": "Bearer " + self.api_key, + "X-Snowflake-Authorization-Token-Type": self.auth_type, + } + + def embeddings(self, model_name: str, messages: List[str]) -> List[List[float]]: + """Request embedding vectors for the provided `messages`.""" + url = f"{self._get_base_url()}/cortex/inference:embed" + + payload = {"text": messages, "model": model_name} + + response = requests.post(url, json=payload, headers=self._get_headers()) + _raise_for_status(response) + + embeddings = [] + for item in response.json()["data"]: + embeddings.append(item["embedding"][0]) + return embeddings + + def completion( + self, + model_name: str, + messages: List[dict], + temperature: Optional[float] = None, + max_tokens: Optional[int] = None, + top_p: Optional[float] = None, + ) -> str: + """Generate a chat completion with the Cortex complete endpoint.""" + url = f"{self._get_base_url()}/cortex/inference:complete" + + payload = { + "model": model_name, + "stream": False, + "messages": messages, + } + + if temperature: + payload["temperature"] = temperature + if max_tokens: + payload["max_tokens"] = max_tokens + if top_p: + payload["top_p"] = top_p + + response = requests.post(url, json=payload, headers=self._get_headers()) + _raise_for_status(response) + data = response.json() + return data["choices"][0]["message"]["content"] + + async def acompletion( + self, + model_name: str, + messages: List[dict], + temperature: Optional[float] = None, + max_tokens: Optional[int] = None, + top_p: Optional[float] = None, + ) -> str: + """Async variant of `completion` using httpx.""" + url = f"{self._get_base_url()}/cortex/inference:complete" + + payload = { + "model": model_name, + "stream": False, + "messages": messages, + } + + if temperature: + payload["temperature"] = temperature + if max_tokens: + payload["max_tokens"] = max_tokens + if top_p: + payload["top_p"] = top_p + + async with httpx.AsyncClient() as client: + response = await client.post(url, json=payload, headers=self._get_headers()) + _raise_for_status(response) + data = response.json() + return data["choices"][0]["message"]["content"] diff --git a/requirements/requirements.txt b/requirements/requirements.txt index 759f05a0bcd..ab5217a0416 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -51,3 +51,7 @@ pydantic-ai==1.77.0 # Required for Pydantic AI agents bs4 # for rag HTMLDocumentLoader urllib3>=2.6.3 # not directly required, pinned by Snyk to avoid a vulnerability + +# kb providers +aiobotocore==3.4.0 +google-genai==1.70.0 diff --git a/tests/scripts/check_requirements.py b/tests/scripts/check_requirements.py index f3bcc6382dc..f3ec6de303e 100644 --- a/tests/scripts/check_requirements.py +++ b/tests/scripts/check_requirements.py @@ -202,6 +202,7 @@ def get_requirements_with_DEP002(path): "google-analytics-admin": ["google"], "google-auth": ["google"], "google-cloud-storage": ["google"], + "google-genai": ["google"], "google-auth-oauthlib": ["google_auth_oauthlib"], "google-api-python-client": ["googleapiclient"], "ibm-cos-sdk": ["ibm_boto3", "ibm_botocore"], @@ -259,6 +260,7 @@ def get_requirements_with_DEP002(path): "python-dotenv": ["dotenv"], "pyjwt": ["jwt"], "sklearn": ["scikit-learn"], + "types-aioboto3": ["aioboto3"], "ag2": ["autogen"], } diff --git a/tests/unit/executor/test_knowledge_base.py b/tests/unit/executor/test_knowledge_base.py index 991166e45ab..7646caca9ae 100644 --- a/tests/unit/executor/test_knowledge_base.py +++ b/tests/unit/executor/test_knowledge_base.py @@ -325,13 +325,8 @@ class _Choice: def __init__(self, content): self.message = _Msg(content) - class _Resp: - def __init__(self, content): - self.choices = [_Choice(content)] - async def _fake_call_llm(messages): - content = '{"ranking": [{"doc_index": 2, "score": 0.9}, {"doc_index": 1, "score": 0.6}, {"doc_index": 3, "score": 0.1}]}' - return _Resp(content) + return '{"ranking": [{"doc_index": 2, "score": 0.9}, {"doc_index": 1, "score": 0.6}, {"doc_index": 3, "score": 0.1}]}' # Bind the async method to this reranker instance reranker._call_llm = _fake_call_llm # type: ignore @@ -356,16 +351,11 @@ class _Choice: def __init__(self, content): self.message = _Msg(content) - class _Resp: - def __init__(self, content): - self.choices = [_Choice(content)] - async def _fake_call_llm(messages): # Returns code-fenced JSON, includes only two entries, one without score - content = """```json + return """```json {"ranking": [1, {"doc_index": 3, "score": 0.8}]} ```""" - return _Resp(content) reranker._call_llm = _fake_call_llm # type: ignore @@ -389,14 +379,9 @@ class _Choice: def __init__(self, content): self.message = _Msg(content) - class _Resp: - def __init__(self, content): - self.choices = [_Choice(content)] - async def _fake_call_llm(messages): # Invalid JSON forces fallback - content = "not-json" - return _Resp(content) + return "not-json" reranker._call_llm = _fake_call_llm # type: ignore @@ -1352,6 +1337,47 @@ def test_create_index(self, mock_embedding): ret = self.run_sql("select * from kb_ral where content='white'") assert "white" in ret["chunk_content"].iloc[0] + def test_providers(self): + with patch("mindsdb.interfaces.knowledge_base.llm_client.BedrockClient.embeddings") as embed: + with patch( + "mindsdb.integrations.utilities.rag.rerankers.base_reranker.AsyncBedrockClient.acompletion" + ) as rerank: + embed.return_value = [[1, 1, 1]] + rerank.return_value = "100" + self._create_kb( + "kb_test", + embedding_model={ + "provider": "bedrock", + "model_name": "amazon.titan", + "aws_access_key_id": "-", + "aws_region_name": "us-east-2", + "aws_secret_access_key": "-", + }, + reranking_model={ + "provider": "bedrock", + "model_name": "llama3", + "aws_access_key_id": "-", + "aws_region_name": "us-east-2", + "aws_secret_access_key": "-", + }, + ) + assert embed.call_args_list[0][0][0] == "amazon.titan" + assert rerank.call_args_list[0][1]["model_name"] == "llama3" + + with patch("mindsdb.interfaces.knowledge_base.llm_client.SnowflakeClient.embeddings") as embed: + embed.return_value = [[1, 1, 1]] + self._create_kb( + "kb_test", + embedding_model={"provider": "snowflake", "model_name": "arctic", "account_id": "ABC", "api_key": "-"}, + ) + assert embed.call_args_list[0][0][0] == "arctic" + with patch("mindsdb.interfaces.knowledge_base.llm_client.GeminiClient.embeddings") as embed: + embed.return_value = [[1, 1, 1]] + self._create_kb( + "kb_test", embedding_model={"provider": "gemini", "model_name": "gemini-embedding", "api_key": "-"} + ) + assert embed.call_args_list[0][0][0] == "gemini-embedding" + class TestKBAutoBatch(BaseTestKB): @patch("mindsdb.interfaces.knowledge_base.controller.LLMClient")