Skip to content

Commit 5af0a4b

Browse files
charliegilletclaude
authored andcommitted
feat(nodes): improve Milvus vector DB node — address all TODOs (#562)
* feat(vscode): improve stop button feedback in Pipeline Observability screen Handle TASK_STATE.STOPPING in the control button to show "Stopping..." with a disabled state and distinct orange styling, preventing duplicate clicks and giving immediate visual feedback during pipeline shutdown. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * feat(nodes): improve Milvus vector DB node — address all TODOs - Add configurable timeout (default 60s) replacing hardcoded timeout=20, read from node config via 'timeout' key (TODO line 101, 483) - Add connection error handling with meaningful failure messages instead of raw pymilvus exceptions propagating - Implement bulk insert with configurable batch size (default 50) for addChunks(), replacing one-at-a-time upserts (TODO lines 449, 464) - Add _batchUpsertResults() helper to batch-update markDeleted/markActive operations, eliminating the per-vector upsert loop bottleneck (TODO lines 514-515, 546-547) - Add timeout parameter to remove() delete call (TODO line 483) - Document Milvus COSINE distance score range [0,2] rescaling to [0,1] for codebase consistency (TODO line 253) - Fix typos in docstrings and comments Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * fix: remove unrelated PageStatus "Stopping..." changes from Milvus PR The PageStatus changes belong in a separate PR (#549) and were accidentally included here. * fix(nodes): address CodeRabbit feedback on Milvus PR #562 - Remove dead protocol check (host already stripped of scheme at init) - Add exception chaining with 'from e' for connection errors (B904) - Add output_fields to markDeleted/markActive queries to prevent data loss during upsert (was only returning primary key) - Add output_fields to renderChunks query to prevent KeyError on content/chunkId access Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * fix(nodes): address remaining review feedback on Milvus PR #562 - Remove unrelated PageStatus changes that were re-introduced - Validate timeout and bulkInsertBatchSize to ensure positive values - Make isDeleted a keyword-only argument in _batchUpsertResults Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> --------- Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 3aeda91 commit 5af0a4b

1 file changed

Lines changed: 83 additions & 33 deletions

File tree

nodes/src/nodes/milvus/milvus.py

Lines changed: 83 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,12 @@
4747
from ai.common.store import DocumentStoreBase
4848
from ai.common.config import Config
4949

50+
# Default batch size for bulk upsert operations
51+
DEFAULT_BULK_INSERT_BATCH_SIZE = 50
52+
53+
# Default connection timeout in seconds
54+
DEFAULT_TIMEOUT = 60
55+
5056

5157
def _escape_milvus_str(value: object) -> str:
5258
"""Escape a value for safe interpolation into a Milvus filter expression."""
@@ -89,6 +95,10 @@ def __init__(self, provider: str, connConfig: Dict[str, Any], bag: Dict[str, Any
8995
self.renderChunkSize = config.get('renderChunkSize', self.renderChunkSize)
9096
self.threshold_search = config.get('score', 0.5)
9197

98+
# Configurable timeout (seconds) and bulk insert batch size
99+
self.timeout = max(int(config.get('timeout', DEFAULT_TIMEOUT)), 1)
100+
self.bulkInsertBatchSize = max(int(config.get('bulkInsertBatchSize', DEFAULT_BULK_INSERT_BATCH_SIZE)), 1)
101+
92102
profile = config.get('mode')
93103

94104
# check if the similarity matches milvus configuration options
@@ -98,15 +108,16 @@ def __init__(self, provider: str, connConfig: Dict[str, Any], bag: Dict[str, Any
98108
else:
99109
raise Exception('The metric you provided in the config.json does not match required milvus configurations')
100110

101-
# Establish a connection // TODO: Revise alternative setup as this connection action is only necessary for the flush() method
102-
if profile != 'local':
103-
# Init the store
104-
if self.host.startswith('https:') or self.host.startswith('http:'):
105-
self.client = MilvusClient(uri=self.host, token=self.apikey, timeout=20)
111+
# Establish a connection to the Milvus instance with configurable timeout
112+
try:
113+
if profile != 'local':
114+
# Init the store (host was stripped of protocol at line 87, so always add https://)
115+
self.client = MilvusClient(uri=f'https://{self.host}', token=self.apikey, timeout=self.timeout)
106116
else:
107-
self.client = MilvusClient(uri=f'https://{self.host}', token=self.apikey, timeout=20)
108-
else:
109-
self.client = MilvusClient(uri=f'http://{self.host}:{self.port}', timeout=20)
117+
self.client = MilvusClient(uri=f'http://{self.host}:{self.port}', timeout=self.timeout)
118+
except Exception as e:
119+
self.client = None
120+
raise Exception(f'Failed to connect to Milvus at {self.host}: {e}') from e
110121

111122
return
112123

@@ -250,7 +261,9 @@ def _convertToDocs(self, points: List[dict]) -> List[Doc]:
250261
entity = point
251262
score = 0
252263
else:
253-
# If we are return scaled scores, build it TODO: CHECK IF THIS IS ALSO THE CASE FOR MILVUS (-1 to 1 range) OR MIGHT IT BE CORRECTED ALREADY?
264+
# Milvus COSINE distance returns values in the range [0, 2] where 0 is
265+
# identical. We rescale to [0, 1] with 1 meaning most similar to stay
266+
# consistent with the rest of the codebase score convention.
254267
if self.similarity == 'COSINE':
255268
score = (point.get('distance') + 1) / 2
256269
else:
@@ -267,7 +280,7 @@ def _convertToDocs(self, points: List[dict]) -> List[Doc]:
267280
# Get the payload content and metadata
268281
metadata = cast(DocMetadata, metadata)
269282

270-
# Create asearc new document
283+
# Create a new document
271284
doc = Doc(score=score, page_content=content, metadata=metadata)
272285

273286
# Append it to this documents chunks
@@ -419,7 +432,7 @@ def getPaths(self, parent: str | None = None, offset: int = 0, limit: int = 1000
419432

420433
def addChunks(self, chunks: List[Doc], checkCollection: bool = True) -> None:
421434
"""
422-
Addsdocument chunks to the document store.
435+
Add document chunks to the document store using batched bulk upsert.
423436
"""
424437
# If no documents present, get out
425438
if not len(chunks):
@@ -437,7 +450,8 @@ def addChunks(self, chunks: List[Doc], checkCollection: bool = True) -> None:
437450
# Save this object id
438451
objectIds[chunk.metadata.objectId] = True
439452

440-
# Erase all documents/chunks associated with that ObjectId in one operation (TODO: Start discussion about better use of upsert() method to increase performance)
453+
# Erase all documents/chunks associated with that ObjectId in one operation
454+
# so we can cleanly insert the new version
441455
if len(objectIds.keys()):
442456
filter_condition = f"meta['objectId'] in [{', '.join(json.dumps(k) for k in objectIds.keys())}]"
443457
try:
@@ -446,8 +460,20 @@ def addChunks(self, chunks: List[Doc], checkCollection: bool = True) -> None:
446460
except Exception as e:
447461
engLib.debug(f'Error deleting old chunks: {e}')
448462

449-
# TODO: Consider implementing a bulk insertion https://milvus.io/api-reference/pymilvus/v2.4.x/ORM/utility/do_bulk_insert.md
450-
# Disatvantage here is that is will require to reformat interation data into a JSON file format
463+
# Collect chunks into batches for bulk upsert instead of one-at-a-time
464+
batch: List[dict] = []
465+
466+
def flush_batch():
467+
nonlocal batch
468+
if not batch:
469+
return
470+
try:
471+
self.client.upsert(collection_name=self.collection, data=batch)
472+
engLib.debug(f'Milvus bulk upsert: {len(batch)} chunks inserted')
473+
except Exception as e:
474+
engLib.debug(f'Error during bulk upsert ({len(batch)} chunks): {e}')
475+
raise
476+
batch = []
451477

452478
# For each document
453479
for chunk in chunks:
@@ -461,8 +487,14 @@ def addChunks(self, chunks: List[Doc], checkCollection: bool = True) -> None:
461487
# Append the points // create a unique identifier that fits into an int64 id field
462488
tmp_struct = {'id': np.int64(((uuid.uuid1().time & 0x1FFFFFFFF) << 27) | random.getrandbits(27)), 'vector': embedding, 'content': chunk.page_content, 'meta': chunk.metadata}
463489

464-
# TODO: Consider printing out upsert count for debugging and imprement bulk insert
465-
self.client.upsert(collection_name=self.collection, data=[tmp_struct])
490+
batch.append(tmp_struct)
491+
492+
# Flush when batch reaches configured size
493+
if len(batch) >= self.bulkInsertBatchSize:
494+
flush_batch()
495+
496+
# Flush any remaining chunks
497+
flush_batch()
466498

467499
def remove(self, objectIds: List[str]) -> None:
468500
"""
@@ -480,13 +512,39 @@ def remove(self, objectIds: List[str]) -> None:
480512
objectIdsJoint = ', '.join(f"'{_escape_milvus_str(o)}'" for o in objectIds)
481513
must_conditions.append(f"meta['objectId'] in [{objectIdsJoint}]")
482514

483-
# TODO: Add time out
484515
filter_expression = ' and '.join(must_conditions) if must_conditions else None
485516
if filter_expression:
486-
self.client.delete(collection_name=self.collection, filter=filter_expression)
517+
try:
518+
self.client.delete(collection_name=self.collection, filter=filter_expression, timeout=self.timeout)
519+
except Exception as e:
520+
engLib.debug(f'Error removing documents: {e}')
521+
raise
487522

488523
return
489524

525+
def _batchUpsertResults(self, results: List[dict], *, isDeleted: bool) -> None:
526+
"""
527+
Batch-update the isDeleted metadata field on a list of query results.
528+
529+
Collects results into batches of bulkInsertBatchSize and upserts them
530+
together, avoiding the performance bottleneck of one-at-a-time upserts.
531+
"""
532+
batch: List[dict] = []
533+
534+
for result in results:
535+
meta = result.get('meta', {})
536+
meta['isDeleted'] = isDeleted
537+
result['meta'] = meta
538+
batch.append(result)
539+
540+
if len(batch) >= self.bulkInsertBatchSize:
541+
self.client.upsert(collection_name=self.collection, data=batch)
542+
batch = []
543+
544+
# Flush remaining
545+
if batch:
546+
self.client.upsert(collection_name=self.collection, data=batch)
547+
490548
def markDeleted(self, objectIds: List[str]) -> None:
491549
"""
492550
Mark the set of documents with the given objectId as deleted.
@@ -509,14 +567,10 @@ def markDeleted(self, objectIds: List[str]) -> None:
509567
if not filter_expression:
510568
return
511569

512-
results = self.client.query(collection_name=self.collection, filter=filter_expression)
570+
results = self.client.query(collection_name=self.collection, filter=filter_expression, output_fields=['id', 'vector', 'content', 'meta'])
513571

514-
# Update the 'isDeleted' field for each result -> TODO: Might there be a better way to do this? Looping over the
515-
# vecotrs can be a performance bottleneck and additionally whats the oint if all entries will be deleled shortly after?
516-
for result in results:
517-
result['isDeleted'] = True
518-
# Assuming there's a method to update the document in the client
519-
self.client.upsert(collection_name=self.collection, data=result)
572+
# Batch-update instead of one-at-a-time to avoid performance bottleneck
573+
self._batchUpsertResults(results, isDeleted=True)
520574
return
521575

522576
def markActive(self, objectIds: List[str]) -> None:
@@ -541,14 +595,10 @@ def markActive(self, objectIds: List[str]) -> None:
541595
if not filter_expression:
542596
return
543597

544-
results = self.client.query(collection_name=self.collection, filter=filter_expression)
598+
results = self.client.query(collection_name=self.collection, filter=filter_expression, output_fields=['id', 'vector', 'content', 'meta'])
545599

546-
# Update the 'isDeleted' field for each result -> TODO: Might there be a better way to do this? Looping over the
547-
# vecotrs can be a performance bottleneck and additionally whats the oint if all entries will be deleled shortly after?
548-
for result in results:
549-
result['isDeleted'] = False
550-
# Assuming there's a method to update the document in the client
551-
self.client.upsert(collection_name=self.collection, data=result)
600+
# Batch-update instead of one-at-a-time to avoid performance bottleneck
601+
self._batchUpsertResults(results, isDeleted=False)
552602
return
553603

554604
def render(self, objectId: str, callback: Callable[[str], None]) -> None:
@@ -573,7 +623,7 @@ def render(self, objectId: str, callback: Callable[[str], None]) -> None:
573623
# Build filter for getting a set of chunks within the offset range
574624
must_condition = f"(meta['objectId'] == '{_escape_milvus_str(objectId)}') && ({offset - 1} < meta['chunkId'] < {offset + self.renderChunkSize})"
575625

576-
results = self.client.query(collection_name=self.collection, filter=must_condition)
626+
results = self.client.query(collection_name=self.collection, filter=must_condition, output_fields=['meta', 'content'])
577627

578628
# Create a renderChunkSize array with empty
579629
# entries. This will allow us to join even when

0 commit comments

Comments
 (0)