Skip to content

Commit d04c10c

Browse files
committed
allow user prompt to be optional and update regression test to be correct
1 parent 4d3e406 commit d04c10c

2 files changed

Lines changed: 19 additions & 14 deletions

File tree

haystack/components/generators/chat/llm.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def __init__(
4848
*,
4949
chat_generator: ChatGenerator,
5050
system_prompt: str | None = None,
51-
user_prompt: str,
51+
user_prompt: str | None = None,
5252
required_variables: list[str] | Literal["*"] = "*",
5353
streaming_callback: StreamingCallbackT | None = None,
5454
) -> None:
@@ -59,12 +59,13 @@ def __init__(
5959
:param system_prompt: System prompt for the LLM.
6060
:param user_prompt: User prompt for the LLM. This prompt is appended to the messages provided at
6161
runtime. If it contains Jinja2 template variables (e.g., `{{ variable_name }}`), they become
62-
inputs to the component. If there are no template variables, `messages` must be provided at
63-
runtime instead.
62+
inputs to the component. If omitted or if there are no template variables, `messages` must be
63+
provided at runtime instead.
6464
:param required_variables:
6565
Variables that must be provided as input to user_prompt.
6666
If a variable listed as required is not provided, an exception is raised.
6767
If set to `"*"`, all variables found in the prompt are required. Defaults to `"*"`.
68+
Only relevant when `user_prompt` contains template variables.
6869
:param streaming_callback: A callback that will be invoked when a response is streamed from the LLM.
6970
:raises ValueError: If user_prompt contains template variables but required_variables is an empty list.
7071
"""

test/components/generators/chat/test_llm.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -301,41 +301,45 @@ class TestLLMNotTriggeredByInjectedInput:
301301
when its required inputs (e.g. `query`) never arrive.
302302
"""
303303

304-
USER_PROMPT = '{% message role="user" %}{{ query }}{% endmessage %}'
305-
306304
def test_llm_not_triggered_by_injected_streaming_callback(self):
305+
307306
@component
308307
class Planner:
309-
@component.output_types(query=str, last_role=str)
308+
@component.output_types(messages=list[ChatMessage], last_role=str)
310309
def run(self) -> dict:
311-
return {"query": "What is 2+2?", "last_role": "assistant"}
310+
return {"messages": [ChatMessage.from_user("hello")], "last_role": "assistant"}
312311

313312
chat_generator = MockChatGenerator()
314-
llm = LLM(chat_generator=chat_generator, user_prompt=self.USER_PROMPT)
313+
llm = LLM(chat_generator=chat_generator)
315314
chat_generator.run = MagicMock(return_value={"replies": [ChatMessage.from_assistant("x")]})
316315

317316
router = ConditionalRouter(
318317
routes=[
319318
{
320319
"condition": "{{ last_role == 'tool' }}",
321-
"output": "{{ query }}",
320+
"output": "{{ messages }}",
322321
"output_name": "processing",
323-
"output_type": str,
322+
"output_type": list[ChatMessage],
323+
},
324+
{
325+
"condition": "{{ True }}",
326+
"output": "{{ messages }}",
327+
"output_name": "planning",
328+
"output_type": list[ChatMessage],
324329
},
325-
{"condition": "{{ True }}", "output": "{{ query }}", "output_name": "planning", "output_type": str},
326330
],
327331
unsafe=True,
328332
)
329333

330334
pipeline = Pipeline()
331335
pipeline.add_component("planner", Planner())
332336
pipeline.add_component("router", router)
333-
pipeline.add_component("branch_joiner", BranchJoiner(type_=str))
337+
pipeline.add_component("branch_joiner", BranchJoiner(type_=list[ChatMessage]))
334338
pipeline.add_component("llm", llm)
335-
pipeline.connect("planner.query", "router.query")
339+
pipeline.connect("planner.messages", "router.messages")
336340
pipeline.connect("planner.last_role", "router.last_role")
337341
pipeline.connect("router.processing", "branch_joiner.value")
338-
pipeline.connect("branch_joiner.value", "llm.query")
342+
pipeline.connect("branch_joiner.value", "llm.messages")
339343

340344
result = pipeline.run(data={"llm": {"streaming_callback": sync_streaming_callback}})
341345

0 commit comments

Comments
 (0)