diff --git a/integrations/dspy/src/haystack_integrations/components/generators/dspy/chat/chat_generator.py b/integrations/dspy/src/haystack_integrations/components/generators/dspy/chat/chat_generator.py index 26464a42b8..b09c9ac6ea 100644 --- a/integrations/dspy/src/haystack_integrations/components/generators/dspy/chat/chat_generator.py +++ b/integrations/dspy/src/haystack_integrations/components/generators/dspy/chat/chat_generator.py @@ -228,7 +228,7 @@ def from_dict(cls, data: dict[str, Any]) -> "DSPySignatureChatGenerator": return default_from_dict(cls, data) - @component.output_types(replies=list[ChatMessage]) + @component.output_types(replies=list[ChatMessage], metadata=list[dict[str, Any]]) def run( self, messages: list[ChatMessage], @@ -246,22 +246,36 @@ def run( if not messages: msg = "The 'messages' parameter cannot be empty." raise ValueError(msg) - + # 1. Extract prompt and handle model-specific formatting prompt = self._extract_last_user_message(messages) - dspy_inputs = self._build_dspy_inputs(prompt, **kwargs) - prediction = self._module(**dspy_inputs, config=generation_kwargs or {}) + if "mistral" in self.model.lower(): + prompt = f"[INST] {prompt} [/INST]" + # This fixes the issue for Mistral users + # 2. Prepare inputs and merge generation config + dspy_inputs = self._build_dspy_inputs(prompt, **kwargs) + # Merge component-level kwargs with runtime-level kwargs + config = {**self.generation_kwargs, **(generation_kwargs or {})} + # 3. Execute the DSPy module + prediction = self._module(**dspy_inputs, config=config) + # 4. Extract result with safety check if not hasattr(prediction, self.output_field): available = list(prediction.keys()) msg = f"Output field '{self.output_field}' not found in prediction. Available fields: {available}" raise ValueError(msg) output_text = getattr(prediction, self.output_field) + # 5. Build rich metadata (Important for Haystack) + metadata = { + "model": self.model, + "module_type": self.module_type, + "signature": str(self.signature), + # DSPy predictions often store reasoning in 'rationale' for CoT + "rationale": getattr(prediction, "rationale", None), + } + return {"replies": [ChatMessage.from_assistant(text=output_text)], "metadata": [metadata]} - replies = [ChatMessage.from_assistant(text=output_text)] - return {"replies": replies} - - @component.output_types(replies=list[ChatMessage]) + @component.output_types(replies=list[ChatMessage], metadata=list[dict[str, Any]]) async def run_async( self, messages: list[ChatMessage], @@ -284,14 +298,17 @@ async def run_async( prompt = self._extract_last_user_message(messages) dspy_inputs = self._build_dspy_inputs(prompt, **kwargs) - - prediction = await self._module.acall(**dspy_inputs, config=generation_kwargs or {}) + config = {**self.generation_kwargs, **(generation_kwargs or {})} + prediction = await self._module.acall(**dspy_inputs, config=config) if not hasattr(prediction, self.output_field): available = list(prediction.keys()) msg = f"Output field '{self.output_field}' not found in prediction. Available fields: {available}" raise ValueError(msg) output_text = getattr(prediction, self.output_field) + metadata = { + "model": self.model, + "rationale": getattr(prediction, "rationale", None), + } - replies = [ChatMessage.from_assistant(text=output_text)] - return {"replies": replies} + return {"replies": [ChatMessage.from_assistant(text=output_text)], "metadata": [metadata]} diff --git a/integrations/dspy/tests/test_chat_generator.py b/integrations/dspy/tests/test_chat_generator.py index 477b9ff9c9..7bbc12a6f7 100644 --- a/integrations/dspy/tests/test_chat_generator.py +++ b/integrations/dspy/tests/test_chat_generator.py @@ -345,10 +345,11 @@ def test_run_with_generation_kwargs(self, chat_messages, mock_dspy_module): response = component.run(chat_messages, generation_kwargs={"temperature": 0.9}) _, kwargs = mock_dspy_module.call_args - assert kwargs["config"] == {"temperature": 0.9} + assert kwargs["config"] == {"max_tokens": 10, "temperature": 0.9} assert isinstance(response, dict) assert "replies" in response + assert "metadata" in response assert len(response["replies"]) == 1 assert all(isinstance(reply, ChatMessage) for reply in response["replies"]) @@ -439,6 +440,50 @@ def test_run_with_wrong_model(self, mock_dspy_module): with pytest.raises(Exception, match="Invalid model name"): generator.run(messages=[ChatMessage.from_user("Whatever")]) + def test_run_mistral_prompt_formatting(self, mock_dspy_module): + component = DSPySignatureChatGenerator(signature="question -> answer", model="mistral-small") + messages = [ChatMessage.from_user("Test mistral prompt")] + component.run(messages=messages) + call_kwargs = mock_dspy_module.call_args.kwargs + assert "[INST] Test mistral prompt [/INST]" in call_kwargs.get("question") + + def test_run_without_input_mapping_but_with_kwargs(self, mock_dspy_module): + component = DSPySignatureChatGenerator(signature="context, question -> answer") + messages = [ChatMessage.from_user("The prompt text")] + component.run(messages=messages, question="The question kwarg") + call_kwargs = mock_dspy_module.call_args.kwargs + assert call_kwargs.get("context") == "The prompt text" + assert call_kwargs.get("question") == "The question kwarg" + + def test_run_without_input_mapping_and_missing_kwargs(self, mock_dspy_module): + component = DSPySignatureChatGenerator(signature="context, question -> answer") + messages = [ChatMessage.from_user("The context text")] + component.run(messages=messages) + call_kwargs = mock_dspy_module.call_args.kwargs + assert call_kwargs.get("context") == "The context text" + assert "question" not in call_kwargs + + def test_run_with_signature_class_without_mapping(self, mock_dspy_module, sample_qa_signature): + component = DSPySignatureChatGenerator( + signature=sample_qa_signature, + ) + messages = [ChatMessage.from_user("The question")] + component.run(messages=messages) + call_kwargs = mock_dspy_module.call_args.kwargs + assert call_kwargs.get("question") == "The question" + + def test_from_dict_without_signature_dict_in_init_params(self, mock_dspy_module): + """Test deserialization when signature is a plain string instead of a dict.""" + data = { + "type": "haystack_integrations.components.generators.dspy.chat.chat_generator.DSPySignatureChatGenerator", + "init_parameters": { + "signature": "question -> answer", + "model": "openai/gpt-4o", + }, + } + component = DSPySignatureChatGenerator.from_dict(data) + assert component.signature == "question -> answer" + @pytest.mark.skipif( not os.environ.get("OPENAI_API_KEY", None), reason="Export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.", diff --git a/integrations/dspy/tests/test_chat_generator_async.py b/integrations/dspy/tests/test_chat_generator_async.py index e83a8b1255..382dd05673 100644 --- a/integrations/dspy/tests/test_chat_generator_async.py +++ b/integrations/dspy/tests/test_chat_generator_async.py @@ -74,6 +74,7 @@ async def test_run_async_always_passes_config(self, chat_messages, mock_dspy_mod async def test_run_async_with_params(self, chat_messages, mock_dspy_module): component = DSPySignatureChatGenerator( signature="question -> answer", + generation_kwargs={"max_tokens": 10, "temperature": 0.5}, ) response = await component.run_async( messages=chat_messages, @@ -81,10 +82,11 @@ async def test_run_async_with_params(self, chat_messages, mock_dspy_module): ) _, kwargs = mock_dspy_module.acall.call_args - assert kwargs["config"] == {"temperature": 0.9} + assert kwargs["config"] == {"max_tokens": 10, "temperature": 0.9} assert isinstance(response, dict) assert "replies" in response + assert "metadata" in response assert len(response["replies"]) == 1 assert all(isinstance(reply, ChatMessage) for reply in response["replies"]) @@ -95,3 +97,16 @@ async def test_run_async_with_empty_messages(self, mock_dspy_module): ) with pytest.raises(ValueError, match="messages"): await component.run_async(messages=[]) + + @pytest.mark.asyncio + async def test_run_async_with_wrong_output_field(self, mock_dspy_module): + prediction = MagicMock(spec=["answer", "keys"]) + prediction.keys.return_value = ["answer"] + mock_dspy_module.acall.return_value = prediction + component = DSPySignatureChatGenerator( + signature="question -> answer", + output_field="nonexistent", + ) + messages = [ChatMessage.from_user("Hello")] + with pytest.raises(ValueError, match="Output field 'nonexistent' not found"): + await component.run_async(messages=messages) diff --git a/integrations/dspy/tests/test_serialization.py b/integrations/dspy/tests/test_serialization.py new file mode 100644 index 0000000000..7783e0e124 --- /dev/null +++ b/integrations/dspy/tests/test_serialization.py @@ -0,0 +1,20 @@ +from haystack_integrations.components.generators.dspy.chat.chat_generator import DSPySignatureChatGenerator + + +class TestSerialization: + def test_to_dict(self): + component = DSPySignatureChatGenerator(signature="question -> answer", model="gpt-4o-mini") + data = component.to_dict() + base_path = "haystack_integrations.components.generators.dspy.chat.chat_generator" + expected_path = f"{base_path}.DSPySignatureChatGenerator" + assert data["type"] == expected_path + assert data["init_parameters"]["signature"] == {"type": "str", "value": "question -> answer"} + + def test_from_dict(self): + data = { + "type": "haystack_integrations.components.generators.dspy.chat.chat_generator.DSPySignatureChatGenerator", + "init_parameters": {"signature": {"type": "str", "value": "question -> answer"}, "model": "gpt-4o-mini"}, + } + component = DSPySignatureChatGenerator.from_dict(data) + assert component.model == "gpt-4o-mini" + assert component.signature == "question -> answer"