Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
188 changes: 160 additions & 28 deletions aperag/mcp/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,42 @@

import logging
import os
from typing import Any, Dict
from typing import Any, Dict, Optional

import httpx
from fastmcp import FastMCP
from fastmcp.server.dependencies import get_http_headers

from aperag.domains.knowledge_base.schemas import CollectionViewList

# Import view models for type safety
from aperag.domains.retrieval.schemas import SearchResult
from aperag.domains.web_access.schemas import WebReadResponse, WebSearchResponse
from aperag.mcp.tools import (
ByteRange,
)
from aperag.mcp.tools import (
get_collection_metadata as _d10c_get_collection_metadata,
)
from aperag.mcp.tools import (
get_document_metadata as _d10c_get_document_metadata,
)
from aperag.mcp.tools import (
list_collections as _d10c_list_collections,
)
from aperag.mcp.tools import (
list_documents as _d10c_list_documents,
)
from aperag.mcp.tools import (
read_document as _d10c_read_document,
)
from aperag.mcp.tools import (
read_document_chunk as _d10c_read_document_chunk,
)
from aperag.mcp.tools import (
read_document_outline as _d10c_read_document_outline,
)
from aperag.mcp.tools import (
read_document_section as _d10c_read_document_section,
)

logger = logging.getLogger(__name__)

Expand All @@ -35,8 +60,32 @@
API_BASE_URL = "http://localhost:8000"


# === D10.c read primitives ===
#
# Per docs/modularization/d10-design-pack.md §A — 8 read primitives that
# replace the legacy HTTP-delegated list_collections + add 7 net-new
# tools (list_documents / get_collection_metadata / get_document_metadata
# / read_document / read_document_outline / read_document_section /
# read_document_chunk).
#
# Each primitive enforces (in order, never cache-shortcut per §E.7):
# 1. tenancy gate (D9 base canonical SoT — db_ops.query_collection)
# 2. 3-level authorization (D9 §2 — tools/authorization.py)
# 3. parse_version computation (only the 4 parse-version-keyed primitives)
# 4. authoritative fetch (un-cached; D10.g #99 wires cache around this)
#
# chenyexuan's D10.d (#96) split-search registrations land adjacent to
# this block — append below the marker, no merge churn expected.


@mcp_server.tool
async def list_collections() -> Dict[str, Any]:
async def list_collections(
cursor: Optional[str] = None,
limit: int = 50,
sort_by: str = "created_at",
sort_order: str = "desc",
title_filter: Optional[str] = None,
) -> Dict[str, Any]:
"""Discover which knowledge bases the current user can access.

Use this when:
Expand All @@ -63,32 +112,115 @@ async def list_collections() -> Dict[str, Any]:
- After completion: "Checked which knowledge bases are available."

Returns:
List of collections with only essential information (id, title, description)
for secure and efficient LLM use.
Paginated CollectionList envelope per D10 §A.1 (items + next_cursor + total_count).
"""
result = await _d10c_list_collections(
cursor=cursor,
limit=limit,
sort_by=sort_by, # type: ignore[arg-type]
sort_order=sort_order, # type: ignore[arg-type]
title_filter=title_filter,
)
return result.model_dump()

Note:
Uses CollectionViewList view model for type-safe response parsing but filters
sensitive and unnecessary information.

@mcp_server.tool
async def list_documents(
collection_id: str,
cursor: Optional[str] = None,
limit: int = 50,
sort_by: str = "created_at",
sort_order: str = "desc",
title_filter: Optional[str] = None,
type_filter: Optional[list[str]] = None,
indexed_only: bool = False,
) -> Dict[str, Any]:
"""List documents within a collection. D10 §A.2."""
result = await _d10c_list_documents(
collection_id,
cursor=cursor,
limit=limit,
sort_by=sort_by, # type: ignore[arg-type]
sort_order=sort_order, # type: ignore[arg-type]
title_filter=title_filter,
type_filter=type_filter,
indexed_only=indexed_only,
)
return result.model_dump()


