From ddb7b5ca9772732cb4bcc678c4a668f5469e215d Mon Sep 17 00:00:00 2001 From: Max Swain <89113255+maxdswain@users.noreply.github.com> Date: Sun, 1 Mar 2026 18:11:22 +0000 Subject: [PATCH 01/10] feat: Update WatsonXChatGenerator to use the StreamingChunk fields --- .../generators/watsonx/chat/chat_generator.py | 67 +++++++++---------- .../watsonx/tests/test_chat_generator.py | 14 ++-- 2 files changed, 40 insertions(+), 41 deletions(-) diff --git a/integrations/watsonx/src/haystack_integrations/components/generators/watsonx/chat/chat_generator.py b/integrations/watsonx/src/haystack_integrations/components/generators/watsonx/chat/chat_generator.py index ef184596f4..2226da7943 100644 --- a/integrations/watsonx/src/haystack_integrations/components/generators/watsonx/chat/chat_generator.py +++ b/integrations/watsonx/src/haystack_integrations/components/generators/watsonx/chat/chat_generator.py @@ -6,10 +6,12 @@ from typing import Any, Literal, get_args from haystack import component, default_from_dict, default_to_dict, logging +from haystack.components.generators.utils import _convert_streaming_chunks_to_chat_message from haystack.dataclasses import ( AsyncStreamingCallbackT, ChatMessage, ChatRole, + FinishReason, ImageContent, StreamingCallbackT, StreamingChunk, @@ -29,6 +31,17 @@ ImageFormat = Literal["image/jpeg", "image/png"] IMAGE_SUPPORTED_FORMATS: list[ImageFormat] = list(get_args(ImageFormat)) +# See https://ibm.github.io/watsonx-ai-node-sdk/enums/1_6_x.WatsonXAI.TextChatResultChoiceStream.Constants.FinishReason.html +# for possible finish reasons +FINISH_REASON_MAPPING: dict[str, FinishReason] = { + "cancelled": "stop", + "error": "stop", + "length": "length", + "stop": "stop", + "time_limit": "stop", + "tool_calls": "tool_calls", +} + @component class WatsonxChatGenerator: @@ -327,6 +340,22 @@ def _prepare_api_call( return {"messages": watsonx_messages, "params": merged_kwargs} + def _convert_chunk_to_streaming_chunk(self, content: str, chunk: dict[str, Any]) -> StreamingChunk: + """ + Convert one Watsonx AI stream-chunk to Haystack StreamingChunk. + """ + chunk_meta = { + "model": self.model, + "received_at": datetime.now(timezone.utc).isoformat(), + } + streaming_chunk = StreamingChunk( + content=content, + meta=chunk_meta, + index=chunk["choices"][0].get("index", 0), + finish_reason=FINISH_REASON_MAPPING.get(chunk["choices"][0].get("finish_reason")), + ) + return streaming_chunk + def _handle_streaming( self, *, @@ -350,17 +379,11 @@ def _handle_streaming( content = chunk["choices"][0].get("delta", {}).get("content", "") if content: - chunk_meta = { - "model": self.model, - "index": chunk["choices"][0].get("index", 0), - "finish_reason": chunk["choices"][0].get("finish_reason"), - "received_at": datetime.now(timezone.utc).isoformat(), - } - streaming_chunk = StreamingChunk(content=content, meta=chunk_meta) + streaming_chunk = self._convert_chunk_to_streaming_chunk(content, chunk) chunks.append(streaming_chunk) callback(streaming_chunk) - return {"replies": [self._convert_streaming_chunks_to_chat_message(chunks)]} + return {"replies": [_convert_streaming_chunks_to_chat_message(chunks)]} def _handle_standard(self, api_args: dict[str, Any]) -> dict[str, list[ChatMessage]]: """Handle synchronous standard response.""" @@ -383,35 +406,11 @@ async def _handle_async_streaming( content = chunk["choices"][0].get("delta", {}).get("content", "") if content: - chunk_meta = { - "model": self.model, - "index": chunk["choices"][0].get("index", 0), - "finish_reason": chunk["choices"][0].get("finish_reason"), - "received_at": datetime.now(timezone.utc).isoformat(), - } - streaming_chunk = StreamingChunk(content=content, meta=chunk_meta) + streaming_chunk = self._convert_chunk_to_streaming_chunk(content, chunk) chunks.append(streaming_chunk) await callback(streaming_chunk) - return {"replies": [self._convert_streaming_chunks_to_chat_message(chunks)]} - - def _convert_streaming_chunks_to_chat_message(self, chunks: list[StreamingChunk]) -> ChatMessage: - """Convert list of streaming chunks to a single ChatMessage.""" - if not chunks: - return ChatMessage.from_assistant("") - - content = "".join(chunk.content for chunk in chunks) - last_chunk_meta = chunks[-1].meta if chunks else {} - - return ChatMessage.from_assistant( - text=content, - meta={ - "model": self.model, - "finish_reason": last_chunk_meta.get("finish_reason"), - "usage": last_chunk_meta.get("usage", {}), - "chunks_count": len(chunks), - }, - ) + return {"replies": [_convert_streaming_chunks_to_chat_message(chunks)]} async def _handle_async_standard(self, api_args: dict[str, Any]) -> dict[str, list[ChatMessage]]: """Handle asynchronous standard response.""" diff --git a/integrations/watsonx/tests/test_chat_generator.py b/integrations/watsonx/tests/test_chat_generator.py index 41c36fbfd9..9062412ac0 100644 --- a/integrations/watsonx/tests/test_chat_generator.py +++ b/integrations/watsonx/tests/test_chat_generator.py @@ -41,7 +41,7 @@ def mock_watsonx(self, monkeypatch): { "message": {"content": "This is a generated response", "role": "assistant"}, "index": 0, - "finish_reason": "completed", + "finish_reason": "stop", } ], "usage": {"prompt_tokens": 10, "completion_tokens": 20, "total_tokens": 30}, @@ -53,7 +53,7 @@ def mock_watsonx(self, monkeypatch): { "message": {"content": "Async generated response", "role": "assistant"}, "index": 0, - "finish_reason": "completed", + "finish_reason": "stop", } ], "usage": {"prompt_tokens": 10, "completion_tokens": 20, "total_tokens": 30}, @@ -62,7 +62,7 @@ def mock_watsonx(self, monkeypatch): mock_model_instance.chat_stream = MagicMock( return_value=[ {"choices": [{"delta": {"content": "Streaming"}, "index": 0, "finish_reason": None}]}, - {"choices": [{"delta": {"content": " response"}, "index": 0, "finish_reason": "completed"}]}, + {"choices": [{"delta": {"content": " response"}, "index": 0, "finish_reason": "stop"}]}, ] ) @@ -85,7 +85,7 @@ async def __anext__(self): elif self._count == 2: return { "choices": [ - {"delta": {"content": " response"}, "finish_reason": "completed", "index": 0} + {"delta": {"content": " response"}, "finish_reason": "stop", "index": 0} ] } else: @@ -227,7 +227,7 @@ def test_run_single_message(self, mock_watsonx): assert len(result["replies"]) == 1 assert result["replies"][0].text == "This is a generated response" - assert result["replies"][0].meta["finish_reason"] == "completed" + assert result["replies"][0].meta["finish_reason"] == "stop" mock_watsonx["model_instance"].chat.assert_called_once_with( messages=[{"role": "user", "content": "Test prompt"}], params={} @@ -273,7 +273,7 @@ def test_run_with_streaming(self, mock_watsonx): assert len(result["replies"]) == 1 assert result["replies"][0].text == "Streaming response" - assert result["replies"][0].meta["finish_reason"] == "completed" + assert result["replies"][0].meta["finish_reason"] == "stop" def test_run_with_empty_messages(self, mock_watsonx): generator = WatsonxChatGenerator( @@ -338,7 +338,7 @@ async def test_run_async_single_message(self, mock_watsonx): assert len(result["replies"]) == 1 assert result["replies"][0].text == "Async generated response" - assert result["replies"][0].meta["finish_reason"] == "completed" + assert result["replies"][0].meta["finish_reason"] == "stop" @pytest.mark.asyncio async def test_run_async_streaming(self, mock_watsonx): From e8726eb764167692fdd5fc20b358b6aabc198320 Mon Sep 17 00:00:00 2001 From: Max Swain <89113255+maxdswain@users.noreply.github.com> Date: Sun, 1 Mar 2026 22:48:56 +0000 Subject: [PATCH 02/10] fix: Formatting issue --- integrations/watsonx/tests/test_chat_generator.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/integrations/watsonx/tests/test_chat_generator.py b/integrations/watsonx/tests/test_chat_generator.py index 9062412ac0..0c0525b6af 100644 --- a/integrations/watsonx/tests/test_chat_generator.py +++ b/integrations/watsonx/tests/test_chat_generator.py @@ -84,9 +84,7 @@ async def __anext__(self): } elif self._count == 2: return { - "choices": [ - {"delta": {"content": " response"}, "finish_reason": "stop", "index": 0} - ] + "choices": [{"delta": {"content": " response"}, "finish_reason": "stop", "index": 0}] } else: raise StopAsyncIteration From ada14e1a0b2fc342de14f6ebcff58c389aabd3c2 Mon Sep 17 00:00:00 2001 From: Max Swain <89113255+maxdswain@users.noreply.github.com> Date: Tue, 10 Mar 2026 18:21:36 +0000 Subject: [PATCH 03/10] feat(watsonx): Add support for tool calling --- .../generators/watsonx/chat/chat_generator.py | 138 ++++++++++++++---- 1 file changed, 106 insertions(+), 32 deletions(-) diff --git a/integrations/watsonx/src/haystack_integrations/components/generators/watsonx/chat/chat_generator.py b/integrations/watsonx/src/haystack_integrations/components/generators/watsonx/chat/chat_generator.py index 2226da7943..2ca7efe518 100644 --- a/integrations/watsonx/src/haystack_integrations/components/generators/watsonx/chat/chat_generator.py +++ b/integrations/watsonx/src/haystack_integrations/components/generators/watsonx/chat/chat_generator.py @@ -11,14 +11,18 @@ AsyncStreamingCallbackT, ChatMessage, ChatRole, + ComponentInfo, FinishReason, ImageContent, StreamingCallbackT, StreamingChunk, SyncStreamingCallbackT, TextContent, + ToolCall, + ToolCallDelta, select_streaming_callback, ) +from haystack.tools import ToolsType, _check_duplicate_tool_names, flatten_tools_or_toolsets from haystack.utils import Secret, deserialize_callable, deserialize_secrets_inplace, serialize_callable from ibm_watsonx_ai import Credentials from ibm_watsonx_ai.foundation_models import ModelInference @@ -113,6 +117,7 @@ def __init__( max_retries: int | None = None, verify: bool | str | None = None, streaming_callback: StreamingCallbackT | None = None, + tools: ToolsType | None = None, ) -> None: """ Creates an instance of WatsonxChatGenerator. @@ -149,6 +154,8 @@ def __init__( - False: Skip verification (insecure) - Path to CA bundle for custom certificates :param streaming_callback: A callback function for streaming responses. + :param tools: + A list of Tool and/or Toolset objects, or a single Toolset for which the model can prepare calls. """ self.api_key = api_key self.model = model @@ -219,6 +226,7 @@ def run( messages: list[ChatMessage], generation_kwargs: dict[str, Any] | None = None, streaming_callback: StreamingCallbackT | None = None, + tools: ToolsType | None = None, ) -> dict[str, list[ChatMessage]]: """ Generate chat completions synchronously. @@ -231,6 +239,9 @@ def run( :param streaming_callback: A callback function that is called when a new token is received from the stream. If provided this will override the `streaming_callback` set in the `__init__` method. + :param tools: + A list of Tool and/or Toolset objects, or a single Toolset for which the model can prepare calls. + If set, it will override the `tools` parameter provided during initialization. :returns: A dictionary with the following key: - `replies`: A list containing the generated responses as ChatMessage instances. @@ -242,7 +253,7 @@ def run( init_callback=self.streaming_callback, runtime_callback=streaming_callback, requires_async=False ) - api_args = self._prepare_api_call(messages=messages, generation_kwargs=generation_kwargs) + api_args = self._prepare_api_call(messages=messages, generation_kwargs=generation_kwargs, tools=tools) if resolved_streaming_callback: return self._handle_streaming(api_args=api_args, callback=resolved_streaming_callback) @@ -256,6 +267,7 @@ async def run_async( messages: list[ChatMessage], generation_kwargs: dict[str, Any] | None = None, streaming_callback: StreamingCallbackT | None = None, + tools: ToolsType | None = None, ) -> dict[str, list[ChatMessage]]: """ Generate chat completions asynchronously. @@ -268,6 +280,9 @@ async def run_async( :param streaming_callback: A callback function that is called when a new token is received from the stream. If provided this will override the `streaming_callback` set in the `__init__` method. + :param tools: + A list of Tool and/or Toolset objects, or a single Toolset for which the model can prepare calls. + If set, it will override the `tools` parameter provided during initialization. :returns: A dictionary with the following key: - `replies`: A list containing the generated responses as ChatMessage instances. @@ -279,7 +294,7 @@ async def run_async( init_callback=self.streaming_callback, runtime_callback=streaming_callback, requires_async=True ) - api_args = self._prepare_api_call(messages=messages, generation_kwargs=generation_kwargs) + api_args = self._prepare_api_call(messages=messages, generation_kwargs=generation_kwargs, tools=tools) if resolved_streaming_callback: return await self._handle_async_streaming(api_args=api_args, callback=resolved_streaming_callback) @@ -287,16 +302,27 @@ async def run_async( return await self._handle_async_standard(api_args) def _prepare_api_call( - self, *, messages: list[ChatMessage], generation_kwargs: dict[str, Any] | None = None + self, + *, + messages: list[ChatMessage], + generation_kwargs: dict[str, Any] | None = None, + tools: ToolsType | None = None, ) -> dict[str, Any]: merged_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})} watsonx_messages = [] content: str | None | dict[str, Any] | list[dict[str, Any]] + flattened_tools = flatten_tools_or_toolsets(tools or self.tools) + _check_duplicate_tool_names(flattened_tools) + tool_definitions = [{"type": "function", "function": {**tool.tool_spec}} for tool in flattened_tools] + for msg in messages: - if msg.is_from("tool"): - logger.debug("Skipping tool message - tool calls are not currently supported") + # Watsonx tool call result messages are of the same format as OpenAI chat completions + if msg.tool_call_results: + watsonx_messages.append( + msg._tool_result_message_to_openai({"role": msg.role.value}, require_tool_call_ids=True) + ) continue # Check that images are only in user messages @@ -338,23 +364,53 @@ def _prepare_api_call( merged_kwargs.pop("stream", None) - return {"messages": watsonx_messages, "params": merged_kwargs} + api_args = {"messages": watsonx_messages, "params": merged_kwargs} + if tool_definitions: + api_args["tools"] = tool_definitions - def _convert_chunk_to_streaming_chunk(self, content: str, chunk: dict[str, Any]) -> StreamingChunk: + return api_args + + def _convert_chunk_to_streaming_chunk(self, chunk: dict[str, Any], component_info: ComponentInfo) -> StreamingChunk: """ Convert one Watsonx AI stream-chunk to Haystack StreamingChunk. """ + choice = chunk["choices"][0] chunk_meta = { "model": self.model, "received_at": datetime.now(timezone.utc).isoformat(), } - streaming_chunk = StreamingChunk( - content=content, + + if choice["delta"] and (choice_delta_tool_calls := choice["delta"]["tool_calls"]): + # create a list of ToolCallDelta objects from the tool calls + tool_calls_deltas = [ + ToolCallDelta( + index=tool_call["index"], + id=tool_call["id"], + tool_name=tool_call.get("function", {}).get("name"), + arguments=tool_call.get("function", {}).get("arguments"), + ) + for tool_call in choice_delta_tool_calls + ] + return StreamingChunk( + content=choice.get("delta", {}).get("content", ""), + meta=chunk_meta, + component_info=component_info, + # We adopt the first tool_calls_deltas.index as the overall index of the chunk to match OpenAI + index=tool_calls_deltas[0].index, + tool_calls=tool_calls_deltas, + start=tool_calls_deltas[0].tool_name is not None, + finish_reason=FINISH_REASON_MAPPING.get(choice.get("finish_reason")), + ) + + index = choice.get("index", 0) + return StreamingChunk( + content=choice.get("delta", {}).get("content", ""), meta=chunk_meta, - index=chunk["choices"][0].get("index", 0), - finish_reason=FINISH_REASON_MAPPING.get(chunk["choices"][0].get("finish_reason")), + component_info=component_info, + index=index, + start=index == 0, + finish_reason=FINISH_REASON_MAPPING.get(choice.get("finish_reason")), ) - return streaming_chunk def _handle_streaming( self, @@ -371,23 +427,24 @@ def _handle_streaming( A dictionary with the generated responses as ChatMessage instances. """ chunks: list[StreamingChunk] = [] - stream = self.client.chat_stream(messages=api_args["messages"], params=api_args["params"]) + stream = self.client.chat_stream( + messages=api_args["messages"], params=api_args["params"], tools=api_args["tools"] + ) + component_info = ComponentInfo.from_component(self) for chunk in stream: if not isinstance(chunk, dict) or not chunk.get("choices"): continue - content = chunk["choices"][0].get("delta", {}).get("content", "") - if content: - streaming_chunk = self._convert_chunk_to_streaming_chunk(content, chunk) - chunks.append(streaming_chunk) - callback(streaming_chunk) + streaming_chunk = self._convert_chunk_to_streaming_chunk(chunk, component_info) + chunks.append(streaming_chunk) + callback(streaming_chunk) return {"replies": [_convert_streaming_chunks_to_chat_message(chunks)]} def _handle_standard(self, api_args: dict[str, Any]) -> dict[str, list[ChatMessage]]: """Handle synchronous standard response.""" - response = self.client.chat(messages=api_args["messages"], params=api_args["params"]) + response = self.client.chat(messages=api_args["messages"], params=api_args["params"], tools=api_args["tools"]) return self._process_response(response) async def _handle_async_streaming( @@ -398,23 +455,26 @@ async def _handle_async_streaming( ) -> dict[str, list[ChatMessage]]: """Handle asynchronous streaming response.""" chunks: list[StreamingChunk] = [] - stream_generator = await self.client.achat_stream(messages=api_args["messages"], params=api_args["params"]) + stream_generator = await self.client.achat_stream( + messages=api_args["messages"], params=api_args["params"], tools=api_args["tools"] + ) + component_info = ComponentInfo.from_component(self) async for chunk in stream_generator: if not isinstance(chunk, dict) or not chunk.get("choices"): continue - content = chunk["choices"][0].get("delta", {}).get("content", "") - if content: - streaming_chunk = self._convert_chunk_to_streaming_chunk(content, chunk) - chunks.append(streaming_chunk) - await callback(streaming_chunk) + streaming_chunk = self._convert_chunk_to_streaming_chunk(chunk, component_info) + chunks.append(streaming_chunk) + await callback(streaming_chunk) return {"replies": [_convert_streaming_chunks_to_chat_message(chunks)]} async def _handle_async_standard(self, api_args: dict[str, Any]) -> dict[str, list[ChatMessage]]: """Handle asynchronous standard response.""" - response = await self.client.achat(messages=api_args["messages"], params=api_args["params"]) + response = await self.client.achat( + messages=api_args["messages"], params=api_args["params"], tools=api_args["tools"] + ) return self._process_response(response) def _process_response(self, response: dict[str, Any]) -> dict[str, list[ChatMessage]]: @@ -422,12 +482,25 @@ def _process_response(self, response: dict[str, Any]) -> dict[str, list[ChatMess if not response.get("choices"): return {"replies": []} - choice = response["choices"][0] - message = choice.get("message", {}) - return { - "replies": [ + choices = response["choices"] + chat_messages = [] + for choice in choices: + message = choice.get("message", {}) + + if tool_calls := message.get("tool_calls", []): + message_tool_calls = [ + ToolCall( + id=tool_call["id"], + tool_name=tool_call["function"]["name"], + arguments=tool_call["function"]["arguments"], + ) + for tool_call in tool_calls + ] + + chat_messages.append( ChatMessage.from_assistant( text=message.get("content", ""), + tool_calls=message_tool_calls, meta={ "model": self.model, "index": choice.get("index", 0), @@ -435,5 +508,6 @@ def _process_response(self, response: dict[str, Any]) -> dict[str, list[ChatMess "usage": response.get("usage", {}), }, ) - ] - } + ) + + return {"replies": chat_messages} From eef119469ac4c3852dd92e41b6901ee40ffb6f7d Mon Sep 17 00:00:00 2001 From: Max Swain <89113255+maxdswain@users.noreply.github.com> Date: Wed, 11 Mar 2026 12:04:47 +0000 Subject: [PATCH 04/10] fix: Fix dict key error bugs and update existing tests --- .../generators/watsonx/chat/chat_generator.py | 14 +++++++++----- .../watsonx/tests/test_chat_generator.py | 18 +++--------------- integrations/watsonx/tests/test_generator.py | 9 +++++---- 3 files changed, 17 insertions(+), 24 deletions(-) diff --git a/integrations/watsonx/src/haystack_integrations/components/generators/watsonx/chat/chat_generator.py b/integrations/watsonx/src/haystack_integrations/components/generators/watsonx/chat/chat_generator.py index 2ca7efe518..0b4908f806 100644 --- a/integrations/watsonx/src/haystack_integrations/components/generators/watsonx/chat/chat_generator.py +++ b/integrations/watsonx/src/haystack_integrations/components/generators/watsonx/chat/chat_generator.py @@ -166,6 +166,7 @@ def __init__( self.max_retries = max_retries self.verify = verify self.streaming_callback = streaming_callback + self.tools = tools self._initialize_client() @@ -380,7 +381,7 @@ def _convert_chunk_to_streaming_chunk(self, chunk: dict[str, Any], component_inf "received_at": datetime.now(timezone.utc).isoformat(), } - if choice["delta"] and (choice_delta_tool_calls := choice["delta"]["tool_calls"]): + if choice["delta"] and (choice_delta_tool_calls := choice["delta"].get("tool_calls")): # create a list of ToolCallDelta objects from the tool calls tool_calls_deltas = [ ToolCallDelta( @@ -428,7 +429,7 @@ def _handle_streaming( """ chunks: list[StreamingChunk] = [] stream = self.client.chat_stream( - messages=api_args["messages"], params=api_args["params"], tools=api_args["tools"] + messages=api_args["messages"], params=api_args["params"], tools=api_args.get("tools") ) component_info = ComponentInfo.from_component(self) @@ -444,7 +445,9 @@ def _handle_streaming( def _handle_standard(self, api_args: dict[str, Any]) -> dict[str, list[ChatMessage]]: """Handle synchronous standard response.""" - response = self.client.chat(messages=api_args["messages"], params=api_args["params"], tools=api_args["tools"]) + response = self.client.chat( + messages=api_args["messages"], params=api_args["params"], tools=api_args.get("tools") + ) return self._process_response(response) async def _handle_async_streaming( @@ -456,7 +459,7 @@ async def _handle_async_streaming( """Handle asynchronous streaming response.""" chunks: list[StreamingChunk] = [] stream_generator = await self.client.achat_stream( - messages=api_args["messages"], params=api_args["params"], tools=api_args["tools"] + messages=api_args["messages"], params=api_args["params"], tools=api_args.get("tools") ) component_info = ComponentInfo.from_component(self) @@ -473,7 +476,7 @@ async def _handle_async_streaming( async def _handle_async_standard(self, api_args: dict[str, Any]) -> dict[str, list[ChatMessage]]: """Handle asynchronous standard response.""" response = await self.client.achat( - messages=api_args["messages"], params=api_args["params"], tools=api_args["tools"] + messages=api_args["messages"], params=api_args["params"], tools=api_args.get("tools") ) return self._process_response(response) @@ -487,6 +490,7 @@ def _process_response(self, response: dict[str, Any]) -> dict[str, list[ChatMess for choice in choices: message = choice.get("message", {}) + message_tool_calls: list[ToolCall] | None = None if tool_calls := message.get("tool_calls", []): message_tool_calls = [ ToolCall( diff --git a/integrations/watsonx/tests/test_chat_generator.py b/integrations/watsonx/tests/test_chat_generator.py index 0c0525b6af..e320654048 100644 --- a/integrations/watsonx/tests/test_chat_generator.py +++ b/integrations/watsonx/tests/test_chat_generator.py @@ -66,7 +66,7 @@ def mock_watsonx(self, monkeypatch): ] ) - async def mock_achat_stream(messages=None, params=None): + async def mock_achat_stream(messages=None, params=None, tools=None): class MockAsyncGenerator: def __init__(self): self._count = 0 @@ -228,7 +228,7 @@ def test_run_single_message(self, mock_watsonx): assert result["replies"][0].meta["finish_reason"] == "stop" mock_watsonx["model_instance"].chat.assert_called_once_with( - messages=[{"role": "user", "content": "Test prompt"}], params={} + messages=[{"role": "user", "content": "Test prompt"}], params={}, tools=None ) def test_run_with_generation_params(self, mock_watsonx): @@ -245,6 +245,7 @@ def test_run_with_generation_params(self, mock_watsonx): mock_watsonx["model_instance"].chat.assert_called_once_with( messages=[{"role": "user", "content": "Test prompt"}], params={"max_tokens": 100, "temperature": 0.7, "top_p": 0.9}, + tools=None, ) def test_run_with_streaming(self, mock_watsonx): @@ -282,19 +283,6 @@ def test_run_with_empty_messages(self, mock_watsonx): result = generator.run(messages=[]) assert result["replies"] == [] - def test_skips_tool_messages(self, mock_watsonx): - generator = WatsonxChatGenerator( - project_id=Secret.from_token("test-project"), - ) - - messages = [ChatMessage.from_user("User message"), ChatMessage.from_tool("Tool result", "test-origin")] - - generator.run(messages=messages) - - mock_watsonx["model_instance"].chat.assert_called_once_with( - messages=[{"role": "user", "content": "User message"}], params={} - ) - def test_init_with_streaming_callback(self, mock_watsonx): def custom_callback(chunk: StreamingChunk): pass diff --git a/integrations/watsonx/tests/test_generator.py b/integrations/watsonx/tests/test_generator.py index c32e594d1a..260617b63d 100644 --- a/integrations/watsonx/tests/test_generator.py +++ b/integrations/watsonx/tests/test_generator.py @@ -185,7 +185,7 @@ def test_run_with_prompt_only(self, mock_watsonx): assert "usage" in result["meta"][0] mock_watsonx["model_instance"].chat.assert_called_once_with( - messages=[{"role": "user", "content": "Test prompt"}], params={} + messages=[{"role": "user", "content": "Test prompt"}], params={}, tools=None ) def test_run_with_system_prompt(self, mock_watsonx): @@ -203,7 +203,7 @@ def test_run_with_system_prompt(self, mock_watsonx): {"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "Test prompt"}, ] - mock_watsonx["model_instance"].chat.assert_called_once_with(messages=expected_messages, params={}) + mock_watsonx["model_instance"].chat.assert_called_once_with(messages=expected_messages, params={}, tools=None) def test_run_with_generation_kwargs(self, mock_watsonx): generator = WatsonxGenerator( @@ -218,6 +218,7 @@ def test_run_with_generation_kwargs(self, mock_watsonx): mock_watsonx["model_instance"].chat.assert_called_once_with( messages=[{"role": "user", "content": "Test prompt"}], params={"max_tokens": 100, "temperature": 0.7, "top_p": 0.9}, + tools=None, ) def test_run_with_streaming(self, mock_watsonx): @@ -296,7 +297,7 @@ async def test_run_async_with_prompt_only(self, mock_watsonx): assert result["meta"][0]["finish_reason"] == "completed" mock_watsonx["model_instance"].achat.assert_called_once_with( - messages=[{"role": "user", "content": "Test prompt"}], params={} + messages=[{"role": "user", "content": "Test prompt"}], params={}, tools=None ) @pytest.mark.asyncio @@ -315,7 +316,7 @@ async def test_run_async_with_system_prompt(self, mock_watsonx): {"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "Test prompt"}, ] - mock_watsonx["model_instance"].achat.assert_called_once_with(messages=expected_messages, params={}) + mock_watsonx["model_instance"].achat.assert_called_once_with(messages=expected_messages, params={}, tools=None) @pytest.mark.asyncio async def test_run_async_streaming(self, mock_watsonx): From b6144e7a5be29f69b0040ea79d908bcaf57ea5b1 Mon Sep 17 00:00:00 2001 From: Max Swain <89113255+maxdswain@users.noreply.github.com> Date: Wed, 11 Mar 2026 12:07:28 +0000 Subject: [PATCH 05/10] Increase minimum version required --- integrations/watsonx/pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/integrations/watsonx/pyproject.toml b/integrations/watsonx/pyproject.toml index a66eae071a..a0771eeb8f 100644 --- a/integrations/watsonx/pyproject.toml +++ b/integrations/watsonx/pyproject.toml @@ -22,7 +22,7 @@ classifiers = [ "Programming Language :: Python :: 3.12", "Programming Language :: Python :: 3.13", ] -dependencies = ["haystack-ai>=2.17.1", "ibm-watsonx-ai>=1.3.26", "pandas>=2.2.3"] +dependencies = ["haystack-ai>=2.24.1", "ibm-watsonx-ai>=1.3.26", "pandas>=2.2.3"] [project.urls] Documentation = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/watsonx#readme" From 3d5d11c3e50c53ce307f30dfd05d1b2176d5665b Mon Sep 17 00:00:00 2001 From: Max Swain <89113255+maxdswain@users.noreply.github.com> Date: Wed, 11 Mar 2026 14:39:10 +0000 Subject: [PATCH 06/10] test: Add tests for tool calling and improve robustness of tool calling --- .../generators/watsonx/chat/chat_generator.py | 16 +- .../watsonx/tests/test_chat_generator.py | 294 +++++++++++++++++- integrations/watsonx/tests/test_generator.py | 1 + 3 files changed, 301 insertions(+), 10 deletions(-) diff --git a/integrations/watsonx/src/haystack_integrations/components/generators/watsonx/chat/chat_generator.py b/integrations/watsonx/src/haystack_integrations/components/generators/watsonx/chat/chat_generator.py index 0b4908f806..bf4dee90e7 100644 --- a/integrations/watsonx/src/haystack_integrations/components/generators/watsonx/chat/chat_generator.py +++ b/integrations/watsonx/src/haystack_integrations/components/generators/watsonx/chat/chat_generator.py @@ -2,6 +2,7 @@ # # SPDX-License-Identifier: Apache-2.0 +import json from datetime import datetime, timezone from typing import Any, Literal, get_args @@ -22,7 +23,13 @@ ToolCallDelta, select_streaming_callback, ) -from haystack.tools import ToolsType, _check_duplicate_tool_names, flatten_tools_or_toolsets +from haystack.tools import ( + ToolsType, + _check_duplicate_tool_names, + deserialize_tools_or_toolset_inplace, + flatten_tools_or_toolsets, + serialize_tools_or_toolset, +) from haystack.utils import Secret, deserialize_callable, deserialize_secrets_inplace, serialize_callable from ibm_watsonx_ai import Credentials from ibm_watsonx_ai.foundation_models import ModelInference @@ -190,6 +197,7 @@ def to_dict(self) -> dict[str, Any]: The serialized component as a dictionary. """ callback_name = serialize_callable(self.streaming_callback) if self.streaming_callback else None + serialized_tools = serialize_tools_or_toolset(self.tools) if self.tools else None return default_to_dict( self, model=self.model, @@ -201,6 +209,7 @@ def to_dict(self) -> dict[str, Any]: max_retries=self.max_retries, verify=self.verify, streaming_callback=callback_name, + tools=serialized_tools, ) @classmethod @@ -214,6 +223,7 @@ def from_dict(cls, data: dict[str, Any]) -> "WatsonxChatGenerator": The deserialized component instance. """ deserialize_secrets_inplace(data["init_parameters"], keys=["api_key", "project_id"]) + deserialize_tools_or_toolset_inplace(data["init_parameters"], key="tools") init_params = data.get("init_parameters", {}) serialized_callback = init_params.get("streaming_callback") if serialized_callback: @@ -386,7 +396,7 @@ def _convert_chunk_to_streaming_chunk(self, chunk: dict[str, Any], component_inf tool_calls_deltas = [ ToolCallDelta( index=tool_call["index"], - id=tool_call["id"], + id=tool_call.get("id"), tool_name=tool_call.get("function", {}).get("name"), arguments=tool_call.get("function", {}).get("arguments"), ) @@ -496,7 +506,7 @@ def _process_response(self, response: dict[str, Any]) -> dict[str, list[ChatMess ToolCall( id=tool_call["id"], tool_name=tool_call["function"]["name"], - arguments=tool_call["function"]["arguments"], + arguments=json.loads(tool_call["function"]["arguments"]), ) for tool_call in tool_calls ] diff --git a/integrations/watsonx/tests/test_chat_generator.py b/integrations/watsonx/tests/test_chat_generator.py index e320654048..f7a8f9464c 100644 --- a/integrations/watsonx/tests/test_chat_generator.py +++ b/integrations/watsonx/tests/test_chat_generator.py @@ -2,12 +2,14 @@ # # SPDX-License-Identifier: Apache-2.0 import os +from collections.abc import Generator from unittest.mock import AsyncMock, MagicMock, patch import pytest from haystack import logging from haystack.components.generators.utils import print_streaming_chunk -from haystack.dataclasses import ChatMessage, ImageContent, StreamingChunk +from haystack.dataclasses import ChatMessage, ChatRole, ImageContent, StreamingChunk +from haystack.tools import Tool, Toolset from haystack.utils import Secret from haystack_integrations.components.generators.watsonx.chat.chat_generator import WatsonxChatGenerator @@ -15,9 +17,30 @@ logger = logging.getLogger(__name__) +def weather(city: str): + """Get weather information for a city.""" + return f"Weather in {city}: 22°C, sunny" + + +def population(city: str) -> str: + return f"The population of {city} is 2.2 million" + + +@pytest.fixture +def tools(): + return [ + Tool( + name="weather", + description="useful to determine the weather in a given location", + parameters={"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]}, + function=weather, + ) + ] + + class TestWatsonxChatGenerator: @pytest.fixture - def mock_watsonx(self, monkeypatch): + def mock_watsonx(self, monkeypatch) -> Generator[dict[str, AsyncMock | MagicMock], None]: """Fixture for setting up common mocks""" monkeypatch.setenv("WATSONX_API_KEY", "fake-api-key") monkeypatch.setenv("WATSONX_PROJECT_ID", "fake-project-id") @@ -108,14 +131,18 @@ def test_init_default(self, mock_watsonx): assert isinstance(generator.project_id, Secret) assert generator.project_id.resolve_value() == "fake-project-id" assert generator.api_base_url == "https://us-south.ml.cloud.ibm.com" + assert generator.tools is None + + def test_init_with_all_params(self, mock_watsonx: dict[str, AsyncMock | MagicMock]) -> None: + tool = Tool(name="name", description="description", parameters={"x": {"type": "string"}}, function=weather) - def test_init_with_all_params(self, mock_watsonx): generator = WatsonxChatGenerator( api_key=Secret.from_token("test-api-key"), project_id=Secret.from_token("test-project"), api_base_url="https://custom-url.com", generation_kwargs={"max_tokens": 100, "temperature": 0.7, "top_p": 0.9}, verify=False, + tools=[tool], ) _, kwargs = mock_watsonx["model"].call_args @@ -125,6 +152,12 @@ def test_init_with_all_params(self, mock_watsonx): assert isinstance(generator.project_id, Secret) assert generator.project_id.resolve_value() == "test-project" + assert generator.tools == [tool] + + def test_init_with_toolset(self, mock_watsonx: dict[str, AsyncMock | MagicMock], tools: list[Tool]) -> None: + toolset = Toolset(tools) + generator = WatsonxChatGenerator(project_id=Secret.from_token("fake-project-id"), tools=toolset) + assert generator.tools == toolset def test_init_fails_without_project(self, mock_watsonx): os.environ.pop("WATSONX_PROJECT_ID", None) @@ -132,10 +165,9 @@ def test_init_fails_without_project(self, mock_watsonx): with pytest.raises(ValueError, match="None of the following authentication environment variables are set"): WatsonxChatGenerator(api_key=Secret.from_token("test-api-key")) - def test_to_dict(self, mock_watsonx): + def test_to_dict(self, mock_watsonx: dict[str, AsyncMock | MagicMock]) -> None: generator = WatsonxChatGenerator( - project_id=Secret.from_env_var("WATSONX_PROJECT_ID"), - generation_kwargs={"max_tokens": 100}, + project_id=Secret.from_env_var("WATSONX_PROJECT_ID"), generation_kwargs={"max_tokens": 100} ) data = generator.to_dict() @@ -152,15 +184,17 @@ def test_to_dict(self, mock_watsonx): "timeout": None, "max_retries": None, "streaming_callback": None, + "tools": None, }, } assert data == expected - def test_to_dict_with_params(self, mock_watsonx): + def test_to_dict_with_params(self, mock_watsonx: dict[str, AsyncMock | MagicMock], tools: list[Tool]) -> None: generator = WatsonxChatGenerator( project_id=Secret.from_env_var("WATSONX_PROJECT_ID"), generation_kwargs={"max_tokens": 100}, streaming_callback=print_streaming_chunk, + tools=tools, ) data = generator.to_dict() @@ -177,6 +211,24 @@ def test_to_dict_with_params(self, mock_watsonx): "timeout": None, "max_retries": None, "streaming_callback": "haystack.components.generators.utils.print_streaming_chunk", + "tools": [ + { + "data": { + "description": "useful to determine the weather in a given location", + "function": "tests.test_chat_generator.weather", + "inputs_from_state": None, + "name": "weather", + "outputs_to_state": None, + "outputs_to_string": None, + "parameters": { + "properties": {"city": {"type": "string"}}, + "required": ["city"], + "type": "object", + }, + }, + "type": "haystack.tools.tool.Tool", + }, + ], }, } assert data == expected @@ -214,6 +266,39 @@ def test_from_dict_with_callback(self, mock_watsonx): generator = WatsonxChatGenerator.from_dict(data) assert generator.streaming_callback is print_streaming_chunk + def test_from_dict_with_tools(self, mock_watsonx: dict[str, AsyncMock | MagicMock], tools: list[Tool]) -> None: + data = { + "type": "haystack_integrations.components.generators.watsonx.chat.chat_generator.WatsonxChatGenerator", + "init_parameters": { + "api_key": {"env_vars": ["WATSONX_API_KEY"], "strict": True, "type": "env_var"}, + "model": "ibm/granite-4-h-small", + "project_id": {"env_vars": ["WATSONX_PROJECT_ID"], "strict": True, "type": "env_var"}, + "tools": [ + { + "data": { + "description": "useful to determine the weather in a given location", + "function": "tests.test_chat_generator.weather", + "inputs_from_state": None, + "name": "weather", + "outputs_to_state": None, + "outputs_to_string": None, + "parameters": { + "properties": {"city": {"type": "string"}}, + "required": ["city"], + "type": "object", + }, + }, + "type": "haystack.tools.tool.Tool", + }, + ], + }, + } + + generator = WatsonxChatGenerator.from_dict(data) + assert isinstance(generator.tools, list) + assert len(generator.tools) == len(tools) + assert all(isinstance(tool, Tool) for tool in generator.tools) + def test_run_single_message(self, mock_watsonx): generator = WatsonxChatGenerator( api_key=Secret.from_token("test-api-key"), @@ -536,6 +621,45 @@ def test_live_run(self): assert len(results["replies"][0].text) > 0 assert isinstance(generator.project_id, Secret) + @pytest.mark.skipif( + not os.environ.get("WATSONX_API_KEY") or not os.environ.get("WATSONX_PROJECT_ID"), + reason="WATSONX_API_KEY or WATSONX_PROJECT_ID not set", + ) + def test_live_run_with_toolset(self, tools: list[Tool]) -> None: + """Test that WatsonxChatGenerator can run with a Toolset.""" + toolset = Toolset(tools) + generator = WatsonxChatGenerator( + project_id=Secret.from_env_var("WATSONX_PROJECT_ID"), + generation_kwargs={"max_tokens": 50, "temperature": 0.7, "top_p": 0.9}, + tools=toolset, + ) + messages = [ChatMessage.from_user("What's the weather like in Paris?")] + results = generator.run(messages=messages) + + assert len(results["replies"]) == 1 + message = results["replies"][0] + + # Check if tool calls were made + assert message.tool_calls is not None, "Message has no tool calls" + assert len(message.tool_calls) == 1, "Message has multiple tool calls and it should only have one" + tool_call = message.tool_calls[0] + assert message.meta["finish_reason"] == "tool_calls" + + assert tool_call.tool_name == "weather" + assert tool_call.arguments == {"city": "Paris"} + + # Test full conversation with tool result + tool_result_message = ChatMessage.from_tool(tool_result="22°C, sunny", origin=tool_call) + follow_up_messages = [*messages, message, tool_result_message] + final_results = generator.run(messages=follow_up_messages) + + assert len(final_results["replies"]) == 1 + final_message = final_results["replies"][0] + assert final_message.text + assert "paris" in final_message.text.lower() or "weather" in final_message.text.lower(), ( + "Response does not contain Paris or weather" + ) + @pytest.mark.skipif( not os.environ.get("WATSONX_API_KEY") or not os.environ.get("WATSONX_PROJECT_ID"), reason="WATSONX_API_KEY or WATSONX_PROJECT_ID not set", @@ -558,6 +682,136 @@ def callback(chunk: StreamingChunk): assert len(collected_chunks) > 0 assert all(isinstance(chunk, StreamingChunk) for chunk in collected_chunks) + @pytest.mark.skipif( + not os.environ.get("WATSONX_API_KEY") or not os.environ.get("WATSONX_PROJECT_ID"), + reason="WATSONX_API_KEY or WATSONX_PROJECT_ID not set", + ) + def test_live_run_with_tools_streaming(self, tools: list[Tool]) -> None: + """ + Integration test that the WatsonxChatGenerator component can run with tools and streaming. + """ + component = WatsonxChatGenerator( + project_id=Secret.from_env_var("WATSONX_PROJECT_ID"), tools=tools, streaming_callback=print_streaming_chunk + ) + results = component.run(messages=[ChatMessage.from_user("What's the weather like in Paris?")]) + + assert len(results["replies"]) > 0, "No replies received" + + # Find the message with tool calls + tool_message = None + for message in results["replies"]: + if message.tool_calls: + tool_message = message + break + + assert tool_message is not None, "No message with tool call found" + assert tool_message.tool_calls is not None, "Tool message has no tool calls" + assert len(tool_message.tool_calls) == 1, "Tool message has multiple tool calls" + assert tool_message.tool_calls[0].tool_name == "weather" + assert tool_message.tool_calls[0].arguments == {"city": "Paris"} + + assert isinstance(tool_message, ChatMessage), "Tool message is not a ChatMessage instance" + assert ChatMessage.is_from(tool_message, ChatRole.ASSISTANT), "Tool message is not from the assistant" + assert tool_message.meta["finish_reason"] == "tool_calls" + + tool_call = tool_message.tool_calls[0] + assert tool_call.tool_name == "weather" + assert tool_call.arguments == {"city": "Paris"} + + @pytest.mark.skipif( + not os.environ.get("WATSONX_API_KEY") or not os.environ.get("WATSONX_PROJECT_ID"), + reason="WATSONX_API_KEY or WATSONX_PROJECT_ID not set", + ) + def test_live_run_with_mixed_tools(self) -> None: + """ + Integration test that verifies WatsonxChatGenerator works with mixed Tool and Toolset. + This tests that the LLM can correctly invoke tools from both a standalone Tool and a Toolset. + """ + weather_tool = Tool( + name="weather", + description="useful to determine the weather in a given location", + parameters={ + "type": "object", + "properties": { + "city": { + "type": "string", + "description": "The name of the city to get weather for, e.g. Paris, London", + } + }, + "required": ["city"], + }, + function=weather, + ) + + population_tool = Tool( + name="population", + description="useful to determine the population of a given city", + parameters={ + "type": "object", + "properties": { + "city": { + "type": "string", + "description": "The name of the city to get population for, e.g. Paris, Berlin", + } + }, + "required": ["city"], + }, + function=population, + ) + + # Create a toolset with the population tool + population_toolset = Toolset([population_tool]) + + # Mix standalone tool with toolset + mixed_tools = [weather_tool, population_toolset] + + initial_messages = [ + ChatMessage.from_user("What's the weather like in Paris and what is the population of Berlin?") + ] + component = WatsonxChatGenerator(project_id=Secret.from_env_var("WATSONX_PROJECT_ID"), tools=mixed_tools) + results = component.run(messages=initial_messages) + + assert len(results["replies"]) > 0, "No replies received" + + first_reply = results["replies"][0] + assert isinstance(first_reply, ChatMessage), "First reply is not a ChatMessage instance" + assert ChatMessage.is_from(first_reply, ChatRole.ASSISTANT), "First reply is not from the assistant" + assert first_reply.tool_calls, "First reply has no tool calls" + + tool_calls = first_reply.tool_calls + assert len(tool_calls) == 2, f"Expected 2 tool calls, got {len(tool_calls)}" + + # Verify we got calls to both weather and population tools + tool_names = {tc.tool_name for tc in tool_calls} + assert "weather" in tool_names, "Expected 'weather' tool call" + assert "population" in tool_names, "Expected 'population' tool call" + + # Verify tool call details + for tool_call in tool_calls: + assert tool_call.tool_name in ["weather", "population"] + assert "city" in tool_call.arguments + assert tool_call.arguments["city"] in ["Paris", "Berlin"] + assert first_reply.meta["finish_reason"] == "tool_calls" + + # Mock the response we'd get from ToolInvoker + tool_result_messages = [] + for tool_call in tool_calls: + if tool_call.tool_name == "weather": + result = "The weather in Paris is sunny and 32°C" + else: # population + result = "The population of Berlin is 2.2 million" + tool_result_messages.append(ChatMessage.from_tool(tool_result=result, origin=tool_call)) + + new_messages = [*initial_messages, first_reply, *tool_result_messages] + results = component.run(messages=new_messages) + + assert len(results["replies"]) == 1 + final_message = results["replies"][0] + assert not final_message.tool_calls + assert len(final_message.text) > 0 + assert "paris" in final_message.text.lower() + assert "berlin" in final_message.text.lower() + @pytest.mark.asyncio @pytest.mark.skipif( not os.environ.get("WATSONX_API_KEY") or not os.environ.get("WATSONX_PROJECT_ID"), @@ -575,6 +829,32 @@ async def test_live_run_async(self): assert isinstance(results["replies"][0], ChatMessage) assert len(results["replies"][0].text) > 0 + @pytest.mark.asyncio + @pytest.mark.skipif( + not os.environ.get("WATSONX_API_KEY") or not os.environ.get("WATSONX_PROJECT_ID"), + reason="WATSONX_API_KEY or WATSONX_PROJECT_ID not set", + ) + async def test_live_run_async_with_tools(self, tools: list[Tool]) -> None: + """Test async version with tools.""" + component = WatsonxChatGenerator(project_id=Secret.from_env_var("WATSONX_PROJECT_ID"), tools=tools) + results = await component.run_async(messages=[ChatMessage.from_user("What's the weather like in Paris?")]) + + assert len(results["replies"]) > 0, "No replies received" + + # Find the message with tool calls + tool_message = None + for message in results["replies"]: + if message.tool_calls: + tool_message = message + break + + assert tool_message is not None, "No message with tool call found" + assert tool_message.tool_calls is not None, "Tool message has no tool calls" + assert len(tool_message.tool_calls) == 1, "Tool message has multiple tool calls" + assert tool_message.tool_calls[0].tool_name == "weather" + assert tool_message.tool_calls[0].arguments == {"city": "Paris"} + assert tool_message.meta["finish_reason"] == "tool_calls" + @pytest.mark.skipif( not os.environ.get("WATSONX_API_KEY") or not os.environ.get("WATSONX_PROJECT_ID"), reason="WATSONX_API_KEY or WATSONX_PROJECT_ID not set", diff --git a/integrations/watsonx/tests/test_generator.py b/integrations/watsonx/tests/test_generator.py index 260617b63d..8bfcb474d5 100644 --- a/integrations/watsonx/tests/test_generator.py +++ b/integrations/watsonx/tests/test_generator.py @@ -135,6 +135,7 @@ def test_to_dict(self, mock_watsonx): "timeout": None, "max_retries": None, "streaming_callback": None, + "tools": None, }, } assert data == expected From 0c22a7399981165946027ad4f65fb62785a1d599 Mon Sep 17 00:00:00 2001 From: Max Swain <89113255+maxdswain@users.noreply.github.com> Date: Wed, 11 Mar 2026 18:13:06 +0000 Subject: [PATCH 07/10] test: Add test with real chunks and add more chunk metadata --- .../generators/watsonx/chat/chat_generator.py | 4 + .../watsonx/tests/test_chat_generator.py | 114 +++++++++++++++++- 2 files changed, 117 insertions(+), 1 deletion(-) diff --git a/integrations/watsonx/src/haystack_integrations/components/generators/watsonx/chat/chat_generator.py b/integrations/watsonx/src/haystack_integrations/components/generators/watsonx/chat/chat_generator.py index bf4dee90e7..213fd8e59b 100644 --- a/integrations/watsonx/src/haystack_integrations/components/generators/watsonx/chat/chat_generator.py +++ b/integrations/watsonx/src/haystack_integrations/components/generators/watsonx/chat/chat_generator.py @@ -388,6 +388,10 @@ def _convert_chunk_to_streaming_chunk(self, chunk: dict[str, Any], component_inf choice = chunk["choices"][0] chunk_meta = { "model": self.model, + "model_id": chunk.get("model_id"), + "model_version": chunk.get("model_version"), + "created": chunk.get("created"), + "created_at": chunk.get("created_at"), "received_at": datetime.now(timezone.utc).isoformat(), } diff --git a/integrations/watsonx/tests/test_chat_generator.py b/integrations/watsonx/tests/test_chat_generator.py index f7a8f9464c..6fc20e9b8a 100644 --- a/integrations/watsonx/tests/test_chat_generator.py +++ b/integrations/watsonx/tests/test_chat_generator.py @@ -8,7 +8,7 @@ import pytest from haystack import logging from haystack.components.generators.utils import print_streaming_chunk -from haystack.dataclasses import ChatMessage, ChatRole, ImageContent, StreamingChunk +from haystack.dataclasses import ChatMessage, ChatRole, ComponentInfo, ImageContent, StreamingChunk from haystack.tools import Tool, Toolset from haystack.utils import Secret @@ -522,6 +522,118 @@ def test_prepare_api_call_image_in_non_user_message(self, mock_watsonx): with pytest.raises(ValueError, match="Image content is only supported for user messages"): generator._prepare_api_call(messages=[message]) + def test_convert_chunk_to_streaming_chunk_real_example( + self, mock_watsonx: dict[str, AsyncMock | MagicMock] + ) -> None: + component = WatsonxChatGenerator( + project_id=Secret.from_token("test-project"), model="meta-llama/llama-3-2-11b-vision-instruct" + ) + component_info = ComponentInfo.from_component(component) + + # Chunk 1: Text only + chunk1 = { + "id": "chatcmpl-21e72dd9-ed65-49cc-9ea2-64d971707cda---2dedc26eab5af753744ed4eaa116a197---e0399d75-cd8c-486e-b907-dc211cb70eac", # noqa: E501 + "object": "chat.completion.chunk", + "model_id": "meta-llama/llama-3-2-11b-vision-instruct", + "model": "meta-llama/llama-3-2-11b-vision-instruct", + "choices": [ + { + "index": 0, + "finish_reason": None, + "delta": {"content": "I'll get the weather information for Paris and Berlin"}, + } + ], + "created": 1773250972, + "model_version": "3.2.0", + "created_at": "2026-03-11T17:42:52.921Z", + } + + streaming_chunk1 = component._convert_chunk_to_streaming_chunk(chunk=chunk1, component_info=component_info) + assert streaming_chunk1.content == "I'll get the weather information for Paris and Berlin" + assert streaming_chunk1.tool_calls is None + assert streaming_chunk1.finish_reason is None + assert streaming_chunk1.index == 0 + assert "created" in streaming_chunk1.meta + assert "created_at" in streaming_chunk1.meta + assert "received_at" in streaming_chunk1.meta + assert streaming_chunk1.meta["model"] == "meta-llama/llama-3-2-11b-vision-instruct" + assert streaming_chunk1.meta["model_id"] == "meta-llama/llama-3-2-11b-vision-instruct" + assert streaming_chunk1.meta["model_version"] == "3.2.0" + assert streaming_chunk1.component_info == component_info + + # Chunk 2: Text only + chunk2 = { + "id": "chatcmpl-21e72dd9-ed65-49cc-9ea2-64d971707cda---2dedc26eab5af753744ed4eaa116a197---e0399d75-cd8c-486e-b907-dc211cb70eac", # noqa: E501 + "object": "chat.completion.chunk", + "model_id": "meta-llama/llama-3-2-11b-vision-instruct", + "model": "meta-llama/llama-3-2-11b-vision-instruct", + "choices": [ + {"index": 0, "finish_reason": None, "delta": {"content": " and present it in a structured format."}} + ], + "created": 1773250972, + "model_version": "3.2.0", + "created_at": "2026-03-11T17:42:52.929Z", + } + + streaming_chunk2 = component._convert_chunk_to_streaming_chunk(chunk=chunk2, component_info=component_info) + assert streaming_chunk2.content == " and present it in a structured format." + assert streaming_chunk2.tool_calls is None + assert streaming_chunk2.finish_reason is None + assert streaming_chunk2.index == 0 + assert "created" in streaming_chunk2.meta + assert "created_at" in streaming_chunk2.meta + assert "received_at" in streaming_chunk2.meta + assert streaming_chunk2.meta["model"] == "meta-llama/llama-3-2-11b-vision-instruct" + assert streaming_chunk2.meta["model_id"] == "meta-llama/llama-3-2-11b-vision-instruct" + assert streaming_chunk2.meta["model_version"] == "3.2.0" + assert streaming_chunk2.component_info == component_info + + # Chunk 3: Multiple tool calls (6 function calls) for 2 cities with 3 tools each + chunk3 = { + "id": "chatcmpl-6b615ca6-4aa7-4f79-832f-bedce4641c2b---87fdc1a1cd2032ff0c6776ecfc20b6a5---34576777-949d-4df1-b95f-56d14b848eca", # noqa: E501 + "object": "chat.completion.chunk", + "model_id": "meta-llama/llama-3-2-11b-vision-instruct", + "model": "meta-llama/llama-3-2-11b-vision-instruct", + "choices": [ + { + "index": 0, + "finish_reason": None, + "delta": { + "tool_calls": [ + { + "index": 0, + "id": "chatcmpl-tool-9646185282a54afc86c3572513b2dafa", + "type": "function", + "function": {"name": "weather", "arguments": ""}, + } + ] + }, + } + ], + "created": 1773252289, + "model_version": "3.2.0", + "created_at": "2026-03-11T18:04:49.696Z", + } + + streaming_chunk3 = component._convert_chunk_to_streaming_chunk(chunk=chunk3, component_info=component_info) + assert streaming_chunk3.content == "" + assert streaming_chunk3.tool_calls is not None + assert len(streaming_chunk3.tool_calls) == 1 + assert streaming_chunk3.finish_reason is None + assert streaming_chunk3.index == 0 + assert "created" in streaming_chunk3.meta + assert "created_at" in streaming_chunk3.meta + assert "received_at" in streaming_chunk3.meta + assert streaming_chunk3.meta["model"] == "meta-llama/llama-3-2-11b-vision-instruct" + assert streaming_chunk3.meta["model_id"] == "meta-llama/llama-3-2-11b-vision-instruct" + assert streaming_chunk3.meta["model_version"] == "3.2.0" + assert streaming_chunk3.component_info == component_info + + assert streaming_chunk3.tool_calls[0].tool_name == "weather" + assert streaming_chunk3.tool_calls[0].arguments == "" + assert streaming_chunk3.tool_calls[0].id == "chatcmpl-tool-9646185282a54afc86c3572513b2dafa" + assert streaming_chunk3.tool_calls[0].index == 0 + def test_multimodal_message_processing(self, mock_watsonx): """Test multimodal message processing with mocked model.""" base64_image = ( From 404aca67e240794e8ff5a81123565e3d59e8442f Mon Sep 17 00:00:00 2001 From: Max Swain <89113255+maxdswain@users.noreply.github.com> Date: Thu, 12 Mar 2026 09:45:40 +0000 Subject: [PATCH 08/10] Use public method for converting to openai dict format --- .../components/generators/watsonx/chat/chat_generator.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/integrations/watsonx/src/haystack_integrations/components/generators/watsonx/chat/chat_generator.py b/integrations/watsonx/src/haystack_integrations/components/generators/watsonx/chat/chat_generator.py index 213fd8e59b..c4bda36c19 100644 --- a/integrations/watsonx/src/haystack_integrations/components/generators/watsonx/chat/chat_generator.py +++ b/integrations/watsonx/src/haystack_integrations/components/generators/watsonx/chat/chat_generator.py @@ -331,9 +331,7 @@ def _prepare_api_call( for msg in messages: # Watsonx tool call result messages are of the same format as OpenAI chat completions if msg.tool_call_results: - watsonx_messages.append( - msg._tool_result_message_to_openai({"role": msg.role.value}, require_tool_call_ids=True) - ) + watsonx_messages.append(msg.to_openai_dict_format(require_tool_call_ids=True)) continue # Check that images are only in user messages From 6c86ca477879a8e4be8a1e99b9bff9e9a63bbd0b Mon Sep 17 00:00:00 2001 From: Max Swain <89113255+maxdswain@users.noreply.github.com> Date: Thu, 12 Mar 2026 12:23:28 +0000 Subject: [PATCH 09/10] Fix tool call parsing errors --- .../generators/watsonx/chat/chat_generator.py | 45 +++++++++++++++++-- .../watsonx/tests/test_chat_generator.py | 6 ++- 2 files changed, 47 insertions(+), 4 deletions(-) diff --git a/integrations/watsonx/src/haystack_integrations/components/generators/watsonx/chat/chat_generator.py b/integrations/watsonx/src/haystack_integrations/components/generators/watsonx/chat/chat_generator.py index c4bda36c19..5c13a14b09 100644 --- a/integrations/watsonx/src/haystack_integrations/components/generators/watsonx/chat/chat_generator.py +++ b/integrations/watsonx/src/haystack_integrations/components/generators/watsonx/chat/chat_generator.py @@ -3,6 +3,7 @@ # SPDX-License-Identifier: Apache-2.0 import json +from dataclasses import replace from datetime import datetime, timezone from typing import Any, Literal, get_args @@ -453,7 +454,21 @@ def _handle_streaming( chunks.append(streaming_chunk) callback(streaming_chunk) - return {"replies": [_convert_streaming_chunks_to_chat_message(chunks)]} + chat_message = _convert_streaming_chunks_to_chat_message(chunks) + message_tool_calls = [ + replace(tool_call, arguments=self._parse_tool_call_json(tool_call.arguments)) + for tool_call in chat_message.tool_calls + ] + return { + "replies": [ + ChatMessage.from_assistant( + text=chat_message.text, + meta=chat_message.meta, + tool_calls=message_tool_calls, + reasoning=chat_message.reasoning, + ) + ] + } def _handle_standard(self, api_args: dict[str, Any]) -> dict[str, list[ChatMessage]]: """Handle synchronous standard response.""" @@ -483,7 +498,21 @@ async def _handle_async_streaming( chunks.append(streaming_chunk) await callback(streaming_chunk) - return {"replies": [_convert_streaming_chunks_to_chat_message(chunks)]} + chat_message = _convert_streaming_chunks_to_chat_message(chunks) + message_tool_calls = [ + replace(tool_call, arguments=self._parse_tool_call_json(tool_call.arguments)) + for tool_call in chat_message.tool_calls + ] + return { + "replies": [ + ChatMessage.from_assistant( + text=chat_message.text, + meta=chat_message.meta, + tool_calls=message_tool_calls, + reasoning=chat_message.reasoning, + ) + ] + } async def _handle_async_standard(self, api_args: dict[str, Any]) -> dict[str, list[ChatMessage]]: """Handle asynchronous standard response.""" @@ -492,6 +521,16 @@ async def _handle_async_standard(self, api_args: dict[str, Any]) -> dict[str, li ) return self._process_response(response) + @staticmethod + def _parse_tool_call_json(tool_call: object) -> dict[str, Any]: + """Parse tool call json from Watsonx tool calls.""" + if isinstance(tool_call, dict): + return tool_call + obj = json.loads(tool_call) + if isinstance(obj, str): + obj = json.loads(obj) + return obj + def _process_response(self, response: dict[str, Any]) -> dict[str, list[ChatMessage]]: """Process standard response into Haystack format.""" if not response.get("choices"): @@ -508,7 +547,7 @@ def _process_response(self, response: dict[str, Any]) -> dict[str, list[ChatMess ToolCall( id=tool_call["id"], tool_name=tool_call["function"]["name"], - arguments=json.loads(tool_call["function"]["arguments"]), + arguments=self._parse_tool_call_json(tool_call["function"]["arguments"]), ) for tool_call in tool_calls ] diff --git a/integrations/watsonx/tests/test_chat_generator.py b/integrations/watsonx/tests/test_chat_generator.py index 6fc20e9b8a..b85cf8211a 100644 --- a/integrations/watsonx/tests/test_chat_generator.py +++ b/integrations/watsonx/tests/test_chat_generator.py @@ -880,7 +880,11 @@ def test_live_run_with_mixed_tools(self) -> None: initial_messages = [ ChatMessage.from_user("What's the weather like in Paris and what is the population of Berlin?") ] - component = WatsonxChatGenerator(project_id=Secret.from_env_var("WATSONX_PROJECT_ID"), tools=mixed_tools) + component = WatsonxChatGenerator( + model="meta-llama/llama-3-2-11b-vision-instruct", + project_id=Secret.from_env_var("WATSONX_PROJECT_ID"), + tools=mixed_tools, + ) results = component.run(messages=initial_messages) assert len(results["replies"]) > 0, "No replies received" From da3cabaa7e426725ebd0fac7c2d2782040356c14 Mon Sep 17 00:00:00 2001 From: Max Swain <89113255+maxdswain@users.noreply.github.com> Date: Thu, 12 Mar 2026 12:24:46 +0000 Subject: [PATCH 10/10] Fix type error --- .../components/generators/watsonx/chat/chat_generator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/integrations/watsonx/src/haystack_integrations/components/generators/watsonx/chat/chat_generator.py b/integrations/watsonx/src/haystack_integrations/components/generators/watsonx/chat/chat_generator.py index 5c13a14b09..e31c36569b 100644 --- a/integrations/watsonx/src/haystack_integrations/components/generators/watsonx/chat/chat_generator.py +++ b/integrations/watsonx/src/haystack_integrations/components/generators/watsonx/chat/chat_generator.py @@ -522,7 +522,7 @@ async def _handle_async_standard(self, api_args: dict[str, Any]) -> dict[str, li return self._process_response(response) @staticmethod - def _parse_tool_call_json(tool_call: object) -> dict[str, Any]: + def _parse_tool_call_json(tool_call: str | dict) -> dict[str, Any]: """Parse tool call json from Watsonx tool calls.""" if isinstance(tool_call, dict): return tool_call