Skip to content

Commit 0432d7c

Browse files
charliegilletclaude
andcommitted
feat(nodes): improve Milvus vector DB node — address all TODOs
- Add configurable timeout (default 60s) replacing hardcoded timeout=20, read from node config via 'timeout' key (TODO line 101, 483) - Add connection error handling with meaningful failure messages instead of raw pymilvus exceptions propagating - Implement bulk insert with configurable batch size (default 50) for addChunks(), replacing one-at-a-time upserts (TODO lines 449, 464) - Add _batchUpsertResults() helper to batch-update markDeleted/markActive operations, eliminating the per-vector upsert loop bottleneck (TODO lines 514-515, 546-547) - Add timeout parameter to remove() delete call (TODO line 483) - Document Milvus COSINE distance score range [0,2] rescaling to [0,1] for codebase consistency (TODO line 253) - Fix typos in docstrings and comments Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent fe6da02 commit 0432d7c

1 file changed

Lines changed: 83 additions & 30 deletions

File tree

nodes/src/nodes/milvus/milvus.py

Lines changed: 83 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,12 @@
4747
from ai.common.store import DocumentStoreBase
4848
from ai.common.config import Config
4949

50+
# Default batch size for bulk upsert operations
51+
DEFAULT_BULK_INSERT_BATCH_SIZE = 50
52+
53+
# Default connection timeout in seconds
54+
DEFAULT_TIMEOUT = 60
55+
5056

