From 79d3c1b09b599a41bd01305a6a433343c3f20e0e Mon Sep 17 00:00:00 2001 From: bogdankostic Date: Fri, 29 May 2026 14:06:11 +0200 Subject: [PATCH 1/3] feat!: Gate pipeline deserialization through a module allowlist --- MIGRATION.md | 55 +++ haystack/core/pipeline/base.py | 67 +++- haystack/core/serialization.py | 52 ++- haystack/core/serialization_security.py | 142 ++++++++ haystack/utils/callable_serialization.py | 13 +- haystack/utils/type_serialization.py | 11 +- ...ialization-allowlist-d878a69ac69ee667.yaml | 52 +++ .../caching/test_url_cache_checker.py | 10 +- .../test_in_memory_bm25_retriever.py | 8 +- .../test_in_memory_embedding_retriever.py | 8 +- .../writers/test_document_writer.py | 10 +- test/conftest.py | 16 + test/core/pipeline/test_pipeline_base.py | 42 ++- test/core/test_serialization.py | 46 ++- test/core/test_serialization_security.py | 339 ++++++++++++++++++ 15 files changed, 847 insertions(+), 24 deletions(-) create mode 100644 haystack/core/serialization_security.py create mode 100644 releasenotes/notes/pipeline-deserialization-allowlist-d878a69ac69ee667.yaml create mode 100644 test/core/test_serialization_security.py diff --git a/MIGRATION.md b/MIGRATION.md index 17ee14b2d5..36a78d6390 100644 --- a/MIGRATION.md +++ b/MIGRATION.md @@ -423,3 +423,58 @@ builder = PromptBuilder( ) builder.run(name="John") # greeting renders as "" ``` + +### Pipeline deserialization is gated by a module allowlist + +**What changed:** `Pipeline.load`, `Pipeline.loads`, and `Pipeline.from_dict` now refuse to import classes from modules outside a trusted-module allowlist and raise a `DeserializationError` instead. The default allowlist contains `haystack`, `haystack_integrations`, `haystack_experimental`, `builtins`, `typing`, and `collections`. Pipelines that reference custom components, callables, or types in other packages will fail to load until those modules are explicitly allowed. + +In addition, `default_from_dict` now rejects nested `{"type": "..."}` dictionaries whose key is not an `__init__` parameter of the parent class — this can surface pre-existing YAML bugs (typos, leftovers from removed parameters, stale snapshots). + +**Why:** Loading a pipeline from YAML used to dynamically import any class referenced in the file, which made a crafted YAML capable of causing arbitrary classes to be imported and instantiated. Gating imports through an allowlist closes that gap while leaving Haystack's own packages working out of the box. + +**How to migrate:** + +If your pipeline only references components from `haystack`, `haystack_integrations`, or `haystack_experimental`, no action is needed. + +Otherwise, extend the allowlist via one of the four mechanisms below. + +Before (v2.x), all modules implicitly trusted: +```python +from haystack import Pipeline + +# Worked for any class on the import path, including third-party packages. +with open("pipeline.yaml") as fp: + pipeline = Pipeline.load(fp) +``` + +After (v3.0), pick one of the following options. The first two scope the trust to a single call; the others extend it process-wide. + +```python +# 1. Per-call kwarg — recommended for application code that knows exactly which extra +# packages a given YAML needs. +from haystack import Pipeline + +with open("pipeline.yaml") as fp: + pipeline_a = Pipeline.load(fp, allowed_modules=["mypkg.*", "anotherpkg.components.*"]) + +# 2. Per-call bypass — equivalent to "I fully trust this YAML; skip the allowlist". +# Mirrors the `yaml.safe_load` / `yaml.unsafe_load` convention. +with open("pipeline.yaml") as fp: + pipeline_b = Pipeline.load(fp, unsafe=True) + +# 3. Process-wide programmatic — call once at startup, e.g. in your application's +# entry point or a custom integration package's __init__. +from haystack.core.serialization import allow_deserialization_module + +allow_deserialization_module("mypkg.*") +with open("pipeline.yaml") as fp: + pipeline_c = Pipeline.load(fp) # `mypkg.*` is now trusted for every load in this process. +``` + +```bash +# 4. Environment variable — useful for ops/deployments where code shouldn't change. +# Comma-separated patterns; read at runtime on every deserialization call. +export HAYSTACK_DESERIALIZATION_ALLOWLIST="mypkg.*,otherpkg.*" +``` + +Patterns are matched as prefixes by default (`"mypkg"` matches `mypkg` and any submodule), or as `fnmatch` globs if they contain `*`, `?`, or `[` somewhere other than a trailing `.*`. diff --git a/haystack/core/pipeline/base.py b/haystack/core/pipeline/base.py index 8734e927c1..7180e1d828 100644 --- a/haystack/core/pipeline/base.py +++ b/haystack/core/pipeline/base.py @@ -40,6 +40,7 @@ component_to_dict, generate_qualified_class_name, ) +from haystack.core.serialization_security import _check_module_allowed, _deserialization_context from haystack.core.type_utils import ( ConversionStrategyType, _convert_value, @@ -173,7 +174,13 @@ def to_dict(self) -> dict[str, Any]: @classmethod def from_dict( - cls: type[T], data: dict[str, Any], callbacks: DeserializationCallbacks | None = None, **kwargs: Any + cls: type[T], + data: dict[str, Any], + callbacks: DeserializationCallbacks | None = None, + *, + allowed_modules: list[str] | None = None, + unsafe: bool = False, + **kwargs: Any, ) -> T: """ Deserializes the pipeline from a dictionary. @@ -182,12 +189,28 @@ def from_dict( Dictionary to deserialize from. :param callbacks: Callbacks to invoke during deserialization. + :param allowed_modules: + Additional module patterns whose classes may be imported during deserialization. + By default, only modules under `haystack`, `haystack_integrations`, `haystack_experimental`, + `builtins`, `typing`, and `collections` are trusted. See + `haystack.core.serialization.allow_deserialization_module` for the matching semantics. + :param unsafe: + If `True`, bypass the deserialization allowlist entirely. Only use this when you fully + trust the source of the serialized data — any class in any importable module can be + instantiated. :param kwargs: `components`: a dictionary of `{name: instance}` to reuse instances of components instead of creating new ones. :returns: Deserialized component. """ + with _deserialization_context(allowed_modules=allowed_modules, unsafe=unsafe): + return cls._from_dict_impl(data, callbacks, **kwargs) + + @classmethod + def _from_dict_impl( + cls: type[T], data: dict[str, Any], callbacks: DeserializationCallbacks | None = None, **kwargs: Any + ) -> T: data_copy = _deepcopy_with_exceptions(data) # to prevent modification of original data metadata = data_copy.get("metadata", {}) max_runs_per_component = data_copy.get("max_runs_per_component", 100) @@ -206,28 +229,32 @@ def from_dict( if "type" not in component_data: raise PipelineError(f"Missing 'type' in component '{name}'") - if component_data["type"] not in component.registry: + component_type = component_data["type"] + if isinstance(component_type, str) and "." in component_type: + _check_module_allowed(component_type.rsplit(".", 1)[0]) + + if component_type not in component.registry: try: # Import the module first... - module, _ = component_data["type"].rsplit(".", 1) + module, _ = component_type.rsplit(".", 1) logger.debug("Trying to import module {module_name}", module_name=module) type_serialization.thread_safe_import(module) # ...then try again - if component_data["type"] not in component.registry: + if component_type not in component.registry: raise PipelineError( # noqa: TRY301 f"Successfully imported module '{module}' but couldn't find " - f"'{component_data['type']}' in the component registry.\n" + f"'{component_type}' in the component registry.\n" f"The component might be registered under a different path. " f"Here are the registered components:\n {list(component.registry.keys())}\n" ) except (ImportError, PipelineError, ValueError) as e: raise PipelineError( - f"Component '{component_data['type']}' (name: '{name}') not imported. Please " + f"Component '{component_type}' (name: '{name}') not imported. Please " f"check that the package is installed and the component path is correct." ) from e # Create a new one - component_class = component.registry[component_data["type"]] + component_class = component.registry[component_type] try: instance = component_from_dict(component_class, component_data, name, callbacks) @@ -287,6 +314,9 @@ def loads( data: str | bytes | bytearray, marshaller: Marshaller = DEFAULT_MARSHALLER, callbacks: DeserializationCallbacks | None = None, + *, + allowed_modules: list[str] | None = None, + unsafe: bool = False, ) -> T: """ Creates a `Pipeline` object from the string representation passed in the `data` argument. @@ -297,6 +327,14 @@ def loads( The Marshaller used to create the string representation. Defaults to `YamlMarshaller`. :param callbacks: Callbacks to invoke during deserialization. + :param allowed_modules: + Additional module patterns whose classes may be imported during deserialization. + By default, only modules under `haystack`, `haystack_integrations`, `haystack_experimental`, + `builtins`, `typing`, and `collections` are trusted. + :param unsafe: + If `True`, bypass the deserialization allowlist entirely. Only use this when you fully + trust the source of the serialized data — any class in any importable module can be + instantiated. :raises DeserializationError: If an error occurs during deserialization. :returns: @@ -310,7 +348,7 @@ def loads( "caused by malformed or invalid syntax in the serialized representation." ) from e - return cls.from_dict(deserialized_data, callbacks) + return cls.from_dict(deserialized_data, callbacks, allowed_modules=allowed_modules, unsafe=unsafe) @classmethod def load( @@ -318,6 +356,9 @@ def load( fp: TextIO, marshaller: Marshaller = DEFAULT_MARSHALLER, callbacks: DeserializationCallbacks | None = None, + *, + allowed_modules: list[str] | None = None, + unsafe: bool = False, ) -> T: """ Creates a `Pipeline` object a string representation. @@ -331,12 +372,20 @@ def load( The Marshaller used to create the string representation. Defaults to `YamlMarshaller`. :param callbacks: Callbacks to invoke during deserialization. + :param allowed_modules: + Additional module patterns whose classes may be imported during deserialization. + By default, only modules under `haystack`, `haystack_integrations`, `haystack_experimental`, + `builtins`, `typing`, and `collections` are trusted. + :param unsafe: + If `True`, bypass the deserialization allowlist entirely. Only use this when you fully + trust the source of the serialized data — any class in any importable module can be + instantiated. :raises DeserializationError: If an error occurs during deserialization. :returns: A `Pipeline` object. """ - return cls.loads(fp.read(), marshaller, callbacks) + return cls.loads(fp.read(), marshaller, callbacks, allowed_modules=allowed_modules, unsafe=unsafe) def add_component(self, name: str, instance: Component) -> None: """ diff --git a/haystack/core/serialization.py b/haystack/core/serialization.py index 082eb492e7..900ecfe382 100644 --- a/haystack/core/serialization.py +++ b/haystack/core/serialization.py @@ -10,10 +10,22 @@ from haystack import logging from haystack.core.component.component import _hook_component_init from haystack.core.errors import DeserializationError, SerializationError +from haystack.core.serialization_security import _check_module_allowed, allow_deserialization_module from haystack.utils.auth import Secret from haystack.utils.device import ComponentDevice from haystack.utils.type_serialization import thread_safe_import +__all__ = [ + "DeserializationCallbacks", + "allow_deserialization_module", + "component_from_dict", + "component_to_dict", + "default_from_dict", + "default_to_dict", + "generate_qualified_class_name", + "import_class_by_name", +] + logger = logging.getLogger(__name__) @@ -287,6 +299,8 @@ def default_from_dict(cls: type[T], data: dict[str, Any]) -> T: if data["type"] != generate_qualified_class_name(cls): raise DeserializationError(f"Class '{data['type']}' can't be deserialized as '{cls.__name__}'") + valid_init_param_names = _init_parameter_names(cls) + # Automatically detect and deserialize objects with from_dict methods for key, value in init_params.items(): if isinstance(value, dict) and "type" in value: @@ -299,6 +313,12 @@ def default_from_dict(cls: type[T], data: dict[str, Any]) -> T: init_params[key] = ComponentDevice.from_dict(value) # If type looks like a fully qualified class name, try to import it and deserialize elif isinstance(type_value, str) and "." in type_value: + # Reject before importing if the parent class does not accept this parameter. + # This blocks YAML that smuggles untrusted classes into unused parameter slots. + if valid_init_param_names is not None and key not in valid_init_param_names: + raise DeserializationError( + f"Refusing to deserialize unknown parameter '{key}' for '{cls.__name__}'." + ) try: imported_class = import_class_by_name(type_value) if hasattr(imported_class, "from_dict") and callable(imported_class.from_dict): @@ -311,6 +331,30 @@ def default_from_dict(cls: type[T], data: dict[str, Any]) -> T: return cls(**init_params) +def _init_parameter_names(cls: type[object]) -> set[str] | None: + """ + Return the set of init parameter names accepted by `cls`. + + Returns `None` if the constructor accepts arbitrary keyword arguments (`**kwargs`) — in + which case we cannot validate keys. + """ + try: + signature = inspect.signature(cls.__init__) + except (TypeError, ValueError): + return None + names: set[str] = set() + for name, param in signature.parameters.items(): + if name == "self": + continue + if param.kind is inspect.Parameter.VAR_KEYWORD: + # Constructor accepts **kwargs; we cannot tell whether `key` is a real parameter. + return None + if param.kind is inspect.Parameter.VAR_POSITIONAL: + continue + names.add(name) + return names + + def import_class_by_name(fully_qualified_name: str) -> type[object]: """ Utility function to import (load) a class object based on its fully qualified class name. @@ -319,12 +363,18 @@ def import_class_by_name(fully_qualified_name: str) -> type[object]: It splits the name into module path and class name, imports the module, and returns the class object. + For security, the module path is checked against the deserialization allowlist + (see :mod:`haystack.core.serialization_security`). Modules outside the allowlist + are rejected with a :class:`DeserializationError`. + :param fully_qualified_name: the fully qualified class name as a string :returns: the class object. :raises ImportError: If the class cannot be imported or found. + :raises DeserializationError: If the module is not on the deserialization allowlist. """ + module_path, class_name = fully_qualified_name.rsplit(".", 1) + _check_module_allowed(module_path) try: - module_path, class_name = fully_qualified_name.rsplit(".", 1) logger.debug( "Attempting to import class '{cls_name}' from module '{md_path}'", cls_name=class_name, md_path=module_path ) diff --git a/haystack/core/serialization_security.py b/haystack/core/serialization_security.py new file mode 100644 index 0000000000..f036b9d50d --- /dev/null +++ b/haystack/core/serialization_security.py @@ -0,0 +1,142 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +""" +Security primitives for pipeline deserialization. + +This module provides an allowlist mechanism that gates arbitrary imports. + +Three ways to extend the allowlist: +- Per-call kwarg: `Pipeline.load(..., allowed_modules=["mypkg.*"])` +- Process-wide programmatic API: :func:`allow_deserialization_module` +- Environment variable: `HAYSTACK_DESERIALIZATION_ALLOWLIST="mypkg.*,otherpkg.*"` + +The two-mode loading API (`unsafe=True`) bypasses the allowlist entirely. +""" + +import contextvars +import fnmatch +import os +from collections.abc import Iterable, Iterator +from contextlib import contextmanager +from dataclasses import dataclass, field + +from haystack.core.errors import DeserializationError + +# The default allowlist covers Haystack's own packages plus a small set of standard-library type modules +# that are commonly referenced in serialized type annotations (e.g. `typing.List[str]`, +# `collections.deque`). Importing these modules has no meaningful side effects on its own. +DEFAULT_ALLOWED_MODULES: tuple[str, ...] = ( + "haystack", + "haystack_integrations", + "haystack_experimental", + "builtins", + "typing", + "collections", +) +DESERIALIZATION_ALLOWLIST_ENV_VAR = "HAYSTACK_DESERIALIZATION_ALLOWLIST" + + +@dataclass(frozen=True) +class _DeserializationContext: + extra_allowed: tuple[str, ...] = field(default_factory=tuple) + unsafe: bool = False + + +_current_context: contextvars.ContextVar[_DeserializationContext | None] = contextvars.ContextVar( + "haystack_deserialization_context", default=None +) + + +def _get_context() -> _DeserializationContext: + ctx = _current_context.get() + return ctx if ctx is not None else _DeserializationContext() + + +# Process-wide patterns set via allow_deserialization_module. +_extra_allowed_modules: list[str] = [] + + +def allow_deserialization_module(pattern: str) -> None: + """ + Add a module pattern to the process-wide deserialization allowlist. + + Once added, classes from modules matching the pattern can be deserialized from YAML / dict + representations until the process exits. + + A pattern matches a module name if: + - The pattern contains `*`, `?` or `[` — :mod:`fnmatch` semantics are used. + - Otherwise the pattern is treated as a prefix: a module matches if it equals the pattern or + is a submodule of it (i.e. starts with `pattern + "."`). A trailing `.*` is stripped + before this comparison, so `"mypkg"` and `"mypkg.*"` behave identically. + + :param pattern: + The module pattern to allow. + """ + if pattern not in _extra_allowed_modules: + _extra_allowed_modules.append(pattern) + + +def _module_matches(module_name: str, pattern: str) -> bool: + """Return whether `module_name` matches the given allowlist `pattern`.""" + # `pkg.*` is treated as a prefix match (matches `pkg` and any submodule); this is the most + # common form, and we want it to match the bare top-level package too, which fnmatch wouldn't. + if pattern.endswith(".*"): + pattern = pattern[:-2] + return module_name == pattern or module_name.startswith(pattern + ".") + if any(c in pattern for c in "*?["): + return fnmatch.fnmatchcase(module_name, pattern) + return module_name == pattern or module_name.startswith(pattern + ".") + + +def _patterns_from_env() -> list[str]: + raw = os.environ.get(DESERIALIZATION_ALLOWLIST_ENV_VAR, "") + return [p.strip() for p in raw.split(",") if p.strip()] + + +def _is_module_allowed(module_name: str) -> bool: + """Return whether `module_name` is on the active deserialization allowlist.""" + ctx = _get_context() + if ctx.unsafe: + return True + patterns: list[str] = [] + patterns.extend(DEFAULT_ALLOWED_MODULES) + patterns.extend(_extra_allowed_modules) + patterns.extend(_patterns_from_env()) + patterns.extend(ctx.extra_allowed) + return any(_module_matches(module_name, p) for p in patterns) + + +def _check_module_allowed(module_name: str) -> None: + """Raise :class:`DeserializationError` if `module_name` is not on the allowlist.""" + if _is_module_allowed(module_name): + return + raise DeserializationError( + f"Refusing to deserialize a class from module '{module_name}': the module is not on the " + f"trusted-module allowlist. If you trust the source of this serialized data, you can either:\n" + f" - extend the allowlist for this call: " + f"Pipeline.load(..., allowed_modules=['{module_name}']),\n" + f" - extend it process-wide via haystack.core.serialization.allow_deserialization_module" + f"('{module_name}') or the {DESERIALIZATION_ALLOWLIST_ENV_VAR} environment variable,\n" + f" - or bypass the allowlist entirely: Pipeline.load(..., unsafe=True)." + ) + + +@contextmanager +def _deserialization_context(allowed_modules: Iterable[str] | None = None, unsafe: bool = False) -> Iterator[None]: + """ + Context manager that activates a per-call deserialization context. + + Patterns from `allowed_modules` are appended to the parent context's patterns, and `unsafe` + is OR-ed with the parent's `unsafe` flag — so this never narrows the active permissions. + The previous context is restored on exit. + """ + parent = _get_context() + extra = parent.extra_allowed + (tuple(allowed_modules) if allowed_modules else ()) + merged_unsafe = parent.unsafe or unsafe + token = _current_context.set(_DeserializationContext(extra_allowed=extra, unsafe=merged_unsafe)) + try: + yield + finally: + _current_context.reset(token) diff --git a/haystack/utils/callable_serialization.py b/haystack/utils/callable_serialization.py index f242c39a3a..7286bc61d9 100644 --- a/haystack/utils/callable_serialization.py +++ b/haystack/utils/callable_serialization.py @@ -7,6 +7,7 @@ from typing import Any from haystack.core.errors import DeserializationError, SerializationError +from haystack.core.serialization_security import _check_module_allowed from haystack.utils.type_serialization import thread_safe_import @@ -47,9 +48,18 @@ def deserialize_callable(callable_handle: str) -> Callable: """ Deserializes a callable given its full import path as a string. + Every module path tried during resolution is checked against the + deserialization allowlist (see `haystack.core.serialization_security`). Callables in modules + outside the allowlist are rejected with a `DeserializationError` before any import is + attempted. To allow a third-party module, extend the allowlist via + `Pipeline.load(..., allowed_modules=[...])`, `allow_deserialization_module(...)`, or the + `HAYSTACK_DESERIALIZATION_ALLOWLIST` environment variable. + :param callable_handle: The full path of the callable_handle :return: The callable - :raises DeserializationError: If the callable cannot be found + :raises DeserializationError: + If the module path is not on the deserialization allowlist, or if the callable cannot + be found. """ # Import here to avoid circular imports from haystack.tools.tool import Tool @@ -58,6 +68,7 @@ def deserialize_callable(callable_handle: str) -> Callable: for i in range(len(parts), 0, -1): module_name = ".".join(parts[:i]) + _check_module_allowed(module_name) try: mod: Any = thread_safe_import(module_name) except Exception: diff --git a/haystack/utils/type_serialization.py b/haystack/utils/type_serialization.py index 973b105ef4..6699889f70 100644 --- a/haystack/utils/type_serialization.py +++ b/haystack/utils/type_serialization.py @@ -12,6 +12,7 @@ from typing import Any, Union, get_args from haystack.core.errors import DeserializationError +from haystack.core.serialization_security import _check_module_allowed _import_lock = Lock() @@ -147,12 +148,18 @@ def deserialize_type(type_str: str) -> Any: and then retrieve the type object from it. It also handles nested generic types like `list[dict[int, str]]`. + Every module path with a `.` prefix is checked against the deserialization + allowlist (see `haystack.core.serialization_security`) before being imported. Modules outside + the allowlist are rejected with a `DeserializationError`. Builtin and `typing`/`collections` + names without a module prefix bypass this check. + :param type_str: The string representation of the type's full import path. :returns: The deserialized type object. :raises DeserializationError: - If the type cannot be deserialized due to missing module or type. + If the module is not on the deserialization allowlist, or if the type cannot be + deserialized due to a missing module or type. """ # Handle PEP 604 union syntax at the top level (e.g., "str | int", "str | None") pep604_union_args = _parse_pep604_union_args(type_str) @@ -181,6 +188,8 @@ def deserialize_type(type_str: str) -> Any: module_name = ".".join(parts[:-1]) type_name = parts[-1] + _check_module_allowed(module_name) + module = sys.modules.get(module_name) if module is None: try: diff --git a/releasenotes/notes/pipeline-deserialization-allowlist-d878a69ac69ee667.yaml b/releasenotes/notes/pipeline-deserialization-allowlist-d878a69ac69ee667.yaml new file mode 100644 index 0000000000..6c1d205a45 --- /dev/null +++ b/releasenotes/notes/pipeline-deserialization-allowlist-d878a69ac69ee667.yaml @@ -0,0 +1,52 @@ +--- +upgrade: + - | + Pipeline deserialization (``Pipeline.load`` / ``Pipeline.loads`` / ``Pipeline.from_dict``) now + refuses to import classes from modules outside a trusted-module allowlist. By default the + allowlist contains ``haystack``, ``haystack_integrations``, ``haystack_experimental``, + ``builtins``, ``typing``, and ``collections``. Pipelines that reference custom components or callables in + other packages will fail to load with a ``DeserializationError`` until the additional modules + are added to the allowlist. + + To restore loading of such pipelines, extend the allowlist via one of: + + - per-call kwarg: ``Pipeline.load(fp, allowed_modules=["mypkg.*"])``; + - process-wide programmatic API: + ``from haystack.core.serialization import allow_deserialization_module``; + - environment variable: ``HAYSTACK_DESERIALIZATION_ALLOWLIST="mypkg.*,otherpkg.*"``; + - or, if the source of the serialized data is fully trusted, bypass the allowlist with + ``Pipeline.load(fp, unsafe=True)``. + - | + During pipeline deserialization, ``default_from_dict`` now validates the keys of + ``init_parameters`` against the parent class's ``__init__`` signature before recursing into + any nested ``{"type": "...", "init_parameters": {...}}`` dictionary. A nested dict whose key + is not an accepted parameter name on the parent class is rejected with a ``DeserializationError`` + *before* the nested type is imported. Classes whose constructor takes ``**kwargs`` are + exempt, since their accepted parameter set cannot be statically determined. + + This may surface pre-existing YAML bugs — e.g. typos, leftovers from renamed or removed + parameters, or stale snapshots from older Haystack versions. The fix is to update the YAML + so each nested-component key matches a real ``__init__`` parameter on the parent class. +security: + - | + Loading a pipeline from YAML used to dynamically import any class referenced in the file, + which made a crafted YAML capable of causing arbitrary classes to be imported and instantiated + by ``Pipeline.from_dict``. Deserialization is now gated by an allowlist of trusted module + namespaces (see the upgrade note for how to extend it). In addition, + ``default_from_dict`` now refuses to recurse into nested ``{"type": "..."}`` dictionaries + keyed by parameter names that the parent class does not actually accept — blocking attempts + to smuggle untrusted classes into unused parameter slots. +features: + - | + ``Pipeline.load`` / ``Pipeline.loads`` / ``Pipeline.from_dict`` now accept two new keyword + arguments: + + - ``allowed_modules: list[str] | None`` — additional module patterns that may be imported + during this deserialization call. Patterns are matched as prefixes (``"mypkg"`` matches + ``mypkg`` and any submodule) or, if they contain ``*``/``?``/``[``, using ``fnmatch`` rules. + - ``unsafe: bool = False`` — when set to ``True``, the deserialization allowlist is bypassed + entirely. Only use this when you fully trust the source of the serialized data. + + A new public helper, ``haystack.core.serialization.allow_deserialization_module(pattern)``, + extends the process-wide allowlist. The ``HAYSTACK_DESERIALIZATION_ALLOWLIST`` environment + variable (comma-separated patterns) is read at runtime on every deserialization call. diff --git a/test/components/caching/test_url_cache_checker.py b/test/components/caching/test_url_cache_checker.py index 7501e26460..f3330c0b62 100644 --- a/test/components/caching/test_url_cache_checker.py +++ b/test/components/caching/test_url_cache_checker.py @@ -60,11 +60,17 @@ def test_from_dict_without_docstore(self): CacheChecker.from_dict(data) def test_from_dict_nonexisting_docstore(self): + # Use a type whose module passes the deserialization allowlist (haystack.*) but cannot be + # resolved, so we still exercise the "import failed" code path rather than the allowlist gate. data = { "type": "haystack.components.caching.cache_checker.CacheChecker", - "init_parameters": {"document_store": {"type": "Nonexisting.DocumentStore", "init_parameters": {}}}, + "init_parameters": { + "document_store": {"type": "haystack.does.not.exist.DocumentStore", "init_parameters": {}} + }, } - with pytest.raises(ImportError, match=r"Failed to deserialize 'document_store':.*Nonexisting\.DocumentStore"): + with pytest.raises( + ImportError, match=r"Failed to deserialize 'document_store':.*haystack\.does\.not\.exist\.DocumentStore" + ): CacheChecker.from_dict(data) def test_run(self, in_memory_doc_store): diff --git a/test/components/retrievers/test_in_memory_bm25_retriever.py b/test/components/retrievers/test_in_memory_bm25_retriever.py index 5946ce9b7c..42cbcb56ce 100644 --- a/test/components/retrievers/test_in_memory_bm25_retriever.py +++ b/test/components/retrievers/test_in_memory_bm25_retriever.py @@ -118,11 +118,15 @@ def test_from_dict_without_docstore_type(self): InMemoryBM25Retriever.from_dict(data) def test_from_dict_nonexisting_docstore(self): + # Use a type whose module passes the deserialization allowlist (haystack.*) but cannot be + # resolved, so we still exercise the "import failed" code path rather than the allowlist gate. data = { "type": "haystack.components.retrievers.in_memory.bm25_retriever.InMemoryBM25Retriever", - "init_parameters": {"document_store": {"type": "Nonexisting.Docstore", "init_parameters": {}}}, + "init_parameters": {"document_store": {"type": "haystack.does.not.exist.Docstore", "init_parameters": {}}}, } - with pytest.raises(ImportError, match=r"Failed to deserialize 'document_store':.*Nonexisting\.Docstore"): + with pytest.raises( + ImportError, match=r"Failed to deserialize 'document_store':.*haystack\.does\.not\.exist\.Docstore" + ): InMemoryBM25Retriever.from_dict(data) def test_retriever_valid_run(self, in_memory_doc_store, mock_docs): diff --git a/test/components/retrievers/test_in_memory_embedding_retriever.py b/test/components/retrievers/test_in_memory_embedding_retriever.py index a14d699f32..024495c312 100644 --- a/test/components/retrievers/test_in_memory_embedding_retriever.py +++ b/test/components/retrievers/test_in_memory_embedding_retriever.py @@ -113,11 +113,15 @@ def test_from_dict_without_docstore_type(self): InMemoryEmbeddingRetriever.from_dict(data) def test_from_dict_nonexisting_docstore(self): + # Use a type whose module passes the deserialization allowlist (haystack.*) but cannot be + # resolved, so we still exercise the "import failed" code path rather than the allowlist gate. data = { "type": "haystack.components.retrievers.in_memory.embedding_retriever.InMemoryEmbeddingRetriever", - "init_parameters": {"document_store": {"type": "Nonexisting.Docstore", "init_parameters": {}}}, + "init_parameters": {"document_store": {"type": "haystack.does.not.exist.Docstore", "init_parameters": {}}}, } - with pytest.raises(ImportError, match=r"Failed to deserialize 'document_store':.*Nonexisting\.Docstore"): + with pytest.raises( + ImportError, match=r"Failed to deserialize 'document_store':.*haystack\.does\.not\.exist\.Docstore" + ): InMemoryEmbeddingRetriever.from_dict(data) def test_valid_run(self): diff --git a/test/components/writers/test_document_writer.py b/test/components/writers/test_document_writer.py index 9bd4c90f44..9c12c9d992 100644 --- a/test/components/writers/test_document_writer.py +++ b/test/components/writers/test_document_writer.py @@ -71,11 +71,17 @@ def test_from_dict_without_docstore(self): DocumentWriter.from_dict(data) def test_from_dict_nonexisting_docstore(self): + # Use a type whose module passes the deserialization allowlist (haystack.*) but cannot be + # resolved, so we still exercise the "import failed" code path rather than the allowlist gate. data = { "type": "haystack.components.writers.document_writer.DocumentWriter", - "init_parameters": {"document_store": {"type": "Nonexisting.DocumentStore", "init_parameters": {}}}, + "init_parameters": { + "document_store": {"type": "haystack.does.not.exist.DocumentStore", "init_parameters": {}} + }, } - with pytest.raises(ImportError, match=r"Failed to deserialize 'document_store':.*Nonexisting\.DocumentStore"): + with pytest.raises( + ImportError, match=r"Failed to deserialize 'document_store':.*haystack\.does\.not\.exist\.DocumentStore" + ): DocumentWriter.from_dict(data) def test_run(self, in_memory_doc_store): diff --git a/test/conftest.py b/test/conftest.py index 796880688e..a46a61308a 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -12,6 +12,7 @@ import pytest from haystack import component, tracing +from haystack.core.serialization import allow_deserialization_module from haystack.document_stores.in_memory import InMemoryDocumentStore from haystack.testing.test_utils import set_all_seeds from test.tracing.utils import SpyingTracer @@ -23,6 +24,21 @@ tracing.disable_tracing() +# Tests legitimately deserialize callables/components/types from a handful of modules that aren't +# part of the default Haystack allowlist. We extend the allowlist explicitly. +# +# Tests that exercise the rejection path themselves install a clean context (and clear the +# process-wide patterns); see `test/core/test_serialization_security.py`. +for _pattern in ( + "test_*", # top-level `test_` modules (pytest rootdir-level files) + "*.test_*", # `.test_` modules (pytest treats sub-packages this way) + "test.*", # modules inside the proper `test` package (with __init__.py) + "pydantic", # pydantic models used in base-serialization tests + "httpx", # used in callable-serialization tests +): + allow_deserialization_module(_pattern) + + @pytest.fixture() def in_memory_doc_store(): store = InMemoryDocumentStore() diff --git a/test/core/pipeline/test_pipeline_base.py b/test/core/pipeline/test_pipeline_base.py index ccc2ed532b..0aaecfd2ff 100644 --- a/test/core/pipeline/test_pipeline_base.py +++ b/test/core/pipeline/test_pipeline_base.py @@ -626,9 +626,11 @@ def test_from_dict_without_component_type(self): err.match("Missing 'type' in component 'add_two'") def test_from_dict_without_registered_component_type(self): + # A component type whose module passes the allowlist but cannot be imported should + # surface as a `PipelineError` ("not imported"). data = { "metadata": {"test": "test"}, - "components": {"add_two": {"type": "foo.bar.baz", "init_parameters": {"add": 2}}}, + "components": {"add_two": {"type": "haystack.does.not.exist.Component", "init_parameters": {"add": 2}}}, "connections": [], } with pytest.raises(PipelineError) as err: @@ -636,6 +638,44 @@ def test_from_dict_without_registered_component_type(self): err.match(r"Component .+ not imported.") + def test_from_dict_rejects_untrusted_component_module(self): + data = { + "metadata": {"test": "test"}, + "components": {"add_two": {"type": "foo.bar.baz", "init_parameters": {"add": 2}}}, + "connections": [], + } + with pytest.raises(DeserializationError, match="not on the trusted-module allowlist"): + PipelineBase.from_dict(data) + + def test_from_dict_with_unsafe_bypasses_allowlist(self): + # `unsafe=True` bypasses the allowlist but the import itself still fails because the module + # is nonexistent — proving that the allowlist check (not the import) is what changes. + data = { + "metadata": {"test": "test"}, + "components": {"add_two": {"type": "foo.bar.baz", "init_parameters": {"add": 2}}}, + "connections": [], + } + # Sanity check: without ``unsafe=True`` we'd get the allowlist rejection. + with pytest.raises(DeserializationError): + PipelineBase.from_dict(data) + # With ``unsafe=True`` the allowlist is bypassed; we fall through to a normal import error. + with pytest.raises(PipelineError, match="not imported"): + PipelineBase.from_dict(data, unsafe=True) + + def test_from_dict_with_allowed_modules_kwarg(self): + # Passing the third-party module via `allowed_modules` should make the allowlist check pass. + data = { + "metadata": {"test": "test"}, + "components": {"add_two": {"type": "foo.bar.baz", "init_parameters": {"add": 2}}}, + "connections": [], + } + # Without an extension, the allowlist rejects the module. + with pytest.raises(DeserializationError): + PipelineBase.from_dict(data) + # Passing the matching pattern lets us hit the actual import failure instead. + with pytest.raises(PipelineError, match="not imported"): + PipelineBase.from_dict(data, allowed_modules=["foo.*"]) + def test_from_dict_with_invalid_type(self): data = { "metadata": {"test": "test"}, diff --git a/test/core/test_serialization.py b/test/core/test_serialization.py index 6cab3d5eae..dce5a7f89e 100644 --- a/test/core/test_serialization.py +++ b/test/core/test_serialization.py @@ -108,11 +108,18 @@ def test_import_class_by_name(): def test_import_class_by_name_no_valid_class(): - data = "some.invalid.class" + # A name that passes the deserialization allowlist but cannot be resolved should raise ImportError. + data = "haystack.does.not.exist.Class" with pytest.raises(ImportError): import_class_by_name(data) +def test_import_class_by_name_rejects_untrusted_module(): + # A module outside the default allowlist is rejected before the import is attempted. + with pytest.raises(DeserializationError, match="not on the trusted-module allowlist"): + import_class_by_name("some.invalid.class") + + class CustomData: def __init__(self, a: int, b: str) -> None: self.a = a @@ -557,10 +564,43 @@ def test_default_from_dict_with_invalid_class_name(): data = { "type": generate_qualified_class_name(CustomComponentWithDocumentStore), "init_parameters": { - "document_store": {"type": "nonexistent.module.Class", "init_parameters": {}}, + # Use a class name that passes the allowlist (haystack.*) but cannot be resolved. + "document_store": {"type": "haystack.does.not.exist.Class", "init_parameters": {}}, "name": "test", }, } # Verify the error message includes the parameter key and original error - with pytest.raises(ImportError, match=r"Failed to deserialize 'document_store':.*nonexistent\.module\.Class"): + with pytest.raises( + ImportError, match=r"Failed to deserialize 'document_store':.*haystack\.does\.not\.exist\.Class" + ): + default_from_dict(CustomComponentWithDocumentStore, data) + + +def test_default_from_dict_rejects_untrusted_nested_class(): + """A nested class with a module outside the allowlist should be rejected.""" + data = { + "type": generate_qualified_class_name(CustomComponentWithDocumentStore), + "init_parameters": { + "document_store": {"type": "nonexistent.module.Class", "init_parameters": {}}, + "name": "test", + }, + } + with pytest.raises( + DeserializationError, match=r"Failed to deserialize 'document_store':.*not on the trusted-module allowlist" + ): + default_from_dict(CustomComponentWithDocumentStore, data) + + +def test_default_from_dict_rejects_unknown_nested_parameter(): + """A nested ``{type: ...}`` dict on a parameter that the class does not accept must be rejected + before the smuggled type is imported (Option 3: type-aware deserialization).""" + data = { + "type": generate_qualified_class_name(CustomComponentWithDocumentStore), + "init_parameters": { + # `payload` is not an init parameter of CustomComponentWithDocumentStore. + "payload": {"type": "haystack.testing.factory.MyComponent", "init_parameters": {}}, + "name": "test", + }, + } + with pytest.raises(DeserializationError, match=r"Refusing to deserialize unknown parameter 'payload'"): default_from_dict(CustomComponentWithDocumentStore, data) diff --git a/test/core/test_serialization_security.py b/test/core/test_serialization_security.py new file mode 100644 index 0000000000..c6619d23f6 --- /dev/null +++ b/test/core/test_serialization_security.py @@ -0,0 +1,339 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +import pytest + +from haystack.core.errors import DeserializationError +from haystack.core.serialization import allow_deserialization_module, import_class_by_name +from haystack.core.serialization_security import ( + DESERIALIZATION_ALLOWLIST_ENV_VAR, + _check_module_allowed, + _current_context, + _deserialization_context, + _DeserializationContext, + _extra_allowed_modules, + _is_module_allowed, + _module_matches, +) + + +@pytest.fixture(autouse=True) +def _reset_allowlist_state(): + """ + Force a clean (safe-default, no extra patterns) state for every test in this module so we are + testing the actual security model. The top-level test conftest extends the process-wide + allowlist with test-only patterns (`test_*`, `pydantic`, ...); we must clear those here so + "untrusted" really means untrusted. + """ + snapshot = list(_extra_allowed_modules) + _extra_allowed_modules.clear() + token = _current_context.set(_DeserializationContext()) + try: + yield + finally: + _extra_allowed_modules.clear() + _extra_allowed_modules.extend(snapshot) + _current_context.reset(token) + + +class TestModuleMatches: + def test_prefix_match_equal(self): + assert _module_matches("haystack", "haystack") + + def test_prefix_match_submodule(self): + assert _module_matches("haystack.components.builders", "haystack") + + def test_prefix_match_strips_trailing_wildcard(self): + assert _module_matches("haystack.components", "haystack.*") + assert _module_matches("haystack", "haystack.*") + + def test_prefix_match_not_a_partial_word(self): + assert not _module_matches("haystack_other", "haystack") + + def test_trailing_star_matches_submodules(self): + assert _module_matches("mypkg.components.foo", "mypkg.*") + assert _module_matches("mypkg.foo.bar", "mypkg.*") + + def test_trailing_star_does_not_match_unrelated(self): + assert not _module_matches("other.foo", "mypkg.*") + + def test_fnmatch_glob_in_middle(self): + assert _module_matches("pkg.foo.utils", "pkg.*.utils") + assert _module_matches("pkg.bar.utils", "pkg.*.utils") + + def test_fnmatch_glob_in_middle_no_match(self): + assert not _module_matches("pkg.foo.helpers", "pkg.*.utils") + + def test_fnmatch_single_char(self): + # `?` is an fnmatch wildcard for a single character. + assert _module_matches("pkga", "pkg?") + assert not _module_matches("pkgab", "pkg?") + + def test_fnmatch_character_class(self): + assert _module_matches("data_3", "data_[0-9]") + assert not _module_matches("data_x", "data_[0-9]") + + +class TestAllowlistDefaults: + def test_haystack_allowed(self): + assert _is_module_allowed("haystack") + assert _is_module_allowed("haystack.components.builders.prompt_builder") + + def test_haystack_integrations_allowed(self): + assert _is_module_allowed("haystack_integrations.components.retrievers") + + def test_haystack_experimental_allowed(self): + assert _is_module_allowed("haystack_experimental") + + def test_typing_allowed(self): + assert _is_module_allowed("typing") + + def test_collections_allowed(self): + assert _is_module_allowed("collections") + assert _is_module_allowed("collections.abc") + + def test_builtins_allowed(self): + assert _is_module_allowed("builtins") + + def test_arbitrary_third_party_not_allowed(self): + assert not _is_module_allowed("subprocess") + assert not _is_module_allowed("os") + + +class TestAllowDeserializationModule: + def test_extends_allowlist(self): + assert not _is_module_allowed("mypkg.components") + allow_deserialization_module("mypkg") + assert _is_module_allowed("mypkg") + assert _is_module_allowed("mypkg.components") + + def test_pattern_with_wildcard(self): + allow_deserialization_module("mypkg.components.*") + assert _is_module_allowed("mypkg.components.foo") + + def test_duplicate_pattern_only_added_once(self): + allow_deserialization_module("mypkg") + allow_deserialization_module("mypkg") + assert _extra_allowed_modules.count("mypkg") == 1 + + +class TestDeserializationContext: + def test_extra_allowed_modules_via_context(self): + assert not _is_module_allowed("mypkg.thing") + with _deserialization_context(allowed_modules=["mypkg"]): + assert _is_module_allowed("mypkg.thing") + # The per-call extension is reset on exit. + assert not _is_module_allowed("mypkg.thing") + + def test_unsafe_bypasses_allowlist(self): + assert not _is_module_allowed("subprocess") + with _deserialization_context(unsafe=True): + assert _is_module_allowed("subprocess") + assert _is_module_allowed("any.arbitrary.module") + assert not _is_module_allowed("subprocess") + + +class TestEnvVar: + def test_env_var_extends_allowlist(self, monkeypatch): + monkeypatch.setenv(DESERIALIZATION_ALLOWLIST_ENV_VAR, "mypkg.components.*,otherpkg") + assert _is_module_allowed("mypkg.components.foo") + assert _is_module_allowed("otherpkg") + assert _is_module_allowed("otherpkg.sub") + assert not _is_module_allowed("yetanother") + + def test_env_var_ignores_empty_entries(self, monkeypatch): + monkeypatch.setenv(DESERIALIZATION_ALLOWLIST_ENV_VAR, ", ,mypkg,,") + assert _is_module_allowed("mypkg.sub") + + +class TestCheckModuleAllowed: + def test_passes_silently_for_allowed_module(self): + _check_module_allowed("haystack.foo") + + def test_raises_for_disallowed_module(self): + with pytest.raises(DeserializationError, match="not on the trusted-module allowlist"): + _check_module_allowed("subprocess") + + def test_error_message_suggests_remediations(self): + with pytest.raises(DeserializationError) as exc_info: + _check_module_allowed("mypkg.evil") + message = str(exc_info.value) + assert "allowed_modules" in message + assert "allow_deserialization_module" in message + assert DESERIALIZATION_ALLOWLIST_ENV_VAR in message + assert "unsafe=True" in message + + +class TestImportClassByNameAllowlist: + def test_allowlisted_class(self): + cls = import_class_by_name("haystack.core.pipeline.Pipeline") + from haystack.core.pipeline import Pipeline + + assert cls is Pipeline + + def test_rejects_untrusted_module(self): + with pytest.raises(DeserializationError, match="not on the trusted-module allowlist"): + import_class_by_name("subprocess.Popen") + + def test_per_call_extension(self): + # subprocess is normally blocked + with pytest.raises(DeserializationError): + import_class_by_name("subprocess.Popen") + # ... but extending the allowlist for a single call lets it through. + with _deserialization_context(allowed_modules=["subprocess"]): + cls = import_class_by_name("subprocess.Popen") + import subprocess + + assert cls is subprocess.Popen + + +@pytest.fixture +def _registered_untrusted_component(): + """ + Set up a fake component class registered under a fully-qualified name in an untrusted module + (`evilpkg.evilmod.EvilComponent`). Yields a dict payload referencing it. The fixture cleans + up the registry on teardown. + """ + from haystack import component as component_module + + fake_type = "evilpkg.evilmod.EvilComponent" + + @component_module + class EvilComponent: + @component_module.output_types(value=int) + def run(self, value: int) -> dict[str, int]: + return {"value": value} + + registry = component_module.registry + original = registry.get(fake_type) + registry[fake_type] = EvilComponent + try: + yield { + "fake_type": fake_type, + "data": { + "metadata": {}, + "components": {"evil": {"type": fake_type, "init_parameters": {}}}, + "connections": [], + }, + } + finally: + if original is None: + registry.pop(fake_type, None) + else: + registry[fake_type] = original + + +class TestPipelineFromDictAllowlistBypass: + def test_pre_registered_untrusted_component_is_rejected(self, _registered_untrusted_component): + from haystack.core.pipeline import Pipeline + + with pytest.raises(DeserializationError, match="not on the trusted-module allowlist"): + Pipeline.from_dict(_registered_untrusted_component["data"]) + + def test_pre_registered_component_loadable_with_allowed_modules(self, _registered_untrusted_component): + """ + Counterpart to the bypass test: once the user opts the module into the allowlist, the + load gets past the allowlist gate. (It still fails downstream because the fake type name + doesn't match the test class's real qualified name — that's expected and proves the + allowlist gate, not a downstream check, is what changed.) + """ + from haystack.core.pipeline import Pipeline + + data = _registered_untrusted_component["data"] + # Without allowed_modules, this is rejected as untrusted. + with pytest.raises(DeserializationError, match="not on the trusted-module allowlist"): + Pipeline.from_dict(data) + # With the matching pattern, the allowlist gate passes; the failure now comes from + # the qualified-name mismatch in default_from_dict — a downstream check. + with pytest.raises(DeserializationError, match="can't be deserialized as"): + Pipeline.from_dict(data, allowed_modules=["evilpkg.*"]) + + +class TestPipelineLoadAndLoadsPropagation: + """ + Verify that the security kwargs added to `Pipeline.from_dict` are propagated correctly + through the `Pipeline.loads` (string) and `Pipeline.load` (file-like) entry points, and that + they produce equivalent behavior to calling `from_dict` directly. + """ + + @staticmethod + def _yaml_for(data: dict) -> str: + # We can't round-trip through `Pipeline.from_dict` + `dumps` because the registered + # `EvilComponent`'s real qualified name doesn't match the fake type — the inner + # `default_from_dict` would reject it. Build the YAML directly via the marshaller instead. + from haystack.marshal import YamlMarshaller + + return YamlMarshaller().marshal(data) + + def test_loads_rejects_untrusted_by_default(self, _registered_untrusted_component): + from haystack.core.pipeline import Pipeline + + yaml_str = self._yaml_for(_registered_untrusted_component["data"]) + with pytest.raises(DeserializationError, match="not on the trusted-module allowlist"): + Pipeline.loads(yaml_str) + + def test_loads_propagates_allowed_modules(self, _registered_untrusted_component): + from haystack.core.pipeline import Pipeline + + yaml_str = self._yaml_for(_registered_untrusted_component["data"]) + # With the matching pattern, the allowlist gate passes; downstream we get the type + # mismatch — proving the kwarg reached the gate. + with pytest.raises(DeserializationError, match="can't be deserialized as"): + Pipeline.loads(yaml_str, allowed_modules=["evilpkg.*"]) + + def test_loads_propagates_unsafe(self, _registered_untrusted_component): + from haystack.core.pipeline import Pipeline + + yaml_str = self._yaml_for(_registered_untrusted_component["data"]) + # `unsafe=True` bypasses the allowlist entirely; downstream we still get the type mismatch. + with pytest.raises(DeserializationError, match="can't be deserialized as"): + Pipeline.loads(yaml_str, unsafe=True) + + def test_load_rejects_untrusted_by_default(self, _registered_untrusted_component): + import io + + from haystack.core.pipeline import Pipeline + + yaml_str = self._yaml_for(_registered_untrusted_component["data"]) + with pytest.raises(DeserializationError, match="not on the trusted-module allowlist"): + Pipeline.load(io.StringIO(yaml_str)) + + def test_load_propagates_allowed_modules(self, _registered_untrusted_component): + import io + + from haystack.core.pipeline import Pipeline + + yaml_str = self._yaml_for(_registered_untrusted_component["data"]) + with pytest.raises(DeserializationError, match="can't be deserialized as"): + Pipeline.load(io.StringIO(yaml_str), allowed_modules=["evilpkg.*"]) + + def test_load_propagates_unsafe(self, _registered_untrusted_component): + import io + + from haystack.core.pipeline import Pipeline + + yaml_str = self._yaml_for(_registered_untrusted_component["data"]) + with pytest.raises(DeserializationError, match="can't be deserialized as"): + Pipeline.load(io.StringIO(yaml_str), unsafe=True) + + def test_load_loads_from_dict_equivalent_on_rejection(self, _registered_untrusted_component): + """All three entry points produce the same rejection message for the same untrusted payload.""" + import io + + from haystack.core.pipeline import Pipeline + + data = _registered_untrusted_component["data"] + yaml_str = self._yaml_for(data) + + def _capture(callable_) -> str: + with pytest.raises(DeserializationError) as exc_info: + callable_() + return str(exc_info.value) + + from_dict_msg = _capture(lambda: Pipeline.from_dict(data)) + loads_msg = _capture(lambda: Pipeline.loads(yaml_str)) + load_msg = _capture(lambda: Pipeline.load(io.StringIO(yaml_str))) + + assert "not on the trusted-module allowlist" in from_dict_msg + assert from_dict_msg == loads_msg == load_msg From 9e75eae927dd1546d615cfcdb7350e316f247b33 Mon Sep 17 00:00:00 2001 From: bogdankostic Date: Fri, 29 May 2026 14:37:42 +0200 Subject: [PATCH 2/3] fix mypy --- test/core/test_serialization_security.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test/core/test_serialization_security.py b/test/core/test_serialization_security.py index c6619d23f6..5953666012 100644 --- a/test/core/test_serialization_security.py +++ b/test/core/test_serialization_security.py @@ -320,13 +320,14 @@ def test_load_propagates_unsafe(self, _registered_untrusted_component): def test_load_loads_from_dict_equivalent_on_rejection(self, _registered_untrusted_component): """All three entry points produce the same rejection message for the same untrusted payload.""" import io + from collections.abc import Callable from haystack.core.pipeline import Pipeline data = _registered_untrusted_component["data"] yaml_str = self._yaml_for(data) - def _capture(callable_) -> str: + def _capture(callable_: Callable[[], object]) -> str: with pytest.raises(DeserializationError) as exc_info: callable_() return str(exc_info.value) From 0d9789b2b79354bf682b6952f5aaac6895a7955f Mon Sep 17 00:00:00 2001 From: bogdankostic Date: Sat, 30 May 2026 12:31:25 +0200 Subject: [PATCH 3/3] honor fnmatch patterns on module prefixes in the deserialization allowlist --- haystack/core/serialization_security.py | 13 ++-- haystack/utils/callable_serialization.py | 8 ++- test/core/test_serialization_security.py | 75 +++++++++++++----------- 3 files changed, 54 insertions(+), 42 deletions(-) diff --git a/haystack/core/serialization_security.py b/haystack/core/serialization_security.py index f036b9d50d..7041acc81b 100644 --- a/haystack/core/serialization_security.py +++ b/haystack/core/serialization_security.py @@ -80,11 +80,14 @@ def allow_deserialization_module(pattern: str) -> None: def _module_matches(module_name: str, pattern: str) -> bool: """Return whether `module_name` matches the given allowlist `pattern`.""" - # `pkg.*` is treated as a prefix match (matches `pkg` and any submodule); this is the most - # common form, and we want it to match the bare top-level package too, which fnmatch wouldn't. - if pattern.endswith(".*"): - pattern = pattern[:-2] - return module_name == pattern or module_name.startswith(pattern + ".") + # `pkg.*` (where the part before `.*` has no other wildcards) is treated as a prefix match — + # matches `pkg` and any submodule. This is the most common form, and we want it to match + # the bare top-level package too (which true fnmatch wouldn't, since `pkg.*` requires a + # literal `.` to follow). Patterns like `j*on.*` keep their wildcards and fall through to + # fnmatch so the semantics stay consistent. + if pattern.endswith(".*") and not any(c in pattern[:-2] for c in "*?["): + prefix = pattern[:-2] + return module_name == prefix or module_name.startswith(prefix + ".") if any(c in pattern for c in "*?["): return fnmatch.fnmatchcase(module_name, pattern) return module_name == pattern or module_name.startswith(pattern + ".") diff --git a/haystack/utils/callable_serialization.py b/haystack/utils/callable_serialization.py index 7286bc61d9..02513b1cd1 100644 --- a/haystack/utils/callable_serialization.py +++ b/haystack/utils/callable_serialization.py @@ -7,7 +7,7 @@ from typing import Any from haystack.core.errors import DeserializationError, SerializationError -from haystack.core.serialization_security import _check_module_allowed +from haystack.core.serialization_security import _check_module_allowed, _is_module_allowed from haystack.utils.type_serialization import thread_safe_import @@ -66,9 +66,13 @@ def deserialize_callable(callable_handle: str) -> Callable: parts = callable_handle.split(".") + # Allow if any prefix is on the allowlist; checking each one individually would wrongly + # reject patterns like `j*on` against `json.dumps` (matches `json`, not the full handle). + if not any(_is_module_allowed(".".join(parts[:i])) for i in range(1, len(parts) + 1)): + _check_module_allowed(callable_handle) # raises with the standard help message + for i in range(len(parts), 0, -1): module_name = ".".join(parts[:i]) - _check_module_allowed(module_name) try: mod: Any = thread_safe_import(module_name) except Exception: diff --git a/test/core/test_serialization_security.py b/test/core/test_serialization_security.py index 5953666012..e656e79361 100644 --- a/test/core/test_serialization_security.py +++ b/test/core/test_serialization_security.py @@ -2,9 +2,16 @@ # # SPDX-License-Identifier: Apache-2.0 +import io +import json +import subprocess +from collections.abc import Callable + import pytest +from haystack import component as component_module from haystack.core.errors import DeserializationError +from haystack.core.pipeline import Pipeline from haystack.core.serialization import allow_deserialization_module, import_class_by_name from haystack.core.serialization_security import ( DESERIALIZATION_ALLOWLIST_ENV_VAR, @@ -16,6 +23,8 @@ _is_module_allowed, _module_matches, ) +from haystack.marshal import YamlMarshaller +from haystack.utils import deserialize_callable @pytest.fixture(autouse=True) @@ -74,6 +83,15 @@ def test_fnmatch_character_class(self): assert _module_matches("data_3", "data_[0-9]") assert not _module_matches("data_x", "data_[0-9]") + def test_trailing_star_with_wildcards_in_prefix_uses_fnmatch(self): + # `j*on.*` has a `*` before the trailing `.*`, so it must NOT be short-circuited to a + # prefix match against the literal `j*on`. It should fall through to fnmatch. + assert _module_matches("json.tool", "j*on.*") + assert _module_matches("jaeon.subpkg.foo", "j*on.*") + # Pure fnmatch doesn't match the bare `json` for the pattern `j*on.*` (the `.*` requires + # a `.X` part). + assert not _module_matches("json", "j*on.*") + class TestAllowlistDefaults: def test_haystack_allowed(self): @@ -168,8 +186,6 @@ def test_error_message_suggests_remediations(self): class TestImportClassByNameAllowlist: def test_allowlisted_class(self): cls = import_class_by_name("haystack.core.pipeline.Pipeline") - from haystack.core.pipeline import Pipeline - assert cls is Pipeline def test_rejects_untrusted_module(self): @@ -183,11 +199,31 @@ def test_per_call_extension(self): # ... but extending the allowlist for a single call lets it through. with _deserialization_context(allowed_modules=["subprocess"]): cls = import_class_by_name("subprocess.Popen") - import subprocess - assert cls is subprocess.Popen +class TestDeserializeCallableAllowlist: + """ + `deserialize_callable` walks progressively-shorter module prefixes when resolving a dotted + name. The allowlist check must apply to "is *any* prefix on the allowlist?", not to each + individual candidate — otherwise fnmatch patterns that match the actual module but not the + full handle (e.g. `j*on` matches `json` but not `json.dumps`) would be wrongly rejected. + """ + + def test_fnmatch_pattern_matches_shorter_prefix(self): + # `j*on` matches `json` (the actual module) but not `json.dumps` (the full handle). + # The deferred allowlist check should still accept this. + with _deserialization_context(allowed_modules=["j*on"]): + fn = deserialize_callable("json.dumps") + assert fn is json.dumps + + def test_rejects_when_no_prefix_matches(self): + # No prefix of `subprocess.Popen` matches the default allowlist (or `unrelated`). + with _deserialization_context(allowed_modules=["unrelated"]): + with pytest.raises(DeserializationError, match="not on the trusted-module allowlist"): + deserialize_callable("subprocess.Popen") + + @pytest.fixture def _registered_untrusted_component(): """ @@ -195,8 +231,6 @@ def _registered_untrusted_component(): (`evilpkg.evilmod.EvilComponent`). Yields a dict payload referencing it. The fixture cleans up the registry on teardown. """ - from haystack import component as component_module - fake_type = "evilpkg.evilmod.EvilComponent" @component_module @@ -226,8 +260,6 @@ def run(self, value: int) -> dict[str, int]: class TestPipelineFromDictAllowlistBypass: def test_pre_registered_untrusted_component_is_rejected(self, _registered_untrusted_component): - from haystack.core.pipeline import Pipeline - with pytest.raises(DeserializationError, match="not on the trusted-module allowlist"): Pipeline.from_dict(_registered_untrusted_component["data"]) @@ -238,8 +270,6 @@ def test_pre_registered_component_loadable_with_allowed_modules(self, _registere doesn't match the test class's real qualified name — that's expected and proves the allowlist gate, not a downstream check, is what changed.) """ - from haystack.core.pipeline import Pipeline - data = _registered_untrusted_component["data"] # Without allowed_modules, this is rejected as untrusted. with pytest.raises(DeserializationError, match="not on the trusted-module allowlist"): @@ -262,20 +292,14 @@ def _yaml_for(data: dict) -> str: # We can't round-trip through `Pipeline.from_dict` + `dumps` because the registered # `EvilComponent`'s real qualified name doesn't match the fake type — the inner # `default_from_dict` would reject it. Build the YAML directly via the marshaller instead. - from haystack.marshal import YamlMarshaller - return YamlMarshaller().marshal(data) def test_loads_rejects_untrusted_by_default(self, _registered_untrusted_component): - from haystack.core.pipeline import Pipeline - yaml_str = self._yaml_for(_registered_untrusted_component["data"]) with pytest.raises(DeserializationError, match="not on the trusted-module allowlist"): Pipeline.loads(yaml_str) def test_loads_propagates_allowed_modules(self, _registered_untrusted_component): - from haystack.core.pipeline import Pipeline - yaml_str = self._yaml_for(_registered_untrusted_component["data"]) # With the matching pattern, the allowlist gate passes; downstream we get the type # mismatch — proving the kwarg reached the gate. @@ -283,47 +307,28 @@ def test_loads_propagates_allowed_modules(self, _registered_untrusted_component) Pipeline.loads(yaml_str, allowed_modules=["evilpkg.*"]) def test_loads_propagates_unsafe(self, _registered_untrusted_component): - from haystack.core.pipeline import Pipeline - yaml_str = self._yaml_for(_registered_untrusted_component["data"]) # `unsafe=True` bypasses the allowlist entirely; downstream we still get the type mismatch. with pytest.raises(DeserializationError, match="can't be deserialized as"): Pipeline.loads(yaml_str, unsafe=True) def test_load_rejects_untrusted_by_default(self, _registered_untrusted_component): - import io - - from haystack.core.pipeline import Pipeline - yaml_str = self._yaml_for(_registered_untrusted_component["data"]) with pytest.raises(DeserializationError, match="not on the trusted-module allowlist"): Pipeline.load(io.StringIO(yaml_str)) def test_load_propagates_allowed_modules(self, _registered_untrusted_component): - import io - - from haystack.core.pipeline import Pipeline - yaml_str = self._yaml_for(_registered_untrusted_component["data"]) with pytest.raises(DeserializationError, match="can't be deserialized as"): Pipeline.load(io.StringIO(yaml_str), allowed_modules=["evilpkg.*"]) def test_load_propagates_unsafe(self, _registered_untrusted_component): - import io - - from haystack.core.pipeline import Pipeline - yaml_str = self._yaml_for(_registered_untrusted_component["data"]) with pytest.raises(DeserializationError, match="can't be deserialized as"): Pipeline.load(io.StringIO(yaml_str), unsafe=True) def test_load_loads_from_dict_equivalent_on_rejection(self, _registered_untrusted_component): """All three entry points produce the same rejection message for the same untrusted payload.""" - import io - from collections.abc import Callable - - from haystack.core.pipeline import Pipeline - data = _registered_untrusted_component["data"] yaml_str = self._yaml_for(data)