Skip to content

Commit 8356e3f

Browse files
authored
feat: Add to_dict to STACKITDocumentEmbedder and STACKITTextEmbedder and more init parameters from underlying OpenAI classes (#1779)
* Add to_dicts and more tests * Bumpy haystack version * Add changes to chat generator as well
1 parent e19d1e5 commit 8356e3f

7 files changed

Lines changed: 186 additions & 16 deletions

File tree

integrations/stackit/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ classifiers = [
2323
"Programming Language :: Python :: Implementation :: CPython",
2424
"Programming Language :: Python :: Implementation :: PyPy",
2525
]
26-
dependencies = ["haystack-ai"]
26+
dependencies = ["haystack-ai>=2.13.0"]
2727

2828
[project.urls]
2929
Documentation = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/stackit#readme"

integrations/stackit/src/haystack_integrations/components/embedders/stackit/document_embedder.py

Lines changed: 43 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
# SPDX-FileCopyrightText: 2023-present deepset GmbH <info@deepset.ai>
22
#
33
# SPDX-License-Identifier: Apache-2.0
4-
from typing import List, Optional
4+
from typing import Any, Dict, List, Optional
55

6-
from haystack import component
6+
from haystack import component, default_to_dict
77
from haystack.components.embedders import OpenAIDocumentEmbedder
88
from haystack.utils.auth import Secret
99

@@ -41,6 +41,10 @@ def __init__(
4141
progress_bar: bool = True,
4242
meta_fields_to_embed: Optional[List[str]] = None,
4343
embedding_separator: str = "\n",
44+
*,
45+
timeout: Optional[float] = None,
46+
max_retries: Optional[int] = None,
47+
http_client_kwargs: Optional[Dict[str, Any]] = None,
4448
):
4549
"""
4650
Creates a STACKITDocumentEmbedder component.
@@ -65,6 +69,15 @@ def __init__(
6569
List of meta fields that should be embedded along with the Document text.
6670
:param embedding_separator:
6771
Separator used to concatenate the meta fields to the Document text.
72+
:param timeout:
73+
Timeout for STACKIT client calls. If not set, it defaults to either the `OPENAI_TIMEOUT` environment
74+
variable, or 30 seconds.
75+
:param max_retries:
76+
Maximum number of retries to contact STACKIT after an internal error.
77+
If not set, it defaults to either the `OPENAI_MAX_RETRIES` environment variable, or set to 5.
78+
:param http_client_kwargs:
79+
A dictionary of keyword arguments to configure a custom `httpx.Client`or `httpx.AsyncClient`.
80+
For more information, see the [HTTPX documentation](https://www.python-httpx.org/api/#client).
6881
"""
6982
super(STACKITDocumentEmbedder, self).__init__( # noqa: UP008
7083
api_key=api_key,
@@ -78,4 +91,32 @@ def __init__(
7891
progress_bar=progress_bar,
7992
meta_fields_to_embed=meta_fields_to_embed,
8093
embedding_separator=embedding_separator,
94+
timeout=timeout,
95+
max_retries=max_retries,
96+
http_client_kwargs=http_client_kwargs,
97+
)
98+
# We add these since they were only added in Haystack 2.14.0
99+
self.timeout = timeout
100+
self.max_retries = max_retries
101+
102+
def to_dict(self) -> Dict[str, Any]:
103+
"""
104+
Serializes the component to a dictionary.
105+
:returns:
106+
Dictionary with serialized data.
107+
"""
108+
return default_to_dict(
109+
self,
110+
model=self.model,
111+
api_key=self.api_key.to_dict(),
112+
api_base_url=self.api_base_url,
113+
prefix=self.prefix,
114+
suffix=self.suffix,
115+
batch_size=self.batch_size,
116+
progress_bar=self.progress_bar,
117+
meta_fields_to_embed=self.meta_fields_to_embed,
118+
embedding_separator=self.embedding_separator,
119+
timeout=self.timeout,
120+
max_retries=self.max_retries,
121+
http_client_kwargs=self.http_client_kwargs,
81122
)

integrations/stackit/src/haystack_integrations/components/embedders/stackit/text_embedder.py

Lines changed: 39 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
# SPDX-FileCopyrightText: 2023-present deepset GmbH <info@deepset.ai>
22
#
33
# SPDX-License-Identifier: Apache-2.0
4-
from typing import Optional
4+
from typing import Any, Dict, Optional
55

6-
from haystack import component
6+
from haystack import component, default_to_dict
77
from haystack.components.embedders import OpenAITextEmbedder
88
from haystack.utils.auth import Secret
99

@@ -30,6 +30,10 @@ def __init__(
3030
api_base_url: Optional[str] = "https://api.openai-compat.model-serving.eu01.onstackit.cloud/v1",
3131
prefix: str = "",
3232
suffix: str = "",
33+
*,
34+
timeout: Optional[float] = None,
35+
max_retries: Optional[int] = None,
36+
http_client_kwargs: Optional[Dict[str, Any]] = None,
3337
):
3438
"""
3539
Creates a STACKITTextEmbedder component.
@@ -45,6 +49,15 @@ def __init__(
4549
A string to add to the beginning of each text.
4650
:param suffix:
4751
A string to add to the end of each text.
52+
:param timeout:
53+
Timeout for STACKIT client calls. If not set, it defaults to either the `OPENAI_TIMEOUT` environment
54+
variable, or 30 seconds.
55+
:param max_retries:
56+
Maximum number of retries to contact STACKIT after an internal error.
57+
If not set, it defaults to either the `OPENAI_MAX_RETRIES` environment variable, or set to 5.
58+
:param http_client_kwargs:
59+
A dictionary of keyword arguments to configure a custom `httpx.Client`or `httpx.AsyncClient`.
60+
For more information, see the [HTTPX documentation](https://www.python-httpx.org/api/#client).
4861
"""
4962
super(STACKITTextEmbedder, self).__init__( # noqa: UP008
5063
api_key=api_key,
@@ -54,4 +67,28 @@ def __init__(
5467
organization=None,
5568
prefix=prefix,
5669
suffix=suffix,
70+
timeout=timeout,
71+
max_retries=max_retries,
72+
http_client_kwargs=http_client_kwargs,
73+
)
74+
# We add these since they were only added in Haystack 2.14.0
75+
self.timeout = timeout
76+
self.max_retries = max_retries
77+
78+
def to_dict(self) -> Dict[str, Any]:
79+
"""
80+
Serializes the component to a dictionary.
81+
:returns:
82+
Dictionary with serialized data.
83+
"""
84+
return default_to_dict(
85+
self,
86+
api_key=self.api_key.to_dict(),
87+
model=self.model,
88+
api_base_url=self.api_base_url,
89+
prefix=self.prefix,
90+
suffix=self.suffix,
91+
timeout=self.timeout,
92+
max_retries=self.max_retries,
93+
http_client_kwargs=self.http_client_kwargs,
5794
)

integrations/stackit/src/haystack_integrations/components/generators/stackit/chat/chat_generator.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
# SPDX-FileCopyrightText: 2025-present deepset GmbH <info@deepset.ai>
22
#
33
# SPDX-License-Identifier: Apache-2.0
4-
from typing import Any, Callable, Dict, Optional
4+
from typing import Any, Dict, Optional
55

66
from haystack import component, default_to_dict
77
from haystack.components.generators.chat import OpenAIChatGenerator
8-
from haystack.dataclasses import StreamingChunk
8+
from haystack.dataclasses import StreamingCallbackT
99
from haystack.utils import serialize_callable
1010
from haystack.utils.auth import Secret
1111

@@ -40,9 +40,13 @@ def __init__(
4040
self,
4141
model: str,
4242
api_key: Secret = Secret.from_env_var("STACKIT_API_KEY"),
43-
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
43+
streaming_callback: Optional[StreamingCallbackT] = None,
4444
api_base_url: Optional[str] = "https://api.openai-compat.model-serving.eu01.onstackit.cloud/v1",
4545
generation_kwargs: Optional[Dict[str, Any]] = None,
46+
*,
47+
timeout: Optional[float] = None,
48+
max_retries: Optional[int] = None,
49+
http_client_kwargs: Optional[Dict[str, Any]] = None,
4650
):
4751
"""
4852
Creates an instance of STACKITChatGenerator class.
@@ -70,6 +74,15 @@ def __init__(
7074
events as they become available, with the stream terminated by a data: [DONE] message.
7175
- `safe_prompt`: Whether to inject a safety prompt before all conversations.
7276
- `random_seed`: The seed to use for random sampling.
77+
:param timeout:
78+
Timeout for STACKIT client calls. If not set, it defaults to either the `OPENAI_TIMEOUT` environment
79+
variable, or 30 seconds.
80+
:param max_retries:
81+
Maximum number of retries to contact STACKIT after an internal error.
82+
If not set, it defaults to either the `OPENAI_MAX_RETRIES` environment variable, or set to 5.
83+
:param http_client_kwargs:
84+
A dictionary of keyword arguments to configure a custom `httpx.Client`or `httpx.AsyncClient`.
85+
For more information, see the [HTTPX documentation](https://www.python-httpx.org/api/#client).
7386
"""
7487
super(STACKITChatGenerator, self).__init__( # noqa: UP008
7588
model=model,
@@ -78,6 +91,9 @@ def __init__(
7891
api_base_url=api_base_url,
7992
organization=None,
8093
generation_kwargs=generation_kwargs,
94+
timeout=timeout,
95+
max_retries=max_retries,
96+
http_client_kwargs=http_client_kwargs,
8197
)
8298

8399
def to_dict(self) -> Dict[str, Any]:
@@ -100,4 +116,7 @@ def to_dict(self) -> Dict[str, Any]:
100116
api_base_url=self.api_base_url,
101117
generation_kwargs=self.generation_kwargs,
102118
api_key=self.api_key.to_dict(),
119+
timeout=self.timeout,
120+
max_retries=self.max_retries,
121+
http_client_kwargs=self.http_client_kwargs,
103122
)

integrations/stackit/tests/test_stackit_chat_generator.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,9 @@ def test_to_dict_default(self, monkeypatch):
9393
"streaming_callback": None,
9494
"api_base_url": "https://api.openai-compat.model-serving.eu01.onstackit.cloud/v1",
9595
"generation_kwargs": {},
96+
"timeout": None,
97+
"max_retries": None,
98+
"http_client_kwargs": None,
9699
}
97100

98101
for key, value in expected_params.items():
@@ -106,6 +109,9 @@ def test_to_dict_with_parameters(self, monkeypatch):
106109
streaming_callback=print_streaming_chunk,
107110
api_base_url="test-base-url",
108111
generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"},
112+
timeout=10.0,
113+
max_retries=2,
114+
http_client_kwargs={"proxy": "https://proxy.example.com:8080"},
109115
)
110116
data = component.to_dict()
111117

@@ -120,6 +126,9 @@ def test_to_dict_with_parameters(self, monkeypatch):
120126
"api_base_url": "test-base-url",
121127
"streaming_callback": "haystack.components.generators.utils.print_streaming_chunk",
122128
"generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"},
129+
"timeout": 10.0,
130+
"max_retries": 2,
131+
"http_client_kwargs": {"proxy": "https://proxy.example.com:8080"},
123132
}
124133

