Skip to content

Commit 272f533

Browse files
committed
Refactor search API to consistently handle types and entity_types
Updated the naming conventions and handling of `types` and `entity_types` across the codebase for consistency. Modified function definitions, queries, and filters to ensure clear separation and proper usage. Adjusted and added tests to align with the new structure.
1 parent 2006528 commit 272f533

16 files changed

Lines changed: 183 additions & 79 deletions

File tree

src/basic_memory/cli/commands/tool.py

Lines changed: 28 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,31 +2,29 @@
22

33
import asyncio
44
import sys
5-
from typing import Optional, List, Annotated
5+
from typing import Annotated, List, Optional
66

77
import typer
88
from loguru import logger
99
from rich import print as rprint
1010

1111
from basic_memory.cli.app import app
12-
from basic_memory.mcp.tools import build_context as mcp_build_context
13-
from basic_memory.mcp.tools import read_note as mcp_read_note
14-
from basic_memory.mcp.tools import recent_activity as mcp_recent_activity
15-
from basic_memory.mcp.tools import search_notes as mcp_search
16-
from basic_memory.mcp.tools import write_note as mcp_write_note
1712

1813
# Import prompts
1914
from basic_memory.mcp.prompts.continue_conversation import (
2015
continue_conversation as mcp_continue_conversation,
2116
)
22-
2317
from basic_memory.mcp.prompts.recent_activity import (
2418
recent_activity_prompt as recent_activity_prompt,
2519
)
26-
20+
from basic_memory.mcp.tools import build_context as mcp_build_context
21+
from basic_memory.mcp.tools import read_note as mcp_read_note
22+
from basic_memory.mcp.tools import recent_activity as mcp_recent_activity
23+
from basic_memory.mcp.tools import search_notes as mcp_search
24+
from basic_memory.mcp.tools import write_note as mcp_write_note
2725
from basic_memory.schemas.base import TimeFrame
2826
from basic_memory.schemas.memory import MemoryUrl
29-
from basic_memory.schemas.search import SearchQuery, SearchItemType
27+
from basic_memory.schemas.search import SearchItemType
3028

