Skip to content

Commit 115217d

Browse files
authored
fix: Bring Mistral integration up to date with changes made to OpenAIChatGenerator and OpenAI Embedders (#1774)
* Bringing Mistral up to date * Fix Mistral Embedders to be deserializable * Fix lint * Fix lint * Bump minimum haystack version
1 parent 242bba8 commit 115217d

7 files changed

Lines changed: 186 additions & 20 deletions

File tree

integrations/mistral/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ classifiers = [
2323
"Programming Language :: Python :: Implementation :: CPython",
2424
"Programming Language :: Python :: Implementation :: PyPy",
2525
]
26-
dependencies = ["haystack-ai>=2.9.0"]
26+
dependencies = ["haystack-ai>=2.13.0"]
2727

2828
[project.urls]
2929
Documentation = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/mistral#readme"

integrations/mistral/src/haystack_integrations/components/embedders/mistral/document_embedder.py

Lines changed: 44 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
# SPDX-FileCopyrightText: 2023-present deepset GmbH <info@deepset.ai>
22
#
33
# SPDX-License-Identifier: Apache-2.0
4-
from typing import List, Optional
4+
from typing import Any, Dict, List, Optional
55

6-
from haystack import component
6+
from haystack import component, default_to_dict
77
from haystack.components.embedders import OpenAIDocumentEmbedder
88
from haystack.utils.auth import Secret
99

@@ -41,6 +41,10 @@ def __init__(
4141
progress_bar: bool = True,
4242
meta_fields_to_embed: Optional[List[str]] = None,
4343
embedding_separator: str = "\n",
44+
*,
45+
timeout: Optional[float] = None,
46+
max_retries: Optional[int] = None,
47+
http_client_kwargs: Optional[Dict[str, Any]] = None,
4448
):
4549
"""
4650
Creates a MistralDocumentEmbedder component.
@@ -64,6 +68,15 @@ def __init__(
6468
List of meta fields that should be embedded along with the Document text.
6569
:param embedding_separator:
6670
Separator used to concatenate the meta fields to the Document text.
71+
:param timeout:
72+
Timeout for Mistral client calls. If not set, it defaults to either the `OPENAI_TIMEOUT` environment
73+
variable, or 30 seconds.
74+
:param max_retries:
75+
Maximum number of retries to contact Mistral after an internal error.
76+
If not set, it defaults to either the `OPENAI_MAX_RETRIES` environment variable, or set to 5.
77+
:param http_client_kwargs:
78+
A dictionary of keyword arguments to configure a custom `httpx.Client`or `httpx.AsyncClient`.
79+
For more information, see the [HTTPX documentation](https://www.python-httpx.org/api/#client).
6780
"""
6881
super(MistralDocumentEmbedder, self).__init__( # noqa: UP008
6982
api_key=api_key,
@@ -77,4 +90,33 @@ def __init__(
7790
progress_bar=progress_bar,
7891
meta_fields_to_embed=meta_fields_to_embed,
7992
embedding_separator=embedding_separator,
93+
timeout=timeout,
94+
max_retries=max_retries,
95+
http_client_kwargs=http_client_kwargs,
96+
)
97+
# We add these since they were only added in Haystack 2.14.0
98+
self.timeout = timeout
99+
self.max_retries = max_retries
100+
101+
def to_dict(self) -> Dict[str, Any]:
102+
"""
103+
Serializes the component to a dictionary.
104+
105+
:returns:
106+
Dictionary with serialized data.
107+
"""
108+
return default_to_dict(
109+
self,
110+
api_key=self.api_key.to_dict(),
111+
model=self.model,
112+
api_base_url=self.api_base_url,
113+
prefix=self.prefix,
114+
suffix=self.suffix,
115+
batch_size=self.batch_size,
116+
progress_bar=self.progress_bar,
117+
meta_fields_to_embed=self.meta_fields_to_embed,
118+
embedding_separator=self.embedding_separator,
119+
timeout=self.timeout,
120+
max_retries=self.max_retries,
121+
http_client_kwargs=self.http_client_kwargs,
80122
)

integrations/mistral/src/haystack_integrations/components/embedders/mistral/text_embedder.py

Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
# SPDX-FileCopyrightText: 2023-present deepset GmbH <info@deepset.ai>
22
#
33
# SPDX-License-Identifier: Apache-2.0
4-
from typing import Optional
4+
from typing import Any, Dict, Optional
55

6-
from haystack import component
6+
from haystack import component, default_to_dict
77
from haystack.components.embedders import OpenAITextEmbedder
88
from haystack.utils.auth import Secret
99

@@ -35,6 +35,10 @@ def __init__(
3535
api_base_url: Optional[str] = "https://api.mistral.ai/v1",
3636
prefix: str = "",
3737
suffix: str = "",
38+
*,
39+
timeout: Optional[float] = None,
40+
max_retries: Optional[int] = None,
41+
http_client_kwargs: Optional[Dict[str, Any]] = None,
3842
):
3943
"""
4044
Creates an MistralTextEmbedder component.
@@ -50,6 +54,15 @@ def __init__(
5054
A string to add to the beginning of each text.
5155
:param suffix:
5256
A string to add to the end of each text.
57+
:param timeout:
58+
Timeout for Mistral client calls. If not set, it defaults to either the `OPENAI_TIMEOUT` environment
59+
variable, or 30 seconds.
60+
:param max_retries:
61+
Maximum number of retries to contact Mistral after an internal error.
62+
If not set, it defaults to either the `OPENAI_MAX_RETRIES` environment variable, or set to 5.
63+
:param http_client_kwargs:
64+
A dictionary of keyword arguments to configure a custom `httpx.Client`or `httpx.AsyncClient`.
65+
For more information, see the [HTTPX documentation](https://www.python-httpx.org/api/#client).
5366
"""
5467
super(MistralTextEmbedder, self).__init__( # noqa: UP008
5568
api_key=api_key,
@@ -59,4 +72,29 @@ def __init__(
5972
organization=None,
6073
prefix=prefix,
6174
suffix=suffix,
75+
timeout=timeout,
76+
max_retries=max_retries,
77+
http_client_kwargs=http_client_kwargs,
78+
)
79+
# We add these since they were only added in Haystack 2.14.0
80+
self.timeout = timeout
81+
self.max_retries = max_retries
82+
83+
def to_dict(self) -> Dict[str, Any]:
84+
"""
85+
Serializes the component to a dictionary.
86+
87+
:returns:
88+
Dictionary with serialized data.
89+
"""
90+
return default_to_dict(
91+
self,
92+
api_key=self.api_key.to_dict(),
93+
model=self.model,
94+
api_base_url=self.api_base_url,
95+
prefix=self.prefix,
96+
suffix=self.suffix,
97+
timeout=self.timeout,
98+
max_retries=self.max_retries,
99+
http_client_kwargs=self.http_client_kwargs,
62100
)

integrations/mistral/src/haystack_integrations/components/generators/mistral/chat/chat_generator.py

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,12 @@
22
#
33
# SPDX-License-Identifier: Apache-2.0
44

5-
from typing import Any, Callable, Dict, List, Optional
5+
from typing import Any, Dict, List, Optional, Union
66

77
from haystack import component, default_to_dict, logging
88
from haystack.components.generators.chat import OpenAIChatGenerator
9-
from haystack.dataclasses import StreamingChunk
10-
from haystack.tools import Tool
9+
from haystack.dataclasses import StreamingCallbackT
10+
from haystack.tools import Tool, Toolset
1111
from haystack.utils import serialize_callable
1212
from haystack.utils.auth import Secret
1313

@@ -61,10 +61,14 @@ def __init__(
6161
self,
6262
api_key: Secret = Secret.from_env_var("MISTRAL_API_KEY"),
6363
model: str = "mistral-small-latest",
64-
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
64+
streaming_callback: Optional[StreamingCallbackT] = None,
6565
api_base_url: Optional[str] = "https://api.mistral.ai/v1",
6666
generation_kwargs: Optional[Dict[str, Any]] = None,
67-
tools: Optional[List[Tool]] = None,
67+
tools: Optional[Union[List[Tool], Toolset]] = None,
68+
*,
69+
timeout: Optional[float] = None,
70+
max_retries: Optional[int] = None,
71+
http_client_kwargs: Optional[Dict[str, Any]] = None,
6872
):
6973
"""
7074
Creates an instance of MistralChatGenerator. Unless specified otherwise in the `model`, this is for Mistral's
@@ -95,7 +99,17 @@ def __init__(
9599
- `safe_prompt`: Whether to inject a safety prompt before all conversations.
96100
- `random_seed`: The seed to use for random sampling.
97101
:param tools:
98-
A list of tools for which the model can prepare calls.
102+
A list of tools or a Toolset for which the model can prepare calls. This parameter can accept either a
103+
list of `Tool` objects or a `Toolset` instance.
104+
:param timeout:
105+
The timeout for the Mistral API call. If not set, it defaults to either the `OPENAI_TIMEOUT`
106+
environment variable, or 30 seconds.
107+
:param max_retries:
108+
Maximum number of retries to contact OpenAI after an internal error.
109+
If not set, it defaults to either the `OPENAI_MAX_RETRIES` environment variable, or set to 5.
110+
:param http_client_kwargs:
111+
A dictionary of keyword arguments to configure a custom `httpx.Client`or `httpx.AsyncClient`.
112+
For more information, see the [HTTPX documentation](https://www.python-httpx.org/api/#client).
99113
"""
100114
super(MistralChatGenerator, self).__init__( # noqa: UP008
101115
api_key=api_key,
@@ -105,6 +119,9 @@ def __init__(
105119
organization=None,
106120
generation_kwargs=generation_kwargs,
107121
tools=tools,
122+
timeout=timeout,
123+
max_retries=max_retries,
124+
http_client_kwargs=http_client_kwargs,
108125
)
109126

110127
def to_dict(self) -> Dict[str, Any]:
@@ -119,7 +136,6 @@ def to_dict(self) -> Dict[str, Any]:
119136
# if we didn't implement the to_dict method here then the to_dict method of the superclass would be used
120137
# which would serialiaze some fields that we don't want to serialize (e.g. the ones we don't have in
121138
# the __init__)
122-
# it would be hard to maintain the compatibility as superclass changes
123139
return default_to_dict(
124140
self,
125141
model=self.model,
@@ -128,4 +144,7 @@ def to_dict(self) -> Dict[str, Any]:
128144
generation_kwargs=self.generation_kwargs,
129145
api_key=self.api_key.to_dict(),
130146
tools=[tool.to_dict() for tool in self.tools] if self.tools else None,
147+
timeout=self.timeout,
148+
max_retries=self.max_retries,
149+
http_client_kwargs=self.http_client_kwargs,
131150
)

integrations/mistral/tests/test_mistral_chat_generator.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -487,6 +487,9 @@ def test_serde_in_pipeline(self, monkeypatch):
487487
},
488488
}
489489
],
490+
"http_client_kwargs": None,
491+
"timeout": None,
492+
"max_retries": None,
490493
},
491494
}
492495
},

