Skip to content

Commit d59b974

Browse files
committed
refactor(core): address review nits
Signed-off-by: phernandez <paul@basicmachines.co>
1 parent a68933a commit d59b974

6 files changed

Lines changed: 168 additions & 123 deletions

File tree

src/basic_memory/repository/embedding_provider.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""Embedding provider protocol for pluggable semantic backends."""
22

3-
from typing import Protocol
3+
from typing import Any, Protocol
44

55

66
class EmbeddingProvider(Protocol):
@@ -16,3 +16,7 @@ async def embed_query(self, text: str) -> list[float]:
1616
async def embed_documents(self, texts: list[str]) -> list[list[float]]:
1717
"""Embed a list of document chunks."""
1818
...
19+
20+
def runtime_log_attrs(self) -> dict[str, Any]:
21+
"""Return provider-specific runtime settings suitable for startup logs."""
22+
...

src/basic_memory/repository/embedding_provider_factory.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
_EMBEDDING_PROVIDER_CACHE: dict[ProviderCacheKey, EmbeddingProvider] = {}
2121
_EMBEDDING_PROVIDER_CACHE_LOCK = Lock()
22+
_FASTEMBED_MAX_THREADS = 8
2223

2324

2425
def _available_cpu_count() -> int | None:
@@ -55,7 +56,7 @@ def _resolve_fastembed_runtime_knobs(
5556
if available_cpus <= 2:
5657
return available_cpus, 1
5758

58-
threads = min(8, max(2, available_cpus - 2))
59+
threads = min(_FASTEMBED_MAX_THREADS, max(2, available_cpus - 2))
5960
return threads, 1
6061

6162

src/basic_memory/repository/fastembed_provider.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,15 @@ class FastEmbedEmbeddingProvider(EmbeddingProvider):
2424
def _effective_parallel(self) -> int | None:
2525
return self.parallel if self.parallel is not None and self.parallel > 1 else None
2626

27+
def runtime_log_attrs(self) -> dict[str, int | str | None]:
28+
"""Return the resolved runtime knobs that shape FastEmbed throughput."""
29+
return {
30+
"provider_batch_size": self.batch_size,
31+
"threads": self.threads,
32+
"configured_parallel": self.parallel,
33+
"effective_parallel": self._effective_parallel(),
34+
}
35+
2736
def __init__(
2837
self,
2938
model_name: str = "bge-small-en-v1.5",

src/basic_memory/repository/openai_provider.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,13 @@ def __init__(
3434
self._client: Any | None = None
3535
self._client_lock = asyncio.Lock()
3636

37+
def runtime_log_attrs(self) -> dict[str, int]:
38+
"""Return the request fan-out knobs that shape API embedding batches."""
39+
return {
40+
"provider_batch_size": self.batch_size,
41+
"request_concurrency": self.request_concurrency,
42+
}
43+
3744
async def _get_client(self) -> Any:
3845
if self._client is not None:
3946
return self._client

src/basic_memory/repository/search_repository_base.py

Lines changed: 123 additions & 121 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
HEADER_LINE_PATTERN = re.compile(r"^\s*#{1,6}\s+")
3838
BULLET_PATTERN = re.compile(r"^[\-\*]\s+")
3939
OVERSIZED_ENTITY_VECTOR_SHARD_SIZE = 256
40+
_SQLITE_MAX_PREPARE_WINDOW = 8
4041

4142

4243
@dataclass
@@ -800,7 +801,9 @@ async def _sync_entity_vectors_internal(
800801
batch_start = time.perf_counter()
801802
backend_name = type(self).__name__.removesuffix("SearchRepository").lower()
802803

803-
self._log_vector_sync_runtime_settings(backend_name=backend_name, entities_total=total_entities)
804+
self._log_vector_sync_runtime_settings(
805+
backend_name=backend_name, entities_total=total_entities
806+
)
804807
logger.info(
805808
"Vector batch sync start: project_id={project_id} entities_total={entities_total} "
806809
"sync_batch_size={sync_batch_size} prepare_window_size={prepare_window_size}",
@@ -1006,102 +1009,104 @@ def emit_progress(entity_id: int) -> None:
10061009
for failed_entity_id in affected_entity_ids:
10071010
emit_progress(failed_entity_id)
10081011

1009-
# Trigger: this should never happen after all flushes succeed.
1010-
# Why: remaining jobs mean runtime tracking drifted from queued jobs.
1011-
# Outcome: fail-safe marks these entities as failed to avoid false positives.
1012-
if entity_runtime:
1013-
orphan_runtime_entities = sorted(entity_runtime.keys())
1014-
failed_entity_ids.update(orphan_runtime_entities)
1015-
synced_entity_ids.difference_update(orphan_runtime_entities)
1016-
deferred_entity_ids.difference_update(orphan_runtime_entities)
1017-
logger.warning(
1018-
"Vector batch sync left unfinished entities after flushes: "
1019-
"project_id={project_id} unfinished_entities={unfinished_entities}",
1012+
# Trigger: this should never happen after all flushes succeed.
1013+
# Why: remaining jobs mean runtime tracking drifted from queued jobs.
1014+
# Outcome: fail-safe marks these entities as failed to avoid false positives.
1015+
if entity_runtime:
1016+
orphan_runtime_entities = sorted(entity_runtime.keys())
1017+
failed_entity_ids.update(orphan_runtime_entities)
1018+
synced_entity_ids.difference_update(orphan_runtime_entities)
1019+
deferred_entity_ids.difference_update(orphan_runtime_entities)
1020+
logger.warning(
1021+
"Vector batch sync left unfinished entities after flushes: "
1022+
"project_id={project_id} unfinished_entities={unfinished_entities}",
1023+
project_id=self.project_id,
1024+
unfinished_entities=orphan_runtime_entities,
1025+
)
1026+
for failed_entity_id in orphan_runtime_entities:
1027+
emit_progress(failed_entity_id)
1028+
1029+
# Keep result counters aligned with successful/failed terminal states.
1030+
synced_entity_ids.difference_update(failed_entity_ids)
1031+
deferred_entity_ids.difference_update(failed_entity_ids)
1032+
deferred_entity_ids.difference_update(synced_entity_ids)
1033+
result.failed_entity_ids = sorted(failed_entity_ids)
1034+
result.entities_failed = len(result.failed_entity_ids)
1035+
result.entities_deferred = len(deferred_entity_ids)
1036+
result.entities_synced = len(synced_entity_ids)
1037+
1038+
logger.info(
1039+
"Vector batch sync complete: project_id={project_id} entities_total={entities_total} "
1040+
"entities_synced={entities_synced} entities_failed={entities_failed} "
1041+
"entities_deferred={entities_deferred} "
1042+
"entities_skipped={entities_skipped} chunks_total={chunks_total} "
1043+
"chunks_skipped={chunks_skipped} embedding_jobs_total={embedding_jobs_total} "
1044+
"prepare_seconds_total={prepare_seconds_total:.3f} "
1045+
"queue_wait_seconds_total={queue_wait_seconds_total:.3f} "
1046+
"embed_seconds_total={embed_seconds_total:.3f} write_seconds_total={write_seconds_total:.3f}",
10201047
project_id=self.project_id,
1021-
unfinished_entities=orphan_runtime_entities,
1048+
entities_total=result.entities_total,
1049+
entities_synced=result.entities_synced,
1050+
entities_failed=result.entities_failed,
1051+
entities_deferred=result.entities_deferred,
1052+
entities_skipped=result.entities_skipped,
1053+
chunks_total=result.chunks_total,
1054+
chunks_skipped=result.chunks_skipped,
1055+
embedding_jobs_total=result.embedding_jobs_total,
1056+
prepare_seconds_total=result.prepare_seconds_total,
1057+
queue_wait_seconds_total=result.queue_wait_seconds_total,
1058+
embed_seconds_total=result.embed_seconds_total,
1059+
write_seconds_total=result.write_seconds_total,
10221060
)
1023-
for failed_entity_id in orphan_runtime_entities:
1024-
emit_progress(failed_entity_id)
1025-
1026-
# Keep result counters aligned with successful/failed terminal states.
1027-
synced_entity_ids.difference_update(failed_entity_ids)
1028-
deferred_entity_ids.difference_update(failed_entity_ids)
1029-
deferred_entity_ids.difference_update(synced_entity_ids)
1030-
result.failed_entity_ids = sorted(failed_entity_ids)
1031-
result.entities_failed = len(result.failed_entity_ids)
1032-
result.entities_deferred = len(deferred_entity_ids)
1033-
result.entities_synced = len(synced_entity_ids)
1034-
1035-
logger.info(
1036-
"Vector batch sync complete: project_id={project_id} entities_total={entities_total} "
1037-
"entities_synced={entities_synced} entities_failed={entities_failed} "
1038-
"entities_deferred={entities_deferred} "
1039-
"entities_skipped={entities_skipped} chunks_total={chunks_total} "
1040-
"chunks_skipped={chunks_skipped} embedding_jobs_total={embedding_jobs_total} "
1041-
"prepare_seconds_total={prepare_seconds_total:.3f} "
1042-
"queue_wait_seconds_total={queue_wait_seconds_total:.3f} "
1043-
"embed_seconds_total={embed_seconds_total:.3f} write_seconds_total={write_seconds_total:.3f}",
1044-
project_id=self.project_id,
1045-
entities_total=result.entities_total,
1046-
entities_synced=result.entities_synced,
1047-
entities_failed=result.entities_failed,
1048-
entities_deferred=result.entities_deferred,
1049-
entities_skipped=result.entities_skipped,
1050-
chunks_total=result.chunks_total,
1051-
chunks_skipped=result.chunks_skipped,
1052-
embedding_jobs_total=result.embedding_jobs_total,
1053-
prepare_seconds_total=result.prepare_seconds_total,
1054-
queue_wait_seconds_total=result.queue_wait_seconds_total,
1055-
embed_seconds_total=result.embed_seconds_total,
1056-
write_seconds_total=result.write_seconds_total,
1057-
)
1058-
batch_total_seconds = time.perf_counter() - batch_start
1059-
metric_attrs = {
1060-
"backend": backend_name,
1061-
"skip_only_batch": result.embedding_jobs_total == 0,
1062-
}
1063-
telemetry.record_histogram(
1064-
"vector_sync_batch_total_seconds",
1065-
batch_total_seconds,
1066-
unit="s",
1067-
**metric_attrs,
1068-
)
1069-
telemetry.add_counter("vector_sync_entities_total", result.entities_total, **metric_attrs)
1070-
telemetry.add_counter(
1071-
"vector_sync_entities_skipped",
1072-
result.entities_skipped,
1073-
**metric_attrs,
1074-
)
1075-
telemetry.add_counter(
1076-
"vector_sync_entities_deferred",
1077-
result.entities_deferred,
1078-
**metric_attrs,
1079-
)
1080-
telemetry.add_counter(
1081-
"vector_sync_embedding_jobs_total",
1082-
result.embedding_jobs_total,
1083-
**metric_attrs,
1084-
)
1085-
telemetry.add_counter("vector_sync_chunks_total", result.chunks_total, **metric_attrs)
1086-
telemetry.add_counter(
1087-
"vector_sync_chunks_skipped",
1088-
result.chunks_skipped,
1089-
**metric_attrs,
1090-
)
1091-
if batch_span is not None:
1092-
batch_span.set_attributes(
1093-
{
1094-
"backend": backend_name,
1095-
"entities_synced": result.entities_synced,
1096-
"entities_failed": result.entities_failed,
1097-
"entities_deferred": result.entities_deferred,
1098-
"entities_skipped": result.entities_skipped,
1099-
"embedding_jobs_total": result.embedding_jobs_total,
1100-
"chunks_total": result.chunks_total,
1101-
"chunks_skipped": result.chunks_skipped,
1102-
"batch_total_seconds": batch_total_seconds,
1103-
}
1061+
batch_total_seconds = time.perf_counter() - batch_start
1062+
metric_attrs = {
1063+
"backend": backend_name,
1064+
"skip_only_batch": result.embedding_jobs_total == 0,
1065+
}
1066+
telemetry.record_histogram(
1067+
"vector_sync_batch_total_seconds",
1068+
batch_total_seconds,
1069+
unit="s",
1070+
**metric_attrs,
1071+
)
1072+
telemetry.add_counter(
1073+
"vector_sync_entities_total", result.entities_total, **metric_attrs
1074+
)
1075+
telemetry.add_counter(
1076+
"vector_sync_entities_skipped",
1077+
result.entities_skipped,
1078+
**metric_attrs,
1079+
)
1080+
telemetry.add_counter(
1081+
"vector_sync_entities_deferred",
1082+
result.entities_deferred,
1083+
**metric_attrs,
1084+
)
1085+
telemetry.add_counter(
1086+
"vector_sync_embedding_jobs_total",
1087+
result.embedding_jobs_total,
1088+
**metric_attrs,
1089+
)
1090+
telemetry.add_counter("vector_sync_chunks_total", result.chunks_total, **metric_attrs)
1091+
telemetry.add_counter(
1092+
"vector_sync_chunks_skipped",
1093+
result.chunks_skipped,
1094+
**metric_attrs,
11041095
)
1096+
if batch_span is not None:
1097+
batch_span.set_attributes(
1098+
{
1099+
"backend": backend_name,
1100+
"entities_synced": result.entities_synced,
1101+
"entities_failed": result.entities_failed,
1102+
"entities_deferred": result.entities_deferred,
1103+
"entities_skipped": result.entities_skipped,
1104+
"embedding_jobs_total": result.embedding_jobs_total,
1105+
"chunks_total": result.chunks_total,
1106+
"chunks_skipped": result.chunks_skipped,
1107+
"batch_total_seconds": batch_total_seconds,
1108+
}
1109+
)
11051110

11061111
return result
11071112

@@ -1113,7 +1118,10 @@ def _vector_prepare_window_size(self) -> int:
11131118
# explode to the full embed batch size creates unnecessary write contention.
11141119
# Outcome: local backends get a small bounded window, while Postgres keeps
11151120
# its explicit higher concurrency override.
1116-
return max(1, min(self._semantic_embedding_sync_batch_size, 8))
1121+
return max(
1122+
1,
1123+
min(self._semantic_embedding_sync_batch_size, _SQLITE_MAX_PREPARE_WINDOW),
1124+
)
11171125

11181126
@asynccontextmanager
11191127
async def _prepare_entity_write_scope(self):
@@ -1223,14 +1231,18 @@ async def _prepare_entity_vector_jobs_window(
12231231
session, entity_ids
12241232
)
12251233
except Exception as exc:
1234+
# Trigger: the shared read pass failed before we had entity-level diffs.
1235+
# Why: once the window-level read session breaks, we cannot safely
1236+
# distinguish one entity from another inside that window.
1237+
# Outcome: every entity in the window gets the same failure object.
12261238
return [exc for _ in entity_ids]
12271239

12281240
# Trigger: prepare now does one shared read pass per window instead of
12291241
# paying the same select/join round-trips per entity.
12301242
# Why: both SQLite and Postgres were still burning wall clock in read-side
12311243
# fingerprint/orphan checks even when every entity ended up skipped.
1232-
# Outcome: we batch the reads once, then fan back out over entities while
1233-
# preserving input order in the gathered results.
1244+
# Outcome: we batch the reads once, close that shared read session, and
1245+
# then fan back out over entities while preserving input order.
12341246
prepared_window = await asyncio.gather(
12351247
*(
12361248
self._prepare_entity_vector_jobs_prefetched(
@@ -1264,7 +1276,8 @@ async def _prepare_entity_vector_jobs_prefetched(
12641276
prepare_start = sync_start
12651277
source_rows_count = len(source_rows)
12661278

1267-
if not source_rows:
1279+
async def _delete_entity_chunks_and_finish() -> _PreparedEntityVectorSync:
1280+
"""Delete derived rows and return the empty prepare result."""
12681281
async with self._prepare_entity_write_scope():
12691282
async with db.scoped_session(self.session_maker) as session:
12701283
await self._prepare_vector_session(session)
@@ -1279,22 +1292,13 @@ async def _prepare_entity_vector_jobs_prefetched(
12791292
prepare_seconds=prepare_seconds,
12801293
)
12811294

1295+
if not source_rows:
1296+
return await _delete_entity_chunks_and_finish()
1297+
12821298
chunk_records = self._build_chunk_records(source_rows)
12831299
built_chunk_records_count = len(chunk_records)
12841300
if not chunk_records:
1285-
async with self._prepare_entity_write_scope():
1286-
async with db.scoped_session(self.session_maker) as session:
1287-
await self._prepare_vector_session(session)
1288-
await self._delete_entity_chunks(session, entity_id)
1289-
await session.commit()
1290-
prepare_seconds = time.perf_counter() - prepare_start
1291-
return _PreparedEntityVectorSync(
1292-
entity_id=entity_id,
1293-
sync_start=sync_start,
1294-
source_rows_count=source_rows_count,
1295-
embedding_jobs=[],
1296-
prepare_seconds=prepare_seconds,
1297-
)
1301+
return await _delete_entity_chunks_and_finish()
12981302

12991303
current_entity_fingerprint = self._build_entity_fingerprint(chunk_records)
13001304
current_embedding_model = self._embedding_model_key()
@@ -1607,27 +1611,25 @@ def _log_vector_sync_runtime_settings(self, *, backend_name: str, entities_total
16071611
"""
16081612
assert self._embedding_provider is not None
16091613

