diff --git a/haystack/components/generators/chat/llm.py b/haystack/components/generators/chat/llm.py index fc4002481a..5c97fa0f20 100644 --- a/haystack/components/generators/chat/llm.py +++ b/haystack/components/generators/chat/llm.py @@ -48,8 +48,8 @@ def __init__( *, chat_generator: ChatGenerator, system_prompt: str | None = None, - user_prompt: str | None = None, - required_variables: list[str] | Literal["*"] | None = None, + user_prompt: str, + required_variables: list[str] | Literal["*"] = "*", streaming_callback: StreamingCallbackT | None = None, ) -> None: """ @@ -57,13 +57,21 @@ def __init__( :param chat_generator: An instance of the chat generator that the LLM should use. :param system_prompt: System prompt for the LLM. - :param user_prompt: User prompt for the LLM. If provided this is appended to the messages provided at runtime. + :param user_prompt: User prompt for the LLM. Must contain at least one Jinja2 template variable + (e.g., ``{{ variable_name }}``). This prompt is appended to the messages provided at runtime. :param required_variables: - List variables that must be provided as input to user_prompt. + Variables that must be provided as input to user_prompt. If a variable listed as required is not provided, an exception is raised. - If set to `"*"`, all variables found in the prompt are required. Optional. + If set to ``"*"``, all variables found in the prompt are required. Defaults to ``"*"``. :param streaming_callback: A callback that will be invoked when a response is streamed from the LLM. + :raises ValueError: If user_prompt contains no template variables. + :raises ValueError: If required_variables is an empty list. """ + if isinstance(required_variables, list) and len(required_variables) == 0: + raise ValueError( + "required_variables must not be empty. Set it to '*' to require all variables, " + "or provide a non-empty list of variable names." + ) super(LLM, self).__init__( # noqa: UP008 chat_generator=chat_generator, system_prompt=system_prompt, @@ -71,6 +79,12 @@ def __init__( required_variables=required_variables, streaming_callback=streaming_callback, ) + if self._user_chat_prompt_builder is None or len(self._user_chat_prompt_builder.variables) == 0: + raise ValueError( + "user_prompt must contain at least one template variable (e.g., '{{ variable_name }}'). " + "The LLM component requires at least one required input variable to ensure proper " + "pipeline scheduling." + ) def to_dict(self) -> dict[str, Any]: """ diff --git a/releasenotes/notes/llm-require-user-prompt-variables-46c69997acf72e05.yaml b/releasenotes/notes/llm-require-user-prompt-variables-46c69997acf72e05.yaml new file mode 100644 index 0000000000..7bae8b3413 --- /dev/null +++ b/releasenotes/notes/llm-require-user-prompt-variables-46c69997acf72e05.yaml @@ -0,0 +1,29 @@ +--- +upgrade: + - | + The ``LLM`` component now requires ``user_prompt`` to be provided at initialization and it must + contain at least one Jinja2 template variable (e.g. ``{{ variable_name }}``). This ensures the + component always exposes at least one required input socket, which is necessary for correct + pipeline scheduling. + + ``required_variables`` now defaults to ``"*"`` (all variables in ``user_prompt`` are required), + and passing an empty list raises a ``ValueError``. + + **If you are affected**: update any code that instantiates ``LLM`` without a ``user_prompt``, + or with a ``user_prompt`` that has no template variables, to include at least one variable. + + Before: + + .. code:: python + + llm = LLM(chat_generator=OpenAIChatGenerator(), system_prompt="You are helpful.") + + After: + + .. code:: python + + llm = LLM( + chat_generator=OpenAIChatGenerator(), + system_prompt="You are helpful.", + user_prompt='{% message role="user" %}{{ query }}{% endmessage %}', + ) diff --git a/test/components/generators/chat/test_llm.py b/test/components/generators/chat/test_llm.py index 72e727cfbd..4898890fe9 100644 --- a/test/components/generators/chat/test_llm.py +++ b/test/components/generators/chat/test_llm.py @@ -63,37 +63,52 @@ async def run_async(self, messages: list[ChatMessage], **kwargs) -> dict[str, An class TestLLM: class TestInit: + USER_PROMPT = '{% message role="user" %}{{ query }}{% endmessage %}' + def test_is_subclass_of_agent(self): assert issubclass(LLM, Agent) def test_defaults(self): - llm = LLM(chat_generator=MockChatGenerator()) + llm = LLM(chat_generator=MockChatGenerator(), user_prompt=self.USER_PROMPT) assert llm.chat_generator is not None assert llm.tools == [] assert llm.system_prompt is None - assert llm.user_prompt is None + assert llm.user_prompt == self.USER_PROMPT + assert llm.required_variables == "*" assert llm.streaming_callback is None assert llm._tool_invoker is None def test_output_sockets(self): - llm = LLM(chat_generator=MockChatGenerator()) + llm = LLM(chat_generator=MockChatGenerator(), user_prompt=self.USER_PROMPT) assert llm.__haystack_output__._sockets_dict == { "messages": OutputSocket(name="messages", type=list[ChatMessage], receivers=[]), "last_message": OutputSocket(name="last_message", type=ChatMessage, receivers=[]), } def test_detects_no_tools_support(self): - llm = LLM(chat_generator=MockChatGenerator()) + llm = LLM(chat_generator=MockChatGenerator(), user_prompt=self.USER_PROMPT) assert llm._chat_generator_supports_tools is False def test_detects_tools_support(self): - llm = LLM(chat_generator=MockChatGeneratorWithTools()) + llm = LLM(chat_generator=MockChatGeneratorWithTools(), user_prompt=self.USER_PROMPT) assert llm._chat_generator_supports_tools is True + def test_raises_if_user_prompt_has_no_variables(self): + with pytest.raises(ValueError, match="at least one template variable"): + LLM( + chat_generator=MockChatGenerator(), + user_prompt='{% message role="user" %}Hello world{% endmessage %}', + ) + + def test_raises_if_required_variables_empty(self): + with pytest.raises(ValueError, match="required_variables must not be empty"): + LLM(chat_generator=MockChatGenerator(), user_prompt=self.USER_PROMPT, required_variables=[]) + class TestSerialization: def test_to_dict_excludes_agent_only_params(self, monkeypatch): monkeypatch.setenv("OPENAI_API_KEY", "fake-key") - llm = LLM(chat_generator=OpenAIChatGenerator(), system_prompt="You are helpful.") + user_prompt = '{% message role="user" %}{{ query }}{% endmessage %}' + llm = LLM(chat_generator=OpenAIChatGenerator(), system_prompt="You are helpful.", user_prompt=user_prompt) serialized = llm.to_dict() @@ -153,8 +168,8 @@ def test_from_dict(self, monkeypatch): }, }, "system_prompt": "You are helpful.", - "user_prompt": None, - "required_variables": None, + "user_prompt": '{% message role="user" %}{{ query }}{% endmessage %}', + "required_variables": "*", "streaming_callback": None, }, } @@ -168,7 +183,10 @@ def test_from_dict(self, monkeypatch): def test_roundtrip(self, monkeypatch): monkeypatch.setenv("OPENAI_API_KEY", "fake-key") - original = LLM(chat_generator=OpenAIChatGenerator(), system_prompt="You are a poet.") + user_prompt = '{% message role="user" %}{{ query }}{% endmessage %}' + original = LLM( + chat_generator=OpenAIChatGenerator(), system_prompt="You are a poet.", user_prompt=user_prompt + ) restored = LLM.from_dict(original.to_dict())