Skip to content

Commit 78a0dd0

Browse files
authored
feat: pass component_info to StreamingChunk in OllamaChatGenerator (#2039)
* feat: pass component_info to StreamingChunk in OllamaChatGenerator * small tests improvements * pin to patch release
1 parent 2f0dcf8 commit 78a0dd0

3 files changed

Lines changed: 55 additions & 28 deletions

File tree

integrations/ollama/pyproject.toml

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ classifiers = [
2727
"Programming Language :: Python :: Implementation :: CPython",
2828
"Programming Language :: Python :: Implementation :: PyPy",
2929
]
30-
dependencies = ["haystack-ai>=2.13.1", "ollama>=0.4.0", "pydantic"]
30+
dependencies = ["haystack-ai>=2.15.1", "ollama>=0.4.0", "pydantic"]
3131

3232
[project.urls]
3333
Documentation = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/ollama#readme"
@@ -158,9 +158,7 @@ show_missing = true
158158
exclude_lines = ["no cov", "if __name__ == .__main__.:", "if TYPE_CHECKING:"]
159159

160160
[tool.pytest.ini_options]
161-
markers = [
162-
"integration: marks tests as slow (deselect with '-m \"not integration\"')",
163-
]
161+
markers = ["integration: integration tests"]
164162
log_cli = true
165163
addopts = ["--import-mode=importlib"]
166164
asyncio_mode = "auto"

integrations/ollama/src/haystack_integrations/components/generators/ollama/chat/chat_generator.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
from typing import Any, Callable, Dict, List, Literal, Optional, Union
1+
from typing import Any, Callable, Dict, Iterator, List, Literal, Optional, Union
22

33
from haystack import component, default_from_dict, default_to_dict
4-
from haystack.dataclasses import ChatMessage, StreamingChunk, ToolCall
4+
from haystack.dataclasses import ChatMessage, ComponentInfo, StreamingChunk, ToolCall
55
from haystack.tools import (
66
Tool,
77
_check_duplicate_tool_names,
@@ -100,7 +100,7 @@ def _convert_ollama_meta_to_openai_format(input_response_dict: Dict) -> Dict:
100100
return meta
101101

102102

103-
def _convert_ollama_response_to_chatmessage(ollama_response: "ChatResponse") -> ChatMessage:
103+
def _convert_ollama_response_to_chatmessage(ollama_response: ChatResponse) -> ChatMessage:
104104
"""
105105
Convert non-streaming Ollama Chat API response to Haystack ChatMessage with the assistant role.
106106
"""
@@ -271,7 +271,7 @@ def from_dict(cls, data: Dict[str, Any]) -> "OllamaChatGenerator":
271271
return default_from_dict(cls, data)
272272

273273
@staticmethod
274-
def _build_chunk(chunk_response: Any) -> StreamingChunk:
274+
def _build_chunk(chunk_response: ChatResponse, component_info: ComponentInfo) -> StreamingChunk:
275275
"""
276276
Convert one Ollama stream-chunk to Haystack StreamingChunk.
277277
"""
@@ -283,11 +283,11 @@ def _build_chunk(chunk_response: Any) -> StreamingChunk:
283283
if tool_calls := chunk_response_dict["message"].get("tool_calls"):
284284
meta["tool_calls"] = tool_calls
285285

286-
return StreamingChunk(content, meta)
286+
return StreamingChunk(content=content, meta=meta, component_info=component_info)
287287

288288
def _handle_streaming_response(
289289
self,
290-
response_iter: Any,
290+
response_iter: Iterator[ChatResponse],
291291
callback: Optional[Callable[[StreamingChunk], None]],
292292
) -> Dict[str, List[ChatMessage]]:
293293
"""
@@ -296,6 +296,8 @@ def _handle_streaming_response(
296296
or as full JSON dicts.
297297
"""
298298

299+
component_info = ComponentInfo.from_component(self)
300+
299301
chunks: List[StreamingChunk] = []
300302

301303
# Accumulators
@@ -305,7 +307,7 @@ def _handle_streaming_response(
305307

306308
# Stream
307309
for raw in response_iter:
308-
chunk = self._build_chunk(raw)
310+
chunk = self._build_chunk(chunk_response=raw, component_info=component_info)
309311
chunks.append(chunk)
310312

311313
if callback:
@@ -428,8 +430,8 @@ def run(
428430
format=self.response_format,
429431
)
430432

431-
if is_stream:
432-
return self._handle_streaming_response(response, callback)
433+
if isinstance(response, Iterator):
434+
return self._handle_streaming_response(response_iter=response, callback=callback)
433435

434436
# non-stream path
435-
return {"replies": [_convert_ollama_response_to_chatmessage(response)]}
437+
return {"replies": [_convert_ollama_response_to_chatmessage(ollama_response=response)]}

integrations/ollama/tests/test_chat_generator.py

Lines changed: 41 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from haystack.dataclasses import (
77
ChatMessage,
88
ChatRole,
9+
ComponentInfo,
910
StreamingChunk,
1011
TextContent,
1112
ToolCall,
@@ -287,6 +288,9 @@ def test_to_dict(self):
287288
"type": "string",
288289
},
289290
},
291+
"outputs_to_string": None,
292+
"inputs_from_state": None,
293+
"outputs_to_state": None,
290294
},
291295
},
292296
],
@@ -297,15 +301,6 @@ def test_to_dict(self):
297301
},
298302
}
299303

