4747from ai .common .store import DocumentStoreBase
4848from ai .common .config import Config
4949
50+ # Default batch size for bulk upsert operations
51+ DEFAULT_BULK_INSERT_BATCH_SIZE = 50
52+
53+ # Default connection timeout in seconds
54+ DEFAULT_TIMEOUT = 60
55+
5056
5157def _escape_milvus_str (value : object ) -> str :
5258 """Escape a value for safe interpolation into a Milvus filter expression."""
@@ -89,6 +95,10 @@ def __init__(self, provider: str, connConfig: Dict[str, Any], bag: Dict[str, Any
8995 self .renderChunkSize = config .get ('renderChunkSize' , self .renderChunkSize )
9096 self .threshold_search = config .get ('score' , 0.5 )
9197
98+ # Configurable timeout (seconds) and bulk insert batch size
99+ self .timeout = config .get ('timeout' , DEFAULT_TIMEOUT )
100+ self .bulkInsertBatchSize = config .get ('bulkInsertBatchSize' , DEFAULT_BULK_INSERT_BATCH_SIZE )
101+
92102 profile = config .get ('mode' )
93103
94104 # check if the similarity matches milvus configuration options
@@ -98,15 +108,19 @@ def __init__(self, provider: str, connConfig: Dict[str, Any], bag: Dict[str, Any
98108 else :
99109 raise Exception ('The metric you provided in the config.json does not match required milvus configurations' )
100110
101- # Establish a connection // TODO: Revise alternative setup as this connection action is only necessary for the flush() method
102- if profile != 'local' :
103- # Init the store
104- if self .host .startswith ('https:' ) or self .host .startswith ('http:' ):
105- self .client = MilvusClient (uri = self .host , token = self .apikey , timeout = 20 )
111+ # Establish a connection to the Milvus instance with configurable timeout
112+ try :
113+ if profile != 'local' :
114+ # Init the store
115+ if self .host .startswith ('https:' ) or self .host .startswith ('http:' ):
116+ self .client = MilvusClient (uri = self .host , token = self .apikey , timeout = self .timeout )
117+ else :
118+ self .client = MilvusClient (uri = f'https://{ self .host } ' , token = self .apikey , timeout = self .timeout )
106119 else :
107- self .client = MilvusClient (uri = f'https://{ self .host } ' , token = self .apikey , timeout = 20 )
108- else :
109- self .client = MilvusClient (uri = f'http://{ self .host } :{ self .port } ' , timeout = 20 )
120+ self .client = MilvusClient (uri = f'http://{ self .host } :{ self .port } ' , timeout = self .timeout )
121+ except Exception as e :
122+ self .client = None
123+ raise Exception (f'Failed to connect to Milvus at { self .host } : { e } ' )
110124
111125 return
112126
@@ -250,7 +264,9 @@ def _convertToDocs(self, points: List[dict]) -> List[Doc]:
250264 entity = point
251265 score = 0
252266 else :
253- # If we are return scaled scores, build it TODO: CHECK IF THIS IS ALSO THE CASE FOR MILVUS (-1 to 1 range) OR MIGHT IT BE CORRECTED ALREADY?
267+ # Milvus COSINE distance returns values in the range [0, 2] where 0 is
268+ # identical. We rescale to [0, 1] with 1 meaning most similar to stay
269+ # consistent with the rest of the codebase score convention.
254270 if self .similarity == 'COSINE' :
255271 score = (point .get ('distance' ) + 1 ) / 2
256272 else :
@@ -267,7 +283,7 @@ def _convertToDocs(self, points: List[dict]) -> List[Doc]:
267283 # Get the payload content and metadata
268284 metadata = cast (DocMetadata , metadata )
269285
270- # Create asearc new document
286+ # Create a new document
271287 doc = Doc (score = score , page_content = content , metadata = metadata )
272288
273289 # Append it to this documents chunks
@@ -419,7 +435,7 @@ def getPaths(self, parent: str | None = None, offset: int = 0, limit: int = 1000
419435
420436 def addChunks (self , chunks : List [Doc ], checkCollection : bool = True ) -> None :
421437 """
422- Addsdocument chunks to the document store.
438+ Add document chunks to the document store using batched bulk upsert .
423439 """
424440 # If no documents present, get out
425441 if not len (chunks ):
@@ -437,7 +453,8 @@ def addChunks(self, chunks: List[Doc], checkCollection: bool = True) -> None:
437453 # Save this object id
438454 objectIds [chunk .metadata .objectId ] = True
439455
440- # Erase all documents/chunks associated with that ObjectId in one operation (TODO: Start discussion about better use of upsert() method to increase performance)
456+ # Erase all documents/chunks associated with that ObjectId in one operation
457+ # so we can cleanly insert the new version
441458 if len (objectIds .keys ()):
442459 filter_condition = f"meta['objectId'] in [{ ', ' .join (json .dumps (k ) for k in objectIds .keys ())} ]"
443460 try :
@@ -446,8 +463,20 @@ def addChunks(self, chunks: List[Doc], checkCollection: bool = True) -> None:
446463 except Exception as e :
447464 engLib .debug (f'Error deleting old chunks: { e } ' )
448465
449- # TODO: Consider implementing a bulk insertion https://milvus.io/api-reference/pymilvus/v2.4.x/ORM/utility/do_bulk_insert.md
450- # Disatvantage here is that is will require to reformat interation data into a JSON file format
466+ # Collect chunks into batches for bulk upsert instead of one-at-a-time
467+ batch : List [dict ] = []
468+
469+ def flush_batch ():
470+ nonlocal batch
471+ if not batch :
472+ return
473+ try :
474+ self .client .upsert (collection_name = self .collection , data = batch )
475+ engLib .debug (f'Milvus bulk upsert: { len (batch )} chunks inserted' )
476+ except Exception as e :
477+ engLib .debug (f'Error during bulk upsert ({ len (batch )} chunks): { e } ' )
478+ raise
479+ batch = []
451480
452481 # For each document
453482 for chunk in chunks :
@@ -461,8 +490,14 @@ def addChunks(self, chunks: List[Doc], checkCollection: bool = True) -> None:
461490 # Append the points // create a unique identifier that fits into an int64 id field
462491 tmp_struct = {'id' : np .int64 (((uuid .uuid1 ().time & 0x1FFFFFFFF ) << 27 ) | random .getrandbits (27 )), 'vector' : embedding , 'content' : chunk .page_content , 'meta' : chunk .metadata }
463492
464- # TODO: Consider printing out upsert count for debugging and imprement bulk insert
465- self .client .upsert (collection_name = self .collection , data = [tmp_struct ])
493+ batch .append (tmp_struct )
494+
495+ # Flush when batch reaches configured size
496+ if len (batch ) >= self .bulkInsertBatchSize :
497+ flush_batch ()
498+
499+ # Flush any remaining chunks
500+ flush_batch ()
466501
467502 def remove (self , objectIds : List [str ]) -> None :
468503 """
@@ -480,13 +515,39 @@ def remove(self, objectIds: List[str]) -> None:
480515 objectIdsJoint = ', ' .join (f"'{ _escape_milvus_str (o )} '" for o in objectIds )
481516 must_conditions .append (f"meta['objectId'] in [{ objectIdsJoint } ]" )
482517
483- # TODO: Add time out
484518 filter_expression = ' and ' .join (must_conditions ) if must_conditions else None
485519 if filter_expression :
486- self .client .delete (collection_name = self .collection , filter = filter_expression )
520+ try :
521+ self .client .delete (collection_name = self .collection , filter = filter_expression , timeout = self .timeout )
522+ except Exception as e :
523+ engLib .debug (f'Error removing documents: { e } ' )
524+ raise
487525
488526 return
489527
528+ def _batchUpsertResults (self , results : List [dict ], isDeleted : bool ) -> None :
529+ """
530+ Batch-update the isDeleted metadata field on a list of query results.
531+
532+ Collects results into batches of bulkInsertBatchSize and upserts them
533+ together, avoiding the performance bottleneck of one-at-a-time upserts.
534+ """
535+ batch : List [dict ] = []
536+
537+ for result in results :
538+ meta = result .get ('meta' , {})
539+ meta ['isDeleted' ] = isDeleted
540+ result ['meta' ] = meta
541+ batch .append (result )
542+
543+ if len (batch ) >= self .bulkInsertBatchSize :
544+ self .client .upsert (collection_name = self .collection , data = batch )
545+ batch = []
546+
547+ # Flush remaining
548+ if batch :
549+ self .client .upsert (collection_name = self .collection , data = batch )
550+
490551 def markDeleted (self , objectIds : List [str ]) -> None :
491552 """
492553 Mark the set of documents with the given objectId as deleted.
@@ -511,12 +572,8 @@ def markDeleted(self, objectIds: List[str]) -> None:
511572
512573 results = self .client .query (collection_name = self .collection , filter = filter_expression )
513574
514- # Update the 'isDeleted' field for each result -> TODO: Might there be a better way to do this? Looping over the
515- # vecotrs can be a performance bottleneck and additionally whats the oint if all entries will be deleled shortly after?
516- for result in results :
517- result ['isDeleted' ] = True
518- # Assuming there's a method to update the document in the client
519- self .client .upsert (collection_name = self .collection , data = result )
575+ # Batch-update instead of one-at-a-time to avoid performance bottleneck
576+ self ._batchUpsertResults (results , isDeleted = True )
520577 return
521578
522579 def markActive (self , objectIds : List [str ]) -> None :
@@ -543,12 +600,8 @@ def markActive(self, objectIds: List[str]) -> None:
543600
544601 results = self .client .query (collection_name = self .collection , filter = filter_expression )
545602
546- # Update the 'isDeleted' field for each result -> TODO: Might there be a better way to do this? Looping over the
547- # vecotrs can be a performance bottleneck and additionally whats the oint if all entries will be deleled shortly after?
548- for result in results :
549- result ['isDeleted' ] = False
550- # Assuming there's a method to update the document in the client
551- self .client .upsert (collection_name = self .collection , data = result )
603+ # Batch-update instead of one-at-a-time to avoid performance bottleneck
604+ self ._batchUpsertResults (results , isDeleted = False )
552605 return
553606
554607 def render (self , objectId : str , callback : Callable [[str ], None ]) -> None :
0 commit comments