3129
tool_app = typer.Typer()
3230
app.add_typer(tool_app, name="tool", help="Access to MCP tools via CLI")
@@ -198,13 +196,28 @@ def search_notes(
198196
raise typer.Abort()
199197

200198
try:
201-
search_query = SearchQuery(
202-
permalink_match=query if permalink else None,
203-
text=query if not (permalink or title) else None,
204-
title=query if title else None,
205-
after_date=after_date,
199+
if permalink and title: # pragma: no cover
200+
typer.echo(
201+
"Use either --permalink or --title, not both. Exiting.",
202+
err=True,
203+
)
204+
raise typer.Exit(1)
205+
206+
# set search type
207+
search_type = ("permalink" if permalink else None,)
208+
search_type = ("permalink_match" if permalink and "*" in query else None,)
209+
search_type = ("title" if title else None,)
210+
search_type = "text" if search_type is None else search_type
211+
212+
results = asyncio.run(
213+
mcp_search(
214+
query,
215+
search_type=search_type,
216+
page=page,
217+
after_date=after_date,
218+
page_size=page_size,
219+
)
206220
)
207-
results = asyncio.run(mcp_search(query=search_query, page=page, page_size=page_size))
208221
# Use json module for more controlled serialization
209222
import json
210223

src/basic_memory/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,7 @@ def get_process_name(): # pragma: no cover
234234
# Global flag to track if logging has been set up
235235
_LOGGING_SETUP = False
236236

237+
237238
def setup_basic_memory_logging(): # pragma: no cover
238239
"""Set up logging for basic-memory, ensuring it only happens once."""
239240
global _LOGGING_SETUP

src/basic_memory/mcp/prompts/continue_conversation.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,19 +5,19 @@
55
"""
66

77
from textwrap import dedent
8-
from typing import Optional, Annotated
8+
from typing import Annotated, Optional
99

1010
from loguru import logger
1111
from pydantic import Field
1212

13-
from basic_memory.mcp.prompts.utils import format_prompt_context, PromptContext, PromptContextItem
13+
from basic_memory.mcp.prompts.utils import PromptContext, PromptContextItem, format_prompt_context
1414
from basic_memory.mcp.server import mcp
1515
from basic_memory.mcp.tools.build_context import build_context
1616
from basic_memory.mcp.tools.recent_activity import recent_activity
1717
from basic_memory.mcp.tools.search import search_notes
1818
from basic_memory.schemas.base import TimeFrame
1919
from basic_memory.schemas.memory import GraphContext
20-
from basic_memory.schemas.search import SearchQuery, SearchItemType
20+
from basic_memory.schemas.search import SearchItemType
2121

2222

2323
@mcp.prompt(
@@ -48,7 +48,7 @@ async def continue_conversation(
4848
# If topic provided, search for it
4949
if topic:
5050
search_results = await search_notes(
51-
SearchQuery(text=topic, after_date=timeframe, types=[SearchItemType.ENTITY])
51+
query=topic, after_date=timeframe, entity_types=[SearchItemType.ENTITY]
5252
)
5353

5454
# Build context from results

src/basic_memory/mcp/prompts/search.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from basic_memory.mcp.server import mcp
1313
from basic_memory.mcp.tools.search import search_notes as search_tool
1414
from basic_memory.schemas.base import TimeFrame
15-
from basic_memory.schemas.search import SearchQuery, SearchResponse
15+
from basic_memory.schemas.search import SearchResponse
1616

1717

1818
@mcp.prompt(

src/basic_memory/mcp/tools/read_note.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
from basic_memory.mcp.tools.search import search_notes
1010
from basic_memory.mcp.tools.utils import call_get
1111
from basic_memory.schemas.memory import memory_url_path
12-
from basic_memory.schemas.search import SearchQuery
1312

1413

1514
@mcp.tool(

src/basic_memory/mcp/tools/search.py

Lines changed: 24 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,21 @@
11
"""Search tools for Basic Memory MCP server."""
22

3-
from typing import Optional, List
3+
from typing import List, Optional
4+
45
from loguru import logger
56

7+
from basic_memory.mcp.async_client import client
68
from basic_memory.mcp.server import mcp
79
from basic_memory.mcp.tools.utils import call_post
8-
from basic_memory.schemas.search import SearchQuery, SearchResponse, SearchItemType
9-
from basic_memory.mcp.async_client import client
10+
from basic_memory.schemas.search import SearchItemType, SearchQuery, SearchResponse
1011

1112

1213
@mcp.tool(
1314
description="Search across all content in the knowledge base.",
1415
)
1516
async def search_notes(
16-
query: str,
17-
page: int = 1,
17+
query: str,
18+
page: int = 1,
1819
page_size: int = 10,
1920
search_type: str = "text",
2021
types: Optional[List[str]] = None,
@@ -31,9 +32,9 @@ async def search_notes(
3132
query: The search query string
3233
page: The page number of results to return (default 1)
3334
page_size: The number of results to return per page (default 10)
34-
search_type: Type of search to perform, one of: "text", "title", "permalink", "permalink_match" (default: "text")
35-
types: Optional list of content types to search (e.g., ["entity", "observation"])
36-
entity_types: Optional list of entity types to filter by (e.g., ["note", "person"])
35+
search_type: Type of search to perform, one of: "text", "title", "permalink" (default: "text")
36+
types: Optional list of note types to search (e.g., ["note", "person"])
37+
entity_types: Optional list of entity types to filter by (e.g., ["entity", "observation"])
3738
after_date: Optional date filter for recent content (e.g., "1 week", "2d")
3839
3940
Returns:
@@ -61,6 +62,12 @@ async def search_notes(
6162
types=["entity"],
6263
)
6364
65+
# Search with entity type filter, e.g., note vs
66+
results = await search_notes(
67+
query="meeting notes",
68+
types=["entity"],
69+
)
70+
6471
# Search for recent content
6572
results = await search_notes(
6673
query="bug report",
@@ -70,32 +77,32 @@ async def search_notes(
7077
# Pattern matching on permalinks
7178
results = await search_notes(
7279
query="docs/meeting-*",
73-
search_type="permalink_match"
80+
search_type="permalink"
7481
)
7582
"""
7683
# Create a SearchQuery object based on the parameters
7784
search_query = SearchQuery()
78-
85+
7986
# Set the appropriate search field based on search_type
8087
if search_type == "text":
8188
search_query.text = query
8289
elif search_type == "title":
8390
search_query.title = query
91+
elif search_type == "permalink" and "*" in query:
92+
search_query.permalink_match = query
8493
elif search_type == "permalink":
8594
search_query.permalink = query
86-
elif search_type == "permalink_match":
87-
search_query.permalink_match = query
8895
else:
8996
search_query.text = query # Default to text search
90-
97+
9198
# Add optional filters if provided
92-
if types:
93-
search_query.types = [SearchItemType(t) for t in types]
9499
if entity_types:
95-
search_query.entity_types = entity_types
100+
search_query.entity_types = [SearchItemType(t) for t in entity_types]
101+
if types:
102+
search_query.types = types
96103
if after_date:
97104
search_query.after_date = after_date
98-
105+
99106
logger.info(f"Searching for {search_query}")
100107
response = await call_post(
101108
client,

src/basic_memory/repository/search_repository.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,10 @@
44
import time
55
from dataclasses import dataclass
66
from datetime import datetime
7-
from typing import List, Optional, Any, Dict
7+
from typing import Any, Dict, List, Optional
88

99
from loguru import logger
10-
from sqlalchemy import text, Executable, Result
10+
from sqlalchemy import Executable, Result, text
1111
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
1212

1313
from basic_memory import db
@@ -123,9 +123,9 @@ async def search(
123123
permalink: Optional[str] = None,
124124
permalink_match: Optional[str] = None,
125125
title: Optional[str] = None,
126-
types: Optional[List[SearchItemType]] = None,
126+
types: Optional[List[str]] = None,
127127
after_date: Optional[datetime] = None,
128-
entity_types: Optional[List[str]] = None,
128+
entity_types: Optional[List[SearchItemType]] = None,
129129
limit: int = 10,
130130
offset: int = 0,
131131
) -> List[SearchIndexRow]:
@@ -174,15 +174,15 @@ async def search(
174174
else:
175175
conditions.append("permalink MATCH :permalink")
176176

177-
# Handle type filter
178-
if types:
179-
type_list = ", ".join(f"'{t.value}'" for t in types)
180-
conditions.append(f"type IN ({type_list})")
181-
182177
# Handle entity type filter
183178
if entity_types:
184-
entity_type_list = ", ".join(f"'{t}'" for t in entity_types)
185-
conditions.append(f"json_extract(metadata, '$.entity_type') IN ({entity_type_list})")
179+
type_list = ", ".join(f"'{t.value}'" for t in entity_types)
180+
conditions.append(f"type IN ({type_list})")
181+
182+
# Handle type filter
183+
if types:
184+
type_list = ", ".join(f"'{t}'" for t in types)
185+
conditions.append(f"json_extract(metadata, '$.entity_type') IN ({type_list})")
186186

187187
# Handle date filter using datetime() for proper comparison
188188
if after_date:

src/basic_memory/schemas/search.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,8 @@ class SearchQuery(BaseModel):
4949
title: Optional[str] = None # title only search
5050

5151
# Optional filters
52-
types: Optional[List[SearchItemType]] = None # Filter by item type
53-
entity_types: Optional[List[str]] = None # Filter by entity type
52+
types: Optional[List[str]] = None # Filter by type
53+
entity_types: Optional[List[SearchItemType]] = None # Filter by entity type
5454
after_date: Optional[Union[datetime, str]] = None # Time-based filter
5555

5656
@field_validator("after_date")

src/basic_memory/services/context_service.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ async def build_context(
8181
else:
8282
logger.debug(f"Build context for '{types}'")
8383
primary = await self.search_repository.search(
84-
types=types, after_date=since, limit=limit, offset=offset
84+
entity_types=types, after_date=since, limit=limit, offset=offset
8585
)
8686

8787
# Get type_id pairs for traversal

src/basic_memory/services/link_resolver.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,15 +48,15 @@ async def resolve_link(self, link_text: str, use_search: bool = True) -> Optiona
4848

4949
# 3. Try file path
5050
found_path = await self.entity_repository.get_by_file_path(clean_text)
51-
if found_path :
51+
if found_path:
5252
logger.debug(f"Found entity with path: {found_path.file_path}")
5353
return found_path
54-
54+
5555
# search if indicated
5656
if use_search and "*" not in clean_text:
5757
# 3. Fall back to search for fuzzy matching on title
5858
results = await self.search_service.search(
59-
query=SearchQuery(title=clean_text, types=[SearchItemType.ENTITY]),
59+
query=SearchQuery(title=clean_text, entity_types=[SearchItemType.ENTITY]),
6060
)
6161

6262
if results:

0 commit comments

Comments
 (0)