Skip to content

Commit 557cba2

Browse files
feat: add output_passthrough to ConditionalRouter for non-Jinja2 routing (#11555)
Co-authored-by: David S. Batista <dsbatista@gmail.com>
1 parent 70beefb commit 557cba2

3 files changed

Lines changed: 374 additions & 30 deletions

File tree

haystack/components/routers/conditional_router.py

Lines changed: 119 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from jinja2 import Environment, TemplateSyntaxError
1111
from jinja2.nativetypes import NativeEnvironment
1212
from jinja2.sandbox import SandboxedEnvironment
13+
from typing_extensions import NotRequired
1314

1415
from haystack import component, default_from_dict, default_to_dict, logging
1516
from haystack.utils import deserialize_callable, deserialize_type, serialize_callable, serialize_type
@@ -32,6 +33,7 @@ class Route(TypedDict):
3233
output: str | list[str]
3334
output_name: str | list[str]
3435
output_type: type | list[type]
36+
output_passthrough: NotRequired[bool]
3537

3638

3739
@component
@@ -47,6 +49,10 @@ class ConditionalRouter:
4749
- `output_name`: The name you want to use to publish `output`. This name is used to connect
4850
the router to other components in the pipeline.
4951
52+
An optional field `output_passthrough` can be set to `True` to treat `output` as a variable name
53+
instead of a Jinja2 template, passing the variable value directly. This is useful for routing
54+
complex non-basic types (dataclasses, Pydantic models, etc.) without Jinja2 processing.
55+
5056
### Usage example
5157
5258
```python
@@ -116,6 +122,64 @@ class ConditionalRouter:
116122
print(result)
117123
# >> {'router': {'few_items': 'Processing few items'}}
118124
```
125+
126+
### Passthrough routing for non-basic types
127+
128+
Without `output_passthrough`, the router renders `output` as a Jinja2 template, which converts
129+
the value to its string representation. Custom types cannot survive that round-trip:
130+
131+
```python
132+
# Without output_passthrough — the object is silently converted to a string
133+
routes = [
134+
{
135+
"condition": "{{True}}",
136+
"output": "{{query}}",
137+
"output_name": "out",
138+
"output_type": ParsedQuery,
139+
}
140+
]
141+
router = ConditionalRouter(routes)
142+
result = router.run(query=ParsedQuery(text="hello", intent="search", entities=[]))
143+
# result["out"] == "ParsedQuery(text='hello', intent='search', entities=[])"
144+
# ^^^ str, not ParsedQuery — the object was destroyed
145+
```
146+
147+
Set `output_passthrough: True` to skip Jinja2 entirely and pass the value directly from kwargs:
148+
149+
```python
150+
from haystack.components.routers import ConditionalRouter
151+
from dataclasses import dataclass, field
152+
153+
@dataclass
154+
class ParsedQuery:
155+
text: str
156+
intent: str # "search" | "chat"
157+
entities: list[str] = field(default_factory=list)
158+
159+
routes = [
160+
{
161+
"condition": "{{query.intent == 'search'}}",
162+
"output": "query", # variable name, not a Jinja2 template
163+
"output_name": "search_query",
164+
"output_type": ParsedQuery,
165+
"output_passthrough": True,
166+
},
167+
{
168+
"condition": "{{query.intent == 'chat'}}",
169+
"output": "query",
170+
"output_name": "chat_query",
171+
"output_type": ParsedQuery,
172+
"output_passthrough": True,
173+
},
174+
]
175+
176+
router = ConditionalRouter(routes)
177+
query = ParsedQuery(text="What is Haystack?", intent="search", entities=["Haystack"])
178+
result = router.run(query=query)
179+
180+
assert isinstance(result["search_query"], ParsedQuery) # type preserved
181+
assert result["search_query"] is query # same object, no copying
182+
```
119183
"""
120184

121185
def __init__(
@@ -132,10 +196,16 @@ def __init__(
132196
:param routes: A list of dictionaries, each defining a route.
133197
Each route has these four elements:
134198
- `condition`: A Jinja2 string expression that determines if the route is selected.
135-
- `output`: A Jinja2 expression defining the route's output value.
199+
- `output`: A Jinja2 expression defining the route's output value, or a plain variable name
200+
if `output_passthrough` is `True`.
136201
- `output_type`: The type of the output data (for example, `str`, `list[int]`).
137202
- `output_name`: The name you want to use to publish `output`. This name is used to connect
138203
the router to other components in the pipeline.
204+
- `output_passthrough` (optional): If `True`, treats `output` as a plain variable name and
205+
passes the value directly from the input kwargs, skipping all Jinja2 processing. Useful
206+
for routing complex non-basic types without template transformation.
207+
Note: if the variable named in `output` is also listed in `optional_variables`, a missing
208+
value at runtime will route `None` downstream rather than raising a `ValueError`.
139209
:param custom_filters: A dictionary of custom Jinja2 filters used in the condition expressions.
140210
For example, passing `{"my_filter": my_filter_fcn}` where:
141211
- `my_filter` is the name of the custom filter.
@@ -214,11 +284,17 @@ def __init__(
214284
output_types: dict[str, type | list[type]] = {}
215285

216286
for route in routes:
217-
# extract inputs
218-
route_input_names = self._extract_variables(
219-
self._env,
220-
[route["condition"]] + (route["output"] if isinstance(route["output"], list) else [route["output"]]),
221-
)
287+
output_passthrough = route.get("output_passthrough", False)
288+
outputs = route["output"] if isinstance(route["output"], list) else [route["output"]]
289+
290+
if output_passthrough:
291+
# For passthrough routes, output values are plain variable names — treat them as inputs
292+
route_input_names = self._extract_variables(self._env, [route["condition"]])
293+
route_input_names.update(outputs)
294+
else:
295+
# For normal routes, extract variables from both condition and output templates
296+
route_input_names = self._extract_variables(self._env, [route["condition"]] + outputs)
297+
222298
input_types.update(route_input_names)
223299

224300
# extract outputs
@@ -322,9 +398,9 @@ def run(self, **kwargs: Any) -> dict[str, Any]:
322398
:raises RouteConditionException:
323399
If there is an error parsing or evaluating the `condition` expression in the routes.
324400
:raises ValueError:
325-
If type validation is enabled and route type doesn't match actual value type.
401+
If type validation is enabled and the route output doesn't match the declared type, or if
402+
`output_passthrough` is `True` and the variable named in `output` is not found in kwargs.
326403
"""
327-
# Create a Jinja native environment to evaluate the condition templates as Python expressions
328404
for route in self.routes:
329405
try:
330406
t = self._env.from_string(route["condition"])
@@ -342,20 +418,30 @@ def run(self, **kwargs: Any) -> dict[str, Any]:
342418
output_names = (
343419
route["output_name"] if isinstance(route["output_name"], list) else [route["output_name"]]
344420
)
421+
output_passthrough = route.get("output_passthrough", False)
345422

346423
result = {}
347424
for output, output_type, output_name in zip(outputs, output_types, output_names, strict=True):
348-
# Evaluate output template
349-
t_output = self._env.from_string(output)
350-
output_value = t_output.render(**kwargs)
351-
352-
# We suppress the exception in case the output is already a string, otherwise
353-
# we try to evaluate it and would fail.
354-
# This must be done cause the output could be different literal structures.
355-
# This doesn't support any user types.
356-
with contextlib.suppress(Exception):
357-
if not self._unsafe:
358-
output_value = ast.literal_eval(output_value)
425+
if output_passthrough:
426+
# output is a plain variable name — retrieve directly from kwargs, no Jinja2 processing
427+
if output not in kwargs:
428+
raise ValueError( # noqa: TRY301
429+
f"Variable '{output}' not found in inputs for passthrough route '{output_name}'. "
430+
f"Ensure '{output}' is passed as an input to the router."
431+
)
432+
output_value = kwargs[output]
433+
else:
434+
# Standard Jinja2 template evaluation
435+
t_output = self._env.from_string(output)
436+
output_value = t_output.render(**kwargs)
437+
438+
# We suppress the exception in case the output is already a string, otherwise
439+
# we try to evaluate it and would fail.
440+
# This must be done cause the output could be different literal structures.
441+
# This doesn't support any user types.
442+
with contextlib.suppress(Exception):
443+
if not self._unsafe:
444+
output_value = ast.literal_eval(output_value)
359445

360446
# Validate output type if needed
361447
if self._validate_output_type and not self._output_matches_type(output_value, output_type):
@@ -366,7 +452,7 @@ def run(self, **kwargs: Any) -> dict[str, Any]:
366452
return result
367453

368454
except Exception as e:
369-
# If this was a typevalidation failure, let it propagate as a ValueError
455+
# If this was a type-validation failure or missing passthrough variable, let it propagate
370456
if isinstance(e, ValueError):
371457
raise
372458
msg = f"Error evaluating condition for route '{route}': {e}"
@@ -402,7 +488,7 @@ def _validate_routes(self, routes: list[Route]) -> None:
402488
if not len(outputs) == len(output_types) == len(output_names):
403489
raise ValueError(f"Route output, output_type and output_name must have same length: {route}")
404490

405-
# Validate templates
491+
# Condition is always a Jinja2 template — validate it
406492
if not self._validate_template(self._env, route["condition"]):
407493
condition_value = route["condition"]
408494
if not isinstance(condition_value, str):
@@ -413,15 +499,18 @@ def _validate_routes(self, routes: list[Route]) -> None:
413499
)
414500
raise ValueError(f"Invalid template for condition: {condition_value}")
415501

