Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 83 additions & 33 deletions nodes/src/nodes/milvus/milvus.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,12 @@
from ai.common.store import DocumentStoreBase
from ai.common.config import Config

# Default batch size for bulk upsert operations
DEFAULT_BULK_INSERT_BATCH_SIZE = 50

# Default connection timeout in seconds
DEFAULT_TIMEOUT = 60


def _escape_milvus_str(value: object) -> str:
"""Escape a value for safe interpolation into a Milvus filter expression."""
Expand Down Expand Up @@ -89,6 +95,10 @@ def __init__(self, provider: str, connConfig: Dict[str, Any], bag: Dict[str, Any
self.renderChunkSize = config.get('renderChunkSize', self.renderChunkSize)
self.threshold_search = config.get('score', 0.5)

# Configurable timeout (seconds) and bulk insert batch size
self.timeout = max(int(config.get('timeout', DEFAULT_TIMEOUT)), 1)
self.bulkInsertBatchSize = max(int(config.get('bulkInsertBatchSize', DEFAULT_BULK_INSERT_BATCH_SIZE)), 1)
Comment thread
stepmikhaylov marked this conversation as resolved.

profile = config.get('mode')

# check if the similarity matches milvus configuration options
Expand All @@ -98,15 +108,16 @@ def __init__(self, provider: str, connConfig: Dict[str, Any], bag: Dict[str, Any
else:
raise Exception('The metric you provided in the config.json does not match required milvus configurations')

# Establish a connection // TODO: Revise alternative setup as this connection action is only necessary for the flush() method
if profile != 'local':
# Init the store
if self.host.startswith('https:') or self.host.startswith('http:'):
self.client = MilvusClient(uri=self.host, token=self.apikey, timeout=20)
# Establish a connection to the Milvus instance with configurable timeout
try:
if profile != 'local':
# Init the store (host was stripped of protocol at line 87, so always add https://)
self.client = MilvusClient(uri=f'https://{self.host}', token=self.apikey, timeout=self.timeout)
else:
self.client = MilvusClient(uri=f'https://{self.host}', token=self.apikey, timeout=20)
else:
self.client = MilvusClient(uri=f'http://{self.host}:{self.port}', timeout=20)
self.client = MilvusClient(uri=f'http://{self.host}:{self.port}', timeout=self.timeout)
except Exception as e:
self.client = None
raise Exception(f'Failed to connect to Milvus at {self.host}: {e}') from e

return

Expand Down Expand Up @@ -250,7 +261,9 @@ def _convertToDocs(self, points: List[dict]) -> List[Doc]:
entity = point
score = 0
else:
# If we are return scaled scores, build it TODO: CHECK IF THIS IS ALSO THE CASE FOR MILVUS (-1 to 1 range) OR MIGHT IT BE CORRECTED ALREADY?
# Milvus COSINE distance returns values in the range [0, 2] where 0 is
# identical. We rescale to [0, 1] with 1 meaning most similar to stay
# consistent with the rest of the codebase score convention.
if self.similarity == 'COSINE':
score = (point.get('distance') + 1) / 2
else:
Expand All @@ -267,7 +280,7 @@ def _convertToDocs(self, points: List[dict]) -> List[Doc]:
# Get the payload content and metadata
metadata = cast(DocMetadata, metadata)

# Create asearc new document
# Create a new document
doc = Doc(score=score, page_content=content, metadata=metadata)

# Append it to this documents chunks
Expand Down Expand Up @@ -419,7 +432,7 @@ def getPaths(self, parent: str | None = None, offset: int = 0, limit: int = 1000

def addChunks(self, chunks: List[Doc], checkCollection: bool = True) -> None:
"""
Addsdocument chunks to the document store.
Add document chunks to the document store using batched bulk upsert.
"""
# If no documents present, get out
if not len(chunks):
Expand All @@ -437,7 +450,8 @@ def addChunks(self, chunks: List[Doc], checkCollection: bool = True) -> None:
# Save this object id
objectIds[chunk.metadata.objectId] = True

# Erase all documents/chunks associated with that ObjectId in one operation (TODO: Start discussion about better use of upsert() method to increase performance)
# Erase all documents/chunks associated with that ObjectId in one operation
# so we can cleanly insert the new version
if len(objectIds.keys()):
filter_condition = f"meta['objectId'] in [{', '.join(json.dumps(k) for k in objectIds.keys())}]"
try:
Expand All @@ -446,8 +460,20 @@ def addChunks(self, chunks: List[Doc], checkCollection: bool = True) -> None:
except Exception as e:
engLib.debug(f'Error deleting old chunks: {e}')

# TODO: Consider implementing a bulk insertion https://milvus.io/api-reference/pymilvus/v2.4.x/ORM/utility/do_bulk_insert.md
# Disatvantage here is that is will require to reformat interation data into a JSON file format
# Collect chunks into batches for bulk upsert instead of one-at-a-time
batch: List[dict] = []

def flush_batch():
nonlocal batch
if not batch:
return
try:
self.client.upsert(collection_name=self.collection, data=batch)
engLib.debug(f'Milvus bulk upsert: {len(batch)} chunks inserted')
except Exception as e:
engLib.debug(f'Error during bulk upsert ({len(batch)} chunks): {e}')
raise
batch = []
Comment thread
stepmikhaylov marked this conversation as resolved.

# For each document
for chunk in chunks:
Expand All @@ -461,8 +487,14 @@ def addChunks(self, chunks: List[Doc], checkCollection: bool = True) -> None:
# Append the points // create a unique identifier that fits into an int64 id field
tmp_struct = {'id': np.int64(((uuid.uuid1().time & 0x1FFFFFFFF) << 27) | random.getrandbits(27)), 'vector': embedding, 'content': chunk.page_content, 'meta': chunk.metadata}

# TODO: Consider printing out upsert count for debugging and imprement bulk insert
self.client.upsert(collection_name=self.collection, data=[tmp_struct])
batch.append(tmp_struct)

# Flush when batch reaches configured size
if len(batch) >= self.bulkInsertBatchSize:
flush_batch()

# Flush any remaining chunks
flush_batch()

def remove(self, objectIds: List[str]) -> None:
"""
Expand All @@ -480,13 +512,39 @@ def remove(self, objectIds: List[str]) -> None:
objectIdsJoint = ', '.join(f"'{_escape_milvus_str(o)}'" for o in objectIds)
must_conditions.append(f"meta['objectId'] in [{objectIdsJoint}]")

# TODO: Add time out
filter_expression = ' and '.join(must_conditions) if must_conditions else None
if filter_expression:
self.client.delete(collection_name=self.collection, filter=filter_expression)
try:
self.client.delete(collection_name=self.collection, filter=filter_expression, timeout=self.timeout)
except Exception as e:
engLib.debug(f'Error removing documents: {e}')
raise

return

def _batchUpsertResults(self, results: List[dict], *, isDeleted: bool) -> None:
"""
Batch-update the isDeleted metadata field on a list of query results.

Collects results into batches of bulkInsertBatchSize and upserts them
together, avoiding the performance bottleneck of one-at-a-time upserts.
"""
batch: List[dict] = []

for result in results:
meta = result.get('meta', {})
meta['isDeleted'] = isDeleted
result['meta'] = meta
batch.append(result)

if len(batch) >= self.bulkInsertBatchSize:
self.client.upsert(collection_name=self.collection, data=batch)
batch = []

# Flush remaining
if batch:
self.client.upsert(collection_name=self.collection, data=batch)

def markDeleted(self, objectIds: List[str]) -> None:
"""
Mark the set of documents with the given objectId as deleted.
Expand All @@ -509,14 +567,10 @@ def markDeleted(self, objectIds: List[str]) -> None:
if not filter_expression:
return

results = self.client.query(collection_name=self.collection, filter=filter_expression)
results = self.client.query(collection_name=self.collection, filter=filter_expression, output_fields=['id', 'vector', 'content', 'meta'])

# Update the 'isDeleted' field for each result -> TODO: Might there be a better way to do this? Looping over the
# vecotrs can be a performance bottleneck and additionally whats the oint if all entries will be deleled shortly after?
for result in results:
result['isDeleted'] = True
# Assuming there's a method to update the document in the client
self.client.upsert(collection_name=self.collection, data=result)
# Batch-update instead of one-at-a-time to avoid performance bottleneck
self._batchUpsertResults(results, isDeleted=True)
return

def markActive(self, objectIds: List[str]) -> None:
Expand All @@ -541,14 +595,10 @@ def markActive(self, objectIds: List[str]) -> None:
if not filter_expression:
return

results = self.client.query(collection_name=self.collection, filter=filter_expression)
results = self.client.query(collection_name=self.collection, filter=filter_expression, output_fields=['id', 'vector', 'content', 'meta'])

# Update the 'isDeleted' field for each result -> TODO: Might there be a better way to do this? Looping over the
# vecotrs can be a performance bottleneck and additionally whats the oint if all entries will be deleled shortly after?
for result in results:
result['isDeleted'] = False
# Assuming there's a method to update the document in the client
self.client.upsert(collection_name=self.collection, data=result)
# Batch-update instead of one-at-a-time to avoid performance bottleneck
self._batchUpsertResults(results, isDeleted=False)
return

def render(self, objectId: str, callback: Callable[[str], None]) -> None:
Expand All @@ -573,7 +623,7 @@ def render(self, objectId: str, callback: Callable[[str], None]) -> None:
# Build filter for getting a set of chunks within the offset range
must_condition = f"(meta['objectId'] == '{_escape_milvus_str(objectId)}') && ({offset - 1} < meta['chunkId'] < {offset + self.renderChunkSize})"

results = self.client.query(collection_name=self.collection, filter=must_condition)
results = self.client.query(collection_name=self.collection, filter=must_condition, output_fields=['meta', 'content'])

# Create a renderChunkSize array with empty
# entries. This will allow us to join even when
Expand Down
Loading