Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ def __init__(
generation_kwargs: dict[str, Any] | None = None,
module_kwargs: dict[str, Any] | None = None,
input_mapping: dict[str, str] | None = None,
pipeline_inputs: list[str] | None = None,
):
"""
Initialize the DSPySignatureChatGenerator.
Expand All @@ -101,9 +102,13 @@ def __init__(
For example, use `{"tools": [tool1, tool2]}` when using the `"ReAct"` module type.
:param input_mapping: Maps DSPy signature input field names to `run()` kwarg names.
For example, if your signature has an input field `"context"` but your pipeline
provides it as `"documents"`, use `{"context": "documents"}`. When not provided,
the first input field receives the last user message text, and remaining fields
are matched by name from `**kwargs`.
provides it as `"documents"`, use `{"context": "documents"}`. When neither
`input_mapping` nor `pipeline_inputs` is provided, the first input field receives
the last user message text, and remaining fields are matched by name from `**kwargs`.
:param pipeline_inputs: Signature input fields exposed as Haystack pipeline input
sockets, so upstream components (e.g. a retriever or a text input) can connect
to them. Each name in this list must be a signature input field and becomes a
real `str` socket on the component.
"""
if module_type not in VALID_MODULE_TYPES:
msg = f"Invalid module_type '{module_type}'. Must be one of {sorted(VALID_MODULE_TYPES)}"
Expand All @@ -117,6 +122,7 @@ def __init__(
self.generation_kwargs = generation_kwargs or {}
self.module_kwargs = module_kwargs or {}
self.input_mapping = input_mapping
self.pipeline_inputs = pipeline_inputs

self._lm = _create_dspy_lm(
model=self.model,
Expand All @@ -128,6 +134,9 @@ def __init__(
self._module = module_class(self.signature, **self.module_kwargs)
self._module.set_lm(self._lm)

for extra_input in self.pipeline_inputs or []:
component.set_input_type(self, extra_input, str, "")

def _build_dspy_inputs(self, prompt: str, **kwargs: Any) -> dict[str, Any]:
"""Build the input dict for the DSPy module call."""
if self.input_mapping:
Expand Down
48 changes: 48 additions & 0 deletions integrations/dspy/tests/test_chat_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@

import dspy
import pytest
from haystack import Document, Pipeline, component
from haystack.components.retrievers.in_memory import InMemoryBM25Retriever
from haystack.dataclasses import ChatMessage
from haystack.document_stores.in_memory import InMemoryDocumentStore

from haystack_integrations.components.generators.dspy.chat.chat_generator import (
VALID_MODULE_TYPES,
Expand Down Expand Up @@ -428,6 +431,51 @@ def test_run_with_input_mapping(self, mock_dspy_module):
assert call_kwargs.get("context") == "Machine learning is a subset of AI."
assert call_kwargs.get("question") == "What is ML?"

def test_rag_pipeline_question_with_dependent_context(self, mock_dspy_module):
"""
Test case where the context passed to DSPy is dynamic and depends on the question asked by the user.
"""

doc_store = InMemoryDocumentStore()
doc_store.write_documents(
[
Document(content="Paris is the capital of France."),
Document(content="Tokyo is the capital of Japan."),
Document(content="Bananas are yellow fruits."),
]
)

@component
class DocsToString:
@component.output_types(text=str)
def run(self, documents: list[Document]) -> dict:
return {"text": "\n".join(d.content for d in documents)}

generator = DSPySignatureChatGenerator(
signature="question, context -> answer",
pipeline_inputs=["context"],
)

pipeline = Pipeline()
pipeline.add_component("retriever", InMemoryBM25Retriever(doc_store, top_k=1))
pipeline.add_component("docs_to_text", DocsToString())
pipeline.add_component("llm", generator)
pipeline.connect("retriever.documents", "docs_to_text.documents")
pipeline.connect("docs_to_text.text", "llm.context")

question = "What is the capital of Japan?" # context should be Japan-related
pipeline.run(
{
"retriever": {"query": question},
"llm": {"messages": [ChatMessage.from_user(question)]},
}
)

call_kwargs = mock_dspy_module.call_args.kwargs
assert call_kwargs["question"] == question
assert "Tokyo is the capital of Japan." in call_kwargs["context"]
assert "Paris" not in call_kwargs["context"]

def test_run_with_wrong_model(self, mock_dspy_module):
mock_dspy_module.side_effect = Exception("Invalid model name")

Expand Down
Loading