Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 19 additions & 5 deletions haystack/components/generators/chat/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,29 +48,43 @@ def __init__(
*,
chat_generator: ChatGenerator,
system_prompt: str | None = None,
user_prompt: str | None = None,
required_variables: list[str] | Literal["*"] | None = None,
user_prompt: str,
required_variables: list[str] | Literal["*"] = "*",
streaming_callback: StreamingCallbackT | None = None,
) -> None:
"""
Initialize the LLM component.

:param chat_generator: An instance of the chat generator that the LLM should use.
:param system_prompt: System prompt for the LLM.
:param user_prompt: User prompt for the LLM. If provided this is appended to the messages provided at runtime.
:param user_prompt: User prompt for the LLM. Must contain at least one Jinja2 template variable
(e.g., ``{{ variable_name }}``). This prompt is appended to the messages provided at runtime.
:param required_variables:
List variables that must be provided as input to user_prompt.
Variables that must be provided as input to user_prompt.
If a variable listed as required is not provided, an exception is raised.
If set to `"*"`, all variables found in the prompt are required. Optional.
If set to ``"*"``, all variables found in the prompt are required. Defaults to ``"*"``.
:param streaming_callback: A callback that will be invoked when a response is streamed from the LLM.
:raises ValueError: If user_prompt contains no template variables.
:raises ValueError: If required_variables is an empty list.
"""
if isinstance(required_variables, list) and len(required_variables) == 0:
raise ValueError(
"required_variables must not be empty. Set it to '*' to require all variables, "
"or provide a non-empty list of variable names."
)
super(LLM, self).__init__( # noqa: UP008
chat_generator=chat_generator,
system_prompt=system_prompt,
user_prompt=user_prompt,
required_variables=required_variables,
streaming_callback=streaming_callback,
)
if self._user_chat_prompt_builder is None or len(self._user_chat_prompt_builder.variables) == 0:
raise ValueError(
"user_prompt must contain at least one template variable (e.g., '{{ variable_name }}'). "
"The LLM component requires at least one required input variable to ensure proper "
"pipeline scheduling."
)

def to_dict(self) -> dict[str, Any]:
"""
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
---
upgrade:
- |
The ``LLM`` component now requires ``user_prompt`` to be provided at initialization and it must
contain at least one Jinja2 template variable (e.g. ``{{ variable_name }}``). This ensures the
component always exposes at least one required input socket, which is necessary for correct
pipeline scheduling.

``required_variables`` now defaults to ``"*"`` (all variables in ``user_prompt`` are required),
and passing an empty list raises a ``ValueError``.

**If you are affected**: update any code that instantiates ``LLM`` without a ``user_prompt``,
or with a ``user_prompt`` that has no template variables, to include at least one variable.

Before:

.. code:: python

llm = LLM(chat_generator=OpenAIChatGenerator(), system_prompt="You are helpful.")

After:

.. code:: python

llm = LLM(
chat_generator=OpenAIChatGenerator(),
system_prompt="You are helpful.",
user_prompt='{% message role="user" %}{{ query }}{% endmessage %}',
)
36 changes: 27 additions & 9 deletions test/components/generators/chat/test_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,37 +63,52 @@ async def run_async(self, messages: list[ChatMessage], **kwargs) -> dict[str, An

class TestLLM:
class TestInit:
USER_PROMPT = '{% message role="user" %}{{ query }}{% endmessage %}'

def test_is_subclass_of_agent(self):
assert issubclass(LLM, Agent)

def test_defaults(self):
llm = LLM(chat_generator=MockChatGenerator())
llm = LLM(chat_generator=MockChatGenerator(), user_prompt=self.USER_PROMPT)
assert llm.chat_generator is not None
assert llm.tools == []
assert llm.system_prompt is None
assert llm.user_prompt is None
assert llm.user_prompt == self.USER_PROMPT
assert llm.required_variables == "*"
assert llm.streaming_callback is None
assert llm._tool_invoker is None

def test_output_sockets(self):
llm = LLM(chat_generator=MockChatGenerator())
llm = LLM(chat_generator=MockChatGenerator(), user_prompt=self.USER_PROMPT)
assert llm.__haystack_output__._sockets_dict == {
"messages": OutputSocket(name="messages", type=list[ChatMessage], receivers=[]),
"last_message": OutputSocket(name="last_message", type=ChatMessage, receivers=[]),
}

def test_detects_no_tools_support(self):
llm = LLM(chat_generator=MockChatGenerator())
llm = LLM(chat_generator=MockChatGenerator(), user_prompt=self.USER_PROMPT)
assert llm._chat_generator_supports_tools is False

def test_detects_tools_support(self):
llm = LLM(chat_generator=MockChatGeneratorWithTools())
llm = LLM(chat_generator=MockChatGeneratorWithTools(), user_prompt=self.USER_PROMPT)
assert llm._chat_generator_supports_tools is True

def test_raises_if_user_prompt_has_no_variables(self):
with pytest.raises(ValueError, match="at least one template variable"):
LLM(
chat_generator=MockChatGenerator(),
user_prompt='{% message role="user" %}Hello world{% endmessage %}',
)

def test_raises_if_required_variables_empty(self):
with pytest.raises(ValueError, match="required_variables must not be empty"):
LLM(chat_generator=MockChatGenerator(), user_prompt=self.USER_PROMPT, required_variables=[])

class TestSerialization:
def test_to_dict_excludes_agent_only_params(self, monkeypatch):
monkeypatch.setenv("OPENAI_API_KEY", "fake-key")
llm = LLM(chat_generator=OpenAIChatGenerator(), system_prompt="You are helpful.")
user_prompt = '{% message role="user" %}{{ query }}{% endmessage %}'
llm = LLM(chat_generator=OpenAIChatGenerator(), system_prompt="You are helpful.", user_prompt=user_prompt)

serialized = llm.to_dict()

Expand Down Expand Up @@ -153,8 +168,8 @@ def test_from_dict(self, monkeypatch):
},
},
"system_prompt": "You are helpful.",
"user_prompt": None,
"required_variables": None,
"user_prompt": '{% message role="user" %}{{ query }}{% endmessage %}',
"required_variables": "*",
"streaming_callback": None,
},
}
Expand All @@ -168,7 +183,10 @@ def test_from_dict(self, monkeypatch):

def test_roundtrip(self, monkeypatch):
monkeypatch.setenv("OPENAI_API_KEY", "fake-key")
original = LLM(chat_generator=OpenAIChatGenerator(), system_prompt="You are a poet.")
user_prompt = '{% message role="user" %}{{ query }}{% endmessage %}'
original = LLM(
chat_generator=OpenAIChatGenerator(), system_prompt="You are a poet.", user_prompt=user_prompt
)

restored = LLM.from_dict(original.to_dict())

Expand Down
Loading