Skip to content

Commit 8c90f9d

Browse files
committed
Add missing license headers, auto call warmup, don't modify docs in place
1 parent 78ec9d3 commit 8c90f9d

8 files changed

Lines changed: 27 additions & 28 deletions

File tree

integrations/nvidia/src/haystack_integrations/components/embedders/nvidia/document_embedder.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import os
66
import warnings
7+
from dataclasses import replace
78
from typing import Any, Optional, Union
89

910
from haystack import Document, component, default_from_dict, default_to_dict, logging
@@ -49,7 +50,7 @@ def __init__(
4950
embedding_separator: str = "\n",
5051
truncate: Optional[Union[EmbeddingTruncateMode, str]] = None,
5152
timeout: Optional[float] = None,
52-
):
53+
) -> None:
5354
"""
5455
Create a NvidiaTextEmbedder component.
5556
@@ -108,7 +109,7 @@ def __init__(
108109
def class_name(cls) -> str:
109110
return "NvidiaDocumentEmbedder"
110111

111-
def default_model(self):
112+
def default_model(self) -> None:
112113
"""Set default model in local NIM mode."""
113114
valid_models = [
114115
model.id for model in self.available_models if not model.base_model or model.base_model == model.id
@@ -129,7 +130,7 @@ def default_model(self):
129130
error_message = "No locally hosted model was found."
130131
raise ValueError(error_message)
131132

132-
def warm_up(self):
133+
def warm_up(self) -> None:
133134
"""
134135
Initializes the component.
135136
"""
@@ -267,7 +268,9 @@ def run(self, documents: list[Document]) -> dict[str, Union[list[Document], dict
267268

268269
texts_to_embed = self._prepare_texts_to_embed(documents)
269270
embeddings, metadata = self._embed_batch(texts_to_embed, self.batch_size)
271+
272+
new_documents = []
270273
for doc, emb in zip(documents, embeddings):
271-
doc.embedding = emb
274+
new_documents.append(replace(doc, embedding=emb))
272275

273-
return {"documents": documents, "meta": metadata}
276+
return {"documents": new_documents, "meta": metadata}

integrations/nvidia/src/haystack_integrations/components/generators/nvidia/chat/chat_generator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def __init__(
6060
timeout: Optional[float] = None,
6161
max_retries: Optional[int] = None,
6262
http_client_kwargs: Optional[dict[str, Any]] = None,
63-
):
63+
) -> None:
6464
"""
6565
Creates an instance of NvidiaChatGenerator.
6666

integrations/nvidia/src/haystack_integrations/components/generators/nvidia/generator.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def __init__(
4949
api_key: Optional[Secret] = Secret.from_env_var("NVIDIA_API_KEY"),
5050
model_arguments: Optional[dict[str, Any]] = None,
5151
timeout: Optional[float] = None,
52-
):
52+
) -> None:
5353
"""
5454
Create a NvidiaGenerator component.
5555
@@ -90,7 +90,7 @@ def __init__(
9090
def class_name(cls) -> str:
9191
return "NvidiaGenerator"
9292

93-
def default_model(self):
93+
def default_model(self) -> None:
9494
"""Set default model in local NIM mode."""
9595
valid_models = [
9696
model.id for model in self.available_models if not model.base_model or model.base_model == model.id
@@ -111,7 +111,7 @@ def default_model(self):
111111
error_message = "No locally hosted model was found."
112112
raise ValueError(error_message)
113113

114-
def warm_up(self):
114+
def warm_up(self) -> None:
115115
"""
116116
Initializes the component.
117117
"""
@@ -183,8 +183,7 @@ def run(self, prompt: str) -> dict[str, Union[list[str], list[dict[str, Any]]]]:
183183
- `meta` - Metadata for each reply.
184184
"""
185185
if self.backend is None:
186-
msg = "The generation model has not been loaded. Call warm_up() before running."
187-
raise RuntimeError(msg)
186+
self.warm_up()
188187

189188
assert self.backend is not None
190189
replies, meta = self.backend.generate(prompt=prompt)

integrations/nvidia/src/haystack_integrations/components/rankers/nvidia/ranker.py

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def __init__(
5757
meta_fields_to_embed: Optional[list[str]] = None,
5858
embedding_separator: str = "\n",
5959
timeout: Optional[float] = None,
60-
):
60+
) -> None:
6161
"""
6262
Create a NvidiaRanker component.
6363
@@ -155,7 +155,7 @@ def from_dict(cls, data: dict[str, Any]) -> "NvidiaRanker":
155155
deserialize_secrets_inplace(data["init_parameters"], keys=["api_key"])
156156
return default_from_dict(cls, data)
157157

158-
def warm_up(self):
158+
def warm_up(self) -> None:
159159
"""
160160
Initialize the ranker.
161161
@@ -192,27 +192,21 @@ def _prepare_documents_to_embed(self, documents: list[Document]) -> list[str]:
192192
return document_texts
193193

194194
@component.output_types(documents=list[Document])
195-
def run(
196-
self,
197-
query: str,
198-
documents: list[Document],
199-
top_k: Optional[int] = None,
200-
) -> dict[str, list[Document]]:
195+
def run(self, query: str, documents: list[Document], top_k: Optional[int] = None) -> dict[str, list[Document]]:
201196
"""
202197
Rank a list of documents based on a given query.
203198
204199
:param query: The query to rank the documents against.
205200
:param documents: The list of documents to rank.
206201
:param top_k: The number of documents to return.
207202
208-
:raises RuntimeError: If the ranker has not been loaded.
209203
:raises TypeError: If the arguments are of the wrong type.
210204
211205
:returns: A dictionary containing the ranked documents.
212206
"""
213207
if not self._initialized:
214-
msg = "The ranker has not been loaded. Please call warm_up() before running."
215-
raise RuntimeError(msg)
208+
self.warm_up()
209+
216210
if not isinstance(query, str):
217211
msg = "NvidiaRanker expects the `query` parameter to be a string."
218212
raise TypeError(msg)

integrations/nvidia/src/haystack_integrations/utils/nvidia/client.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
# SPDX-FileCopyrightText: 2024-present deepset GmbH <info@deepset.ai>
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
15
from enum import Enum
26

37

integrations/nvidia/src/haystack_integrations/utils/nvidia/models.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
# SPDX-FileCopyrightText: 2024-present deepset GmbH <info@deepset.ai>
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
15
from dataclasses import dataclass
26
from typing import Literal, Optional, Union
37

integrations/nvidia/tests/test_nim_backend.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# SPDX-FileCopyrightText: 2024-present deepset GmbH <info@deepset.ai>
22
#
33
# SPDX-License-Identifier: Apache-2.0
4+
45
import json
56
from unittest.mock import patch
67

integrations/nvidia/tests/test_ranker.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -46,12 +46,6 @@ def test_init_pass_wo_api_key_w_api_url(self):
4646
client = NvidiaRanker(api_url=url)
4747
assert client.api_url == url
4848

49-
def test_warm_up_required(self):
50-
client = NvidiaRanker()
51-
with pytest.raises(RuntimeError) as e:
52-
client.run("query", [Document(content="doc")])
53-
assert "not been loaded" in str(e.value)
54-
5549
@pytest.mark.parametrize(
5650
"truncate",
5751
[

0 commit comments

Comments
 (0)