Skip to content
16 changes: 10 additions & 6 deletions haystack/components/builders/chat_prompt_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@
from copy import deepcopy
from typing import Any, Literal, Optional, Union

from jinja2 import meta
from jinja2.sandbox import SandboxedEnvironment

from haystack import component, default_from_dict, default_to_dict, logging
from haystack.dataclasses.chat_message import ChatMessage, ChatRole, TextContent
from haystack.lazy_imports import LazyImport
from haystack.utils import Jinja2TimeExtension
from haystack.utils.jinja2_chat_extension import ChatMessageExtension, templatize_part
from haystack.utils.jinja2_extensions import _extract_template_variables_and_assignments

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -179,13 +179,17 @@ def __init__(
raise ValueError(NO_TEXT_ERROR_MESSAGE.format(role=message.role.value, message=message))
if message.text and "templatize_part" in message.text:
raise ValueError(FILTER_NOT_ALLOWED_ERROR_MESSAGE)
ast = self._env.parse(message.text)
template_variables = meta.find_undeclared_variables(ast)
extracted_variables += list(template_variables)
assigned_variables, template_variables = _extract_template_variables_and_assignments(
env=self._env, template=message.text
)
extracted_variables += list(template_variables - assigned_variables)
elif isinstance(template, str):
ast = self._env.parse(template)
extracted_variables = list(meta.find_undeclared_variables(ast))
assigned_variables, template_variables = _extract_template_variables_and_assignments(
env=self._env, template=template
)
extracted_variables = list(template_variables - assigned_variables)

extracted_variables = extracted_variables or []
self.variables = variables or extracted_variables
self.required_variables = required_variables or []

Expand Down
12 changes: 7 additions & 5 deletions haystack/components/builders/prompt_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@

from typing import Any, Literal, Optional, Union

from jinja2 import meta
from jinja2.sandbox import SandboxedEnvironment

from haystack import component, default_to_dict, logging
from haystack.utils import Jinja2TimeExtension
from haystack.utils.jinja2_extensions import _extract_template_variables_and_assignments

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -174,11 +174,13 @@ def __init__(
self._env = SandboxedEnvironment()

self.template = self._env.from_string(template)

if not variables:
# infer variables from template
ast = self._env.parse(template)
template_variables = meta.find_undeclared_variables(ast)
variables = list(template_variables)
assigned_variables, template_variables = _extract_template_variables_and_assignments(
env=self._env, template=template
)
variables = list(template_variables - assigned_variables)

variables = variables or []
self.variables = variables

Expand Down
20 changes: 7 additions & 13 deletions haystack/components/converters/output_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,14 @@
from typing import Any, Callable, Optional

import jinja2.runtime
from jinja2 import Environment, TemplateSyntaxError, meta
from jinja2 import TemplateSyntaxError
from jinja2.nativetypes import NativeEnvironment
from jinja2.sandbox import SandboxedEnvironment
from typing_extensions import TypeAlias

from haystack import component, default_from_dict, default_to_dict, logging
from haystack.utils import deserialize_callable, deserialize_type, serialize_callable, serialize_type
from haystack.utils.jinja2_extensions import _extract_template_variables_and_assignments

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -46,7 +47,7 @@ def __init__(
output_type: TypeAlias,
custom_filters: Optional[dict[str, Callable]] = None,
unsafe: bool = False,
):
) -> None:
"""
Create an OutputAdapter component.

Expand Down Expand Up @@ -92,7 +93,10 @@ def __init__(
self._env.filters[name] = filter_func

# b) extract variables in the template
route_input_names = self._extract_variables(self._env)
assigned_variables, template_variables = _extract_template_variables_and_assignments(
env=self._env, template=self.template
)
route_input_names = template_variables - assigned_variables
input_types.update(route_input_names)

# the env is not needed, discarded automatically
Expand Down Expand Up @@ -173,13 +177,3 @@ def from_dict(cls, data: dict[str, Any]) -> "OutputAdapter":
for name, filter_func in custom_filters.items()
}
return default_from_dict(cls, data)

def _extract_variables(self, env: Environment) -> set[str]:
"""
Extracts all variables from a list of Jinja template strings.

:param env: A Jinja environment.
:return: A set of variable names extracted from the template strings.
"""
ast = env.parse(self.template)
return meta.find_undeclared_variables(ast)
11 changes: 8 additions & 3 deletions haystack/components/routers/conditional_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,13 @@
import contextlib
from typing import Any, Callable, Mapping, Optional, Sequence, TypedDict, Union, get_args, get_origin

from jinja2 import Environment, TemplateSyntaxError, meta
from jinja2 import Environment, TemplateSyntaxError
from jinja2.nativetypes import NativeEnvironment
from jinja2.sandbox import SandboxedEnvironment

from haystack import component, default_from_dict, default_to_dict, logging
from haystack.utils import deserialize_callable, deserialize_type, serialize_callable, serialize_type
from haystack.utils.jinja2_extensions import _extract_template_variables_and_assignments

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -403,7 +404,8 @@ def _validate_routes(self, routes: list[Route]):
if not self._validate_template(self._env, output):
raise ValueError(f"Invalid template for output: {output}")

def _extract_variables(self, env: Environment, templates: list[str]) -> set[str]:
@staticmethod
def _extract_variables(env: Environment, templates: list[str]) -> set[str]:
"""
Extracts all variables from a list of Jinja template strings.

Expand All @@ -413,7 +415,10 @@ def _extract_variables(self, env: Environment, templates: list[str]) -> set[str]
"""
variables = set()
for template in templates:
variables.update(meta.find_undeclared_variables(env.parse(template)))
assigned_variables, template_variables = _extract_template_variables_and_assignments(
env=env, template=template
)
variables.update(template_variables - assigned_variables)
return variables

def _validate_template(self, env: Environment, template_text: str):
Expand Down
41 changes: 40 additions & 1 deletion haystack/utils/jinja2_extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from typing import Any, Optional, Union

from jinja2 import Environment, nodes
from jinja2 import Environment, meta, nodes
from jinja2.ext import Extension

from haystack.lazy_imports import LazyImport
Expand Down Expand Up @@ -94,3 +94,42 @@ def parse(self, parser: Any) -> Union[nodes.Node, list[nodes.Node]]:
)

return nodes.Output([call_method], lineno=lineno)


def _collect_assigned_variables(ast: nodes.Template) -> set[str]:
"""
Extract variables assigned within the Jinja2 template AST.

:param ast: The Jinja2 Abstract Syntax Tree (AST) of the template.

:returns:
A set of variable names that are assigned within the template.
"""
# Collect all variables assigned inside the template via {% set %}
assigned_variables = set()

for node in ast.find_all(nodes.Assign):
if isinstance(node.target, nodes.Name):
assigned_variables.add(node.target.name)
elif isinstance(node.target, (nodes.List, nodes.Tuple)):
for name_node in node.target.items:
if isinstance(name_node, nodes.Name):
assigned_variables.add(name_node.name)

return assigned_variables


def _extract_template_variables_and_assignments(env: Environment, template: str) -> tuple[set[str], set[str]]:
"""
Extract variables from a Jinja2 template and variables assigned within it.

:param env: A Jinja2 environment.
:param template: A Jinja2 template string.
:returns: A tuple of (assigned_variables, template_variables) where:
- assigned_variables: Variables assigned within the template (e.g., via {% set %})
- template_variables: All undeclared variables used in the template
"""
jinja2_ast = env.parse(template)
template_variables = meta.find_undeclared_variables(jinja2_ast)
assigned_variables = _collect_assigned_variables(jinja2_ast)
return assigned_variables, template_variables
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
---
fixes:
- |
Fixes jinja2 variable detection in ``ConditionalRouter``, ``ChatPromptBuilder``, ``PromptBuilder`` and ``OutputAdapter`` by properly
skipping variables that are assigned within the template.
Previously under specific scenarios variables assigned within a template would falsely be picked up as input variables to the component.
For more information you can check out the parent issue in the Jinja2 library here: https://github.com/pallets/jinja/issues/2069
50 changes: 50 additions & 0 deletions test/components/builders/test_chat_prompt_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -957,3 +957,53 @@ def test_from_dict(self):
assert builder.template == template
assert builder.variables == ["name", "assistant_name"]
assert builder.required_variables == ["name"]

def test_variables_correct_with_assignment(self):
template = """{% message role="user" %}
{% if existing_documents is not none -%}
{% set x = existing_documents|length -%}
{% else -%}
{% set x = 0 -%}
{% endif -%}
The number is {{ x }}!
{% endmessage %}
"""
builder = ChatPromptBuilder(template=template, required_variables="*")
assert builder.variables == ["existing_documents"]
assert builder.required_variables == "*"
res = builder.run(existing_documents=None)
assert res["prompt"][0].text == "The number is 0!"

def test_variables_correct_with_tuple_assignment(self):
template = """{% message role="user" %}
{% if name is not none -%}
{% set x, y = (0, 1) %}
{% else -%}
{% set x, y = (2, 3) %}
{% endif -%}
x={{ x }}, y={{ y }}
Hello, my name is {{name}}!
{% endmessage %}
"""
builder = ChatPromptBuilder(template=template, required_variables="*")
assert builder.variables == ["name"]
assert builder.required_variables == "*"
res = builder.run(name="John")
assert res["prompt"][0].text == "x=0, y=1\nHello, my name is John!"

def test_variables_correct_with_list_assignment(self):
template = """{% message role="user" %}
{% if name is not none -%}
{% set x, y = [0, 1] %}
{% else -%}
{% set x, y = [2, 3] %}
{% endif -%}
x={{ x }}, y={{ y }}
Hello, my name is {{name}}!
{% endmessage %}
"""
builder = ChatPromptBuilder(template=template, required_variables="*")
assert builder.variables == ["name"]
assert builder.required_variables == "*"
res = builder.run(name="John")
assert res["prompt"][0].text == "x=0, y=1\nHello, my name is John!"
44 changes: 44 additions & 0 deletions test/components/builders/test_prompt_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,3 +337,47 @@ def test_warning_no_required_variables(self, caplog):
with caplog.at_level(logging.WARNING):
_ = PromptBuilder(template="This is a {{ variable }}")
assert "but `required_variables` is not set." in caplog.text

def test_variables_correct_with_assignment(self) -> None:
template = """{% if existing_documents is not none %}
{% set existing_doc_len = existing_documents|length %}
{% else %}
{% set existing_doc_len = 0 %}
{% endif %}
{% for doc in docs %}
<document reference="{{loop.index + existing_doc_len}}">
{{ doc.content }}
</document>
{% endfor %}
"""
builder = PromptBuilder(template=template, required_variables="*")
assert set(builder.variables) == {"docs", "existing_documents"}
assert builder.required_variables == "*"

def test_variables_correct_with_tuple_assignment(self):
template = """{% if existing_documents is not none -%}
{% set x, y = (existing_documents|length, 1) -%}
{% else -%}
{% set x, y = (0, 1) -%}
{% endif -%}
x={{ x }}, y={{ y }}
"""
builder = PromptBuilder(template=template, required_variables="*")
assert builder.variables == ["existing_documents"]
assert builder.required_variables == "*"
res = builder.run(existing_documents=None)
assert res["prompt"] == "x=0, y=1"

def test_variables_correct_with_list_assignment(self):
template = """{% if existing_documents is not none -%}
{% set x, y = [existing_documents|length, 1] -%}
{% else -%}
{% set x, y = [0, 1] -%}
{% endif -%}
x={{ x }}, y={{ y }}
"""
builder = PromptBuilder(template=template, required_variables="*")
assert builder.variables == ["existing_documents"]
assert builder.required_variables == "*"
res = builder.run(existing_documents=None)
assert res["prompt"] == "x=0, y=1"
16 changes: 15 additions & 1 deletion test/components/converters/test_output_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,14 @@
# SPDX-License-Identifier: Apache-2.0

import json
from typing import List
from typing import Any, List

import pytest

from haystack import Pipeline, component
from haystack.components.converters import OutputAdapter
from haystack.components.converters.output_adapter import OutputAdaptationException
from haystack.core.component.sockets import InputSocket
from haystack.dataclasses import Document


Expand Down Expand Up @@ -203,3 +204,16 @@ def test_unsafe(self):
]
res = adapter.run(documents=documents)
assert res["output"] == documents[0]

def test_variables_correct_with_assignment(self) -> None:
template = """{% if control == 'something' %}
{% set output = 1 %}
{% else %}
{% set output = 3 %}
{% endif %}
{{ output }}
"""
adapter = OutputAdapter(template=template, output_type=int)
assert adapter.__haystack_input__._sockets_dict == {"control": InputSocket(name="control", type=Any)}
res = adapter.run(control="something")
assert res["output"] == 1
13 changes: 13 additions & 0 deletions test/components/routers/test_conditional_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from unittest import mock

import pytest
from jinja2.nativetypes import NativeEnvironment

from haystack import Pipeline
from haystack.components.routers import ConditionalRouter
Expand Down Expand Up @@ -636,3 +637,15 @@ def test_sede_multiple_outputs(self):
reloaded_router = ConditionalRouter.from_dict(router.to_dict())
assert reloaded_router.custom_filters == router.custom_filters
assert reloaded_router.routes == router.routes

def test_extract_variables_correct_with_assignment(self):
condition = """{%- if control == 'something' -%}
{% set streams = 1 %}
{%- else -%}
{% set streams = 2 %}
{%- endif -%}
{{streams == 1}}
"""
templates = [condition, "{{query}}"]
extracted_variables = ConditionalRouter._extract_variables(env=NativeEnvironment(), templates=templates)
assert extracted_variables == {"control", "query"}
Loading