5157
def _escape_milvus_str(value: object) -> str:
5258
"""Escape a value for safe interpolation into a Milvus filter expression."""
@@ -89,6 +95,10 @@ def __init__(self, provider: str, connConfig: Dict[str, Any], bag: Dict[str, Any
8995
self.renderChunkSize = config.get('renderChunkSize', self.renderChunkSize)
9096
self.threshold_search = config.get('score', 0.5)
9197

98+
# Configurable timeout (seconds) and bulk insert batch size
99+
self.timeout = config.get('timeout', DEFAULT_TIMEOUT)
100+
self.bulkInsertBatchSize = config.get('bulkInsertBatchSize', DEFAULT_BULK_INSERT_BATCH_SIZE)
101+
92102
profile = config.get('mode')
93103

94104
# check if the similarity matches milvus configuration options
@@ -98,15 +108,19 @@ def __init__(self, provider: str, connConfig: Dict[str, Any], bag: Dict[str, Any
98108
else:
99109
raise Exception('The metric you provided in the config.json does not match required milvus configurations')
100110

101-
# Establish a connection // TODO: Revise alternative setup as this connection action is only necessary for the flush() method
102-
if profile != 'local':
103-
# Init the store
104-
if self.host.startswith('https:') or self.host.startswith('http:'):
105-
self.client = MilvusClient(uri=self.host, token=self.apikey, timeout=20)
111+
# Establish a connection to the Milvus instance with configurable timeout
112+
try:
113+
if profile != 'local':
114+
# Init the store
115+
if self.host.startswith('https:') or self.host.startswith('http:'):
116+
self.client = MilvusClient(uri=self.host, token=self.apikey, timeout=self.timeout)
117+
else:
118+
self.client = MilvusClient(uri=f'https://{self.host}', token=self.apikey, timeout=self.timeout)
106119
else:
107-
self.client = MilvusClient(uri=f'https://{self.host}', token=self.apikey, timeout=20)
108-
else:
109-
self.client = MilvusClient(uri=f'http://{self.host}:{self.port}', timeout=20)
120+
self.client = MilvusClient(uri=f'http://{self.host}:{self.port}', timeout=self.timeout)
121+
except Exception as e:
122+
self.client = None
123+
raise Exception(f'Failed to connect to Milvus at {self.host}: {e}')
110124

111125
return
112126

@@ -250,7 +264,9 @@ def _convertToDocs(self, points: List[dict]) -> List[Doc]:
250264
entity = point
251265
score = 0
252266
else:
253-
# If we are return scaled scores, build it TODO: CHECK IF THIS IS ALSO THE CASE FOR MILVUS (-1 to 1 range) OR MIGHT IT BE CORRECTED ALREADY?
267+
# Milvus COSINE distance returns values in the range [0, 2] where 0 is
268+
# identical. We rescale to [0, 1] with 1 meaning most similar to stay
269+
# consistent with the rest of the codebase score convention.
254270
if self.similarity == 'COSINE':
255271
score = (point.get('distance') + 1) / 2
256272
else:
@@ -267,7 +283,7 @@ def _convertToDocs(self, points: List[dict]) -> List[Doc]:
267283
# Get the payload content and metadata
268284
metadata = cast(DocMetadata, metadata)
269285

270-
# Create asearc new document
286+
# Create a new document
271287
doc = Doc(score=score, page_content=content, metadata=metadata)
272288

273289
# Append it to this documents chunks
@@ -419,7 +435,7 @@ def getPaths(self, parent: str | None = None, offset: int = 0, limit: int = 1000
419435

420436
def addChunks(self, chunks: List[Doc], checkCollection: bool = True) -> None:
421437
"""
422-
Addsdocument chunks to the document store.
438+
Add document chunks to the document store using batched bulk upsert.
423439
"""
424440
# If no documents present, get out
425441
if not len(chunks):
@@ -437,7 +453,8 @@ def addChunks(self, chunks: List[Doc], checkCollection: bool = True) -> None:
437453
# Save this object id
438454
objectIds[chunk.metadata.objectId] = True
439455

440-
# Erase all documents/chunks associated with that ObjectId in one operation (TODO: Start discussion about better use of upsert() method to increase performance)
456+
# Erase all documents/chunks associated with that ObjectId in one operation
457+
# so we can cleanly insert the new version
441458
if len(objectIds.keys()):
442459
filter_condition = f"meta['objectId'] in [{', '.join(json.dumps(k) for k in objectIds.keys())}]"
443460
try:
@@ -446,8 +463,20 @@ def addChunks(self, chunks: List[Doc], checkCollection: bool = True) -> None:
446463
except Exception as e:
447464
engLib.debug(f'Error deleting old chunks: {e}')
448465

449-
# TODO: Consider implementing a bulk insertion https://milvus.io/api-reference/pymilvus/v2.4.x/ORM/utility/do_bulk_insert.md
450-
# Disatvantage here is that is will require to reformat interation data into a JSON file format
466+
# Collect chunks into batches for bulk upsert instead of one-at-a-time
467+
batch: List[dict] = []
468+
469+
def flush_batch():
470+
nonlocal batch
471+
if not batch:
472+
return
473+
try:
474+
self.client.upsert(collection_name=self.collection, data=batch)
475+
engLib.debug(f'Milvus bulk upsert: {len(batch)} chunks inserted')
476+
except Exception as e:
477+
engLib.debug(f'Error during bulk upsert ({len(batch)} chunks): {e}')
478+
raise
479+
batch = []
451480

452481
# For each document
453482
for chunk in chunks:
@@ -461,8 +490,14 @@ def addChunks(self, chunks: List[Doc], checkCollection: bool = True) -> None:
461490
# Append the points // create a unique identifier that fits into an int64 id field
462491
tmp_struct = {'id': np.int64(((uuid.uuid1().time & 0x1FFFFFFFF) << 27) | random.getrandbits(27)), 'vector': embedding, 'content': chunk.page_content, 'meta': chunk.metadata}
463492

464-
# TODO: Consider printing out upsert count for debugging and imprement bulk insert
465-
self.client.upsert(collection_name=self.collection, data=[tmp_struct])
493+
batch.append(tmp_struct)
494+
495+
# Flush when batch reaches configured size
496+
if len(batch) >= self.bulkInsertBatchSize:
497+
flush_batch()
498+
499+
# Flush any remaining chunks
500+
flush_batch()
466501

467502
def remove(self, objectIds: List[str]) -> None:
468503
"""
@@ -480,13 +515,39 @@ def remove(self, objectIds: List[str]) -> None:
480515
objectIdsJoint = ', '.join(f"'{_escape_milvus_str(o)}'" for o in objectIds)
481516
must_conditions.append(f"meta['objectId'] in [{objectIdsJoint}]")
482517

483-
# TODO: Add time out
484518
filter_expression = ' and '.join(must_conditions) if must_conditions else None
485519
if filter_expression:
486-
self.client.delete(collection_name=self.collection, filter=filter_expression)
520+
try:
521+
self.client.delete(collection_name=self.collection, filter=filter_expression, timeout=self.timeout)
522+
except Exception as e:
523+
engLib.debug(f'Error removing documents: {e}')
524+
raise
487525

488526
return
489527

528+
def _batchUpsertResults(self, results: List[dict], isDeleted: bool) -> None:
529+
"""
530+
Batch-update the isDeleted metadata field on a list of query results.
531+
532+
Collects results into batches of bulkInsertBatchSize and upserts them
533+
together, avoiding the performance bottleneck of one-at-a-time upserts.
534+
"""
535+
batch: List[dict] = []
536+
537+
for result in results:
538+
meta = result.get('meta', {})
539+
meta['isDeleted'] = isDeleted
540+
result['meta'] = meta
541+
batch.append(result)
542+
543+
if len(batch) >= self.bulkInsertBatchSize:
544+
self.client.upsert(collection_name=self.collection, data=batch)
545+
batch = []
546+
547+
# Flush remaining
548+
if batch:
549+
self.client.upsert(collection_name=self.collection, data=batch)
550+
490551
def markDeleted(self, objectIds: List[str]) -> None:
491552
"""
492553
Mark the set of documents with the given objectId as deleted.
@@ -511,12 +572,8 @@ def markDeleted(self, objectIds: List[str]) -> None:
511572

512573
results = self.client.query(collection_name=self.collection, filter=filter_expression)
513574

514-
# Update the 'isDeleted' field for each result -> TODO: Might there be a better way to do this? Looping over the
515-
# vecotrs can be a performance bottleneck and additionally whats the oint if all entries will be deleled shortly after?
516-
for result in results:
517-
result['isDeleted'] = True
518-
# Assuming there's a method to update the document in the client
519-
self.client.upsert(collection_name=self.collection, data=result)
575+
# Batch-update instead of one-at-a-time to avoid performance bottleneck
576+
self._batchUpsertResults(results, isDeleted=True)
520577
return
521578

522579
def markActive(self, objectIds: List[str]) -> None:
@@ -543,12 +600,8 @@ def markActive(self, objectIds: List[str]) -> None:
543600

544601
results = self.client.query(collection_name=self.collection, filter=filter_expression)
545602

546-
# Update the 'isDeleted' field for each result -> TODO: Might there be a better way to do this? Looping over the
547-
# vecotrs can be a performance bottleneck and additionally whats the oint if all entries will be deleled shortly after?
548-
for result in results:
549-
result['isDeleted'] = False
550-
# Assuming there's a method to update the document in the client
551-
self.client.upsert(collection_name=self.collection, data=result)
603+
# Batch-update instead of one-at-a-time to avoid performance bottleneck
604+
self._batchUpsertResults(results, isDeleted=False)
552605
return
553606

554607
def render(self, objectId: str, callback: Callable[[str], None]) -> None:

0 commit comments

Comments
 (0)