diff --git a/integrations/watsonx/pyproject.toml b/integrations/watsonx/pyproject.toml index a66eae071a..a0771eeb8f 100644 --- a/integrations/watsonx/pyproject.toml +++ b/integrations/watsonx/pyproject.toml @@ -22,7 +22,7 @@ classifiers = [ "Programming Language :: Python :: 3.12", "Programming Language :: Python :: 3.13", ] -dependencies = ["haystack-ai>=2.17.1", "ibm-watsonx-ai>=1.3.26", "pandas>=2.2.3"] +dependencies = ["haystack-ai>=2.24.1", "ibm-watsonx-ai>=1.3.26", "pandas>=2.2.3"] [project.urls] Documentation = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/watsonx#readme" diff --git a/integrations/watsonx/src/haystack_integrations/components/generators/watsonx/chat/chat_generator.py b/integrations/watsonx/src/haystack_integrations/components/generators/watsonx/chat/chat_generator.py index ef184596f4..e31c36569b 100644 --- a/integrations/watsonx/src/haystack_integrations/components/generators/watsonx/chat/chat_generator.py +++ b/integrations/watsonx/src/haystack_integrations/components/generators/watsonx/chat/chat_generator.py @@ -2,21 +2,35 @@ # # SPDX-License-Identifier: Apache-2.0 +import json +from dataclasses import replace from datetime import datetime, timezone from typing import Any, Literal, get_args from haystack import component, default_from_dict, default_to_dict, logging +from haystack.components.generators.utils import _convert_streaming_chunks_to_chat_message from haystack.dataclasses import ( AsyncStreamingCallbackT, ChatMessage, ChatRole, + ComponentInfo, + FinishReason, ImageContent, StreamingCallbackT, StreamingChunk, SyncStreamingCallbackT, TextContent, + ToolCall, + ToolCallDelta, select_streaming_callback, ) +from haystack.tools import ( + ToolsType, + _check_duplicate_tool_names, + deserialize_tools_or_toolset_inplace, + flatten_tools_or_toolsets, + serialize_tools_or_toolset, +) from haystack.utils import Secret, deserialize_callable, deserialize_secrets_inplace, serialize_callable from ibm_watsonx_ai import Credentials from ibm_watsonx_ai.foundation_models import ModelInference @@ -29,6 +43,17 @@ ImageFormat = Literal["image/jpeg", "image/png"] IMAGE_SUPPORTED_FORMATS: list[ImageFormat] = list(get_args(ImageFormat)) +# See https://ibm.github.io/watsonx-ai-node-sdk/enums/1_6_x.WatsonXAI.TextChatResultChoiceStream.Constants.FinishReason.html +# for possible finish reasons +FINISH_REASON_MAPPING: dict[str, FinishReason] = { + "cancelled": "stop", + "error": "stop", + "length": "length", + "stop": "stop", + "time_limit": "stop", + "tool_calls": "tool_calls", +} + @component class WatsonxChatGenerator: @@ -100,6 +125,7 @@ def __init__( max_retries: int | None = None, verify: bool | str | None = None, streaming_callback: StreamingCallbackT | None = None, + tools: ToolsType | None = None, ) -> None: """ Creates an instance of WatsonxChatGenerator. @@ -136,6 +162,8 @@ def __init__( - False: Skip verification (insecure) - Path to CA bundle for custom certificates :param streaming_callback: A callback function for streaming responses. + :param tools: + A list of Tool and/or Toolset objects, or a single Toolset for which the model can prepare calls. """ self.api_key = api_key self.model = model @@ -146,6 +174,7 @@ def __init__( self.max_retries = max_retries self.verify = verify self.streaming_callback = streaming_callback + self.tools = tools self._initialize_client() @@ -169,6 +198,7 @@ def to_dict(self) -> dict[str, Any]: The serialized component as a dictionary. """ callback_name = serialize_callable(self.streaming_callback) if self.streaming_callback else None + serialized_tools = serialize_tools_or_toolset(self.tools) if self.tools else None return default_to_dict( self, model=self.model, @@ -180,6 +210,7 @@ def to_dict(self) -> dict[str, Any]: max_retries=self.max_retries, verify=self.verify, streaming_callback=callback_name, + tools=serialized_tools, ) @classmethod @@ -193,6 +224,7 @@ def from_dict(cls, data: dict[str, Any]) -> "WatsonxChatGenerator": The deserialized component instance. """ deserialize_secrets_inplace(data["init_parameters"], keys=["api_key", "project_id"]) + deserialize_tools_or_toolset_inplace(data["init_parameters"], key="tools") init_params = data.get("init_parameters", {}) serialized_callback = init_params.get("streaming_callback") if serialized_callback: @@ -206,6 +238,7 @@ def run( messages: list[ChatMessage], generation_kwargs: dict[str, Any] | None = None, streaming_callback: StreamingCallbackT | None = None, + tools: ToolsType | None = None, ) -> dict[str, list[ChatMessage]]: """ Generate chat completions synchronously. @@ -218,6 +251,9 @@ def run( :param streaming_callback: A callback function that is called when a new token is received from the stream. If provided this will override the `streaming_callback` set in the `__init__` method. + :param tools: + A list of Tool and/or Toolset objects, or a single Toolset for which the model can prepare calls. + If set, it will override the `tools` parameter provided during initialization. :returns: A dictionary with the following key: - `replies`: A list containing the generated responses as ChatMessage instances. @@ -229,7 +265,7 @@ def run( init_callback=self.streaming_callback, runtime_callback=streaming_callback, requires_async=False ) - api_args = self._prepare_api_call(messages=messages, generation_kwargs=generation_kwargs) + api_args = self._prepare_api_call(messages=messages, generation_kwargs=generation_kwargs, tools=tools) if resolved_streaming_callback: return self._handle_streaming(api_args=api_args, callback=resolved_streaming_callback) @@ -243,6 +279,7 @@ async def run_async( messages: list[ChatMessage], generation_kwargs: dict[str, Any] | None = None, streaming_callback: StreamingCallbackT | None = None, + tools: ToolsType | None = None, ) -> dict[str, list[ChatMessage]]: """ Generate chat completions asynchronously. @@ -255,6 +292,9 @@ async def run_async( :param streaming_callback: A callback function that is called when a new token is received from the stream. If provided this will override the `streaming_callback` set in the `__init__` method. + :param tools: + A list of Tool and/or Toolset objects, or a single Toolset for which the model can prepare calls. + If set, it will override the `tools` parameter provided during initialization. :returns: A dictionary with the following key: - `replies`: A list containing the generated responses as ChatMessage instances. @@ -266,7 +306,7 @@ async def run_async( init_callback=self.streaming_callback, runtime_callback=streaming_callback, requires_async=True ) - api_args = self._prepare_api_call(messages=messages, generation_kwargs=generation_kwargs) + api_args = self._prepare_api_call(messages=messages, generation_kwargs=generation_kwargs, tools=tools) if resolved_streaming_callback: return await self._handle_async_streaming(api_args=api_args, callback=resolved_streaming_callback) @@ -274,16 +314,25 @@ async def run_async( return await self._handle_async_standard(api_args) def _prepare_api_call( - self, *, messages: list[ChatMessage], generation_kwargs: dict[str, Any] | None = None + self, + *, + messages: list[ChatMessage], + generation_kwargs: dict[str, Any] | None = None, + tools: ToolsType | None = None, ) -> dict[str, Any]: merged_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})} watsonx_messages = [] content: str | None | dict[str, Any] | list[dict[str, Any]] + flattened_tools = flatten_tools_or_toolsets(tools or self.tools) + _check_duplicate_tool_names(flattened_tools) + tool_definitions = [{"type": "function", "function": {**tool.tool_spec}} for tool in flattened_tools] + for msg in messages: - if msg.is_from("tool"): - logger.debug("Skipping tool message - tool calls are not currently supported") + # Watsonx tool call result messages are of the same format as OpenAI chat completions + if msg.tool_call_results: + watsonx_messages.append(msg.to_openai_dict_format(require_tool_call_ids=True)) continue # Check that images are only in user messages @@ -325,7 +374,57 @@ def _prepare_api_call( merged_kwargs.pop("stream", None) - return {"messages": watsonx_messages, "params": merged_kwargs} + api_args = {"messages": watsonx_messages, "params": merged_kwargs} + if tool_definitions: + api_args["tools"] = tool_definitions + + return api_args + + def _convert_chunk_to_streaming_chunk(self, chunk: dict[str, Any], component_info: ComponentInfo) -> StreamingChunk: + """ + Convert one Watsonx AI stream-chunk to Haystack StreamingChunk. + """ + choice = chunk["choices"][0] + chunk_meta = { + "model": self.model, + "model_id": chunk.get("model_id"), + "model_version": chunk.get("model_version"), + "created": chunk.get("created"), + "created_at": chunk.get("created_at"), + "received_at": datetime.now(timezone.utc).isoformat(), + } + + if choice["delta"] and (choice_delta_tool_calls := choice["delta"].get("tool_calls")): + # create a list of ToolCallDelta objects from the tool calls + tool_calls_deltas = [ + ToolCallDelta( + index=tool_call["index"], + id=tool_call.get("id"), + tool_name=tool_call.get("function", {}).get("name"), + arguments=tool_call.get("function", {}).get("arguments"), + ) + for tool_call in choice_delta_tool_calls + ] + return StreamingChunk( + content=choice.get("delta", {}).get("content", ""), + meta=chunk_meta, + component_info=component_info, + # We adopt the first tool_calls_deltas.index as the overall index of the chunk to match OpenAI + index=tool_calls_deltas[0].index, + tool_calls=tool_calls_deltas, + start=tool_calls_deltas[0].tool_name is not None, + finish_reason=FINISH_REASON_MAPPING.get(choice.get("finish_reason")), + ) + + index = choice.get("index", 0) + return StreamingChunk( + content=choice.get("delta", {}).get("content", ""), + meta=chunk_meta, + component_info=component_info, + index=index, + start=index == 0, + finish_reason=FINISH_REASON_MAPPING.get(choice.get("finish_reason")), + ) def _handle_streaming( self, @@ -342,29 +441,40 @@ def _handle_streaming( A dictionary with the generated responses as ChatMessage instances. """ chunks: list[StreamingChunk] = [] - stream = self.client.chat_stream(messages=api_args["messages"], params=api_args["params"]) + stream = self.client.chat_stream( + messages=api_args["messages"], params=api_args["params"], tools=api_args.get("tools") + ) + component_info = ComponentInfo.from_component(self) for chunk in stream: if not isinstance(chunk, dict) or not chunk.get("choices"): continue - content = chunk["choices"][0].get("delta", {}).get("content", "") - if content: - chunk_meta = { - "model": self.model, - "index": chunk["choices"][0].get("index", 0), - "finish_reason": chunk["choices"][0].get("finish_reason"), - "received_at": datetime.now(timezone.utc).isoformat(), - } - streaming_chunk = StreamingChunk(content=content, meta=chunk_meta) - chunks.append(streaming_chunk) - callback(streaming_chunk) + streaming_chunk = self._convert_chunk_to_streaming_chunk(chunk, component_info) + chunks.append(streaming_chunk) + callback(streaming_chunk) - return {"replies": [self._convert_streaming_chunks_to_chat_message(chunks)]} + chat_message = _convert_streaming_chunks_to_chat_message(chunks) + message_tool_calls = [ + replace(tool_call, arguments=self._parse_tool_call_json(tool_call.arguments)) + for tool_call in chat_message.tool_calls + ] + return { + "replies": [ + ChatMessage.from_assistant( + text=chat_message.text, + meta=chat_message.meta, + tool_calls=message_tool_calls, + reasoning=chat_message.reasoning, + ) + ] + } def _handle_standard(self, api_args: dict[str, Any]) -> dict[str, list[ChatMessage]]: """Handle synchronous standard response.""" - response = self.client.chat(messages=api_args["messages"], params=api_args["params"]) + response = self.client.chat( + messages=api_args["messages"], params=api_args["params"], tools=api_args.get("tools") + ) return self._process_response(response) async def _handle_async_streaming( @@ -375,60 +485,77 @@ async def _handle_async_streaming( ) -> dict[str, list[ChatMessage]]: """Handle asynchronous streaming response.""" chunks: list[StreamingChunk] = [] - stream_generator = await self.client.achat_stream(messages=api_args["messages"], params=api_args["params"]) + stream_generator = await self.client.achat_stream( + messages=api_args["messages"], params=api_args["params"], tools=api_args.get("tools") + ) + component_info = ComponentInfo.from_component(self) async for chunk in stream_generator: if not isinstance(chunk, dict) or not chunk.get("choices"): continue - content = chunk["choices"][0].get("delta", {}).get("content", "") - if content: - chunk_meta = { - "model": self.model, - "index": chunk["choices"][0].get("index", 0), - "finish_reason": chunk["choices"][0].get("finish_reason"), - "received_at": datetime.now(timezone.utc).isoformat(), - } - streaming_chunk = StreamingChunk(content=content, meta=chunk_meta) - chunks.append(streaming_chunk) - await callback(streaming_chunk) - - return {"replies": [self._convert_streaming_chunks_to_chat_message(chunks)]} - - def _convert_streaming_chunks_to_chat_message(self, chunks: list[StreamingChunk]) -> ChatMessage: - """Convert list of streaming chunks to a single ChatMessage.""" - if not chunks: - return ChatMessage.from_assistant("") - - content = "".join(chunk.content for chunk in chunks) - last_chunk_meta = chunks[-1].meta if chunks else {} - - return ChatMessage.from_assistant( - text=content, - meta={ - "model": self.model, - "finish_reason": last_chunk_meta.get("finish_reason"), - "usage": last_chunk_meta.get("usage", {}), - "chunks_count": len(chunks), - }, - ) + streaming_chunk = self._convert_chunk_to_streaming_chunk(chunk, component_info) + chunks.append(streaming_chunk) + await callback(streaming_chunk) + + chat_message = _convert_streaming_chunks_to_chat_message(chunks) + message_tool_calls = [ + replace(tool_call, arguments=self._parse_tool_call_json(tool_call.arguments)) + for tool_call in chat_message.tool_calls + ] + return { + "replies": [ + ChatMessage.from_assistant( + text=chat_message.text, + meta=chat_message.meta, + tool_calls=message_tool_calls, + reasoning=chat_message.reasoning, + ) + ] + } async def _handle_async_standard(self, api_args: dict[str, Any]) -> dict[str, list[ChatMessage]]: """Handle asynchronous standard response.""" - response = await self.client.achat(messages=api_args["messages"], params=api_args["params"]) + response = await self.client.achat( + messages=api_args["messages"], params=api_args["params"], tools=api_args.get("tools") + ) return self._process_response(response) + @staticmethod + def _parse_tool_call_json(tool_call: str | dict) -> dict[str, Any]: + """Parse tool call json from Watsonx tool calls.""" + if isinstance(tool_call, dict): + return tool_call + obj = json.loads(tool_call) + if isinstance(obj, str): + obj = json.loads(obj) + return obj + def _process_response(self, response: dict[str, Any]) -> dict[str, list[ChatMessage]]: """Process standard response into Haystack format.""" if not response.get("choices"): return {"replies": []} - choice = response["choices"][0] - message = choice.get("message", {}) - return { - "replies": [ + choices = response["choices"] + chat_messages = [] + for choice in choices: + message = choice.get("message", {}) + + message_tool_calls: list[ToolCall] | None = None + if tool_calls := message.get("tool_calls", []): + message_tool_calls = [ + ToolCall( + id=tool_call["id"], + tool_name=tool_call["function"]["name"], + arguments=self._parse_tool_call_json(tool_call["function"]["arguments"]), + ) + for tool_call in tool_calls + ] + + chat_messages.append( ChatMessage.from_assistant( text=message.get("content", ""), + tool_calls=message_tool_calls, meta={ "model": self.model, "index": choice.get("index", 0), @@ -436,5 +563,6 @@ def _process_response(self, response: dict[str, Any]) -> dict[str, list[ChatMess "usage": response.get("usage", {}), }, ) - ] - } + ) + + return {"replies": chat_messages} diff --git a/integrations/watsonx/tests/test_chat_generator.py b/integrations/watsonx/tests/test_chat_generator.py index 41c36fbfd9..b85cf8211a 100644 --- a/integrations/watsonx/tests/test_chat_generator.py +++ b/integrations/watsonx/tests/test_chat_generator.py @@ -2,12 +2,14 @@ # # SPDX-License-Identifier: Apache-2.0 import os +from collections.abc import Generator from unittest.mock import AsyncMock, MagicMock, patch import pytest from haystack import logging from haystack.components.generators.utils import print_streaming_chunk -from haystack.dataclasses import ChatMessage, ImageContent, StreamingChunk +from haystack.dataclasses import ChatMessage, ChatRole, ComponentInfo, ImageContent, StreamingChunk +from haystack.tools import Tool, Toolset from haystack.utils import Secret from haystack_integrations.components.generators.watsonx.chat.chat_generator import WatsonxChatGenerator @@ -15,9 +17,30 @@ logger = logging.getLogger(__name__) +def weather(city: str): + """Get weather information for a city.""" + return f"Weather in {city}: 22°C, sunny" + + +def population(city: str) -> str: + return f"The population of {city} is 2.2 million" + + +@pytest.fixture +def tools(): + return [ + Tool( + name="weather", + description="useful to determine the weather in a given location", + parameters={"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]}, + function=weather, + ) + ] + + class TestWatsonxChatGenerator: @pytest.fixture - def mock_watsonx(self, monkeypatch): + def mock_watsonx(self, monkeypatch) -> Generator[dict[str, AsyncMock | MagicMock], None]: """Fixture for setting up common mocks""" monkeypatch.setenv("WATSONX_API_KEY", "fake-api-key") monkeypatch.setenv("WATSONX_PROJECT_ID", "fake-project-id") @@ -41,7 +64,7 @@ def mock_watsonx(self, monkeypatch): { "message": {"content": "This is a generated response", "role": "assistant"}, "index": 0, - "finish_reason": "completed", + "finish_reason": "stop", } ], "usage": {"prompt_tokens": 10, "completion_tokens": 20, "total_tokens": 30}, @@ -53,7 +76,7 @@ def mock_watsonx(self, monkeypatch): { "message": {"content": "Async generated response", "role": "assistant"}, "index": 0, - "finish_reason": "completed", + "finish_reason": "stop", } ], "usage": {"prompt_tokens": 10, "completion_tokens": 20, "total_tokens": 30}, @@ -62,11 +85,11 @@ def mock_watsonx(self, monkeypatch): mock_model_instance.chat_stream = MagicMock( return_value=[ {"choices": [{"delta": {"content": "Streaming"}, "index": 0, "finish_reason": None}]}, - {"choices": [{"delta": {"content": " response"}, "index": 0, "finish_reason": "completed"}]}, + {"choices": [{"delta": {"content": " response"}, "index": 0, "finish_reason": "stop"}]}, ] ) - async def mock_achat_stream(messages=None, params=None): + async def mock_achat_stream(messages=None, params=None, tools=None): class MockAsyncGenerator: def __init__(self): self._count = 0 @@ -84,9 +107,7 @@ async def __anext__(self): } elif self._count == 2: return { - "choices": [ - {"delta": {"content": " response"}, "finish_reason": "completed", "index": 0} - ] + "choices": [{"delta": {"content": " response"}, "finish_reason": "stop", "index": 0}] } else: raise StopAsyncIteration @@ -110,14 +131,18 @@ def test_init_default(self, mock_watsonx): assert isinstance(generator.project_id, Secret) assert generator.project_id.resolve_value() == "fake-project-id" assert generator.api_base_url == "https://us-south.ml.cloud.ibm.com" + assert generator.tools is None + + def test_init_with_all_params(self, mock_watsonx: dict[str, AsyncMock | MagicMock]) -> None: + tool = Tool(name="name", description="description", parameters={"x": {"type": "string"}}, function=weather) - def test_init_with_all_params(self, mock_watsonx): generator = WatsonxChatGenerator( api_key=Secret.from_token("test-api-key"), project_id=Secret.from_token("test-project"), api_base_url="https://custom-url.com", generation_kwargs={"max_tokens": 100, "temperature": 0.7, "top_p": 0.9}, verify=False, + tools=[tool], ) _, kwargs = mock_watsonx["model"].call_args @@ -127,6 +152,12 @@ def test_init_with_all_params(self, mock_watsonx): assert isinstance(generator.project_id, Secret) assert generator.project_id.resolve_value() == "test-project" + assert generator.tools == [tool] + + def test_init_with_toolset(self, mock_watsonx: dict[str, AsyncMock | MagicMock], tools: list[Tool]) -> None: + toolset = Toolset(tools) + generator = WatsonxChatGenerator(project_id=Secret.from_token("fake-project-id"), tools=toolset) + assert generator.tools == toolset def test_init_fails_without_project(self, mock_watsonx): os.environ.pop("WATSONX_PROJECT_ID", None) @@ -134,10 +165,9 @@ def test_init_fails_without_project(self, mock_watsonx): with pytest.raises(ValueError, match="None of the following authentication environment variables are set"): WatsonxChatGenerator(api_key=Secret.from_token("test-api-key")) - def test_to_dict(self, mock_watsonx): + def test_to_dict(self, mock_watsonx: dict[str, AsyncMock | MagicMock]) -> None: generator = WatsonxChatGenerator( - project_id=Secret.from_env_var("WATSONX_PROJECT_ID"), - generation_kwargs={"max_tokens": 100}, + project_id=Secret.from_env_var("WATSONX_PROJECT_ID"), generation_kwargs={"max_tokens": 100} ) data = generator.to_dict() @@ -154,15 +184,17 @@ def test_to_dict(self, mock_watsonx): "timeout": None, "max_retries": None, "streaming_callback": None, + "tools": None, }, } assert data == expected - def test_to_dict_with_params(self, mock_watsonx): + def test_to_dict_with_params(self, mock_watsonx: dict[str, AsyncMock | MagicMock], tools: list[Tool]) -> None: generator = WatsonxChatGenerator( project_id=Secret.from_env_var("WATSONX_PROJECT_ID"), generation_kwargs={"max_tokens": 100}, streaming_callback=print_streaming_chunk, + tools=tools, ) data = generator.to_dict() @@ -179,6 +211,24 @@ def test_to_dict_with_params(self, mock_watsonx): "timeout": None, "max_retries": None, "streaming_callback": "haystack.components.generators.utils.print_streaming_chunk", + "tools": [ + { + "data": { + "description": "useful to determine the weather in a given location", + "function": "tests.test_chat_generator.weather", + "inputs_from_state": None, + "name": "weather", + "outputs_to_state": None, + "outputs_to_string": None, + "parameters": { + "properties": {"city": {"type": "string"}}, + "required": ["city"], + "type": "object", + }, + }, + "type": "haystack.tools.tool.Tool", + }, + ], }, } assert data == expected @@ -216,6 +266,39 @@ def test_from_dict_with_callback(self, mock_watsonx): generator = WatsonxChatGenerator.from_dict(data) assert generator.streaming_callback is print_streaming_chunk + def test_from_dict_with_tools(self, mock_watsonx: dict[str, AsyncMock | MagicMock], tools: list[Tool]) -> None: + data = { + "type": "haystack_integrations.components.generators.watsonx.chat.chat_generator.WatsonxChatGenerator", + "init_parameters": { + "api_key": {"env_vars": ["WATSONX_API_KEY"], "strict": True, "type": "env_var"}, + "model": "ibm/granite-4-h-small", + "project_id": {"env_vars": ["WATSONX_PROJECT_ID"], "strict": True, "type": "env_var"}, + "tools": [ + { + "data": { + "description": "useful to determine the weather in a given location", + "function": "tests.test_chat_generator.weather", + "inputs_from_state": None, + "name": "weather", + "outputs_to_state": None, + "outputs_to_string": None, + "parameters": { + "properties": {"city": {"type": "string"}}, + "required": ["city"], + "type": "object", + }, + }, + "type": "haystack.tools.tool.Tool", + }, + ], + }, + } + + generator = WatsonxChatGenerator.from_dict(data) + assert isinstance(generator.tools, list) + assert len(generator.tools) == len(tools) + assert all(isinstance(tool, Tool) for tool in generator.tools) + def test_run_single_message(self, mock_watsonx): generator = WatsonxChatGenerator( api_key=Secret.from_token("test-api-key"), @@ -227,10 +310,10 @@ def test_run_single_message(self, mock_watsonx): assert len(result["replies"]) == 1 assert result["replies"][0].text == "This is a generated response" - assert result["replies"][0].meta["finish_reason"] == "completed" + assert result["replies"][0].meta["finish_reason"] == "stop" mock_watsonx["model_instance"].chat.assert_called_once_with( - messages=[{"role": "user", "content": "Test prompt"}], params={} + messages=[{"role": "user", "content": "Test prompt"}], params={}, tools=None ) def test_run_with_generation_params(self, mock_watsonx): @@ -247,6 +330,7 @@ def test_run_with_generation_params(self, mock_watsonx): mock_watsonx["model_instance"].chat.assert_called_once_with( messages=[{"role": "user", "content": "Test prompt"}], params={"max_tokens": 100, "temperature": 0.7, "top_p": 0.9}, + tools=None, ) def test_run_with_streaming(self, mock_watsonx): @@ -273,7 +357,7 @@ def test_run_with_streaming(self, mock_watsonx): assert len(result["replies"]) == 1 assert result["replies"][0].text == "Streaming response" - assert result["replies"][0].meta["finish_reason"] == "completed" + assert result["replies"][0].meta["finish_reason"] == "stop" def test_run_with_empty_messages(self, mock_watsonx): generator = WatsonxChatGenerator( @@ -284,19 +368,6 @@ def test_run_with_empty_messages(self, mock_watsonx): result = generator.run(messages=[]) assert result["replies"] == [] - def test_skips_tool_messages(self, mock_watsonx): - generator = WatsonxChatGenerator( - project_id=Secret.from_token("test-project"), - ) - - messages = [ChatMessage.from_user("User message"), ChatMessage.from_tool("Tool result", "test-origin")] - - generator.run(messages=messages) - - mock_watsonx["model_instance"].chat.assert_called_once_with( - messages=[{"role": "user", "content": "User message"}], params={} - ) - def test_init_with_streaming_callback(self, mock_watsonx): def custom_callback(chunk: StreamingChunk): pass @@ -338,7 +409,7 @@ async def test_run_async_single_message(self, mock_watsonx): assert len(result["replies"]) == 1 assert result["replies"][0].text == "Async generated response" - assert result["replies"][0].meta["finish_reason"] == "completed" + assert result["replies"][0].meta["finish_reason"] == "stop" @pytest.mark.asyncio async def test_run_async_streaming(self, mock_watsonx): @@ -451,6 +522,118 @@ def test_prepare_api_call_image_in_non_user_message(self, mock_watsonx): with pytest.raises(ValueError, match="Image content is only supported for user messages"): generator._prepare_api_call(messages=[message]) + def test_convert_chunk_to_streaming_chunk_real_example( + self, mock_watsonx: dict[str, AsyncMock | MagicMock] + ) -> None: + component = WatsonxChatGenerator( + project_id=Secret.from_token("test-project"), model="meta-llama/llama-3-2-11b-vision-instruct" + ) + component_info = ComponentInfo.from_component(component) + + # Chunk 1: Text only + chunk1 = { + "id": "chatcmpl-21e72dd9-ed65-49cc-9ea2-64d971707cda---2dedc26eab5af753744ed4eaa116a197---e0399d75-cd8c-486e-b907-dc211cb70eac", # noqa: E501 + "object": "chat.completion.chunk", + "model_id": "meta-llama/llama-3-2-11b-vision-instruct", + "model": "meta-llama/llama-3-2-11b-vision-instruct", + "choices": [ + { + "index": 0, + "finish_reason": None, + "delta": {"content": "I'll get the weather information for Paris and Berlin"}, + } + ], + "created": 1773250972, + "model_version": "3.2.0", + "created_at": "2026-03-11T17:42:52.921Z", + } + + streaming_chunk1 = component._convert_chunk_to_streaming_chunk(chunk=chunk1, component_info=component_info) + assert streaming_chunk1.content == "I'll get the weather information for Paris and Berlin" + assert streaming_chunk1.tool_calls is None + assert streaming_chunk1.finish_reason is None + assert streaming_chunk1.index == 0 + assert "created" in streaming_chunk1.meta + assert "created_at" in streaming_chunk1.meta + assert "received_at" in streaming_chunk1.meta + assert streaming_chunk1.meta["model"] == "meta-llama/llama-3-2-11b-vision-instruct" + assert streaming_chunk1.meta["model_id"] == "meta-llama/llama-3-2-11b-vision-instruct" + assert streaming_chunk1.meta["model_version"] == "3.2.0" + assert streaming_chunk1.component_info == component_info + + # Chunk 2: Text only + chunk2 = { + "id": "chatcmpl-21e72dd9-ed65-49cc-9ea2-64d971707cda---2dedc26eab5af753744ed4eaa116a197---e0399d75-cd8c-486e-b907-dc211cb70eac", # noqa: E501 + "object": "chat.completion.chunk", + "model_id": "meta-llama/llama-3-2-11b-vision-instruct", + "model": "meta-llama/llama-3-2-11b-vision-instruct", + "choices": [ + {"index": 0, "finish_reason": None, "delta": {"content": " and present it in a structured format."}} + ], + "created": 1773250972, + "model_version": "3.2.0", + "created_at": "2026-03-11T17:42:52.929Z", + } + + streaming_chunk2 = component._convert_chunk_to_streaming_chunk(chunk=chunk2, component_info=component_info) + assert streaming_chunk2.content == " and present it in a structured format." + assert streaming_chunk2.tool_calls is None + assert streaming_chunk2.finish_reason is None + assert streaming_chunk2.index == 0 + assert "created" in streaming_chunk2.meta + assert "created_at" in streaming_chunk2.meta + assert "received_at" in streaming_chunk2.meta + assert streaming_chunk2.meta["model"] == "meta-llama/llama-3-2-11b-vision-instruct" + assert streaming_chunk2.meta["model_id"] == "meta-llama/llama-3-2-11b-vision-instruct" + assert streaming_chunk2.meta["model_version"] == "3.2.0" + assert streaming_chunk2.component_info == component_info + + # Chunk 3: Multiple tool calls (6 function calls) for 2 cities with 3 tools each + chunk3 = { + "id": "chatcmpl-6b615ca6-4aa7-4f79-832f-bedce4641c2b---87fdc1a1cd2032ff0c6776ecfc20b6a5---34576777-949d-4df1-b95f-56d14b848eca", # noqa: E501 + "object": "chat.completion.chunk", + "model_id": "meta-llama/llama-3-2-11b-vision-instruct", + "model": "meta-llama/llama-3-2-11b-vision-instruct", + "choices": [ + { + "index": 0, + "finish_reason": None, + "delta": { + "tool_calls": [ + { + "index": 0, + "id": "chatcmpl-tool-9646185282a54afc86c3572513b2dafa", + "type": "function", + "function": {"name": "weather", "arguments": ""}, + } + ] + }, + } + ], + "created": 1773252289, + "model_version": "3.2.0", + "created_at": "2026-03-11T18:04:49.696Z", + } + + streaming_chunk3 = component._convert_chunk_to_streaming_chunk(chunk=chunk3, component_info=component_info) + assert streaming_chunk3.content == "" + assert streaming_chunk3.tool_calls is not None + assert len(streaming_chunk3.tool_calls) == 1 + assert streaming_chunk3.finish_reason is None + assert streaming_chunk3.index == 0 + assert "created" in streaming_chunk3.meta + assert "created_at" in streaming_chunk3.meta + assert "received_at" in streaming_chunk3.meta + assert streaming_chunk3.meta["model"] == "meta-llama/llama-3-2-11b-vision-instruct" + assert streaming_chunk3.meta["model_id"] == "meta-llama/llama-3-2-11b-vision-instruct" + assert streaming_chunk3.meta["model_version"] == "3.2.0" + assert streaming_chunk3.component_info == component_info + + assert streaming_chunk3.tool_calls[0].tool_name == "weather" + assert streaming_chunk3.tool_calls[0].arguments == "" + assert streaming_chunk3.tool_calls[0].id == "chatcmpl-tool-9646185282a54afc86c3572513b2dafa" + assert streaming_chunk3.tool_calls[0].index == 0 + def test_multimodal_message_processing(self, mock_watsonx): """Test multimodal message processing with mocked model.""" base64_image = ( @@ -550,6 +733,45 @@ def test_live_run(self): assert len(results["replies"][0].text) > 0 assert isinstance(generator.project_id, Secret) + @pytest.mark.skipif( + not os.environ.get("WATSONX_API_KEY") or not os.environ.get("WATSONX_PROJECT_ID"), + reason="WATSONX_API_KEY or WATSONX_PROJECT_ID not set", + ) + def test_live_run_with_toolset(self, tools: list[Tool]) -> None: + """Test that WatsonxChatGenerator can run with a Toolset.""" + toolset = Toolset(tools) + generator = WatsonxChatGenerator( + project_id=Secret.from_env_var("WATSONX_PROJECT_ID"), + generation_kwargs={"max_tokens": 50, "temperature": 0.7, "top_p": 0.9}, + tools=toolset, + ) + messages = [ChatMessage.from_user("What's the weather like in Paris?")] + results = generator.run(messages=messages) + + assert len(results["replies"]) == 1 + message = results["replies"][0] + + # Check if tool calls were made + assert message.tool_calls is not None, "Message has no tool calls" + assert len(message.tool_calls) == 1, "Message has multiple tool calls and it should only have one" + tool_call = message.tool_calls[0] + assert message.meta["finish_reason"] == "tool_calls" + + assert tool_call.tool_name == "weather" + assert tool_call.arguments == {"city": "Paris"} + + # Test full conversation with tool result + tool_result_message = ChatMessage.from_tool(tool_result="22°C, sunny", origin=tool_call) + follow_up_messages = [*messages, message, tool_result_message] + final_results = generator.run(messages=follow_up_messages) + + assert len(final_results["replies"]) == 1 + final_message = final_results["replies"][0] + assert final_message.text + assert "paris" in final_message.text.lower() or "weather" in final_message.text.lower(), ( + "Response does not contain Paris or weather" + ) + @pytest.mark.skipif( not os.environ.get("WATSONX_API_KEY") or not os.environ.get("WATSONX_PROJECT_ID"), reason="WATSONX_API_KEY or WATSONX_PROJECT_ID not set", @@ -572,6 +794,140 @@ def callback(chunk: StreamingChunk): assert len(collected_chunks) > 0 assert all(isinstance(chunk, StreamingChunk) for chunk in collected_chunks) + @pytest.mark.skipif( + not os.environ.get("WATSONX_API_KEY") or not os.environ.get("WATSONX_PROJECT_ID"), + reason="WATSONX_API_KEY or WATSONX_PROJECT_ID not set", + ) + def test_live_run_with_tools_streaming(self, tools: list[Tool]) -> None: + """ + Integration test that the WatsonxChatGenerator component can run with tools and streaming. + """ + component = WatsonxChatGenerator( + project_id=Secret.from_env_var("WATSONX_PROJECT_ID"), tools=tools, streaming_callback=print_streaming_chunk + ) + results = component.run(messages=[ChatMessage.from_user("What's the weather like in Paris?")]) + + assert len(results["replies"]) > 0, "No replies received" + + # Find the message with tool calls + tool_message = None + for message in results["replies"]: + if message.tool_calls: + tool_message = message + break + + assert tool_message is not None, "No message with tool call found" + assert tool_message.tool_calls is not None, "Tool message has no tool calls" + assert len(tool_message.tool_calls) == 1, "Tool message has multiple tool calls" + assert tool_message.tool_calls[0].tool_name == "weather" + assert tool_message.tool_calls[0].arguments == {"city": "Paris"} + + assert isinstance(tool_message, ChatMessage), "Tool message is not a ChatMessage instance" + assert ChatMessage.is_from(tool_message, ChatRole.ASSISTANT), "Tool message is not from the assistant" + assert tool_message.meta["finish_reason"] == "tool_calls" + + tool_call = tool_message.tool_calls[0] + assert tool_call.tool_name == "weather" + assert tool_call.arguments == {"city": "Paris"} + + @pytest.mark.skipif( + not os.environ.get("WATSONX_API_KEY") or not os.environ.get("WATSONX_PROJECT_ID"), + reason="WATSONX_API_KEY or WATSONX_PROJECT_ID not set", + ) + def test_live_run_with_mixed_tools(self) -> None: + """ + Integration test that verifies WatsonxChatGenerator works with mixed Tool and Toolset. + This tests that the LLM can correctly invoke tools from both a standalone Tool and a Toolset. + """ + weather_tool = Tool( + name="weather", + description="useful to determine the weather in a given location", + parameters={ + "type": "object", + "properties": { + "city": { + "type": "string", + "description": "The name of the city to get weather for, e.g. Paris, London", + } + }, + "required": ["city"], + }, + function=weather, + ) + + population_tool = Tool( + name="population", + description="useful to determine the population of a given city", + parameters={ + "type": "object", + "properties": { + "city": { + "type": "string", + "description": "The name of the city to get population for, e.g. Paris, Berlin", + } + }, + "required": ["city"], + }, + function=population, + ) + + # Create a toolset with the population tool + population_toolset = Toolset([population_tool]) + + # Mix standalone tool with toolset + mixed_tools = [weather_tool, population_toolset] + + initial_messages = [ + ChatMessage.from_user("What's the weather like in Paris and what is the population of Berlin?") + ] + component = WatsonxChatGenerator( + model="meta-llama/llama-3-2-11b-vision-instruct", + project_id=Secret.from_env_var("WATSONX_PROJECT_ID"), + tools=mixed_tools, + ) + results = component.run(messages=initial_messages) + + assert len(results["replies"]) > 0, "No replies received" + + first_reply = results["replies"][0] + assert isinstance(first_reply, ChatMessage), "First reply is not a ChatMessage instance" + assert ChatMessage.is_from(first_reply, ChatRole.ASSISTANT), "First reply is not from the assistant" + assert first_reply.tool_calls, "First reply has no tool calls" + + tool_calls = first_reply.tool_calls + assert len(tool_calls) == 2, f"Expected 2 tool calls, got {len(tool_calls)}" + + # Verify we got calls to both weather and population tools + tool_names = {tc.tool_name for tc in tool_calls} + assert "weather" in tool_names, "Expected 'weather' tool call" + assert "population" in tool_names, "Expected 'population' tool call" + + # Verify tool call details + for tool_call in tool_calls: + assert tool_call.tool_name in ["weather", "population"] + assert "city" in tool_call.arguments + assert tool_call.arguments["city"] in ["Paris", "Berlin"] + assert first_reply.meta["finish_reason"] == "tool_calls" + + # Mock the response we'd get from ToolInvoker + tool_result_messages = [] + for tool_call in tool_calls: + if tool_call.tool_name == "weather": + result = "The weather in Paris is sunny and 32°C" + else: # population + result = "The population of Berlin is 2.2 million" + tool_result_messages.append(ChatMessage.from_tool(tool_result=result, origin=tool_call)) + + new_messages = [*initial_messages, first_reply, *tool_result_messages] + results = component.run(messages=new_messages) + + assert len(results["replies"]) == 1 + final_message = results["replies"][0] + assert not final_message.tool_calls + assert len(final_message.text) > 0 + assert "paris" in final_message.text.lower() + assert "berlin" in final_message.text.lower() + @pytest.mark.asyncio @pytest.mark.skipif( not os.environ.get("WATSONX_API_KEY") or not os.environ.get("WATSONX_PROJECT_ID"), @@ -589,6 +945,32 @@ async def test_live_run_async(self): assert isinstance(results["replies"][0], ChatMessage) assert len(results["replies"][0].text) > 0 + @pytest.mark.asyncio + @pytest.mark.skipif( + not os.environ.get("WATSONX_API_KEY") or not os.environ.get("WATSONX_PROJECT_ID"), + reason="WATSONX_API_KEY or WATSONX_PROJECT_ID not set", + ) + async def test_live_run_async_with_tools(self, tools: list[Tool]) -> None: + """Test async version with tools.""" + component = WatsonxChatGenerator(project_id=Secret.from_env_var("WATSONX_PROJECT_ID"), tools=tools) + results = await component.run_async(messages=[ChatMessage.from_user("What's the weather like in Paris?")]) + + assert len(results["replies"]) > 0, "No replies received" + + # Find the message with tool calls + tool_message = None + for message in results["replies"]: + if message.tool_calls: + tool_message = message + break + + assert tool_message is not None, "No message with tool call found" + assert tool_message.tool_calls is not None, "Tool message has no tool calls" + assert len(tool_message.tool_calls) == 1, "Tool message has multiple tool calls" + assert tool_message.tool_calls[0].tool_name == "weather" + assert tool_message.tool_calls[0].arguments == {"city": "Paris"} + assert tool_message.meta["finish_reason"] == "tool_calls" + @pytest.mark.skipif( not os.environ.get("WATSONX_API_KEY") or not os.environ.get("WATSONX_PROJECT_ID"), reason="WATSONX_API_KEY or WATSONX_PROJECT_ID not set", diff --git a/integrations/watsonx/tests/test_generator.py b/integrations/watsonx/tests/test_generator.py index c32e594d1a..8bfcb474d5 100644 --- a/integrations/watsonx/tests/test_generator.py +++ b/integrations/watsonx/tests/test_generator.py @@ -135,6 +135,7 @@ def test_to_dict(self, mock_watsonx): "timeout": None, "max_retries": None, "streaming_callback": None, + "tools": None, }, } assert data == expected @@ -185,7 +186,7 @@ def test_run_with_prompt_only(self, mock_watsonx): assert "usage" in result["meta"][0] mock_watsonx["model_instance"].chat.assert_called_once_with( - messages=[{"role": "user", "content": "Test prompt"}], params={} + messages=[{"role": "user", "content": "Test prompt"}], params={}, tools=None ) def test_run_with_system_prompt(self, mock_watsonx): @@ -203,7 +204,7 @@ def test_run_with_system_prompt(self, mock_watsonx): {"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "Test prompt"}, ] - mock_watsonx["model_instance"].chat.assert_called_once_with(messages=expected_messages, params={}) + mock_watsonx["model_instance"].chat.assert_called_once_with(messages=expected_messages, params={}, tools=None) def test_run_with_generation_kwargs(self, mock_watsonx): generator = WatsonxGenerator( @@ -218,6 +219,7 @@ def test_run_with_generation_kwargs(self, mock_watsonx): mock_watsonx["model_instance"].chat.assert_called_once_with( messages=[{"role": "user", "content": "Test prompt"}], params={"max_tokens": 100, "temperature": 0.7, "top_p": 0.9}, + tools=None, ) def test_run_with_streaming(self, mock_watsonx): @@ -296,7 +298,7 @@ async def test_run_async_with_prompt_only(self, mock_watsonx): assert result["meta"][0]["finish_reason"] == "completed" mock_watsonx["model_instance"].achat.assert_called_once_with( - messages=[{"role": "user", "content": "Test prompt"}], params={} + messages=[{"role": "user", "content": "Test prompt"}], params={}, tools=None ) @pytest.mark.asyncio @@ -315,7 +317,7 @@ async def test_run_async_with_system_prompt(self, mock_watsonx): {"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "Test prompt"}, ] - mock_watsonx["model_instance"].achat.assert_called_once_with(messages=expected_messages, params={}) + mock_watsonx["model_instance"].achat.assert_called_once_with(messages=expected_messages, params={}, tools=None) @pytest.mark.asyncio async def test_run_async_streaming(self, mock_watsonx):