300-
# add outputs_to_string, inputs_from_state and outputs_to_state tool parameters for compatibility with
301-
# haystack-ai>=2.12.0
302-
if hasattr(tool, "outputs_to_string"):
303-
expected_dict["init_parameters"]["tools"][0]["data"]["outputs_to_string"] = tool.outputs_to_string
304-
if hasattr(tool, "inputs_from_state"):
305-
expected_dict["init_parameters"]["tools"][0]["data"]["inputs_from_state"] = tool.inputs_from_state
306-
if hasattr(tool, "outputs_to_state"):
307-
expected_dict["init_parameters"]["tools"][0]["data"]["outputs_to_state"] = tool.outputs_to_state
308-
309304
assert data == expected_dict
310305

311306
def test_from_dict(self):
@@ -365,6 +360,30 @@ def test_from_dict(self):
365360
"properties": {"name": {"type": "string"}, "age": {"type": "number"}},
366361
}
367362

363+
def test_build_chunk(self):
364+
generator = OllamaChatGenerator()
365+
366+
mock_chunk_response = Mock()
367+
mock_chunk_response.model_dump.return_value = {
368+
"message": {"role": "assistant", "content": "Hello world"},
369+
"model": "llama2",
370+
"created_at": "2023-12-12T14:13:43.416799Z",
371+
"done": False,
372+
}
373+
374+
component_info = ComponentInfo.from_component(generator)
375+
376+
chunk = generator._build_chunk(mock_chunk_response, component_info)
377+
378+
assert isinstance(chunk, StreamingChunk)
379+
assert chunk.content == "Hello world"
380+
assert chunk.component_info == component_info
381+
assert chunk.meta["role"] == "assistant"
382+
assert chunk.meta["model"] == "llama2"
383+
assert chunk.meta["created_at"] == "2023-12-12T14:13:43.416799Z"
384+
assert chunk.meta["done"] is False
385+
assert "tool_calls" not in chunk.meta
386+
368387
@patch("haystack_integrations.components.generators.ollama.chat.chat_generator.Client")
369388
def test_run(self, mock_client):
370389
generator = OllamaChatGenerator()
@@ -407,11 +426,10 @@ def test_run(self, mock_client):
407426

408427
@patch("haystack_integrations.components.generators.ollama.chat.chat_generator.Client")
409428
def test_run_streaming(self, mock_client):
410-
streaming_callback_called = False
429+
collected_chunks = []
411430

412-
def streaming_callback(_: StreamingChunk) -> None:
413-
nonlocal streaming_callback_called
414-
streaming_callback_called = True
431+
def streaming_callback(chunk: StreamingChunk) -> None:
432+
collected_chunks.append(chunk)
415433

416434
generator = OllamaChatGenerator(streaming_callback=streaming_callback)
417435

@@ -443,7 +461,16 @@ def streaming_callback(_: StreamingChunk) -> None:
443461

444462
result = generator.run(messages=[ChatMessage.from_user("irrelevant")])
445463

446-
assert streaming_callback_called
464+
assert len(collected_chunks) == 2
465+
assert collected_chunks[0].content == "first chunk "
466+
assert collected_chunks[1].content == "second chunk"
467+
468+
for chunk in collected_chunks:
469+
assert (
470+
chunk.component_info.type
471+
== "haystack_integrations.components.generators.ollama.chat.chat_generator.OllamaChatGenerator"
472+
)
473+
assert chunk.component_info.name is None # not in a pipeline
447474

448475
assert "replies" in result
449476
assert len(result["replies"]) == 1

0 commit comments

Comments
 (0)