22#
33# SPDX-License-Identifier: Apache-2.0
44
5+ import json
56from copy import deepcopy
67from typing import Any , Dict , List , Literal , Optional , Set , Union
78
1011
1112from haystack import component , default_from_dict , default_to_dict , logging
1213from haystack .dataclasses .chat_message import ChatMessage , ChatRole , TextContent
14+ from haystack .lazy_imports import LazyImport
1315from haystack .utils import Jinja2TimeExtension
16+ from haystack .utils .jinja2_chat_extension import ChatMessageExtension , templatize_part
1417
1518logger = logging .getLogger (__name__ )
1619
20+ with LazyImport ("Run 'pip install \" arrow>=1.3.0\" '" ) as arrow_import :
21+ import arrow # pylint: disable=unused-import
22+
23+ NO_TEXT_ERROR_MESSAGE = "ChatMessages from {role} role must contain text. Received ChatMessage with no text: {message}"
24+
25+ FILTER_NOT_ALLOWED_ERROR_MESSAGE = (
26+ "The templatize_part filter cannot be used with a template containing a list of"
27+ "ChatMessage objects. Use a string template or remove the templatize_part filter "
28+ "from the template."
29+ )
30+
1731
1832@component
1933class ChatPromptBuilder :
2034 """
21- Renders a chat prompt from a template string using Jinja2 syntax.
35+ Renders a chat prompt from a template using Jinja2 syntax.
36+
37+ A template can be a list of `ChatMessage` objects, or a special string, as shown in the usage examples.
2238
2339 It constructs prompts using static or dynamic templates, which you can update for each pipeline run.
2440
@@ -28,15 +44,15 @@ class ChatPromptBuilder:
2844
2945 ### Usage examples
3046
31- #### With static prompt template
47+ #### Static ChatMessage prompt template
3248
3349 ```python
3450 template = [ChatMessage.from_user("Translate to {{ target_language }}. Context: {{ snippet }}; Translation:")]
3551 builder = ChatPromptBuilder(template=template)
3652 builder.run(target_language="spanish", snippet="I can't speak spanish.")
3753 ```
3854
39- #### Overriding static template at runtime
55+ #### Overriding static ChatMessage template at runtime
4056
4157 ```python
4258 template = [ChatMessage.from_user("Translate to {{ target_language }}. Context: {{ snippet }}; Translation:")]
@@ -48,7 +64,7 @@ class ChatPromptBuilder:
4864 builder.run(target_language="spanish", snippet="I can't speak spanish.", template=summary_template)
4965 ```
5066
51- #### With dynamic prompt template
67+ #### Dynamic ChatMessage prompt template
5268
5369 ```python
5470 from haystack.components.builders import ChatPromptBuilder
@@ -97,19 +113,42 @@ class ChatPromptBuilder:
97113 'total_tokens': 238}})]}}
98114 ```
99115
116+ #### String prompt template
117+ ```python
118+ from haystack.components.builders import ChatPromptBuilder
119+ from haystack.dataclasses.image_content import ImageContent
120+
121+ template = \" \" \"
122+ {% message role="system" %}
123+ You are a helpful assistant.
124+ {% endmessage %}
125+
126+ {% message role="user" %}
127+ Hello! I am {{user_name}}. What's the difference between the following images?
128+ {% for image in images %}
129+ {{ image | templatize_part }}
130+ {% endfor %}
131+ {% endmessage %}
132+ \" \" \"
133+
134+ images = [ImageContent.from_file_path("apple.jpg"), ImageContent.from_file_path("orange.jpg")]
135+
136+ builder = ChatPromptBuilder(template=template)
137+ builder.run(user_name="John", images=images)
138+ ```
100139 """
101140
102141 def __init__ (
103142 self ,
104- template : Optional [List [ChatMessage ]] = None ,
143+ template : Optional [Union [ List [ChatMessage ], str ]] = None ,
105144 required_variables : Optional [Union [List [str ], Literal ["*" ]]] = None ,
106145 variables : Optional [List [str ]] = None ,
107146 ):
108147 """
109148 Constructs a ChatPromptBuilder component.
110149
111150 :param template:
112- A list of `ChatMessage` objects. The component looks for Jinja2 template syntax and
151+ A list of `ChatMessage` objects or a string template . The component looks for Jinja2 template syntax and
113152 renders the prompt with the provided variables. Provide the template in either
114153 the `init` method` or the `run` method.
115154 :param required_variables:
@@ -123,26 +162,32 @@ def __init__(
123162 """
124163 self ._variables = variables
125164 self ._required_variables = required_variables
126- self .required_variables = required_variables or []
127165 self .template = template
128- variables = variables or []
129- try :
130- # The Jinja2TimeExtension needs an optional dependency to be installed.
131- # If it's not available we can do without it and use the ChatPromptBuilder as is.
132- self ._env = SandboxedEnvironment (extensions = [Jinja2TimeExtension ])
133- except ImportError :
134- self ._env = SandboxedEnvironment ()
135166
167+ self ._env = SandboxedEnvironment (extensions = [ChatMessageExtension ])
168+ self ._env .filters ["templatize_part" ] = templatize_part
169+ if arrow_import .is_successful ():
170+ self ._env .add_extension (Jinja2TimeExtension )
171+
172+ extracted_variables = []
136173 if template and not variables :
137- for message in template :
138- if message .is_from (ChatRole .USER ) or message .is_from (ChatRole .SYSTEM ):
139- # infer variables from template
140- if message .text is None :
141- raise ValueError (f"The provided ChatMessage has no text. ChatMessage: { message } " )
142- ast = self ._env .parse (message .text )
143- template_variables = meta .find_undeclared_variables (ast )
144- variables += list (template_variables )
145- self .variables = variables
174+ if isinstance (template , list ):
175+ for message in template :
176+ if message .is_from (ChatRole .USER ) or message .is_from (ChatRole .SYSTEM ):
177+ # infer variables from template
178+ if message .text is None :
179+ raise ValueError (NO_TEXT_ERROR_MESSAGE .format (role = message .role .value , message = message ))
180+ if message .text and "templatize_part" in message .text :
181+ raise ValueError (FILTER_NOT_ALLOWED_ERROR_MESSAGE )
182+ ast = self ._env .parse (message .text )
183+ template_variables = meta .find_undeclared_variables (ast )
184+ extracted_variables += list (template_variables )
185+ elif isinstance (template , str ):
186+ ast = self ._env .parse (template )
187+ extracted_variables = list (meta .find_undeclared_variables (ast ))
188+
189+ self .variables = variables or extracted_variables
190+ self .required_variables = required_variables or []
146191
147192 if len (self .variables ) > 0 and required_variables is None :
148193 logger .warning (
@@ -163,7 +208,7 @@ def __init__(
163208 @component .output_types (prompt = List [ChatMessage ])
164209 def run (
165210 self ,
166- template : Optional [List [ChatMessage ]] = None ,
211+ template : Optional [Union [ List [ChatMessage ], str ]] = None ,
167212 template_variables : Optional [Dict [str , Any ]] = None ,
168213 ** kwargs ,
169214 ):
@@ -175,7 +220,8 @@ def run(
175220 To overwrite pipeline kwargs, you can set the `template_variables` parameter.
176221
177222 :param template:
178- An optional list of `ChatMessage` objects to overwrite ChatPromptBuilder's default template.
223+ An optional list of `ChatMessage` objects or string template to overwrite ChatPromptBuilder's default
224+ template.
179225 If `None`, the default template provided at initialization is used.
180226 :param template_variables:
181227 An optional dictionary of template variables to overwrite the pipeline variables.
@@ -200,30 +246,56 @@ def run(
200246 f"Please provide a valid list of ChatMessage instances to render the prompt."
201247 )
202248
203- if not all (isinstance (message , ChatMessage ) for message in template ):
249+ if isinstance ( template , list ) and not all (isinstance (message , ChatMessage ) for message in template ):
204250 raise ValueError (
205251 f"The { self .__class__ .__name__ } expects a list containing only ChatMessage instances. "
206252 f"The provided list contains other types. Please ensure that all elements in the list "
207253 f"are ChatMessage instances."
208254 )
209255
210256 processed_messages = []
211- for message in template :
212- if message .is_from (ChatRole .USER ) or message .is_from (ChatRole .SYSTEM ):
213- self ._validate_variables (set (template_variables_combined .keys ()))
214- if message .text is None :
215- raise ValueError (f"The provided ChatMessage has no text. ChatMessage: { message } " )
216- compiled_template = self ._env .from_string (message .text )
217- rendered_text = compiled_template .render (template_variables_combined )
218- # deep copy the message to avoid modifying the original message
219- rendered_message : ChatMessage = deepcopy (message )
220- rendered_message ._content = [TextContent (text = rendered_text )]
221- processed_messages .append (rendered_message )
222- else :
223- processed_messages .append (message )
257+ if isinstance (template , list ):
258+ for message in template :
259+ if message .is_from (ChatRole .USER ) or message .is_from (ChatRole .SYSTEM ):
260+ self ._validate_variables (set (template_variables_combined .keys ()))
261+ if message .text is None :
262+ raise ValueError (NO_TEXT_ERROR_MESSAGE .format (role = message .role .value , message = message ))
263+ if message .text and "templatize_part" in message .text :
264+ raise ValueError (FILTER_NOT_ALLOWED_ERROR_MESSAGE )
265+ compiled_template = self ._env .from_string (message .text )
266+ rendered_text = compiled_template .render (template_variables_combined )
267+ # deep copy the message to avoid modifying the original message
268+ rendered_message : ChatMessage = deepcopy (message )
269+ rendered_message ._content = [TextContent (text = rendered_text )]
270+ processed_messages .append (rendered_message )
271+ else :
272+ processed_messages .append (message )
273+ elif isinstance (template , str ):
274+ self ._validate_variables (set (template_variables_combined .keys ()))
275+ processed_messages = self ._render_chat_messages_from_str_template (template , template_variables_combined )
224276
225277 return {"prompt" : processed_messages }
226278
279+ def _render_chat_messages_from_str_template (
280+ self , template : str , template_variables : Dict [str , Any ]
281+ ) -> List [ChatMessage ]:
282+ """
283+ Renders a chat message from a string template.
284+
285+ This must be used in conjunction with the `ChatMessageExtension` Jinja2 extension
286+ and the `templatize_part` filter.
287+ """
288+ compiled_template = self ._env .from_string (template )
289+ rendered = compiled_template .render (template_variables )
290+
291+ messages = []
292+ for line in rendered .strip ().split ("\n " ):
293+ line = line .strip ()
294+ if line :
295+ messages .append (ChatMessage .from_dict (json .loads (line )))
296+
297+ return messages
298+
227299 def _validate_variables (self , provided_variables : Set [str ]):
228300 """
229301 Checks if all the required template variables are provided.
@@ -252,10 +324,11 @@ def to_dict(self) -> Dict[str, Any]:
252324 :returns:
253325 Serialized dictionary representation of the component.
254326 """
255- if self .template is not None :
327+ template : Optional [Union [List [Dict [str , Any ]], str ]] = None
328+ if isinstance (self .template , list ):
256329 template = [m .to_dict () for m in self .template ]
257- else :
258- template = None
330+ elif isinstance ( self . template , str ) :
331+ template = self . template
259332
260333 return default_to_dict (
261334 self , template = template , variables = self ._variables , required_variables = self ._required_variables
@@ -275,6 +348,9 @@ def from_dict(cls, data: Dict[str, Any]) -> "ChatPromptBuilder":
275348 init_parameters = data ["init_parameters" ]
276349 template = init_parameters .get ("template" )
277350 if template :
278- init_parameters ["template" ] = [ChatMessage .from_dict (d ) for d in template ]
351+ if isinstance (template , list ):
352+ init_parameters ["template" ] = [ChatMessage .from_dict (d ) for d in template ]
353+ elif isinstance (template , str ):
354+ init_parameters ["template" ] = template
279355
280356 return default_from_dict (cls , data )
0 commit comments