diff --git a/integrations/dspy/src/haystack_integrations/components/generators/dspy/chat/chat_generator.py b/integrations/dspy/src/haystack_integrations/components/generators/dspy/chat/chat_generator.py index 26464a42b8..ac191b94da 100644 --- a/integrations/dspy/src/haystack_integrations/components/generators/dspy/chat/chat_generator.py +++ b/integrations/dspy/src/haystack_integrations/components/generators/dspy/chat/chat_generator.py @@ -86,6 +86,7 @@ def __init__( generation_kwargs: dict[str, Any] | None = None, module_kwargs: dict[str, Any] | None = None, input_mapping: dict[str, str] | None = None, + pipeline_inputs: list[str] | None = None, ): """ Initialize the DSPySignatureChatGenerator. @@ -101,9 +102,13 @@ def __init__( For example, use `{"tools": [tool1, tool2]}` when using the `"ReAct"` module type. :param input_mapping: Maps DSPy signature input field names to `run()` kwarg names. For example, if your signature has an input field `"context"` but your pipeline - provides it as `"documents"`, use `{"context": "documents"}`. When not provided, - the first input field receives the last user message text, and remaining fields - are matched by name from `**kwargs`. + provides it as `"documents"`, use `{"context": "documents"}`. When neither + `input_mapping` nor `pipeline_inputs` is provided, the first input field receives + the last user message text, and remaining fields are matched by name from `**kwargs`. + :param pipeline_inputs: Signature input fields exposed as Haystack pipeline input + sockets, so upstream components (e.g. a retriever or a text input) can connect + to them. Each name in this list must be a signature input field and becomes a + real `str` socket on the component. """ if module_type not in VALID_MODULE_TYPES: msg = f"Invalid module_type '{module_type}'. Must be one of {sorted(VALID_MODULE_TYPES)}" @@ -117,6 +122,7 @@ def __init__( self.generation_kwargs = generation_kwargs or {} self.module_kwargs = module_kwargs or {} self.input_mapping = input_mapping + self.pipeline_inputs = pipeline_inputs self._lm = _create_dspy_lm( model=self.model, @@ -128,6 +134,9 @@ def __init__( self._module = module_class(self.signature, **self.module_kwargs) self._module.set_lm(self._lm) + for extra_input in self.pipeline_inputs or []: + component.set_input_type(self, extra_input, str, "") + def _build_dspy_inputs(self, prompt: str, **kwargs: Any) -> dict[str, Any]: """Build the input dict for the DSPy module call.""" if self.input_mapping: diff --git a/integrations/dspy/tests/test_chat_generator.py b/integrations/dspy/tests/test_chat_generator.py index 477b9ff9c9..5958aae2cc 100644 --- a/integrations/dspy/tests/test_chat_generator.py +++ b/integrations/dspy/tests/test_chat_generator.py @@ -3,7 +3,10 @@ import dspy import pytest +from haystack import Document, Pipeline, component +from haystack.components.retrievers.in_memory import InMemoryBM25Retriever from haystack.dataclasses import ChatMessage +from haystack.document_stores.in_memory import InMemoryDocumentStore from haystack_integrations.components.generators.dspy.chat.chat_generator import ( VALID_MODULE_TYPES, @@ -428,6 +431,51 @@ def test_run_with_input_mapping(self, mock_dspy_module): assert call_kwargs.get("context") == "Machine learning is a subset of AI." assert call_kwargs.get("question") == "What is ML?" + def test_rag_pipeline_question_with_dependent_context(self, mock_dspy_module): + """ + Test case where the context passed to DSPy is dynamic and depends on the question asked by the user. + """ + + doc_store = InMemoryDocumentStore() + doc_store.write_documents( + [ + Document(content="Paris is the capital of France."), + Document(content="Tokyo is the capital of Japan."), + Document(content="Bananas are yellow fruits."), + ] + ) + + @component + class DocsToString: + @component.output_types(text=str) + def run(self, documents: list[Document]) -> dict: + return {"text": "\n".join(d.content for d in documents)} + + generator = DSPySignatureChatGenerator( + signature="question, context -> answer", + pipeline_inputs=["context"], + ) + + pipeline = Pipeline() + pipeline.add_component("retriever", InMemoryBM25Retriever(doc_store, top_k=1)) + pipeline.add_component("docs_to_text", DocsToString()) + pipeline.add_component("llm", generator) + pipeline.connect("retriever.documents", "docs_to_text.documents") + pipeline.connect("docs_to_text.text", "llm.context") + + question = "What is the capital of Japan?" # context should be Japan-related + pipeline.run( + { + "retriever": {"query": question}, + "llm": {"messages": [ChatMessage.from_user(question)]}, + } + ) + + call_kwargs = mock_dspy_module.call_args.kwargs + assert call_kwargs["question"] == question + assert "Tokyo is the capital of Japan." in call_kwargs["context"] + assert "Paris" not in call_kwargs["context"] + def test_run_with_wrong_model(self, mock_dspy_module): mock_dspy_module.side_effect = Exception("Invalid model name")