-
Notifications
You must be signed in to change notification settings - Fork 252
Expand file tree
/
Copy pathtest_stackit_text_embedder.py
More file actions
101 lines (89 loc) · 4.27 KB
/
test_stackit_text_embedder.py
File metadata and controls
101 lines (89 loc) · 4.27 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
# SPDX-FileCopyrightText: 2023-present deepset GmbH <info@deepset.ai>
#
# SPDX-License-Identifier: Apache-2.0
import os
import pytest
from haystack.utils import Secret
from haystack_integrations.components.embedders.stackit.text_embedder import STACKITTextEmbedder
class TestSTACKITTextEmbedder:
def test_init_default(self, monkeypatch):
monkeypatch.setenv("STACKIT_API_KEY", "test-api-key")
embedder = STACKITTextEmbedder(model="intfloat/e5-mistral-7b-instruct")
assert embedder.api_key == Secret.from_env_var(["STACKIT_API_KEY"])
assert embedder.api_base_url == "https://api.openai-compat.model-serving.eu01.onstackit.cloud/v1"
assert embedder.model == "intfloat/e5-mistral-7b-instruct"
assert embedder.prefix == ""
assert embedder.suffix == ""
def test_init_with_parameters(self):
embedder = STACKITTextEmbedder(
api_key=Secret.from_token("test-api-key"),
model="intfloat/e5-mistral-7b-instruct",
prefix="START",
suffix="END",
)
assert embedder.api_key == Secret.from_token("test-api-key")
assert embedder.api_base_url == "https://api.openai-compat.model-serving.eu01.onstackit.cloud/v1"
assert embedder.model == "intfloat/e5-mistral-7b-instruct"
assert embedder.prefix == "START"
assert embedder.suffix == "END"
def test_to_dict(self, monkeypatch):
monkeypatch.setenv("STACKIT_API_KEY", "test-api-key")
embedder_component = STACKITTextEmbedder(model="intfloat/e5-mistral-7b-instruct")
component_dict = embedder_component.to_dict()
assert component_dict == {
"type": "haystack_integrations.components.embedders.stackit.text_embedder.STACKITTextEmbedder",
"init_parameters": {
"api_key": {"env_vars": ["STACKIT_API_KEY"], "strict": True, "type": "env_var"},
"model": "intfloat/e5-mistral-7b-instruct",
"api_base_url": "https://api.openai-compat.model-serving.eu01.onstackit.cloud/v1",
"dimensions": None,
"organization": None,
"prefix": "",
"suffix": "",
"http_client_kwargs": None,
},
}
def test_to_dict_with_custom_init_parameters(self, monkeypatch):
monkeypatch.setenv("ENV_VAR", "test-secret-key")
embedder = STACKITTextEmbedder(
api_key=Secret.from_env_var("ENV_VAR", strict=False),
model="intfloat/e5-mistral-7b-instruct",
api_base_url="https://custom-api-base-url.com",
prefix="START",
suffix="END",
)
component_dict = embedder.to_dict()
assert component_dict == {
"type": "haystack_integrations.components.embedders.stackit.text_embedder.STACKITTextEmbedder",
"init_parameters": {
"api_key": {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"},
"model": "intfloat/e5-mistral-7b-instruct",
"api_base_url": "https://custom-api-base-url.com",
"dimensions": None,
"organization": None,
"prefix": "START",
"suffix": "END",
"http_client_kwargs": None,
},
}
@pytest.mark.skipif(
not os.environ.get("STACKIT_API_KEY", None),
reason="Export an env var called STACKIT_API_KEY containing the STACKIT API key to run this test.",
)
@pytest.mark.integration
def test_run(self):
embedder = STACKITTextEmbedder(model="intfloat/e5-mistral-7b-instruct")
text = "The food was delicious"
result = embedder.run(text)
assert all(isinstance(x, float) for x in result["embedding"])
def test_run_wrong_input_format(self):
embedder = STACKITTextEmbedder(
model="intfloat/e5-mistral-7b-instruct", api_key=Secret.from_token("test-api-key")
)
list_integers_input = ["text_snippet_1", "text_snippet_2"]
match_error_msg = (
"OpenAITextEmbedder expects a string as an input.In case you want to embed a list of Documents,"
" please use the OpenAIDocumentEmbedder."
)
with pytest.raises(TypeError, match=match_error_msg):
embedder.run(text=list_integers_input)