Skip to content

Commit 60570b5

Browse files
sjrljulian-risch
authored andcommitted
feat: Add raise_on_failure boolean parameter to OpenAIDocumentEmbedder and AzureOpenAIDocumentEmbedder (#9474)
* Add raise_on_failure to OpenAIDocumentEmbedder * Add reno * Add parameter to Azure Doc embedder as well * Fix bug * Update reno * PR comments * update reno
1 parent 045603d commit 60570b5

5 files changed

Lines changed: 112 additions & 17 deletions

File tree

haystack/components/embedders/azure_document_embedder.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ def __init__( # noqa: PLR0913 (too-many-arguments) # pylint: disable=too-many-p
5959
default_headers: Optional[Dict[str, str]] = None,
6060
azure_ad_token_provider: Optional[AzureADTokenProvider] = None,
6161
http_client_kwargs: Optional[Dict[str, Any]] = None,
62+
raise_on_failure: bool = False,
6263
):
6364
"""
6465
Creates an AzureOpenAIDocumentEmbedder component.
@@ -109,6 +110,9 @@ def __init__( # noqa: PLR0913 (too-many-arguments) # pylint: disable=too-many-p
109110
:param http_client_kwargs:
110111
A dictionary of keyword arguments to configure a custom `httpx.Client`or `httpx.AsyncClient`.
111112
For more information, see the [HTTPX documentation](https://www.python-httpx.org/api/#client).
113+
:param raise_on_failure:
114+
Whether to raise an exception if the embedding request fails. If `False`, the component will log the error
115+
and continue processing the remaining documents. If `True`, it will raise an exception on failure.
112116
"""
113117
# We intentionally do not call super().__init__ here because we only need to instantiate the client to interact
114118
# with the API.
@@ -140,6 +144,7 @@ def __init__( # noqa: PLR0913 (too-many-arguments) # pylint: disable=too-many-p
140144
self.default_headers = default_headers or {}
141145
self.azure_ad_token_provider = azure_ad_token_provider
142146
self.http_client_kwargs = http_client_kwargs
147+
self.raise_on_failure = raise_on_failure
143148

144149
client_args: Dict[str, Any] = {
145150
"api_version": api_version,
@@ -191,6 +196,7 @@ def to_dict(self) -> Dict[str, Any]:
191196
default_headers=self.default_headers,
192197
azure_ad_token_provider=azure_ad_token_provider_name,
193198
http_client_kwargs=self.http_client_kwargs,
199+
raise_on_failure=self.raise_on_failure,
194200
)
195201

196202
@classmethod

haystack/components/embedders/openai_document_embedder.py

Lines changed: 34 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ class OpenAIDocumentEmbedder:
3939
```
4040
"""
4141

42-
def __init__( # pylint: disable=too-many-positional-arguments
42+
def __init__( # noqa: PLR0913 (too-many-arguments) # pylint: disable=too-many-positional-arguments
4343
self,
4444
api_key: Secret = Secret.from_env_var("OPENAI_API_KEY"),
4545
model: str = "text-embedding-ada-002",
@@ -55,6 +55,8 @@ def __init__( # pylint: disable=too-many-positional-arguments
5555
timeout: Optional[float] = None,
5656
max_retries: Optional[int] = None,
5757
http_client_kwargs: Optional[Dict[str, Any]] = None,
58+
*,
59+
raise_on_failure: bool = False,
5860
):
5961
"""
6062
Creates an OpenAIDocumentEmbedder component.
@@ -100,6 +102,9 @@ def __init__( # pylint: disable=too-many-positional-arguments
100102
:param http_client_kwargs:
101103
A dictionary of keyword arguments to configure a custom `httpx.Client`or `httpx.AsyncClient`.
102104
For more information, see the [HTTPX documentation](https://www.python-httpx.org/api/#client).
105+
:param raise_on_failure:
106+
Whether to raise an exception if the embedding request fails. If `False`, the component will log the error
107+
and continue processing the remaining documents. If `True`, it will raise an exception on failure.
103108
"""
104109
self.api_key = api_key
105110
self.model = model
@@ -115,6 +120,7 @@ def __init__( # pylint: disable=too-many-positional-arguments
115120
self.timeout = timeout
116121
self.max_retries = max_retries
117122
self.http_client_kwargs = http_client_kwargs
123+
self.raise_on_failure = raise_on_failure
118124

119125
if timeout is None:
120126
timeout = float(os.environ.get("OPENAI_TIMEOUT", "30.0"))
@@ -163,6 +169,7 @@ def to_dict(self) -> Dict[str, Any]:
163169
timeout=self.timeout,
164170
max_retries=self.max_retries,
165171
http_client_kwargs=self.http_client_kwargs,
172+
raise_on_failure=self.raise_on_failure,
166173
)
167174

168175
@classmethod
@@ -194,12 +201,14 @@ def _prepare_texts_to_embed(self, documents: List[Document]) -> Dict[str, str]:
194201

195202
return texts_to_embed
196203

197-
def _embed_batch(self, texts_to_embed: Dict[str, str], batch_size: int) -> Tuple[List[List[float]], Dict[str, Any]]:
204+
def _embed_batch(
205+
self, texts_to_embed: Dict[str, str], batch_size: int
206+
) -> Tuple[Dict[str, List[float]], Dict[str, Any]]:
198207
"""
199208
Embed a list of texts in batches.
200209
"""
201210

202-
all_embeddings = []
211+
doc_ids_to_embeddings: Dict[str, List[float]] = {}
203212
meta: Dict[str, Any] = {}
204213
for batch in tqdm(
205214
batched(texts_to_embed.items(), batch_size), disable=not self.progress_bar, desc="Calculating embeddings"
@@ -215,10 +224,12 @@ def _embed_batch(self, texts_to_embed: Dict[str, str], batch_size: int) -> Tuple
215224
ids = ", ".join(b[0] for b in batch)
216225
msg = "Failed embedding of documents {ids} caused by {exc}"
217226
logger.exception(msg, ids=ids, exc=exc)
227+
if self.raise_on_failure:
228+
raise exc
218229
continue
219230

220231
embeddings = [el.embedding for el in response.data]
221-
all_embeddings.extend(embeddings)
232+
doc_ids_to_embeddings.update(dict(zip((b[0] for b in batch), embeddings)))
222233

223234
if "model" not in meta:
224235
meta["model"] = response.model
@@ -228,16 +239,16 @@ def _embed_batch(self, texts_to_embed: Dict[str, str], batch_size: int) -> Tuple
228239
meta["usage"]["prompt_tokens"] += response.usage.prompt_tokens
229240
meta["usage"]["total_tokens"] += response.usage.total_tokens
230241

231-
return all_embeddings, meta
242+
return doc_ids_to_embeddings, meta
232243

233244
async def _embed_batch_async(
234245
self, texts_to_embed: Dict[str, str], batch_size: int
235-
) -> Tuple[List[List[float]], Dict[str, Any]]:
246+
) -> Tuple[Dict[str, List[float]], Dict[str, Any]]:
236247
"""
237248
Embed a list of texts in batches asynchronously.
238249
"""
239250

240-
all_embeddings = []
251+
doc_ids_to_embeddings: Dict[str, List[float]] = {}
241252
meta: Dict[str, Any] = {}
242253

243254
batches = list(batched(texts_to_embed.items(), batch_size))
@@ -256,10 +267,12 @@ async def _embed_batch_async(
256267
ids = ", ".join(b[0] for b in batch)
257268
msg = "Failed embedding of documents {ids} caused by {exc}"
258269
logger.exception(msg, ids=ids, exc=exc)
270+
if self.raise_on_failure:
271+
raise exc
259272
continue
260273

261274
embeddings = [el.embedding for el in response.data]
262-
all_embeddings.extend(embeddings)
275+
doc_ids_to_embeddings.update(dict(zip((b[0] for b in batch), embeddings)))
263276

264277
if "model" not in meta:
265278
meta["model"] = response.model
@@ -269,7 +282,7 @@ async def _embed_batch_async(
269282
meta["usage"]["prompt_tokens"] += response.usage.prompt_tokens
270283
meta["usage"]["total_tokens"] += response.usage.total_tokens
271284

272-
return all_embeddings, meta
285+
return doc_ids_to_embeddings, meta
273286

274287
@component.output_types(documents=List[Document], meta=Dict[str, Any])
275288
def run(self, documents: List[Document]):
@@ -292,12 +305,13 @@ def run(self, documents: List[Document]):
292305

293306
texts_to_embed = self._prepare_texts_to_embed(documents=documents)
294307

295-
embeddings, meta = self._embed_batch(texts_to_embed=texts_to_embed, batch_size=self.batch_size)
308+
doc_ids_to_embeddings, meta = self._embed_batch(texts_to_embed=texts_to_embed, batch_size=self.batch_size)
296309

297-
for doc, emb in zip(documents, embeddings):
298-
doc.embedding = emb
310+
doc_id_to_document = {doc.id: doc for doc in documents}
311+
for doc_id, emb in doc_ids_to_embeddings.items():
312+
doc_id_to_document[doc_id].embedding = emb
299313

300-
return {"documents": documents, "meta": meta}
314+
return {"documents": list(doc_id_to_document.values()), "meta": meta}
301315

302316
@component.output_types(documents=List[Document], meta=Dict[str, Any])
303317
async def run_async(self, documents: List[Document]):
@@ -320,9 +334,12 @@ async def run_async(self, documents: List[Document]):
320334

321335
texts_to_embed = self._prepare_texts_to_embed(documents=documents)
322336

323-
embeddings, meta = await self._embed_batch_async(texts_to_embed=texts_to_embed, batch_size=self.batch_size)
337+
doc_ids_to_embeddings, meta = await self._embed_batch_async(
338+
texts_to_embed=texts_to_embed, batch_size=self.batch_size
339+
)
324340

325-
for doc, emb in zip(documents, embeddings):
326-
doc.embedding = emb
341+
doc_id_to_document = {doc.id: doc for doc in documents}
342+
for doc_id, emb in doc_ids_to_embeddings.items():
343+
doc_id_to_document[doc_id].embedding = emb
327344

328-
return {"documents": documents, "meta": meta}
345+
return {"documents": list(doc_id_to_document.values()), "meta": meta}
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
---
2+
features:
3+
- |
4+
Added a raise_on_failure boolean parameter to OpenAIDocumentEmbedder and AzureOpenAIDocumentEmbedder. If set to True then the component will raise an exception when there is an error with the API request. It is set to False by default to so the previous behavior of logging an exception and continuing is still the default.
5+
fixes:
6+
- |
7+
Fix bug where if raise_on_failure=False and an error occurs mid-batch that the following embeddings would be paired with the wrong documents.

test/components/embedders/test_azure_document_embedder.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ def test_to_dict(self, monkeypatch):
7777
"default_headers": {},
7878
"azure_ad_token_provider": None,
7979
"http_client_kwargs": None,
80+
"raise_on_failure": False,
8081
},
8182
}
8283

@@ -94,6 +95,7 @@ def test_to_dict_with_parameters(self, monkeypatch):
9495
default_headers={"x-custom-header": "custom-value"},
9596
azure_ad_token_provider=default_azure_ad_token_provider,
9697
http_client_kwargs={"proxy": "http://example.com:3128", "verify": False},
98+
raise_on_failure=True,
9799
)
98100
data = component.to_dict()
99101
assert data == {
@@ -117,6 +119,7 @@ def test_to_dict_with_parameters(self, monkeypatch):
117119
"default_headers": {"x-custom-header": "custom-value"},
118120
"azure_ad_token_provider": "haystack.utils.azure.default_azure_ad_token_provider",
119121
"http_client_kwargs": {"proxy": "http://example.com:3128", "verify": False},
122+
"raise_on_failure": True,
120123
},
121124
}
122125

@@ -143,6 +146,7 @@ def test_from_dict(self, monkeypatch):
143146
"default_headers": {},
144147
"azure_ad_token_provider": None,
145148
"http_client_kwargs": None,
149+
"raise_on_failure": False,
146150
},
147151
}
148152
component = AzureOpenAIDocumentEmbedder.from_dict(data)
@@ -156,6 +160,7 @@ def test_from_dict(self, monkeypatch):
156160
assert component.default_headers == {}
157161
assert component.azure_ad_token_provider is None
158162
assert component.http_client_kwargs is None
163+
assert component.raise_on_failure is False
159164

160165
def test_from_dict_with_parameters(self, monkeypatch):
161166
monkeypatch.setenv("AZURE_OPENAI_API_KEY", "fake-api-key")
@@ -180,6 +185,7 @@ def test_from_dict_with_parameters(self, monkeypatch):
180185
"default_headers": {"x-custom-header": "custom-value"},
181186
"azure_ad_token_provider": "haystack.utils.azure.default_azure_ad_token_provider",
182187
"http_client_kwargs": {"proxy": "http://example.com:3128", "verify": False},
188+
"raise_on_failure": True,
183189
},
184190
}
185191
component = AzureOpenAIDocumentEmbedder.from_dict(data)
@@ -193,6 +199,7 @@ def test_from_dict_with_parameters(self, monkeypatch):
193199
assert component.default_headers == {"x-custom-header": "custom-value"}
194200
assert component.azure_ad_token_provider is not None
195201
assert component.http_client_kwargs == {"proxy": "http://example.com:3128", "verify": False}
202+
assert component.raise_on_failure is True
196203

197204
def test_embed_batch_handles_exceptions_gracefully(self, caplog):
198205
embedder = AzureOpenAIDocumentEmbedder(
@@ -214,6 +221,22 @@ def test_embed_batch_handles_exceptions_gracefully(self, caplog):
214221
assert len(caplog.records) == 1
215222
assert "Failed embedding of documents 1, 2 caused by Mocked error" in caplog.text
216223

224+
def test_embed_batch_raises_exception_on_failure(self):
225+
embedder = AzureOpenAIDocumentEmbedder(
226+
azure_endpoint="https://test.openai.azure.com",
227+
api_key=Secret.from_token("fake-api-key"),
228+
azure_deployment="text-embedding-ada-002",
229+
raise_on_failure=True,
230+
)
231+
fake_texts_to_embed = {"1": "text1", "2": "text2"}
232+
with patch.object(
233+
embedder.client.embeddings,
234+
"create",
235+
side_effect=APIError(message="Mocked error", request=Mock(), body=None),
236+
):
237+
with pytest.raises(APIError, match="Mocked error"):
238+
embedder._embed_batch(texts_to_embed=fake_texts_to_embed, batch_size=2)
239+
217240
@pytest.mark.integration
218241
@pytest.mark.skipif(
219242
not os.environ.get("AZURE_OPENAI_API_KEY", None) and not os.environ.get("AZURE_OPENAI_ENDPOINT", None),

test/components/embedders/test_openai_document_embedder.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,7 @@ def test_to_dict(self, monkeypatch):
124124
"embedding_separator": "\n",
125125
"timeout": None,
126126
"max_retries": None,
127+
"raise_on_failure": False,
127128
},
128129
}
129130

@@ -142,6 +143,7 @@ def test_to_dict_with_custom_init_parameters(self, monkeypatch):
142143
embedding_separator=" | ",
143144
timeout=10.0,
144145
max_retries=2,
146+
raise_on_failure=True,
145147
)
146148
data = component.to_dict()
147149
assert data == {
@@ -161,6 +163,7 @@ def test_to_dict_with_custom_init_parameters(self, monkeypatch):
161163
"embedding_separator": " | ",
162164
"timeout": 10.0,
163165
"max_retries": 2,
166+
"raise_on_failure": True,
164167
},
165168
}
166169

@@ -236,6 +239,45 @@ def test_embed_batch_handles_exceptions_gracefully(self, caplog):
236239
assert len(caplog.records) == 1
237240
assert "Failed embedding of documents 1, 2 caused by Mocked error" in caplog.records[0].msg
238241

242+
def test_run_handles_exceptions_gracefully(self, caplog):
243+
embedder = OpenAIDocumentEmbedder(api_key=Secret.from_token("fake_api_key"), batch_size=1)
244+
docs = [
245+
Document(content="I love cheese", meta={"topic": "Cuisine"}),
246+
Document(content="A transformer is a deep learning architecture", meta={"topic": "ML"}),
247+
]
248+
249+
# Create a successful response for the second call
250+
successful_response = Mock()
251+
successful_response.data = [
252+
Mock(embedding=[0.4, 0.5, 0.6]) # Mock embedding for second doc
253+
]
254+
successful_response.model = "text-embedding-ada-002"
255+
successful_response.usage = {"prompt_tokens": 10, "total_tokens": 10}
256+
257+
with patch.object(
258+
embedder.client.embeddings,
259+
"create",
260+
side_effect=[
261+
APIError(message="Mocked error", request=Mock(), body=None), # First call fails
262+
successful_response, # Second call succeeds
263+
],
264+
):
265+
result = embedder.run(documents=docs)
266+
assert len(result["documents"]) == 2
267+
assert result["documents"][0].embedding is None
268+
assert result["documents"][1].embedding == [0.4, 0.5, 0.6]
269+
270+
def test_embed_batch_raises_exception_on_failure(self):
271+
embedder = OpenAIDocumentEmbedder(api_key=Secret.from_token("fake_api_key"), raise_on_failure=True)
272+
fake_texts_to_embed = {"1": "text1", "2": "text2"}
273+
with patch.object(
274+
embedder.client.embeddings,
275+
"create",
276+
side_effect=APIError(message="Mocked error", request=Mock(), body=None),
277+
):
278+
with pytest.raises(APIError, match="Mocked error"):
279+
embedder._embed_batch(texts_to_embed=fake_texts_to_embed, batch_size=2)
280+
239281
@pytest.mark.skipif(os.environ.get("OPENAI_API_KEY", "") == "", reason="OPENAI_API_KEY is not set")
240282
@pytest.mark.integration
241283
def test_run(self):

0 commit comments

Comments
 (0)