Skip to content

Commit 49a40be

Browse files
committed
fix(core): address batch indexing review issues
Signed-off-by: phernandez <paul@basicmachines.co>
1 parent 2f3cf95 commit 49a40be

File tree

9 files changed

+517
-265
lines changed

9 files changed

+517
-265
lines changed

src/basic_memory/indexing/batch_indexer.py

Lines changed: 40 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
from __future__ import annotations
44

55
import asyncio
6-
import time
76
from dataclasses import dataclass
87
from datetime import datetime
98
from pathlib import Path
@@ -13,18 +12,18 @@
1312
from sqlalchemy.exc import IntegrityError
1413

1514
from basic_memory.config import BasicMemoryConfig
16-
from basic_memory.file_utils import compute_checksum, has_frontmatter, remove_frontmatter
15+
from basic_memory.file_utils import compute_checksum, has_frontmatter
1716
from basic_memory.markdown.schemas import EntityMarkdown
1817
from basic_memory.indexing.models import (
1918
IndexedEntity,
2019
IndexFileWriter,
2120
IndexFrontmatterUpdate,
2221
IndexingBatchResult,
2322
IndexInputFile,
24-
IndexProgress,
2523
)
2624
from basic_memory.models import Entity, Relation
2725
from basic_memory.services import EntityService
26+
from basic_memory.services.exceptions import SyncFatalError
2827
from basic_memory.services.search_service import SearchService
2928
from basic_memory.repository import EntityRepository, RelationRepository
3029

@@ -76,28 +75,17 @@ async def index_files(
7675
*,
7776
max_concurrent: int,
7877
parse_max_concurrent: int | None = None,
79-
progress_callback: Callable[[IndexProgress], Awaitable[None]] | None = None,
78+
existing_permalink_by_path: dict[str, str | None] | None = None,
8079
) -> IndexingBatchResult:
8180
"""Index one batch of loaded files with bounded concurrency."""
8281
if max_concurrent <= 0:
8382
raise ValueError("max_concurrent must be greater than zero")
8483

8584
ordered_paths = sorted(files)
8685
if not ordered_paths:
87-
result = IndexingBatchResult()
88-
if progress_callback is not None:
89-
await progress_callback(
90-
IndexProgress(
91-
files_total=0,
92-
files_processed=0,
93-
batches_total=0,
94-
batches_completed=0,
95-
)
96-
)
97-
return result
86+
return IndexingBatchResult()
9887

9988
parse_limit = parse_max_concurrent or max_concurrent
100-
batch_start = time.monotonic()
10189
error_by_path: dict[str, str] = {}
10290

