Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,15 +1,22 @@
# Copyright (c) Microsoft. All rights reserved.

from collections.abc import AsyncGenerator
from unittest.mock import AsyncMock, MagicMock, patch

import httpx
import pytest
from ollama import AsyncClient

import semantic_kernel.connectors.ai.ollama.services.ollama_chat_completion as occ_module
from semantic_kernel.connectors.ai.completion_usage import CompletionUsage
from semantic_kernel.connectors.ai.function_choice_behavior import FunctionChoiceBehavior
from semantic_kernel.connectors.ai.ollama.ollama_prompt_execution_settings import OllamaChatPromptExecutionSettings
from semantic_kernel.connectors.ai.ollama.services.ollama_chat_completion import OllamaChatCompletion
from semantic_kernel.contents.chat_history import ChatHistory
from semantic_kernel.contents.chat_message_content import ChatMessageContent
from semantic_kernel.contents.function_call_content import FunctionCallContent
from semantic_kernel.contents.streaming_chat_message_content import StreamingChatMessageContent
from semantic_kernel.contents.streaming_text_content import StreamingTextContent
from semantic_kernel.exceptions.service_exceptions import (
ServiceInitializationError,
ServiceInvalidExecutionSettingsError,
Expand Down Expand Up @@ -260,3 +267,187 @@ async def test_prepare_chat_history_for_request(setup_ollama_chat_completion):

prepared_history = ollama_chat_completion._prepare_chat_history_for_request(chat_history)
assert prepared_history == []


async def test_service_url_with_httpx_client(model_id: str) -> None:
"""
Test that service_url returns the base_url of the underlying httpx.AsyncClient.
"""
# Initialize an AsyncClient and manually set its _client attribute to an httpx.AsyncClient
client = AsyncClient(host="unused")
base = httpx.AsyncClient(base_url="http://example.com:8000")
client._client = base # simulate underlying httpx client

ollama = OllamaChatCompletion(ai_model_id=model_id, client=client)
# service_url should reflect the base_url of the httpx client
assert ollama.service_url() == "http://example.com:8000"


@patch("ollama.AsyncClient.chat", new_callable=AsyncMock)
async def test_chat_response_branch(
mock_chat: AsyncMock,
model_id: str,
service_id: str,
default_options: dict,
chat_history,
monkeypatch,
) -> None:
"""
Test get_chat_message_contents when AsyncClient.chat returns a ChatResponse instance.
"""

class DummyFunction:
def __init__(self, name, arguments):
self.name = name
self.arguments = arguments

class DummyToolCall:
def __init__(self, function):
self.function = function

class DummyMessage:
def __init__(self, content: str, tool_calls=None) -> None:
self.content = content
self.tool_calls = tool_calls or []

class DummyChatResponse:
def __init__(
self,
content: str,
model: str,
prompt_eval_count: int,
eval_count: int,
tool_calls=None,
) -> None:
function_calls = [
DummyToolCall(DummyFunction(tc["function"]["name"], tc["function"]["arguments"])) for tc in tool_calls
]
self.message = DummyMessage(content, function_calls)
self.model = model
self.prompt_eval_count = prompt_eval_count
self.eval_count = eval_count

# Monkeypatch the ChatResponse type in the module so isinstance works
monkeypatch.setattr(occ_module, "ChatResponse", DummyChatResponse)

# Prepare a dummy ChatResponse return value
dummy_resp = DummyChatResponse(
content="resp_text",
model="mdl",
prompt_eval_count=2,
eval_count=3,
tool_calls=[{"function": {"name": "fn", "arguments": {"x": 1}}}],
)
mock_chat.return_value = dummy_resp

ollama = OllamaChatCompletion(ai_model_id=model_id)
settings = OllamaChatPromptExecutionSettings(service_id=service_id, options=default_options)

results = await ollama.get_chat_message_contents(chat_history, settings)
# Only one response expected
assert len(results) == 1
msg = results[0]
# Assert it's a ChatMessageContent
assert isinstance(msg, ChatMessageContent)
# The content property should return the response text
assert msg.content == "resp_text"

# The second item should be a FunctionCallContent
func_item = msg.items[1]
assert isinstance(func_item, FunctionCallContent)
# Validate function call details
assert func_item.name == "fn"
assert func_item.arguments == {"x": 1}

# Check metadata
assert "model" in msg.metadata and msg.metadata["model"] == "mdl"
# Access usage directly, key should exist
usage = msg.metadata["usage"]
assert isinstance(usage, CompletionUsage)
assert usage.prompt_tokens == 2 and usage.completion_tokens == 3


@patch("ollama.AsyncClient.chat", new_callable=AsyncMock)
async def test_streaming_chat_response_branch(
mock_chat: AsyncMock,
model_id: str,
service_id: str,
default_options: dict,
chat_history,
monkeypatch,
) -> None:
"""
Test get_streaming_chat_message_contents when AsyncClient.chat yields ChatResponse instances.
"""

class DummyFunction:
def __init__(self, name, arguments):
self.name = name
self.arguments = arguments

class DummyToolCall:
def __init__(self, function):
self.function = function

class DummyMessage:
def __init__(self, content: str, tool_calls=None) -> None:
self.content = content
self.tool_calls = tool_calls or []

class DummyChatResponse:
def __init__(
self,
content: str,
model: str,
prompt_eval_count: int,
eval_count: int,
tool_calls=None,
) -> None:
function_calls = [
DummyToolCall(DummyFunction(tc["function"]["name"], tc["function"]["arguments"])) for tc in tool_calls
]
self.message = DummyMessage(content, function_calls)
self.model = model
self.prompt_eval_count = prompt_eval_count
self.eval_count = eval_count

# Monkeypatch ChatResponse type
monkeypatch.setattr(occ_module, "ChatResponse", DummyChatResponse)

# Prepare an async generator yielding DummyChatResponse
async def fake_stream() -> AsyncGenerator[DummyChatResponse, None]:
yield DummyChatResponse(
content="stream_text",
model="m2",
prompt_eval_count=1,
eval_count=1,
tool_calls=[{"function": {"name": "f2", "arguments": {}}}],
)

mock_chat.return_value = fake_stream()

ollama = OllamaChatCompletion(ai_model_id=model_id)
settings = OllamaChatPromptExecutionSettings(service_id=service_id, options=default_options)

collected = []
# Iterate over streamed batches
async for batch in ollama.get_streaming_chat_message_contents(chat_history, settings):
# We expect a list with a single StreamingChatMessageContent
assert len(batch) == 1
sc = batch[0]
assert isinstance(sc, StreamingChatMessageContent)

# First item should be text content
text_item = sc.items[0]
assert isinstance(text_item, StreamingTextContent)
assert text_item.text == "stream_text"

# Next item should be a FunctionCallContent
func_item = sc.items[1]
assert isinstance(func_item, FunctionCallContent)
assert func_item.name == "f2"

collected.append(sc)

# Only one batch should be collected
assert len(collected) == 1
182 changes: 182 additions & 0 deletions python/tests/unit/connectors/memory/chroma/test_chroma_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
# Copyright (c) Microsoft. All rights reserved.

import logging
from typing import Any, cast

import numpy as np
import pytest
from chromadb.api.types import QueryResult

from semantic_kernel.connectors.memory.chroma.utils import (
chroma_compute_similarity_scores,
query_results_to_records,
)

# Possible bug found: the function camel_to_snake is not doing correctly the conversion.
# def test_camel_to_snake_basic_cases() -> None:
Comment thread
gaudyb marked this conversation as resolved.
# """
# Test converting various camelCase strings to snake_case.

# This covers standard conversions, acronyms, numbers, and edge cases.
# """
# # Standard conversions
# assert camel_to_snake("camelCaseTest") == "camel_case_test"
# assert camel_to_snake("CamelCase") == "camel_case"
# # Already snake or lowercase remains unchanged
# assert camel_to_snake("already_snake") == "already_snake"
# assert camel_to_snake("lowercase") == "lowercase"
# # Acronyms and consecutive uppercase letters
# assert camel_to_snake("XMLHttpRequest") == "xml_http_request"
# assert camel_to_snake("HTTPServerError") == "http_server_error"
# # Numbers in string
# assert camel_to_snake("Test123Case") == "test123_case"
# # Empty string edge case
# assert camel_to_snake("") == ""


def test_query_results_to_records_shallow_and_embedding_flags() -> None:
"""
Test query_results_to_records with a single (flat) record and both embedding flags.

Validate transformation of shallow dict input and check both with_embedding True/False.
"""
# Prepare flat-style results (shallow lists)
meta: Any = {
"is_reference": "True",
"external_source_name": "sourceA",
"id": "meta123",
"description": "a record",
"additional_metadata": "extra",
"timestamp": "2022-05-01T12:00:00",
}
results_flat: Any = {
"ids": ["key1"],
"documents": ["document text"],
"embeddings": [[0.1, 0.2, 0.3]],
"metadatas": [meta],
}

# With embedding included
records_with_emb = query_results_to_records(cast(QueryResult, results_flat.copy()), with_embedding=True)
assert len(records_with_emb) == 1
rec1 = records_with_emb[0]
# Public properties
assert rec1.id == "meta123"
assert rec1.text == "document text"
np.testing.assert_array_equal(rec1.embedding, np.array([0.1, 0.2, 0.3]))
assert rec1.additional_metadata == "extra"
assert rec1.description == "a record"
assert rec1.timestamp == "2022-05-01T12:00:00"
# Private/internal attributes
assert rec1._is_reference is True
assert rec1._external_source_name == "sourceA"
assert rec1._key == "key1"

# Without embedding (embedding should be None)
records_no_emb = query_results_to_records(cast(QueryResult, results_flat.copy()), with_embedding=False)
assert len(records_no_emb) == 1
rec2 = records_no_emb[0]
assert rec2.embedding is None
assert rec2._key == "key1"


def test_query_results_to_records_nested_multiple() -> None:
"""
Test query_results_to_records with nested lists and multiple records.

Ensure nested lists are handled correctly and multiple records extracted.
"""
# Prepare nested-style results
meta1: Any = {
"is_reference": "False",
"external_source_name": "src1",
"id": "m1",
"description": "desc1",
"additional_metadata": "md1",
"timestamp": "t1",
}
meta2: Any = {
"is_reference": "True",
"external_source_name": "src2",
"id": "m2",
"description": "desc2",
"additional_metadata": "md2",
"timestamp": "t2",
}
results_nested: Any = {
"ids": [["k1", "k2"]],
"documents": [["doc1", "doc2"]],
"embeddings": [[[1, 0], [0, 1]]],
"metadatas": [[meta1, meta2]],
}
records = query_results_to_records(cast(QueryResult, results_nested.copy()), with_embedding=True)
# Should produce two records
assert len(records) == 2

# First record checks
r0 = records[0]
assert r0._key == "k1"
assert r0.id == "m1"
assert r0._is_reference is False
np.testing.assert_array_equal(r0.embedding, np.array([1, 0]))

# Second record checks
r1 = records[1]
assert r1._key == "k2"
assert r1.id == "m2"
assert r1._is_reference is True
np.testing.assert_array_equal(r1.embedding, np.array([0, 1]))


def test_query_results_to_records_empty_ids_returns_empty_list() -> None:
"""
If results contain no ids, function should return an empty list without raising.
"""
empties: Any = {"ids": [], "documents": [], "embeddings": [], "metadatas": []}
result = query_results_to_records(cast(QueryResult, empties), with_embedding=True)
assert result == []


def test_chroma_compute_similarity_scores_all_valid_no_warnings(caplog) -> None:
"""
Test similarity scores when all vectors are non-zero; no warnings should be emitted.
"""
caplog.set_level(logging.WARNING)
emb_q = np.array([1.0, 0.0])
emb_arr = np.array([[1.0, 0.0], [0.0, 1.0]])
scores = chroma_compute_similarity_scores(emb_q, emb_arr)
# Expect cosine similarities [1, 0]
assert pytest.approx(scores.tolist()) == [1.0, 0.0]
# No warnings should be logged
assert "Some vectors in the embedding collection" not in caplog.text


def test_chroma_compute_similarity_scores_with_zero_vectors_warns(caplog) -> None:
"""
Test similarity scores when some embedding vectors are zero; should generate warning and leave -1 for zero vectors.
"""
caplog.set_level(logging.WARNING)
emb_q = np.array([1.0, 0.0])
emb_arr = np.array([[1.0, 0.0], [0.0, 0.0]])
scores = chroma_compute_similarity_scores(emb_q, emb_arr)
# First index valid: 1.0, second invalid: -1.0
assert pytest.approx(scores.tolist()) == [1.0, -1.0]
# Warning logged for zero vectors
assert "Some vectors in the embedding collection are zero vectors" in caplog.text


def test_chroma_compute_similarity_scores_all_invalid_raises() -> None:
"""
Test that if all vectors are invalid (zero), a ValueError is raised.
"""
emb_q = np.array([0.0, 0.0])
emb_arr = np.array([[1.0, 0.0], [0.0, 1.0]]) # non-zero collection
with pytest.raises(ValueError) as excinfo1:
chroma_compute_similarity_scores(emb_q, emb_arr)
assert "Invalid vectors" in str(excinfo1.value)

emb_q2 = np.array([1.0, 0.0])
emb_arr2 = np.array([[0.0, 0.0], [0.0, 0.0]]) # all zero collection
with pytest.raises(ValueError) as excinfo2:
chroma_compute_similarity_scores(emb_q2, emb_arr2)
assert "Invalid vectors" in str(excinfo2.value)
Loading