Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
149 changes: 119 additions & 30 deletions haystack/components/routers/conditional_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from jinja2 import Environment, TemplateSyntaxError
from jinja2.nativetypes import NativeEnvironment
from jinja2.sandbox import SandboxedEnvironment
from typing_extensions import NotRequired

from haystack import component, default_from_dict, default_to_dict, logging
from haystack.utils import deserialize_callable, deserialize_type, serialize_callable, serialize_type
Expand All @@ -32,6 +33,7 @@ class Route(TypedDict):
output: str | list[str]
output_name: str | list[str]
output_type: type | list[type]
output_passthrough: NotRequired[bool]


@component
Expand All @@ -47,6 +49,10 @@ class ConditionalRouter:
- `output_name`: The name you want to use to publish `output`. This name is used to connect
the router to other components in the pipeline.

An optional field `output_passthrough` can be set to `True` to treat `output` as a variable name
instead of a Jinja2 template, passing the variable value directly. This is useful for routing
complex non-basic types (dataclasses, Pydantic models, etc.) without Jinja2 processing.

### Usage example

```python
Expand Down Expand Up @@ -116,6 +122,64 @@ class ConditionalRouter:
print(result)
# >> {'router': {'few_items': 'Processing few items'}}
```

### Passthrough routing for non-basic types

Without `output_passthrough`, the router renders `output` as a Jinja2 template, which converts
the value to its string representation. Custom types cannot survive that round-trip:

```python
# Without output_passthrough — the object is silently converted to a string
routes = [
{
"condition": "{{True}}",
"output": "{{query}}",
"output_name": "out",
"output_type": ParsedQuery,
}
]
router = ConditionalRouter(routes)
result = router.run(query=ParsedQuery(text="hello", intent="search", entities=[]))
# result["out"] == "ParsedQuery(text='hello', intent='search', entities=[])"
# ^^^ str, not ParsedQuery — the object was destroyed
```

Set `output_passthrough: True` to skip Jinja2 entirely and pass the value directly from kwargs:

```python
from haystack.components.routers import ConditionalRouter
from dataclasses import dataclass, field

@dataclass
class ParsedQuery:
text: str
intent: str # "search" | "chat"
entities: list[str] = field(default_factory=list)

routes = [
{
"condition": "{{query.intent == 'search'}}",
"output": "query", # variable name, not a Jinja2 template
"output_name": "search_query",
"output_type": ParsedQuery,
"output_passthrough": True,
},
{
"condition": "{{query.intent == 'chat'}}",
"output": "query",
"output_name": "chat_query",
"output_type": ParsedQuery,
"output_passthrough": True,
},
]

router = ConditionalRouter(routes)
query = ParsedQuery(text="What is Haystack?", intent="search", entities=["Haystack"])
result = router.run(query=query)

assert isinstance(result["search_query"], ParsedQuery) # type preserved
assert result["search_query"] is query # same object, no copying
```
"""

def __init__(
Expand All @@ -132,10 +196,16 @@ def __init__(
:param routes: A list of dictionaries, each defining a route.
Each route has these four elements:
- `condition`: A Jinja2 string expression that determines if the route is selected.
- `output`: A Jinja2 expression defining the route's output value.
- `output`: A Jinja2 expression defining the route's output value, or a plain variable name
if `output_passthrough` is `True`.
- `output_type`: The type of the output data (for example, `str`, `list[int]`).
- `output_name`: The name you want to use to publish `output`. This name is used to connect
the router to other components in the pipeline.
- `output_passthrough` (optional): If `True`, treats `output` as a plain variable name and
passes the value directly from the input kwargs, skipping all Jinja2 processing. Useful
for routing complex non-basic types without template transformation.
Note: if the variable named in `output` is also listed in `optional_variables`, a missing
value at runtime will route `None` downstream rather than raising a `ValueError`.
:param custom_filters: A dictionary of custom Jinja2 filters used in the condition expressions.
For example, passing `{"my_filter": my_filter_fcn}` where:
- `my_filter` is the name of the custom filter.
Expand Down Expand Up @@ -214,11 +284,17 @@ def __init__(
output_types: dict[str, type | list[type]] = {}

for route in routes:
# extract inputs
route_input_names = self._extract_variables(
self._env,
[route["condition"]] + (route["output"] if isinstance(route["output"], list) else [route["output"]]),
)
output_passthrough = route.get("output_passthrough", False)
outputs = route["output"] if isinstance(route["output"], list) else [route["output"]]

if output_passthrough:
# For passthrough routes, output values are plain variable names — treat them as inputs
route_input_names = self._extract_variables(self._env, [route["condition"]])
route_input_names.update(outputs)
else:
# For normal routes, extract variables from both condition and output templates
route_input_names = self._extract_variables(self._env, [route["condition"]] + outputs)

input_types.update(route_input_names)

# extract outputs
Expand Down Expand Up @@ -322,9 +398,9 @@ def run(self, **kwargs: Any) -> dict[str, Any]:
:raises RouteConditionException:
If there is an error parsing or evaluating the `condition` expression in the routes.
:raises ValueError:
If type validation is enabled and route type doesn't match actual value type.
If type validation is enabled and the route output doesn't match the declared type, or if
`output_passthrough` is `True` and the variable named in `output` is not found in kwargs.
"""
# Create a Jinja native environment to evaluate the condition templates as Python expressions
for route in self.routes:
try:
t = self._env.from_string(route["condition"])
Expand All @@ -342,20 +418,30 @@ def run(self, **kwargs: Any) -> dict[str, Any]:
output_names = (
route["output_name"] if isinstance(route["output_name"], list) else [route["output_name"]]
)
output_passthrough = route.get("output_passthrough", False)

