diff --git a/src/basic_memory/mcp/tools/read_note.py b/src/basic_memory/mcp/tools/read_note.py index 1d1adec87..e22718436 100644 --- a/src/basic_memory/mcp/tools/read_note.py +++ b/src/basic_memory/mcp/tools/read_note.py @@ -1,6 +1,7 @@ """Read note tool for Basic Memory MCP server.""" from textwrap import dedent +from typing import Optional from loguru import logger @@ -8,13 +9,17 @@ from basic_memory.mcp.server import mcp from basic_memory.mcp.tools.search import search_notes from basic_memory.mcp.tools.utils import call_get +from basic_memory.mcp.project_session import get_active_project from basic_memory.schemas.memory import memory_url_path +from basic_memory.utils import validate_project_path @mcp.tool( description="Read a markdown note by title or permalink.", ) -async def read_note(identifier: str, page: int = 1, page_size: int = 10) -> str: +async def read_note( + identifier: str, page: int = 1, page_size: int = 10, project: Optional[str] = None +) -> str: """Read a markdown note from the knowledge base. This tool finds and retrieves a note by its title, permalink, or content search, @@ -26,6 +31,7 @@ async def read_note(identifier: str, page: int = 1, page_size: int = 10) -> str: Can be a full memory:// URL, a permalink, a title, or search text page: Page number for paginated results (default: 1) page_size: Number of items per page (default: 10) + project: Optional project name to read from. If not provided, uses current active project. Returns: The full markdown content of the note if found, or helpful guidance if not found. @@ -42,10 +48,41 @@ async def read_note(identifier: str, page: int = 1, page_size: int = 10) -> str: # Read with pagination read_note("Project Updates", page=2, page_size=5) + + # Read from specific project + read_note("Meeting Notes", project="work-project") """ + + # Get the active project first to check project-specific sync status + active_project = get_active_project(project) + + # Validate identifier to prevent path traversal attacks + # We need to check both the raw identifier and the processed path + processed_path = memory_url_path(identifier) + project_path = active_project.home + + if not validate_project_path(identifier, project_path) or not validate_project_path(processed_path, project_path): + logger.warning( + "Attempted path traversal attack blocked", + identifier=identifier, + processed_path=processed_path, + project=active_project.name, + ) + return f"# Error\n\nIdentifier '{identifier}' is not allowed - paths must stay within project boundaries" + + # Check migration status and wait briefly if needed + from basic_memory.mcp.tools.utils import wait_for_migration_or_return_status + + migration_status = await wait_for_migration_or_return_status( + timeout=5.0, project_name=active_project.name + ) + if migration_status: # pragma: no cover + return f"# System Status\n\n{migration_status}\n\nPlease wait for migration to complete before reading notes." + project_url = active_project.project_url + # Get the file via REST API - first try direct permalink lookup entity_path = memory_url_path(identifier) - path = f"/resource/{entity_path}" + path = f"{project_url}/resource/{entity_path}" logger.info(f"Attempting to read note from URL: {path}") try: @@ -62,14 +99,14 @@ async def read_note(identifier: str, page: int = 1, page_size: int = 10) -> str: # Fallback 1: Try title search via API logger.info(f"Search title for: {identifier}") - title_results = await search_notes(query=identifier, search_type="title") + title_results = await search_notes.fn(query=identifier, search_type="title", project=project) if title_results and title_results.results: result = title_results.results[0] # Get the first/best match if result.permalink: try: # Try to fetch the content using the found permalink - path = f"/resource/{result.permalink}" + path = f"{project_url}/resource/{result.permalink}" response = await call_get( client, path, params={"page": page, "page_size": page_size} ) @@ -86,7 +123,7 @@ async def read_note(identifier: str, page: int = 1, page_size: int = 10) -> str: # Fallback 2: Text search as a last resort logger.info(f"Title search failed, trying text search for: {identifier}") - text_results = await search_notes(query=identifier, search_type="text") + text_results = await search_notes.fn(query=identifier, search_type="text", project=project) # We didn't find a direct match, construct a helpful error message if not text_results or not text_results.results: diff --git a/src/basic_memory/repository/search_repository.py b/src/basic_memory/repository/search_repository.py index 9c93cc6c4..20343a76f 100644 --- a/src/basic_memory/repository/search_repository.py +++ b/src/basic_memory/repository/search_repository.py @@ -523,8 +523,8 @@ async def index_item( async with db.scoped_session(self.session_maker) as session: # Delete existing record if any await session.execute( - text("DELETE FROM search_index WHERE permalink = :permalink"), - {"permalink": search_index_row.permalink}, + text("DELETE FROM search_index WHERE permalink = :permalink AND project_id = :project_id"), + {"permalink": search_index_row.permalink, "project_id": self.project_id}, ) # Prepare data for insert with project_id diff --git a/src/basic_memory/utils.py b/src/basic_memory/utils.py index e0b43ee86..25a519f35 100644 --- a/src/basic_memory/utils.py +++ b/src/basic_memory/utils.py @@ -53,26 +53,53 @@ def generate_permalink(file_path: Union[Path, str, PathLike]) -> str: # Remove extension base = os.path.splitext(path_str)[0] - # Check if we have non-ASCII characters that should be preserved - has_non_ascii = any(ord(char) > 127 for char in base) + # Check if we have CJK characters that should be preserved + # CJK ranges: \u4e00-\u9fff (CJK Unified Ideographs), \u3000-\u303f (CJK symbols), + # \u3400-\u4dbf (CJK Extension A), \uff00-\uffef (Fullwidth forms) + has_cjk_chars = any( + '\u4e00' <= char <= '\u9fff' or + '\u3000' <= char <= '\u303f' or + '\u3400' <= char <= '\u4dbf' or + '\uff00' <= char <= '\uffef' + for char in base + ) - if has_non_ascii: - # Preserve non-ASCII characters like Chinese while still processing ASCII parts - result = base + if has_cjk_chars: + # For text with CJK characters, selectively transliterate only Latin accented chars + result = "" + for char in base: + if ('\u4e00' <= char <= '\u9fff' or + '\u3000' <= char <= '\u303f' or + '\u3400' <= char <= '\u4dbf'): + # Preserve CJK ideographs and symbols + result += char + elif ('\uff00' <= char <= '\uffef'): + # Remove Chinese fullwidth punctuation entirely (like ,!?) + continue + else: + # Transliterate Latin accented characters to ASCII + result += unidecode(char) + + # Insert hyphens between CJK and Latin character transitions + # Match: CJK followed by Latin letter/digit, or Latin letter/digit followed by CJK + result = re.sub(r'([\u4e00-\u9fff\u3000-\u303f\u3400-\u4dbf])([a-zA-Z0-9])', r'\1-\2', result) + result = re.sub(r'([a-zA-Z0-9])([\u4e00-\u9fff\u3000-\u303f\u3400-\u4dbf])', r'\1-\2', result) # Insert dash between camelCase result = re.sub(r"([a-z0-9])([A-Z])", r"\1-\2", result) - # Convert only ASCII letters to lowercase, preserve non-ASCII + # Convert ASCII letters to lowercase, preserve CJK lower_text = "".join(c.lower() if c.isascii() and c.isalpha() else c for c in result) # Replace underscores with hyphens text_with_hyphens = lower_text.replace("_", "-") - # Replace spaces and unsafe ASCII chars with hyphens, preserve non-ASCII chars - # Includes Chinese character ranges (CJK Unified Ideographs, CJK symbols, etc.) + # Remove apostrophes entirely (don't replace with hyphens) + text_no_apostrophes = text_with_hyphens.replace("'", "") + + # Replace unsafe chars with hyphens, but preserve CJK characters clean_text = re.sub( - r"[^a-z0-9\u4e00-\u9fff\u3000-\u303f\u3400-\u4dbf/\-]", "-", text_with_hyphens + r"[^a-z0-9\u4e00-\u9fff\u3000-\u303f\u3400-\u4dbf/\-]", "-", text_no_apostrophes ) else: # Original ASCII-only processing for backward compatibility @@ -88,8 +115,11 @@ def generate_permalink(file_path: Union[Path, str, PathLike]) -> str: # replace underscores with hyphens text_with_hyphens = lower_text.replace("_", "-") + # Remove apostrophes entirely (don't replace with hyphens) + text_no_apostrophes = text_with_hyphens.replace("'", "") + # Replace remaining invalid chars with hyphens - clean_text = re.sub(r"[^a-z0-9/\-]", "-", text_with_hyphens) + clean_text = re.sub(r"[^a-z0-9/\-]", "-", text_no_apostrophes) # Collapse multiple hyphens clean_text = re.sub(r"-+", "-", clean_text) @@ -187,3 +217,105 @@ def parse_tags(tags: Union[List[str], str, None]) -> List[str]: except (ValueError, TypeError): # pragma: no cover logger.warning(f"Couldn't parse tags from input of type {type(tags)}: {tags}") return [] + + +def normalize_file_path_for_comparison(file_path: str) -> str: + """Normalize a file path for conflict detection. + + This function normalizes file paths to help detect potential conflicts: + - Converts to lowercase for case-insensitive comparison + - Normalizes Unicode characters + - Handles path separators consistently + + Args: + file_path: The file path to normalize + + Returns: + Normalized file path for comparison purposes + """ + import unicodedata + + # Convert to lowercase for case-insensitive comparison + normalized = file_path.lower() + + # Normalize Unicode characters (NFD normalization) + normalized = unicodedata.normalize('NFD', normalized) + + # Replace path separators with forward slashes + normalized = normalized.replace('\\', '/') + + # Remove multiple slashes + normalized = re.sub(r'/+', '/', normalized) + + return normalized + + +def detect_potential_file_conflicts(file_path: str, existing_paths: List[str]) -> List[str]: + """Detect potential conflicts between a file path and existing paths. + + This function checks for various types of conflicts: + - Case sensitivity differences + - Unicode normalization differences + - Path separator differences + - Permalink generation conflicts + + Args: + file_path: The file path to check + existing_paths: List of existing file paths to check against + + Returns: + List of existing paths that might conflict with the given file path + """ + conflicts = [] + + # Normalize the input file path + normalized_input = normalize_file_path_for_comparison(file_path) + input_permalink = generate_permalink(file_path) + + for existing_path in existing_paths: + # Skip identical paths + if existing_path == file_path: + continue + + # Check for case-insensitive path conflicts + normalized_existing = normalize_file_path_for_comparison(existing_path) + if normalized_input == normalized_existing: + conflicts.append(existing_path) + continue + + # Check for permalink conflicts + existing_permalink = generate_permalink(existing_path) + if input_permalink == existing_permalink: + conflicts.append(existing_path) + continue + + return conflicts + + +def validate_project_path(path: str, project_path: Path) -> bool: + """Ensure path stays within project boundaries.""" + # Allow empty strings as they resolve to the project root + if not path: + return True + + # Check for obvious path traversal patterns first + if ".." in path or "~" in path: + return False + + # Check for Windows-style path traversal (even on Unix systems) + if "\\.." in path or path.startswith("\\"): + return False + + # Block absolute paths (Unix-style starting with / or Windows-style with drive letters) + if path.startswith("/") or (len(path) >= 2 and path[1] == ":"): + return False + + # Block paths with control characters (but allow whitespace that will be stripped) + if path.strip() and any(ord(c) < 32 and c not in [" ", "\t"] for c in path): + return False + + try: + resolved = (project_path / path).resolve() + return resolved.is_relative_to(project_path.resolve()) + except (ValueError, OSError): + return False \ No newline at end of file diff --git a/tests/mcp/test_tool_read_note.py b/tests/mcp/test_tool_read_note.py index 831e98886..04dbc747e 100644 --- a/tests/mcp/test_tool_read_note.py +++ b/tests/mcp/test_tool_read_note.py @@ -26,7 +26,7 @@ async def mock_call_get(): @pytest_asyncio.fixture async def mock_search(): """Mock for search tool.""" - with patch("basic_memory.mcp.tools.read_note.search_notes") as mock: + with patch("basic_memory.mcp.tools.read_note.search_notes.fn") as mock: # Default to empty results mock.return_value = SearchResponse(results=[], current_page=1, page_size=1) yield mock @@ -36,10 +36,10 @@ async def mock_search(): async def test_read_note_by_title(app): """Test reading a note by its title.""" # First create a note - await write_note(title="Special Note", folder="test", content="Note content here") + await write_note.fn(title="Special Note", folder="test", content="Note content here") # Should be able to read it by title - content = await read_note("Special Note") + content = await read_note.fn("Special Note") assert "Note content here" in content @@ -47,7 +47,7 @@ async def test_read_note_by_title(app): async def test_note_unicode_content(app): """Test handling of unicode content in""" content = "# Test 🚀\nThis note has emoji 🎉 and unicode ♠♣♥♦" - result = await write_note(title="Unicode Test", folder="test", content=content) + result = await write_note.fn(title="Unicode Test", folder="test", content=content) assert ( dedent(""" @@ -60,7 +60,7 @@ async def test_note_unicode_content(app): ) # Read back should preserve unicode - result = await read_note("test/unicode-test") + result = await read_note.fn("test/unicode-test") assert content in result @@ -75,16 +75,16 @@ async def test_multiple_notes(app): ] for _, title, folder, content, tags in notes_data: - await write_note(title=title, folder=folder, content=content, tags=tags) + await write_note.fn(title=title, folder=folder, content=content, tags=tags) # Should be able to read each one for permalink, title, folder, content, _ in notes_data: - note = await read_note(permalink) + note = await read_note.fn(permalink) assert content in note # read multiple notes at once - result = await read_note("test/*") + result = await read_note.fn("test/*") # note we can't compare times assert "--- memory://test/note-1" in result @@ -108,15 +108,15 @@ async def test_multiple_notes_pagination(app): ] for _, title, folder, content, tags in notes_data: - await write_note(title=title, folder=folder, content=content, tags=tags) + await write_note.fn(title=title, folder=folder, content=content, tags=tags) # Should be able to read each one for permalink, title, folder, content, _ in notes_data: - note = await read_note(permalink) + note = await read_note.fn(permalink) assert content in note # read multiple notes at once with pagination - result = await read_note("test/*", page=1, page_size=2) + result = await read_note.fn("test/*", page=1, page_size=2) # note we can't compare times assert "--- memory://test/note-1" in result @@ -136,7 +136,7 @@ async def test_read_note_memory_url(app): - Return the note content """ # First create a note - result = await write_note( + result = await write_note.fn( title="Memory URL Test", folder="test", content="Testing memory:// URL handling", @@ -145,7 +145,7 @@ async def test_read_note_memory_url(app): # Should be able to read it with a memory:// URL memory_url = "memory://test/memory-url-test" - content = await read_note(memory_url) + content = await read_note.fn(memory_url) assert "Testing memory:// URL handling" in content @@ -159,7 +159,7 @@ async def test_read_note_direct_success(mock_call_get): mock_call_get.return_value = mock_response # Call the function - result = await read_note("test/test-note") + result = await read_note.fn("test/test-note") # Verify direct lookup was used mock_call_get.assert_called_once() @@ -199,7 +199,7 @@ async def test_read_note_title_search_fallback(mock_call_get, mock_search): ) # Call the function - result = await read_note("Test Note") + result = await read_note.fn("Test Note") # Verify title search was used mock_search.assert_called_once() @@ -253,7 +253,7 @@ async def test_read_note_text_search_fallback(mock_call_get, mock_search): ] # Call the function - result = await read_note("some query") + result = await read_note.fn("some query") # Verify both search types were used assert mock_search.call_count == 2 @@ -281,7 +281,7 @@ async def test_read_note_complete_fallback(mock_call_get, mock_search): mock_search.return_value = SearchResponse(results=[], current_page=1, page_size=1) # Call the function - result = await read_note("nonexistent") + result = await read_note.fn("nonexistent") # Verify search was used assert mock_search.call_count == 2 @@ -294,3 +294,326 @@ async def test_read_note_complete_fallback(mock_call_get, mock_search): assert "Recent Activity" in result assert "Create New Note" in result assert "write_note(" in result + + +class TestReadNoteSecurityValidation: + """Test read_note security validation features.""" + + @pytest.mark.asyncio + async def test_read_note_blocks_path_traversal_unix(self, app): + """Test that Unix-style path traversal attacks are blocked in identifier parameter.""" + # Test various Unix-style path traversal patterns + attack_identifiers = [ + "../secrets.txt", + "../../etc/passwd", + "../../../root/.ssh/id_rsa", + "notes/../../../etc/shadow", + "folder/../../outside/file.md", + "../../../../etc/hosts", + "../../../home/user/.env", + ] + + for attack_identifier in attack_identifiers: + result = await read_note.fn(identifier=attack_identifier) + + assert isinstance(result, str) + assert "# Error" in result + assert "paths must stay within project boundaries" in result + assert attack_identifier in result + + @pytest.mark.asyncio + async def test_read_note_blocks_path_traversal_windows(self, app): + """Test that Windows-style path traversal attacks are blocked in identifier parameter.""" + # Test various Windows-style path traversal patterns + attack_identifiers = [ + "..\\secrets.txt", + "..\\..\\Windows\\System32\\config\\SAM", + "notes\\..\\..\\..\\Windows\\System32", + "\\\\server\\share\\file.txt", + "..\\..\\Users\\user\\.env", + "\\\\..\\..\\Windows", + "..\\..\\..\\Boot.ini", + ] + + for attack_identifier in attack_identifiers: + result = await read_note.fn(identifier=attack_identifier) + + assert isinstance(result, str) + assert "# Error" in result + assert "paths must stay within project boundaries" in result + assert attack_identifier in result + + @pytest.mark.asyncio + async def test_read_note_blocks_absolute_paths(self, app): + """Test that absolute paths are blocked in identifier parameter.""" + # Test various absolute path patterns + attack_identifiers = [ + "/etc/passwd", + "/home/user/.env", + "/var/log/auth.log", + "/root/.ssh/id_rsa", + "C:\\Windows\\System32\\config\\SAM", + "C:\\Users\\user\\.env", + "D:\\secrets\\config.json", + "/tmp/malicious.txt", + "/usr/local/bin/evil", + ] + + for attack_identifier in attack_identifiers: + result = await read_note.fn(identifier=attack_identifier) + + assert isinstance(result, str) + assert "# Error" in result + assert "paths must stay within project boundaries" in result + assert attack_identifier in result + + @pytest.mark.asyncio + async def test_read_note_blocks_home_directory_access(self, app): + """Test that home directory access patterns are blocked in identifier parameter.""" + # Test various home directory access patterns + attack_identifiers = [ + "~/secrets.txt", + "~/.env", + "~/.ssh/id_rsa", + "~/Documents/passwords.txt", + "~\\AppData\\secrets", + "~\\Desktop\\config.ini", + "~/.bashrc", + "~/Library/Preferences/secret.plist", + ] + + for attack_identifier in attack_identifiers: + result = await read_note.fn(identifier=attack_identifier) + + assert isinstance(result, str) + assert "# Error" in result + assert "paths must stay within project boundaries" in result + assert attack_identifier in result + + @pytest.mark.asyncio + async def test_read_note_blocks_memory_url_attacks(self, app): + """Test that memory URLs with path traversal are blocked.""" + # Test memory URLs with attacks embedded + attack_identifiers = [ + "memory://../../etc/passwd", + "memory://../../../root/.ssh/id_rsa", + "memory://~/.env", + "memory:///etc/passwd", + "memory://notes/../../../etc/shadow", + "memory://..\\..\\Windows\\System32", + ] + + for attack_identifier in attack_identifiers: + result = await read_note.fn(identifier=attack_identifier) + + assert isinstance(result, str) + assert "# Error" in result + assert "paths must stay within project boundaries" in result + + @pytest.mark.asyncio + async def test_read_note_blocks_mixed_attack_patterns(self, app): + """Test that mixed legitimate/attack patterns are blocked in identifier parameter.""" + # Test mixed patterns that start legitimate but contain attacks + attack_identifiers = [ + "notes/../../../etc/passwd", + "docs/../../.env", + "legitimate/path/../../.ssh/id_rsa", + "project/folder/../../../Windows/System32", + "valid/folder/../../home/user/.bashrc", + "assets/../../../tmp/evil.exe", + ] + + for attack_identifier in attack_identifiers: + result = await read_note.fn(identifier=attack_identifier) + + assert isinstance(result, str) + assert "# Error" in result + assert "paths must stay within project boundaries" in result + + @pytest.mark.asyncio + async def test_read_note_allows_safe_identifiers(self, app): + """Test that legitimate identifiers are still allowed.""" + # Test various safe identifier patterns + safe_identifiers = [ + "notes/meeting", + "docs/readme", + "projects/2025/planning", + "archive/old-notes/backup", + "folder/subfolder/document", + "research/ml/algorithms", + "meeting-notes", + "test/simple-note", + ] + + for safe_identifier in safe_identifiers: + result = await read_note.fn(identifier=safe_identifier) + + assert isinstance(result, str) + # Should not contain security error message + assert ( + "# Error" not in result or "paths must stay within project boundaries" not in result + ) + # Should either succeed or fail for legitimate reasons (not found, etc.) + # but not due to security validation + + @pytest.mark.asyncio + async def test_read_note_allows_legitimate_titles(self, app): + """Test that legitimate note titles work normally.""" + # Create a test note first + await write_note.fn( + title="Security Test Note", + folder="security-tests", + content="# Security Test Note\nThis is a legitimate note for security testing.", + ) + + # Test reading by title (should work) + result = await read_note.fn("Security Test Note") + + assert isinstance(result, str) + # Should not be a security error + assert "# Error" not in result or "paths must stay within project boundaries" not in result + # Should either return the note content or search results + + @pytest.mark.asyncio + async def test_read_note_empty_identifier_security(self, app): + """Test that empty identifier is handled securely.""" + # Empty identifier should be allowed (may return search results or error, but not security error) + result = await read_note.fn(identifier="") + + assert isinstance(result, str) + # Empty identifier should not trigger security error + assert "# Error" not in result or "paths must stay within project boundaries" not in result + + @pytest.mark.asyncio + async def test_read_note_security_with_all_parameters(self, app): + """Test security validation works with all read_note parameters.""" + # Test that security validation is applied even when all other parameters are provided + result = await read_note.fn( + identifier="../../../etc/malicious", + page=1, + page_size=5, + project=None, # Use default project + ) + + assert isinstance(result, str) + assert "# Error" in result + assert "paths must stay within project boundaries" in result + assert "../../../etc/malicious" in result + + @pytest.mark.asyncio + async def test_read_note_security_logging(self, app, caplog): + """Test that security violations are properly logged.""" + # Attempt path traversal attack + result = await read_note.fn(identifier="../../../etc/passwd") + + assert "# Error" in result + assert "paths must stay within project boundaries" in result + + # Check that security violation was logged + # Note: This test may need adjustment based on the actual logging setup + # The security validation should generate a warning log entry + + @pytest.mark.asyncio + async def test_read_note_preserves_functionality_with_security(self, app): + """Test that security validation doesn't break normal note reading functionality.""" + # Create a note with complex content to ensure security validation doesn't interfere + await write_note.fn( + title="Full Feature Security Test Note", + folder="security-tests", + content=dedent(""" + # Full Feature Security Test Note + + This note tests that security validation doesn't break normal functionality. + + ## Observations + - [security] Path validation working correctly #security + - [feature] All features still functional #test + + ## Relations + - relates_to [[Security Implementation]] + - depends_on [[Path Validation]] + + Additional content with various formatting. + """).strip(), + tags=["security", "test", "full-feature"], + entity_type="guide", + ) + + # Test reading by permalink + result = await read_note.fn("security-tests/full-feature-security-test-note") + + # Should succeed normally (not a security error) + assert isinstance(result, str) + assert "# Error" not in result or "paths must stay within project boundaries" not in result + # Should either return content or search results, but not security error + + +class TestReadNoteSecurityEdgeCases: + """Test edge cases for read_note security validation.""" + + @pytest.mark.asyncio + async def test_read_note_unicode_identifier_attacks(self, app): + """Test that Unicode-based path traversal attempts are blocked.""" + # Test Unicode path traversal attempts + unicode_attack_identifiers = [ + "notes/文档/../../../etc/passwd", # Chinese characters + "docs/café/../../.env", # Accented characters + "files/αβγ/../../../secret.txt", # Greek characters + ] + + for attack_identifier in unicode_attack_identifiers: + result = await read_note.fn(identifier=attack_identifier) + + assert isinstance(result, str) + assert "# Error" in result + assert "paths must stay within project boundaries" in result + + @pytest.mark.asyncio + async def test_read_note_very_long_attack_identifier(self, app): + """Test handling of very long attack identifiers.""" + # Create a very long path traversal attack + long_attack_identifier = "../" * 1000 + "etc/malicious" + + result = await read_note.fn(identifier=long_attack_identifier) + + assert isinstance(result, str) + assert "# Error" in result + assert "paths must stay within project boundaries" in result + + @pytest.mark.asyncio + async def test_read_note_case_variations_attacks(self, app): + """Test that case variations don't bypass security.""" + # Test case variations (though case sensitivity depends on filesystem) + case_attack_identifiers = [ + "../ETC/passwd", + "../Etc/PASSWD", + "..\\WINDOWS\\system32", + "~/.SSH/id_rsa", + ] + + for attack_identifier in case_attack_identifiers: + result = await read_note.fn(identifier=attack_identifier) + + assert isinstance(result, str) + assert "# Error" in result + assert "paths must stay within project boundaries" in result + + @pytest.mark.asyncio + async def test_read_note_whitespace_in_attack_identifiers(self, app): + """Test that whitespace doesn't help bypass security.""" + # Test attack identifiers with various whitespace + whitespace_attack_identifiers = [ + " ../../../etc/passwd ", + "\t../../../secrets\t", + " ..\\..\\Windows ", + "notes/ ../../ malicious", + ] + + for attack_identifier in whitespace_attack_identifiers: + result = await read_note.fn(identifier=attack_identifier) + + assert isinstance(result, str) + # The attack should still be blocked even with whitespace + if ".." in attack_identifier.strip() or "~" in attack_identifier.strip(): + assert "# Error" in result + assert "paths must stay within project boundaries" in result diff --git a/tests/repository/test_search_repository_edit_bug_fix.py b/tests/repository/test_search_repository_edit_bug_fix.py new file mode 100644 index 000000000..0f4209341 --- /dev/null +++ b/tests/repository/test_search_repository_edit_bug_fix.py @@ -0,0 +1,270 @@ +"""Tests for the search repository edit bug fix. + +This test reproduces the critical bug where editing notes causes them to disappear +from the search index due to missing project_id filter in index_item() method. +""" + +from datetime import datetime, timezone + +import pytest +import pytest_asyncio + +from basic_memory.models.project import Project +from basic_memory.repository.search_repository import SearchRepository, SearchIndexRow +from basic_memory.schemas.search import SearchItemType + + +@pytest_asyncio.fixture +async def second_test_project(project_repository): + """Create a second project for testing project isolation during edits.""" + project_data = { + "name": "Second Edit Test Project", + "description": "Another project for testing edit bug", + "path": "/second/edit/test/path", + "is_active": True, + "is_default": None, + } + return await project_repository.create(project_data) + + +@pytest_asyncio.fixture +async def second_search_repo(session_maker, second_test_project): + """Create a search repository for the second project.""" + return SearchRepository(session_maker, project_id=second_test_project.id) + + +@pytest.mark.asyncio +async def test_index_item_respects_project_isolation_during_edit(): + """Test that index_item() doesn't delete records from other projects during edits. + + This test reproduces the critical bug where editing a note in one project + would delete search index entries with the same permalink from ALL projects, + causing notes to disappear from the search index. + """ + from basic_memory import db + from basic_memory.models.base import Base + from basic_memory.repository.search_repository import SearchRepository + from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker + + # Create a separate in-memory database for this test + engine = create_async_engine("sqlite+aiosqlite:///:memory:") + session_maker = async_sessionmaker(engine, expire_on_commit=False) + + # Create the database schema + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + + # Create two projects + async with db.scoped_session(session_maker) as session: + project1 = Project( + name="Project 1", + description="First project", + path="/project1/path", + is_active=True, + is_default=True + ) + project2 = Project( + name="Project 2", + description="Second project", + path="/project2/path", + is_active=True, + is_default=False + ) + session.add(project1) + session.add(project2) + await session.flush() + + project1_id = project1.id + project2_id = project2.id + await session.commit() + + # Create search repositories for both projects + repo1 = SearchRepository(session_maker, project_id=project1_id) + repo2 = SearchRepository(session_maker, project_id=project2_id) + + # Initialize search index + await repo1.init_search_index() + + # Create two notes with the SAME permalink in different projects + # This simulates the same note name/structure across different projects + same_permalink = "notes/test-note" + + search_row1 = SearchIndexRow( + id=1, + type=SearchItemType.ENTITY.value, + title="Test Note in Project 1", + content_stems="project 1 content original", + content_snippet="This is the original content in project 1", + permalink=same_permalink, + file_path="notes/test_note.md", + entity_id=1, + metadata={"entity_type": "note"}, + created_at=datetime.now(timezone.utc), + updated_at=datetime.now(timezone.utc), + project_id=project1_id, + ) + + search_row2 = SearchIndexRow( + id=2, + type=SearchItemType.ENTITY.value, + title="Test Note in Project 2", + content_stems="project 2 content original", + content_snippet="This is the original content in project 2", + permalink=same_permalink, # SAME permalink as project 1 + file_path="notes/test_note.md", + entity_id=2, + metadata={"entity_type": "note"}, + created_at=datetime.now(timezone.utc), + updated_at=datetime.now(timezone.utc), + project_id=project2_id, + ) + + # Index both items in their respective projects + await repo1.index_item(search_row1) + await repo2.index_item(search_row2) + + # Verify both projects can find their respective notes + results1_before = await repo1.search(search_text="project 1 content") + assert len(results1_before) == 1 + assert results1_before[0].title == "Test Note in Project 1" + assert results1_before[0].project_id == project1_id + + results2_before = await repo2.search(search_text="project 2 content") + assert len(results2_before) == 1 + assert results2_before[0].title == "Test Note in Project 2" + assert results2_before[0].project_id == project2_id + + # Now simulate editing the note in project 1 (which re-indexes it) + # This would trigger the bug where the DELETE query doesn't filter by project_id + edited_search_row1 = SearchIndexRow( + id=1, + type=SearchItemType.ENTITY.value, + title="Test Note in Project 1", + content_stems="project 1 content EDITED", # Changed content + content_snippet="This is the EDITED content in project 1", + permalink=same_permalink, + file_path="notes/test_note.md", + entity_id=1, + metadata={"entity_type": "note"}, + created_at=datetime.now(timezone.utc), + updated_at=datetime.now(timezone.utc), + project_id=project1_id, + ) + + # Re-index the edited note in project 1 + # BEFORE THE FIX: This would delete the note from project 2 as well! + await repo1.index_item(edited_search_row1) + + # Verify project 1 has the edited version + results1_after = await repo1.search(search_text="project 1 content EDITED") + assert len(results1_after) == 1 + assert results1_after[0].title == "Test Note in Project 1" + assert "EDITED" in results1_after[0].content_snippet + + # CRITICAL TEST: Verify project 2's note is still there (the bug would delete it) + results2_after = await repo2.search(search_text="project 2 content") + assert len(results2_after) == 1, "Project 2's note disappeared after editing project 1's note!" + assert results2_after[0].title == "Test Note in Project 2" + assert results2_after[0].project_id == project2_id + assert "original" in results2_after[0].content_snippet # Should still be original + + # Double-check: project 1 should not be able to see project 2's note + cross_search = await repo1.search(search_text="project 2 content") + assert len(cross_search) == 0 + + await engine.dispose() + + +@pytest.mark.asyncio +async def test_index_item_updates_existing_record_same_project(): + """Test that index_item() correctly updates existing records within the same project.""" + from basic_memory import db + from basic_memory.models.base import Base + from basic_memory.repository.search_repository import SearchRepository + from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker + + # Create a separate in-memory database for this test + engine = create_async_engine("sqlite+aiosqlite:///:memory:") + session_maker = async_sessionmaker(engine, expire_on_commit=False) + + # Create the database schema + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + + # Create one project + async with db.scoped_session(session_maker) as session: + project = Project( + name="Test Project", + description="Test project", + path="/test/path", + is_active=True, + is_default=True + ) + session.add(project) + await session.flush() + project_id = project.id + await session.commit() + + # Create search repository + repo = SearchRepository(session_maker, project_id=project_id) + await repo.init_search_index() + + permalink = "test/my-note" + + # Create initial note + initial_row = SearchIndexRow( + id=1, + type=SearchItemType.ENTITY.value, + title="My Test Note", + content_stems="initial content here", + content_snippet="This is the initial content", + permalink=permalink, + file_path="test/my_note.md", + entity_id=1, + metadata={"entity_type": "note"}, + created_at=datetime.now(timezone.utc), + updated_at=datetime.now(timezone.utc), + project_id=project_id, + ) + + # Index the initial version + await repo.index_item(initial_row) + + # Verify it exists + results_initial = await repo.search(search_text="initial content") + assert len(results_initial) == 1 + assert results_initial[0].content_snippet == "This is the initial content" + + # Now update the note (simulate an edit) + updated_row = SearchIndexRow( + id=1, + type=SearchItemType.ENTITY.value, + title="My Test Note", + content_stems="updated content here", # Changed + content_snippet="This is the UPDATED content", # Changed + permalink=permalink, # Same permalink + file_path="test/my_note.md", + entity_id=1, + metadata={"entity_type": "note"}, + created_at=datetime.now(timezone.utc), + updated_at=datetime.now(timezone.utc), + project_id=project_id, + ) + + # Re-index (should replace the old version) + await repo.index_item(updated_row) + + # Verify the old version is gone + results_old = await repo.search(search_text="initial content") + assert len(results_old) == 0 + + # Verify the new version exists + results_new = await repo.search(search_text="updated content") + assert len(results_new) == 1 + assert results_new[0].content_snippet == "This is the UPDATED content" + + # Verify we only have one record (not duplicated) + all_results = await repo.search(search_text="My Test Note") + assert len(all_results) == 1 + + await engine.dispose() \ No newline at end of file