Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions haystack/components/embedders/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
"openai_text_embedder": ["OpenAITextEmbedder"],
"sentence_transformers_document_embedder": ["SentenceTransformersDocumentEmbedder"],
"sentence_transformers_text_embedder": ["SentenceTransformersTextEmbedder"],
"watsonx_document_embedder": ["WatsonXDocumentEmbedder"],
"watsonx_text_embedder": ["WatsonXTextEmbedder"],
}

if TYPE_CHECKING:
Expand All @@ -27,6 +29,8 @@
from .openai_text_embedder import OpenAITextEmbedder
from .sentence_transformers_document_embedder import SentenceTransformersDocumentEmbedder
from .sentence_transformers_text_embedder import SentenceTransformersTextEmbedder
from .watsonx_document_embedder import WatsonXDocumentEmbedder
from .watsonx_text_embedder import WatsonXTextEmbedder

else:
sys.modules[__name__] = LazyImporter(name=__name__, module_file=__file__, import_structure=_import_structure)
191 changes: 191 additions & 0 deletions haystack/components/embedders/watsonx_document_embedder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
# SPDX-FileCopyrightText: 2023-present IBM Corporation
#
# SPDX-License-Identifier: Apache-2.0

from typing import Any, Dict, List, Optional

from ibm_watsonx_ai import Credentials
from ibm_watsonx_ai.foundation_models import Embeddings

from haystack import Document, component, default_from_dict, default_to_dict
from haystack.utils import Secret, deserialize_secrets_inplace


@component
class WatsonXDocumentEmbedder:
"""
Computes document embeddings using IBM watsonx.ai models.

### Usage example

```python
from haystack import Document
from haystack.components.embedders import WatsonXDocumentEmbedder

documents = [Document(content="I love pizza!"), Document(content="Pasta is great too")]

document_embedder = WatsonXDocumentEmbedder(
model="ibm/slate-30m-english-rtrvr",
api_key=Secret.from_env_var("WATSONX_API_KEY"),
url="https://us-south.ml.cloud.ibm.com",
project_id="your-project-id"
)

result = document_embedder.run(documents=documents)
print(result['documents'][0].embedding)

# [0.017020374536514282, -0.023255806416273117, ...]
```
"""

def __init__(
self,
model: str = "ibm/slate-30m-english-rtrvr",
api_key: Secret = Secret.from_env_var("WATSONX_API_KEY"),
url: str = "https://us-south.ml.cloud.ibm.com",
project_id: Optional[str] = None,
space_id: Optional[str] = None,
truncate_input_tokens: Optional[int] = None,
prefix: str = "",
suffix: str = "",
batch_size: int = 1000,
concurrency_limit: int = 5,
timeout: Optional[float] = None,
max_retries: Optional[int] = None,
):
"""
Creates a WatsonXDocumentEmbedder component.

:param model:
The name of the model to use for calculating embeddings.
Default is "ibm/slate-30m-english-rtrvr".
:param api_key:
The WATSONX API key. Can be set via environment variable WATSONX_API_KEY.
:param url:
The WATSONX URL for the watsonx.ai service.
Default is "https://us-south.ml.cloud.ibm.com".
:param project_id:
The ID of the Watson Studio project. Either project_id or space_id must be provided.
:param space_id:
The ID of the Watson Studio space. Either project_id or space_id must be provided.
:param truncate_input_tokens:
Maximum number of tokens to use from the input text.
:param prefix:
A string to add at the beginning of each text.
:param suffix:
A string to add at the end of each text.
:param batch_size:
Number of documents to embed in one API call. Default is 1000.
:param concurrency_limit:
Number of parallel requests to make. Default is 5.
:param timeout:
Timeout for API requests in seconds.
:param max_retries:
Maximum number of retries for API requests.
"""
if not project_id and not space_id:
raise ValueError("Either project_id or space_id must be provided")

self.model = model
self.api_key = api_key
self.url = url
self.project_id = project_id
self.space_id = space_id
self.truncate_input_tokens = truncate_input_tokens
self.prefix = prefix
self.suffix = suffix
self.batch_size = batch_size
self.concurrency_limit = concurrency_limit
self.timeout = timeout
self.max_retries = max_retries

