Skip to content

Commit 484f227

Browse files
sjrldavidsbatista
authored andcommitted
refactor!: Make user_prompt required in LLM (#11152)
1 parent 8d2dfa4 commit 484f227

File tree

3 files changed

+75
-14
lines changed

3 files changed

+75
-14
lines changed

haystack/components/generators/chat/llm.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -48,29 +48,43 @@ def __init__(
4848
*,
4949
chat_generator: ChatGenerator,
5050
system_prompt: str | None = None,
51-
user_prompt: str | None = None,
52-
required_variables: list[str] | Literal["*"] | None = None,
51+
user_prompt: str,
52+
required_variables: list[str] | Literal["*"] = "*",
5353
streaming_callback: StreamingCallbackT | None = None,
5454
) -> None:
5555
"""
5656
Initialize the LLM component.
5757
5858
:param chat_generator: An instance of the chat generator that the LLM should use.
5959
:param system_prompt: System prompt for the LLM.
60-
:param user_prompt: User prompt for the LLM. If provided this is appended to the messages provided at runtime.
60+
:param user_prompt: User prompt for the LLM. Must contain at least one Jinja2 template variable
61+
(e.g., ``{{ variable_name }}``). This prompt is appended to the messages provided at runtime.
6162
:param required_variables:
62-
List variables that must be provided as input to user_prompt.
63+
Variables that must be provided as input to user_prompt.
6364
If a variable listed as required is not provided, an exception is raised.
64-
If set to `"*"`, all variables found in the prompt are required. Optional.
65+
If set to ``"*"``, all variables found in the prompt are required. Defaults to ``"*"``.
6566
:param streaming_callback: A callback that will be invoked when a response is streamed from the LLM.
67+
:raises ValueError: If user_prompt contains no template variables.
68+
:raises ValueError: If required_variables is an empty list.
6669
"""
70+
if isinstance(required_variables, list) and len(required_variables) == 0:
71+
raise ValueError(
72+
"required_variables must not be empty. Set it to '*' to require all variables, "
73+
"or provide a non-empty list of variable names."
74+
)
6775
super(LLM, self).__init__( # noqa: UP008
6876
chat_generator=chat_generator,
6977
system_prompt=system_prompt,
7078
user_prompt=user_prompt,
7179
required_variables=required_variables,
7280
streaming_callback=streaming_callback,
7381
)
82+
if self._user_chat_prompt_builder is None or len(self._user_chat_prompt_builder.variables) == 0:
83+
raise ValueError(
84+
"user_prompt must contain at least one template variable (e.g., '{{ variable_name }}'). "
85+
"The LLM component requires at least one required input variable to ensure proper "
86+
"pipeline scheduling."
87+
)
7488

7589
def to_dict(self) -> dict[str, Any]:
7690
"""
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
---
2+
upgrade:
3+
- |
4+
The ``LLM`` component now requires ``user_prompt`` to be provided at initialization and it must
5+
contain at least one Jinja2 template variable (e.g. ``{{ variable_name }}``). This ensures the
6+
component always exposes at least one required input socket, which is necessary for correct
7+
pipeline scheduling.
8+
9+
``required_variables`` now defaults to ``"*"`` (all variables in ``user_prompt`` are required),
10+
and passing an empty list raises a ``ValueError``.
11+
12+
**If you are affected**: update any code that instantiates ``LLM`` without a ``user_prompt``,
13+
or with a ``user_prompt`` that has no template variables, to include at least one variable.
14+
15+
Before:
16+
17+
.. code:: python
18+
19+
llm = LLM(chat_generator=OpenAIChatGenerator(), system_prompt="You are helpful.")
20+
21+
After:
22+
23+
.. code:: python
24+
25+
llm = LLM(
26+
chat_generator=OpenAIChatGenerator(),
27+
system_prompt="You are helpful.",
28+
user_prompt='{% message role="user" %}{{ query }}{% endmessage %}',
29+
)

test/components/generators/chat/test_llm.py

Lines changed: 27 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -63,37 +63,52 @@ async def run_async(self, messages: list[ChatMessage], **kwargs) -> dict[str, An
6363

6464
class TestLLM:
6565
class TestInit:
66+
USER_PROMPT = '{% message role="user" %}{{ query }}{% endmessage %}'
67+
6668
def test_is_subclass_of_agent(self):
6769
assert issubclass(LLM, Agent)
6870

6971
def test_defaults(self):
70-
llm = LLM(chat_generator=MockChatGenerator())
72+
llm = LLM(chat_generator=MockChatGenerator(), user_prompt=self.USER_PROMPT)
7173
assert llm.chat_generator is not None
7274
assert llm.tools == []
7375
assert llm.system_prompt is None
74-
assert llm.user_prompt is None
76+
assert llm.user_prompt == self.USER_PROMPT
77+
assert llm.required_variables == "*"
7578
assert llm.streaming_callback is None
7679
assert llm._tool_invoker is None
7780

7881
def test_output_sockets(self):
79-
llm = LLM(chat_generator=MockChatGenerator())
82+
llm = LLM(chat_generator=MockChatGenerator(), user_prompt=self.USER_PROMPT)
8083
assert llm.__haystack_output__._sockets_dict == {
8184
"messages": OutputSocket(name="messages", type=list[ChatMessage], receivers=[]),
8285
"last_message": OutputSocket(name="last_message", type=ChatMessage, receivers=[]),
8386
}
8487

8588
def test_detects_no_tools_support(self):
86-
llm = LLM(chat_generator=MockChatGenerator())
89+
llm = LLM(chat_generator=MockChatGenerator(), user_prompt=self.USER_PROMPT)
8790
assert llm._chat_generator_supports_tools is False
8891

8992
def test_detects_tools_support(self):
90-
llm = LLM(chat_generator=MockChatGeneratorWithTools())
93+
llm = LLM(chat_generator=MockChatGeneratorWithTools(), user_prompt=self.USER_PROMPT)
9194
assert llm._chat_generator_supports_tools is True
9295

96+
def test_raises_if_user_prompt_has_no_variables(self):
97+
with pytest.raises(ValueError, match="at least one template variable"):
98+
LLM(
99+
chat_generator=MockChatGenerator(),
100+
user_prompt='{% message role="user" %}Hello world{% endmessage %}',
101+
)
102+
103+
def test_raises_if_required_variables_empty(self):
104+
with pytest.raises(ValueError, match="required_variables must not be empty"):
105+
LLM(chat_generator=MockChatGenerator(), user_prompt=self.USER_PROMPT, required_variables=[])
106+
93107
class TestSerialization:
94108
def test_to_dict_excludes_agent_only_params(self, monkeypatch):
95109
monkeypatch.setenv("OPENAI_API_KEY", "fake-key")
96-
llm = LLM(chat_generator=OpenAIChatGenerator(), system_prompt="You are helpful.")
110+
user_prompt = '{% message role="user" %}{{ query }}{% endmessage %}'
111+
llm = LLM(chat_generator=OpenAIChatGenerator(), system_prompt="You are helpful.", user_prompt=user_prompt)
97112

98113
serialized = llm.to_dict()
99114

@@ -153,8 +168,8 @@ def test_from_dict(self, monkeypatch):
153168
},
154169
},
155170
"system_prompt": "You are helpful.",
156-
"user_prompt": None,
157-
"required_variables": None,
171+
"user_prompt": '{% message role="user" %}{{ query }}{% endmessage %}',
172+
"required_variables": "*",
158173
"streaming_callback": None,
159174
},
160175
}
@@ -168,7 +183,10 @@ def test_from_dict(self, monkeypatch):
168183

169184
def test_roundtrip(self, monkeypatch):
170185
monkeypatch.setenv("OPENAI_API_KEY", "fake-key")
171-
original = LLM(chat_generator=OpenAIChatGenerator(), system_prompt="You are a poet.")
186+
user_prompt = '{% message role="user" %}{{ query }}{% endmessage %}'
187+
original = LLM(
188+
chat_generator=OpenAIChatGenerator(), system_prompt="You are a poet.", user_prompt=user_prompt
189+
)
172190

173191
restored = LLM.from_dict(original.to_dict())
174192

0 commit comments

Comments
 (0)