125134
for key, value in expected_params.items():

integrations/stackit/tests/test_stackit_document_embedder.py

Lines changed: 40 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -58,14 +58,14 @@ def test_to_dict(self, monkeypatch):
5858
"api_key": {"env_vars": ["STACKIT_API_KEY"], "strict": True, "type": "env_var"},
5959
"model": "intfloat/e5-mistral-7b-instruct",
6060
"api_base_url": "https://api.openai-compat.model-serving.eu01.onstackit.cloud/v1",
61-
"dimensions": None,
62-
"organization": None,
6361
"prefix": "",
6462
"suffix": "",
6563
"batch_size": 32,
6664
"progress_bar": True,
6765
"meta_fields_to_embed": [],
6866
"embedding_separator": "\n",
67+
"timeout": None,
68+
"max_retries": None,
6969
"http_client_kwargs": None,
7070
},
7171
}
@@ -82,25 +82,61 @@ def test_to_dict_with_custom_init_parameters(self, monkeypatch):
8282
progress_bar=False,
8383
meta_fields_to_embed=["test_field"],
8484
embedding_separator="-",
85+
timeout=10.0,
86+
max_retries=2,
87+
http_client_kwargs={"proxy": "https://proxy.example.com"},
8588
)
8689
component_dict = embedder.to_dict()
8790
assert component_dict == {
8891
"type": "haystack_integrations.components.embedders.stackit.document_embedder.STACKITDocumentEmbedder",
8992
"init_parameters": {
9093
"api_key": {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"},
9194
"model": "intfloat/e5-mistral-7b-instruct",
92-
"dimensions": None,
9395
"api_base_url": "https://custom-api-base-url.com",
94-
"organization": None,
9596
"prefix": "START",
9697
"suffix": "END",
9798
"batch_size": 64,
9899
"progress_bar": False,
99100
"meta_fields_to_embed": ["test_field"],
100101
"embedding_separator": "-",
102+
"timeout": 10.0,
103+
"max_retries": 2,
104+
"http_client_kwargs": {"proxy": "https://proxy.example.com"},
105+
},
106+
}
107+
108+
def test_from_dict(self, monkeypatch):
109+
monkeypatch.setenv("STACKIT_API_KEY", "test-api-key")
110+
data = {
111+
"type": "haystack_integrations.components.embedders.stackit.document_embedder.STACKITDocumentEmbedder",
112+
"init_parameters": {
113+
"api_key": {"env_vars": ["STACKIT_API_KEY"], "strict": True, "type": "env_var"},
114+
"model": "intfloat/e5-mistral-7b-instruct",
115+
"api_base_url": "https://api.openai-compat.model-serving.eu01.onstackit.cloud/v1",
116+
"prefix": "",
117+
"suffix": "",
118+
"batch_size": 32,
119+
"progress_bar": True,
120+
"meta_fields_to_embed": [],
121+
"embedding_separator": "\n",
122+
"timeout": None,
123+
"max_retries": None,
101124
"http_client_kwargs": None,
102125
},
103126
}
127+
embedder = STACKITDocumentEmbedder.from_dict(data)
128+
assert embedder.api_key == Secret.from_env_var(["STACKIT_API_KEY"])
129+
assert embedder.model == "intfloat/e5-mistral-7b-instruct"
130+
assert embedder.api_base_url == "https://api.openai-compat.model-serving.eu01.onstackit.cloud/v1"
131+
assert embedder.prefix == ""
132+
assert embedder.suffix == ""
133+
assert embedder.batch_size == 32
134+
assert embedder.progress_bar is True
135+
assert embedder.meta_fields_to_embed == []
136+
assert embedder.embedding_separator == "\n"
137+
assert embedder.timeout is None
138+
assert embedder.max_retries is None
139+
assert embedder.http_client_kwargs is None
104140

105141
@pytest.mark.skipif(
106142
not os.environ.get("STACKIT_API_KEY", None),

integrations/stackit/tests/test_stackit_text_embedder.py

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,10 +44,10 @@ def test_to_dict(self, monkeypatch):
4444
"api_key": {"env_vars": ["STACKIT_API_KEY"], "strict": True, "type": "env_var"},
4545
"model": "intfloat/e5-mistral-7b-instruct",
4646
"api_base_url": "https://api.openai-compat.model-serving.eu01.onstackit.cloud/v1",
47-
"dimensions": None,
48-
"organization": None,
4947
"prefix": "",
5048
"suffix": "",
49+
"timeout": None,
50+
"max_retries": None,
5151
"http_client_kwargs": None,
5252
},
5353
}
@@ -60,6 +60,9 @@ def test_to_dict_with_custom_init_parameters(self, monkeypatch):
6060
api_base_url="https://custom-api-base-url.com",
6161
prefix="START",
6262
suffix="END",
63+
timeout=10.0,
64+
max_retries=2,
65+
http_client_kwargs={"proxy": "https://proxy.example.com"},
6366
)
6467
component_dict = embedder.to_dict()
6568
assert component_dict == {
@@ -68,13 +71,38 @@ def test_to_dict_with_custom_init_parameters(self, monkeypatch):
6871
"api_key": {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"},
6972
"model": "intfloat/e5-mistral-7b-instruct",
7073
"api_base_url": "https://custom-api-base-url.com",
71-
"dimensions": None,
72-
"organization": None,
7374
"prefix": "START",
7475
"suffix": "END",
76+
"timeout": 10.0,
77+
"max_retries": 2,
78+
"http_client_kwargs": {"proxy": "https://proxy.example.com"},
79+
},
80+
}
81+
82+
def test_from_dict(self, monkeypatch):
83+
monkeypatch.setenv("STACKIT_API_KEY", "test-secret-key")
84+
data = {
85+
"type": "haystack_integrations.components.embedders.stackit.text_embedder.STACKITTextEmbedder",
86+
"init_parameters": {
87+
"api_key": {"env_vars": ["STACKIT_API_KEY"], "strict": True, "type": "env_var"},
88+
"model": "intfloat/e5-mistral-7b-instruct",
89+
"api_base_url": "https://api.openai-compat.model-serving.eu01.onstackit.cloud/v1",
90+
"prefix": "",
91+
"suffix": "",
92+
"timeout": None,
93+
"max_retries": None,
7594
"http_client_kwargs": None,
7695
},
7796
}
97+
embedder = STACKITTextEmbedder.from_dict(data)
98+
assert embedder.api_key == Secret.from_env_var(["STACKIT_API_KEY"])
99+
assert embedder.api_base_url == "https://api.openai-compat.model-serving.eu01.onstackit.cloud/v1"
100+
assert embedder.model == "intfloat/e5-mistral-7b-instruct"
101+
assert embedder.prefix == ""
102+
assert embedder.suffix == ""
103+
assert embedder.timeout is None
104+
assert embedder.max_retries is None
105+
assert embedder.http_client_kwargs is None
78106

79107
@pytest.mark.skipif(
80108
not os.environ.get("STACKIT_API_KEY", None),

0 commit comments

Comments
 (0)