# Initialize the embeddings client
credentials = Credentials(api_key=api_key.resolve_value(), url=url)

params = {}
if truncate_input_tokens is not None:
params["truncate_input_tokens"] = truncate_input_tokens

self.embedder = Embeddings(
model_id=model,
credentials=credentials,
project_id=project_id,
space_id=space_id,
params=params if params else None,
batch_size=batch_size,
concurrency_limit=concurrency_limit,
max_retries=max_retries or 10,
)

def _get_telemetry_data(self) -> Dict[str, Any]:
"""
Data that is sent to Posthog for usage analytics.
"""
return {"model": self.model}

def to_dict(self) -> Dict[str, Any]:
"""
Serializes the component to a dictionary.
"""
return default_to_dict(
self,
model=self.model,
api_key=self.api_key.to_dict(),
url=self.url,
project_id=self.project_id,
space_id=self.space_id,
truncate_input_tokens=self.truncate_input_tokens,
prefix=self.prefix,
suffix=self.suffix,
batch_size=self.batch_size,
concurrency_limit=self.concurrency_limit,
timeout=self.timeout,
max_retries=self.max_retries,
)

@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "WatsonXDocumentEmbedder":
"""
Deserializes the component from a dictionary.
"""
deserialize_secrets_inplace(data["init_parameters"], keys=["api_key"])
return default_from_dict(cls, data)

def _prepare_text(self, text: str) -> str:
"""
Prepares text for embedding by adding prefix and suffix.
"""
return self.prefix + text + self.suffix

@component.output_types(documents=List[Document], meta=Dict[str, Any])
def run(self, documents: List[Document]):
"""
Embeds a list of documents.

:param documents:
A list of documents to embed.
:returns:
A dictionary with:
- 'documents': List of Documents with embeddings added
- 'meta': Information about the model usage
"""
if not isinstance(documents, list) or documents and not isinstance(documents[0], Document):
raise TypeError(
"WatsonXDocumentEmbedder expects a list of Documents as input."
"In case you want to embed a string, please use the WatsonXTextEmbedder."
)

texts_to_embed = [self._prepare_text(doc.content or "") for doc in documents]
embeddings = self.embedder.embed_documents(texts_to_embed)

for doc, emb in zip(documents, embeddings):
doc.embedding = emb

return {
"documents": documents,
"meta": {
"model": self.model,
"truncate_input_tokens": self.truncate_input_tokens,
"batch_size": self.batch_size,
},
}
167 changes: 167 additions & 0 deletions haystack/components/embedders/watsonx_text_embedder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
import os
from typing import Any, Dict, List, Optional

from ibm_watsonx_ai import Credentials
from ibm_watsonx_ai.foundation_models import Embeddings
from ibm_watsonx_ai.foundation_models.utils.enums import EmbeddingTypes

from haystack import component, default_from_dict, default_to_dict
from haystack.utils import Secret, deserialize_secrets_inplace


@component
class WatsonXTextEmbedder:
"""
Embeds strings using IBM watsonx.ai foundation models.

You can use it to embed user query and send it to an embedding Retriever.

### Usage example

```python
from haystack.components.embedders import WatsonXTextEmbedder

text_to_embed = "I love pizza!"

text_embedder = WatsonXTextEmbedder(
model="ibm/slate-30m-english-rtrvr",
api_key=Secret.from_env_var("WATSONX_API_KEY"),
url="https://us-south.ml.cloud.ibm.com",
project_id="your-project-id"
)

print(text_embedder.run(text_to_embed))

# {'embedding': [0.017020374536514282, -0.023255806416273117, ...],
# 'meta': {'model': 'ibm/slate-30m-english-rtrvr',
# 'truncated_input_tokens': 3}}
```
"""

