Skip to content

Commit 27fe361

Browse files
committed
Remove generator, tests, and updates as discussed
1 parent 9ddd737 commit 27fe361

10 files changed

Lines changed: 325 additions & 552 deletions

File tree

integrations/dspy/README.md

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
# dspy-haystack
2+
3+
[![PyPI - Version](https://img.shields.io/pypi/v/dspy-haystack.svg)](https://pypi.org/project/dspy-haystack)
4+
[![PyPI - Python Version](https://img.shields.io/pypi/pyversions/dspy-haystack.svg)](https://pypi.org/project/dspy-haystack)
5+
6+
An integration between [DSPy](https://github.com/stanfordnlp/dspy) and [Haystack](https://haystack.deepset.ai/).
7+
8+
DSPy is a framework for algorithmically optimizing prompts for Language Models by applying classical machine learning concepts (training data, evaluation metrics, optimization).
9+
10+
This integration provides:
11+
- **DSPyChatGenerator** — a Haystack ChatGenerator component that uses DSPy signatures and modules for structured generation
12+
13+
## Installation
14+
15+
```bash
16+
pip install dspy-haystack
17+
```
18+
19+
## Quick Start
20+
21+
### DSPyChatGenerator
22+
23+
A Haystack chat generator that uses DSPy signatures for structured generation with built-in reasoning patterns (Chain-of-Thought, Predict, ReAct).
24+
25+
```python
26+
from haystack import Pipeline
27+
from haystack.dataclasses import ChatMessage
28+
from haystack_integrations.components.generators.dspy import DSPyChatGenerator
29+
import dspy
30+
31+
# Define a DSPy signature
32+
class QASignature(dspy.Signature):
33+
"""Answer questions accurately and concisely."""
34+
question = dspy.InputField(desc="The user's question")
35+
answer = dspy.OutputField(desc="A clear, concise answer")
36+
37+
# Create the generator
38+
generator = DSPyChatGenerator(
39+
model="openai/gpt-5-mini",
40+
signature=QASignature,
41+
module_type="ChainOfThought"
42+
)
43+
44+
# Use in pipeline
45+
pipeline = Pipeline()
46+
pipeline.add_component("llm", generator)
47+
48+
messages = [ChatMessage.from_user("What is the capital of France?")]
49+
result = pipeline.run({"llm": {"messages": messages}})
50+
print(result["llm"]["replies"][0].text)
51+
```
52+
53+
You can also use string signatures for quick prototyping:
54+
55+
```python
56+
generator = DSPyChatGenerator(
57+
model="openai/gpt-5-mini",
58+
signature="question -> answer",
59+
module_type="Predict"
60+
)
61+
```
62+
63+
## License
64+
65+
`dspy-haystack` is distributed under the terms of the [Apache-2.0](https://spdx.org/licenses/Apache-2.0.html) license.

integrations/dspy/examples/chat_generator_example.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ def basic_qa_example():
1515
"""Simple question-answering with Chain-of-Thought reasoning."""
1616

1717
generator = DSPyChatGenerator(
18-
model="openai/gpt-4o-mini",
18+
model="openai/gpt-5-mini",
1919
signature=QASignature,
2020
module_type="ChainOfThought",
2121
output_field="answer",
@@ -34,7 +34,7 @@ def basic_qa_example():
3434
def string_signature_example():
3535
"""Using a simple string signature instead of a class."""
3636
generator = DSPyChatGenerator(
37-
model="openai/gpt-4o-mini",
37+
model="openai/gpt-5-mini",
3838
signature="question -> answer",
3939
module_type="Predict",
4040
output_field="answer",

integrations/dspy/pyproject.toml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,8 +65,7 @@ integration = 'pytest -m "integration" {args:tests}'
6565
all = 'pytest {args:tests}'
6666
cov-retry = 'pytest --cov=haystack_integrations --reruns 3 --reruns-delay 30 -x {args:tests}'
6767

68-
types = """mypy -p haystack_integrations.components.generators.dspy \
69-
-p haystack_integrations.utils.dspy {args}"""
68+
types = "mypy -p haystack_integrations.components.generators.dspy {args}"
7069

7170
[tool.mypy]
7271
install_types = true
Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
11
from haystack_integrations.components.generators.dspy.chat.chat_generator import DSPyChatGenerator
2-
from haystack_integrations.components.generators.dspy.generator import DSPyGenerator
3-
from haystack_integrations.components.generators.dspy.program_runner import DSPyProgramRunner
42

5-
__all__ = ["DSPyChatGenerator", "DSPyGenerator", "DSPyProgramRunner"]
3+
__all__ = ["DSPyChatGenerator"]
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from haystack_integrations.components.generators.dspy.chat.chat_generator import DSPyChatGenerator
2+
3+
__all__ = ["DSPyChatGenerator"]

integrations/dspy/src/haystack_integrations/components/generators/dspy/chat/chat_generator.py

Lines changed: 140 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,48 @@
11
from typing import Any, Callable, Dict, List, Optional, Type, Union
22

33
import dspy
4-
from haystack import component
4+
from haystack import component, default_from_dict, default_to_dict
55
from haystack.dataclasses import ChatMessage, ChatRole
6-
from haystack.utils import Secret
6+
from haystack.utils import Secret, deserialize_secrets_inplace
77

8-
from haystack_integrations.components.generators.dspy.generator import DSPyGenerator
8+
VALID_MODULE_TYPES = {"Predict", "ChainOfThought", "ReAct"}
9+
10+
11+
def configure_dspy_lm(model: str, api_key: str, **kwargs: Any) -> dspy.LM:
12+
"""
13+
Create and configure a DSPy language model.
14+
15+
:param model: Model identifier (e.g. ``"openai/gpt-5-mini"``).
16+
:param api_key: Resolved API key string.
17+
:param kwargs: Additional keyword arguments passed to ``dspy.LM``.
18+
:returns: The configured ``dspy.LM`` instance.
19+
"""
20+
lm = dspy.LM(model=model, api_key=api_key, **kwargs)
21+
dspy.configure(lm=lm)
22+
return lm
23+
24+
25+
def get_dspy_module_class(module_type: str):
26+
"""
27+
Map a module type string to the corresponding DSPy module class.
28+
29+
:param module_type: One of ``"Predict"``, ``"ChainOfThought"``, or ``"ReAct"``.
30+
:returns: The DSPy module class.
31+
:raises ValueError: If the module type is not recognized.
32+
"""
33+
mapping = {
34+
"Predict": dspy.Predict,
35+
"ChainOfThought": dspy.ChainOfThought,
36+
"ReAct": dspy.ReAct,
37+
}
38+
if module_type not in mapping:
39+
msg = f"Invalid module_type '{module_type}'. Must be one of {sorted(VALID_MODULE_TYPES)}"
40+
raise ValueError(msg)
41+
return mapping[module_type]
942

1043

1144
@component
12-
class DSPyChatGenerator(DSPyGenerator):
45+
class DSPyChatGenerator:
1346
"""
1447
A Haystack chat generator component that uses DSPy signatures and modules
1548
for structured generation.
@@ -64,18 +97,94 @@ def __init__(
6497
:param input_mapping: Maps DSPy signature input field names to run kwarg names.
6598
:param streaming_callback: Callback for streaming responses.
6699
"""
67-
DSPyGenerator.__init__(
68-
self,
69-
signature=signature,
70-
model=model,
71-
api_key=api_key,
72-
module_type=module_type,
73-
output_field=output_field,
74-
generation_kwargs=generation_kwargs,
75-
input_mapping=input_mapping,
76-
streaming_callback=streaming_callback,
100+
if module_type not in VALID_MODULE_TYPES:
101+
msg = f"Invalid module_type '{module_type}'. Must be one of {sorted(VALID_MODULE_TYPES)}"
102+
raise ValueError(msg)
103+
104+
self.signature = signature
105+
self.model = model
106+
self.api_key = api_key
107+
self.module_type = module_type
108+
self.output_field = output_field
109+
self.generation_kwargs = generation_kwargs or {}
110+
self.input_mapping = input_mapping
111+
self.streaming_callback = streaming_callback
112+
113+
self._lm = configure_dspy_lm(
114+
model=self.model,
115+
api_key=self.api_key.resolve_value(),
116+
**self.generation_kwargs,
77117
)
78118

119+
module_class = get_dspy_module_class(self.module_type)
120+
self._module = module_class(self.signature)
121+
122+
def _build_dspy_inputs(self, prompt: str, **kwargs) -> Dict[str, Any]:
123+
"""Build the input dict for the DSPy module call."""
124+
if self.input_mapping:
125+
dspy_inputs = {}
126+
for sig_field, source in self.input_mapping.items():
127+
if source in kwargs:
128+
dspy_inputs[sig_field] = kwargs[source]
129+
else:
130+
dspy_inputs[sig_field] = prompt
131+
return dspy_inputs
132+
133+
input_fields = self._get_input_field_names()
134+
dspy_inputs = {input_fields[0]: prompt}
135+
136+
for field in input_fields[1:]:
137+
if field in kwargs:
138+
dspy_inputs[field] = kwargs[field]
139+
140+
return dspy_inputs
141+
142+
def _get_input_field_names(self) -> List[str]:
143+
"""Get input field names from the signature."""
144+
if isinstance(self.signature, str):
145+
input_part = self.signature.split("->")[0].strip()
146+
return [f.strip() for f in input_part.split(",")]
147+
return list(self.signature.input_fields.keys())
148+
149+
@staticmethod
150+
def _extract_last_user_message(messages: List[ChatMessage]) -> str:
151+
"""Extract the text of the last user message from a list of chat messages."""
152+
for msg in reversed(messages):
153+
if msg.role == ChatRole.USER:
154+
return msg.text
155+
return messages[-1].text
156+
157+
def _signature_to_string(self) -> str:
158+
"""Convert the signature to a string representation for serialization."""
159+
if isinstance(self.signature, str):
160+
return self.signature
161+
input_names = list(self.signature.input_fields.keys())
162+
output_names = list(self.signature.output_fields.keys())
163+
return ", ".join(input_names) + " -> " + ", ".join(output_names)
164+
165+
def to_dict(self) -> Dict[str, Any]:
166+
"""Serialize this component to a dictionary."""
167+
kwargs: Dict[str, Any] = {
168+
"signature": self._signature_to_string(),
169+
"model": self.model,
170+
"module_type": self.module_type,
171+
"output_field": self.output_field,
172+
"generation_kwargs": self.generation_kwargs,
173+
"input_mapping": self.input_mapping,
174+
}
175+
try:
176+
kwargs["api_key"] = self.api_key.to_dict()
177+
except ValueError:
178+
pass
179+
return default_to_dict(self, **kwargs)
180+
181+
@classmethod
182+
def from_dict(cls, data: Dict[str, Any]) -> "DSPyChatGenerator":
183+
"""Deserialize a component from a dictionary."""
184+
init_params = data.get("init_parameters", {})
185+
deserialize_secrets_inplace(init_params, ["api_key"])
186+
return default_from_dict(cls, data)
187+
79188
@component.output_types(replies=List[ChatMessage])
80189
def run(
81190
self,
@@ -96,11 +205,17 @@ def run(
96205
raise ValueError(msg)
97206

98207
prompt = self._extract_last_user_message(messages)
99-
result = DSPyGenerator.run(self, prompt=prompt, generation_kwargs=generation_kwargs, **kwargs)
208+
dspy_inputs = self._build_dspy_inputs(prompt, **kwargs)
100209

101-
replies = [ChatMessage.from_assistant(text=text) for text in result["replies"]]
210+
if generation_kwargs:
211+
prediction = self._module(**dspy_inputs, config=generation_kwargs)
212+
else:
213+
prediction = self._module(**dspy_inputs)
102214

103-
return {"replies": replies, "meta": result["meta"]}
215+
output_text = getattr(prediction, self.output_field, str(prediction))
216+
217+
replies = [ChatMessage.from_assistant(text=output_text)]
218+
return {"replies": replies}
104219

105220
@component.output_types(replies=List[ChatMessage])
106221
async def run_async(
@@ -124,18 +239,14 @@ async def run_async(
124239
raise ValueError(msg)
125240

126241
prompt = self._extract_last_user_message(messages)
127-
result = await DSPyGenerator.run_async(self, prompt=prompt, generation_kwargs=generation_kwargs, **kwargs)
242+
dspy_inputs = self._build_dspy_inputs(prompt, **kwargs)
128243

129-
replies = [ChatMessage.from_assistant(text=text) for text in result["replies"]]
244+
if generation_kwargs:
245+
prediction = await self._module.acall(**dspy_inputs, config=generation_kwargs)
246+
else:
247+
prediction = await self._module.acall(**dspy_inputs)
130248

131-
return {"replies": replies, "meta": result["meta"]}
249+
output_text = getattr(prediction, self.output_field, str(prediction))
132250

133-
@staticmethod
134-
def _extract_last_user_message(messages: List[ChatMessage]) -> str:
135-
"""Extract the text of the last user message from a list of chat messages."""
136-
for msg in reversed(messages):
137-
if msg.role == ChatRole.USER:
138-
return msg.text
139-
140-
# Fallback to last message if no user message found
141-
return messages[-1].text
251+
replies = [ChatMessage.from_assistant(text=output_text)]
252+
return {"replies": replies}

0 commit comments

Comments
 (0)