Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ def from_dict(cls, data: dict[str, Any]) -> "DSPySignatureChatGenerator":

return default_from_dict(cls, data)

@component.output_types(replies=list[ChatMessage])
@component.output_types(replies=list[ChatMessage], metadata=list[dict[str, Any]])
def run(
self,
messages: list[ChatMessage],
Expand All @@ -246,22 +246,36 @@ def run(
if not messages:
msg = "The 'messages' parameter cannot be empty."
raise ValueError(msg)

# 1. Extract prompt and handle model-specific formatting
prompt = self._extract_last_user_message(messages)
dspy_inputs = self._build_dspy_inputs(prompt, **kwargs)

prediction = self._module(**dspy_inputs, config=generation_kwargs or {})
if "mistral" in self.model.lower():
prompt = f"<s>[INST] {prompt} [/INST]"
# This fixes the issue for Mistral users
# 2. Prepare inputs and merge generation config
dspy_inputs = self._build_dspy_inputs(prompt, **kwargs)

# Merge component-level kwargs with runtime-level kwargs
config = {**self.generation_kwargs, **(generation_kwargs or {})}
# 3. Execute the DSPy module
prediction = self._module(**dspy_inputs, config=config)
# 4. Extract result with safety check
if not hasattr(prediction, self.output_field):
available = list(prediction.keys())
msg = f"Output field '{self.output_field}' not found in prediction. Available fields: {available}"
raise ValueError(msg)
output_text = getattr(prediction, self.output_field)
# 5. Build rich metadata (Important for Haystack)
metadata = {
"model": self.model,
"module_type": self.module_type,
"signature": str(self.signature),
# DSPy predictions often store reasoning in 'rationale' for CoT
"rationale": getattr(prediction, "rationale", None),
}
return {"replies": [ChatMessage.from_assistant(text=output_text)], "metadata": [metadata]}

replies = [ChatMessage.from_assistant(text=output_text)]
return {"replies": replies}

@component.output_types(replies=list[ChatMessage])
@component.output_types(replies=list[ChatMessage], metadata=list[dict[str, Any]])
async def run_async(
self,
messages: list[ChatMessage],
Expand All @@ -284,14 +298,17 @@ async def run_async(

prompt = self._extract_last_user_message(messages)
dspy_inputs = self._build_dspy_inputs(prompt, **kwargs)

prediction = await self._module.acall(**dspy_inputs, config=generation_kwargs or {})
config = {**self.generation_kwargs, **(generation_kwargs or {})}
prediction = await self._module.acall(**dspy_inputs, config=config)

if not hasattr(prediction, self.output_field):
available = list(prediction.keys())
msg = f"Output field '{self.output_field}' not found in prediction. Available fields: {available}"
raise ValueError(msg)
output_text = getattr(prediction, self.output_field)
metadata = {
"model": self.model,
"rationale": getattr(prediction, "rationale", None),
}

replies = [ChatMessage.from_assistant(text=output_text)]
return {"replies": replies}
return {"replies": [ChatMessage.from_assistant(text=output_text)], "metadata": [metadata]}
47 changes: 46 additions & 1 deletion integrations/dspy/tests/test_chat_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,10 +345,11 @@ def test_run_with_generation_kwargs(self, chat_messages, mock_dspy_module):
response = component.run(chat_messages, generation_kwargs={"temperature": 0.9})

_, kwargs = mock_dspy_module.call_args
assert kwargs["config"] == {"temperature": 0.9}
assert kwargs["config"] == {"max_tokens": 10, "temperature": 0.9}

assert isinstance(response, dict)
assert "replies" in response
assert "metadata" in response
assert len(response["replies"]) == 1
assert all(isinstance(reply, ChatMessage) for reply in response["replies"])

Expand Down Expand Up @@ -439,6 +440,50 @@ def test_run_with_wrong_model(self, mock_dspy_module):
with pytest.raises(Exception, match="Invalid model name"):
generator.run(messages=[ChatMessage.from_user("Whatever")])

def test_run_mistral_prompt_formatting(self, mock_dspy_module):
component = DSPySignatureChatGenerator(signature="question -> answer", model="mistral-small")
messages = [ChatMessage.from_user("Test mistral prompt")]
component.run(messages=messages)
call_kwargs = mock_dspy_module.call_args.kwargs
assert "<s>[INST] Test mistral prompt [/INST]" in call_kwargs.get("question")

def test_run_without_input_mapping_but_with_kwargs(self, mock_dspy_module):
component = DSPySignatureChatGenerator(signature="context, question -> answer")
messages = [ChatMessage.from_user("The prompt text")]
component.run(messages=messages, question="The question kwarg")
call_kwargs = mock_dspy_module.call_args.kwargs
assert call_kwargs.get("context") == "The prompt text"
assert call_kwargs.get("question") == "The question kwarg"

def test_run_without_input_mapping_and_missing_kwargs(self, mock_dspy_module):
component = DSPySignatureChatGenerator(signature="context, question -> answer")
messages = [ChatMessage.from_user("The context text")]
component.run(messages=messages)
call_kwargs = mock_dspy_module.call_args.kwargs
assert call_kwargs.get("context") == "The context text"
assert "question" not in call_kwargs

def test_run_with_signature_class_without_mapping(self, mock_dspy_module, sample_qa_signature):
component = DSPySignatureChatGenerator(
signature=sample_qa_signature,
)
messages = [ChatMessage.from_user("The question")]
component.run(messages=messages)
call_kwargs = mock_dspy_module.call_args.kwargs
assert call_kwargs.get("question") == "The question"

def test_from_dict_without_signature_dict_in_init_params(self, mock_dspy_module):
"""Test deserialization when signature is a plain string instead of a dict."""
data = {
"type": "haystack_integrations.components.generators.dspy.chat.chat_generator.DSPySignatureChatGenerator",
"init_parameters": {
"signature": "question -> answer",
"model": "openai/gpt-4o",
},
}
component = DSPySignatureChatGenerator.from_dict(data)
assert component.signature == "question -> answer"

@pytest.mark.skipif(
not os.environ.get("OPENAI_API_KEY", None),
reason="Export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.",
Expand Down
17 changes: 16 additions & 1 deletion integrations/dspy/tests/test_chat_generator_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,17 +74,19 @@ async def test_run_async_always_passes_config(self, chat_messages, mock_dspy_mod
async def test_run_async_with_params(self, chat_messages, mock_dspy_module):
component = DSPySignatureChatGenerator(
signature="question -> answer",
generation_kwargs={"max_tokens": 10, "temperature": 0.5},
)
response = await component.run_async(
messages=chat_messages,
generation_kwargs={"temperature": 0.9},
)

_, kwargs = mock_dspy_module.acall.call_args
assert kwargs["config"] == {"temperature": 0.9}
assert kwargs["config"] == {"max_tokens": 10, "temperature": 0.9}

assert isinstance(response, dict)
assert "replies" in response
assert "metadata" in response
assert len(response["replies"]) == 1
assert all(isinstance(reply, ChatMessage) for reply in response["replies"])

Expand All @@ -95,3 +97,16 @@ async def test_run_async_with_empty_messages(self, mock_dspy_module):
)
with pytest.raises(ValueError, match="messages"):
await component.run_async(messages=[])

@pytest.mark.asyncio
async def test_run_async_with_wrong_output_field(self, mock_dspy_module):
prediction = MagicMock(spec=["answer", "keys"])
prediction.keys.return_value = ["answer"]
mock_dspy_module.acall.return_value = prediction
component = DSPySignatureChatGenerator(
signature="question -> answer",
output_field="nonexistent",
)
messages = [ChatMessage.from_user("Hello")]
with pytest.raises(ValueError, match="Output field 'nonexistent' not found"):
await component.run_async(messages=messages)
20 changes: 20 additions & 0 deletions integrations/dspy/tests/test_serialization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from haystack_integrations.components.generators.dspy.chat.chat_generator import DSPySignatureChatGenerator


class TestSerialization:
def test_to_dict(self):
component = DSPySignatureChatGenerator(signature="question -> answer", model="gpt-4o-mini")
data = component.to_dict()
base_path = "haystack_integrations.components.generators.dspy.chat.chat_generator"
expected_path = f"{base_path}.DSPySignatureChatGenerator"
assert data["type"] == expected_path
assert data["init_parameters"]["signature"] == {"type": "str", "value": "question -> answer"}

def test_from_dict(self):
data = {
"type": "haystack_integrations.components.generators.dspy.chat.chat_generator.DSPySignatureChatGenerator",
"init_parameters": {"signature": {"type": "str", "value": "question -> answer"}, "model": "gpt-4o-mini"},
}
component = DSPySignatureChatGenerator.from_dict(data)
assert component.model == "gpt-4o-mini"
assert component.signature == "question -> answer"
Loading