Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion haystack/core/super_component/super_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,11 +367,13 @@ def _to_super_component_dict(self) -> Dict[str, Any]:
:return: Dictionary containing serialized SuperComponent data
"""
serialized_pipeline = self.pipeline.to_dict()
is_pipeline_async = isinstance(self.pipeline, AsyncPipeline)
serialized = default_to_dict(
self,
pipeline=serialized_pipeline,
input_mapping=self._original_input_mapping,
output_mapping=self._original_output_mapping,
is_pipeline_async=is_pipeline_async,
)
serialized["type"] = generate_qualified_class_name(SuperComponent)
return serialized
Expand Down Expand Up @@ -462,7 +464,9 @@ def from_dict(cls, data: Dict[str, Any]) -> "SuperComponent":
:returns:
The deserialized SuperComponent.
"""
pipeline = Pipeline.from_dict(data["init_parameters"]["pipeline"])
is_pipeline_async = data["init_parameters"].pop("is_pipeline_async", False)
pipeline_class = AsyncPipeline if is_pipeline_async else Pipeline
pipeline = pipeline_class.from_dict(data["init_parameters"]["pipeline"])
data["init_parameters"]["pipeline"] = pipeline
return default_from_dict(cls, data)

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
fixes:
- |
The SuperComponent class can now correctly serialize and deserialize a SuperComponent based on an async pipeline.
Previously, the SuperComponent class always assumed the underlying pipeline was synchronous.
31 changes: 31 additions & 0 deletions test/core/super_component/test_super_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,7 @@ def test_wrapper_serialization(self, document_store):
assert "type" in serialized
assert "init_parameters" in serialized
assert "pipeline" in serialized["init_parameters"]
assert serialized["init_parameters"]["is_pipeline_async"] is False

# Test deserialization
deserialized = SuperComponent.from_dict(serialized)
Expand Down Expand Up @@ -434,3 +435,33 @@ def run(self, specific: str, generic: Any):
input_sockets = wrapper.__haystack_input__._sockets_dict
assert "text" in input_sockets
assert input_sockets["text"].type == str

@pytest.mark.asyncio
async def test_super_component_async_serialization_deserialization(self):
"""
Test that when using the SuperComponent class, a SuperComponent based on an async pipeline can be serialized and
deserialized correctly.
"""

@component
class AsyncComponent:
@component.output_types(output=str)
def run(self):
return {"output": "irrelevant"}

@component.output_types(output=str)
async def run_async(self):
return {"output": "Hello world"}

pipeline = AsyncPipeline()
pipeline.add_component("hello", AsyncComponent())

async_super_component = SuperComponent(pipeline=pipeline)
serialized_super_component = async_super_component.to_dict()
assert serialized_super_component["init_parameters"]["is_pipeline_async"] is True

deserialized_super_component = SuperComponent.from_dict(serialized_super_component)
assert isinstance(deserialized_super_component.pipeline, AsyncPipeline)

result = await deserialized_super_component.run_async()
assert result == {"output": "Hello world"}
Loading