Skip to content

Commit d9ab2a1

Browse files
committed
perf(core): shard oversized vector sync work
Signed-off-by: phernandez <paul@basicmachines.co>
1 parent d2055f3 commit d9ab2a1

File tree

4 files changed

+446
-16
lines changed

4 files changed

+446
-16
lines changed

src/basic_memory/repository/postgres_search_repository.py

Lines changed: 61 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -533,6 +533,7 @@ async def sync_entity_vectors_batch(
533533
pending_jobs: list[_PendingEmbeddingJob] = []
534534
entity_runtime: dict[int, _EntitySyncRuntime] = {}
535535
failed_entity_ids: set[int] = set()
536+
deferred_entity_ids: set[int] = set()
536537
synced_entity_ids: set[int] = set()
537538

538539
for window_start in range(0, total_entities, POSTGRES_VECTOR_PREPARE_CONCURRENCY):
@@ -572,7 +573,10 @@ async def sync_entity_vectors_batch(
572573
result.prepare_seconds_total += prepared_sync.prepare_seconds
573574

574575
if embedding_jobs_count == 0:
575-
synced_entity_ids.add(entity_id)
576+
if prepared_sync.entity_complete:
577+
synced_entity_ids.add(entity_id)
578+
else:
579+
deferred_entity_ids.add(entity_id)
576580
total_seconds = time.perf_counter() - prepared_sync.sync_start
577581
queue_wait_seconds = max(0.0, total_seconds - prepared_sync.prepare_seconds)
578582
result.queue_wait_seconds_total += queue_wait_seconds
@@ -588,6 +592,12 @@ async def sync_entity_vectors_batch(
588592
chunks_skipped=prepared_sync.chunks_skipped,
589593
embedding_jobs_count=0,
590594
entity_skipped=prepared_sync.entity_skipped,
595+
entity_complete=prepared_sync.entity_complete,
596+
oversized_entity=prepared_sync.oversized_entity,
597+
pending_jobs_total=prepared_sync.pending_jobs_total,
598+
shard_index=prepared_sync.shard_index,
599+
shard_count=prepared_sync.shard_count,
600+
remaining_jobs_after_shard=prepared_sync.remaining_jobs_after_shard,
591601
)
592602
continue
593603

@@ -599,6 +609,12 @@ async def sync_entity_vectors_batch(
599609
chunks_total=prepared_sync.chunks_total,
600610
chunks_skipped=prepared_sync.chunks_skipped,
601611
entity_skipped=prepared_sync.entity_skipped,
612+
entity_complete=prepared_sync.entity_complete,
613+
oversized_entity=prepared_sync.oversized_entity,
614+
pending_jobs_total=prepared_sync.pending_jobs_total,
615+
shard_index=prepared_sync.shard_index,
616+
shard_count=prepared_sync.shard_count,
617+
remaining_jobs_after_shard=prepared_sync.remaining_jobs_after_shard,
602618
prepare_seconds=prepared_sync.prepare_seconds,
603619
)
604620
pending_jobs.extend(
@@ -624,10 +640,13 @@ async def sync_entity_vectors_batch(
624640
(result.queue_wait_seconds_total) += self._finalize_completed_entity_syncs(
625641
entity_runtime=entity_runtime,
626642
synced_entity_ids=synced_entity_ids,
643+
deferred_entity_ids=deferred_entity_ids,
627644
)
628645
except Exception as exc:
629646
affected_entity_ids = sorted({job.entity_id for job in flush_jobs})
630647
failed_entity_ids.update(affected_entity_ids)
648+
synced_entity_ids.difference_update(affected_entity_ids)
649+
deferred_entity_ids.difference_update(affected_entity_ids)
631650
for failed_entity_id in affected_entity_ids:
632651
entity_runtime.pop(failed_entity_id, None)
633652
logger.warning(
@@ -654,10 +673,13 @@ async def sync_entity_vectors_batch(
654673
(result.queue_wait_seconds_total) += self._finalize_completed_entity_syncs(
655674
entity_runtime=entity_runtime,
656675
synced_entity_ids=synced_entity_ids,
676+
deferred_entity_ids=deferred_entity_ids,
657677
)
658678
except Exception as exc:
659679
affected_entity_ids = sorted({job.entity_id for job in flush_jobs})
660680
failed_entity_ids.update(affected_entity_ids)
681+
synced_entity_ids.difference_update(affected_entity_ids)
682+
deferred_entity_ids.difference_update(affected_entity_ids)
661683
for failed_entity_id in affected_entity_ids:
662684
entity_runtime.pop(failed_entity_id, None)
663685
logger.warning(
@@ -672,6 +694,8 @@ async def sync_entity_vectors_batch(
672694
if entity_runtime:
673695
orphan_runtime_entities = sorted(entity_runtime.keys())
674696
failed_entity_ids.update(orphan_runtime_entities)
697+
synced_entity_ids.difference_update(orphan_runtime_entities)
698+
deferred_entity_ids.difference_update(orphan_runtime_entities)
675699
logger.warning(
676700
"Vector batch sync left unfinished entities after flushes: "
677701
"project_id={project_id} unfinished_entities={unfinished_entities}",
@@ -680,13 +704,17 @@ async def sync_entity_vectors_batch(
680704
)
681705

682706
synced_entity_ids.difference_update(failed_entity_ids)
707+
deferred_entity_ids.difference_update(failed_entity_ids)
708+
deferred_entity_ids.difference_update(synced_entity_ids)
683709
result.failed_entity_ids = sorted(failed_entity_ids)
684710
result.entities_failed = len(result.failed_entity_ids)
711+
result.entities_deferred = len(deferred_entity_ids)
685712
result.entities_synced = len(synced_entity_ids)
686713

687714
logger.info(
688715
"Vector batch sync complete: project_id={project_id} entities_total={entities_total} "
689716
"entities_synced={entities_synced} entities_failed={entities_failed} "
717+
"entities_deferred={entities_deferred} "
690718
"entities_skipped={entities_skipped} chunks_total={chunks_total} "
691719
"chunks_skipped={chunks_skipped} embedding_jobs_total={embedding_jobs_total} "
692720
"prepare_seconds_total={prepare_seconds_total:.3f} "
@@ -696,6 +724,7 @@ async def sync_entity_vectors_batch(
696724
entities_total=result.entities_total,
697725
entities_synced=result.entities_synced,
698726
entities_failed=result.entities_failed,
727+
entities_deferred=result.entities_deferred,
699728
entities_skipped=result.entities_skipped,
700729
chunks_total=result.chunks_total,
701730
chunks_skipped=result.chunks_skipped,
@@ -849,14 +878,13 @@ async def _prepare_entity_vector_jobs(self, entity_id: int) -> _PreparedEntityVe
849878
prepare_seconds=prepare_seconds,
850879
)
851880

852-
upsert_records: list[dict[str, str]] = []
853-
embedding_jobs: list[tuple[int, str]] = []
881+
pending_records: list[dict[str, str]] = []
854882
skipped_chunks_count = 0
855883

856884
for record in chunk_records:
857885
current = existing_by_key.get(record["chunk_key"])
858886
if current is None:
859-
upsert_records.append(record)
887+
pending_records.append(record)
860888
continue
861889

862890
row_id = int(current["id"])
@@ -886,7 +914,19 @@ async def _prepare_entity_vector_jobs(self, entity_id: int) -> _PreparedEntityVe
886914
skipped_chunks_count += 1
887915
continue
888916

889-
upsert_records.append(record)
917+
pending_records.append(record)
918+
919+
shard_plan = self._plan_entity_vector_shard(pending_records)
920+
self._log_vector_shard_plan(entity_id=entity_id, shard_plan=shard_plan)
921+
922+
scheduled_records = [
923+
record
924+
for record in sorted(pending_records, key=lambda record: record["chunk_key"])
925+
if record["chunk_key"] in shard_plan.scheduled_chunk_keys
926+
]
927+
928+
embedding_jobs: list[tuple[int, str]] = []
929+
upsert_records = list(scheduled_records)
890930

891931
if upsert_records:
892932
upsert_params: dict[str, object] = {
@@ -943,14 +983,23 @@ async def _prepare_entity_vector_jobs(self, entity_id: int) -> _PreparedEntityVe
943983
"stale_chunks_count={stale_chunks_count} "
944984
"orphan_chunks_count={orphan_chunks_count} "
945985
"chunks_skipped={chunks_skipped} "
946-
"embedding_jobs_count={embedding_jobs_count}",
986+
"embedding_jobs_count={embedding_jobs_count} "
987+
"pending_jobs_total={pending_jobs_total} shard_index={shard_index} "
988+
"shard_count={shard_count} remaining_jobs_after_shard={remaining_jobs_after_shard} "
989+
"oversized_entity={oversized_entity} entity_complete={entity_complete}",
947990
project_id=self.project_id,
948991
entity_id=entity_id,
949992
existing_chunks_count=existing_chunks_count,
950993
stale_chunks_count=stale_chunks_count,
951994
orphan_chunks_count=orphan_chunks_count,
952995
chunks_skipped=skipped_chunks_count,
953996
embedding_jobs_count=len(embedding_jobs),
997+
pending_jobs_total=shard_plan.pending_jobs_total,
998+
shard_index=shard_plan.shard_index,
999+
shard_count=shard_plan.shard_count,
1000+
remaining_jobs_after_shard=shard_plan.remaining_jobs_after_shard,
1001+
oversized_entity=shard_plan.oversized_entity,
1002+
entity_complete=shard_plan.entity_complete,
9541003
)
9551004

9561005
prepare_seconds = time.perf_counter() - sync_start
@@ -961,6 +1010,12 @@ async def _prepare_entity_vector_jobs(self, entity_id: int) -> _PreparedEntityVe
9611010
embedding_jobs=embedding_jobs,
9621011
chunks_total=built_chunk_records_count,
9631012
chunks_skipped=skipped_chunks_count,
1013+
entity_complete=shard_plan.entity_complete,
1014+
oversized_entity=shard_plan.oversized_entity,
1015+
pending_jobs_total=shard_plan.pending_jobs_total,
1016+
shard_index=shard_plan.shard_index,
1017+
shard_count=shard_plan.shard_count,
1018+
remaining_jobs_after_shard=shard_plan.remaining_jobs_after_shard,
9641019
prepare_seconds=prepare_seconds,
9651020
)
9661021

0 commit comments

Comments
 (0)