def __init__(
self,
model: str = "ibm/slate-30m-english-rtrvr",
api_key: Secret = Secret.from_env_var("WATSONX_API_KEY"),
url: str = "https://us-south.ml.cloud.ibm.com",
project_id: Optional[str] = None,
space_id: Optional[str] = None,
truncate_input_tokens: Optional[int] = None,
prefix: str = "",
suffix: str = "",
timeout: Optional[float] = None,
max_retries: Optional[int] = None,
):
"""
Creates an WatsonXTextEmbedder component.

:param model:
The name of the IBM watsonx model to use for calculating embeddings.
Default is "ibm/slate-30m-english-rtrvr".
:param api_key:
The WATSONX API key. Can be set via environment variable WATSONX_API_KEY.
:param url:
The WATSONX URL for the watsonx.ai service.
Default is "https://us-south.ml.cloud.ibm.com".
:param project_id:
The ID of the Watson Studio project. Either project_id or space_id must be provided.
:param space_id:
The ID of the Watson Studio space. Either project_id or space_id must be provided.
:param truncate_input_tokens:
Maximum number of tokens to use from the input text.
:param prefix:
A string to add at the beginning of each text to embed.
:param suffix:
A string to add at the end of each text to embed.
:param timeout:
Timeout for API requests in seconds.
:param max_retries:
Maximum number of retries for API requests.
"""
if not project_id and not space_id:
raise ValueError("Either project_id or space_id must be provided")

self.model = model
self.api_key = api_key
self.url = url
self.project_id = project_id
self.space_id = space_id
self.truncate_input_tokens = truncate_input_tokens
self.prefix = prefix
self.suffix = suffix
self.timeout = timeout
self.max_retries = max_retries

# Initialize the embeddings client
credentials = Credentials(api_key=api_key.resolve_value(), url=url)

params = {}
if truncate_input_tokens is not None:
params["truncate_input_tokens"] = truncate_input_tokens

self.embedder = Embeddings(
model_id=model,
credentials=credentials,
project_id=project_id,
space_id=space_id,
params=params if params else None,
)

def _get_telemetry_data(self) -> Dict[str, Any]:
"""
Data that is sent to Posthog for usage analytics.
"""
return {"model": self.model}

def to_dict(self) -> Dict[str, Any]:
"""
Serializes the component to a dictionary.
"""
return default_to_dict(
self,
model=self.model,
api_key=self.api_key.to_dict(),
url=self.url,
project_id=self.project_id,
space_id=self.space_id,
truncate_input_tokens=self.truncate_input_tokens,
prefix=self.prefix,
suffix=self.suffix,
timeout=self.timeout,
max_retries=self.max_retries,
)

@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "WatsonXTextEmbedder":
"""
Deserializes the component from a dictionary.
"""
deserialize_secrets_inplace(data["init_parameters"], keys=["api_key"])
return default_from_dict(cls, data)

def _prepare_input(self, text: str) -> str:
if not isinstance(text, str):
raise TypeError(
"WatsonXTextEmbedder expects a string as an input. "
"In case you want to embed a list of Documents, please use the WatsonXDocumentEmbedder."
)
return self.prefix + text + self.suffix

def _prepare_output(self, embedding: List[float]) -> Dict[str, Any]:
return {
"embedding": embedding,
"meta": {"model": self.model, "truncated_input_tokens": self.truncate_input_tokens},
}

@component.output_types(embedding=List[float], meta=Dict[str, Any])
def run(self, text: str):
"""
Embeds a single string.

:param text: Text to embed.
:returns: A dictionary with:
- 'embedding': The embedding of the input text
- 'meta': Information about the model usage
"""
text_to_embed = self._prepare_input(text=text)
embedding = self.embedder.embed_query(text_to_embed)
return self._prepare_output(embedding)
2 changes: 2 additions & 0 deletions haystack/components/generators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
"hugging_face_local": ["HuggingFaceLocalGenerator"],
"hugging_face_api": ["HuggingFaceAPIGenerator"],
"openai_dalle": ["DALLEImageGenerator"],
"watsonx": ["WatsonxGenerator"],
}

if TYPE_CHECKING:
Expand All @@ -21,6 +22,7 @@
from .hugging_face_local import HuggingFaceLocalGenerator
from .openai import OpenAIGenerator
from .openai_dalle import DALLEImageGenerator
from .watsonx import WatsonxGenerator

else:
sys.modules[__name__] = LazyImporter(name=__name__, module_file=__file__, import_structure=_import_structure)
Loading