Skip to content

Commit 9ddd737

Browse files
committed
Addded tests
1 parent 49606b2 commit 9ddd737

7 files changed

Lines changed: 627 additions & 10 deletions

File tree

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +0,0 @@
1-
# SPDX-FileCopyrightText: 2024-present deepset GmbH <info@deepset.ai>
2-
#
3-
# SPDX-License-Identifier: Apache-2.0
Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +0,0 @@
1-
# SPDX-FileCopyrightText: 2024-present deepset GmbH <info@deepset.ai>
2-
#
3-
# SPDX-License-Identifier: Apache-2.0
Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +0,0 @@
1-
# SPDX-FileCopyrightText: 2024-present deepset GmbH <info@deepset.ai>
2-
#
3-
# SPDX-License-Identifier: Apache-2.0

integrations/dspy/src/haystack_integrations/components/generators/dspy/chat/chat_generator.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,8 @@ def __init__(
6464
:param input_mapping: Maps DSPy signature input field names to run kwarg names.
6565
:param streaming_callback: Callback for streaming responses.
6666
"""
67-
super().__init__(
67+
DSPyGenerator.__init__(
68+
self,
6869
signature=signature,
6970
model=model,
7071
api_key=api_key,
Lines changed: 262 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,262 @@
1+
import os
2+
from unittest.mock import MagicMock, patch
3+
4+
import pytest
5+
from haystack.dataclasses import ChatMessage
6+
from haystack.utils.auth import Secret
7+
8+
from haystack_integrations.components.generators.dspy.chat.chat_generator import DSPyChatGenerator
9+
10+
11+
@pytest.fixture
12+
def mock_dspy_module():
13+
"""
14+
Mock DSPy LM, configure, and module classes to avoid real API calls.
15+
"""
16+
with patch("dspy.LM") as mock_lm_class, \
17+
patch("dspy.configure"), \
18+
patch("dspy.ChainOfThought") as mock_cot_class, \
19+
patch("dspy.Predict") as mock_predict_class, \
20+
patch("dspy.ReAct") as mock_react_class:
21+
mock_lm = MagicMock()
22+
mock_lm_class.return_value = mock_lm
23+
24+
mock_module = MagicMock()
25+
mock_module.return_value = MagicMock(answer="Hello world!")
26+
mock_cot_class.return_value = mock_module
27+
mock_predict_class.return_value = mock_module
28+
mock_react_class.return_value = mock_module
29+
30+
yield mock_module
31+
32+
33+
@pytest.fixture
34+
def chat_messages():
35+
return [
36+
ChatMessage.from_system("You are a helpful assistant"),
37+
ChatMessage.from_user("What's the capital of France"),
38+
]
39+
40+
41+
class TestDSPyChatGenerator:
42+
def test_init_default(self, monkeypatch, mock_dspy_module):
43+
monkeypatch.setenv("OPENAI_API_KEY", "test-api-key")
44+
component = DSPyChatGenerator(signature="question -> answer")
45+
assert component.model == "openai/gpt-5-mini"
46+
assert component.signature == "question -> answer"
47+
assert component.module_type == "ChainOfThought"
48+
assert component.output_field == "answer"
49+
assert component.streaming_callback is None
50+
assert not component.generation_kwargs
51+
assert component.input_mapping is None
52+
53+
def test_init_fail_wo_api_key(self, monkeypatch):
54+
monkeypatch.delenv("OPENAI_API_KEY", raising=False)
55+
with pytest.raises(ValueError, match=r"None of the .* environment variables are set"):
56+
DSPyChatGenerator(signature="question -> answer")
57+
58+
def test_init_with_parameters(self, mock_dspy_module):
59+
component = DSPyChatGenerator(
60+
signature="context, question -> answer",
61+
model="openai/gpt-4o",
62+
api_key=Secret.from_token("test-api-key"),
63+
module_type="Predict",
64+
output_field="response",
65+
generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"},
66+
input_mapping={"context": "context", "question": "question"},
67+
)
68+
assert component.model == "openai/gpt-4o"
69+
assert component.signature == "context, question -> answer"
70+
assert component.module_type == "Predict"
71+
assert component.output_field == "response"
72+
assert component.generation_kwargs == {"max_tokens": 10, "some_test_param": "test-params"}
73+
assert component.input_mapping == {"context": "context", "question": "question"}
74+
75+
def test_init_with_invalid_module_type(self, mock_dspy_module):
76+
with pytest.raises(ValueError, match="Invalid module_type"):
77+
DSPyChatGenerator(
78+
signature="question -> answer",
79+
api_key=Secret.from_token("test-api-key"),
80+
module_type="InvalidModule",
81+
)
82+
83+
def test_to_dict_default(self, monkeypatch, mock_dspy_module):
84+
monkeypatch.setenv("OPENAI_API_KEY", "test-api-key")
85+
component = DSPyChatGenerator(
86+
signature="question -> answer",
87+
api_key=Secret.from_env_var("OPENAI_API_KEY"),
88+
)
89+
data = component.to_dict()
90+
assert data == {
91+
"type": "haystack_integrations.components.generators.dspy.chat.chat_generator.DSPyChatGenerator",
92+
"init_parameters": {
93+
"api_key": {"env_vars": ["OPENAI_API_KEY"], "strict": True, "type": "env_var"},
94+
"signature": "question -> answer",
95+
"model": "openai/gpt-5-mini",
96+
"module_type": "ChainOfThought",
97+
"output_field": "answer",
98+
"generation_kwargs": {},
99+
"input_mapping": None,
100+
},
101+
}
102+
103+
def test_to_dict_with_parameters(self, monkeypatch, mock_dspy_module):
104+
monkeypatch.setenv("OPENAI_API_KEY", "test-api-key")
105+
component = DSPyChatGenerator(
106+
signature="context, question -> answer",
107+
model="openai/gpt-4o",
108+
api_key=Secret.from_env_var("OPENAI_API_KEY"),
109+
module_type="Predict",
110+
output_field="response",
111+
generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"},
112+
input_mapping={"context": "context", "question": "question"},
113+
)
114+
data = component.to_dict()
115+
assert data == {
116+
"type": "haystack_integrations.components.generators.dspy.chat.chat_generator.DSPyChatGenerator",
117+
"init_parameters": {
118+
"api_key": {"env_vars": ["OPENAI_API_KEY"], "strict": True, "type": "env_var"},
119+
"signature": "context, question -> answer",
120+
"model": "openai/gpt-4o",
121+
"module_type": "Predict",
122+
"output_field": "response",
123+
"generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"},
124+
"input_mapping": {"context": "context", "question": "question"},
125+
},
126+
}
127+
128+
def test_from_dict(self, monkeypatch, mock_dspy_module):
129+
monkeypatch.setenv("OPENAI_API_KEY", "fake-api-key")
130+
data = {
131+
"type": "haystack_integrations.components.generators.dspy.chat.chat_generator.DSPyChatGenerator",
132+
"init_parameters": {
133+
"api_key": {"env_vars": ["OPENAI_API_KEY"], "strict": True, "type": "env_var"},
134+
"signature": "question -> answer",
135+
"model": "openai/gpt-4o",
136+
"module_type": "Predict",
137+
"output_field": "response",
138+
"generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"},
139+
"input_mapping": None,
140+
},
141+
}
142+
component = DSPyChatGenerator.from_dict(data)
143+
assert component.model == "openai/gpt-4o"
144+
assert component.signature == "question -> answer"
145+
assert component.module_type == "Predict"
146+
assert component.output_field == "response"
147+
assert component.generation_kwargs == {"max_tokens": 10, "some_test_param": "test-params"}
148+
assert component.api_key == Secret.from_env_var("OPENAI_API_KEY")
149+
assert component.input_mapping is None
150+
151+
def test_from_dict_fail_wo_env_var(self, monkeypatch):
152+
monkeypatch.delenv("OPENAI_API_KEY", raising=False)
153+
data = {
154+
"type": "haystack_integrations.components.generators.dspy.chat.chat_generator.DSPyChatGenerator",
155+
"init_parameters": {
156+
"api_key": {"env_vars": ["OPENAI_API_KEY"], "strict": True, "type": "env_var"},
157+
"signature": "question -> answer",
158+
"model": "openai/gpt-4o",
159+
"module_type": "Predict",
160+
"output_field": "response",
161+
"generation_kwargs": {},
162+
"input_mapping": None,
163+
},
164+
}
165+
with pytest.raises(ValueError, match=r"None of the .* environment variables are set"):
166+
DSPyChatGenerator.from_dict(data)
167+
168+
def test_run(self, chat_messages, mock_dspy_module):
169+
component = DSPyChatGenerator(
170+
signature="question -> answer",
171+
api_key=Secret.from_token("test-api-key"),
172+
)
173+
response = component.run(chat_messages)
174+
175+
# Verify the mock was called
176+
mock_dspy_module.assert_called_once()
177+
178+
# Check that the component returns the correct ChatMessage response
179+
assert isinstance(response, dict)
180+
assert "replies" in response
181+
assert isinstance(response["replies"], list)
182+
assert len(response["replies"]) == 1
183+
assert all(isinstance(reply, ChatMessage) for reply in response["replies"])
184+
185+
def test_run_with_params(self, chat_messages, mock_dspy_module):
186+
component = DSPyChatGenerator(
187+
signature="question -> answer",
188+
api_key=Secret.from_token("test-api-key"),
189+
generation_kwargs={"max_tokens": 10, "temperature": 0.5},
190+
)
191+
response = component.run(chat_messages, generation_kwargs={"temperature": 0.9})
192+
193+
# Check that the component calls the DSPy module with the correct parameters
194+
_, kwargs = mock_dspy_module.call_args
195+
assert kwargs["config"] == {"temperature": 0.9}
196+
197+
# Check that the component returns the correct response
198+
assert isinstance(response, dict)
199+
assert "replies" in response
200+
assert isinstance(response["replies"], list)
201+
assert len(response["replies"]) == 1
202+
assert all(isinstance(reply, ChatMessage) for reply in response["replies"])
203+
204+
def test_run_with_multiple_messages(self, mock_dspy_module):
205+
component = DSPyChatGenerator(
206+
signature="question -> answer",
207+
api_key=Secret.from_token("test-api-key"),
208+
)
209+
messages = [
210+
ChatMessage.from_user("Hello"),
211+
ChatMessage.from_assistant("Hi there!"),
212+
ChatMessage.from_user("What is the capital of Germany?"),
213+
]
214+
response = component.run(messages=messages)
215+
216+
# Verify the last user message was used as input
217+
args, _ = mock_dspy_module.call_args
218+
# The first positional kwarg should be the question from the last user message
219+
call_kwargs = mock_dspy_module.call_args.kwargs
220+
assert call_kwargs.get("question") == "What is the capital of Germany?"
221+
222+
assert "replies" in response
223+
assert len(response["replies"]) == 1
224+
assert isinstance(response["replies"][0], ChatMessage)
225+
226+
def test_run_with_empty_messages(self, mock_dspy_module):
227+
component = DSPyChatGenerator(
228+
signature="question -> answer",
229+
api_key=Secret.from_token("test-api-key"),
230+
)
231+
with pytest.raises(ValueError, match="messages"):
232+
component.run(messages=[])
233+
234+
def test_run_with_wrong_model(self, mock_dspy_module):
235+
mock_dspy_module.side_effect = Exception("Invalid model name")
236+
237+
generator = DSPyChatGenerator(
238+
signature="question -> answer",
239+
api_key=Secret.from_token("test-api-key"),
240+
model="something-obviously-wrong",
241+
)
242+
243+
with pytest.raises(Exception, match="Invalid model name"):
244+
generator.run(messages=[ChatMessage.from_user("Whatever")])
245+
246+
@pytest.mark.skipif(
247+
not os.environ.get("OPENAI_API_KEY", None),
248+
reason="Export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.",
249+
)
250+
@pytest.mark.integration
251+
def test_live_run(self):
252+
chat_messages = [ChatMessage.from_user("What's the capital of France")]
253+
component = DSPyChatGenerator(signature="question -> answer")
254+
results = component.run(chat_messages)
255+
assert len(results["replies"]) == 1
256+
message: ChatMessage = results["replies"][0]
257+
assert "Paris" in message.text
258+
259+
metadata = results["meta"][0]
260+
assert metadata["model"] == "openai/gpt-5-mini"
261+
assert metadata["module_type"] == "ChainOfThought"
262+
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
from unittest.mock import AsyncMock, MagicMock, patch
2+
3+
import pytest
4+
from haystack.dataclasses import ChatMessage
5+
from haystack.utils.auth import Secret
6+
7+
from haystack_integrations.components.generators.dspy.chat.chat_generator import DSPyChatGenerator
8+
9+
10+
@pytest.fixture
11+
def mock_dspy_module():
12+
"""
13+
Mock DSPy LM, configure, and module classes to avoid real API calls.
14+
"""
15+
with patch("dspy.LM") as mock_lm_class, \
16+
patch("dspy.configure"), \
17+
patch("dspy.ChainOfThought") as mock_cot_class, \
18+
patch("dspy.Predict") as mock_predict_class, \
19+
patch("dspy.ReAct") as mock_react_class:
20+
mock_lm = MagicMock()
21+
mock_lm_class.return_value = mock_lm
22+
23+
mock_module = MagicMock()
24+
mock_module.return_value = MagicMock(answer="Hello world!")
25+
mock_module.acall = AsyncMock(return_value=MagicMock(answer="Hello world!"))
26+
27+
mock_cot_class.return_value = mock_module
28+
mock_predict_class.return_value = mock_module
29+
mock_react_class.return_value = mock_module
30+
31+
yield mock_module
32+
33+
34+
@pytest.fixture
35+
def chat_messages():
36+
return [
37+
ChatMessage.from_system("You are a helpful assistant"),
38+
ChatMessage.from_user("What's the capital of France"),
39+
]
40+
41+
42+
class TestDSPyChatGeneratorAsync:
43+
@pytest.mark.asyncio
44+
async def test_run_async(self, chat_messages, mock_dspy_module):
45+
component = DSPyChatGenerator(
46+
signature="question -> answer",
47+
api_key=Secret.from_token("test-api-key"),
48+
)
49+
response = await component.run_async(messages=chat_messages)
50+
51+
# Verify the async mock was called
52+
mock_dspy_module.acall.assert_called_once()
53+
54+
# Check that the component returns the correct ChatMessage response
55+
assert isinstance(response, dict)
56+
assert "replies" in response
57+
assert isinstance(response["replies"], list)
58+
assert len(response["replies"]) == 1
59+
assert all(isinstance(reply, ChatMessage) for reply in response["replies"])
60+
61+
@pytest.mark.asyncio
62+
async def test_run_async_with_params(self, chat_messages, mock_dspy_module):
63+
component = DSPyChatGenerator(
64+
signature="question -> answer",
65+
api_key=Secret.from_token("test-api-key"),
66+
)
67+
response = await component.run_async(
68+
messages=chat_messages,
69+
generation_kwargs={"temperature": 0.9},
70+
)
71+
72+
# Check that acall was called with the correct parameters
73+
_, kwargs = mock_dspy_module.acall.call_args
74+
assert kwargs["config"] == {"temperature": 0.9}
75+
76+
# Check that the component returns the correct response
77+
assert isinstance(response, dict)
78+
assert "replies" in response
79+
assert len(response["replies"]) == 1
80+
assert all(isinstance(reply, ChatMessage) for reply in response["replies"])
81+
82+
@pytest.mark.asyncio
83+
async def test_run_async_with_empty_messages(self, mock_dspy_module):
84+
component = DSPyChatGenerator(
85+
signature="question -> answer",
86+
api_key=Secret.from_token("test-api-key"),
87+
)
88+
with pytest.raises(ValueError, match="messages"):
89+
await component.run_async(messages=[])

0 commit comments

Comments
 (0)