From 58dd6963bdfcf7682cbb7666c1223cf75773295a Mon Sep 17 00:00:00 2001 From: phernandez Date: Mon, 13 Apr 2026 01:42:22 -0500 Subject: [PATCH 1/2] chore(core): use ty for typechecking Signed-off-by: phernandez --- justfile | 8 +- .../api/v2/routers/prompt_router.py | 71 ++++---- .../api/v2/routers/resource_router.py | 6 +- src/basic_memory/api/v2/utils.py | 117 ++++++++----- .../cli/commands/cloud/api_client.py | 70 ++++---- src/basic_memory/cli/commands/tool.py | 2 +- src/basic_memory/config.py | 7 +- src/basic_memory/importers/utils.py | 13 +- src/basic_memory/markdown/schemas.py | 40 ++++- src/basic_memory/mcp/prompts/utils.py | 14 +- src/basic_memory/mcp/tools/read_note.py | 27 ++- src/basic_memory/models/base.py | 5 +- .../repository/fastembed_provider.py | 4 +- .../repository/openai_provider.py | 2 +- src/basic_memory/repository/repository.py | 7 +- .../repository/search_repository_base.py | 47 +++--- .../repository/sqlite_search_repository.py | 2 +- src/basic_memory/schemas/memory.py | 2 +- src/basic_memory/schemas/response.py | 2 +- src/basic_memory/services/entity_service.py | 20 ++- src/basic_memory/sync/watch_service.py | 18 +- .../test_output_format_json_integration.py | 3 +- test-int/mcp/test_pagination_integration.py | 3 +- test-int/semantic/test_search_diagnostics.py | 15 +- test-int/semantic/test_semantic_coverage.py | 4 +- test-int/test_db_wal_mode.py | 17 +- tests/api/v2/conftest.py | 3 +- .../api/v2/test_knowledge_router_telemetry.py | 35 ++-- tests/api/v2/test_memory_hydration.py | 29 +++- tests/api/v2/test_project_router.py | 26 ++- tests/api/v2/test_search_hydration.py | 5 +- tests/api/v2/test_search_router_telemetry.py | 18 +- .../cloud/test_cloud_api_client_and_utils.py | 1 + tests/cli/test_auto_update.py | 25 +-- tests/cli/test_cli_telemetry.py | 4 +- tests/cli/test_cloud_authentication.py | 11 +- tests/cli/test_cloud_promo.py | 10 +- tests/cli/test_json_output.py | 14 +- .../test_issue_254_foreign_key_constraints.py | 1 + .../markdown/test_date_frontmatter_parsing.py | 1 + .../test_entity_parser_error_handling.py | 2 + tests/markdown/test_markdown_processor.py | 1 + tests/markdown/test_observation_edge_cases.py | 7 + tests/markdown/test_parser_edge_cases.py | 2 + tests/markdown/test_relation_edge_cases.py | 6 + tests/mcp/conftest.py | 4 +- ...test_permalink_collision_file_overwrite.py | 2 + tests/mcp/test_project_context.py | 33 ++-- tests/mcp/test_project_context_telemetry.py | 3 +- .../mcp/test_recent_activity_prompt_modes.py | 3 +- tests/mcp/test_tool_build_context.py | 4 + tests/mcp/test_tool_contracts.py | 4 +- tests/mcp/test_tool_project_management.py | 2 + tests/mcp/test_tool_read_note.py | 8 +- tests/mcp/test_tool_recent_activity.py | 3 +- tests/mcp/test_tool_utils.py | 34 ++-- tests/mcp/test_tool_utils_cloud_auth.py | 10 +- tests/mcp/test_tool_workspace_management.py | 10 +- tests/mcp/test_tool_write_note.py | 4 +- tests/mcp/test_ui_sdk.py | 7 +- tests/repository/test_entity_repository.py | 6 +- tests/repository/test_fastembed_provider.py | 10 +- tests/repository/test_hybrid_fusion.py | 43 ++++- .../repository/test_observation_repository.py | 4 +- tests/repository/test_openai_provider.py | 10 +- .../test_postgres_search_repository.py | 6 + tests/repository/test_relation_repository.py | 7 +- .../test_search_repository_edit_bug_fix.py | 2 + tests/repository/test_semantic_search_base.py | 27 ++- .../test_sqlite_vector_search_repository.py | 19 ++- tests/repository/test_vector_pagination.py | 37 +++- tests/repository/test_vector_threshold.py | 49 ++++-- tests/schema/test_resolver.py | 3 + tests/schema/test_validator.py | 2 + tests/schemas/test_search.py | 5 +- tests/services/test_entity_service.py | 158 +++++++++--------- .../test_entity_service_disable_permalinks.py | 3 +- tests/services/test_project_service.py | 21 ++- tests/services/test_search_service.py | 4 + tests/services/test_semantic_search.py | 4 +- .../services/test_task_scheduler_semantic.py | 25 +-- tests/sync/test_sync_service.py | 8 +- tests/sync/test_sync_service_incremental.py | 79 ++++----- tests/sync/test_watch_service_reload.py | 35 ++-- tests/test_config.py | 13 +- tests/test_production_cascade_delete.py | 67 ++++++-- tests/utils/test_file_utils.py | 3 +- tests/utils/test_parse_tags.py | 4 +- 88 files changed, 961 insertions(+), 551 deletions(-) diff --git a/justfile b/justfile index e12fcbe97..04c5b25af 100644 --- a/justfile +++ b/justfile @@ -170,13 +170,17 @@ lint: fix fix: uv run ruff check --fix --unsafe-fixes src tests test-int -# Type check code (pyright) +# Type check code (ty) typecheck: + uv run ty check src tests test-int + +# Type check code (pyright) +typecheck-pyright: uv run pyright # Type check code (ty) typecheck-ty: - uv run ty check src/ + just typecheck # Clean build artifacts and cache files clean: diff --git a/src/basic_memory/api/v2/routers/prompt_router.py b/src/basic_memory/api/v2/routers/prompt_router.py index c6e58ec0f..69cc1c3cc 100644 --- a/src/basic_memory/api/v2/routers/prompt_router.py +++ b/src/basic_memory/api/v2/routers/prompt_router.py @@ -6,6 +6,7 @@ """ from datetime import datetime, timezone +from typing import Any from fastapi import APIRouter, HTTPException, status, Path from loguru import logger @@ -59,6 +60,7 @@ async def continue_conversation( # Initialize search results search_results = [] + hierarchical_results_for_count = [] # Get data needed for template if request.topic: @@ -91,7 +93,8 @@ async def continue_conversation( # Limit to a reasonable number of total results all_hierarchical_results = all_hierarchical_results[:10] - template_context = { + hierarchical_results_for_count = all_hierarchical_results + template_context: dict[str, Any] = { "topic": request.topic, "timeframe": request.timeframe, "hierarchical_results": all_hierarchical_results, @@ -110,6 +113,7 @@ async def continue_conversation( hierarchical_results = recent_context.results[:5] # Limit to top 5 recent items + hierarchical_results_for_count = hierarchical_results template_context = { "topic": f"Recent Activity from ({request.timeframe})", "timeframe": request.timeframe, @@ -129,9 +133,6 @@ async def continue_conversation( relation_count = 0 entity_count = 0 - # Get the hierarchical results from the template context - hierarchical_results_for_count = template_context.get("hierarchical_results", []) - # For topic-based search if request.topic: for item in hierarchical_results_for_count: @@ -159,29 +160,24 @@ async def continue_conversation( elif related.type == "entity": # pragma: no cover entity_count += 1 # pragma: no cover - # Build metadata - metadata = { - "query": request.topic, - "timeframe": request.timeframe, - "search_count": len(search_results) - if request.topic - else 0, # Original search results count - "context_count": len(hierarchical_results_for_count), - "observation_count": observation_count, - "relation_count": relation_count, - "total_items": ( + prompt_metadata = PromptMetadata( + query=request.topic, + timeframe=request.timeframe, + search_count=len(search_results) if request.topic else 0, + context_count=len(hierarchical_results_for_count), + observation_count=observation_count, + relation_count=relation_count, + total_items=( len(hierarchical_results_for_count) + observation_count + relation_count + entity_count ), - "search_limit": request.search_items_limit, - "context_depth": request.depth, - "related_limit": request.related_items_limit, - "generated_at": datetime.now(timezone.utc).isoformat(), - } - - prompt_metadata = PromptMetadata(**metadata) + search_limit=request.search_items_limit, + context_depth=request.depth, + related_limit=request.related_items_limit, + generated_at=datetime.now(timezone.utc).isoformat(), + ) return PromptResponse( prompt=rendered_prompt, context=template_context, metadata=prompt_metadata @@ -229,7 +225,7 @@ async def search_prompt( results = await search_service.search(query, limit=limit, offset=offset) search_results = await to_search_results(entity_service, results) - template_context = { + template_context: dict[str, Any] = { "query": request.query, "timeframe": request.timeframe, "results": search_results, @@ -241,22 +237,19 @@ async def search_prompt( # Render template rendered_prompt = await template_loader.render("prompts/search.hbs", template_context) - # Build metadata - metadata = { - "query": request.query, - "timeframe": request.timeframe, - "search_count": len(search_results), - "context_count": len(search_results), - "observation_count": 0, # Search results don't include observations - "relation_count": 0, # Search results don't include relations - "total_items": len(search_results), - "search_limit": limit, - "context_depth": 0, # No context depth for basic search - "related_limit": 0, # No related items for basic search - "generated_at": datetime.now(timezone.utc).isoformat(), - } - - prompt_metadata = PromptMetadata(**metadata) + prompt_metadata = PromptMetadata( + query=request.query, + timeframe=request.timeframe, + search_count=len(search_results), + context_count=len(search_results), + observation_count=0, + relation_count=0, + total_items=len(search_results), + search_limit=limit, + context_depth=0, + related_limit=0, + generated_at=datetime.now(timezone.utc).isoformat(), + ) return PromptResponse( prompt=rendered_prompt, context=template_context, metadata=prompt_metadata diff --git a/src/basic_memory/api/v2/routers/resource_router.py b/src/basic_memory/api/v2/routers/resource_router.py index d459bb9d9..4ebf94506 100644 --- a/src/basic_memory/api/v2/routers/resource_router.py +++ b/src/basic_memory/api/v2/routers/resource_router.py @@ -211,7 +211,7 @@ async def create_resource( action="create", phase="search_index", ): - await search_service.index_entity(entity) # pyright: ignore + await search_service.index_entity(entity) return ResourceResponse( entity_id=entity.id, @@ -326,6 +326,8 @@ async def update_resource( "updated_at": file_metadata.modified_at, }, ) + if updated_entity is None: + raise HTTPException(status_code=404, detail=f"Entity {entity_id} not found") with telemetry.scope( "api.resource.update.search_index", @@ -333,7 +335,7 @@ async def update_resource( action="update", phase="search_index", ): - await search_service.index_entity(updated_entity) # pyright: ignore + await search_service.index_entity(updated_entity) return ResourceResponse( entity_id=entity.id, diff --git a/src/basic_memory/api/v2/utils.py b/src/basic_memory/api/v2/utils.py index 7c0ac65d4..2977e3ef8 100644 --- a/src/basic_memory/api/v2/utils.py +++ b/src/basic_memory/api/v2/utils.py @@ -1,8 +1,6 @@ -from typing import Optional, List +from typing import Any, Protocol, Optional, List, Sequence from basic_memory import telemetry -from basic_memory.models import Entity as EntityModel -from basic_memory.repository import EntityRepository from basic_memory.repository.search_repository import SearchIndexRow from basic_memory.schemas.memory import ( EntitySummary, @@ -13,19 +11,38 @@ ContextResult, ) from basic_memory.schemas.search import SearchItemType, SearchResult -from basic_memory.services import EntityService from basic_memory.services.context_service import ( ContextResultRow, ContextResult as ServiceContextResult, ) +class EntityBatchLookup(Protocol): + async def find_by_ids(self, ids: List[int]) -> Sequence[Any]: ... + + +class EntityServiceBatchLookup(Protocol): + async def get_entities_by_id(self, ids: List[int]) -> Sequence[Any]: ... + + +def _required_str(value: str | None, field_name: str) -> str: + """Return a required search field or fail before producing invalid response data.""" + if value is None: + raise ValueError(f"Search result is missing required field: {field_name}") + return value + + +def _search_item_type(value: str | SearchItemType) -> SearchItemType: + """Normalize repository row type strings into the public search enum.""" + return value if isinstance(value, SearchItemType) else SearchItemType(value) + + async def to_graph_context( context_result: ServiceContextResult, - entity_repository: EntityRepository, + entity_repository: EntityBatchLookup, page: Optional[int] = None, page_size: Optional[int] = None, -): +) -> GraphContext: with telemetry.scope( "memory.hydrate_context", domain="memory", @@ -44,17 +61,18 @@ async def to_graph_context( + context_item.observations + context_item.related_results ): - if item.type == SearchItemType.ENTITY: + item_type = _search_item_type(item.type) + if item_type == SearchItemType.ENTITY: # Entity's own ID for its external_id entity_ids_needed.add(item.id) - elif item.type == SearchItemType.OBSERVATION: + elif item_type == SearchItemType.OBSERVATION: # Parent entity ID for entity_external_id - if item.entity_id: # pyright: ignore - entity_ids_needed.add(item.entity_id) # pyright: ignore - elif item.type == SearchItemType.RELATION: + if item.entity_id: + entity_ids_needed.add(item.entity_id) + elif item_type == SearchItemType.RELATION: # Source and target entity IDs for external_ids - if item.from_id: # pyright: ignore - entity_ids_needed.add(item.from_id) # pyright: ignore + if item.from_id: + entity_ids_needed.add(item.from_id) if item.to_id: entity_ids_needed.add(item.to_id) @@ -75,57 +93,60 @@ async def to_graph_context( entity_external_id_lookup[e.id] = e.external_id # Helper function to convert items to summaries - def to_summary(item: SearchIndexRow | ContextResultRow): - match item.type: + def to_summary( + item: SearchIndexRow | ContextResultRow, + ) -> EntitySummary | ObservationSummary | RelationSummary: + item_type = _search_item_type(item.type) + match item_type: case SearchItemType.ENTITY: return EntitySummary( external_id=entity_external_id_lookup.get(item.id, ""), entity_id=item.id, - title=item.title, # pyright: ignore + title=_required_str(item.title, "title"), permalink=item.permalink, content=item.content, - file_path=item.file_path, + file_path=_required_str(item.file_path, "file_path"), created_at=item.created_at, ) case SearchItemType.OBSERVATION: entity_ext_id = None - if item.entity_id: # pyright: ignore - entity_ext_id = entity_external_id_lookup.get(item.entity_id) # pyright: ignore + entity_title = None + if item.entity_id: + entity_ext_id = entity_external_id_lookup.get(item.entity_id) + entity_title = entity_title_lookup.get(item.entity_id) return ObservationSummary( observation_id=item.id, - entity_id=item.entity_id, # pyright: ignore + entity_id=item.entity_id, entity_external_id=entity_ext_id, - title=entity_title_lookup.get(item.entity_id), # pyright: ignore - file_path=item.file_path, - category=item.category, # pyright: ignore - content=item.content, # pyright: ignore - permalink=item.permalink, # pyright: ignore + title=entity_title, + file_path=_required_str(item.file_path, "file_path"), + category=_required_str(item.category, "category"), + content=_required_str(item.content, "content"), + permalink=_required_str(item.permalink, "permalink"), created_at=item.created_at, ) case SearchItemType.RELATION: - from_title = entity_title_lookup.get(item.from_id) if item.from_id else None # pyright: ignore + from_title = entity_title_lookup.get(item.from_id) if item.from_id else None to_title = entity_title_lookup.get(item.to_id) if item.to_id else None from_ext_id = ( entity_external_id_lookup.get(item.from_id) if item.from_id else None - ) # pyright: ignore + ) to_ext_id = entity_external_id_lookup.get(item.to_id) if item.to_id else None return RelationSummary( relation_id=item.id, - entity_id=item.entity_id, # pyright: ignore - title=item.title, # pyright: ignore - file_path=item.file_path, - permalink=item.permalink, # pyright: ignore - relation_type=item.relation_type, # pyright: ignore + entity_id=item.entity_id, + title=_required_str(item.title, "title"), + file_path=_required_str(item.file_path, "file_path"), + permalink=_required_str(item.permalink, "permalink"), + relation_type=_required_str(item.relation_type, "relation_type"), from_entity=from_title, - from_entity_id=item.from_id, # pyright: ignore + from_entity_id=item.from_id, from_entity_external_id=from_ext_id, to_entity=to_title, to_entity_id=item.to_id, to_entity_external_id=to_ext_id, created_at=item.created_at, ) - case _: # pragma: no cover - raise ValueError(f"Unexpected type: {item.type}") with telemetry.scope( "memory.hydrate_context.shape_results", @@ -137,12 +158,16 @@ def to_summary(item: SearchIndexRow | ContextResultRow): hierarchical_results = [] for context_item in context_result.results: primary_result = to_summary(context_item.primary_result) - observations = [to_summary(obs) for obs in context_item.observations] + observations = [ + summary + for summary in (to_summary(obs) for obs in context_item.observations) + if isinstance(summary, ObservationSummary) + ] related = [to_summary(rel) for rel in context_item.related_results] hierarchical_results.append( ContextResult( primary_result=primary_result, - observations=observations, # pyright: ignore[reportArgumentType] + observations=observations, related_results=related, ) ) @@ -170,7 +195,9 @@ def to_summary(item: SearchIndexRow | ContextResultRow): ) -async def to_search_results(entity_service: EntityService, results: List[SearchIndexRow]): +async def to_search_results( + entity_service: EntityServiceBatchLookup, results: List[SearchIndexRow] +) -> list[SearchResult]: with telemetry.scope( "search.hydrate_results", domain="search", @@ -187,7 +214,7 @@ async def to_search_results(entity_service: EntityService, results: List[SearchI all_entity_ids.add(eid) # Single batch fetch for all entities - entities_by_id: dict[int, EntityModel] = {} + entities_by_id: dict[int, Any] = {} with telemetry.scope( "search.hydrate_results.fetch_entities", domain="search", @@ -222,20 +249,20 @@ async def to_search_results(entity_service: EntityService, results: List[SearchI entity_id = result.entity_id # Look up entities by their specific IDs - parent_entity = entities_by_id.get(result.entity_id) if result.entity_id else None # pyright: ignore - from_entity = entities_by_id.get(result.from_id) if result.from_id else None # pyright: ignore + parent_entity = entities_by_id.get(result.entity_id) if result.entity_id else None + from_entity = entities_by_id.get(result.from_id) if result.from_id else None to_entity = entities_by_id.get(result.to_id) if result.to_id else None search_results.append( SearchResult( - title=result.title, # pyright: ignore - type=result.type, # pyright: ignore + title=_required_str(result.title, "title"), + type=_search_item_type(result.type), permalink=result.permalink, - score=result.score, # pyright: ignore + score=result.score if result.score is not None else 0.0, entity=parent_entity.permalink if parent_entity else None, content=result.content, matched_chunk=result.matched_chunk_text, - file_path=result.file_path, + file_path=_required_str(result.file_path, "file_path"), metadata=result.metadata, entity_id=entity_id, observation_id=observation_id, diff --git a/src/basic_memory/cli/commands/cloud/api_client.py b/src/basic_memory/cli/commands/cloud/api_client.py index 8f77963b8..98dacc2f7 100644 --- a/src/basic_memory/cli/commands/cloud/api_client.py +++ b/src/basic_memory/cli/commands/cloud/api_client.py @@ -99,41 +99,39 @@ async def make_api_request( response = await client.request(method=method, url=url, headers=headers, json=json_data) response.raise_for_status() return response + except httpx.HTTPStatusError as e: + response = e.response + + # Try to parse error detail from response + error_detail = None + try: + error_detail = response.json() + except Exception: + # If JSON parsing fails, we'll handle it as a generic error + pass + + # Check for subscription_required error (403) + if response.status_code == 403 and isinstance(error_detail, dict): + # Handle both FastAPI HTTPException format (nested under "detail") + # and direct format + detail_obj = error_detail.get("detail", error_detail) + if ( + isinstance(detail_obj, dict) + and detail_obj.get("error") == "subscription_required" + ): + message = detail_obj.get("message", "Active subscription required") + subscribe_url = detail_obj.get( + "subscribe_url", "https://basicmemory.com/subscribe" + ) + raise SubscriptionRequiredError( + message=message, subscribe_url=subscribe_url + ) from e + + # Raise generic CloudAPIError with status code and detail + raise CloudAPIError( + f"API request failed: {e}", + status_code=response.status_code, + detail=error_detail if isinstance(error_detail, dict) else {}, + ) from e except httpx.HTTPError as e: - # Check if this is a response error with response details - if hasattr(e, "response") and e.response is not None: # pyright: ignore [reportAttributeAccessIssue] - response = e.response # type: ignore - - # Try to parse error detail from response - error_detail = None - try: - error_detail = response.json() - except Exception: - # If JSON parsing fails, we'll handle it as a generic error - pass - - # Check for subscription_required error (403) - if response.status_code == 403 and isinstance(error_detail, dict): - # Handle both FastAPI HTTPException format (nested under "detail") - # and direct format - detail_obj = error_detail.get("detail", error_detail) - if ( - isinstance(detail_obj, dict) - and detail_obj.get("error") == "subscription_required" - ): - message = detail_obj.get("message", "Active subscription required") - subscribe_url = detail_obj.get( - "subscribe_url", "https://basicmemory.com/subscribe" - ) - raise SubscriptionRequiredError( - message=message, subscribe_url=subscribe_url - ) from e - - # Raise generic CloudAPIError with status code and detail - raise CloudAPIError( - f"API request failed: {e}", - status_code=response.status_code, - detail=error_detail if isinstance(error_detail, dict) else {}, - ) from e - raise CloudAPIError(f"API request failed: {e}") from e diff --git a/src/basic_memory/cli/commands/tool.py b/src/basic_memory/cli/commands/tool.py index 6486b4c25..8b4457901 100644 --- a/src/basic_memory/cli/commands/tool.py +++ b/src/basic_memory/cli/commands/tool.py @@ -345,7 +345,7 @@ def recent_activity( with force_routing(local=local, cloud=cloud): result = run_with_cleanup( mcp_recent_activity( - type=type, # pyright: ignore[reportArgumentType] + type=type or "", depth=depth if depth is not None else 1, timeframe=timeframe if timeframe is not None else "7d", page=page, diff --git a/src/basic_memory/config.py b/src/basic_memory/config.py index 3fa875806..61eb4556d 100644 --- a/src/basic_memory/config.py +++ b/src/basic_memory/config.py @@ -8,7 +8,7 @@ from datetime import datetime from enum import Enum from pathlib import Path -from typing import Any, Dict, Literal, Optional, List, Tuple +from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, List, Tuple from loguru import logger from pydantic import AliasChoices, BaseModel, Field, model_validator @@ -122,6 +122,11 @@ class ProjectEntry(BaseModel): class BasicMemoryConfig(BaseSettings): """Pydantic model for Basic Memory global configuration.""" + if TYPE_CHECKING: + # Pydantic accepts raw constructor data and validates/coerces it at runtime. + # Model attributes remain strongly typed after initialization. + def __init__(self, **data: Any) -> None: ... + env: Environment = Field(default="dev", description="Environment name") projects: Dict[str, ProjectEntry] = Field( diff --git a/src/basic_memory/importers/utils.py b/src/basic_memory/importers/utils.py index 9c2f987bf..ef07a0f9d 100644 --- a/src/basic_memory/importers/utils.py +++ b/src/basic_memory/importers/utils.py @@ -39,23 +39,24 @@ def format_timestamp(timestamp: Any) -> str: # pragma: no cover Returns: A formatted string representation of the timestamp. """ + parsed_timestamp = timestamp if isinstance(timestamp, str): try: # Try ISO format - timestamp = datetime.fromisoformat(timestamp.replace("Z", "+00:00")) + parsed_timestamp = datetime.fromisoformat(timestamp.replace("Z", "+00:00")) except ValueError: try: # Try unix timestamp as string - timestamp = datetime.fromtimestamp(float(timestamp)).astimezone() + parsed_timestamp = datetime.fromtimestamp(float(timestamp)).astimezone() except ValueError: # Return as is if we can't parse it return timestamp elif isinstance(timestamp, (int, float)): # Unix timestamp - timestamp = datetime.fromtimestamp(timestamp).astimezone() + parsed_timestamp = datetime.fromtimestamp(timestamp).astimezone() - if isinstance(timestamp, datetime): - return timestamp.strftime("%Y-%m-%d %H:%M:%S") + if isinstance(parsed_timestamp, datetime): + return parsed_timestamp.strftime("%Y-%m-%d %H:%M:%S") # Return as is if we can't format it - return str(timestamp) # pragma: no cover + return str(parsed_timestamp) # pragma: no cover diff --git a/src/basic_memory/markdown/schemas.py b/src/basic_memory/markdown/schemas.py index f7c2d078c..cde881cc1 100644 --- a/src/basic_memory/markdown/schemas.py +++ b/src/basic_memory/markdown/schemas.py @@ -1,9 +1,9 @@ """Schema models for entity markdown files.""" from datetime import datetime -from typing import List, Optional +from typing import TYPE_CHECKING, Any, List, Optional -from pydantic import BaseModel +from pydantic import BaseModel, Field, model_validator class Observation(BaseModel): @@ -38,23 +38,47 @@ def __str__(self) -> str: class EntityFrontmatter(BaseModel): """Required frontmatter fields for an entity.""" - metadata: dict = {} + if TYPE_CHECKING: + # Frontmatter may be built from raw YAML keys. The validator below + # gathers those keys into the metadata mapping used at runtime. + def __init__(self, **data: Any) -> None: ... + + metadata: dict[str, Any] = Field(default_factory=dict) + + @model_validator(mode="before") + @classmethod + def collect_metadata(cls, data: Any) -> Any: + if not isinstance(data, dict): + return data + + if "metadata" not in data: + return {"metadata": data} + + metadata = data.get("metadata") or {} + extras = {key: value for key, value in data.items() if key != "metadata"} + if extras: + return {"metadata": {**extras, **metadata}} + return data @property def tags(self) -> List[str]: - return self.metadata.get("tags") if self.metadata else None # pyright: ignore + tags = self.metadata.get("tags") + return [str(tag) for tag in tags] if isinstance(tags, list) else [] @property def title(self) -> str: - return self.metadata.get("title") if self.metadata else None # pyright: ignore + title = self.metadata.get("title") + return title if isinstance(title, str) else "" @property def type(self) -> str: - return self.metadata.get("type", "note") if self.metadata else "note" # pyright: ignore + note_type = self.metadata.get("type", "note") + return note_type if isinstance(note_type, str) else "note" @property - def permalink(self) -> str: - return self.metadata.get("permalink") if self.metadata else None # pyright: ignore + def permalink(self) -> Optional[str]: + permalink = self.metadata.get("permalink") + return permalink if isinstance(permalink, str) else None class EntityMarkdown(BaseModel): diff --git a/src/basic_memory/mcp/prompts/utils.py b/src/basic_memory/mcp/prompts/utils.py index 0e680b61e..ba0fe4643 100644 --- a/src/basic_memory/mcp/prompts/utils.py +++ b/src/basic_memory/mcp/prompts/utils.py @@ -95,8 +95,8 @@ def format_prompt_context(context: PromptContext) -> str: sections = [] # Process each context - for context in context.results: # pyright: ignore - for primary in context.primary_results: # pyright: ignore + for context_item in context.results: + for primary in context_item.primary_results: if primary.permalink not in added_permalinks: primary_permalink = primary.permalink @@ -121,8 +121,8 @@ def format_prompt_context(context: PromptContext) -> str: section += f"- **Created**: {primary.created_at.strftime('%Y-%m-%d %H:%M')}\n" # Add content snippet - if hasattr(primary, "content") and primary.content: # pyright: ignore - content = primary.content or "" # pyright: ignore # pragma: no cover + if hasattr(primary, "content") and primary.content: + content = primary.content or "" # pragma: no cover if content: # pragma: no cover section += f"\n**Excerpt**:\n{content}\n" # pragma: no cover @@ -132,14 +132,14 @@ def format_prompt_context(context: PromptContext) -> str: """) sections.append(section) - if context.related_results: # pyright: ignore - section += dedent( # pyright: ignore + if context_item.related_results: + section += dedent( """ ## Related Context """ ) - for related in context.related_results: # pyright: ignore + for related in context_item.related_results: section_content = dedent(f""" - type: **{related.type}** - title: {related.title} diff --git a/src/basic_memory/mcp/tools/read_note.py b/src/basic_memory/mcp/tools/read_note.py index 09dc8916c..229bc28c7 100644 --- a/src/basic_memory/mcp/tools/read_note.py +++ b/src/basic_memory/mcp/tools/read_note.py @@ -1,7 +1,7 @@ """Read note tool for Basic Memory MCP server.""" from textwrap import dedent -from typing import Optional, Literal +from typing import Optional, Literal, cast import yaml @@ -235,13 +235,22 @@ def _empty_json_payload() -> dict: "frontmatter": None, } - def _search_results(payload: object) -> list[dict]: + def _search_results(payload: object) -> list[dict[str, object]]: if not isinstance(payload, dict): return [] - results = payload.get("results") - return results if isinstance(results, list) else [] + payload_dict = cast(dict[str, object], payload) + results = payload_dict.get("results") + if not isinstance(results, list): + return [] + return [ + cast(dict[str, object], result) + for result in results + if isinstance(result, dict) + ] - async def _search_candidates(identifier_text: str, *, title_only: bool) -> dict: + async def _search_candidates( + identifier_text: str, *, title_only: bool + ) -> dict[str, object]: # Trigger: direct entity resolution failed for the caller's identifier. # Why: search_notes applies the same memory:// normalization and tool-level # query handling as the rest of MCP routing, which raw client calls skip. @@ -257,16 +266,16 @@ async def _search_candidates(identifier_text: str, *, title_only: bool) -> dict: output_format="json", context=context, ) - return response if isinstance(response, dict) else {} + return cast(dict[str, object], response) if isinstance(response, dict) else {} - def _result_title(item: dict) -> str: + def _result_title(item: dict[str, object]) -> str: return str(item.get("title") or "") - def _result_permalink(item: dict) -> Optional[str]: + def _result_permalink(item: dict[str, object]) -> Optional[str]: value = item.get("permalink") return str(value) if value else None - def _result_file_path(item: dict) -> Optional[str]: + def _result_file_path(item: dict[str, object]) -> Optional[str]: value = item.get("file_path") return str(value) if value else None diff --git a/src/basic_memory/models/base.py b/src/basic_memory/models/base.py index 471f14e24..ea1e17c3e 100644 --- a/src/basic_memory/models/base.py +++ b/src/basic_memory/models/base.py @@ -1,5 +1,7 @@ """Base model class for SQLAlchemy models.""" +from typing import TYPE_CHECKING + from sqlalchemy.ext.asyncio import AsyncAttrs from sqlalchemy.orm import DeclarativeBase @@ -7,4 +9,5 @@ class Base(AsyncAttrs, DeclarativeBase): """Base class for all models""" - pass + if TYPE_CHECKING: + id: int diff --git a/src/basic_memory/repository/fastembed_provider.py b/src/basic_memory/repository/fastembed_provider.py index e36351487..5ade579e9 100644 --- a/src/basic_memory/repository/fastembed_provider.py +++ b/src/basic_memory/repository/fastembed_provider.py @@ -11,7 +11,7 @@ from basic_memory.repository.semantic_errors import SemanticDependenciesMissingError if TYPE_CHECKING: - from fastembed import TextEmbedding # type: ignore[import-not-found] # pragma: no cover + from fastembed import TextEmbedding # pragma: no cover class FastEmbedEmbeddingProvider(EmbeddingProvider): @@ -62,7 +62,7 @@ async def _load_model(self) -> "TextEmbedding": def _create_model() -> "TextEmbedding": try: - from fastembed import TextEmbedding # type: ignore[import-not-found] + from fastembed import TextEmbedding except ( ImportError ) as exc: # pragma: no cover - exercised via tests with monkeypatch diff --git a/src/basic_memory/repository/openai_provider.py b/src/basic_memory/repository/openai_provider.py index b44e13b77..24a5ac04b 100644 --- a/src/basic_memory/repository/openai_provider.py +++ b/src/basic_memory/repository/openai_provider.py @@ -50,7 +50,7 @@ async def _get_client(self) -> Any: return self._client try: - from openai import AsyncOpenAI # type: ignore[import-not-found] + from openai import AsyncOpenAI except ImportError as exc: # pragma: no cover - covered via monkeypatch tests raise SemanticDependenciesMissingError( "OpenAI dependency is missing. " diff --git a/src/basic_memory/repository/repository.py b/src/basic_memory/repository/repository.py index 63cba2d21..558321d87 100644 --- a/src/basic_memory/repository/repository.py +++ b/src/basic_memory/repository/repository.py @@ -268,7 +268,7 @@ async def create_all(self, data_list: List[dict]) -> Sequence[T]: return await self.select_by_ids(session, [model.id for model in model_list]) # pyright: ignore [reportAttributeAccessIssue] - async def update(self, entity_id: int, entity_data: dict | T) -> Optional[T]: + async def update(self, entity_id: int, entity_data: dict[str, Any] | T) -> Optional[T]: """Update an entity with the given data.""" logger.debug(f"Updating {self.Model.__name__} {entity_id} with data: {entity_data}") async with db.scoped_session(self.session_maker) as session: @@ -279,12 +279,13 @@ async def update(self, entity_id: int, entity_data: dict | T) -> Optional[T]: entity = result.scalars().one() if isinstance(entity_data, dict): - for key, value in entity_data.items(): + update_data = cast(dict[str, Any], entity_data) + for key, value in update_data.items(): if key in self.valid_columns: setattr(entity, key, value) elif isinstance(entity_data, self.Model): - for column in self.Model.__table__.columns.keys(): + for column in self.valid_columns: setattr(entity, column, getattr(entity_data, column)) await session.flush() # Make sure changes are flushed diff --git a/src/basic_memory/repository/search_repository_base.py b/src/basic_memory/repository/search_repository_base.py index 7783aa282..f44db34ba 100644 --- a/src/basic_memory/repository/search_repository_base.py +++ b/src/basic_memory/repository/search_repository_base.py @@ -1068,39 +1068,48 @@ def emit_progress(entity_id: int) -> None: write_seconds_total=result.write_seconds_total, ) batch_total_seconds = time.perf_counter() - batch_start - metric_attrs = { - "backend": backend_name, - "skip_only_batch": result.embedding_jobs_total == 0, - } telemetry.record_histogram( "vector_sync_batch_total_seconds", batch_total_seconds, unit="s", - **metric_attrs, + backend=backend_name, + skip_only_batch=result.embedding_jobs_total == 0, ) telemetry.add_counter( - "vector_sync_entities_total", result.entities_total, **metric_attrs + "vector_sync_entities_total", + result.entities_total, + backend=backend_name, + skip_only_batch=result.embedding_jobs_total == 0, ) telemetry.add_counter( "vector_sync_entities_skipped", result.entities_skipped, - **metric_attrs, + backend=backend_name, + skip_only_batch=result.embedding_jobs_total == 0, ) telemetry.add_counter( "vector_sync_entities_deferred", result.entities_deferred, - **metric_attrs, + backend=backend_name, + skip_only_batch=result.embedding_jobs_total == 0, ) telemetry.add_counter( "vector_sync_embedding_jobs_total", result.embedding_jobs_total, - **metric_attrs, + backend=backend_name, + skip_only_batch=result.embedding_jobs_total == 0, + ) + telemetry.add_counter( + "vector_sync_chunks_total", + result.chunks_total, + backend=backend_name, + skip_only_batch=result.embedding_jobs_total == 0, ) - telemetry.add_counter("vector_sync_chunks_total", result.chunks_total, **metric_attrs) telemetry.add_counter( "vector_sync_chunks_skipped", result.chunks_skipped, - **metric_attrs, + backend=backend_name, + skip_only_batch=result.embedding_jobs_total == 0, ) if batch_span is not None: batch_span.set_attributes( @@ -1675,33 +1684,33 @@ def _log_vector_sync_complete( ) -> None: """Log completion and slow-entity warnings with a consistent format.""" backend_name = type(self).__name__.removesuffix("SearchRepository").lower() - metric_attrs = { - "backend": backend_name, - "skip_only_entity": entity_skipped and embedding_jobs_count == 0, - } telemetry.record_histogram( "vector_sync_prepare_seconds", prepare_seconds, unit="s", - **metric_attrs, + backend=backend_name, + skip_only_entity=entity_skipped and embedding_jobs_count == 0, ) telemetry.record_histogram( "vector_sync_queue_wait_seconds", queue_wait_seconds, unit="s", - **metric_attrs, + backend=backend_name, + skip_only_entity=entity_skipped and embedding_jobs_count == 0, ) telemetry.record_histogram( "vector_sync_embed_seconds", embed_seconds, unit="s", - **metric_attrs, + backend=backend_name, + skip_only_entity=entity_skipped and embedding_jobs_count == 0, ) telemetry.record_histogram( "vector_sync_write_seconds", write_seconds, unit="s", - **metric_attrs, + backend=backend_name, + skip_only_entity=entity_skipped and embedding_jobs_count == 0, ) if total_seconds > 10: logger.warning( diff --git a/src/basic_memory/repository/sqlite_search_repository.py b/src/basic_memory/repository/sqlite_search_repository.py index 06eb41c07..474c8c5ca 100644 --- a/src/basic_memory/repository/sqlite_search_repository.py +++ b/src/basic_memory/repository/sqlite_search_repository.py @@ -350,7 +350,7 @@ async def _ensure_sqlite_vec_loaded(self, session) -> None: pass try: - import sqlite_vec # type: ignore[import-not-found] + import sqlite_vec except ImportError as exc: raise SemanticDependenciesMissingError( "sqlite-vec package is missing. " diff --git a/src/basic_memory/schemas/memory.py b/src/basic_memory/schemas/memory.py index f1a27e974..a0af709ce 100644 --- a/src/basic_memory/schemas/memory.py +++ b/src/basic_memory/schemas/memory.py @@ -103,7 +103,7 @@ def normalize_memory_url(url: str | None) -> str: memory_url = TypeAdapter(MemoryUrl) -def memory_url_path(url: memory_url) -> str: # pyright: ignore +def memory_url_path(url: str) -> str: """ Returns the uri for a url value by removing the prefix "memory://" from a given MemoryUrl. diff --git a/src/basic_memory/schemas/response.py b/src/basic_memory/schemas/response.py index 0feb78953..f9343c82d 100644 --- a/src/basic_memory/schemas/response.py +++ b/src/basic_memory/schemas/response.py @@ -194,7 +194,7 @@ class EntityResponse(SQLAlchemyModel): note_type: NoteType # COMPAT(v0.18): old clients expect entity_type; remove when no longer needed - @computed_field # type: ignore[prop-decorator] + @computed_field @property def entity_type(self) -> str: return self.note_type diff --git a/src/basic_memory/services/entity_service.py b/src/basic_memory/services/entity_service.py index 67de97375..9d7f294bf 100644 --- a/src/basic_memory/services/entity_service.py +++ b/src/basic_memory/services/entity_service.py @@ -288,7 +288,9 @@ async def create_entity_with_content(self, schema: EntitySchema) -> EntityWriteR if "permalink" in content_frontmatter: content_markdown = self._build_frontmatter_markdown( - schema.title, schema.note_type, content_frontmatter["permalink"] + schema.title, + schema.note_type, + _coerce_to_string(content_frontmatter["permalink"]), ) # Get unique permalink (prioritizing content frontmatter) unless disabled @@ -393,7 +395,9 @@ async def update_entity_with_content( if "permalink" in content_frontmatter: content_markdown = self._build_frontmatter_markdown( - schema.title, schema.note_type, content_frontmatter["permalink"] + schema.title, + schema.note_type, + _coerce_to_string(content_frontmatter["permalink"]), ) # Check if we need to update the permalink based on content frontmatter (unless disabled) @@ -522,7 +526,9 @@ async def fast_write_entity( if "permalink" in content_frontmatter: content_markdown = self._build_frontmatter_markdown( - schema.title, schema.note_type, content_frontmatter["permalink"] + schema.title, + schema.note_type, + _coerce_to_string(content_frontmatter["permalink"]), ) # --- Permalink Resolution --- @@ -663,9 +669,9 @@ async def fast_edit_entity( if "permalink" in content_frontmatter: content_markdown = self._build_frontmatter_markdown( - update_data.get("title", entity.title), - update_data.get("note_type", entity.note_type), - content_frontmatter["permalink"], + _coerce_to_string(update_data.get("title", entity.title)), + _coerce_to_string(update_data.get("note_type", entity.note_type)), + _coerce_to_string(content_frontmatter["permalink"]), ) metadata = normalize_frontmatter_metadata(content_frontmatter or {}) @@ -1002,7 +1008,7 @@ async def update_entity_relations( target_entity: Optional[Entity] = None if not isinstance(resolved, Exception): # Type narrowing: resolved is Optional[Entity] here, not Exception - target_entity = resolved # type: ignore + target_entity = resolved # if the target is found, store the id target_id = target_entity.id if target_entity else None diff --git a/src/basic_memory/sync/watch_service.py b/src/basic_memory/sync/watch_service.py index 9a53b8d13..4a7d29b7d 100644 --- a/src/basic_memory/sync/watch_service.py +++ b/src/basic_memory/sync/watch_service.py @@ -149,7 +149,7 @@ async def _watch_projects_cycle(self, projects: Sequence[Project], stop_event: a # create coroutines to handle changes change_handlers = [ - self.handle_changes(project, changes) # pyright: ignore + self.handle_changes(project, set(changes)) for project, changes in project_changes.items() ] @@ -502,19 +502,19 @@ async def handle_changes(self, project: Project, changes: Set[FileChange]) -> No # Add a concise summary instead of a divider if processed: - changes = [] # pyright: ignore + change_summary: list[str] = [] if add_count > 0: - changes.append(f"[green]{add_count} added[/green]") # pyright: ignore + change_summary.append(f"[green]{add_count} added[/green]") if modify_count > 0: - changes.append(f"[yellow]{modify_count} modified[/yellow]") # pyright: ignore + change_summary.append(f"[yellow]{modify_count} modified[/yellow]") if moved_count > 0: - changes.append(f"[blue]{moved_count} moved[/blue]") # pyright: ignore + change_summary.append(f"[blue]{moved_count} moved[/blue]") if delete_count > 0: - changes.append(f"[red]{delete_count} deleted[/red]") # pyright: ignore + change_summary.append(f"[red]{delete_count} deleted[/red]") - if changes: - self.console.print(f"{', '.join(changes)}", style="dim") # pyright: ignore - logger.info(f"changes: {len(changes)}") + if change_summary: + self.console.print(f"{', '.join(change_summary)}", style="dim") + logger.info(f"changes: {len(change_summary)}") duration_ms = int((time.time() - start_time) * 1000) self.state.last_scan = datetime.now() diff --git a/test-int/mcp/test_output_format_json_integration.py b/test-int/mcp/test_output_format_json_integration.py index 220818d41..4f2b0568e 100644 --- a/test-int/mcp/test_output_format_json_integration.py +++ b/test-int/mcp/test_output_format_json_integration.py @@ -4,6 +4,7 @@ import json from pathlib import Path +from typing import Any import pytest from fastmcp import Client @@ -11,7 +12,7 @@ from basic_memory.mcp.clients.knowledge import KnowledgeClient -def _json_content(tool_result) -> dict | list: +def _json_content(tool_result) -> Any: """Parse a FastMCP tool result content block into JSON.""" assert len(tool_result.content) == 1 assert tool_result.content[0].type == "text" diff --git a/test-int/mcp/test_pagination_integration.py b/test-int/mcp/test_pagination_integration.py index 0d6928f76..cfed3fb62 100644 --- a/test-int/mcp/test_pagination_integration.py +++ b/test-int/mcp/test_pagination_integration.py @@ -7,12 +7,13 @@ """ import json +from typing import Any import pytest from fastmcp import Client -def _json_content(tool_result) -> dict | list: +def _json_content(tool_result) -> Any: """Parse a FastMCP tool result content block into JSON.""" assert len(tool_result.content) == 1 assert tool_result.content[0].type == "text" diff --git a/test-int/semantic/test_search_diagnostics.py b/test-int/semantic/test_search_diagnostics.py index ef6ba45ce..664e9972c 100644 --- a/test-int/semantic/test_search_diagnostics.py +++ b/test-int/semantic/test_search_diagnostics.py @@ -9,6 +9,8 @@ from __future__ import annotations +from typing import Any, cast + import pytest from basic_memory.config import DatabaseBackend @@ -335,11 +337,10 @@ async def test_similarity_formula_analysis(sqlite_engine_factory, tmp_path): from basic_memory import db as bm_db - async with bm_db.scoped_session(service.repository.session_maker) as session: - await service.repository._prepare_vector_session(session) - raw_rows = await service.repository._run_vector_query( - session, query_embedding, candidate_limit=20 - ) + repo = cast(Any, service.repository) + async with bm_db.scoped_session(repo.session_maker) as session: + await repo._prepare_vector_session(session) + raw_rows = await repo._run_vector_query(session, query_embedding, candidate_limit=20) print(f"\nQuery: '{query_text}'") print(f" {'chunk_key':<40} {'distance':>10} {'sim_old':>12} {'sim_new':>12}") @@ -347,7 +348,7 @@ async def test_similarity_formula_analysis(sqlite_engine_factory, tmp_path): dist = float(row["best_distance"]) sim_old = 1.0 / (1.0 + max(dist, 0.0)) # New formula: L2 distance → cosine similarity for normalized embeddings - sim_new = service.repository._distance_to_similarity(dist) + sim_new = repo._distance_to_similarity(dist) print(f" {row['chunk_key']:<40} {dist:>10.4f} {sim_old:>12.4f} {sim_new:>12.4f}") @@ -431,7 +432,7 @@ async def test_chunking_produces_reasonable_chunks(sqlite_engine_factory, tmp_pa service = await create_search_service( sqlite_engine_factory, DIAG_COMBO, tmp_path, embedding_provider=provider ) - repo = service.repository + repo = cast(Any, service.repository) # Simulate a typical entity with observations text_input = ( diff --git a/test-int/semantic/test_semantic_coverage.py b/test-int/semantic/test_semantic_coverage.py index a5d72faca..43179a7a3 100644 --- a/test-int/semantic/test_semantic_coverage.py +++ b/test-int/semantic/test_semantic_coverage.py @@ -13,6 +13,8 @@ from __future__ import annotations +from typing import Any, cast + import pytest from basic_memory.config import DatabaseBackend @@ -194,7 +196,7 @@ async def test_postgres_vector_dimension_detection(postgres_engine_factory, tmp_ postgres_engine_factory, PG_FASTEMBED, tmp_path, embedding_provider=provider ) - repo = search_service.repository + repo = cast(Any, search_service.repository) # First entity triggers _ensure_vector_tables entity = await search_service.entity_repository.create( diff --git a/test-int/test_db_wal_mode.py b/test-int/test_db_wal_mode.py index a95626f81..acd8993bd 100644 --- a/test-int/test_db_wal_mode.py +++ b/test-int/test_db_wal_mode.py @@ -8,6 +8,11 @@ from sqlalchemy import text +def _first_value(row): + assert row is not None + return row[0] + + @pytest.mark.asyncio async def test_wal_mode_enabled(engine_factory, db_backend): """Test that WAL mode is enabled on filesystem database connections.""" @@ -19,7 +24,7 @@ async def test_wal_mode_enabled(engine_factory, db_backend): # Execute a query to verify WAL mode is enabled async with engine.connect() as conn: result = await conn.execute(text("PRAGMA journal_mode")) - journal_mode = result.fetchone()[0] + journal_mode = _first_value(result.fetchone()) # WAL mode should be enabled for filesystem databases assert journal_mode.upper() == "WAL" @@ -35,7 +40,7 @@ async def test_busy_timeout_configured(engine_factory, db_backend): async with engine.connect() as conn: result = await conn.execute(text("PRAGMA busy_timeout")) - busy_timeout = result.fetchone()[0] + busy_timeout = _first_value(result.fetchone()) # Busy timeout should be 10 seconds (10000 milliseconds) assert busy_timeout == 10000 @@ -51,7 +56,7 @@ async def test_synchronous_mode_configured(engine_factory, db_backend): async with engine.connect() as conn: result = await conn.execute(text("PRAGMA synchronous")) - synchronous = result.fetchone()[0] + synchronous = _first_value(result.fetchone()) # Synchronous should be NORMAL (1) assert synchronous == 1 @@ -67,7 +72,7 @@ async def test_cache_size_configured(engine_factory, db_backend): async with engine.connect() as conn: result = await conn.execute(text("PRAGMA cache_size")) - cache_size = result.fetchone()[0] + cache_size = _first_value(result.fetchone()) # Cache size should be -64000 (64MB) assert cache_size == -64000 @@ -83,7 +88,7 @@ async def test_temp_store_configured(engine_factory, db_backend): async with engine.connect() as conn: result = await conn.execute(text("PRAGMA temp_store")) - temp_store = result.fetchone()[0] + temp_store = _first_value(result.fetchone()) # temp_store should be MEMORY (2) assert temp_store == 2 @@ -114,7 +119,7 @@ async def test_windows_locking_mode_when_on_windows(tmp_path, monkeypatch, confi ): async with engine.connect() as conn: result = await conn.execute(text("PRAGMA locking_mode")) - locking_mode = result.fetchone()[0] + locking_mode = _first_value(result.fetchone()) # Locking mode should be NORMAL on Windows assert locking_mode.upper() == "NORMAL" diff --git a/tests/api/v2/conftest.py b/tests/api/v2/conftest.py index af6cedd8f..633415d6c 100644 --- a/tests/api/v2/conftest.py +++ b/tests/api/v2/conftest.py @@ -1,5 +1,6 @@ """Fixtures for V2 API tests.""" +from collections.abc import Generator from typing import Any, AsyncGenerator import pytest @@ -30,7 +31,7 @@ async def client(app: FastAPI) -> AsyncGenerator[AsyncClient, None]: @pytest.fixture(autouse=True) -def task_scheduler_spy(app: FastAPI) -> list[dict[str, Any]]: +def task_scheduler_spy(app: FastAPI) -> Generator[list[dict[str, Any]], None, None]: """Capture scheduled task specs without executing them.""" scheduled: list[dict[str, Any]] = [] diff --git a/tests/api/v2/test_knowledge_router_telemetry.py b/tests/api/v2/test_knowledge_router_telemetry.py index 21090dc7f..05e5ee2af 100644 --- a/tests/api/v2/test_knowledge_router_telemetry.py +++ b/tests/api/v2/test_knowledge_router_telemetry.py @@ -6,6 +6,7 @@ from contextlib import contextmanager from datetime import datetime, timezone from types import SimpleNamespace +from typing import Any, cast import pytest from fastapi import BackgroundTasks, Response @@ -85,7 +86,7 @@ async def read_file_content(self, path): raise AssertionError("non-fast create should not re-read file content") result = await knowledge_router_module.create_entity( - project_id="project-123", + project_id=123, data=Entity( title="Telemetry Entity", directory="notes", @@ -94,11 +95,11 @@ async def read_file_content(self, path): content="telemetry content", ), background_tasks=BackgroundTasks(), - entity_service=FakeEntityService(), - search_service=FakeSearchService(), + entity_service=cast(Any, FakeEntityService()), + search_service=cast(Any, FakeSearchService()), task_scheduler=FakeTaskScheduler(), - file_service=FakeFileService(), - app_config=SimpleNamespace(semantic_search_enabled=False), + file_service=cast(Any, FakeFileService()), + app_config=cast(Any, SimpleNamespace(semantic_search_enabled=False)), fast=False, ) @@ -159,13 +160,13 @@ async def read_file_content(self, path): ), response=response, background_tasks=BackgroundTasks(), - project_id="project-123", - entity_service=FakeEntityService(), - search_service=FakeSearchService(), - entity_repository=FakeEntityRepository(), + project_id=123, + entity_service=cast(Any, FakeEntityService()), + search_service=cast(Any, FakeSearchService()), + entity_repository=cast(Any, FakeEntityRepository()), task_scheduler=FakeTaskScheduler(), - file_service=FakeFileService(), - app_config=SimpleNamespace(semantic_search_enabled=False), + file_service=cast(Any, FakeFileService()), + app_config=cast(Any, SimpleNamespace(semantic_search_enabled=False)), entity_id=entity.external_id, fast=False, ) @@ -220,13 +221,13 @@ async def read_file_content(self, path): result = await knowledge_router_module.edit_entity_by_id( data=EditEntityRequest(operation="append", content="edited telemetry content"), background_tasks=BackgroundTasks(), - project_id="project-123", - entity_service=FakeEntityService(), - search_service=FakeSearchService(), - entity_repository=FakeEntityRepository(), + project_id=123, + entity_service=cast(Any, FakeEntityService()), + search_service=cast(Any, FakeSearchService()), + entity_repository=cast(Any, FakeEntityRepository()), task_scheduler=FakeTaskScheduler(), - file_service=FakeFileService(), - app_config=SimpleNamespace(semantic_search_enabled=False), + file_service=cast(Any, FakeFileService()), + app_config=cast(Any, SimpleNamespace(semantic_search_enabled=False)), entity_id=entity.external_id, fast=False, ) diff --git a/tests/api/v2/test_memory_hydration.py b/tests/api/v2/test_memory_hydration.py index 2c6b2337c..5154c3942 100644 --- a/tests/api/v2/test_memory_hydration.py +++ b/tests/api/v2/test_memory_hydration.py @@ -8,10 +8,12 @@ from datetime import datetime, timezone from types import SimpleNamespace +from typing import Any import pytest from basic_memory.api.v2.utils import to_graph_context +from basic_memory.schemas.memory import EntitySummary, ObservationSummary, RelationSummary from basic_memory.schemas.search import SearchItemType from basic_memory.services.context_service import ( ContextMetadata, @@ -28,9 +30,9 @@ def _make_entity(id: int, title: str, external_id: str) -> SimpleNamespace: return SimpleNamespace(id=id, title=title, external_id=external_id) -def _make_row(*, type: str, id: int, root_id: int, **kwargs) -> ContextResultRow: +def _make_row(*, type: str, id: int, root_id: int, **kwargs: Any) -> ContextResultRow: now = kwargs.pop("created_at", datetime.now(timezone.utc)) - defaults = dict( + defaults: dict[str, Any] = dict( title=f"Item {id}", permalink=f"notes/{id}", file_path=f"notes/{id}.md", @@ -159,20 +161,31 @@ async def test_to_graph_context_batches_entity_hydration_for_recent_activity(): assert set(repo.calls[0]) == {1, 2, 3} first_result = graph.results[0] - assert first_result.primary_result.external_id == "ext-root" - assert first_result.observations[0].entity_external_id == "ext-root" - assert first_result.observations[0].title == "Root" + first_primary = first_result.primary_result + assert isinstance(first_primary, EntitySummary) + assert first_primary.external_id == "ext-root" + + first_observation = first_result.observations[0] + assert isinstance(first_observation, ObservationSummary) + assert first_observation.entity_external_id == "ext-root" + assert first_observation.title == "Root" relation = first_result.related_results[0] + assert isinstance(relation, RelationSummary) assert relation.from_entity == "Root" assert relation.from_entity_external_id == "ext-root" assert relation.to_entity == "Child" assert relation.to_entity_external_id == "ext-child" second_result = graph.results[1] - assert second_result.primary_result.entity_external_id == "ext-child" - assert second_result.primary_result.title == "Child" - assert second_result.related_results[0].external_id == "ext-peer" + second_primary = second_result.primary_result + assert isinstance(second_primary, ObservationSummary) + assert second_primary.entity_external_id == "ext-child" + assert second_primary.title == "Child" + + peer_result = second_result.related_results[0] + assert isinstance(peer_result, EntitySummary) + assert peer_result.external_id == "ext-peer" @pytest.mark.asyncio diff --git a/tests/api/v2/test_project_router.py b/tests/api/v2/test_project_router.py index dec6e3438..921a6babe 100644 --- a/tests/api/v2/test_project_router.py +++ b/tests/api/v2/test_project_router.py @@ -11,6 +11,11 @@ from basic_memory.schemas.v2 import ProjectResolveResponse +def _project_item(project: ProjectItem | None) -> ProjectItem: + assert project is not None + return project + + @pytest.mark.asyncio async def test_list_projects(client: AsyncClient, test_project: Project, v2_projects_url): """Test listing projects returns default_project from the database.""" @@ -67,10 +72,12 @@ async def test_update_project_path_by_id( assert response.status_code == 200 status_response = ProjectStatusResponse.model_validate(response.json()) assert status_response.status == "success" - assert status_response.new_project.external_id == test_project.external_id + new_project = _project_item(status_response.new_project) + old_project = _project_item(status_response.old_project) + assert new_project.external_id == test_project.external_id # Normalize paths for cross-platform comparison (Windows uses backslashes, API returns forward slashes) - assert Path(status_response.new_project.path) == Path(new_path) - assert status_response.old_project.external_id == test_project.external_id + assert Path(new_project.path) == Path(new_path) + assert old_project.external_id == test_project.external_id @pytest.mark.asyncio @@ -122,10 +129,12 @@ async def test_set_default_project_by_id( status_response = ProjectStatusResponse.model_validate(response.json()) assert status_response.status == "success" assert status_response.default is True - assert status_response.new_project.external_id == created_project.external_id - assert status_response.new_project.is_default is True - assert status_response.old_project.external_id == test_project.external_id - assert status_response.old_project.is_default is False + new_project = _project_item(status_response.new_project) + old_project = _project_item(status_response.old_project) + assert new_project.external_id == created_project.external_id + assert new_project.is_default is True + assert old_project.external_id == test_project.external_id + assert old_project.is_default is False @pytest.mark.asyncio @@ -155,7 +164,8 @@ async def test_delete_project_by_id( assert response.status_code == 200 status_response = ProjectStatusResponse.model_validate(response.json()) assert status_response.status == "success" - assert status_response.old_project.external_id == created_project.external_id + old_project = _project_item(status_response.old_project) + assert old_project.external_id == created_project.external_id assert status_response.new_project is None # Verify it's deleted - trying to get it should return 404 diff --git a/tests/api/v2/test_search_hydration.py b/tests/api/v2/test_search_hydration.py index 2a067e929..0d3858cf8 100644 --- a/tests/api/v2/test_search_hydration.py +++ b/tests/api/v2/test_search_hydration.py @@ -8,6 +8,7 @@ from datetime import datetime, timezone from types import SimpleNamespace +from typing import Any import pytest @@ -22,9 +23,9 @@ def _make_entity(id: int, permalink: str) -> SimpleNamespace: return SimpleNamespace(id=id, permalink=permalink) -def _make_row(*, type: str, id: int, **kwargs) -> SearchIndexRow: +def _make_row(*, type: str, id: int, **kwargs: Any) -> SearchIndexRow: now = datetime.now(timezone.utc) - defaults = dict( + defaults: dict[str, Any] = dict( project_id=1, file_path=f"notes/{id}.md", created_at=now, diff --git a/tests/api/v2/test_search_router_telemetry.py b/tests/api/v2/test_search_router_telemetry.py index bd8507306..eb53bd553 100644 --- a/tests/api/v2/test_search_router_telemetry.py +++ b/tests/api/v2/test_search_router_telemetry.py @@ -4,6 +4,7 @@ import importlib from contextlib import contextmanager +from typing import Any, cast import pytest @@ -14,6 +15,7 @@ @pytest.mark.asyncio async def test_search_router_wraps_request_in_manual_operation() -> None: + router = cast(Any, search_router_module) operations: list[tuple[str, dict]] = [] class FakeSearchService: @@ -28,22 +30,22 @@ def fake_operation(name: str, **attrs): async def fake_to_search_results(entity_service, results): return [] - original_operation = search_router_module.telemetry.operation - original_to_search_results = search_router_module.to_search_results - search_router_module.telemetry.operation = fake_operation - search_router_module.to_search_results = fake_to_search_results + original_operation = router.telemetry.operation + original_to_search_results = router.to_search_results + router.telemetry.operation = fake_operation + router.to_search_results = fake_to_search_results try: - response = await search_router_module.search( + response = await router.search( SearchQuery(text="hello world"), FakeSearchService(), object(), - project_id="project-123", + project_id=123, page=2, page_size=5, ) finally: - search_router_module.telemetry.operation = original_operation - search_router_module.to_search_results = original_to_search_results + router.telemetry.operation = original_operation + router.to_search_results = original_to_search_results assert response.current_page == 2 assert operations == [ diff --git a/tests/cli/cloud/test_cloud_api_client_and_utils.py b/tests/cli/cloud/test_cloud_api_client_and_utils.py index b17da47db..0325a0c8d 100644 --- a/tests/cli/cloud/test_cloud_api_client_and_utils.py +++ b/tests/cli/cloud/test_cloud_api_client_and_utils.py @@ -163,6 +163,7 @@ async def api_request(**kwargs): created = await create_cloud_project("My Project", api_request=api_request) assert created.new_project is not None assert created.new_project["name"] == "My Project" + assert seen["create_payload"] is not None # Path should be permalink-like (kebab) assert seen["create_payload"]["path"] == "my-project" assert seen["create_payload"]["visibility"] == "workspace" diff --git a/tests/cli/test_auto_update.py b/tests/cli/test_auto_update.py index 0e4044922..662da549e 100644 --- a/tests/cli/test_auto_update.py +++ b/tests/cli/test_auto_update.py @@ -5,6 +5,7 @@ import subprocess from datetime import datetime, timedelta, timezone from io import StringIO +from typing import Any, cast from rich.console import Console @@ -36,6 +37,10 @@ def save_config(self, config: BasicMemoryConfig) -> None: self.save_calls += 1 +def _config_manager(manager: StubConfigManager) -> Any: + return cast(Any, manager) + + def _capture_console() -> tuple[Console, StringIO]: """Create a Console that writes to an in-memory buffer.""" buf = StringIO() @@ -91,7 +96,7 @@ def test_interval_gate_skips_check_when_recent(tmp_path): config.update_check_interval = 3600 manager = StubConfigManager(config) - result = run_auto_update(config_manager=manager) + result = run_auto_update(config_manager=_config_manager(manager)) assert result.status == AutoUpdateStatus.SKIPPED assert result.checked is False @@ -103,7 +108,7 @@ def test_auto_update_disabled_skips_periodic(tmp_path): config.auto_update = False manager = StubConfigManager(config) - result = run_auto_update(config_manager=manager) + result = run_auto_update(config_manager=_config_manager(manager)) assert result.status == AutoUpdateStatus.SKIPPED assert result.checked is False @@ -121,7 +126,7 @@ def test_force_bypasses_auto_update_disabled(monkeypatch, tmp_path): result = run_auto_update( force=True, - config_manager=manager, + config_manager=_config_manager(manager), executable="/Users/me/.local/share/uv/tools/basic-memory/bin/python", ) @@ -171,7 +176,7 @@ def _fake_run_subprocess(command, **kwargs): monkeypatch.setattr("basic_memory.cli.auto_update._run_subprocess", _fake_run_subprocess) result = run_auto_update( - config_manager=manager, + config_manager=_config_manager(manager), executable="/opt/homebrew/Cellar/basic-memory/0.18.0/bin/python", ) @@ -196,7 +201,7 @@ def _fake_run_subprocess(command, **kwargs): monkeypatch.setattr("basic_memory.cli.auto_update._run_subprocess", _fake_run_subprocess) result = run_auto_update( - config_manager=manager, + config_manager=_config_manager(manager), executable="/Users/me/.local/share/uv/tools/basic-memory/bin/python", ) @@ -215,7 +220,7 @@ def test_unknown_manager_returns_manual_update_guidance(monkeypatch, tmp_path): result = run_auto_update( force=True, - config_manager=manager, + config_manager=_config_manager(manager), executable="/usr/local/bin/python3", ) @@ -229,7 +234,7 @@ def test_uvx_runtime_is_skipped(monkeypatch, tmp_path): manager = StubConfigManager(config) result = run_auto_update( - config_manager=manager, + config_manager=_config_manager(manager), executable="/Users/me/.cache/uv/archive-v0/abc123/bin/python", ) @@ -256,7 +261,7 @@ def _fake_run_subprocess(command, **kwargs): monkeypatch.setattr("basic_memory.cli.auto_update._run_subprocess", _fake_run_subprocess) result = run_auto_update( - config_manager=manager, + config_manager=_config_manager(manager), executable="/Users/me/.local/share/uv/tools/basic-memory/bin/python", silent=True, ) @@ -281,7 +286,7 @@ def _raise_oserror(command, **kwargs): monkeypatch.setattr("basic_memory.cli.auto_update._run_subprocess", _raise_oserror) result = run_auto_update( - config_manager=manager, + config_manager=_config_manager(manager), executable="/Users/me/.local/share/uv/tools/basic-memory/bin/python", ) @@ -300,7 +305,7 @@ def test_mixed_timezone_timestamp_does_not_crash_interval_gate(monkeypatch, tmp_ ) result = run_auto_update( - config_manager=manager, + config_manager=_config_manager(manager), executable="/Users/me/.local/share/uv/tools/basic-memory/bin/python", ) diff --git a/tests/cli/test_cli_telemetry.py b/tests/cli/test_cli_telemetry.py index a5b2fd2ea..c1330f9ed 100644 --- a/tests/cli/test_cli_telemetry.py +++ b/tests/cli/test_cli_telemetry.py @@ -2,6 +2,8 @@ from __future__ import annotations +from typing import Any, cast + from basic_memory.cli import app as cli_app @@ -38,7 +40,7 @@ def fake_operation(name: str, **attrs): monkeypatch.setattr(cli_app.telemetry, "operation", fake_operation) ctx = FakeContext(invoked_subcommand="status") - cli_app.app_callback(ctx, version=None) + cli_app.app_callback(cast(Any, ctx), version=None) assert ctx.resources == [resource] assert operations == [ diff --git a/tests/cli/test_cloud_authentication.py b/tests/cli/test_cloud_authentication.py index 0fe73205e..10c1526ca 100644 --- a/tests/cli/test_cloud_authentication.py +++ b/tests/cli/test_cloud_authentication.py @@ -3,6 +3,7 @@ from __future__ import annotations from contextlib import asynccontextmanager +from typing import Any, cast import httpx import pytest @@ -28,6 +29,10 @@ async def login(self) -> bool: return self._login_ok +def _auth(auth: _StubAuth) -> Any: + return cast(Any, auth) + + def _make_http_client_factory(handler): @asynccontextmanager async def _factory(): @@ -61,7 +66,7 @@ async def handler(request: httpx.Request) -> httpx.Response: await make_api_request( "GET", "https://test.com/api/endpoint", - auth=auth, + auth=_auth(auth), http_client_factory=_make_http_client_factory(handler), ) @@ -88,7 +93,7 @@ async def handler(request: httpx.Request) -> httpx.Response: await make_api_request( "GET", "https://test.com/api/endpoint", - auth=auth, + auth=_auth(auth), http_client_factory=_make_http_client_factory(handler), ) @@ -110,7 +115,7 @@ async def handler(request: httpx.Request) -> httpx.Response: await make_api_request( "GET", "https://test.com/api/endpoint", - auth=auth, + auth=_auth(auth), http_client_factory=_make_http_client_factory(handler), ) diff --git a/tests/cli/test_cloud_promo.py b/tests/cli/test_cloud_promo.py index 823703d71..2167838c4 100644 --- a/tests/cli/test_cloud_promo.py +++ b/tests/cli/test_cloud_promo.py @@ -263,7 +263,10 @@ def save_config(self, config): assert result.exit_code == 0 assert "Cloud promo messages disabled" in result.stdout assert len(instances) == 1 - assert instances[0].saved_config.cloud_promo_opt_out is True + manager = instances[0] + assert isinstance(manager, _StubConfigManager) + assert manager.saved_config is not None + assert manager.saved_config.cloud_promo_opt_out is True def test_cloud_promo_command_on_clears_opt_out(monkeypatch): @@ -294,7 +297,10 @@ def save_config(self, config): assert result.exit_code == 0 assert "Cloud promo messages enabled" in result.stdout assert len(instances) == 1 - assert instances[0].saved_config.cloud_promo_opt_out is False + manager = instances[0] + assert isinstance(manager, _StubConfigManager) + assert manager.saved_config is not None + assert manager.saved_config.cloud_promo_opt_out is False # --- _is_interactive_session tests --- diff --git a/tests/cli/test_json_output.py b/tests/cli/test_json_output.py index 7a33988ea..5468ed643 100644 --- a/tests/cli/test_json_output.py +++ b/tests/cli/test_json_output.py @@ -17,7 +17,7 @@ from basic_memory.cli.main import app as cli_app from basic_memory.mcp.clients.project import ProjectClient from basic_memory.schemas.project_info import ProjectList -from basic_memory.schemas.sync_report import SyncReportResponse +from basic_memory.schemas.sync_report import SkippedFileResponse, SyncReportResponse # Importing registers subcommands on the shared app instance. import basic_memory.cli.commands.project as project_cmd # noqa: F401 @@ -76,12 +76,12 @@ def _mock_config_manager(): moves={}, checksums={}, skipped_files=[ - { - "path": "bad/file.md", - "reason": "parse error", - "failure_count": 3, - "first_failed": datetime(2025, 6, 15, 12, 0, 0), - } + SkippedFileResponse( + path="bad/file.md", + reason="parse error", + failure_count=3, + first_failed=datetime(2025, 6, 15, 12, 0, 0), + ) ], total=0, ) diff --git a/tests/db/test_issue_254_foreign_key_constraints.py b/tests/db/test_issue_254_foreign_key_constraints.py index 58cf528da..347eaa77d 100644 --- a/tests/db/test_issue_254_foreign_key_constraints.py +++ b/tests/db/test_issue_254_foreign_key_constraints.py @@ -137,6 +137,7 @@ async def test_issue_254_reproduction(project_service: ProjectService): # Create project and entity await project_service.add_project(test_project_name, test_project_path) project = await project_service.get_project(test_project_name) + assert project is not None from basic_memory.repository.entity_repository import EntityRepository diff --git a/tests/markdown/test_date_frontmatter_parsing.py b/tests/markdown/test_date_frontmatter_parsing.py index 47273a31e..7bbd67220 100644 --- a/tests/markdown/test_date_frontmatter_parsing.py +++ b/tests/markdown/test_date_frontmatter_parsing.py @@ -358,4 +358,5 @@ async def test_parse_file_with_reserved_frontmatter_field_content(tmp_path): assert entity_markdown.frontmatter.metadata.get("content") == "Template for topic notes" assert entity_markdown.frontmatter.metadata.get("handler") == "some-handler-value" # The actual body content should be parsed correctly + assert entity_markdown.content is not None assert "Template Body" in entity_markdown.content diff --git a/tests/markdown/test_entity_parser_error_handling.py b/tests/markdown/test_entity_parser_error_handling.py index 5b3528f3f..fbe6b6b4f 100644 --- a/tests/markdown/test_entity_parser_error_handling.py +++ b/tests/markdown/test_entity_parser_error_handling.py @@ -75,6 +75,7 @@ async def test_parse_file_with_completely_invalid_yaml(tmp_path): assert result.frontmatter.title == "broken_yaml" # Default from filename assert result.frontmatter.type == "note" # Default type # Content should include the whole file since frontmatter parsing failed + assert result.content is not None assert "# Content" in result.content @@ -374,6 +375,7 @@ async def test_frontmatter_roundtrip_preserves_user_metadata(tmp_path): assert result.frontmatter.type == "litnote" # NOT overwritten to "note" assert "citekey" in result.frontmatter.metadata assert result.frontmatter.metadata["citekey"] == "authorTitleYear2024" + assert result.content is not None # Simulate what write_frontmatter does post = frontmatter.Post(result.content, **result.frontmatter.metadata) diff --git a/tests/markdown/test_markdown_processor.py b/tests/markdown/test_markdown_processor.py index 21be72543..ac4727a49 100644 --- a/tests/markdown/test_markdown_processor.py +++ b/tests/markdown/test_markdown_processor.py @@ -141,6 +141,7 @@ async def test_update_preserves_content(markdown_processor: MarkdownProcessor, t result = await markdown_processor.read_file(path) # Original content preserved + assert result.content is not None assert "Original content here." in result.content # Both observations present diff --git a/tests/markdown/test_observation_edge_cases.py b/tests/markdown/test_observation_edge_cases.py index b587de804..5e1bf3619 100644 --- a/tests/markdown/test_observation_edge_cases.py +++ b/tests/markdown/test_observation_edge_cases.py @@ -28,6 +28,7 @@ def test_invalid_context(): tokens = md.parse("- [test] Content (unclosed") token = next(t for t in tokens if t.type == "inline") obs = parse_observation(token) + assert obs is not None assert obs["content"] == "Content (unclosed" assert obs["context"] is None @@ -35,6 +36,7 @@ def test_invalid_context(): tokens = md.parse("- [test] Content (with) extra) parens)") token = next(t for t in tokens if t.type == "inline") obs = parse_observation(token) + assert obs is not None assert obs["content"] == "Content" assert obs["context"] == "with) extra) parens" @@ -48,6 +50,7 @@ def test_complex_format(): token = next(t for t in tokens if t.type == "inline") obs = parse_observation(token) + assert obs is not None assert obs["category"] == "complex test" assert set(obs["tags"]) == {"tag1", "tag2", "tag3"} assert obs["content"] == "This is #tag1#tag2 with #tag3 content" @@ -55,6 +58,7 @@ def test_complex_format(): # Pydantic model validation observation = Observation.model_validate(obs) assert observation.category == "complex test" + assert observation.tags is not None assert set(observation.tags) == {"tag1", "tag2", "tag3"} assert observation.content == "This is #tag1#tag2 with #tag3 content" @@ -99,6 +103,7 @@ def test_unicode_content(): tokens = md.parse("- [test] Emoji test 👍 #emoji #test (Testing emoji)") token = next(t for t in tokens if t.type == "inline") obs = parse_observation(token) + assert obs is not None assert "👍" in obs["content"] assert "emoji" in obs["tags"] @@ -106,6 +111,7 @@ def test_unicode_content(): tokens = md.parse("- [中文] Chinese text 测试 #language (Script test)") token = next(t for t in tokens if t.type == "inline") obs = parse_observation(token) + assert obs is not None assert obs["category"] == "中文" assert "测试" in obs["content"] @@ -113,6 +119,7 @@ def test_unicode_content(): tokens = md.parse("- [test] Mixed 中文 and 👍 #mixed") token = next(t for t in tokens if t.type == "inline") obs = parse_observation(token) + assert obs is not None assert "中文" in obs["content"] assert "👍" in obs["content"] diff --git a/tests/markdown/test_parser_edge_cases.py b/tests/markdown/test_parser_edge_cases.py index 00ae5d238..702b65e37 100644 --- a/tests/markdown/test_parser_edge_cases.py +++ b/tests/markdown/test_parser_edge_cases.py @@ -41,6 +41,7 @@ async def test_unicode_content(tmp_path): assert "测试" in entity.frontmatter.metadata["tags"] assert "chinese" not in entity.frontmatter.metadata["tags"] + assert entity.content is not None assert "🧪" in entity.content # Verify Unicode in observations @@ -191,6 +192,7 @@ async def test_null_bytes_stripped(tmp_path): content=content, ) + assert entity.content is not None assert "\x00" not in entity.content assert "Some content" in entity.content assert "with nulls" in entity.content diff --git a/tests/markdown/test_relation_edge_cases.py b/tests/markdown/test_relation_edge_cases.py index 0d7e1bb0c..2547e4412 100644 --- a/tests/markdown/test_relation_edge_cases.py +++ b/tests/markdown/test_relation_edge_cases.py @@ -58,18 +58,21 @@ def test_context_handling(): tokens = md.parse("- type [[Target]] (unclosed") token = next(t for t in tokens if t.type == "inline") rel = parse_relation(token) + assert rel is not None assert rel["context"] is None # Multiple parens tokens = md.parse("- type [[Target]] (with (nested) parens)") token = next(t for t in tokens if t.type == "inline") rel = parse_relation(token) + assert rel is not None assert rel["context"] == "with (nested) parens" # Empty context tokens = md.parse("- type [[Target]] ()") token = next(t for t in tokens if t.type == "inline") rel = parse_relation(token) + assert rel is not None assert rel["context"] is None @@ -103,18 +106,21 @@ def test_unicode_targets(): tokens = md.parse("- type [[测试]]") token = next(t for t in tokens if t.type == "inline") rel = parse_relation(token) + assert rel is not None assert rel["target"] == "测试" # Unicode in type tokens = md.parse("- 使用 [[Target]]") token = next(t for t in tokens if t.type == "inline") rel = parse_relation(token) + assert rel is not None assert rel["type"] == "使用" # Unicode in context tokens = md.parse("- type [[Target]] (测试)") token = next(t for t in tokens if t.type == "inline") rel = parse_relation(token) + assert rel is not None assert rel["context"] == "测试" # Model validation with Unicode diff --git a/tests/mcp/conftest.py b/tests/mcp/conftest.py index 64366b9c8..db77585f4 100644 --- a/tests/mcp/conftest.py +++ b/tests/mcp/conftest.py @@ -1,6 +1,6 @@ """Tests for the MCP server implementation using FastAPI TestClient.""" -from typing import AsyncGenerator +from typing import Any, AsyncGenerator, cast import pytest import pytest_asyncio @@ -16,7 +16,7 @@ @pytest.fixture(scope="function") def mcp() -> FastMCP: - return mcp_server # pyright: ignore [reportReturnType] + return cast(Any, mcp_server) @pytest.fixture(scope="function") diff --git a/tests/mcp/test_permalink_collision_file_overwrite.py b/tests/mcp/test_permalink_collision_file_overwrite.py index dbdae340c..a25ca9a59 100644 --- a/tests/mcp/test_permalink_collision_file_overwrite.py +++ b/tests/mcp/test_permalink_collision_file_overwrite.py @@ -172,6 +172,7 @@ async def test_notes_with_similar_titles_maintain_separate_files(app, test_proje ) permalink = None + assert isinstance(result, str) # Extract permalink from result for line in result.split("\n"): if line.startswith("permalink:"): @@ -180,6 +181,7 @@ async def test_notes_with_similar_titles_maintain_separate_files(app, test_proje break # Verify each note can be read back with its own content + assert permalink is not None content = await read_note(permalink, project=test_project.name) assert f"Unique content for {title}" in content, ( f"Note with title '{title}' should maintain its unique content" diff --git a/tests/mcp/test_project_context.py b/tests/mcp/test_project_context.py index 94a5c94f2..09af793e4 100644 --- a/tests/mcp/test_project_context.py +++ b/tests/mcp/test_project_context.py @@ -6,6 +6,8 @@ from __future__ import annotations +from typing import Any, AsyncIterator, cast + import pytest @@ -22,6 +24,10 @@ async def set_state(self, key: str, value: object, **kwargs) -> None: self._state[key] = value +def _ctx(context: _ContextState) -> Any: + return cast(Any, context) + + @pytest.mark.asyncio async def test_returns_none_when_no_default_and_no_project(config_manager, monkeypatch): from basic_memory.mcp.project_context import resolve_project_parameter @@ -186,7 +192,7 @@ async def fake_get_available_workspaces(context=None): fake_get_available_workspaces, ) - resolved = await resolve_workspace_parameter(context=context) + resolved = await resolve_workspace_parameter(context=_ctx(context)) assert resolved.tenant_id == only_workspace.tenant_id assert await context.get_state("active_workspace") == only_workspace.model_dump() @@ -220,7 +226,7 @@ async def fake_get_available_workspaces(context=None): ) with pytest.raises(ValueError, match="Multiple workspaces are available"): - await resolve_workspace_parameter(context=_ContextState()) + await resolve_workspace_parameter(context=_ctx(_ContextState())) @pytest.mark.asyncio @@ -307,7 +313,7 @@ async def fail_if_called(context=None): # pragma: no cover fail_if_called, ) - resolved = await resolve_workspace_parameter(context=context) + resolved = await resolve_workspace_parameter(context=_ctx(context)) assert resolved.tenant_id == cached_workspace.tenant_id @@ -340,7 +346,7 @@ async def fail_if_called(): # pragma: no cover fail_if_called, ) - resolved = await resolve_project_parameter(project=None, context=context) + resolved = await resolve_project_parameter(project=None, context=_ctx(context)) assert resolved == cached_project.name @@ -366,8 +372,8 @@ async def fake_default_lookup(): fake_default_lookup, ) - first = await resolve_project_parameter(project=None, context=context) - second = await resolve_project_parameter(project=None, context=context) + first = await resolve_project_parameter(project=None, context=_ctx(context)) + second = await resolve_project_parameter(project=None, context=_ctx(context)) assert first == "cloud-default" assert second == "cloud-default" @@ -397,7 +403,7 @@ async def fail_if_called(*args, **kwargs): # pragma: no cover fail_if_called, ) - resolved = await get_active_project(client=None, context=context) + resolved = await get_active_project(client=cast(Any, None), context=_ctx(context)) assert resolved == cached_project @@ -426,7 +432,9 @@ async def fail_if_called(*args, **kwargs): # pragma: no cover fail_if_called, ) - resolved = await get_active_project(client=None, project="my-research", context=context) + resolved = await get_active_project( + client=cast(Any, None), project="my-research", context=_ctx(context) + ) assert resolved == cached_project @@ -464,9 +472,9 @@ async def fake_resolve_project_parameter(project=None, **kwargs): ) active_project, resolved_path, is_memory_url = await resolve_project_and_path( - client=None, + client=cast(Any, None), identifier="memory://my-research/notes/roadmap.md", - context=context, + context=_ctx(context), ) assert active_project == cached_project @@ -709,7 +717,7 @@ async def test_factory_mode_skips_workspace_resolution(self, config_manager, mon # Set up a factory (simulates what cloud MCP server does) @asynccontextmanager - async def fake_factory(workspace=None): + async def fake_factory(workspace: Any = None) -> AsyncIterator[Any]: from httpx import ASGITransport, AsyncClient from basic_memory.api.app import app as fastapi_app @@ -733,10 +741,11 @@ async def fail_if_called(**kwargs): # pragma: no cover # Patch get_cloud_control_plane_client to fail if called @asynccontextmanager - async def fail_control_plane(): # pragma: no cover + async def fail_control_plane() -> AsyncIterator[Any]: # pragma: no cover raise AssertionError( "get_cloud_control_plane_client must not be called in factory mode" ) + yield monkeypatch.setattr( "basic_memory.mcp.async_client.get_cloud_control_plane_client", diff --git a/tests/mcp/test_project_context_telemetry.py b/tests/mcp/test_project_context_telemetry.py index c0f90ef4e..1863bc765 100644 --- a/tests/mcp/test_project_context_telemetry.py +++ b/tests/mcp/test_project_context_telemetry.py @@ -4,6 +4,7 @@ import importlib from contextlib import contextmanager +from typing import Any, cast import pytest @@ -60,7 +61,7 @@ async def fake_get_available_workspaces(context=None): monkeypatch.setattr(project_context.telemetry, "span", fake_span) monkeypatch.setattr(project_context, "get_available_workspaces", fake_get_available_workspaces) - resolved = await project_context.resolve_workspace_parameter(context=context) + resolved = await project_context.resolve_workspace_parameter(context=cast(Any, context)) assert resolved.tenant_id == workspace.tenant_id assert spans == [ diff --git a/tests/mcp/test_recent_activity_prompt_modes.py b/tests/mcp/test_recent_activity_prompt_modes.py index 4a632eeda..2c29333f4 100644 --- a/tests/mcp/test_recent_activity_prompt_modes.py +++ b/tests/mcp/test_recent_activity_prompt_modes.py @@ -5,6 +5,7 @@ """ import pytest +from typing import Any, cast from basic_memory.mcp.prompts.recent_activity import recent_activity_prompt @@ -102,7 +103,7 @@ async def fake_fn(**kwargs): monkeypatch.setattr("basic_memory.mcp.prompts.recent_activity.recent_activity", fake_fn) - await recent_activity_prompt(timeframe=None, project=None) # pyright: ignore[reportGeneralTypeIssues] + await recent_activity_prompt(timeframe=cast(Any, None), project=None) assert captured_kwargs["timeframe"] == "7d" assert captured_kwargs["project"] is None diff --git a/tests/mcp/test_tool_build_context.py b/tests/mcp/test_tool_build_context.py index 64d0fbece..a0bccdc0b 100644 --- a/tests/mcp/test_tool_build_context.py +++ b/tests/mcp/test_tool_build_context.py @@ -92,6 +92,9 @@ async def test_get_discussion_context_timeframe(client, test_graph, test_project timeframe="30d", ) + assert isinstance(recent, dict) + assert isinstance(older, dict) + # Calculate total related items total_recent_related = ( sum(len(item["related_results"]) for item in recent["results"]) if recent["results"] else 0 @@ -161,6 +164,7 @@ async def test_build_context_string_depth_parameter(client, test_graph, test_pro # Test valid string depth parameter — should convert to int try: result = await build_context(url=test_url, depth="2", project=test_project.name) + assert isinstance(result, dict) assert isinstance(result["metadata"]["depth"], int) assert result["metadata"]["depth"] == 2 except ToolError: diff --git a/tests/mcp/test_tool_contracts.py b/tests/mcp/test_tool_contracts.py index 996c66cbd..0da74ae08 100644 --- a/tests/mcp/test_tool_contracts.py +++ b/tests/mcp/test_tool_contracts.py @@ -3,6 +3,8 @@ from __future__ import annotations import inspect +from collections.abc import Callable +from typing import Any, cast from basic_memory.mcp import tools @@ -134,7 +136,7 @@ def _signature_params(tool_obj: object) -> list[str]: params = [] - for param in inspect.signature(tool_obj).parameters.values(): + for param in inspect.signature(cast(Callable[..., Any], tool_obj)).parameters.values(): if param.name == "context": continue params.append(param.name) diff --git a/tests/mcp/test_tool_project_management.py b/tests/mcp/test_tool_project_management.py index dd98e22a3..e5fea5524 100644 --- a/tests/mcp/test_tool_project_management.py +++ b/tests/mcp/test_tool_project_management.py @@ -117,6 +117,7 @@ async def test_create_and_delete_project_and_name_match_branch( project_path=str(project_root), set_default=False, ) + assert isinstance(result, str) assert result.startswith("✓") assert "My Project" in result @@ -495,6 +496,7 @@ async def test_list_memory_projects_json_includes_workspace_info(app, test_proje ): result = await list_memory_projects(output_format="json", workspace="org-tenant-abc") + assert isinstance(result, dict) by_name = {p["name"]: p for p in result["projects"]} # Cloud project carries workspace info diff --git a/tests/mcp/test_tool_read_note.py b/tests/mcp/test_tool_read_note.py index 5fa4dbfb2..8775bdb35 100644 --- a/tests/mcp/test_tool_read_note.py +++ b/tests/mcp/test_tool_read_note.py @@ -42,7 +42,7 @@ async def test_read_note_title_search_fallback_fetches_by_permalink(monkeypatch, direct_identifier = memory_url_path("Fallback Title Note") class SelectiveKnowledgeClient(OriginalKnowledgeClient): - async def resolve_entity(self, identifier: str, *, strict: bool = False) -> int: + async def resolve_entity(self, identifier: str, *, strict: bool = False) -> str: # Fail on the direct identifier to force fallback to title search if identifier == direct_identifier: raise RuntimeError("force direct lookup failure") @@ -94,7 +94,7 @@ async def fake_search_notes_fn(*, query, search_type, **kwargs): # Ensure direct resolution doesn't short-circuit the fallback logic. class FailingKnowledgeClient(OriginalKnowledgeClient): - async def resolve_entity(self, identifier: str, *, strict: bool = False) -> int: + async def resolve_entity(self, identifier: str, *, strict: bool = False) -> str: raise RuntimeError("force fallback") monkeypatch.setattr(clients_mod, "KnowledgeClient", FailingKnowledgeClient) @@ -123,7 +123,7 @@ async def test_read_note_title_fallback_requires_exact_title_match(monkeypatch, OriginalKnowledgeClient = clients_mod.KnowledgeClient class StrictFailingKnowledgeClient(OriginalKnowledgeClient): - async def resolve_entity(self, identifier: str, *, strict: bool = False) -> int: + async def resolve_entity(self, identifier: str, *, strict: bool = False) -> str: if strict: raise RuntimeError("force strict direct lookup failure") return await super().resolve_entity(identifier, strict=strict) @@ -286,7 +286,7 @@ async def test_read_note_memory_url_fallback_uses_search_tool_normalization( search_calls: list[tuple[str, str, str | None]] = [] class SelectiveKnowledgeClient(OriginalKnowledgeClient): - async def resolve_entity(self, identifier: str, *, strict: bool = False) -> int: + async def resolve_entity(self, identifier: str, *, strict: bool = False) -> str: if strict and identifier.endswith("test/memory-url-fallback-note"): raise RuntimeError("force direct lookup failure") return await super().resolve_entity(identifier, strict=strict) diff --git a/tests/mcp/test_tool_recent_activity.py b/tests/mcp/test_tool_recent_activity.py index 8ef7b7862..630de8497 100644 --- a/tests/mcp/test_tool_recent_activity.py +++ b/tests/mcp/test_tool_recent_activity.py @@ -1,6 +1,7 @@ """Tests for discussion context MCP tool.""" from datetime import datetime, timedelta, timezone +from typing import Any, cast import pytest @@ -235,7 +236,7 @@ class P: path = "/tmp/p" proj_activity = await recent_activity_module._get_project_activity( - client=None, project_info=P(), params={}, depth=1 + client=None, project_info=cast(Any, P()), params={}, depth=1 ) assert proj_activity.item_count == 2 assert "folder" in proj_activity.active_folders diff --git a/tests/mcp/test_tool_utils.py b/tests/mcp/test_tool_utils.py index 2d7820d07..69526c11e 100644 --- a/tests/mcp/test_tool_utils.py +++ b/tests/mcp/test_tool_utils.py @@ -1,14 +1,16 @@ """Tests for MCP tool utilities.""" +from typing import Any, cast + import pytest -from httpx import HTTPStatusError +from httpx import HTTPStatusError, Request from mcp.server.fastmcp.exceptions import ToolError from basic_memory.mcp.tools.utils import ( + call_delete, call_get, call_post, call_put, - call_delete, get_error_message, ) @@ -26,7 +28,9 @@ def __init__(self, status_code=200): def raise_for_status(self): if self.status_code >= 400: raise HTTPStatusError( - message=f"HTTP Error {self.status_code}", request=None, response=self + message=f"HTTP Error {self.status_code}", + request=Request("GET", "http://test.com"), + response=cast(Any, self), ) return MockResponse @@ -57,13 +61,17 @@ async def delete(self, *args, **kwargs): return self._responses["delete"] +def _client(client: _Client) -> Any: + return cast(Any, client) + + @pytest.mark.asyncio async def test_call_get_success(mock_response): """Test successful GET request.""" client = _Client() client.set_response("get", mock_response()) - response = await call_get(client, "http://test.com") + response = await call_get(_client(client), "http://test.com") assert response.status_code == 200 @@ -74,7 +82,7 @@ async def test_call_get_error(mock_response): client.set_response("get", mock_response(404)) with pytest.raises(ToolError) as exc: - await call_get(client, "http://test.com") + await call_get(_client(client), "http://test.com") assert "Resource not found" in str(exc.value) @@ -86,7 +94,7 @@ async def test_call_post_success(mock_response): response.json = lambda: {"test": "data"} client.set_response("post", response) - response = await call_post(client, "http://test.com", json={"test": "data"}) + response = await call_post(_client(client), "http://test.com", json={"test": "data"}) assert response.status_code == 200 @@ -100,7 +108,7 @@ async def test_call_post_error(mock_response): client.set_response("post", response) with pytest.raises(ToolError) as exc: - await call_post(client, "http://test.com", json={"test": "data"}) + await call_post(_client(client), "http://test.com", json={"test": "data"}) assert "Internal server error" in str(exc.value) @@ -110,7 +118,7 @@ async def test_call_put_success(mock_response): client = _Client() client.set_response("put", mock_response()) - response = await call_put(client, "http://test.com", json={"test": "data"}) + response = await call_put(_client(client), "http://test.com", json={"test": "data"}) assert response.status_code == 200 @@ -121,7 +129,7 @@ async def test_call_put_error(mock_response): client.set_response("put", mock_response(400)) with pytest.raises(ToolError) as exc: - await call_put(client, "http://test.com", json={"test": "data"}) + await call_put(_client(client), "http://test.com", json={"test": "data"}) assert "Invalid request" in str(exc.value) @@ -131,7 +139,7 @@ async def test_call_delete_success(mock_response): client = _Client() client.set_response("delete", mock_response()) - response = await call_delete(client, "http://test.com") + response = await call_delete(_client(client), "http://test.com") assert response.status_code == 200 @@ -142,7 +150,7 @@ async def test_call_delete_error(mock_response): client.set_response("delete", mock_response(403)) with pytest.raises(ToolError) as exc: - await call_delete(client, "http://test.com") + await call_delete(_client(client), "http://test.com") assert "Access denied" in str(exc.value) @@ -153,7 +161,7 @@ async def test_call_get_with_params(mock_response): client.set_response("get", mock_response()) params = {"key": "value", "test": "data"} - await call_get(client, "http://test.com", params=params) + await call_get(_client(client), "http://test.com", params=params) assert len(client.calls) == 1 method, _args, kwargs = client.calls[0] @@ -199,7 +207,7 @@ async def test_call_post_with_json(mock_response): client.set_response("post", response) json_data = {"key": "value", "nested": {"test": "data"}} - await call_post(client, "http://test.com", json=json_data) + await call_post(_client(client), "http://test.com", json=json_data) assert len(client.calls) == 1 method, _args, kwargs = client.calls[0] diff --git a/tests/mcp/test_tool_utils_cloud_auth.py b/tests/mcp/test_tool_utils_cloud_auth.py index a121ecc03..b08532f6a 100644 --- a/tests/mcp/test_tool_utils_cloud_auth.py +++ b/tests/mcp/test_tool_utils_cloud_auth.py @@ -1,7 +1,9 @@ """Cloud auth error translation tests for MCP tool HTTP helpers.""" +from typing import Any, cast + import pytest -from httpx import HTTPStatusError +from httpx import HTTPStatusError, Request from mcp.server.fastmcp.exceptions import ToolError from basic_memory.mcp.tools.utils import call_post @@ -20,8 +22,8 @@ def raise_for_status(self): if self.status_code >= 400: raise HTTPStatusError( message=f"HTTP Error {self.status_code}", - request=None, - response=self, + request=Request("POST", "http://test/v2/projects/"), + response=cast(Any, self), ) @@ -48,7 +50,7 @@ async def test_call_post_401_with_cloud_key_shows_actionable_remediation(config_ ) with pytest.raises(ToolError) as exc: - await call_post(client, "/v2/projects/", json={"name": "test"}) + await call_post(cast(Any, client), "/v2/projects/", json={"name": "test"}) message = str(exc.value) assert "configured cloud API key was rejected" in message diff --git a/tests/mcp/test_tool_workspace_management.py b/tests/mcp/test_tool_workspace_management.py index 0cf3f8c8c..0df206355 100644 --- a/tests/mcp/test_tool_workspace_management.py +++ b/tests/mcp/test_tool_workspace_management.py @@ -1,5 +1,7 @@ """Tests for workspace MCP tools.""" +from typing import Any, cast + import pytest from basic_memory.mcp.project_context import get_available_workspaces, set_workspace_provider @@ -100,8 +102,8 @@ async def fake_get_available_workspaces(context=None): fake_get_available_workspaces, ) - first = await list_workspaces(context=context) - second = await list_workspaces(context=context) + first = await list_workspaces(context=cast(Any, context)) + second = await list_workspaces(context=cast(Any, context)) assert "# Available Workspaces (1)" in first assert "# Available Workspaces (1)" in second @@ -184,11 +186,11 @@ async def counting_provider() -> list[WorkspaceInfo]: context = _ContextState() # First call: provider is invoked, result cached - first = await get_available_workspaces(context=context) + first = await get_available_workspaces(context=cast(Any, context)) assert len(first) == 1 assert call_count["provider"] == 1 # Second call: served from context cache, provider not called again - second = await get_available_workspaces(context=context) + second = await get_available_workspaces(context=cast(Any, context)) assert len(second) == 1 assert call_count["provider"] == 1 diff --git a/tests/mcp/test_tool_write_note.py b/tests/mcp/test_tool_write_note.py index a4d29b8d8..f73198f0f 100644 --- a/tests/mcp/test_tool_write_note.py +++ b/tests/mcp/test_tool_write_note.py @@ -1,6 +1,7 @@ """Tests for note tools that exercise the full stack with SQLite.""" from textwrap import dedent +from typing import Any import pytest @@ -204,6 +205,7 @@ async def test_issue_93_write_note_respects_custom_permalink_existing_note(app, assert "# Created note" in result1 assert f"project: {test_project.name}" in result1 + assert isinstance(result1, str) # Extract the auto-generated permalink initial_permalink = None @@ -287,7 +289,7 @@ async def test_write_note_with_tag_array_from_bug_report(app, test_project): was passing an array of tags and getting a type mismatch error. """ # This is the exact payload from the bug report - bug_payload = { + bug_payload: dict[str, Any] = { "project": test_project.name, "title": "Title", "directory": "folder", diff --git a/tests/mcp/test_ui_sdk.py b/tests/mcp/test_ui_sdk.py index 4b7bb9a35..975d8221f 100644 --- a/tests/mcp/test_ui_sdk.py +++ b/tests/mcp/test_ui_sdk.py @@ -5,9 +5,11 @@ - basic_memory.mcp.tools.ui_sdk (_text_block, search_notes_ui, read_note_ui) """ +from typing import cast from unittest.mock import MagicMock import pytest +from mcp.types import TextContent from basic_memory.mcp.ui.sdk import ( MissingMCPUIServerError, @@ -133,5 +135,6 @@ class TestTextBlock: def test_returns_single_text_content(self): blocks = _text_block("hello world") assert len(blocks) == 1 - assert blocks[0].type == "text" - assert blocks[0].text == "hello world" + block = cast(TextContent, blocks[0]) + assert block.type == "text" + assert block.text == "hello world" diff --git a/tests/repository/test_entity_repository.py b/tests/repository/test_entity_repository.py index f2c9fcf0e..8a47d76de 100644 --- a/tests/repository/test_entity_repository.py +++ b/tests/repository/test_entity_repository.py @@ -935,7 +935,8 @@ async def test_get_all_file_paths_project_isolation( async def test_permalink_exists(entity_repository: EntityRepository, sample_entity: Entity): """Test checking if a permalink exists without loading full entity.""" # Existing permalink should return True - assert await entity_repository.permalink_exists(sample_entity.permalink) is True # pyright: ignore [reportArgumentType] + assert sample_entity.permalink is not None + assert await entity_repository.permalink_exists(sample_entity.permalink) is True # Non-existent permalink should return False assert await entity_repository.permalink_exists("nonexistent/permalink") is False @@ -990,7 +991,8 @@ async def test_get_file_path_for_permalink( ): """Test getting file_path for a permalink without loading full entity.""" # Existing permalink should return file_path - file_path = await entity_repository.get_file_path_for_permalink(sample_entity.permalink) # pyright: ignore [reportArgumentType] + assert sample_entity.permalink is not None + file_path = await entity_repository.get_file_path_for_permalink(sample_entity.permalink) assert file_path == sample_entity.file_path # Non-existent permalink should return None diff --git a/tests/repository/test_fastembed_provider.py b/tests/repository/test_fastembed_provider.py index 4974317b8..bc0aed7ca 100644 --- a/tests/repository/test_fastembed_provider.py +++ b/tests/repository/test_fastembed_provider.py @@ -46,7 +46,7 @@ def embed(self, texts: list[str], batch_size: int = 64, **kwargs): async def test_fastembed_provider_lazy_loads_and_reuses_model(monkeypatch): """Provider should instantiate FastEmbed lazily and reuse the loaded model.""" module = type(sys)("fastembed") - module.TextEmbedding = _StubTextEmbedding + setattr(module, "TextEmbedding", _StubTextEmbedding) monkeypatch.setitem(sys.modules, "fastembed", module) _StubTextEmbedding.init_count = 0 @@ -67,7 +67,7 @@ async def test_fastembed_provider_lazy_loads_and_reuses_model(monkeypatch): async def test_fastembed_provider_dimension_mismatch_raises_error(monkeypatch): """Provider should fail fast when model output dimensions differ from configured dimensions.""" module = type(sys)("fastembed") - module.TextEmbedding = _StubTextEmbedding + setattr(module, "TextEmbedding", _StubTextEmbedding) monkeypatch.setitem(sys.modules, "fastembed", module) provider = FastEmbedEmbeddingProvider(model_name="stub-model", dimensions=4) @@ -99,7 +99,7 @@ def _raising_import(name, globals=None, locals=None, fromlist=(), level=0): async def test_fastembed_provider_passes_runtime_knobs_to_fastembed(monkeypatch): """Provider should pass optional runtime tuning knobs through to FastEmbed.""" module = type(sys)("fastembed") - module.TextEmbedding = _StubTextEmbedding + setattr(module, "TextEmbedding", _StubTextEmbedding) monkeypatch.setitem(sys.modules, "fastembed", module) _StubTextEmbedding.last_init_kwargs = {} _StubTextEmbedding.last_embed_kwargs = {} @@ -126,7 +126,7 @@ async def test_fastembed_provider_passes_runtime_knobs_to_fastembed(monkeypatch) async def test_fastembed_provider_parallel_one_disables_multiprocessing(monkeypatch): """parallel=1 should not pass FastEmbed multiprocessing kwargs.""" module = type(sys)("fastembed") - module.TextEmbedding = _StubTextEmbedding + setattr(module, "TextEmbedding", _StubTextEmbedding) monkeypatch.setitem(sys.modules, "fastembed", module) _StubTextEmbedding.last_embed_kwargs = {} @@ -140,7 +140,7 @@ async def test_fastembed_provider_parallel_one_disables_multiprocessing(monkeypa async def test_fastembed_provider_parallel_two_passes_multiprocessing(monkeypatch): """parallel>1 should keep passing FastEmbed multiprocessing kwargs.""" module = type(sys)("fastembed") - module.TextEmbedding = _StubTextEmbedding + setattr(module, "TextEmbedding", _StubTextEmbedding) monkeypatch.setitem(sys.modules, "fastembed", module) _StubTextEmbedding.last_embed_kwargs = {} diff --git a/tests/repository/test_hybrid_fusion.py b/tests/repository/test_hybrid_fusion.py index 10f4cc82c..c0226f0dd 100644 --- a/tests/repository/test_hybrid_fusion.py +++ b/tests/repository/test_hybrid_fusion.py @@ -7,11 +7,16 @@ """ from dataclasses import dataclass +from datetime import datetime +from typing import Any, Optional, cast from unittest.mock import AsyncMock, patch import pytest +from basic_memory.repository.embedding_provider import EmbeddingProvider +from basic_memory.repository.search_index_row import SearchIndexRow from basic_memory.repository.search_repository_base import FUSION_BONUS, SearchRepositoryBase +from basic_memory.schemas.search import SearchItemType, SearchRetrievalMode @dataclass @@ -30,7 +35,6 @@ class FakeRow: to_id: int | None = None relation_type: str | None = None entity_id: int | None = None - content_snippet: str | None = None category: str | None = None created_at: str | None = None updated_at: str | None = None @@ -46,7 +50,7 @@ def __init__(self): self._semantic_vector_k = 100 self._semantic_min_similarity = 0.0 # _search_hybrid calls _assert_semantic_available which checks this - self._embedding_provider = type("EP", (), {"dimensions": 384})() + self._embedding_provider = _fake_embedding_provider() self._vector_dimensions = 384 self._vector_tables_initialized = True self.session_maker = None @@ -58,7 +62,21 @@ async def init_search_index(self): def _prepare_search_term(self, term, is_prefix=True): return term # pragma: no cover - async def search(self, **kwargs): + async def search( + self, + search_text: Optional[str] = None, + permalink: Optional[str] = None, + permalink_match: Optional[str] = None, + title: Optional[str] = None, + note_types: Optional[list[str]] = None, + after_date: Optional[datetime] = None, + search_item_types: Optional[list[SearchItemType]] = None, + metadata_filters: Optional[dict[str, Any]] = None, + retrieval_mode: SearchRetrievalMode = SearchRetrievalMode.FTS, + min_similarity: Optional[float] = None, + limit: int = 10, + offset: int = 0, + ) -> list[SearchIndexRow]: return [] # pragma: no cover async def _ensure_vector_tables(self): @@ -83,7 +101,24 @@ def _distance_to_similarity(self, distance: float) -> float: return 1.0 / (1.0 + max(distance, 0.0)) # pragma: no cover -HYBRID_KWARGS = dict( +def _fake_embedding_provider() -> EmbeddingProvider: + return cast( + EmbeddingProvider, + type( + "EP", + (), + { + "model_name": "fake", + "dimensions": 384, + "embed_query": AsyncMock(return_value=[0.0] * 384), + "embed_documents": AsyncMock(return_value=[]), + "runtime_log_attrs": lambda self: {}, + }, + )(), + ) + + +HYBRID_KWARGS: dict[str, Any] = dict( search_text="test", permalink=None, permalink_match=None, diff --git a/tests/repository/test_observation_repository.py b/tests/repository/test_observation_repository.py index 27b863995..25641c8ab 100644 --- a/tests/repository/test_observation_repository.py +++ b/tests/repository/test_observation_repository.py @@ -4,8 +4,8 @@ import pytest import pytest_asyncio -import sqlalchemy from sqlalchemy.ext.asyncio import async_sessionmaker +from sqlalchemy.exc import IntegrityError from basic_memory import db from basic_memory.models import Entity, Observation, Project @@ -59,7 +59,7 @@ async def test_create_observation_entity_does_not_exist( "content": "Test content", "context": "test-context", } - with pytest.raises(sqlalchemy.exc.IntegrityError): + with pytest.raises(IntegrityError): await observation_repository.create(observation_data) diff --git a/tests/repository/test_openai_provider.py b/tests/repository/test_openai_provider.py index e76882a60..b55d0c41c 100644 --- a/tests/repository/test_openai_provider.py +++ b/tests/repository/test_openai_provider.py @@ -81,7 +81,7 @@ def _reset_embedding_provider_cache_fixture(): async def test_openai_provider_lazy_loads_and_reuses_client(monkeypatch): """Provider should instantiate AsyncOpenAI lazily and reuse a single client.""" module = type(sys)("openai") - module.AsyncOpenAI = _StubAsyncOpenAI + setattr(module, "AsyncOpenAI", _StubAsyncOpenAI) monkeypatch.setitem(sys.modules, "openai", module) monkeypatch.setenv("OPENAI_API_KEY", "test-key") _StubAsyncOpenAI.init_count = 0 @@ -105,7 +105,7 @@ async def test_openai_provider_lazy_loads_and_reuses_client(monkeypatch): async def test_openai_provider_dimension_mismatch_raises_error(monkeypatch): """Provider should fail fast when response dimensions differ from configured dimensions.""" module = type(sys)("openai") - module.AsyncOpenAI = _StubAsyncOpenAI + setattr(module, "AsyncOpenAI", _StubAsyncOpenAI) monkeypatch.setitem(sys.modules, "openai", module) monkeypatch.setenv("OPENAI_API_KEY", "test-key") @@ -139,7 +139,7 @@ def _raising_import(name, globals=None, locals=None, fromlist=(), level=0): async def test_openai_provider_missing_api_key_raises_error(monkeypatch): """OPENAI_API_KEY is required unless api_key is passed explicitly.""" module = type(sys)("openai") - module.AsyncOpenAI = _StubAsyncOpenAI + setattr(module, "AsyncOpenAI", _StubAsyncOpenAI) monkeypatch.setitem(sys.modules, "openai", module) monkeypatch.delenv("OPENAI_API_KEY", raising=False) @@ -421,7 +421,7 @@ def __init__(self, *, api_key: str, base_url=None, timeout=30.0): self.embeddings = shared_api module = type(sys)("openai") - module.AsyncOpenAI = _ConcurrentAsyncOpenAI + setattr(module, "AsyncOpenAI", _ConcurrentAsyncOpenAI) monkeypatch.setitem(sys.modules, "openai", module) monkeypatch.setenv("OPENAI_API_KEY", "test-key") @@ -452,7 +452,7 @@ def __init__(self, *, api_key: str, base_url=None, timeout=30.0): self.embeddings = _MalformedEmbeddingsApi() module = type(sys)("openai") - module.AsyncOpenAI = _MalformedAsyncOpenAI + setattr(module, "AsyncOpenAI", _MalformedAsyncOpenAI) monkeypatch.setitem(sys.modules, "openai", module) monkeypatch.setenv("OPENAI_API_KEY", "test-key") diff --git a/tests/repository/test_postgres_search_repository.py b/tests/repository/test_postgres_search_repository.py index 4ac72859c..20f455b10 100644 --- a/tests/repository/test_postgres_search_repository.py +++ b/tests/repository/test_postgres_search_repository.py @@ -36,6 +36,9 @@ async def embed_query(self, text: str) -> list[float]: async def embed_documents(self, texts: list[str]) -> list[list[float]]: return [self._vectorize(text) for text in texts] + def runtime_log_attrs(self) -> dict[str, object]: + return {} + @staticmethod def _vectorize(text: str) -> list[float]: normalized = text.lower() @@ -715,6 +718,9 @@ async def embed_query(self, text: str) -> list[float]: async def embed_documents(self, texts: list[str]) -> list[list[float]]: return [[0.0] * 8 for _ in texts] + def runtime_log_attrs(self) -> dict[str, object]: + return {} + @pytest.mark.asyncio async def test_postgres_dimension_mismatch_triggers_table_recreation(session_maker, test_project): diff --git a/tests/repository/test_relation_repository.py b/tests/repository/test_relation_repository.py index 1bfc07cdd..06f054ebf 100644 --- a/tests/repository/test_relation_repository.py +++ b/tests/repository/test_relation_repository.py @@ -4,10 +4,10 @@ import pytest import pytest_asyncio -import sqlalchemy +from sqlalchemy.exc import IntegrityError from basic_memory import db -from basic_memory.models import Entity, Relation, Project +from basic_memory.models import Entity, Project, Relation from basic_memory.repository.relation_repository import RelationRepository @@ -168,7 +168,7 @@ async def test_create_relation_entity_does_not_exist( "relation_type": "test_relation", "context": "test-context", } - with pytest.raises(sqlalchemy.exc.IntegrityError): + with pytest.raises(IntegrityError): await relation_repository.create(relation_data) @@ -194,6 +194,7 @@ async def test_find_relation(relation_repository: RelationRepository, sample_rel to_permalink=sample_relation.to_entity.permalink, relation_type=sample_relation.relation_type, ) + assert relation is not None assert relation.id == sample_relation.id diff --git a/tests/repository/test_search_repository_edit_bug_fix.py b/tests/repository/test_search_repository_edit_bug_fix.py index 864a8f3c2..98eebdffd 100644 --- a/tests/repository/test_search_repository_edit_bug_fix.py +++ b/tests/repository/test_search_repository_edit_bug_fix.py @@ -160,6 +160,7 @@ async def test_index_item_respects_project_isolation_during_edit(): results1_after = await repo1.search(search_text="project 1 content EDITED") assert len(results1_after) == 1 assert results1_after[0].title == "Test Note in Project 1" + assert results1_after[0].content_snippet is not None assert "EDITED" in results1_after[0].content_snippet # CRITICAL TEST: Verify project 2's note is still there (the bug would delete it) @@ -167,6 +168,7 @@ async def test_index_item_respects_project_isolation_during_edit(): assert len(results2_after) == 1, "Project 2's note disappeared after editing project 1's note!" assert results2_after[0].title == "Test Note in Project 2" assert results2_after[0].project_id == project2_id + assert results2_after[0].content_snippet is not None assert "original" in results2_after[0].content_snippet # Should still be original # Double-check: project 1 should not be able to see project 2's note diff --git a/tests/repository/test_semantic_search_base.py b/tests/repository/test_semantic_search_base.py index f42a32f06..1e8a58db1 100644 --- a/tests/repository/test_semantic_search_base.py +++ b/tests/repository/test_semantic_search_base.py @@ -6,13 +6,16 @@ import asyncio from contextlib import asynccontextmanager +from datetime import datetime from types import SimpleNamespace +from typing import Any from unittest.mock import AsyncMock import pytest import basic_memory.repository.search_repository_base as search_repository_base_module from basic_memory.repository.fastembed_provider import FastEmbedEmbeddingProvider +from basic_memory.repository.search_index_row import SearchIndexRow from basic_memory.repository.search_repository_base import ( MAX_VECTOR_CHUNK_CHARS, SearchRepositoryBase, @@ -71,7 +74,21 @@ async def init_search_index(self): def _prepare_search_term(self, term, is_prefix=True): return term - async def search(self, **kwargs): + async def search( + self, + search_text: str | None = None, + permalink: str | None = None, + permalink_match: str | None = None, + title: str | None = None, + note_types: list[str] | None = None, + after_date: datetime | None = None, + search_item_types: list[SearchItemType] | None = None, + metadata_filters: dict[str, Any] | None = None, + retrieval_mode: SearchRetrievalMode = SearchRetrievalMode.FTS, + min_similarity: float | None = None, + limit: int = 10, + offset: int = 0, + ) -> list[SearchIndexRow]: return [] async def _ensure_vector_tables(self): @@ -637,9 +654,13 @@ async def _yielding_write_scope(): ) prepared = await repo._prepare_entity_vector_jobs_window([1, 2]) + prepared_results = [ + result for result in prepared if isinstance(result, _PreparedEntityVectorSync) + ] - assert [result.sync_start for result in prepared] == [10.0, 11.0] - assert [result.prepare_seconds for result in prepared] == [2.0, 2.0] + assert len(prepared_results) == 2 + assert [result.sync_start for result in prepared_results] == [10.0, 11.0] + assert [result.prepare_seconds for result in prepared_results] == [2.0, 2.0] @pytest.mark.asyncio diff --git a/tests/repository/test_sqlite_vector_search_repository.py b/tests/repository/test_sqlite_vector_search_repository.py index ea25675c0..6daf20f65 100644 --- a/tests/repository/test_sqlite_vector_search_repository.py +++ b/tests/repository/test_sqlite_vector_search_repository.py @@ -3,6 +3,7 @@ import asyncio from contextlib import asynccontextmanager from datetime import datetime, timezone +from typing import Any, cast from unittest.mock import AsyncMock, MagicMock import pytest @@ -27,6 +28,9 @@ async def embed_query(self, text: str) -> list[float]: async def embed_documents(self, texts: list[str]) -> list[list[float]]: return [self._vectorize(text) for text in texts] + def runtime_log_attrs(self) -> dict[str, object]: + return {} + @staticmethod def _vectorize(text: str) -> list[float]: normalized = text.lower() @@ -81,8 +85,9 @@ def _enable_semantic( pytest.skip("sqlite-vec dependency is required for sqlite vector repository tests.") search_repository._semantic_enabled = True - search_repository._embedding_provider = embedding_provider or StubEmbeddingProvider() - search_repository._vector_dimensions = search_repository._embedding_provider.dimensions + provider = embedding_provider or StubEmbeddingProvider() + search_repository._embedding_provider = provider + search_repository._vector_dimensions = provider.dimensions search_repository._vector_tables_initialized = False @@ -355,9 +360,10 @@ async def fake_scoped_session(session_maker): monkeypatch.setattr(repo, "_upsert_scheduled_chunk_records", _stub_upsert) prepared = await repo._prepare_entity_vector_jobs_window([1, 2]) + prepared_results = [result for result in prepared if not isinstance(result, BaseException)] assert fetched_windows == [[1, 2]] - assert [result.entity_id for result in prepared] == [1, 2] + assert [result.entity_id for result in prepared_results] == [1, 2] assert max_active_write_scopes == 1 @@ -418,9 +424,10 @@ async def fake_scoped_session(session_maker): monkeypatch.setattr(repo, "_upsert_scheduled_chunk_records", _stub_upsert) prepared = await asyncio.wait_for(repo._prepare_entity_vector_jobs_window([1]), timeout=1.0) + prepared_results = [result for result in prepared if not isinstance(result, BaseException)] - assert len(prepared) == 1 - assert prepared[0].entity_id == 1 + assert len(prepared_results) == 1 + assert prepared_results[0].entity_id == 1 @pytest.mark.asyncio @@ -535,7 +542,7 @@ async def capturing_execute(stmt, params=None): async with db.scoped_session(search_repository.session_maker) as session: await search_repository._prepare_vector_session(session) - session.execute = capturing_execute + cast(Any, session).execute = capturing_execute query_embedding = [0.1] * search_repository._vector_dimensions diff --git a/tests/repository/test_vector_pagination.py b/tests/repository/test_vector_pagination.py index 79cf30399..821ce84ad 100644 --- a/tests/repository/test_vector_pagination.py +++ b/tests/repository/test_vector_pagination.py @@ -6,11 +6,15 @@ from contextlib import asynccontextmanager from dataclasses import dataclass +from datetime import datetime +from typing import Any from unittest.mock import AsyncMock, patch import pytest from basic_memory.repository.search_repository_base import SearchRepositoryBase +from basic_memory.repository.search_index_row import SearchIndexRow +from basic_memory.schemas.search import SearchItemType, SearchRetrievalMode @dataclass @@ -43,7 +47,21 @@ async def init_search_index(self): def _prepare_search_term(self, term, is_prefix=True): return term # pragma: no cover - async def search(self, **kwargs): + async def search( + self, + search_text: str | None = None, + permalink: str | None = None, + permalink_match: str | None = None, + title: str | None = None, + note_types: list[str] | None = None, + after_date: datetime | None = None, + search_item_types: list[SearchItemType] | None = None, + metadata_filters: dict[str, Any] | None = None, + retrieval_mode: SearchRetrievalMode = SearchRetrievalMode.FTS, + min_similarity: float | None = None, + limit: int = 10, + offset: int = 0, + ) -> list[SearchIndexRow]: return [] # pragma: no cover async def _ensure_vector_tables(self): @@ -73,6 +91,20 @@ async def fake_scoped_session(session_maker): yield AsyncMock() +class _EmbeddingProvider: + dimensions = 384 + model_name = "stub" + + async def embed_query(self, text: str) -> list[float]: + return [0.0] * self.dimensions + + async def embed_documents(self, texts: list[str]) -> list[list[float]]: + return [[0.0] * self.dimensions for _ in texts] + + def runtime_log_attrs(self) -> dict[str, object]: + return {} + + def _make_descending_vector_rows(count: int) -> list[dict]: """Build vector rows with scores descending from ~1.0 to ~0.5.""" rows = [] @@ -98,8 +130,7 @@ async def test_page1_scores_gte_page2_scores(): # 20 results with descending scores fake_rows = _make_descending_vector_rows(20) - mock_embed = AsyncMock(return_value=[0.0] * 384) - repo._embedding_provider = type("EP", (), {"embed_query": mock_embed, "dimensions": 384})() + repo._embedding_provider = _EmbeddingProvider() fake_index_rows = {i: FakeRow(id=i) for i in range(20)} diff --git a/tests/repository/test_vector_threshold.py b/tests/repository/test_vector_threshold.py index 267b95d65..f4da09edf 100644 --- a/tests/repository/test_vector_threshold.py +++ b/tests/repository/test_vector_threshold.py @@ -2,15 +2,20 @@ from contextlib import asynccontextmanager from dataclasses import dataclass +from datetime import datetime +from typing import Any, Optional, cast from unittest.mock import AsyncMock, patch import pytest +from basic_memory.repository.embedding_provider import EmbeddingProvider +from basic_memory.repository.search_index_row import SearchIndexRow from basic_memory.repository.search_repository_base import ( SMALL_NOTE_CONTENT_LIMIT, TOP_CHUNKS_PER_RESULT, SearchRepositoryBase, ) +from basic_memory.schemas.search import SearchItemType, SearchRetrievalMode @dataclass @@ -46,7 +51,21 @@ async def init_search_index(self): def _prepare_search_term(self, term, is_prefix=True): return term # pragma: no cover - async def search(self, **kwargs): + async def search( + self, + search_text: Optional[str] = None, + permalink: Optional[str] = None, + permalink_match: Optional[str] = None, + title: Optional[str] = None, + note_types: Optional[list[str]] = None, + after_date: Optional[datetime] = None, + search_item_types: Optional[list[SearchItemType]] = None, + metadata_filters: Optional[dict[str, Any]] = None, + retrieval_mode: SearchRetrievalMode = SearchRetrievalMode.FTS, + min_similarity: Optional[float] = None, + limit: int = 10, + offset: int = 0, + ) -> list[SearchIndexRow]: return [] # pragma: no cover async def _ensure_vector_tables(self): @@ -90,13 +109,20 @@ def _make_vector_rows(scores: list[float]) -> list[dict]: return rows +def _fake_embedding_provider(mock_embed: AsyncMock) -> EmbeddingProvider: + return cast( + EmbeddingProvider, + type("EP", (), {"embed_query": mock_embed, "dimensions": 384})(), + ) + + @asynccontextmanager async def fake_scoped_session(session_maker): """Fake scoped_session that yields a mock session object.""" yield AsyncMock() -COMMON_SEARCH_KWARGS = dict( +COMMON_SEARCH_KWARGS: dict[str, Any] = dict( search_text="test", permalink=None, permalink_match=None, @@ -119,7 +145,7 @@ async def test_threshold_zero_returns_all(): fake_rows = _make_vector_rows([0.9, 0.5, 0.3]) mock_embed = AsyncMock(return_value=[0.0] * 384) - repo._embedding_provider = type("EP", (), {"embed_query": mock_embed, "dimensions": 384})() + repo._embedding_provider = _fake_embedding_provider(mock_embed) with ( patch( @@ -150,7 +176,7 @@ async def test_threshold_filters_low_scores(): fake_rows = _make_vector_rows([0.9, 0.5, 0.3]) mock_embed = AsyncMock(return_value=[0.0] * 384) - repo._embedding_provider = type("EP", (), {"embed_query": mock_embed, "dimensions": 384})() + repo._embedding_provider = _fake_embedding_provider(mock_embed) with ( patch( @@ -183,7 +209,7 @@ async def test_threshold_returns_empty_when_all_below(): fake_rows = _make_vector_rows([0.5, 0.4, 0.3]) mock_embed = AsyncMock(return_value=[0.0] * 384) - repo._embedding_provider = type("EP", (), {"embed_query": mock_embed, "dimensions": 384})() + repo._embedding_provider = _fake_embedding_provider(mock_embed) mock_fetch = AsyncMock() @@ -214,7 +240,7 @@ async def test_per_query_min_similarity_overrides_instance_default(): fake_rows = _make_vector_rows([0.9, 0.5, 0.3]) mock_embed = AsyncMock(return_value=[0.0] * 384) - repo._embedding_provider = type("EP", (), {"embed_query": mock_embed, "dimensions": 384})() + repo._embedding_provider = _fake_embedding_provider(mock_embed) with ( patch( @@ -247,7 +273,7 @@ async def test_per_query_min_similarity_tightens_threshold(): fake_rows = _make_vector_rows([0.9, 0.5, 0.3]) mock_embed = AsyncMock(return_value=[0.0] * 384) - repo._embedding_provider = type("EP", (), {"embed_query": mock_embed, "dimensions": 384})() + repo._embedding_provider = _fake_embedding_provider(mock_embed) with ( patch( @@ -280,7 +306,7 @@ async def test_matched_chunk_text_populated_on_vector_results(): fake_rows = _make_vector_rows([0.9, 0.7]) mock_embed = AsyncMock(return_value=[0.0] * 384) - repo._embedding_provider = type("EP", (), {"embed_query": mock_embed, "dimensions": 384})() + repo._embedding_provider = _fake_embedding_provider(mock_embed) with ( patch( @@ -335,7 +361,7 @@ async def test_top_n_chunks_joined_in_matched_chunk_text(): fake_rows = _make_multi_chunk_vector_rows(si_id=0, scores=chunk_scores) mock_embed = AsyncMock(return_value=[0.0] * 384) - repo._embedding_provider = type("EP", (), {"embed_query": mock_embed, "dimensions": 384})() + repo._embedding_provider = _fake_embedding_provider(mock_embed) # content_snippet exceeds SMALL_NOTE_CONTENT_LIMIT → top-N chunks path large_content = "x" * (SMALL_NOTE_CONTENT_LIMIT + 1) @@ -358,6 +384,7 @@ async def test_top_n_chunks_joined_in_matched_chunk_text(): assert len(results) == 1 text = results[0].matched_chunk_text + assert text is not None # Top 5 chunks by similarity: 0.9, 0.85, 0.8, 0.75, 0.6 (0.4 and 0.3 excluded) parts = text.split("\n---\n") @@ -378,7 +405,7 @@ async def test_small_note_returns_full_content_as_matched_chunk(): fake_rows = _make_vector_rows([0.9]) mock_embed = AsyncMock(return_value=[0.0] * 384) - repo._embedding_provider = type("EP", (), {"embed_query": mock_embed, "dimensions": 384})() + repo._embedding_provider = _fake_embedding_provider(mock_embed) small_content = "This is a short note with all the important details." assert len(small_content) <= SMALL_NOTE_CONTENT_LIMIT @@ -413,7 +440,7 @@ async def test_large_note_returns_chunks_not_full_content(): fake_rows = _make_vector_rows([0.9]) mock_embed = AsyncMock(return_value=[0.0] * 384) - repo._embedding_provider = type("EP", (), {"embed_query": mock_embed, "dimensions": 384})() + repo._embedding_provider = _fake_embedding_provider(mock_embed) large_content = "x" * (SMALL_NOTE_CONTENT_LIMIT + 500) diff --git a/tests/schema/test_resolver.py b/tests/schema/test_resolver.py index 4c45b8136..babc71d9e 100644 --- a/tests/schema/test_resolver.py +++ b/tests/schema/test_resolver.py @@ -43,6 +43,7 @@ async def test_inline_schema_uses_type_for_entity(self, mock_search_fn): "schema": {"field": "string"}, } result = await resolve_schema(frontmatter, mock_search_fn) + assert result is not None assert result.entity == "CustomType" @pytest.mark.asyncio @@ -51,6 +52,7 @@ async def test_inline_schema_defaults_entity_to_unknown(self, mock_search_fn): "schema": {"field": "string"}, } result = await resolve_schema(frontmatter, mock_search_fn) + assert result is not None assert result.entity == "unknown" @pytest.mark.asyncio @@ -61,6 +63,7 @@ async def test_inline_schema_respects_validation_mode(self, mock_search_fn): "settings": {"validation": "strict"}, } result = await resolve_schema(frontmatter, mock_search_fn) + assert result is not None assert result.validation_mode == "strict" diff --git a/tests/schema/test_validator.py b/tests/schema/test_validator.py index 5d30431a7..4b5a9a266 100644 --- a/tests/schema/test_validator.py +++ b/tests/schema/test_validator.py @@ -182,6 +182,7 @@ def test_enum_invalid_value_warn_mode(self): assert result.passed is True # warn mode fr = result.field_results[0] assert fr.status == "enum_mismatch" + assert fr.message is not None assert "archived" in fr.message assert len(result.warnings) == 1 @@ -323,6 +324,7 @@ def test_enum_frontmatter_invalid_value_warn(self): result = validate_note("test-note", schema, [], [], frontmatter={"status": "archived"}) assert result.passed is True assert result.field_results[0].status == "enum_mismatch" + assert result.field_results[0].message is not None assert "archived" in result.field_results[0].message assert len(result.warnings) == 1 diff --git a/tests/schemas/test_search.py b/tests/schemas/test_search.py index b1fe51e89..b2fd8639e 100644 --- a/tests/schemas/test_search.py +++ b/tests/schemas/test_search.py @@ -47,7 +47,10 @@ def test_search_retrieval_mode_defaults_to_fts(): query = SearchQuery(text="search implementation") assert query.retrieval_mode == SearchRetrievalMode.FTS - vector_query = SearchQuery(text="search implementation", retrieval_mode="vector") + vector_query = SearchQuery( + text="search implementation", + retrieval_mode=SearchRetrievalMode.VECTOR, + ) assert vector_query.retrieval_mode == SearchRetrievalMode.VECTOR diff --git a/tests/services/test_entity_service.py b/tests/services/test_entity_service.py index df071ce84..1a97ab825 100644 --- a/tests/services/test_entity_service.py +++ b/tests/services/test_entity_service.py @@ -3,6 +3,7 @@ import uuid from pathlib import Path from textwrap import dedent +from typing import Any, cast import pytest import yaml @@ -21,6 +22,12 @@ from basic_memory.utils import generate_permalink +def _permalink(entity: EntityModel | EntitySchema) -> str: + permalink = entity.permalink + assert permalink is not None + return permalink + + class _DeleteTestEmbeddingProvider: """Deterministic embedding provider for entity delete cleanup tests.""" @@ -138,7 +145,7 @@ async def test_create_entity( assert len(entity.relations) == 0 # Verify we can retrieve it using permalink - retrieved = await entity_service.get_by_permalink(entity.permalink) + retrieved = await entity_service.get_by_permalink(_permalink(entity)) assert retrieved.title == "Test Entity" assert retrieved.note_type == "test" assert retrieved.created_at is not None @@ -249,13 +256,13 @@ async def test_get_by_permalink(entity_service: EntityService): entity2 = await entity_service.create_entity(entity2_data) # Find by type1 and name - found = await entity_service.get_by_permalink(entity1_data.permalink) + found = await entity_service.get_by_permalink(_permalink(entity1_data)) assert found is not None assert found.id == entity1.id assert found.note_type == entity1.note_type # Find by type2 and name - found = await entity_service.get_by_permalink(entity2_data.permalink) + found = await entity_service.get_by_permalink(_permalink(entity2_data)) assert found is not None assert found.id == entity2.id assert found.note_type == entity2.note_type @@ -276,7 +283,7 @@ async def test_get_entity_success(entity_service: EntityService): await entity_service.create_entity(entity_data) # Get by permalink - retrieved = await entity_service.get_by_permalink(entity_data.permalink) + retrieved = await entity_service.get_by_permalink(_permalink(entity_data)) assert isinstance(retrieved, EntityModel) assert retrieved.title == "TestEntity" @@ -294,12 +301,12 @@ async def test_delete_entity_success(entity_service: EntityService): await entity_service.create_entity(entity_data) # Act using permalink - result = await entity_service.delete_entity(entity_data.permalink) + result = await entity_service.delete_entity(_permalink(entity_data)) # Assert assert result is True with pytest.raises(EntityNotFoundError): - await entity_service.get_by_permalink(entity_data.permalink) + await entity_service.get_by_permalink(_permalink(entity_data)) @pytest.mark.asyncio @@ -318,7 +325,7 @@ async def test_delete_entity_by_id(entity_service: EntityService): # Assert assert result is True with pytest.raises(EntityNotFoundError): - await entity_service.get_by_permalink(entity_data.permalink) + await entity_service.get_by_permalink(_permalink(entity_data)) @pytest.mark.asyncio @@ -332,7 +339,7 @@ async def test_delete_entity_removes_search_and_vector_state( if app_config.database_backend == DatabaseBackend.SQLITE: pytest.importorskip("sqlite_vec") - repository = search_service.repository + repository = cast(Any, search_service.repository) repository._semantic_enabled = True repository._embedding_provider = _DeleteTestEmbeddingProvider() repository._vector_dimensions = repository._embedding_provider.dimensions @@ -403,7 +410,7 @@ async def test_create_entity_with_special_chars(entity_service: EntityService): assert entity.title == name # Verify after retrieval using permalink - await entity_service.get_by_permalink(entity_data.permalink) + await entity_service.get_by_permalink(_permalink(entity_data)) @pytest.mark.asyncio @@ -424,7 +431,7 @@ async def test_get_entities_by_permalinks(entity_service: EntityService): await entity_service.create_entity(entity2_data) # Open nodes by path IDs - permalinks = [entity1_data.permalink, entity2_data.permalink] + permalinks = [_permalink(entity1_data), _permalink(entity2_data)] found = await entity_service.get_entities_by_permalinks(permalinks) assert len(found) == 2 @@ -451,7 +458,7 @@ async def test_get_entities_some_not_found(entity_service: EntityService): await entity_service.create_entity(entity_data) # Try to open two nodes, one exists, one doesn't - permalinks = [entity_data.permalink, "type1/non_existent"] + permalinks = [_permalink(entity_data), "type1/non_existent"] found = await entity_service.get_entities_by_permalinks(permalinks) assert len(found) == 1 @@ -482,6 +489,7 @@ async def test_update_note_entity_content(entity_service: EntityService, file_se ) entity = await entity_service.create_entity(schema) + assert entity.entity_metadata is not None assert entity.entity_metadata.get("status") == "draft" # Update content with a relation @@ -599,12 +607,14 @@ async def test_create_or_update_existing(entity_service: EntityService, file_ser ) ) - entity.content = "Updated content" + entity_for_update = cast(Any, entity) + entity_for_update.content = "Updated content" # Update name - updated, created = await entity_service.create_or_update_entity(entity) + updated, created = await entity_service.create_or_update_entity(entity_for_update) assert updated.title == "test" + assert updated.entity_metadata is not None assert updated.entity_metadata["status"] == "final" assert created is False @@ -847,7 +857,7 @@ async def test_edit_entity_append(entity_service: EntityService, file_service: F # Edit entity with append operation updated = await entity_service.edit_entity( - identifier=entity.permalink, operation="append", content="Appended content" + identifier=_permalink(entity), operation="append", content="Appended content" ) # Verify content was appended @@ -873,7 +883,7 @@ async def test_edit_entity_prepend(entity_service: EntityService, file_service: # Edit entity with prepend operation updated = await entity_service.edit_entity( - identifier=entity.permalink, operation="prepend", content="Prepended content" + identifier=_permalink(entity), operation="prepend", content="Prepended content" ) # Verify content was prepended @@ -899,7 +909,7 @@ async def test_edit_entity_find_replace(entity_service: EntityService, file_serv # Edit entity with find_replace operation updated = await entity_service.edit_entity( - identifier=entity.permalink, + identifier=_permalink(entity), operation="find_replace", content="new content", find_text="old content", @@ -939,7 +949,7 @@ async def test_edit_entity_replace_section( # Edit entity with replace_section operation updated = await entity_service.edit_entity( - identifier=entity.permalink, + identifier=_permalink(entity), operation="replace_section", content="New section 1 content", section="## Section 1", @@ -970,7 +980,7 @@ async def test_edit_entity_replace_section_create_new( # Edit entity with replace_section operation for non-existent section updated = await entity_service.edit_entity( - identifier=entity.permalink, + identifier=_permalink(entity), operation="replace_section", content="New section content", section="## New Section", @@ -1007,7 +1017,7 @@ async def test_edit_entity_invalid_operation(entity_service: EntityService): with pytest.raises(ValueError, match="Unsupported operation"): await entity_service.edit_entity( - identifier=entity.permalink, operation="invalid_operation", content="content" + identifier=_permalink(entity), operation="invalid_operation", content="content" ) @@ -1026,7 +1036,7 @@ async def test_edit_entity_find_replace_missing_find_text(entity_service: Entity with pytest.raises(ValueError, match="find_text is required"): await entity_service.edit_entity( - identifier=entity.permalink, operation="find_replace", content="new content" + identifier=_permalink(entity), operation="find_replace", content="new content" ) @@ -1045,7 +1055,7 @@ async def test_edit_entity_replace_section_missing_section(entity_service: Entit with pytest.raises(ValueError, match="section is required"): await entity_service.edit_entity( - identifier=entity.permalink, operation="replace_section", content="new content" + identifier=_permalink(entity), operation="replace_section", content="new content" ) @@ -1079,7 +1089,7 @@ async def test_edit_entity_with_observations_and_relations( # Edit entity by appending content with new observations/relations updated = await entity_service.edit_entity( - identifier=entity.permalink, + identifier=_permalink(entity), operation="append", content="\n- [category] New observation\n- relates to [[New Entity]]", ) @@ -1185,7 +1195,7 @@ async def test_edit_entity_find_replace_not_found(entity_service: EntityService) # Try to replace text that doesn't exist with pytest.raises(ValueError, match="Text to replace not found: 'nonexistent'"): await entity_service.edit_entity( - identifier=entity.permalink, + identifier=_permalink(entity), operation="find_replace", content="new content", find_text="nonexistent", @@ -1210,7 +1220,7 @@ async def test_edit_entity_find_replace_multiple_occurrences_expected_one( # Try to replace with expected count of 1 when there are 2 with pytest.raises(ValueError, match="Expected 1 occurrences of 'banana', but found 2"): await entity_service.edit_entity( - identifier=entity.permalink, + identifier=_permalink(entity), operation="find_replace", content="replacement", find_text="banana", @@ -1235,7 +1245,7 @@ async def test_edit_entity_find_replace_multiple_occurrences_success( # Replace with correct expected count updated = await entity_service.edit_entity( - identifier=entity.permalink, + identifier=_permalink(entity), operation="find_replace", content="apple", find_text="banana", @@ -1264,7 +1274,7 @@ async def test_edit_entity_find_replace_empty_find_text(entity_service: EntitySe # Try with empty find_text with pytest.raises(ValueError, match="find_text cannot be empty or whitespace only"): await entity_service.edit_entity( - identifier=entity.permalink, + identifier=_permalink(entity), operation="find_replace", content="new content", find_text=" ", # whitespace only @@ -1301,7 +1311,10 @@ async def test_edit_entity_find_replace_multiline( new_text = "This is new content\nthat replaces the old paragraph." updated = await entity_service.edit_entity( - identifier=entity.permalink, operation="find_replace", content=new_text, find_text=find_text + identifier=_permalink(entity), + operation="find_replace", + content=new_text, + find_text=find_text, ) # Verify replacement worked @@ -1341,7 +1354,7 @@ async def test_edit_entity_replace_section_multiple_sections_error(entity_servic # Try to replace section when multiple exist with pytest.raises(ValueError, match="Multiple sections found with header '## Section 1'"): await entity_service.edit_entity( - identifier=entity.permalink, + identifier=_permalink(entity), operation="replace_section", content="New content", section="## Section 1", @@ -1364,7 +1377,7 @@ async def test_edit_entity_replace_section_empty_section(entity_service: EntityS # Try with empty section with pytest.raises(ValueError, match="section cannot be empty or whitespace only"): await entity_service.edit_entity( - identifier=entity.permalink, + identifier=_permalink(entity), operation="replace_section", content="new content", section=" ", # whitespace only @@ -1398,7 +1411,7 @@ async def test_edit_entity_replace_section_header_variations( # Test replacing with different header format (no ##) updated = await entity_service.edit_entity( - identifier=entity.permalink, + identifier=_permalink(entity), operation="replace_section", content="New section content", section="Section Name", # No ## prefix @@ -1438,7 +1451,7 @@ async def test_edit_entity_replace_section_at_end_of_document( # Replace the last section updated = await entity_service.edit_entity( - identifier=entity.permalink, + identifier=_permalink(entity), operation="replace_section", content="New last section content", section="## Last Section", @@ -1485,7 +1498,7 @@ async def test_edit_entity_replace_section_with_subsections( # Replace parent section (should only replace content until first subsection) updated = await entity_service.edit_entity( - identifier=entity.permalink, + identifier=_permalink(entity), operation="replace_section", content="New parent content", section="## Parent Section", @@ -1530,7 +1543,7 @@ async def test_edit_entity_replace_section_strips_duplicate_header( # Replace section with content that includes the duplicate header # (This is what LLMs sometimes do) updated = await entity_service.edit_entity( - identifier=entity.permalink, + identifier=_permalink(entity), operation="replace_section", content="## Testing\nNew content for testing section", section="## Testing", @@ -1577,7 +1590,7 @@ async def test_edit_entity_insert_before_section( ) updated = await entity_service.edit_entity( - identifier=entity.permalink, + identifier=_permalink(entity), operation="insert_before_section", content="Inserted before section 2", section="## Section 2", @@ -1617,7 +1630,7 @@ async def test_edit_entity_insert_after_section( ) updated = await entity_service.edit_entity( - identifier=entity.permalink, + identifier=_permalink(entity), operation="insert_after_section", content="Inserted after section 1 heading", section="## Section 1", @@ -1648,7 +1661,7 @@ async def test_edit_entity_insert_before_section_not_found(entity_service: Entit with pytest.raises(ValueError, match="Section '## Missing' not found"): await entity_service.edit_entity( - identifier=entity.permalink, + identifier=_permalink(entity), operation="insert_before_section", content="new content", section="## Missing", @@ -1669,7 +1682,7 @@ async def test_edit_entity_insert_after_section_not_found(entity_service: Entity with pytest.raises(ValueError, match="Section '## Missing' not found"): await entity_service.edit_entity( - identifier=entity.permalink, + identifier=_permalink(entity), operation="insert_after_section", content="new content", section="## Missing", @@ -1692,7 +1705,7 @@ async def test_edit_entity_insert_before_section_multiple_sections_error( with pytest.raises(ValueError, match="Multiple sections found"): await entity_service.edit_entity( - identifier=entity.permalink, + identifier=_permalink(entity), operation="insert_before_section", content="new content", section="## Dup", @@ -1715,7 +1728,7 @@ async def test_edit_entity_insert_before_section_missing_section_param( with pytest.raises(ValueError, match="section is required"): await entity_service.edit_entity( - identifier=entity.permalink, + identifier=_permalink(entity), operation="insert_before_section", content="new content", ) @@ -1735,7 +1748,7 @@ async def test_edit_entity_insert_before_section_empty_section(entity_service: E with pytest.raises(ValueError, match="section cannot be empty"): await entity_service.edit_entity( - identifier=entity.permalink, + identifier=_permalink(entity), operation="insert_before_section", content="new content", section=" ", @@ -1764,7 +1777,7 @@ async def test_edit_entity_insert_after_section_at_end_of_document( ) updated = await entity_service.edit_entity( - identifier=entity.permalink, + identifier=_permalink(entity), operation="insert_after_section", content="Inserted after the last section heading", section="## Only Section", @@ -1800,7 +1813,7 @@ async def test_edit_entity_insert_after_section_preserves_paragraph_separation( ) updated = await entity_service.edit_entity( - identifier=entity.permalink, + identifier=_permalink(entity), operation="insert_after_section", content="Inserted line", section="## Section", @@ -1840,7 +1853,7 @@ async def test_move_entity_success( # Move entity assert entity.permalink == f"{generate_permalink(project_config.name)}/original/test-note" await entity_service.move_entity( - identifier=entity.permalink, + identifier=_permalink(entity), destination_path="moved/test-note.md", project_config=project_config, app_config=app_config, @@ -1854,7 +1867,7 @@ async def test_move_entity_success( assert new_path.exists() # Verify database was updated - updated_entity = await entity_service.get_by_permalink(entity.permalink) + updated_entity = await entity_service.get_by_permalink(_permalink(entity)) assert updated_entity.file_path == "moved/test-note.md" # Verify file content is preserved @@ -1886,7 +1899,7 @@ async def test_move_entity_with_permalink_update( # Move entity await entity_service.move_entity( - identifier=entity.permalink, + identifier=_permalink(entity), destination_path="moved/test-note.md", project_config=project_config, app_config=app_config, @@ -1900,7 +1913,7 @@ async def test_move_entity_with_permalink_update( # Verify frontmatter was updated with new permalink new_content, _ = await file_service.read_file("moved/test-note.md") - assert moved_entity.permalink in new_content + assert _permalink(moved_entity) in new_content @pytest.mark.asyncio @@ -1924,7 +1937,7 @@ async def test_move_entity_creates_destination_directory( # Move to deeply nested path that doesn't exist await entity_service.move_entity( - identifier=entity.permalink, + identifier=_permalink(entity), destination_path="deeply/nested/folders/test-note.md", project_config=project_config, app_config=app_config, @@ -1978,7 +1991,7 @@ async def test_move_entity_source_file_missing( with pytest.raises(ValueError, match="Source file not found:"): await entity_service.move_entity( - identifier=entity.permalink, + identifier=_permalink(entity), destination_path="new/path.md", project_config=project_config, app_config=app_config, @@ -2016,7 +2029,7 @@ async def test_move_entity_destination_exists( # Try to move entity1 to entity2's location with pytest.raises(ValueError, match="Destination already exists:"): await entity_service.move_entity( - identifier=entity1.permalink, + identifier=_permalink(entity1), destination_path=entity2.file_path, project_config=project_config, app_config=app_config, @@ -2044,7 +2057,7 @@ async def test_move_entity_invalid_destination_path( # Test absolute path with pytest.raises(ValueError, match="Invalid destination path:"): await entity_service.move_entity( - identifier=entity.permalink, + identifier=_permalink(entity), destination_path="/absolute/path.md", project_config=project_config, app_config=app_config, @@ -2053,7 +2066,7 @@ async def test_move_entity_invalid_destination_path( # Test empty path with pytest.raises(ValueError, match="Invalid destination path:"): await entity_service.move_entity( - identifier=entity.permalink, + identifier=_permalink(entity), destination_path="", project_config=project_config, app_config=app_config, @@ -2131,7 +2144,7 @@ async def test_move_entity_preserves_observations_and_relations( # Move entity await entity_service.move_entity( - identifier=entity.permalink, + identifier=_permalink(entity), destination_path="moved/test-note.md", project_config=project_config, app_config=app_config, @@ -2139,6 +2152,7 @@ async def test_move_entity_preserves_observations_and_relations( # Get moved entity moved_entity = await entity_service.link_resolver.resolve_link("moved/test-note.md") + assert moved_entity is not None # Verify observations and relations are preserved assert len(moved_entity.observations) == 1 @@ -2158,6 +2172,7 @@ async def test_move_entity_rollback_on_database_failure( file_service: FileService, project_config: ProjectConfig, entity_repository: EntityRepository, + monkeypatch, ): """Test that filesystem changes are rolled back on database failures.""" # Create test entity @@ -2175,33 +2190,25 @@ async def test_move_entity_rollback_on_database_failure( app_config = BasicMemoryConfig(update_permalinks_on_move=False) - # Mock repository update to fail - original_update = entity_repository.update - async def failing_update(*args, **kwargs): return None # Simulate failure - entity_repository.update = failing_update - - try: - with pytest.raises(ValueError, match="Move failed:"): - await entity_service.move_entity( - identifier=entity.permalink, - destination_path="moved/test-note.md", - project_config=project_config, - app_config=app_config, - ) + monkeypatch.setattr(entity_repository, "update", failing_update) - # Verify rollback - original file should still exist - assert await file_service.exists(original_path) + with pytest.raises(ValueError, match="Move failed:"): + await entity_service.move_entity( + identifier=_permalink(entity), + destination_path="moved/test-note.md", + project_config=project_config, + app_config=app_config, + ) - # Verify destination file was cleaned up - destination_path = project_config.home / "moved/test-note.md" - assert not destination_path.exists() + # Verify rollback - original file should still exist + assert await file_service.exists(original_path) - finally: - # Restore original update method - entity_repository.update = original_update + # Verify destination file was cleaned up + destination_path = project_config.home / "moved/test-note.md" + assert not destination_path.exists() @pytest.mark.asyncio @@ -2238,7 +2245,7 @@ async def test_move_entity_with_complex_observations( # Move entity await entity_service.move_entity( - identifier=entity.permalink, + identifier=_permalink(entity), destination_path="moved/complex-note.md", project_config=project_config, app_config=app_config, @@ -2246,6 +2253,7 @@ async def test_move_entity_with_complex_observations( # Verify moved entity maintains structure moved_entity = await entity_service.link_resolver.resolve_link("moved/complex-note.md") + assert moved_entity is not None # Check observations with tags and context design_obs = [obs for obs in moved_entity.observations if obs.category == "design"][0] @@ -2612,7 +2620,7 @@ async def test_delete_entity_by_permalink_already_deleted(entity_service: Entity assert await entity_service.delete_entity(created.id) is True # Delete again by permalink - should return True (EntityNotFoundError caught) - assert await entity_service.delete_entity(entity_data.permalink) is True + assert await entity_service.delete_entity(_permalink(entity_data)) is True @pytest.mark.asyncio diff --git a/tests/services/test_entity_service_disable_permalinks.py b/tests/services/test_entity_service_disable_permalinks.py index 9ac684374..42cc0da9f 100644 --- a/tests/services/test_entity_service_disable_permalinks.py +++ b/tests/services/test_entity_service_disable_permalinks.py @@ -204,13 +204,14 @@ async def test_move_entity_with_permalinks_disabled( # Create entity entity = await entity_service.create_entity(entity_data) original_permalink = entity.permalink + assert original_permalink is not None # Now disable permalinks app_config_disabled = BasicMemoryConfig(disable_permalinks=True, update_permalinks_on_move=True) # Move entity moved = await entity_service.move_entity( - identifier=entity.permalink, + identifier=original_permalink, destination_path="new_folder/test_entity.md", project_config=project_config, app_config=app_config_disabled, diff --git a/tests/services/test_project_service.py b/tests/services/test_project_service.py index 8e58f8480..c4c0cb26a 100644 --- a/tests/services/test_project_service.py +++ b/tests/services/test_project_service.py @@ -233,9 +233,10 @@ async def test_set_default_project_async(project_service: ProjectService, test_p assert project.is_default is True # Make sure old default is no longer default - old_default_project = await project_service.repository.get_by_name(original_default) - if old_default_project: - assert old_default_project.is_default is not True + if original_default: + old_default_project = await project_service.repository.get_by_name(original_default) + if old_default_project: + assert old_default_project.is_default is not True finally: # Restore original default (only if it exists in database) @@ -327,8 +328,10 @@ async def test_add_project_with_set_default_true(project_service: ProjectService try: # Get original default project from database - original_default_project = await project_service.repository.get_by_name( - original_default + original_default_project = ( + await project_service.repository.get_by_name(original_default) + if original_default + else None ) # Add project with set_default=True @@ -345,7 +348,9 @@ async def test_add_project_with_set_default_true(project_service: ProjectService # Verify original default is no longer default in database if original_default_project: + assert original_default is not None refreshed_original = await project_service.repository.get_by_name(original_default) + assert refreshed_original is not None assert refreshed_original.is_default is not True # Verify only one project has is_default=True @@ -394,8 +399,10 @@ async def test_add_project_with_set_default_false(project_service: ProjectServic assert new_project.is_default is not True # Verify original default is still default - original_default_project = await project_service.repository.get_by_name( - original_default + original_default_project = ( + await project_service.repository.get_by_name(original_default) + if original_default + else None ) if original_default_project: assert original_default_project.is_default is True diff --git a/tests/services/test_search_service.py b/tests/services/test_search_service.py index b4040cce9..d71f0d5f4 100644 --- a/tests/services/test_search_service.py +++ b/tests/services/test_search_service.py @@ -994,6 +994,7 @@ async def test_index_entity_with_duplicate_observations( # Reload entity with observations (get_by_permalink eagerly loads observations) entity = await entity_repo.get_by_permalink("test/duplicate-obs") + assert entity is not None # Verify we have duplicate observations assert len(entity.observations) == 2 @@ -1052,6 +1053,7 @@ async def test_index_entity_dedupes_observations_by_permalink( # Reload entity with observations (get_by_permalink eagerly loads observations) entity = await entity_repo.get_by_permalink("test/dedupe-test") + assert entity is not None assert len(entity.observations) == 3 # Index the entity @@ -1103,6 +1105,7 @@ async def test_index_entity_multiple_categories_same_content( # Reload entity with observations (get_by_permalink eagerly loads observations) entity = await entity_repo.get_by_permalink("test/multi-category") + assert entity is not None assert len(entity.observations) == 2 # Verify permalinks are different due to different categories @@ -1156,6 +1159,7 @@ async def test_index_entity_markdown_strips_nul_bytes(search_service, session_ma } entity = await entity_repo.create(entity_data) entity = await entity_repo.get_by_permalink("test/nul-test") + assert entity is not None # Index with NUL-containing content (simulates rclone-preallocated file) nul_content = "# NUL Test\x00\x00\nSome content\x00here" diff --git a/tests/services/test_semantic_search.py b/tests/services/test_semantic_search.py index 7bbf6b2e9..7776cf0dd 100644 --- a/tests/services/test_semantic_search.py +++ b/tests/services/test_semantic_search.py @@ -175,7 +175,9 @@ async def test_semantic_vector_sync_batch_skips_embed_opt_out_and_reports_skips( result = await search_service.sync_entity_vectors_batch([41, 42]) sync_batch.assert_awaited_once() - assert sync_batch.await_args.args[0] == [42] + sync_batch_args = sync_batch.await_args + assert sync_batch_args is not None + assert sync_batch_args.args[0] == [42] assert result.entities_total == 2 assert result.entities_synced == 1 assert result.entities_skipped == 1 diff --git a/tests/services/test_task_scheduler_semantic.py b/tests/services/test_task_scheduler_semantic.py index c4de471d0..6dd0233c0 100644 --- a/tests/services/test_task_scheduler_semantic.py +++ b/tests/services/test_task_scheduler_semantic.py @@ -2,6 +2,7 @@ import asyncio from pathlib import Path +from typing import Any, cast import pytest @@ -56,14 +57,14 @@ async def test_reindex_entity_task_chains_vector_sync_when_semantic_enabled(tmp_ project_config = ProjectConfig(name="test-project", home=tmp_path) scheduler = await get_task_scheduler( - entity_service=entity_service, - sync_service=sync_service, - search_service=search_service, + entity_service=cast(Any, entity_service), + sync_service=cast(Any, sync_service), + search_service=cast(Any, search_service), project_config=project_config, app_config=app_config, ) # Enable background tasks for this test — uses stubs, no real DB race risk - scheduler._test_mode = False # pyright: ignore [reportAttributeAccessIssue] + cast(Any, scheduler)._test_mode = False scheduler.schedule("reindex_entity", entity_id=42) await asyncio.sleep(0.05) @@ -86,14 +87,14 @@ async def test_reindex_entity_task_skips_vector_sync_when_semantic_disabled(tmp_ project_config = ProjectConfig(name="test-project", home=tmp_path) scheduler = await get_task_scheduler( - entity_service=entity_service, - sync_service=sync_service, - search_service=search_service, + entity_service=cast(Any, entity_service), + sync_service=cast(Any, sync_service), + search_service=cast(Any, search_service), project_config=project_config, app_config=app_config, ) # Enable background tasks for this test — uses stubs, no real DB race risk - scheduler._test_mode = False # pyright: ignore [reportAttributeAccessIssue] + cast(Any, scheduler)._test_mode = False scheduler.schedule("reindex_entity", entity_id=42) await asyncio.sleep(0.05) @@ -116,14 +117,14 @@ async def test_sync_entity_vectors_task_maps_to_search_service(tmp_path): project_config = ProjectConfig(name="test-project", home=tmp_path) scheduler = await get_task_scheduler( - entity_service=entity_service, - sync_service=sync_service, - search_service=search_service, + entity_service=cast(Any, entity_service), + sync_service=cast(Any, sync_service), + search_service=cast(Any, search_service), project_config=project_config, app_config=app_config, ) # Enable background tasks for this test — uses stubs, no real DB race risk - scheduler._test_mode = False # pyright: ignore [reportAttributeAccessIssue] + cast(Any, scheduler)._test_mode = False scheduler.schedule("sync_entity_vectors", entity_id=7) await asyncio.sleep(0.05) diff --git a/tests/sync/test_sync_service.py b/tests/sync/test_sync_service.py index 0be9b8488..db76f996c 100644 --- a/tests/sync/test_sync_service.py +++ b/tests/sync/test_sync_service.py @@ -4,6 +4,7 @@ from datetime import datetime, timezone from pathlib import Path from textwrap import dedent +from typing import Any, cast import pytest @@ -1783,7 +1784,8 @@ async def test_sync_file_continues_on_semantic_dependency_error( await sync_service.sync(project_dir) # Patch index_entity to raise SemanticDependenciesMissingError - original_index = sync_service.search_service.index_entity + search_service_mock = cast(Any, sync_service.search_service) + original_index = search_service_mock.index_entity call_count = 0 async def index_with_semantic_error(entity, **kwargs): @@ -1791,7 +1793,7 @@ async def index_with_semantic_error(entity, **kwargs): call_count += 1 raise SemanticDependenciesMissingError("sqlite-vec package is missing") - sync_service.search_service.index_entity = AsyncMock(side_effect=index_with_semantic_error) + search_service_mock.index_entity = AsyncMock(side_effect=index_with_semantic_error) try: # Modify the file so it gets re-synced @@ -1809,4 +1811,4 @@ async def index_with_semantic_error(entity, **kwargs): # Verify circuit breaker was NOT triggered (failure not recorded) assert "semantic_test.md" not in sync_service._file_failures finally: - sync_service.search_service.index_entity = original_index + search_service_mock.index_entity = original_index diff --git a/tests/sync/test_sync_service_incremental.py b/tests/sync/test_sync_service_incremental.py index beca0db45..44401e462 100644 --- a/tests/sync/test_sync_service_incremental.py +++ b/tests/sync/test_sync_service_incremental.py @@ -19,9 +19,31 @@ from basic_memory.config import ProjectConfig from basic_memory.indexing.models import IndexingBatchResult +from basic_memory.models import Project from basic_memory.sync.sync_service import SyncService +async def _current_project(sync_service: SyncService) -> Project: + project_id = sync_service.entity_repository.project_id + assert project_id is not None + + project = await sync_service.project_repository.find_by_id(project_id) + assert project is not None + return project + + +def _last_scan_timestamp(project: Project) -> float: + timestamp = project.last_scan_timestamp + assert timestamp is not None + return timestamp + + +def _last_file_count(project: Project) -> int: + file_count = project.last_file_count + assert file_count is not None + return file_count + + async def create_test_file(path: Path, content: str = "test content") -> None: """Create a test file with given content.""" path.parent.mkdir(parents=True, exist_ok=True) @@ -59,11 +81,9 @@ async def test_first_sync_uses_full_scan(sync_service: SyncService, project_conf assert "file2.md" in report.new # Verify watermark was set - project = await sync_service.project_repository.find_by_id( - sync_service.entity_repository.project_id - ) + project = await _current_project(sync_service) assert project.last_scan_timestamp is not None - assert project.last_file_count >= 2 # May include config files + assert _last_file_count(project) >= 2 # May include config files @pytest.mark.asyncio @@ -168,11 +188,8 @@ async def test_force_full_bypasses_watermark_optimization( assert len(report.new) == 2 # Verify watermark was set - project = await sync_service.project_repository.find_by_id( - sync_service.entity_repository.project_id - ) - assert project.last_scan_timestamp is not None - initial_timestamp = project.last_scan_timestamp + project = await _current_project(sync_service) + initial_timestamp = _last_scan_timestamp(project) # Sleep to ensure time passes await sleep_past_watermark() @@ -202,11 +219,8 @@ async def test_force_full_bypasses_watermark_optimization( assert "file1.md" in report.modified # Verify watermark was still updated after force_full - project = await sync_service.project_repository.find_by_id( - sync_service.entity_repository.project_id - ) - assert project.last_scan_timestamp is not None - assert project.last_scan_timestamp > initial_timestamp + project = await _current_project(sync_service) + assert _last_scan_timestamp(project) > initial_timestamp @pytest.mark.asyncio @@ -534,9 +548,7 @@ async def test_watermark_updated_after_successful_sync( await create_test_file(project_dir / "file1.md", "# File 1") # Get project before sync - project_before = await sync_service.project_repository.find_by_id( - sync_service.entity_repository.project_id - ) + project_before = await _current_project(sync_service) assert project_before.last_scan_timestamp is None assert project_before.last_file_count is None @@ -546,14 +558,12 @@ async def test_watermark_updated_after_successful_sync( sync_end = time.time() # Verify watermark was set - project_after = await sync_service.project_repository.find_by_id( - sync_service.entity_repository.project_id - ) + project_after = await _current_project(sync_service) assert project_after.last_scan_timestamp is not None - assert project_after.last_file_count >= 1 # May include config files + assert _last_file_count(project_after) >= 1 # May include config files # Watermark should be between sync start and end - assert sync_start <= project_after.last_scan_timestamp <= sync_end + assert sync_start <= _last_scan_timestamp(project_after) <= sync_end @pytest.mark.asyncio @@ -572,14 +582,13 @@ async def test_watermark_uses_sync_start_time( sync_end = time.time() # Get watermark - project = await sync_service.project_repository.find_by_id( - sync_service.entity_repository.project_id - ) + project = await _current_project(sync_service) # Watermark should be closer to start than end # (In practice, watermark == sync_start_timestamp captured in sync()) - time_from_start = abs(project.last_scan_timestamp - sync_start) - time_from_end = abs(project.last_scan_timestamp - sync_end) + project_timestamp = _last_scan_timestamp(project) + time_from_start = abs(project_timestamp - sync_start) + time_from_end = abs(project_timestamp - sync_end) assert time_from_start < time_from_end @@ -600,10 +609,8 @@ async def test_watermark_file_count_accurate( await sync_service.sync(project_dir) # Verify file count - project1 = await sync_service.project_repository.find_by_id( - sync_service.entity_repository.project_id - ) - initial_count = project1.last_file_count + project1 = await _current_project(sync_service) + initial_count = _last_file_count(project1) assert initial_count >= 3 # May include config files # Add more files @@ -615,10 +622,8 @@ async def test_watermark_file_count_accurate( await sync_service.sync(project_dir) # Verify updated count increased by 2 - project2 = await sync_service.project_repository.find_by_id( - sync_service.entity_repository.project_id - ) - assert project2.last_file_count == initial_count + 2 + project2 = await _current_project(sync_service) + assert _last_file_count(project2) == initial_count + 2 # ============================================================================== @@ -667,9 +672,7 @@ async def test_empty_directory_handles_incremental_scan( assert len(report1.new) == 0 # Verify watermark was set even for empty directory - project = await sync_service.project_repository.find_by_id( - sync_service.entity_repository.project_id - ) + project = await _current_project(sync_service) assert project.last_scan_timestamp is not None # May have config files, so just check it's set assert project.last_file_count is not None diff --git a/tests/sync/test_watch_service_reload.py b/tests/sync/test_watch_service_reload.py index f6ee8ff9e..45cc59167 100644 --- a/tests/sync/test_watch_service_reload.py +++ b/tests/sync/test_watch_service_reload.py @@ -9,6 +9,7 @@ import asyncio from dataclasses import dataclass +from typing import Any, cast import pytest @@ -33,11 +34,15 @@ async def get_active_projects(self): return self.projects_return or [] +def _watch_service(config: BasicMemoryConfig, repo: _Repo) -> WatchService: + return WatchService(config, cast(Any, repo), quiet=True) + + @pytest.mark.asyncio async def test_schedule_restart_uses_config_interval(monkeypatch): config = BasicMemoryConfig(watch_project_reload_interval=2) repo = _Repo() - watch_service = WatchService(config, repo, quiet=True) + watch_service = _watch_service(config, repo) stop_event = asyncio.Event() slept: list[int] = [] @@ -58,12 +63,12 @@ async def fake_sleep(seconds): async def test_watch_projects_cycle_handles_empty_project_list(monkeypatch): config = BasicMemoryConfig() repo = _Repo() - watch_service = WatchService(config, repo, quiet=True) + watch_service = _watch_service(config, repo) stop_event = asyncio.Event() stop_event.set() - captured = {"args": None, "kwargs": None} + captured: dict[str, Any] = {"args": None, "kwargs": None} async def awatch_stub(*args, **kwargs): captured["args"] = args @@ -76,18 +81,20 @@ async def awatch_stub(*args, **kwargs): await watch_service._watch_projects_cycle([], stop_event) + kwargs = captured["kwargs"] + assert isinstance(kwargs, dict) assert captured["args"] == () - assert captured["kwargs"]["debounce"] == config.sync_delay - assert captured["kwargs"]["watch_filter"] == watch_service.filter_changes - assert captured["kwargs"]["recursive"] is True - assert captured["kwargs"]["stop_event"] is stop_event + assert kwargs["debounce"] == config.sync_delay + assert kwargs["watch_filter"] == watch_service.filter_changes + assert kwargs["recursive"] is True + assert kwargs["stop_event"] is stop_event @pytest.mark.asyncio async def test_run_handles_no_projects(monkeypatch): config = BasicMemoryConfig() repo = _Repo(projects_return=[]) - watch_service = WatchService(config, repo, quiet=True) + watch_service = _watch_service(config, repo) slept: list[int] = [] @@ -120,7 +127,7 @@ async def test_run_reloads_projects_each_cycle(monkeypatch, tmp_path): ], ] ) - watch_service = WatchService(config, repo, quiet=True) + watch_service = _watch_service(config, repo) cycle_count = 0 @@ -159,7 +166,7 @@ async def test_run_filters_cloud_only_projects_each_cycle(monkeypatch, tmp_path) Project(id=2, name="cloud-only", path="cloud-slug", permalink="cloud-only"), ] ) - watch_service = WatchService(config, repo, quiet=True) + watch_service = _watch_service(config, repo) seen_project_names: list[list[str]] = [] @@ -200,7 +207,7 @@ async def test_run_keeps_cloud_projects_with_local_bisync(monkeypatch, tmp_path) ), ] ) - watch_service = WatchService(config, repo, quiet=True) + watch_service = _watch_service(config, repo) seen_project_names: list[list[str]] = [] @@ -226,7 +233,7 @@ async def test_run_continues_after_cycle_error(monkeypatch, tmp_path): repo = _Repo( projects_return=[Project(id=1, name="test", path=str(tmp_path / "test"), permalink="test")] ) - watch_service = WatchService(config, repo, quiet=True) + watch_service = _watch_service(config, repo) call_count = 0 slept: list[int] = [] @@ -261,7 +268,7 @@ async def test_timer_task_cancelled_properly(monkeypatch, tmp_path): repo = _Repo( projects_return=[Project(id=1, name="test", path=str(tmp_path / "test"), permalink="test")] ) - watch_service = WatchService(config, repo, quiet=True) + watch_service = _watch_service(config, repo) created_tasks: list[asyncio.Task] = [] real_create_task = asyncio.create_task @@ -312,7 +319,7 @@ async def test_new_project_addition_scenario(monkeypatch, tmp_path): ] repo = _Repo(projects_side_effect=[initial_projects, initial_projects, updated_projects]) - watch_service = WatchService(config, repo, quiet=True) + watch_service = _watch_service(config, repo) cycle_count = 0 project_lists_used: list[list[Project]] = [] diff --git a/tests/test_config.py b/tests/test_config.py index 5c4fe5d79..b852ead53 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -3,6 +3,7 @@ import tempfile import pytest from datetime import datetime +from typing import Any, cast from basic_memory.config import ( BasicMemoryConfig, @@ -13,6 +14,10 @@ from pathlib import Path +def _migrate_legacy_projects(data: dict[str, Any]) -> dict[str, Any]: + return cast(dict[str, Any], cast(Any, BasicMemoryConfig.migrate_legacy_projects)(data)) + + class TestBasicMemoryConfig: """Test BasicMemoryConfig behavior with BASIC_MEMORY_HOME environment variable.""" @@ -1346,7 +1351,7 @@ def test_migrate_promotes_local_sync_path_to_path(self): } } } - result = BasicMemoryConfig.migrate_legacy_projects(data) + result = _migrate_legacy_projects(data) assert result["projects"]["specs"]["path"] == "/Users/test/Documents/specs" def test_migrate_does_not_overwrite_absolute_path(self): @@ -1360,7 +1365,7 @@ def test_migrate_does_not_overwrite_absolute_path(self): } } } - result = BasicMemoryConfig.migrate_legacy_projects(data) + result = _migrate_legacy_projects(data) assert result["projects"]["specs"]["path"] == "/Users/test/Documents/specs" def test_migrate_skips_entries_without_local_sync_path(self): @@ -1373,7 +1378,7 @@ def test_migrate_skips_entries_without_local_sync_path(self): } } } - result = BasicMemoryConfig.migrate_legacy_projects(data) + result = _migrate_legacy_projects(data) assert result["projects"]["cloud-only"]["path"] == "cloud-only" def test_migrate_handles_mixed_projects(self, tmp_path): @@ -1391,7 +1396,7 @@ def test_migrate_handles_mixed_projects(self, tmp_path): }, } } - result = BasicMemoryConfig.migrate_legacy_projects(data) + result = _migrate_legacy_projects(data) assert result["projects"]["local-proj"]["path"] == local_path assert result["projects"]["cloud-only"]["path"] == "cloud-only" assert result["projects"]["cloud-bisync"]["path"] == bisync_path diff --git a/tests/test_production_cascade_delete.py b/tests/test_production_cascade_delete.py index 455f2e128..148d311b0 100644 --- a/tests/test_production_cascade_delete.py +++ b/tests/test_production_cascade_delete.py @@ -12,10 +12,15 @@ import sys from datetime import datetime, timezone from pathlib import Path -from typing import Optional +from typing import Any, Optional from sqlalchemy import text -from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker +from sqlalchemy.ext.asyncio import ( + AsyncEngine, + AsyncSession, + async_sessionmaker, + create_async_engine, +) class ProductionCascadeTest: @@ -33,8 +38,34 @@ def __init__(self, db_path: Optional[Path] = None): # Create backup path self.backup_path = self.db_path.with_suffix(".db.backup") - self.engine = None - self.session_maker = None + self.engine: AsyncEngine | None = None + self.session_maker: async_sessionmaker[AsyncSession] | None = None + + def _engine(self) -> AsyncEngine: + assert self.engine is not None + return self.engine + + def _session_maker(self) -> async_sessionmaker[AsyncSession]: + assert self.session_maker is not None + return self.session_maker + + @staticmethod + def _first_int(result: Any) -> int: + row = result.fetchone() + assert row is not None + return int(row[0]) + + @staticmethod + def _lastrowid(result: Any) -> int: + value = result.lastrowid + assert isinstance(value, int) + return value + + @staticmethod + def _rowcount(result: Any) -> int: + value = result.rowcount + assert isinstance(value, int) + return value async def setup(self): """Setup database connection.""" @@ -59,21 +90,21 @@ async def setup(self): async def cleanup(self): """Cleanup database connection.""" if self.engine: - await self.engine.dispose() + await self._engine().dispose() async def check_foreign_keys_enabled(self) -> bool: """Check if foreign keys are enabled in this session.""" - async with self.session_maker() as session: + async with self._session_maker()() as session: # Enable foreign keys like production does await session.execute(text("PRAGMA foreign_keys=ON")) result = await session.execute(text("PRAGMA foreign_keys")) - fk_enabled = result.fetchone()[0] + fk_enabled = self._first_int(result) return bool(fk_enabled) async def check_schema(self): """Check current database schema for foreign key constraints.""" - async with self.session_maker() as session: + async with self._session_maker()() as session: await session.execute(text("PRAGMA foreign_keys=ON")) # Check entity table foreign keys @@ -96,7 +127,7 @@ async def check_schema(self): async def create_test_data(self) -> tuple[int, int]: """Create test project and entity. Returns (project_id, entity_id).""" - async with self.session_maker() as session: + async with self._session_maker()() as session: await session.execute(text("PRAGMA foreign_keys=ON")) # Create test project @@ -119,7 +150,7 @@ async def create_test_data(self) -> tuple[int, int]: "updated_at": now, }, ) - project_id = result.lastrowid + project_id = self._lastrowid(result) # Create test entity linked to project entity_sql = """ @@ -143,7 +174,7 @@ async def create_test_data(self) -> tuple[int, int]: "updated_at": now, }, ) - entity_id = result.lastrowid + entity_id = self._lastrowid(result) await session.commit() @@ -152,19 +183,19 @@ async def create_test_data(self) -> tuple[int, int]: async def verify_test_data_exists(self, project_id: int, entity_id: int) -> bool: """Verify test data exists before deletion.""" - async with self.session_maker() as session: + async with self._session_maker()() as session: # Check project exists result = await session.execute( text("SELECT COUNT(*) FROM project WHERE id = :project_id"), {"project_id": project_id}, ) - project_count = result.fetchone()[0] + project_count = self._first_int(result) # Check entity exists result = await session.execute( text("SELECT COUNT(*) FROM entity WHERE id = :entity_id"), {"entity_id": entity_id} ) - entity_count = result.fetchone()[0] + entity_count = self._first_int(result) exists = project_count > 0 and entity_count > 0 if exists: @@ -180,7 +211,7 @@ async def verify_test_data_exists(self, project_id: int, entity_id: int) -> bool async def test_cascade_delete(self, project_id: int, entity_id: int) -> bool: """Test if deleting project cascades to delete entity.""" - async with self.session_maker() as session: + async with self._session_maker()() as session: await session.execute(text("PRAGMA foreign_keys=ON")) try: @@ -191,7 +222,7 @@ async def test_cascade_delete(self, project_id: int, entity_id: int) -> bool: text("DELETE FROM project WHERE id = :project_id"), {"project_id": project_id} ) - if result.rowcount == 0: + if self._rowcount(result) == 0: print("❌ Project deletion failed - no rows affected") return False @@ -203,7 +234,7 @@ async def test_cascade_delete(self, project_id: int, entity_id: int) -> bool: text("SELECT COUNT(*) FROM entity WHERE id = :entity_id"), {"entity_id": entity_id}, ) - entity_count = result.fetchone()[0] + entity_count = self._first_int(result) if entity_count == 0: print("✅ CASCADE DELETE working: Entity was automatically deleted") @@ -228,7 +259,7 @@ async def test_cascade_delete(self, project_id: int, entity_id: int) -> bool: async def cleanup_test_data(self, project_id: int, entity_id: int): """Clean up any remaining test data.""" - async with self.session_maker() as session: + async with self._session_maker()() as session: await session.execute(text("PRAGMA foreign_keys=ON")) try: diff --git a/tests/utils/test_file_utils.py b/tests/utils/test_file_utils.py index d731bad38..dc1cbf09b 100644 --- a/tests/utils/test_file_utils.py +++ b/tests/utils/test_file_utils.py @@ -4,6 +4,7 @@ import string import sys from pathlib import Path +from typing import Any, cast import pytest @@ -55,7 +56,7 @@ async def test_compute_checksum_error(): """Test checksum error handling.""" with pytest.raises(FileError): # Try to hash an object that can't be encoded - await compute_checksum(object()) # pyright: ignore [reportArgumentType] + await compute_checksum(cast(Any, object())) @pytest.mark.asyncio diff --git a/tests/utils/test_parse_tags.py b/tests/utils/test_parse_tags.py index c61dc9e21..bd99d7aa9 100644 --- a/tests/utils/test_parse_tags.py +++ b/tests/utils/test_parse_tags.py @@ -1,6 +1,6 @@ """Tests for parse_tags utility function.""" -from typing import List, Union +from typing import Any, List, Union, cast import pytest @@ -51,7 +51,7 @@ class TagObject: def __str__(self) -> str: return "tag1,tag2" - result = parse_tags(TagObject()) # pyright: ignore [reportArgumentType] + result = parse_tags(cast(Any, TagObject())) assert result == ["tag1", "tag2"] From a3e4badcf953bc85437b231bad72b1288a36f03f Mon Sep 17 00:00:00 2001 From: phernandez Date: Mon, 13 Apr 2026 09:54:44 -0500 Subject: [PATCH 2/2] fix(core): preserve empty frontmatter permalink semantics Signed-off-by: phernandez --- src/basic_memory/services/entity_service.py | 53 ++++++++++++------- .../test_entity_service_write_result.py | 22 ++++++++ 2 files changed, 55 insertions(+), 20 deletions(-) diff --git a/src/basic_memory/services/entity_service.py b/src/basic_memory/services/entity_service.py index 9d7f294bf..51e9876a8 100644 --- a/src/basic_memory/services/entity_service.py +++ b/src/basic_memory/services/entity_service.py @@ -60,6 +60,11 @@ class EntityWriteResult: search_content: str +def _frontmatter_permalink(value: object) -> str | None: + """Return an explicit frontmatter permalink only when YAML parsed a real string.""" + return value if isinstance(value, str) and value else None + + class EntityService(BaseService[EntityModel]): """Service for managing entities in the database.""" @@ -287,11 +292,13 @@ async def create_entity_with_content(self, schema: EntitySchema) -> EntityWriteR schema.note_type = content_frontmatter["type"] if "permalink" in content_frontmatter: - content_markdown = self._build_frontmatter_markdown( - schema.title, - schema.note_type, - _coerce_to_string(content_frontmatter["permalink"]), - ) + content_permalink = _frontmatter_permalink(content_frontmatter["permalink"]) + if content_permalink is not None: + content_markdown = self._build_frontmatter_markdown( + schema.title, + schema.note_type, + content_permalink, + ) # Get unique permalink (prioritizing content frontmatter) unless disabled if self.app_config and self.app_config.disable_permalinks: @@ -394,11 +401,13 @@ async def update_entity_with_content( schema.note_type = content_frontmatter["type"] if "permalink" in content_frontmatter: - content_markdown = self._build_frontmatter_markdown( - schema.title, - schema.note_type, - _coerce_to_string(content_frontmatter["permalink"]), - ) + content_permalink = _frontmatter_permalink(content_frontmatter["permalink"]) + if content_permalink is not None: + content_markdown = self._build_frontmatter_markdown( + schema.title, + schema.note_type, + content_permalink, + ) # Check if we need to update the permalink based on content frontmatter (unless disabled) new_permalink = entity.permalink # Default to existing @@ -525,11 +534,13 @@ async def fast_write_entity( schema.note_type = content_frontmatter["type"] if "permalink" in content_frontmatter: - content_markdown = self._build_frontmatter_markdown( - schema.title, - schema.note_type, - _coerce_to_string(content_frontmatter["permalink"]), - ) + content_permalink = _frontmatter_permalink(content_frontmatter["permalink"]) + if content_permalink is not None: + content_markdown = self._build_frontmatter_markdown( + schema.title, + schema.note_type, + content_permalink, + ) # --- Permalink Resolution --- if self.app_config and self.app_config.disable_permalinks: @@ -668,11 +679,13 @@ async def fast_edit_entity( update_data["note_type"] = _coerce_to_string(content_frontmatter["type"]) if "permalink" in content_frontmatter: - content_markdown = self._build_frontmatter_markdown( - _coerce_to_string(update_data.get("title", entity.title)), - _coerce_to_string(update_data.get("note_type", entity.note_type)), - _coerce_to_string(content_frontmatter["permalink"]), - ) + content_permalink = _frontmatter_permalink(content_frontmatter["permalink"]) + if content_permalink is not None: + content_markdown = self._build_frontmatter_markdown( + _coerce_to_string(update_data.get("title", entity.title)), + _coerce_to_string(update_data.get("note_type", entity.note_type)), + content_permalink, + ) metadata = normalize_frontmatter_metadata(content_frontmatter or {}) update_data["entity_metadata"] = {k: v for k, v in metadata.items() if v is not None} diff --git a/tests/services/test_entity_service_write_result.py b/tests/services/test_entity_service_write_result.py index 31bd62280..95daac423 100644 --- a/tests/services/test_entity_service_write_result.py +++ b/tests/services/test_entity_service_write_result.py @@ -27,6 +27,28 @@ async def test_create_entity_with_content_returns_full_and_search_content( assert result.search_content == "Create body content" +@pytest.mark.asyncio +@pytest.mark.parametrize("permalink_line", ["permalink:", "permalink: null", 'permalink: ""']) +async def test_create_entity_ignores_empty_frontmatter_permalink( + entity_service, file_service, permalink_line: str +) -> None: + result = await entity_service.create_entity_with_content( + EntitySchema( + title="Empty Frontmatter Permalink", + directory="notes", + note_type="note", + content=f"---\n{permalink_line}\n---\nCreate body content", + ) + ) + + file_path = file_service.get_entity_path(result.entity) + file_content, _ = await file_service.read_file(file_path) + + assert result.entity.permalink == "test-project/notes/empty-frontmatter-permalink" + assert "permalink: test-project/notes/empty-frontmatter-permalink" in file_content + assert "permalink: None" not in file_content + + @pytest.mark.asyncio async def test_update_entity_with_content_returns_full_and_search_content( entity_service, file_service