11from typing import Any , Callable , Dict , List , Optional , Type , Union
22
33import dspy
4- from haystack import component
4+ from haystack import component , default_from_dict , default_to_dict
55from haystack .dataclasses import ChatMessage , ChatRole
6- from haystack .utils import Secret
6+ from haystack .utils import Secret , deserialize_secrets_inplace
77
8- from haystack_integrations .components .generators .dspy .generator import DSPyGenerator
8+ VALID_MODULE_TYPES = {"Predict" , "ChainOfThought" , "ReAct" }
9+
10+
11+ def configure_dspy_lm (model : str , api_key : str , ** kwargs : Any ) -> dspy .LM :
12+ """
13+ Create and configure a DSPy language model.
14+
15+ :param model: Model identifier (e.g. ``"openai/gpt-5-mini"``).
16+ :param api_key: Resolved API key string.
17+ :param kwargs: Additional keyword arguments passed to ``dspy.LM``.
18+ :returns: The configured ``dspy.LM`` instance.
19+ """
20+ lm = dspy .LM (model = model , api_key = api_key , ** kwargs )
21+ dspy .configure (lm = lm )
22+ return lm
23+
24+
25+ def get_dspy_module_class (module_type : str ):
26+ """
27+ Map a module type string to the corresponding DSPy module class.
28+
29+ :param module_type: One of ``"Predict"``, ``"ChainOfThought"``, or ``"ReAct"``.
30+ :returns: The DSPy module class.
31+ :raises ValueError: If the module type is not recognized.
32+ """
33+ mapping = {
34+ "Predict" : dspy .Predict ,
35+ "ChainOfThought" : dspy .ChainOfThought ,
36+ "ReAct" : dspy .ReAct ,
37+ }
38+ if module_type not in mapping :
39+ msg = f"Invalid module_type '{ module_type } '. Must be one of { sorted (VALID_MODULE_TYPES )} "
40+ raise ValueError (msg )
41+ return mapping [module_type ]
942
1043
1144@component
12- class DSPyChatGenerator ( DSPyGenerator ) :
45+ class DSPyChatGenerator :
1346 """
1447 A Haystack chat generator component that uses DSPy signatures and modules
1548 for structured generation.
@@ -64,18 +97,94 @@ def __init__(
6497 :param input_mapping: Maps DSPy signature input field names to run kwarg names.
6598 :param streaming_callback: Callback for streaming responses.
6699 """
67- DSPyGenerator .__init__ (
68- self ,
69- signature = signature ,
70- model = model ,
71- api_key = api_key ,
72- module_type = module_type ,
73- output_field = output_field ,
74- generation_kwargs = generation_kwargs ,
75- input_mapping = input_mapping ,
76- streaming_callback = streaming_callback ,
100+ if module_type not in VALID_MODULE_TYPES :
101+ msg = f"Invalid module_type '{ module_type } '. Must be one of { sorted (VALID_MODULE_TYPES )} "
102+ raise ValueError (msg )
103+
104+ self .signature = signature
105+ self .model = model
106+ self .api_key = api_key
107+ self .module_type = module_type
108+ self .output_field = output_field
109+ self .generation_kwargs = generation_kwargs or {}
110+ self .input_mapping = input_mapping
111+ self .streaming_callback = streaming_callback
112+
113+ self ._lm = configure_dspy_lm (
114+ model = self .model ,
115+ api_key = self .api_key .resolve_value (),
116+ ** self .generation_kwargs ,
77117 )
78118
119+ module_class = get_dspy_module_class (self .module_type )
120+ self ._module = module_class (self .signature )
121+
122+ def _build_dspy_inputs (self , prompt : str , ** kwargs ) -> Dict [str , Any ]:
123+ """Build the input dict for the DSPy module call."""
124+ if self .input_mapping :
125+ dspy_inputs = {}
126+ for sig_field , source in self .input_mapping .items ():
127+ if source in kwargs :
128+ dspy_inputs [sig_field ] = kwargs [source ]
129+ else :
130+ dspy_inputs [sig_field ] = prompt
131+ return dspy_inputs
132+
133+ input_fields = self ._get_input_field_names ()
134+ dspy_inputs = {input_fields [0 ]: prompt }
135+
136+ for field in input_fields [1 :]:
137+ if field in kwargs :
138+ dspy_inputs [field ] = kwargs [field ]
139+
140+ return dspy_inputs
141+
142+ def _get_input_field_names (self ) -> List [str ]:
143+ """Get input field names from the signature."""
144+ if isinstance (self .signature , str ):
145+ input_part = self .signature .split ("->" )[0 ].strip ()
146+ return [f .strip () for f in input_part .split ("," )]
147+ return list (self .signature .input_fields .keys ())
148+
149+ @staticmethod
150+ def _extract_last_user_message (messages : List [ChatMessage ]) -> str :
151+ """Extract the text of the last user message from a list of chat messages."""
152+ for msg in reversed (messages ):
153+ if msg .role == ChatRole .USER :
154+ return msg .text
155+ return messages [- 1 ].text
156+
157+ def _signature_to_string (self ) -> str :
158+ """Convert the signature to a string representation for serialization."""
159+ if isinstance (self .signature , str ):
160+ return self .signature
161+ input_names = list (self .signature .input_fields .keys ())
162+ output_names = list (self .signature .output_fields .keys ())
163+ return ", " .join (input_names ) + " -> " + ", " .join (output_names )
164+
165+ def to_dict (self ) -> Dict [str , Any ]:
166+ """Serialize this component to a dictionary."""
167+ kwargs : Dict [str , Any ] = {
168+ "signature" : self ._signature_to_string (),
169+ "model" : self .model ,
170+ "module_type" : self .module_type ,
171+ "output_field" : self .output_field ,
172+ "generation_kwargs" : self .generation_kwargs ,
173+ "input_mapping" : self .input_mapping ,
174+ }
175+ try :
176+ kwargs ["api_key" ] = self .api_key .to_dict ()
177+ except ValueError :
178+ pass
179+ return default_to_dict (self , ** kwargs )
180+
181+ @classmethod
182+ def from_dict (cls , data : Dict [str , Any ]) -> "DSPyChatGenerator" :
183+ """Deserialize a component from a dictionary."""
184+ init_params = data .get ("init_parameters" , {})
185+ deserialize_secrets_inplace (init_params , ["api_key" ])
186+ return default_from_dict (cls , data )
187+
79188 @component .output_types (replies = List [ChatMessage ])
80189 def run (
81190 self ,
@@ -96,11 +205,17 @@ def run(
96205 raise ValueError (msg )
97206
98207 prompt = self ._extract_last_user_message (messages )
99- result = DSPyGenerator . run ( self , prompt = prompt , generation_kwargs = generation_kwargs , ** kwargs )
208+ dspy_inputs = self . _build_dspy_inputs ( prompt , ** kwargs )
100209
101- replies = [ChatMessage .from_assistant (text = text ) for text in result ["replies" ]]
210+ if generation_kwargs :
211+ prediction = self ._module (** dspy_inputs , config = generation_kwargs )
212+ else :
213+ prediction = self ._module (** dspy_inputs )
102214
103- return {"replies" : replies , "meta" : result ["meta" ]}
215+ output_text = getattr (prediction , self .output_field , str (prediction ))
216+
217+ replies = [ChatMessage .from_assistant (text = output_text )]
218+ return {"replies" : replies }
104219
105220 @component .output_types (replies = List [ChatMessage ])
106221 async def run_async (
@@ -124,18 +239,14 @@ async def run_async(
124239 raise ValueError (msg )
125240
126241 prompt = self ._extract_last_user_message (messages )
127- result = await DSPyGenerator . run_async ( self , prompt = prompt , generation_kwargs = generation_kwargs , ** kwargs )
242+ dspy_inputs = self . _build_dspy_inputs ( prompt , ** kwargs )
128243
129- replies = [ChatMessage .from_assistant (text = text ) for text in result ["replies" ]]
244+ if generation_kwargs :
245+ prediction = await self ._module .acall (** dspy_inputs , config = generation_kwargs )
246+ else :
247+ prediction = await self ._module .acall (** dspy_inputs )
130248
131- return { "replies" : replies , "meta" : result [ "meta" ]}
249+ output_text = getattr ( prediction , self . output_field , str ( prediction ))
132250
133- @staticmethod
134- def _extract_last_user_message (messages : List [ChatMessage ]) -> str :
135- """Extract the text of the last user message from a list of chat messages."""
136- for msg in reversed (messages ):
137- if msg .role == ChatRole .USER :
138- return msg .text
139-
140- # Fallback to last message if no user message found
141- return messages [- 1 ].text
251+ replies = [ChatMessage .from_assistant (text = output_text )]
252+ return {"replies" : replies }
0 commit comments