10391
markdown_paths = [path for path in ordered_paths if self._is_markdown(files[path])]
@@ -111,7 +99,8 @@ async def index_files(
11199
error_by_path.update(parse_errors)
112100

113101
prepared_markdown, normalization_errors = await self._normalize_markdown_batch(
114-
prepared_markdown
102+
prepared_markdown,
103+
existing_permalink_by_path=existing_permalink_by_path,
115104
)
116105
error_by_path.update(normalization_errors)
117106

@@ -171,21 +160,6 @@ async def index_files(
171160

172161
search_indexed = len(indexed_entities)
173162

174-
if progress_callback is not None:
175-
elapsed_seconds = max(time.monotonic() - batch_start, 0.001)
176-
files_per_minute = len(ordered_paths) / elapsed_seconds * 60
177-
await progress_callback(
178-
IndexProgress(
179-
files_total=len(ordered_paths),
180-
files_processed=len(ordered_paths),
181-
batches_total=1,
182-
batches_completed=1,
183-
current_batch_bytes=sum(max(files[path].size, 0) for path in ordered_paths),
184-
files_per_minute=files_per_minute,
185-
eta_seconds=0.0,
186-
)
187-
)
188-
189163
return IndexingBatchResult(
190164
indexed=indexed_entities,
191165
errors=[(path, error_by_path[path]) for path in ordered_paths if path in error_by_path],
@@ -221,12 +195,21 @@ async def _prepare_markdown_file(self, file: IndexInputFile) -> _PreparedMarkdow
221195
async def _normalize_markdown_batch(
222196
self,
223197
prepared_markdown: dict[str, _PreparedMarkdownFile],
198+
*,
199+
existing_permalink_by_path: dict[str, str | None] | None = None,
224200
) -> tuple[dict[str, _PreparedMarkdownFile], dict[str, str]]:
225201
if not prepared_markdown:
226202
return {}, {}
227203

204+
if existing_permalink_by_path is None:
205+
existing_permalink_by_path = {
206+
path: permalink
207+
for path, permalink in (
208+
await self.entity_repository.get_file_path_to_permalink_map()
209+
).items()
210+
}
211+
228212
batch_paths = set(prepared_markdown)
229-
existing_permalink_by_path = await self.entity_repository.get_file_path_to_permalink_map()
230213
reserved_permalinks = {
231214
permalink
232215
for path, permalink in existing_permalink_by_path.items()
@@ -242,6 +225,7 @@ async def _normalize_markdown_batch(
242225
prepared_markdown[path],
243226
reserved_permalinks,
244227
)
228+
existing_permalink_by_path[path] = normalized[path].markdown.frontmatter.permalink
245229
except Exception as exc:
246230
errors[path] = str(exc)
247231
logger.warning("Batch markdown normalization failed", path=path, error=str(exc))
@@ -357,13 +341,18 @@ async def _upsert_markdown_file(self, prepared: _PreparedMarkdownFile) -> _Prepa
357341
entity_id=updated.id,
358342
checksum=prepared.final_checksum,
359343
content_type=prepared.file.content_type,
360-
search_content=remove_frontmatter(prepared.content),
344+
search_content=(
345+
prepared.markdown.content
346+
if prepared.markdown.content is not None
347+
else prepared.content
348+
),
361349
markdown_content=prepared.content,
362350
)
363351

364352
async def _upsert_regular_file(self, file: IndexInputFile) -> _PreparedEntity:
365353
checksum = await self._resolve_checksum(file)
366354
existing = await self.entity_repository.get_by_file_path(file.path, load_relations=False)
355+
is_new_entity = existing is None
367356

368357
if existing is None:
369358
await self.entity_service.resolve_permalink(file.path, skip_conflict_check=True)
@@ -408,7 +397,7 @@ async def _upsert_regular_file(self, file: IndexInputFile) -> _PreparedEntity:
408397

409398
updated = await self.entity_repository.update(
410399
entity_id,
411-
self._entity_metadata_updates(file, checksum, include_created_at=existing is None),
400+
self._entity_metadata_updates(file, checksum, include_created_at=is_new_entity),
412401
)
413402
if updated is None:
414403
raise ValueError(f"Failed to update file entity metadata for {file.path}")
@@ -430,11 +419,15 @@ async def _resolve_batch_relations(
430419
*,
431420
max_concurrent: int,
432421
) -> tuple[int, int]:
433-
unresolved_relations: list[Relation] = []
434-
for entity_id in entity_ids:
435-
unresolved_relations.extend(
436-
await self.relation_repository.find_unresolved_relations_for_entity(entity_id)
422+
unresolved_relation_lists = await asyncio.gather(
423+
*(
424+
self.relation_repository.find_unresolved_relations_for_entity(entity_id)
425+
for entity_id in entity_ids
437426
)
427+
)
428+
unresolved_relations = [
429+
relation for relation_list in unresolved_relation_lists for relation in relation_list
430+
]
438431

439432
if not unresolved_relations:
440433
return 0, 0
@@ -475,11 +468,13 @@ async def resolve_relation(relation: Relation) -> int:
475468
*(resolve_relation(relation) for relation in unresolved_relations)
476469
)
477470

478-
remaining_unresolved = 0
479-
for entity_id in entity_ids:
480-
remaining_unresolved += len(
481-
await self.relation_repository.find_unresolved_relations_for_entity(entity_id)
471+
remaining_relation_lists = await asyncio.gather(
472+
*(
473+
self.relation_repository.find_unresolved_relations_for_entity(entity_id)
474+
for entity_id in entity_ids
482475
)
476+
)
477+
remaining_unresolved = sum(len(relations) for relations in remaining_relation_lists)
483478

484479
return sum(resolved_counts), remaining_unresolved
485480

@@ -552,6 +547,8 @@ async def run(path: str) -> None:
552547
try:
553548
results[path] = await worker(path)
554549
except Exception as exc:
550+
if isinstance(exc, SyncFatalError) or isinstance(exc.__cause__, SyncFatalError):
551+
raise
555552
errors[path] = str(exc)
556553
logger.warning("Batch indexing failed", path=path, error=str(exc))
557554

src/basic_memory/indexing/batching.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def build_index_batches(
5252
current_paths.append(path)
5353
current_bytes += file_bytes
5454

55-
if len(current_paths) >= max_files or current_bytes >= max_bytes:
55+
if len(current_paths) >= max_files or current_bytes == max_bytes:
5656
batches.append(IndexBatch(paths=current_paths, total_bytes=current_bytes))
5757
current_paths = []
5858
current_bytes = 0

src/basic_memory/repository/postgres_search_repository.py

Lines changed: 0 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -512,12 +512,6 @@ async def _prepare_entity_vector_jobs(self, entity_id: int) -> _PreparedEntityVe
512512
"""Prepare chunk mutations with Postgres-specific bulk upserts."""
513513
sync_start = time.perf_counter()
514514

515-
logger.info(
516-
"Vector sync start: project_id={project_id} entity_id={entity_id}",
517-
project_id=self.project_id,
518-
entity_id=entity_id,
519-
)
520-
521515
async with db.scoped_session(self.session_maker) as session:
522516
await self._prepare_vector_session(session)
523517

@@ -546,13 +540,6 @@ async def _prepare_entity_vector_jobs(self, entity_id: int) -> _PreparedEntityVe
546540
source_rows_count = len(rows)
547541

548542
if not rows:
549-
logger.info(
550-
"Vector sync source prepared: project_id={project_id} entity_id={entity_id} "
551-
"source_rows_count={source_rows_count} built_chunk_records_count=0",
552-
project_id=self.project_id,
553-
entity_id=entity_id,
554-
source_rows_count=source_rows_count,
555-
)
556543
await self._delete_entity_chunks(session, entity_id)
557544
await session.commit()
558545
prepare_seconds = time.perf_counter() - sync_start
@@ -568,15 +555,6 @@ async def _prepare_entity_vector_jobs(self, entity_id: int) -> _PreparedEntityVe
568555
built_chunk_records_count = len(chunk_records)
569556
current_entity_fingerprint = self._build_entity_fingerprint(chunk_records)
570557
current_embedding_model = self._embedding_model_key()
571-
logger.info(
572-
"Vector sync source prepared: project_id={project_id} entity_id={entity_id} "
573-
"source_rows_count={source_rows_count} "
574-
"built_chunk_records_count={built_chunk_records_count}",
575-
project_id=self.project_id,
576-
entity_id=entity_id,
577-
source_rows_count=source_rows_count,
578-
built_chunk_records_count=built_chunk_records_count,
579-
)
580558
if not chunk_records:
581559
await self._delete_entity_chunks(session, entity_id)
582560
await session.commit()
@@ -629,16 +607,6 @@ async def _prepare_entity_vector_jobs(self, entity_id: int) -> _PreparedEntityVe
629607
)
630608
)
631609
if skip_unchanged_entity:
632-
logger.info(
633-
"Vector sync skipped unchanged entity: project_id={project_id} "
634-
"entity_id={entity_id} chunks_skipped={chunks_skipped} "
635-
"entity_fingerprint={entity_fingerprint} embedding_model={embedding_model}",
636-
project_id=self.project_id,
637-
entity_id=entity_id,
638-
chunks_skipped=built_chunk_records_count,
639-
entity_fingerprint=current_entity_fingerprint,
640-
embedding_model=current_embedding_model,
641-
)
642610
prepare_seconds = time.perf_counter() - sync_start
643611
return _PreparedEntityVectorSync(
644612
entity_id=entity_id,
@@ -752,31 +720,6 @@ async def _prepare_entity_vector_jobs(self, entity_id: int) -> _PreparedEntityVe
752720
row_id = upserted_ids_by_key[record["chunk_key"]]
753721
embedding_jobs.append((row_id, record["chunk_text"]))
754722

755-
logger.info(
756-
"Vector sync diff complete: project_id={project_id} entity_id={entity_id} "
757-
"existing_chunks_count={existing_chunks_count} "
758-
"stale_chunks_count={stale_chunks_count} "
759-
"orphan_chunks_count={orphan_chunks_count} "
760-
"chunks_skipped={chunks_skipped} "
761-
"embedding_jobs_count={embedding_jobs_count} "
762-
"pending_jobs_total={pending_jobs_total} shard_index={shard_index} "
763-
"shard_count={shard_count} remaining_jobs_after_shard={remaining_jobs_after_shard} "
764-
"oversized_entity={oversized_entity} entity_complete={entity_complete}",
765-
project_id=self.project_id,
766-
entity_id=entity_id,
767-
existing_chunks_count=existing_chunks_count,
768-
stale_chunks_count=stale_chunks_count,
769-
orphan_chunks_count=orphan_chunks_count,
770-
chunks_skipped=skipped_chunks_count,
771-
embedding_jobs_count=len(embedding_jobs),
772-
pending_jobs_total=shard_plan.pending_jobs_total,
773-
shard_index=shard_plan.shard_index,
774-
shard_count=shard_plan.shard_count,
775-
remaining_jobs_after_shard=shard_plan.remaining_jobs_after_shard,
776-
oversized_entity=shard_plan.oversized_entity,
777-
entity_complete=shard_plan.entity_complete,
778-
)
779-
780723
prepare_seconds = time.perf_counter() - sync_start
781724
return _PreparedEntityVectorSync(
782725
entity_id=entity_id,

0 commit comments

Comments
 (0)