Skip to content

Commit 9a196a8

Browse files
authored
feat!: Update LLM component to dynamically set messages as required or optional based on init config (#11300)
1 parent 50b2141 commit 9a196a8

3 files changed

Lines changed: 162 additions & 31 deletions

File tree

haystack/components/generators/chat/llm.py

Lines changed: 35 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def __init__(
4848
*,
4949
chat_generator: ChatGenerator,
5050
system_prompt: str | None = None,
51-
user_prompt: str,
51+
user_prompt: str | None = None,
5252
required_variables: list[str] | Literal["*"] = "*",
5353
streaming_callback: StreamingCallbackT | None = None,
5454
) -> None:
@@ -57,21 +57,18 @@ def __init__(
5757
5858
:param chat_generator: An instance of the chat generator that the LLM should use.
5959
:param system_prompt: System prompt for the LLM.
60-
:param user_prompt: User prompt for the LLM. Must contain at least one Jinja2 template variable
61-
(e.g., ``{{ variable_name }}``). This prompt is appended to the messages provided at runtime.
60+
:param user_prompt: User prompt for the LLM. This prompt is appended to the messages provided at
61+
runtime. If it contains Jinja2 template variables (e.g., `{{ variable_name }}`), they become
62+
inputs to the component. If omitted or if there are no template variables, `messages` must be
63+
provided at runtime instead.
6264
:param required_variables:
6365
Variables that must be provided as input to user_prompt.
6466
If a variable listed as required is not provided, an exception is raised.
65-
If set to ``"*"``, all variables found in the prompt are required. Defaults to ``"*"``.
67+
If set to `"*"`, all variables found in the prompt are required. Defaults to `"*"`.
68+
Only relevant when `user_prompt` contains template variables.
6669
:param streaming_callback: A callback that will be invoked when a response is streamed from the LLM.
67-
:raises ValueError: If user_prompt contains no template variables.
68-
:raises ValueError: If required_variables is an empty list.
70+
:raises ValueError: If user_prompt contains template variables but required_variables is an empty list.
6971
"""
70-
if isinstance(required_variables, list) and len(required_variables) == 0:
71-
raise ValueError(
72-
"required_variables must not be empty. Set it to '*' to require all variables, "
73-
"or provide a non-empty list of variable names."
74-
)
7572
super(LLM, self).__init__( # noqa: UP008
7673
chat_generator=chat_generator,
7774
system_prompt=system_prompt,
@@ -80,11 +77,17 @@ def __init__(
8077
streaming_callback=streaming_callback,
8178
)
8279
if self._user_chat_prompt_builder is None or len(self._user_chat_prompt_builder.variables) == 0:
83-
raise ValueError(
84-
"user_prompt must contain at least one template variable (e.g., '{{ variable_name }}'). "
85-
"The LLM component requires at least one required input variable to ensure proper "
86-
"pipeline scheduling."
87-
)
80+
# This means user_prompt is empty or has no template variables.
81+
# To ensure properly scheduling we then require messages to be passed at runtime.
82+
component.set_input_type(self, "messages", list[ChatMessage])
83+
else:
84+
# user prompt was provided with variables
85+
if isinstance(required_variables, list) and len(required_variables) == 0:
86+
raise ValueError(
87+
"required_variables must not be empty. Set it to '*' to require all variables, "
88+
"or provide a non-empty list of variable names."
89+
)
90+
component.set_input_type(self, "messages", list[ChatMessage], None)
8891

8992
def to_dict(self) -> dict[str, Any]:
9093
"""
@@ -118,11 +121,10 @@ def from_dict(cls, data: dict[str, Any]) -> "LLM":
118121

119122
return default_from_dict(cls, data)
120123

121-
def run(
124+
def run( # type: ignore[override] # `messages` is in **kwargs to allow dynamic required/optional status
122125
self,
123-
messages: list[ChatMessage] | None = None,
124-
streaming_callback: StreamingCallbackT | None = None,
125126
*,
127+
streaming_callback: StreamingCallbackT | None = None,
126128
generation_kwargs: dict[str, Any] | None = None,
127129
system_prompt: str | None = None,
128130
user_prompt: str | None = None,
@@ -131,7 +133,9 @@ def run(
131133
"""
132134
Process messages and generate a response from the language model.
133135
134-
:param messages: List of Haystack ChatMessage objects to process.
136+
:param messages: Optional list of ChatMessage objects to prepend to the conversation. Whether this is
137+
required or optional depends on the `user_prompt` configuration: if `user_prompt` has no template
138+
variables, `messages` must be provided. Passed via `**kwargs`.
135139
:param streaming_callback: A callback that will be invoked when a response is streamed from the LLM.
136140
:param generation_kwargs: Additional keyword arguments for the underlying chat generator. These parameters
137141
will override the parameters passed during component initialization.
@@ -145,6 +149,9 @@ def run(
145149
- "messages": List of all messages exchanged during the LLM's run.
146150
- "last_message": The last message exchanged during the LLM's run.
147151
"""
152+
# `messages` is intentionally omitted from the signature so the framework can treat it as required
153+
# or optional depending on init configuration. See __init__ for details.
154+
messages = kwargs.pop("messages", None)
148155
return super(LLM, self).run( # noqa: UP008
149156
messages=messages or [],
150157
streaming_callback=streaming_callback,
@@ -154,11 +161,10 @@ def run(
154161
**kwargs,
155162
)
156163

157-
async def run_async(
164+
async def run_async( # type: ignore[override] # `messages` is in **kwargs to allow dynamic required/optional status
158165
self,
159-
messages: list[ChatMessage] | None = None,
160-
streaming_callback: StreamingCallbackT | None = None,
161166
*,
167+
streaming_callback: StreamingCallbackT | None = None,
162168
generation_kwargs: dict[str, Any] | None = None,
163169
system_prompt: str | None = None,
164170
user_prompt: str | None = None,
@@ -167,7 +173,9 @@ async def run_async(
167173
"""
168174
Asynchronously process messages and generate a response from the language model.
169175
170-
:param messages: List of Haystack ChatMessage objects to process.
176+
:param messages: Optional list of ChatMessage objects to prepend to the conversation. Whether this is
177+
required or optional depends on the `user_prompt` configuration: if `user_prompt` has no template
178+
variables, `messages` must be provided. Passed via `**kwargs`.
171179
:param streaming_callback: An asynchronous callback that will be invoked when a response is streamed
172180
from the LLM.
173181
:param generation_kwargs: Additional keyword arguments for the underlying chat generator. These parameters
@@ -182,6 +190,9 @@ async def run_async(
182190
- "messages": List of all messages exchanged during the LLM's run.
183191
- "last_message": The last message exchanged during the LLM's run.
184192
"""
193+
# `messages` is intentionally omitted from the signature so the framework can treat it as required
194+
# or optional depending on init configuration. See __init__ for details.
195+
messages = kwargs.pop("messages", None)
185196
return await super(LLM, self).run_async( # noqa: UP008
186197
messages=messages or [],
187198
streaming_callback=streaming_callback,
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
---
2+
upgrade:
3+
- |
4+
``LLM.run`` and ``LLM.run_async`` no longer accept ``messages`` and ``streaming_callback`` as positional
5+
arguments — they must now be passed as keyword arguments. Update any direct calls accordingly:
6+
7+
.. code:: python
8+
9+
# Before
10+
llm.run([message], my_callback)
11+
12+
# After
13+
llm.run(messages=[message], streaming_callback=my_callback)
14+
15+
enhancements:
16+
- |
17+
``LLM`` now supports two usage modes:
18+
19+
1. **Template-variable mode**: provide a ``user_prompt`` with Jinja2 variables (e.g. ``{{ query }}``).
20+
Those variables become pipeline inputs and ``messages`` is optional. The rendered ``user_prompt``
21+
is always appended after any ``messages`` provided at runtime.
22+
2. **Pass-through mode**: omit ``user_prompt`` or provide one with no template variables. ``messages``
23+
becomes a required input, allowing a fully-constructed list of ``ChatMessage``s to be passed from upstream.

test/components/generators/chat/test_llm.py

Lines changed: 104 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,22 +3,30 @@
33
# SPDX-License-Identifier: Apache-2.0
44

55
from typing import Any
6+
from unittest.mock import MagicMock
67

78
import pytest
89

910
from haystack import Document, Pipeline, component
1011
from haystack.components.agents.agent import Agent
1112
from haystack.components.generators.chat import LLM
1213
from haystack.components.generators.chat.openai import OpenAIChatGenerator
14+
from haystack.components.joiners.branch import BranchJoiner
1315
from haystack.components.retrievers.in_memory import InMemoryBM25Retriever
14-
from haystack.core.component.types import OutputSocket
16+
from haystack.components.routers.conditional_router import ConditionalRouter
17+
from haystack.core.component.types import InputSocket, OutputSocket
1518
from haystack.dataclasses import ChatMessage
1619
from haystack.dataclasses.chat_message import ChatRole
20+
from haystack.dataclasses.streaming_chunk import StreamingChunk
1721
from haystack.document_stores.in_memory import InMemoryDocumentStore
1822
from haystack.tools import Tool
1923
from haystack.tools.toolset import Toolset
2024

2125

26+
def sync_streaming_callback(chunk: StreamingChunk) -> None:
27+
pass
28+
29+
2230
@component
2331
class MockChatGeneratorWithTools:
2432
"""A mock chat generator that accepts a tools parameter."""
@@ -93,12 +101,19 @@ def test_detects_tools_support(self):
93101
llm = LLM(chat_generator=MockChatGeneratorWithTools(), user_prompt=self.USER_PROMPT)
94102
assert llm._chat_generator_supports_tools is True
95103

96-
def test_raises_if_user_prompt_has_no_variables(self):
97-
with pytest.raises(ValueError, match="at least one template variable"):
98-
LLM(
99-
chat_generator=MockChatGenerator(),
100-
user_prompt='{% message role="user" %}Hello world{% endmessage %}',
101-
)
104+
def test_messages_required_when_no_prompt_variables(self):
105+
llm = LLM(
106+
chat_generator=MockChatGenerator(), user_prompt='{% message role="user" %}Hello world{% endmessage %}'
107+
)
108+
messages_socket = llm.__haystack_input__._sockets_dict["messages"]
109+
assert isinstance(messages_socket, InputSocket)
110+
assert messages_socket.is_mandatory
111+
112+
def test_messages_optional_when_prompt_has_variables(self):
113+
llm = LLM(chat_generator=MockChatGenerator(), user_prompt=self.USER_PROMPT)
114+
messages_socket = llm.__haystack_input__._sockets_dict["messages"]
115+
assert isinstance(messages_socket, InputSocket)
116+
assert not messages_socket.is_mandatory
102117

103118
def test_raises_if_required_variables_empty(self):
104119
with pytest.raises(ValueError, match="required_variables must not be empty"):
@@ -195,6 +210,31 @@ def test_roundtrip(self, monkeypatch):
195210
assert restored.system_prompt == original.system_prompt
196211
assert restored.tools == []
197212

213+
class TestRun:
214+
USER_PROMPT = '{% message role="user" %}{{ query }}{% endmessage %}'
215+
216+
def test_run_accepts_messages_via_kwargs(self):
217+
llm = LLM(chat_generator=MockChatGenerator(), user_prompt=self.USER_PROMPT)
218+
prior_message = ChatMessage.from_user("Some prior context")
219+
result = llm.run(query="What is 2+2?", messages=[prior_message])
220+
assert result["last_message"].text == "Sync reply"
221+
assert prior_message in result["messages"]
222+
223+
def test_run_without_messages(self):
224+
llm = LLM(chat_generator=MockChatGenerator(), user_prompt=self.USER_PROMPT)
225+
result = llm.run(query="What is 2+2?")
226+
assert result["last_message"].text == "Sync reply"
227+
user_messages = [m for m in result["messages"] if m.is_from(ChatRole.USER)]
228+
assert any("What is 2+2?" in m.text for m in user_messages)
229+
230+
@pytest.mark.asyncio
231+
async def test_run_async_accepts_messages_via_kwargs(self):
232+
llm = LLM(chat_generator=MockChatGenerator(), user_prompt=self.USER_PROMPT)
233+
prior_message = ChatMessage.from_user("Some prior context")
234+
result = await llm.run_async(query="What is 2+2?", messages=[prior_message])
235+
assert result["last_message"].text == "Async reply"
236+
assert prior_message in result["messages"]
237+
198238
class TestPipelineIntegration:
199239
@pytest.fixture()
200240
def document_store_with_docs(self):
@@ -250,3 +290,60 @@ def test_rag_pipeline(self, document_store_with_docs):
250290

251291
assert llm_output["last_message"].is_from(ChatRole.ASSISTANT)
252292
assert llm_output["last_message"].text == "Sync reply"
293+
294+
295+
class TestLLMNotTriggeredByInjectedInput:
296+
"""
297+
Regression guard for the optional-messages scheduling hazard described in
298+
https://github.com/deepset-ai/haystack/issues/11109.
299+
300+
When `user_prompt` contains template variables, `messages` is optional on the LLM.
301+
An optional input with `sender=None` (i.e., injected directly via `pipeline.run`)
302+
would flip `has_user_input()` to True and incorrectly trigger the component even
303+
when its required inputs (e.g. `query`) never arrive.
304+
"""
305+
306+
def test_llm_not_triggered_by_injected_streaming_callback(self):
307+
308+
@component
309+
class Planner:
310+
@component.output_types(messages=list[ChatMessage], last_role=str)
311+
def run(self) -> dict:
312+
return {"messages": [ChatMessage.from_user("hello")], "last_role": "assistant"}
313+
314+
chat_generator = MockChatGenerator()
315+
llm = LLM(chat_generator=chat_generator)
316+
chat_generator.run = MagicMock(return_value={"replies": [ChatMessage.from_assistant("x")]})
317+
318+
router = ConditionalRouter(
319+
routes=[
320+
{
321+
"condition": "{{ last_role == 'tool' }}",
322+
"output": "{{ messages }}",
323+
"output_name": "processing",
324+
"output_type": list[ChatMessage],
325+
},
326+
{
327+
"condition": "{{ True }}",
328+
"output": "{{ messages }}",
329+
"output_name": "planning",
330+
"output_type": list[ChatMessage],
331+
},
332+
],
333+
unsafe=True,
334+
)
335+
336+
pipeline = Pipeline()
337+
pipeline.add_component("planner", Planner())
338+
pipeline.add_component("router", router)
339+
pipeline.add_component("branch_joiner", BranchJoiner(type_=list[ChatMessage]))
340+
pipeline.add_component("llm", llm)
341+
pipeline.connect("planner.messages", "router.messages")
342+
pipeline.connect("planner.last_role", "router.last_role")
343+
pipeline.connect("router.processing", "branch_joiner.value")
344+
pipeline.connect("branch_joiner.value", "llm.messages")
345+
346+
result = pipeline.run(data={"llm": {"streaming_callback": sync_streaming_callback}})
347+
348+
assert "llm" not in result
349+
chat_generator.run.assert_not_called()

0 commit comments

Comments
 (0)