integrations/mistral/tests/test_mistral_document_embedder.py

Lines changed: 40 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -60,14 +60,14 @@ def test_to_dict(self, monkeypatch):
6060
"api_key": {"env_vars": ["MISTRAL_API_KEY"], "strict": True, "type": "env_var"},
6161
"model": "mistral-embed",
6262
"api_base_url": "https://api.mistral.ai/v1",
63-
"dimensions": None,
64-
"organization": None,
6563
"prefix": "",
6664
"suffix": "",
6765
"batch_size": 32,
6866
"progress_bar": True,
6967
"meta_fields_to_embed": [],
7068
"embedding_separator": "\n",
69+
"timeout": None,
70+
"max_retries": None,
7171
"http_client_kwargs": None,
7272
},
7373
}
@@ -84,25 +84,61 @@ def test_to_dict_with_custom_init_parameters(self, monkeypatch):
8484
progress_bar=False,
8585
meta_fields_to_embed=["test_field"],
8686
embedding_separator="-",
87+
timeout=10.0,
88+
max_retries=2,
89+
http_client_kwargs={"proxy": "http://localhost:8080"},
8790
)
8891
component_dict = embedder.to_dict()
8992
assert component_dict == {
9093
"type": "haystack_integrations.components.embedders.mistral.document_embedder.MistralDocumentEmbedder",
9194
"init_parameters": {
9295
"api_key": {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"},
9396
"model": "mistral-embed-v2",
94-
"dimensions": None,
9597
"api_base_url": "https://custom-api-base-url.com",
96-
"organization": None,
9798
"prefix": "START",
9899
"suffix": "END",
99100
"batch_size": 64,
100101
"progress_bar": False,
101102
"meta_fields_to_embed": ["test_field"],
102103
"embedding_separator": "-",
104+
"timeout": 10.0,
105+
"max_retries": 2,
106+
"http_client_kwargs": {"proxy": "http://localhost:8080"},
107+
},
108+
}
109+
110+
def test_from_dict(self, monkeypatch):
111+
monkeypatch.setenv("MISTRAL_API_KEY", "fake-api-key")
112+
data = {
113+
"type": "haystack_integrations.components.embedders.mistral.document_embedder.MistralDocumentEmbedder",
114+
"init_parameters": {
115+
"api_key": {"env_vars": ["MISTRAL_API_KEY"], "strict": True, "type": "env_var"},
116+
"model": "mistral-embed",
117+
"api_base_url": "https://api.mistral.ai/v1",
118+
"prefix": "",
119+
"suffix": "",
120+
"batch_size": 32,
121+
"progress_bar": True,
122+
"meta_fields_to_embed": [],
123+
"embedding_separator": "\n",
124+
"timeout": None,
125+
"max_retries": None,
103126
"http_client_kwargs": None,
104127
},
105128
}
129+
component = MistralDocumentEmbedder.from_dict(data)
130+
assert component.api_key == Secret.from_env_var(["MISTRAL_API_KEY"])
131+
assert component.model == "mistral-embed"
132+
assert component.api_base_url == "https://api.mistral.ai/v1"
133+
assert component.prefix == ""
134+
assert component.suffix == ""
135+
assert component.batch_size == 32
136+
assert component.progress_bar is True
137+
assert component.meta_fields_to_embed == []
138+
assert component.embedding_separator == "\n"
139+
assert component.timeout is None
140+
assert component.max_retries is None
141+
assert component.http_client_kwargs is None
106142

107143
@pytest.mark.skipif(
108144
not os.environ.get("MISTRAL_API_KEY", None),

integrations/mistral/tests/test_mistral_text_embedder.py

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -46,10 +46,10 @@ def test_to_dict(self, monkeypatch):
4646
"api_key": {"env_vars": ["MISTRAL_API_KEY"], "strict": True, "type": "env_var"},
4747
"model": "mistral-embed",
4848
"api_base_url": "https://api.mistral.ai/v1",
49-
"dimensions": None,
50-
"organization": None,
5149
"prefix": "",
5250
"suffix": "",
51+
"timeout": None,
52+
"max_retries": None,
5353
"http_client_kwargs": None,
5454
},
5555
}
@@ -62,6 +62,9 @@ def test_to_dict_with_custom_init_parameters(self, monkeypatch):
6262
api_base_url="https://custom-api-base-url.com",
6363
prefix="START",
6464
suffix="END",
65+
timeout=10.0,
66+
max_retries=2,
67+
http_client_kwargs={"proxy": "http://localhost:8080"},
6568
)
6669
component_dict = embedder.to_dict()
6770
assert component_dict == {
@@ -70,13 +73,38 @@ def test_to_dict_with_custom_init_parameters(self, monkeypatch):
7073
"api_key": {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"},
7174
"model": "mistral-embed-v2",
7275
"api_base_url": "https://custom-api-base-url.com",
73-
"dimensions": None,
74-
"organization": None,
7576
"prefix": "START",
7677
"suffix": "END",
78+
"timeout": 10.0,
79+
"max_retries": 2,
80+
"http_client_kwargs": {"proxy": "http://localhost:8080"},
81+
},
82+
}
83+
84+
def test_from_dict(self, monkeypatch):
85+
monkeypatch.setenv("MISTRAL_API_KEY", "fake-api-key")
86+
data = {
87+
"type": "haystack_integrations.components.embedders.mistral.text_embedder.MistralTextEmbedder",
88+
"init_parameters": {
89+
"api_key": {"env_vars": ["MISTRAL_API_KEY"], "strict": True, "type": "env_var"},
90+
"model": "mistral-embed",
91+
"api_base_url": "https://api.mistral.ai/v1",
92+
"prefix": "",
93+
"suffix": "",
94+
"timeout": None,
95+
"max_retries": None,
7796
"http_client_kwargs": None,
7897
},
7998
}
99+
component = MistralTextEmbedder.from_dict(data)
100+
assert component.api_key == Secret.from_env_var(["MISTRAL_API_KEY"])
101+
assert component.api_base_url == "https://api.mistral.ai/v1"
102+
assert component.model == "mistral-embed"
103+
assert component.prefix == ""
104+
assert component.suffix == ""
105+
assert component.timeout is None
106+
assert component.max_retries is None
107+
assert component.http_client_kwargs is None
80108

81109
@pytest.mark.skipif(
82110
not os.environ.get("MISTRAL_API_KEY", None),

0 commit comments

Comments
 (0)