Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions midrasai/vectordb/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,11 @@
_, __ = Qdrant, AsyncQdrant # Redundant alias for F401

__all__.extend(["Qdrant", "AsyncQdrant"])

# Check for AstraDB support
if importlib.util.find_spec("astrapy"):
from midrasai.vectordb._astradb import AstraDB

_ = AstraDB # Redundant alias for F401

__all__.extend(["AstraDB"])
211 changes: 211 additions & 0 deletions midrasai/vectordb/_astradb.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,211 @@
import datetime
import json
from typing import Any, Dict, List, Set, Tuple

from astrapy import DataAPIClient
from astrapy.authentication import TokenProvider
from astrapy.info import CollectionDefinition
from astrapy.constants import VectorMetric

from midrasai._abc import VectorDB
from midrasai.types import ColBERT, QueryResult


class AstraDB(VectorDB):
def __init__(
self,
access_token: str | TokenProvider,
api_endpoint: str,
**kwargs: Any,
):
self.client = DataAPIClient(access_token)
self.database = self.client.get_database(api_endpoint)
self.index_doc_ids: Dict[str, Set[str]] = {}

def create_index(self, name: str) -> bool:
"""
Create AstraDB collection under default database

:param name: Name of the collection, must be used in save_points and search
:type name: str
:return: Returns outcome of creating collection
:rtype: bool
"""
self.index_doc_ids[name] = set()
collection_definition = (
CollectionDefinition.builder()
.set_vector_dimension(128)
.set_vector_metric(VectorMetric.COSINE)
.build()
)
try:
return bool(
self.database.create_collection(name, definition=collection_definition)
)
except Exception as e:
print(f"Error creating collection {name}: {e}")
return False

def create_point(
self, id: int | str, embedding: ColBERT, data: Dict[str, Any]
) -> List[Dict[str, Any]]:
"""
Create documents for a single point (matching your existing interface)

Args:
id: Document ID
embedding: ColBERT embedding (list of vectors)
data: Metadata for the document

Returns:
List of documents to be inserted
"""
documents = []

# Create metadata document with embedding_id = -1
doc_id = str(id)
metadata_doc = {
"_id": f"{doc_id}:-1",
"doc_id": doc_id,
"embedding_id": -1,
"metadata": data,
}
documents.append(metadata_doc)

# Create a document for each embedding vector
for i, vector in enumerate(embedding):
embedding_doc = {
"_id": f"{doc_id}:{i}",
"doc_id": doc_id,
"embedding_id": i,
"$vector": vector,
}
documents.append(embedding_doc)

return documents

def save_points(
self,
index: str,
points: List[List[Dict[str, Any]]],
) -> bool:
"""
Save points to AstraDB

Args:
index: Name of the collection
points: List of point structures generated by create_point

Returns:
bool: True if successful
"""
collection = self.database.get_collection(index)
documents_to_insert: List[Dict[str, Any]] = []

for point in points:
doc_id = point[0].get("doc_id")
if doc_id:
self.index_doc_ids[index].add(doc_id)
documents_to_insert.extend(point)

# Insert documents
success = True
try:
result = collection.insert_many(documents_to_insert)
if not result or len(result.inserted_ids) != len(documents_to_insert):
print(
f"Incomplete insert, {len(result.inserted_ids)} inserted of {len(documents_to_insert)}."
)
success = False
except Exception as e:
print(f"Error inserting documents: {e}")
success = False

return success

def delete_index(self, name: str) -> bool:
try:
self.index_doc_ids.pop(name, None)
self.database.drop_collection(name)
return True
except Exception as e:
print(f"Error deleting collection {name}: {e}")
return False

def search(
self,
index: str,
query_embedding: ColBERT,
quantity: int,
) -> List[QueryResult]:
"""
Search for documents similar to a ColBERT query embedding using MaxSim

Args:
index: Name of the collection
query_embedding: ColBERT embedding (list of vectors) to search with
quantity: Maximum number of results to return

Returns:
List of document metadata with similarity scores
"""
collection = self.database.get_collection(index)

# Step 1 of MaxSim: Find maximum simarities between each query vector and document vectors.
doc_max_sims: Dict[str, List[float]] = {}
for doc_id in self.index_doc_ids[index]:
doc_max_sims[doc_id] = [0] * len(query_embedding)
for i, query_vector in enumerate(query_embedding):
# Search for highest similarity vector in document
search_result = collection.find_one(
filter={
"$and": [
{"doc_id": doc_id},
{
"embedding_id": {"$gte": 0}
}, # Don't search metadata, which has embedding_id of -1
]
},
include_similarity=True,
sort={"$vector": query_vector},
)
if search_result:
similarity = search_result.get("$similarity")
if similarity:
doc_max_sims[doc_id][i] = float(similarity)

# Step 2 of MaxSim: Calculate final scores by summing max query vector similarities of each document
doc_scores = {}
for doc_id, max_sims in doc_max_sims.items():
doc_scores[doc_id] = sum(max_sims)

# Sort documents by score
sorted_docs: List[Tuple[str, float]] = sorted(
doc_scores.items(), key=lambda x: x[1], reverse=True
)[
:quantity
] # Limit to requested number of results

# Fetch metadata for each doc_id
metadata_docs = collection.find(
filter={
"doc_id": {"$in": [item[0] for item in sorted_docs]},
"embedding_id": -1,
},
projection={"doc_id": True, "metadata": True},
)
metadata_dict = {
doc.get("doc_id"): doc.get("metadata") for doc in metadata_docs
}

# Create query results, already in sorted order
query_results: List[QueryResult] = []
for doc_id, similarity in sorted_docs:
query_results.append(
QueryResult(
id=int(doc_id),
score=similarity,
data=metadata_dict[doc_id],
)
)
return query_results
Loading