diff --git a/haystack/components/routers/conditional_router.py b/haystack/components/routers/conditional_router.py index ec12a30893..61cd23c6ce 100644 --- a/haystack/components/routers/conditional_router.py +++ b/haystack/components/routers/conditional_router.py @@ -10,6 +10,7 @@ from jinja2 import Environment, TemplateSyntaxError from jinja2.nativetypes import NativeEnvironment from jinja2.sandbox import SandboxedEnvironment +from typing_extensions import NotRequired from haystack import component, default_from_dict, default_to_dict, logging from haystack.utils import deserialize_callable, deserialize_type, serialize_callable, serialize_type @@ -32,6 +33,7 @@ class Route(TypedDict): output: str | list[str] output_name: str | list[str] output_type: type | list[type] + output_passthrough: NotRequired[bool] @component @@ -47,6 +49,10 @@ class ConditionalRouter: - `output_name`: The name you want to use to publish `output`. This name is used to connect the router to other components in the pipeline. + An optional field `output_passthrough` can be set to `True` to treat `output` as a variable name + instead of a Jinja2 template, passing the variable value directly. This is useful for routing + complex non-basic types (dataclasses, Pydantic models, etc.) without Jinja2 processing. + ### Usage example ```python @@ -116,6 +122,64 @@ class ConditionalRouter: print(result) # >> {'router': {'few_items': 'Processing few items'}} ``` + + ### Passthrough routing for non-basic types + + Without `output_passthrough`, the router renders `output` as a Jinja2 template, which converts + the value to its string representation. Custom types cannot survive that round-trip: + + ```python + # Without output_passthrough — the object is silently converted to a string + routes = [ + { + "condition": "{{True}}", + "output": "{{query}}", + "output_name": "out", + "output_type": ParsedQuery, + } + ] + router = ConditionalRouter(routes) + result = router.run(query=ParsedQuery(text="hello", intent="search", entities=[])) + # result["out"] == "ParsedQuery(text='hello', intent='search', entities=[])" + # ^^^ str, not ParsedQuery — the object was destroyed + ``` + + Set `output_passthrough: True` to skip Jinja2 entirely and pass the value directly from kwargs: + + ```python + from haystack.components.routers import ConditionalRouter + from dataclasses import dataclass, field + + @dataclass + class ParsedQuery: + text: str + intent: str # "search" | "chat" + entities: list[str] = field(default_factory=list) + + routes = [ + { + "condition": "{{query.intent == 'search'}}", + "output": "query", # variable name, not a Jinja2 template + "output_name": "search_query", + "output_type": ParsedQuery, + "output_passthrough": True, + }, + { + "condition": "{{query.intent == 'chat'}}", + "output": "query", + "output_name": "chat_query", + "output_type": ParsedQuery, + "output_passthrough": True, + }, + ] + + router = ConditionalRouter(routes) + query = ParsedQuery(text="What is Haystack?", intent="search", entities=["Haystack"]) + result = router.run(query=query) + + assert isinstance(result["search_query"], ParsedQuery) # type preserved + assert result["search_query"] is query # same object, no copying + ``` """ def __init__( @@ -132,10 +196,16 @@ def __init__( :param routes: A list of dictionaries, each defining a route. Each route has these four elements: - `condition`: A Jinja2 string expression that determines if the route is selected. - - `output`: A Jinja2 expression defining the route's output value. + - `output`: A Jinja2 expression defining the route's output value, or a plain variable name + if `output_passthrough` is `True`. - `output_type`: The type of the output data (for example, `str`, `list[int]`). - `output_name`: The name you want to use to publish `output`. This name is used to connect the router to other components in the pipeline. + - `output_passthrough` (optional): If `True`, treats `output` as a plain variable name and + passes the value directly from the input kwargs, skipping all Jinja2 processing. Useful + for routing complex non-basic types without template transformation. + Note: if the variable named in `output` is also listed in `optional_variables`, a missing + value at runtime will route `None` downstream rather than raising a `ValueError`. :param custom_filters: A dictionary of custom Jinja2 filters used in the condition expressions. For example, passing `{"my_filter": my_filter_fcn}` where: - `my_filter` is the name of the custom filter. @@ -214,11 +284,17 @@ def __init__( output_types: dict[str, type | list[type]] = {} for route in routes: - # extract inputs - route_input_names = self._extract_variables( - self._env, - [route["condition"]] + (route["output"] if isinstance(route["output"], list) else [route["output"]]), - ) + output_passthrough = route.get("output_passthrough", False) + outputs = route["output"] if isinstance(route["output"], list) else [route["output"]] + + if output_passthrough: + # For passthrough routes, output values are plain variable names — treat them as inputs + route_input_names = self._extract_variables(self._env, [route["condition"]]) + route_input_names.update(outputs) + else: + # For normal routes, extract variables from both condition and output templates + route_input_names = self._extract_variables(self._env, [route["condition"]] + outputs) + input_types.update(route_input_names) # extract outputs @@ -322,9 +398,9 @@ def run(self, **kwargs: Any) -> dict[str, Any]: :raises RouteConditionException: If there is an error parsing or evaluating the `condition` expression in the routes. :raises ValueError: - If type validation is enabled and route type doesn't match actual value type. + If type validation is enabled and the route output doesn't match the declared type, or if + `output_passthrough` is `True` and the variable named in `output` is not found in kwargs. """ - # Create a Jinja native environment to evaluate the condition templates as Python expressions for route in self.routes: try: t = self._env.from_string(route["condition"]) @@ -342,20 +418,30 @@ def run(self, **kwargs: Any) -> dict[str, Any]: output_names = ( route["output_name"] if isinstance(route["output_name"], list) else [route["output_name"]] ) + output_passthrough = route.get("output_passthrough", False) result = {} for output, output_type, output_name in zip(outputs, output_types, output_names, strict=True): - # Evaluate output template - t_output = self._env.from_string(output) - output_value = t_output.render(**kwargs) - - # We suppress the exception in case the output is already a string, otherwise - # we try to evaluate it and would fail. - # This must be done cause the output could be different literal structures. - # This doesn't support any user types. - with contextlib.suppress(Exception): - if not self._unsafe: - output_value = ast.literal_eval(output_value) + if output_passthrough: + # output is a plain variable name — retrieve directly from kwargs, no Jinja2 processing + if output not in kwargs: + raise ValueError( # noqa: TRY301 + f"Variable '{output}' not found in inputs for passthrough route '{output_name}'. " + f"Ensure '{output}' is passed as an input to the router." + ) + output_value = kwargs[output] + else: + # Standard Jinja2 template evaluation + t_output = self._env.from_string(output) + output_value = t_output.render(**kwargs) + + # We suppress the exception in case the output is already a string, otherwise + # we try to evaluate it and would fail. + # This must be done cause the output could be different literal structures. + # This doesn't support any user types. + with contextlib.suppress(Exception): + if not self._unsafe: + output_value = ast.literal_eval(output_value) # Validate output type if needed if self._validate_output_type and not self._output_matches_type(output_value, output_type): @@ -366,7 +452,7 @@ def run(self, **kwargs: Any) -> dict[str, Any]: return result except Exception as e: - # If this was a type‐validation failure, let it propagate as a ValueError + # If this was a type-validation failure or missing passthrough variable, let it propagate if isinstance(e, ValueError): raise msg = f"Error evaluating condition for route '{route}': {e}" @@ -402,7 +488,7 @@ def _validate_routes(self, routes: list[Route]) -> None: if not len(outputs) == len(output_types) == len(output_names): raise ValueError(f"Route output, output_type and output_name must have same length: {route}") - # Validate templates + # Condition is always a Jinja2 template — validate it if not self._validate_template(self._env, route["condition"]): condition_value = route["condition"] if not isinstance(condition_value, str): @@ -413,15 +499,18 @@ def _validate_routes(self, routes: list[Route]) -> None: ) raise ValueError(f"Invalid template for condition: {condition_value}") - for output in outputs: - if not self._validate_template(self._env, output): - if not isinstance(output, str): - raise ValueError( - f"Invalid template for output: {output!r} (type: {type(output).__name__}). " - f"Output must be a string representing a valid Jinja2 template. " - f"For example, use {str(output)!r} instead of {output!r}." - ) - raise ValueError(f"Invalid template for output: {output}") + # Only validate output as Jinja2 template when output_passthrough is False (default) + output_passthrough = route.get("output_passthrough", False) + if not output_passthrough: + for output in outputs: + if not self._validate_template(self._env, output): + if not isinstance(output, str): + raise ValueError( + f"Invalid template for output: {output!r} (type: {type(output).__name__}). " + f"Output must be a string representing a valid Jinja2 template. " + f"For example, use {str(output)!r} instead of {output!r}." + ) + raise ValueError(f"Invalid template for output: {output}") @staticmethod def _extract_variables(env: Environment, templates: list[str]) -> set[str]: diff --git a/releasenotes/notes/add-output-passthrough-to-conditional-router-d6e53b417916362e.yaml b/releasenotes/notes/add-output-passthrough-to-conditional-router-d6e53b417916362e.yaml new file mode 100644 index 0000000000..ad692474e7 --- /dev/null +++ b/releasenotes/notes/add-output-passthrough-to-conditional-router-d6e53b417916362e.yaml @@ -0,0 +1,64 @@ +--- +enhancements: + - | + Add ``output_passthrough`` option to ``ConditionalRouter``. + When ``output_passthrough: True`` is set in a route, the ``output`` field is treated as a plain + variable name instead of a Jinja2 template, and the value is passed directly from the pipeline + inputs to the route output. This allows routing of complex non-basic types such as dataclasses + and Pydantic models without unwanted Jinja2 template processing. + + Without ``output_passthrough``, the router renders ``output`` as a Jinja2 template, which converts + the value to its string representation. Custom types cannot survive that round-trip: + + .. code:: python + + # Without output_passthrough — the object is silently converted to a string + routes = [ + { + "condition": "{{True}}", + "output": "{{query}}", + "output_name": "out", + "output_type": ParsedQuery, + } + ] + router = ConditionalRouter(routes) + result = router.run(query=ParsedQuery(text="hello", intent="search", entities=[])) + # result["out"] == "ParsedQuery(text='hello', intent='search', entities=[])" + # ^^^ str, not ParsedQuery — the object was destroyed + + Set ``output_passthrough: True`` to skip Jinja2 entirely and pass the value directly from kwargs: + + .. code:: python + + from haystack.components.routers import ConditionalRouter + from dataclasses import dataclass, field + + @dataclass + class ParsedQuery: + text: str + intent: str # "search" | "chat" + entities: list[str] = field(default_factory=list) + + routes = [ + { + "condition": "{{query.intent == 'search'}}", + "output": "query", # variable name, not a Jinja2 template + "output_name": "search_query", + "output_type": ParsedQuery, + "output_passthrough": True, + }, + { + "condition": "{{query.intent == 'chat'}}", + "output": "query", + "output_name": "chat_query", + "output_type": ParsedQuery, + "output_passthrough": True, + }, + ] + + router = ConditionalRouter(routes) + query = ParsedQuery(text="What is Haystack?", intent="search", entities=["Haystack"]) + result = router.run(query=query) + + assert isinstance(result["search_query"], ParsedQuery) # type preserved + assert result["search_query"] is query # same object, no copying diff --git a/test/components/routers/test_conditional_router.py b/test/components/routers/test_conditional_router.py index d5ad89b93d..8d0f76f377 100644 --- a/test/components/routers/test_conditional_router.py +++ b/test/components/routers/test_conditional_router.py @@ -730,3 +730,194 @@ def test_extract_variables_correct_with_assignment(self): templates = [condition, "{{query}}"] extracted_variables = ConditionalRouter._extract_variables(env=NativeEnvironment(), templates=templates) assert extracted_variables == {"control", "query"} + + def test_conditional_router_passthrough_serialization_roundtrip(self): + """Test that output_passthrough survives to_dict/from_dict.""" + routes = [ + { + "condition": "{{flag}}", + "output": "value", + "output_name": "matched", + "output_type": str, + "output_passthrough": True, + }, + { + "condition": "{{not flag}}", + "output": "value", + "output_name": "unmatched", + "output_type": str, + "output_passthrough": True, + }, + ] + + router = ConditionalRouter(routes) + reloaded = ConditionalRouter.from_dict(router.to_dict()) + + assert reloaded.routes == router.routes + assert reloaded.routes[0].get("output_passthrough") is True + assert reloaded.routes[1].get("output_passthrough") is True + + assert reloaded.run(flag=True, value="hello") == {"matched": "hello"} + assert reloaded.run(flag=False, value="hello") == {"unmatched": "hello"} + + def test_conditional_router_passthrough_with_custom_type(self): + """Test passthrough routing for custom types without Jinja2.""" + from dataclasses import dataclass + + @dataclass + class CustomDocument: + content: str + metadata: dict + + routes = [ + { + "condition": "{{is_important}}", + "output": "document", + "output_name": "important", + "output_type": CustomDocument, + "output_passthrough": True, + }, + { + "condition": "{{not is_important}}", + "output": "document", + "output_name": "regular", + "output_type": CustomDocument, + "output_passthrough": True, + }, + ] + + router = ConditionalRouter(routes) + doc = CustomDocument(content="Important", metadata={"priority": "high"}) + + result = router.run(is_important=True, document=doc) + assert "important" in result + assert result["important"] == doc + assert result["important"].content == "Important" + + result = router.run(is_important=False, document=doc) + assert "regular" in result + assert result["regular"] == doc + + def test_conditional_router_passthrough_missing_variable(self): + """Test that passthrough routing raises ValueError when the named variable is not provided.""" + routes = [ + { + "condition": "{{True}}", + "output": "missing_var", + "output_name": "out", + "output_type": str, + "output_passthrough": True, + } + ] + + router = ConditionalRouter(routes) + + with pytest.raises(ValueError, match="Variable 'missing_var' not found in inputs"): + router.run(other_var="value") + + def test_conditional_router_passthrough_mixed(self): + """Test mixing passthrough and Jinja2 routes in the same router.""" + routes = [ + { + "condition": "{{mode == 'direct'}}", + "output": "data", + "output_name": "direct_route", + "output_type": list, + "output_passthrough": True, + }, + { + "condition": "{{mode == 'transform'}}", + "output": "{{data | reverse | list}}", + "output_name": "transformed_route", + "output_type": list, + }, + ] + + router = ConditionalRouter(routes) + test_list = [1, 2, 3] + + result = router.run(mode="direct", data=test_list) + assert result["direct_route"] == test_list + + result = router.run(mode="transform", data=test_list) + assert result["transformed_route"] == [3, 2, 1] + + def test_conditional_router_passthrough_multi_output(self): + """Test output_passthrough with a list of output variable names.""" + from dataclasses import dataclass + + @dataclass + class Payload: + body: str + + routes = [ + { + "condition": "{{flag}}", + "output": ["label", "payload"], + "output_name": ["out_label", "out_payload"], + "output_type": [str, Payload], + "output_passthrough": True, + } + ] + + router = ConditionalRouter(routes) + p = Payload(body="test") + result = router.run(flag=True, label="hello", payload=p) + assert result == {"out_label": "hello", "out_payload": p} + assert isinstance(result["out_payload"], Payload) + + def test_conditional_router_passthrough_validate_output_type_mismatch(self): + """Test that validate_output_type catches a type mismatch on a passthrough route.""" + routes = [ + { + "condition": "{{True}}", + "output": "value", + "output_name": "out", + "output_type": int, + "output_passthrough": True, + } + ] + + router = ConditionalRouter(routes, validate_output_type=True) + + with pytest.raises(ValueError, match="type doesn't match"): + router.run(value="not_an_int") + + def test_conditional_router_passthrough_optional_variable_routes_none(self): + """Test that a passthrough variable in optional_variables routes None when the pipeline omits it. + + optional_variables registers the input with default=None. Inside a pipeline, missing optional + inputs are filled with their default before run() is called. We simulate that here by passing + maybe_value=None explicitly. + """ + routes = [ + { + "condition": "{{True}}", + "output": "maybe_value", + "output_name": "out", + "output_type": str, + "output_passthrough": True, + } + ] + + router = ConditionalRouter(routes, optional_variables=["maybe_value"]) + # Simulate pipeline behaviour: optional input not connected → filled with default None + result = router.run(maybe_value=None) + assert result == {"out": None} + + def test_conditional_router_passthrough_skips_output_template_validation(self): + """Test that an invalid Jinja2 string in output is accepted when output_passthrough is True.""" + routes = [ + { + "condition": "{{True}}", + "output": "{{unclosed", # would be rejected as a Jinja2 template + "output_name": "out", + "output_type": str, + "output_passthrough": True, + } + ] + + # Construction must not raise even though the output string is not valid Jinja2 + router = ConditionalRouter(routes) + result = router.run(**{"{{unclosed": "value"}) + assert result == {"out": "value"}