Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@
#
# SPDX-License-Identifier: Apache-2.0

from dataclasses import asdict, is_dataclass
from typing import Any

from haystack import component, default_from_dict, default_to_dict, logging, tracing
from pydantic import BaseModel

from haystack_integrations.tracing.weave import WeaveTracer
from weave.trace.settings import UserSettings
Expand Down Expand Up @@ -110,7 +112,10 @@ def to_dict(self) -> dict[str, Any]:
weave_init_kwargs = self.weave_init_kwargs.copy()
settings = weave_init_kwargs.get("settings", None)
if isinstance(settings, UserSettings):
weave_init_kwargs["settings"] = settings.model_dump(mode="json", exclude_defaults=True)
if isinstance(settings, BaseModel):
weave_init_kwargs["settings"] = settings.model_dump(mode="json", exclude_defaults=True)
elif is_dataclass(settings):
weave_init_kwargs["settings"] = asdict(settings)

return default_to_dict(self, pipeline_name=self.pipeline_name, weave_init_kwargs=weave_init_kwargs)

Expand Down
6 changes: 2 additions & 4 deletions integrations/weave/tests/test_weave_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,9 +97,7 @@ def test_serialization_of_weave_init_kwargs_with_user_settings(self) -> None:
serialized: dict[str, Any] = connector.to_dict()

assert serialized["init_parameters"]["pipeline_name"] == "test_pipeline"
assert serialized["init_parameters"]["weave_init_kwargs"] == {
"settings": {"implicitly_patch_integrations": False}
}
assert serialized["init_parameters"]["weave_init_kwargs"]["settings"]["implicitly_patch_integrations"] is False
assert "type" in serialized
assert serialized["type"] == "haystack_integrations.components.connectors.weave.weave_connector.WeaveConnector"

Expand All @@ -108,7 +106,7 @@ def test_serialization_of_weave_init_kwargs_with_user_settings(self) -> None:
assert isinstance(deserialized, WeaveConnector)
assert deserialized.pipeline_name == "test_pipeline"
assert deserialized.tracer is None # tracer is only initialized with warm_up
assert deserialized.weave_init_kwargs == {"settings": {"implicitly_patch_integrations": False}}
assert deserialized.weave_init_kwargs["settings"]["implicitly_patch_integrations"] is False

def test_pipeline_tracing(self, mock_weave_client: Mock, sample_pipeline: Pipeline) -> None:
"""Test that pipeline operations are correctly traced"""
Expand Down
Loading