1610-
from basic_memory.repository.fastembed_provider import FastEmbedEmbeddingProvider
1611-
16121614
provider = self._embedding_provider
1613-
if isinstance(provider, FastEmbedEmbeddingProvider):
1615+
runtime_attrs = (
1616+
provider.runtime_log_attrs() if hasattr(provider, "runtime_log_attrs") else {}
1617+
)
1618+
if runtime_attrs:
16141619
logger.info(
16151620
"Vector batch runtime settings: project_id={project_id} backend={backend} "
16161621
"entities_total={entities_total} provider={provider} model_name={model_name} "
1617-
"dimensions={dimensions} provider_batch_size={provider_batch_size} "
1618-
"sync_batch_size={sync_batch_size} threads={threads} "
1619-
"configured_parallel={configured_parallel} effective_parallel={effective_parallel}",
1622+
"dimensions={dimensions} sync_batch_size={sync_batch_size} "
1623+
"{runtime_attrs}",
16201624
project_id=self.project_id,
16211625
backend=backend_name,
16221626
entities_total=entities_total,
16231627
provider=type(provider).__name__,
16241628
model_name=provider.model_name,
16251629
dimensions=provider.dimensions,
1626-
provider_batch_size=provider.batch_size,
16271630
sync_batch_size=self._semantic_embedding_sync_batch_size,
1628-
threads=provider.threads,
1629-
configured_parallel=provider.parallel,
1630-
effective_parallel=provider._effective_parallel(),
1631+
runtime_attrs=" ".join(f"{key}={value}" for key, value in runtime_attrs.items()),
1632+
**runtime_attrs,
16311633
)
16321634
return
16331635

0 commit comments

Comments
 (0)