416-
for output in outputs:
417-
if not self._validate_template(self._env, output):
418-
if not isinstance(output, str):
419-
raise ValueError(
420-
f"Invalid template for output: {output!r} (type: {type(output).__name__}). "
421-
f"Output must be a string representing a valid Jinja2 template. "
422-
f"For example, use {str(output)!r} instead of {output!r}."
423-
)
424-
raise ValueError(f"Invalid template for output: {output}")
502+
# Only validate output as Jinja2 template when output_passthrough is False (default)
503+
output_passthrough = route.get("output_passthrough", False)
504+
if not output_passthrough:
505+
for output in outputs:
506+
if not self._validate_template(self._env, output):
507+
if not isinstance(output, str):
508+
raise ValueError(
509+
f"Invalid template for output: {output!r} (type: {type(output).__name__}). "
510+
f"Output must be a string representing a valid Jinja2 template. "
511+
f"For example, use {str(output)!r} instead of {output!r}."
512+
)
513+
raise ValueError(f"Invalid template for output: {output}")
425514

426515
@staticmethod
427516
def _extract_variables(env: Environment, templates: list[str]) -> set[str]:
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
---
2+
enhancements:
3+
- |
4+
Add ``output_passthrough`` option to ``ConditionalRouter``.
5+
When ``output_passthrough: True`` is set in a route, the ``output`` field is treated as a plain
6+
variable name instead of a Jinja2 template, and the value is passed directly from the pipeline
7+
inputs to the route output. This allows routing of complex non-basic types such as dataclasses
8+
and Pydantic models without unwanted Jinja2 template processing.
9+
10+
Without ``output_passthrough``, the router renders ``output`` as a Jinja2 template, which converts
11+
the value to its string representation. Custom types cannot survive that round-trip:
12+
13+
.. code:: python
14+
15+
# Without output_passthrough — the object is silently converted to a string
16+
routes = [
17+
{
18+
"condition": "{{True}}",
19+
"output": "{{query}}",
20+
"output_name": "out",
21+
"output_type": ParsedQuery,
22+
}
23+
]
24+
router = ConditionalRouter(routes)
25+
result = router.run(query=ParsedQuery(text="hello", intent="search", entities=[]))
26+
# result["out"] == "ParsedQuery(text='hello', intent='search', entities=[])"
27+
# ^^^ str, not ParsedQuery — the object was destroyed
28+
29+
Set ``output_passthrough: True`` to skip Jinja2 entirely and pass the value directly from kwargs:
30+
31+
.. code:: python
32+
33+
from haystack.components.routers import ConditionalRouter
34+
from dataclasses import dataclass, field
35+
36+
@dataclass
37+
class ParsedQuery:
38+
text: str
39+
intent: str # "search" | "chat"
40+
entities: list[str] = field(default_factory=list)
41+
42+
routes = [
43+
{
44+
"condition": "{{query.intent == 'search'}}",
45+
"output": "query", # variable name, not a Jinja2 template
46+
"output_name": "search_query",
47+
"output_type": ParsedQuery,
48+
"output_passthrough": True,
49+
},
50+
{
51+
"condition": "{{query.intent == 'chat'}}",
52+
"output": "query",
53+
"output_name": "chat_query",
54+
"output_type": ParsedQuery,
55+
"output_passthrough": True,
56+
},
57+
]
58+
59+
router = ConditionalRouter(routes)
60+
query = ParsedQuery(text="What is Haystack?", intent="search", entities=["Haystack"])
61+
result = router.run(query=query)
62+
63+
assert isinstance(result["search_query"], ParsedQuery) # type preserved
64+
assert result["search_query"] is query # same object, no copying

0 commit comments

Comments
 (0)