result = {}
for output, output_type, output_name in zip(outputs, output_types, output_names, strict=True):
# Evaluate output template
t_output = self._env.from_string(output)
output_value = t_output.render(**kwargs)

# We suppress the exception in case the output is already a string, otherwise
# we try to evaluate it and would fail.
# This must be done cause the output could be different literal structures.
# This doesn't support any user types.
with contextlib.suppress(Exception):
if not self._unsafe:
output_value = ast.literal_eval(output_value)
if output_passthrough:
# output is a plain variable name — retrieve directly from kwargs, no Jinja2 processing
if output not in kwargs:
raise ValueError( # noqa: TRY301
f"Variable '{output}' not found in inputs for passthrough route '{output_name}'. "
f"Ensure '{output}' is passed as an input to the router."
)
output_value = kwargs[output]
else:
# Standard Jinja2 template evaluation
t_output = self._env.from_string(output)
output_value = t_output.render(**kwargs)

# We suppress the exception in case the output is already a string, otherwise
# we try to evaluate it and would fail.
# This must be done cause the output could be different literal structures.
# This doesn't support any user types.
with contextlib.suppress(Exception):
if not self._unsafe:
output_value = ast.literal_eval(output_value)

# Validate output type if needed
if self._validate_output_type and not self._output_matches_type(output_value, output_type):
Expand All @@ -366,7 +452,7 @@ def run(self, **kwargs: Any) -> dict[str, Any]:
return result

except Exception as e:
# If this was a typevalidation failure, let it propagate as a ValueError
# If this was a type-validation failure or missing passthrough variable, let it propagate
if isinstance(e, ValueError):
raise
msg = f"Error evaluating condition for route '{route}': {e}"
Expand Down Expand Up @@ -402,7 +488,7 @@ def _validate_routes(self, routes: list[Route]) -> None:
if not len(outputs) == len(output_types) == len(output_names):
raise ValueError(f"Route output, output_type and output_name must have same length: {route}")

# Validate templates
# Condition is always a Jinja2 template — validate it
if not self._validate_template(self._env, route["condition"]):
condition_value = route["condition"]
if not isinstance(condition_value, str):
Expand All @@ -413,15 +499,18 @@ def _validate_routes(self, routes: list[Route]) -> None:
)
raise ValueError(f"Invalid template for condition: {condition_value}")

for output in outputs:
if not self._validate_template(self._env, output):
if not isinstance(output, str):
raise ValueError(
f"Invalid template for output: {output!r} (type: {type(output).__name__}). "
f"Output must be a string representing a valid Jinja2 template. "
f"For example, use {str(output)!r} instead of {output!r}."
)
raise ValueError(f"Invalid template for output: {output}")
# Only validate output as Jinja2 template when output_passthrough is False (default)
output_passthrough = route.get("output_passthrough", False)
if not output_passthrough:
for output in outputs:
if not self._validate_template(self._env, output):
if not isinstance(output, str):
raise ValueError(
f"Invalid template for output: {output!r} (type: {type(output).__name__}). "
f"Output must be a string representing a valid Jinja2 template. "
f"For example, use {str(output)!r} instead of {output!r}."
)
raise ValueError(f"Invalid template for output: {output}")

@staticmethod
def _extract_variables(env: Environment, templates: list[str]) -> set[str]:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
---
enhancements:
- |
Add ``output_passthrough`` option to ``ConditionalRouter``.
When ``output_passthrough: True`` is set in a route, the ``output`` field is treated as a plain
variable name instead of a Jinja2 template, and the value is passed directly from the pipeline
inputs to the route output. This allows routing of complex non-basic types such as dataclasses
and Pydantic models without unwanted Jinja2 template processing.

Without ``output_passthrough``, the router renders ``output`` as a Jinja2 template, which converts
the value to its string representation. Custom types cannot survive that round-trip:

.. code:: python

# Without output_passthrough — the object is silently converted to a string
routes = [
{
"condition": "{{True}}",
"output": "{{query}}",
"output_name": "out",
"output_type": ParsedQuery,
}
]
router = ConditionalRouter(routes)
result = router.run(query=ParsedQuery(text="hello", intent="search", entities=[]))
# result["out"] == "ParsedQuery(text='hello', intent='search', entities=[])"
# ^^^ str, not ParsedQuery — the object was destroyed

Set ``output_passthrough: True`` to skip Jinja2 entirely and pass the value directly from kwargs:

.. code:: python

from haystack.components.routers import ConditionalRouter
from dataclasses import dataclass, field

@dataclass
class ParsedQuery:
text: str
intent: str # "search" | "chat"
entities: list[str] = field(default_factory=list)

routes = [
{
"condition": "{{query.intent == 'search'}}",
"output": "query", # variable name, not a Jinja2 template
"output_name": "search_query",
"output_type": ParsedQuery,
"output_passthrough": True,
},
{
"condition": "{{query.intent == 'chat'}}",
"output": "query",
"output_name": "chat_query",
"output_type": ParsedQuery,
"output_passthrough": True,
},
]

router = ConditionalRouter(routes)
query = ParsedQuery(text="What is Haystack?", intent="search", entities=["Haystack"])
result = router.run(query=query)

assert isinstance(result["search_query"], ParsedQuery) # type preserved
assert result["search_query"] is query # same object, no copying
Loading
Loading