@mcp_server.tool
async def get_collection_metadata(collection_id: str) -> Dict[str, Any]:
"""Get full metadata for a specific collection. D10 §A.4."""
result = await _d10c_get_collection_metadata(collection_id)
return result.model_dump()


@mcp_server.tool
async def get_document_metadata(collection_id: str, document_id: str) -> Dict[str, Any]:
"""Get metadata for a specific document. D10 §A.3."""
result = await _d10c_get_document_metadata(collection_id, document_id)
return result.model_dump()


@mcp_server.tool
async def read_document(
collection_id: str,
document_id: str,
range_start: Optional[int] = None,
range_end: Optional[int] = None,
) -> Dict[str, Any]:
"""Read parsed markdown content of a document. D10 §A.5.

Optional byte range is best-effort and NOT stable across re-parse.
"""
try:
api_key = get_api_key()
async with httpx.AsyncClient(timeout=30.0) as client:
response = await client.get(
f"{API_BASE_URL}/api/v2/collections", headers={"Authorization": f"Bearer {api_key}"}
)
if response.status_code == 200:
try:
# Parse response using view model for type safety
collection_list = CollectionViewList.model_validate(response.json())
# Return the modified object using model_dump()
return collection_list.model_dump()
except Exception as e:
logger.error(f"Failed to parse collections response: {e}")
return {"error": "Failed to parse collections response", "details": str(e)}
else:
return {"error": f"Failed to fetch collections: {response.status_code}", "details": response.text}
except ValueError as e:
return {"error": str(e)}
byte_range: Optional[ByteRange] = None
if range_start is not None and range_end is not None:
byte_range = ByteRange(start=range_start, end=range_end)
result = await _d10c_read_document(collection_id, document_id, range=byte_range)
return result.model_dump()


@mcp_server.tool
async def read_document_outline(
collection_id: str,
document_id: str,
max_depth: int = 6,
) -> Dict[str, Any]:
"""Read the heading tree (table of contents) of a document. D10 §A.6."""
result = await _d10c_read_document_outline(collection_id, document_id, max_depth=max_depth)
return result.model_dump()


@mcp_server.tool
async def read_document_section(
collection_id: str,
document_id: str,
section_path: Optional[str] = None,
heading_anchor: Optional[str] = None,
) -> Dict[str, Any]:
"""Read a section by section_path (preferred) or heading_anchor. D10 §A.7."""
result = await _d10c_read_document_section(
collection_id,
document_id,
section_path=section_path,
heading_anchor=heading_anchor,
)
return result.model_dump()


@mcp_server.tool
async def read_document_chunk(
collection_id: str,
document_id: str,
chunk_id: str,
) -> Dict[str, Any]:
"""Read a chunk by stable chunk_id. D10 §A.8."""
result = await _d10c_read_document_chunk(collection_id, document_id, chunk_id)
return result.model_dump()


# === end D10.c read primitives ===


@mcp_server.tool
Expand Down
4 changes: 4 additions & 0 deletions aperag/mcp/tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from aperag.mcp.tools.handles import ChunkId, HeadingAnchor, SectionPath
from aperag.mcp.tools.list_collections import list_collections
from aperag.mcp.tools.list_documents import list_documents
from aperag.mcp.tools.parse_version import ParseVersionT, compute_parse_version
from aperag.mcp.tools.read_document import read_document
from aperag.mcp.tools.read_document_chunk import read_document_chunk
from aperag.mcp.tools.read_document_outline import read_document_outline
Expand All @@ -53,6 +54,9 @@
"ChunkId",
"HeadingAnchor",
"SectionPath",
# parse_version helpers (§E.2)
"ParseVersionT",
"compute_parse_version",
# Read primitive functions (per §A.1 - §A.8)
"list_collections",
"list_documents",
Expand Down
Loading
Loading