diff --git a/.gitignore b/.gitignore index 23d6b5655..54edbf9e6 100644 --- a/.gitignore +++ b/.gitignore @@ -4,3 +4,7 @@ __pycache__ .env* .venv/ logs/ +pageindex.egg-info/ +*.db +venv/ +uv.lock diff --git a/examples/cloud_demo.py b/examples/cloud_demo.py new file mode 100644 index 000000000..cd3344b40 --- /dev/null +++ b/examples/cloud_demo.py @@ -0,0 +1,62 @@ +""" +Agentic Vectorless RAG with PageIndex SDK - Cloud Demo + +Uses CloudClient for fully-managed document indexing and QA. +No LLM API key needed — the cloud service handles everything. + +Steps: + 1 — Upload and index a PDF via PageIndex cloud + 2 — Stream a question with tool call visibility + +Requirements: + pip install pageindex + export PAGEINDEX_API_KEY=your-api-key +""" +import asyncio +import os +from pathlib import Path +import requests +from pageindex import CloudClient + +_EXAMPLES_DIR = Path(__file__).parent +PDF_URL = "https://arxiv.org/pdf/1706.03762.pdf" +PDF_PATH = _EXAMPLES_DIR / "documents" / "attention.pdf" + +# Download PDF if needed +if not PDF_PATH.exists(): + print(f"Downloading {PDF_URL} ...") + PDF_PATH.parent.mkdir(parents=True, exist_ok=True) + with requests.get(PDF_URL, stream=True, timeout=30) as r: + r.raise_for_status() + with open(PDF_PATH, "wb") as f: + for chunk in r.iter_content(chunk_size=8192): + if chunk: + f.write(chunk) + print("Download complete.\n") + +client = CloudClient(api_key=os.environ["PAGEINDEX_API_KEY"]) +col = client.collection() + +doc_id = col.add(str(PDF_PATH)) +print(f"Indexed: {doc_id}\n") + +# Streaming query +stream = col.query("What is the main contribution of this paper?", stream=True) + +async def main(): + streamed_text = False + async for event in stream: + if event.type == "answer_delta": + print(event.data, end="", flush=True) + streamed_text = True + elif event.type == "tool_call": + if streamed_text: + print() + streamed_text = False + args = event.data.get("args", "") + print(f"[tool call] {event.data['name']}({args})") + elif event.type == "answer_done": + print() + streamed_text = False + +asyncio.run(main()) diff --git a/examples/local_demo.py b/examples/local_demo.py new file mode 100644 index 000000000..f98d25d69 --- /dev/null +++ b/examples/local_demo.py @@ -0,0 +1,69 @@ +""" +Agentic Vectorless RAG with PageIndex SDK - Local Demo + +A simple example of using LocalClient for self-hosted document indexing +and agent-based QA. The agent uses OpenAI Agents SDK to reason over +the document's tree structure index. + +Steps: + 1 — Download and index a PDF + 2 — Stream a question with tool call visibility + +Requirements: + pip install pageindex + export OPENAI_API_KEY=your-api-key # or any LiteLLM-supported provider +""" +import asyncio +from pathlib import Path +import requests +from pageindex import LocalClient + +_EXAMPLES_DIR = Path(__file__).parent +PDF_URL = "https://arxiv.org/pdf/1706.03762.pdf" +PDF_PATH = _EXAMPLES_DIR / "documents" / "attention.pdf" +WORKSPACE = _EXAMPLES_DIR / "workspace" +MODEL = "gpt-4o-2024-11-20" # any LiteLLM-supported model + +# Download PDF if needed +if not PDF_PATH.exists(): + print(f"Downloading {PDF_URL} ...") + PDF_PATH.parent.mkdir(parents=True, exist_ok=True) + with requests.get(PDF_URL, stream=True, timeout=30) as r: + r.raise_for_status() + with open(PDF_PATH, "wb") as f: + for chunk in r.iter_content(chunk_size=8192): + if chunk: + f.write(chunk) + print("Download complete.\n") + +client = LocalClient(model=MODEL, storage_path=str(WORKSPACE)) +col = client.collection() + +doc_id = col.add(str(PDF_PATH)) +print(f"Indexed: {doc_id}\n") + +# Streaming query +stream = col.query( + "What is the main architecture proposed in this paper and how does self-attention work?", + stream=True, +) + +async def main(): + streamed_text = False + async for event in stream: + if event.type == "answer_delta": + print(event.data, end="", flush=True) + streamed_text = True + elif event.type == "tool_call": + if streamed_text: + print() + streamed_text = False + print(f"[tool call] {event.data['name']}") + elif event.type == "tool_result": + preview = str(event.data)[:200] + "..." if len(str(event.data)) > 200 else event.data + print(f"[tool output] {preview}") + elif event.type == "answer_done": + print() + streamed_text = False + +asyncio.run(main()) diff --git a/pageindex/__init__.py b/pageindex/__init__.py index 658003bf5..64464418f 100644 --- a/pageindex/__init__.py +++ b/pageindex/__init__.py @@ -1,4 +1,40 @@ +# pageindex/__init__.py +# Upstream exports (backward compatibility) from .page_index import * from .page_index_md import md_to_tree from .retrieve import get_document, get_document_structure, get_page_content -from .client import PageIndexClient + +# SDK exports +from .client import PageIndexClient, LocalClient, CloudClient +from .config import IndexConfig +from .collection import Collection +from .parser.protocol import ContentNode, ParsedDocument, DocumentParser +from .storage.protocol import StorageEngine +from .events import QueryEvent +from .errors import ( + PageIndexError, + CollectionNotFoundError, + DocumentNotFoundError, + IndexingError, + CloudAPIError, + FileTypeError, +) + +__all__ = [ + "PageIndexClient", + "LocalClient", + "CloudClient", + "IndexConfig", + "Collection", + "ContentNode", + "ParsedDocument", + "DocumentParser", + "StorageEngine", + "QueryEvent", + "PageIndexError", + "CollectionNotFoundError", + "DocumentNotFoundError", + "IndexingError", + "CloudAPIError", + "FileTypeError", +] diff --git a/pageindex/agent.py b/pageindex/agent.py new file mode 100644 index 000000000..9ee7b9387 --- /dev/null +++ b/pageindex/agent.py @@ -0,0 +1,93 @@ +# pageindex/agent.py +from __future__ import annotations +from typing import AsyncIterator +from .events import QueryEvent +from .backend.protocol import AgentTools + + +SYSTEM_PROMPT = """ +You are PageIndex, a document QA assistant. +TOOL USE: +- Call list_documents() to see available documents. +- Call get_document(doc_id) to confirm status and page/line count. +- Call get_document_structure(doc_id) to identify relevant page ranges. +- Call get_page_content(doc_id, pages="5-7") with tight ranges; never fetch the whole document. +- Before each tool call, output one short sentence explaining the reason. +IMAGES: +- Page content may contain image references like ![image](path). Always preserve these in your answer so the downstream UI can render them. +- Place images near the relevant context in your answer. +Answer based only on tool output. Be concise. +""" + + +class QueryStream: + """Streaming query result, similar to OpenAI's RunResultStreaming. + + Usage: + stream = col.query("question", stream=True) + async for event in stream: + if event.type == "answer_delta": + print(event.data, end="", flush=True) + """ + + def __init__(self, tools: AgentTools, question: str, model: str = None): + from agents import Agent + from agents.model_settings import ModelSettings + self._agent = Agent( + name="PageIndex", + instructions=SYSTEM_PROMPT, + tools=tools.function_tools, + mcp_servers=tools.mcp_servers, + model=model, + model_settings=ModelSettings(parallel_tool_calls=False), + ) + self._question = question + + async def stream_events(self) -> AsyncIterator[QueryEvent]: + """Async generator yielding QueryEvent as they arrive.""" + from agents import Runner, ItemHelpers + from agents.stream_events import RawResponsesStreamEvent, RunItemStreamEvent + from openai.types.responses import ResponseTextDeltaEvent + + streamed_run = Runner.run_streamed(self._agent, self._question) + async for event in streamed_run.stream_events(): + if isinstance(event, RawResponsesStreamEvent): + if isinstance(event.data, ResponseTextDeltaEvent): + yield QueryEvent(type="answer_delta", data=event.data.delta) + elif isinstance(event, RunItemStreamEvent): + item = event.item + if item.type == "tool_call_item": + raw = item.raw_item + yield QueryEvent(type="tool_call", data={ + "name": raw.name, "args": getattr(raw, "arguments", "{}"), + }) + elif item.type == "tool_call_output_item": + yield QueryEvent(type="tool_result", data=str(item.output)) + elif item.type == "message_output_item": + text = ItemHelpers.text_message_output(item) + if text: + yield QueryEvent(type="answer_done", data=text) + + def __aiter__(self): + return self.stream_events() + + +class AgentRunner: + def __init__(self, tools: AgentTools, model: str = None): + self._tools = tools + self._model = model + + def run(self, question: str) -> str: + """Sync non-streaming query. Returns answer string.""" + from agents import Agent, Runner + from agents.model_settings import ModelSettings + agent = Agent( + name="PageIndex", + instructions=SYSTEM_PROMPT, + tools=self._tools.function_tools, + mcp_servers=self._tools.mcp_servers, + model=self._model, + model_settings=ModelSettings(parallel_tool_calls=False), + ) + result = Runner.run_sync(agent, question) + return result.final_output diff --git a/pageindex/backend/__init__.py b/pageindex/backend/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pageindex/backend/cloud.py b/pageindex/backend/cloud.py new file mode 100644 index 000000000..587a6c65b --- /dev/null +++ b/pageindex/backend/cloud.py @@ -0,0 +1,352 @@ +# pageindex/backend/cloud.py +"""CloudBackend — connects to PageIndex cloud service (api.pageindex.ai). + +API reference: https://github.com/VectifyAI/pageindex_sdk +""" +from __future__ import annotations +import json +import logging +import os +import re +import time +import urllib.parse +import requests +from typing import AsyncIterator + +from .protocol import AgentTools +from ..errors import CloudAPIError, PageIndexError +from ..events import QueryEvent + +logger = logging.getLogger(__name__) + +API_BASE = "https://api.pageindex.ai" + +_INTERNAL_TOOLS = frozenset({"ToolSearch", "Read", "Grep", "Glob", "Bash", "Edit", "Write"}) + + +class CloudBackend: + def __init__(self, api_key: str): + self._api_key = api_key + self._headers = {"api_key": api_key} + self._folder_id_cache: dict[str, str | None] = {} + self._folder_warning_shown = False + + # ── HTTP helpers ────────────────────────────────────────────────────── + + def _warn_folder_upgrade(self) -> None: + if not self._folder_warning_shown: + logger.warning( + "Folders (collections) require a Max plan. " + "All documents are stored in a single global space — collection names are ignored. " + "Upgrade at https://dash.pageindex.ai/subscription" + ) + self._folder_warning_shown = True + + def _request(self, method: str, path: str, **kwargs) -> dict: + url = f"{API_BASE}{path}" + for attempt in range(3): + try: + resp = requests.request(method, url, headers=self._headers, timeout=30, **kwargs) + if resp.status_code in (429, 500, 502, 503): + logger.warning("Cloud API %s %s returned %d, retrying...", method, path, resp.status_code) + time.sleep(2 ** attempt) + continue + if resp.status_code != 200: + body = resp.text[:500] if resp.text else "" + raise CloudAPIError(f"Cloud API error {resp.status_code}: {body}") + return resp.json() if resp.content else {} + except requests.RequestException as e: + if attempt == 2: + raise CloudAPIError(f"Cloud API request failed: {e}") from e + time.sleep(2 ** attempt) + raise CloudAPIError("Max retries exceeded") + + @staticmethod + def _validate_collection_name(name: str) -> None: + if not re.match(r'^[a-zA-Z0-9_-]{1,128}$', name): + raise PageIndexError( + f"Invalid collection name: {name!r}. " + "Must be 1-128 chars of [a-zA-Z0-9_-]." + ) + + @staticmethod + def _enc(value: str) -> str: + return urllib.parse.quote(value, safe="") + + # ── Collection management (mapped to folders) ───────────────────────── + + def create_collection(self, name: str) -> None: + self._validate_collection_name(name) + try: + resp = self._request("POST", "/folder/", json={"name": name}) + self._folder_id_cache[name] = resp.get("folder", {}).get("id") + except CloudAPIError as e: + if "403" in str(e): + self._warn_folder_upgrade() + self._folder_id_cache[name] = None + else: + raise + + def get_or_create_collection(self, name: str) -> None: + self._validate_collection_name(name) + try: + data = self._request("GET", "/folders/") + for folder in data.get("folders", []): + if folder.get("name") == name: + self._folder_id_cache[name] = folder["id"] + return + resp = self._request("POST", "/folder/", json={"name": name}) + self._folder_id_cache[name] = resp.get("folder", {}).get("id") + except CloudAPIError as e: + if "403" in str(e): + self._warn_folder_upgrade() + self._folder_id_cache[name] = None + else: + raise + + def _get_folder_id(self, name: str) -> str | None: + """Resolve collection name to folder ID. Returns None if folders not available.""" + if name in self._folder_id_cache: + return self._folder_id_cache.get(name) + try: + data = self._request("GET", "/folders/") + for folder in data.get("folders", []): + if folder.get("name") == name: + self._folder_id_cache[name] = folder["id"] + return folder["id"] + except CloudAPIError: + pass + self._folder_id_cache[name] = None + return None + + def list_collections(self) -> list[str]: + data = self._request("GET", "/folders/") + return [f["name"] for f in data.get("folders", [])] + + def delete_collection(self, name: str) -> None: + folder_id = self._get_folder_id(name) + if folder_id: + self._request("DELETE", f"/folder/{self._enc(folder_id)}/") + + # ── Document management ─────────────────────────────────────────────── + + def add_document(self, collection: str, file_path: str) -> str: + folder_id = self._get_folder_id(collection) + data = {"if_retrieval": "true"} + if folder_id: + data["folder_id"] = folder_id + + with open(file_path, "rb") as f: + resp = self._request("POST", "/doc/", files={"file": f}, data=data) + + doc_id = resp["doc_id"] + + # Poll until retrieval-ready + for _ in range(120): # 10 min max + tree_resp = self._request("GET", f"/doc/{self._enc(doc_id)}/", params={"type": "tree"}) + if tree_resp.get("retrieval_ready"): + return doc_id + status = tree_resp.get("status", "") + if status == "failed": + raise CloudAPIError(f"Document {doc_id} indexing failed") + time.sleep(5) + + raise CloudAPIError(f"Document {doc_id} indexing timed out") + + def get_document(self, collection: str, doc_id: str, include_text: bool = False) -> dict: + resp = self._request("GET", f"/doc/{self._enc(doc_id)}/metadata/") + # Fetch structure in the same call via tree endpoint + tree_resp = self._request("GET", f"/doc/{self._enc(doc_id)}/", + params={"type": "tree", "summary": "true"}) + raw_tree = tree_resp.get("tree", tree_resp.get("structure", tree_resp.get("result", []))) + return { + "doc_id": resp.get("id", doc_id), + "doc_name": resp.get("name", ""), + "doc_description": resp.get("description", ""), + "doc_type": "pdf", + "status": resp.get("status", ""), + "structure": self._normalize_tree(raw_tree), + } + + def get_document_structure(self, collection: str, doc_id: str) -> list: + resp = self._request("GET", f"/doc/{self._enc(doc_id)}/", params={"type": "tree", "summary": "true"}) + raw_tree = resp.get("tree", resp.get("structure", resp.get("result", []))) + return self._normalize_tree(raw_tree) + + def get_page_content(self, collection: str, doc_id: str, pages: str) -> list: + resp = self._request("GET", f"/doc/{self._enc(doc_id)}/", params={"type": "ocr", "format": "page"}) + # Filter to requested pages + from ..index.utils import parse_pages + page_nums = set(parse_pages(pages)) + all_pages = resp.get("pages", resp.get("ocr", resp.get("result", []))) + if isinstance(all_pages, list): + return [ + {"page": p.get("page", p.get("page_index")), + "content": p.get("content", p.get("markdown", ""))} + for p in all_pages + if p.get("page", p.get("page_index")) in page_nums + ] + return [] + + @staticmethod + def _normalize_tree(nodes: list) -> list: + """Normalize cloud tree nodes to match local schema.""" + result = [] + for node in nodes: + normalized = { + "title": node.get("title", ""), + "node_id": node.get("node_id", ""), + "summary": node.get("summary", node.get("prefix_summary", "")), + "start_index": node.get("start_index", node.get("page_index")), + "end_index": node.get("end_index", node.get("page_index")), + } + if "text" in node: + normalized["text"] = node["text"] + children = node.get("nodes", []) + if children: + normalized["nodes"] = CloudBackend._normalize_tree(children) + result.append(normalized) + return result + + def list_documents(self, collection: str) -> list[dict]: + folder_id = self._get_folder_id(collection) + params = {"limit": 100} + if folder_id: + params["folder_id"] = folder_id + data = self._request("GET", "/docs/", params=params) + return [ + {"doc_id": d.get("id", ""), "doc_name": d.get("name", ""), "doc_type": "pdf"} + for d in data.get("documents", []) + ] + + def delete_document(self, collection: str, doc_id: str) -> None: + self._request("DELETE", f"/doc/{self._enc(doc_id)}/") + + # ── Query (uses cloud chat/completions, no LLM key needed) ──────────── + + def query(self, collection: str, question: str, doc_ids: list[str] | None = None) -> str: + """Non-streaming query via cloud chat/completions.""" + doc_id = doc_ids if doc_ids else self._get_all_doc_ids(collection) + resp = self._request("POST", "/chat/completions/", json={ + "messages": [{"role": "user", "content": question}], + "doc_id": doc_id, + "stream": False, + }) + # Extract answer from response + choices = resp.get("choices", []) + if choices: + return choices[0].get("message", {}).get("content", "") + return resp.get("content", resp.get("answer", "")) + + async def query_stream(self, collection: str, question: str, + doc_ids: list[str] | None = None) -> AsyncIterator[QueryEvent]: + """Streaming query via cloud chat/completions SSE. + + Events are yielded in real-time as they arrive from the server. + A background thread handles the blocking HTTP stream and pushes + events through an asyncio.Queue for true async streaming. + """ + import asyncio + import threading + + doc_id = doc_ids if doc_ids else self._get_all_doc_ids(collection) + headers = self._headers + queue: asyncio.Queue[QueryEvent | None] = asyncio.Queue() + loop = asyncio.get_event_loop() + + def _stream(): + """Background thread: read SSE and push events to queue.""" + resp = requests.post( + f"{API_BASE}/chat/completions/", + headers=headers, + json={ + "messages": [{"role": "user", "content": question}], + "doc_id": doc_id, + "stream": True, + "stream_metadata": True, + }, + stream=True, + timeout=120, + ) + try: + if resp.status_code != 200: + body = resp.text[:500] if resp.text else "" + loop.call_soon_threadsafe( + queue.put_nowait, + QueryEvent(type="answer_done", + data=f"Cloud streaming error {resp.status_code}: {body}"), + ) + return + + current_tool_name = None + current_tool_args: list[str] = [] + + for line in resp.iter_lines(decode_unicode=True): + if not line or not line.startswith("data: "): + continue + data_str = line[6:] + if data_str.strip() == "[DONE]": + break + try: + chunk = json.loads(data_str) + except json.JSONDecodeError: + continue + + meta = chunk.get("block_metadata", {}) + block_type = meta.get("type", "") + choices = chunk.get("choices", []) + delta = choices[0].get("delta", {}) if choices else {} + content = delta.get("content", "") + + if block_type == "mcp_tool_use_start": + current_tool_name = meta.get("tool_name", "") + current_tool_args = [] + + elif block_type == "tool_use": + if content: + current_tool_args.append(content) + + elif block_type == "tool_use_stop": + if current_tool_name and current_tool_name not in _INTERNAL_TOOLS: + args_str = "".join(current_tool_args) + loop.call_soon_threadsafe( + queue.put_nowait, + QueryEvent(type="tool_call", data={ + "name": current_tool_name, + "args": args_str, + }), + ) + current_tool_name = None + current_tool_args = [] + + elif block_type == "text" and content: + loop.call_soon_threadsafe( + queue.put_nowait, + QueryEvent(type="answer_delta", data=content), + ) + + finally: + resp.close() + loop.call_soon_threadsafe(queue.put_nowait, None) # sentinel + + thread = threading.Thread(target=_stream, daemon=True) + thread.start() + + while True: + event = await queue.get() + if event is None: + break + yield event + + thread.join(timeout=5) + + def _get_all_doc_ids(self, collection: str) -> list[str]: + """Get all document IDs in a collection.""" + docs = self.list_documents(collection) + return [d["doc_id"] for d in docs] + + # ── Not used in cloud mode ──────────────────────────────────────────── + + def get_agent_tools(self, collection: str, doc_ids: list[str] | None = None) -> AgentTools: + """Not used in cloud mode — query goes through chat/completions.""" + return AgentTools() diff --git a/pageindex/backend/local.py b/pageindex/backend/local.py new file mode 100644 index 000000000..ae2ac25f1 --- /dev/null +++ b/pageindex/backend/local.py @@ -0,0 +1,245 @@ +# pageindex/backend/local.py +import hashlib +import os +import re +import uuid +import shutil +from pathlib import Path + +from ..parser.protocol import DocumentParser, ParsedDocument +from ..parser.pdf import PdfParser +from ..parser.markdown import MarkdownParser +from ..storage.protocol import StorageEngine +from ..index.pipeline import build_index +from ..index.utils import parse_pages, get_pdf_page_content, get_md_page_content, remove_fields +from ..backend.protocol import AgentTools +from ..errors import FileTypeError, DocumentNotFoundError, IndexingError, PageIndexError + +_COLLECTION_NAME_RE = re.compile(r'^[a-zA-Z0-9_-]{1,128}$') + + +class LocalBackend: + def __init__(self, storage: StorageEngine, files_dir: str, model: str = None, + retrieve_model: str = None, index_config=None): + self._storage = storage + self._files_dir = Path(files_dir) + self._model = model + self._retrieve_model = retrieve_model or model + self._index_config = index_config + self._parsers: list[DocumentParser] = [PdfParser(), MarkdownParser()] + + def register_parser(self, parser: DocumentParser) -> None: + self._parsers.insert(0, parser) # user parsers checked first + + def get_retrieve_model(self) -> str | None: + return self._retrieve_model + + def _resolve_parser(self, file_path: str) -> DocumentParser: + ext = os.path.splitext(file_path)[1].lower() + for parser in self._parsers: + if ext in parser.supported_extensions(): + return parser + raise FileTypeError(f"No parser for extension: {ext}") + + # Collection management + def _validate_collection_name(self, name: str) -> None: + if not _COLLECTION_NAME_RE.match(name): + raise PageIndexError(f"Invalid collection name: {name!r}. Must be 1-128 chars of [a-zA-Z0-9_-].") + + def create_collection(self, name: str) -> None: + self._validate_collection_name(name) + self._storage.create_collection(name) + + def get_or_create_collection(self, name: str) -> None: + self._validate_collection_name(name) + self._storage.get_or_create_collection(name) + + def list_collections(self) -> list[str]: + return self._storage.list_collections() + + def delete_collection(self, name: str) -> None: + self._storage.delete_collection(name) + col_dir = self._files_dir / name + if col_dir.exists(): + shutil.rmtree(col_dir) + + @staticmethod + def _file_hash(file_path: str) -> str: + """Compute SHA-256 hash of a file.""" + h = hashlib.sha256() + with open(file_path, "rb") as f: + for chunk in iter(lambda: f.read(65536), b""): + h.update(chunk) + return h.hexdigest() + + # Document management + def add_document(self, collection: str, file_path: str) -> str: + file_path = os.path.realpath(file_path) + if not os.path.isfile(file_path): + raise FileTypeError(f"Not a regular file: {file_path}") + parser = self._resolve_parser(file_path) + + # Dedup: skip if same file already indexed in this collection + file_hash = self._file_hash(file_path) + existing_id = self._storage.find_document_by_hash(collection, file_hash) + if existing_id: + return existing_id + + doc_id = str(uuid.uuid4()) + + # Copy file to managed directory + ext = os.path.splitext(file_path)[1] + col_dir = self._files_dir / collection + col_dir.mkdir(parents=True, exist_ok=True) + managed_path = col_dir / f"{doc_id}{ext}" + shutil.copy2(file_path, managed_path) + + try: + # Store images alongside the document: files/{collection}/{doc_id}/images/ + images_dir = str(col_dir / doc_id / "images") + parsed = parser.parse(file_path, model=self._model, images_dir=images_dir) + result = build_index(parsed, model=self._model, opt=self._index_config) + + # Cache page text for fast retrieval (avoids re-reading files) + pages = [{"page": n.index, "content": n.content, + **({"images": n.images} if n.images else {})} + for n in parsed.nodes if n.content] + + # Strip text from structure to save storage space (PDF only; + # markdown needs text in structure for fallback retrieval) + doc_type = ext.lstrip(".") + if doc_type == "pdf": + clean_structure = remove_fields(result["structure"], fields=["text"]) + else: + clean_structure = result["structure"] + + self._storage.save_document(collection, doc_id, { + "doc_name": parsed.doc_name, + "doc_description": result.get("doc_description", ""), + "file_path": str(managed_path), + "file_hash": file_hash, + "doc_type": doc_type, + "structure": clean_structure, + "pages": pages, + }) + except Exception as e: + managed_path.unlink(missing_ok=True) + doc_dir = col_dir / doc_id + if doc_dir.exists(): + shutil.rmtree(doc_dir) + raise IndexingError(f"Failed to index {file_path}: {e}") from e + + return doc_id + + def get_document(self, collection: str, doc_id: str, include_text: bool = False) -> dict: + """Get document metadata with structure. + + Args: + include_text: If True, populate each structure node's 'text' field + from cached page content. WARNING: may be very large — do NOT + use in agent/LLM contexts as it can exhaust the context window. + """ + doc = self._storage.get_document(collection, doc_id) + if not doc: + return {} + doc["structure"] = self._storage.get_document_structure(collection, doc_id) + if include_text: + pages = self._storage.get_pages(collection, doc_id) or [] + page_map = {p["page"]: p["content"] for p in pages} + self._fill_node_text(doc["structure"], page_map) + return doc + + @staticmethod + def _fill_node_text(nodes: list, page_map: dict) -> None: + """Recursively fill 'text' on structure nodes from cached page content.""" + for node in nodes: + start = node.get("start_index") + end = node.get("end_index") + if start is not None and end is not None: + node["text"] = "\n".join( + page_map.get(p, "") for p in range(start, end + 1) + ) + if "nodes" in node: + LocalBackend._fill_node_text(node["nodes"], page_map) + + def get_document_structure(self, collection: str, doc_id: str) -> list: + return self._storage.get_document_structure(collection, doc_id) + + def get_page_content(self, collection: str, doc_id: str, pages: str) -> list: + doc = self._storage.get_document(collection, doc_id) + if not doc: + raise DocumentNotFoundError(f"Document {doc_id} not found") + page_nums = parse_pages(pages) + + # Try cached pages first (fast, no file I/O) + cached_pages = self._storage.get_pages(collection, doc_id) + if cached_pages: + return [p for p in cached_pages if p["page"] in page_nums] + + # Fallback to reading from file + if doc["doc_type"] == "pdf": + return get_pdf_page_content(doc["file_path"], page_nums) + else: + structure = self._storage.get_document_structure(collection, doc_id) + return get_md_page_content(structure, page_nums) + + def list_documents(self, collection: str) -> list[dict]: + return self._storage.list_documents(collection) + + def delete_document(self, collection: str, doc_id: str) -> None: + doc = self._storage.get_document(collection, doc_id) + if doc and doc.get("file_path"): + Path(doc["file_path"]).unlink(missing_ok=True) + # Clean up images directory: files/{collection}/{doc_id}/ + doc_dir = self._files_dir / collection / doc_id + if doc_dir.exists(): + shutil.rmtree(doc_dir) + self._storage.delete_document(collection, doc_id) + + def get_agent_tools(self, collection: str, doc_ids: list[str] | None = None) -> AgentTools: + from agents import function_tool + import json + storage = self._storage + col_name = collection + backend = self + filter_ids = doc_ids + + @function_tool + def list_documents() -> str: + """List all documents in the collection.""" + docs = storage.list_documents(col_name) + if filter_ids: + docs = [d for d in docs if d["doc_id"] in filter_ids] + return json.dumps(docs) + + @function_tool + def get_document(doc_id: str) -> str: + """Get document metadata.""" + return json.dumps(storage.get_document(col_name, doc_id)) + + @function_tool + def get_document_structure(doc_id: str) -> str: + """Get document tree structure (without text).""" + structure = storage.get_document_structure(col_name, doc_id) + return json.dumps(remove_fields(structure, fields=["text"]), ensure_ascii=False) + + @function_tool + def get_page_content(doc_id: str, pages: str) -> str: + """Get page content. Use tight ranges: '5-7', '3,8', '12'.""" + result = backend.get_page_content(col_name, doc_id, pages) + return json.dumps(result, ensure_ascii=False) + + return AgentTools(function_tools=[list_documents, get_document, get_document_structure, get_page_content]) + + def query(self, collection: str, question: str, doc_ids: list[str] | None = None) -> str: + from ..agent import AgentRunner + tools = self.get_agent_tools(collection, doc_ids) + return AgentRunner(tools=tools, model=self._retrieve_model).run(question) + + async def query_stream(self, collection: str, question: str, + doc_ids: list[str] | None = None): + from ..agent import QueryStream + tools = self.get_agent_tools(collection, doc_ids) + stream = QueryStream(tools=tools, question=question, model=self._retrieve_model) + async for event in stream: + yield event diff --git a/pageindex/backend/protocol.py b/pageindex/backend/protocol.py new file mode 100644 index 000000000..6e4c7a3c6 --- /dev/null +++ b/pageindex/backend/protocol.py @@ -0,0 +1,34 @@ +from __future__ import annotations +from dataclasses import dataclass, field +from typing import Protocol, Any, AsyncIterator, runtime_checkable + +from ..events import QueryEvent + + +@dataclass +class AgentTools: + """Structured container for agent tool configuration (local mode only).""" + function_tools: list[Any] = field(default_factory=list) + mcp_servers: list[Any] = field(default_factory=list) + + +@runtime_checkable +class Backend(Protocol): + # Collection management + def create_collection(self, name: str) -> None: ... + def get_or_create_collection(self, name: str) -> None: ... + def list_collections(self) -> list[str]: ... + def delete_collection(self, name: str) -> None: ... + + # Document management + def add_document(self, collection: str, file_path: str) -> str: ... + def get_document(self, collection: str, doc_id: str, include_text: bool = False) -> dict: ... + def get_document_structure(self, collection: str, doc_id: str) -> list: ... + def get_page_content(self, collection: str, doc_id: str, pages: str) -> list: ... + def list_documents(self, collection: str) -> list[dict]: ... + def delete_document(self, collection: str, doc_id: str) -> None: ... + + # Query + def query(self, collection: str, question: str, doc_ids: list[str] | None = None) -> str: ... + async def query_stream(self, collection: str, question: str, + doc_ids: list[str] | None = None) -> AsyncIterator[QueryEvent]: ... diff --git a/pageindex/client.py b/pageindex/client.py index 894dab181..806ebb638 100644 --- a/pageindex/client.py +++ b/pageindex/client.py @@ -1,18 +1,9 @@ -import os -import uuid -import json -import asyncio -import concurrent.futures +# pageindex/client.py +from __future__ import annotations from pathlib import Path - -import PyPDF2 - -from .page_index import page_index -from .page_index_md import md_to_tree -from .retrieve import get_document, get_document_structure, get_page_content -from .utils import ConfigLoader, remove_fields - -META_INDEX = "_meta.json" +from .collection import Collection +from .config import IndexConfig +from .parser.protocol import DocumentParser def _normalize_retrieve_model(model: str) -> str: @@ -26,209 +17,145 @@ def _normalize_retrieve_model(model: str) -> str: class PageIndexClient: + """PageIndex client — supports both local and cloud modes. + + Args: + api_key: PageIndex cloud API key. When provided, cloud mode is used + and local-only params (model, storage_path, index_config, …) are ignored. + model: LLM model for indexing (local mode only, default: gpt-4o-2024-11-20). + retrieve_model: LLM model for agent QA (local mode only, default: same as model). + storage_path: Directory for SQLite DB and files (local mode only, default: ./.pageindex). + storage: Custom StorageEngine instance (local mode only). + index_config: Advanced indexing parameters (local mode only, optional). + Pass an IndexConfig instance or a dict. Defaults are sensible for most use cases. + + Usage: + # Local mode (auto-detected when no api_key) + client = PageIndexClient(model="gpt-5.4") + + # Cloud mode (auto-detected when api_key provided) + client = PageIndexClient(api_key="your-api-key") + + # Or use LocalClient / CloudClient for explicit mode selection """ - A client for indexing and retrieving document content. - Flow: index() -> get_document() / get_document_structure() / get_page_content() - For agent-based QA, see examples/agentic_vectorless_rag_demo.py. - """ - def __init__(self, api_key: str = None, model: str = None, retrieve_model: str = None, workspace: str = None): + def __init__(self, api_key: str = None, model: str = None, + retrieve_model: str = None, storage_path: str = None, + storage=None, index_config: IndexConfig | dict = None): if api_key: - os.environ["OPENAI_API_KEY"] = api_key - elif not os.getenv("OPENAI_API_KEY") and os.getenv("CHATGPT_API_KEY"): - os.environ["OPENAI_API_KEY"] = os.getenv("CHATGPT_API_KEY") - self.workspace = Path(workspace).expanduser() if workspace else None + self._init_cloud(api_key) + else: + self._init_local(model, retrieve_model, storage_path, storage, index_config) + + def _init_cloud(self, api_key: str): + from .backend.cloud import CloudBackend + self._backend = CloudBackend(api_key=api_key) + + def _init_local(self, model: str = None, retrieve_model: str = None, + storage_path: str = None, storage=None, + index_config: IndexConfig | dict = None): + # Build IndexConfig: merge model/retrieve_model with index_config overrides = {} if model: overrides["model"] = model if retrieve_model: overrides["retrieve_model"] = retrieve_model - opt = ConfigLoader().load(overrides or None) - self.model = opt.model - self.retrieve_model = _normalize_retrieve_model(opt.retrieve_model or self.model) - if self.workspace: - self.workspace.mkdir(parents=True, exist_ok=True) - self.documents = {} - if self.workspace: - self._load_workspace() - - def index(self, file_path: str, mode: str = "auto") -> str: - """Index a document. Returns a document_id.""" - # Persist a canonical absolute path so workspace reloads do not - # reinterpret caller-relative paths against the workspace directory. - file_path = os.path.abspath(os.path.expanduser(file_path)) - if not os.path.exists(file_path): - raise FileNotFoundError(f"File not found: {file_path}") - - doc_id = str(uuid.uuid4()) - ext = os.path.splitext(file_path)[1].lower() - - is_pdf = ext == '.pdf' - is_md = ext in ['.md', '.markdown'] - - if mode == "pdf" or (mode == "auto" and is_pdf): - print(f"Indexing PDF: {file_path}") - result = page_index( - doc=file_path, - model=self.model, - if_add_node_summary='yes', - if_add_node_text='yes', - if_add_node_id='yes', - if_add_doc_description='yes' - ) - # Extract per-page text so queries don't need the original PDF - pages = [] - with open(file_path, 'rb') as f: - pdf_reader = PyPDF2.PdfReader(f) - for i, page in enumerate(pdf_reader.pages, 1): - pages.append({'page': i, 'content': page.extract_text() or ''}) - - self.documents[doc_id] = { - 'id': doc_id, - 'type': 'pdf', - 'path': file_path, - 'doc_name': result.get('doc_name', ''), - 'doc_description': result.get('doc_description', ''), - 'page_count': len(pages), - 'structure': result['structure'], - 'pages': pages, - } - - elif mode == "md" or (mode == "auto" and is_md): - print(f"Indexing Markdown: {file_path}") - coro = md_to_tree( - md_path=file_path, - if_thinning=False, - if_add_node_summary='yes', - summary_token_threshold=200, - model=self.model, - if_add_doc_description='yes', - if_add_node_text='yes', - if_add_node_id='yes' - ) - try: - asyncio.get_running_loop() - with concurrent.futures.ThreadPoolExecutor(max_workers=1) as pool: - result = pool.submit(asyncio.run, coro).result() - except RuntimeError: - result = asyncio.run(coro) - self.documents[doc_id] = { - 'id': doc_id, - 'type': 'md', - 'path': file_path, - 'doc_name': result.get('doc_name', ''), - 'doc_description': result.get('doc_description', ''), - 'line_count': result.get('line_count', 0), - 'structure': result['structure'], - } + if isinstance(index_config, IndexConfig): + opt = index_config.model_copy(update=overrides) + elif isinstance(index_config, dict): + merged = {**index_config, **overrides} # explicit model/retrieve_model win + opt = IndexConfig(**merged) else: - raise ValueError(f"Unsupported file format for: {file_path}") + opt = IndexConfig(**overrides) if overrides else IndexConfig() - print(f"Indexing complete. Document ID: {doc_id}") - if self.workspace: - self._save_doc(doc_id) - return doc_id + self._validate_llm_provider(opt.model) - @staticmethod - def _make_meta_entry(doc: dict) -> dict: - """Build a lightweight meta entry from a document dict.""" - entry = { - 'type': doc.get('type', ''), - 'doc_name': doc.get('doc_name', ''), - 'doc_description': doc.get('doc_description', ''), - 'path': doc.get('path', ''), - } - if doc.get('type') == 'pdf': - entry['page_count'] = doc.get('page_count') - elif doc.get('type') == 'md': - entry['line_count'] = doc.get('line_count') - return entry + storage_path = Path(storage_path or ".pageindex").resolve() + storage_path.mkdir(parents=True, exist_ok=True) + + from .storage.sqlite import SQLiteStorage + from .backend.local import LocalBackend + storage_engine = storage or SQLiteStorage(str(storage_path / "pageindex.db")) + self._backend = LocalBackend( + storage=storage_engine, + files_dir=str(storage_path / "files"), + model=opt.model, + retrieve_model=_normalize_retrieve_model(opt.retrieve_model or opt.model), + index_config=opt, + ) @staticmethod - def _read_json(path) -> dict | None: - """Read a JSON file, returning None on any error.""" + def _validate_llm_provider(model: str) -> None: + """Validate model and check API key via litellm. Warns if key seems missing.""" try: - with open(path, "r", encoding="utf-8") as f: - return json.load(f) - except (json.JSONDecodeError, OSError) as e: - print(f"Warning: corrupt {Path(path).name}: {e}") - return None - - def _save_doc(self, doc_id: str): - doc = self.documents[doc_id].copy() - # Strip text from structure nodes — redundant with pages (PDF only) - if doc.get('structure') and doc.get('type') == 'pdf': - doc['structure'] = remove_fields(doc['structure'], fields=['text']) - path = self.workspace / f"{doc_id}.json" - with open(path, "w", encoding="utf-8") as f: - json.dump(doc, f, ensure_ascii=False, indent=2) - self._save_meta(doc_id, self._make_meta_entry(doc)) - # Drop heavy fields; will lazy-load on demand - self.documents[doc_id].pop('structure', None) - self.documents[doc_id].pop('pages', None) - - def _rebuild_meta(self) -> dict: - """Scan individual doc JSON files and return a meta dict.""" - meta = {} - for path in self.workspace.glob("*.json"): - if path.name == META_INDEX: - continue - doc = self._read_json(path) - if doc and isinstance(doc, dict): - meta[path.stem] = self._make_meta_entry(doc) - return meta - - def _read_meta(self) -> dict | None: - """Read and validate _meta.json, returning None on any corruption.""" - meta = self._read_json(self.workspace / META_INDEX) - if meta is not None and not isinstance(meta, dict): - print(f"Warning: {META_INDEX} is not a JSON object, ignoring") - return None - return meta - - def _save_meta(self, doc_id: str, entry: dict): - meta = self._read_meta() or self._rebuild_meta() - meta[doc_id] = entry - meta_path = self.workspace / META_INDEX - with open(meta_path, "w", encoding="utf-8") as f: - json.dump(meta, f, ensure_ascii=False, indent=2) - - def _load_workspace(self): - meta = self._read_meta() - if meta is None: - meta = self._rebuild_meta() - if meta: - print(f"Loaded {len(meta)} document(s) from workspace (legacy mode).") - for doc_id, entry in meta.items(): - doc = dict(entry, id=doc_id) - if doc.get('path') and not os.path.isabs(doc['path']): - doc['path'] = str((self.workspace / doc['path']).resolve()) - self.documents[doc_id] = doc - - def _ensure_doc_loaded(self, doc_id: str): - """Load full document JSON on demand (structure, pages, etc.).""" - doc = self.documents.get(doc_id) - if not doc or doc.get('structure') is not None: - return - full = self._read_json(self.workspace / f"{doc_id}.json") - if not full: + import litellm + litellm.model_cost_map_url = "" + _, provider, _, _ = litellm.get_llm_provider(model=model) + except Exception: return - doc['structure'] = full.get('structure', []) - if full.get('pages'): - doc['pages'] = full['pages'] - - def get_document(self, doc_id: str) -> str: - """Return document metadata JSON.""" - return get_document(self.documents, doc_id) - - def get_document_structure(self, doc_id: str) -> str: - """Return document tree structure JSON (without text fields).""" - if self.workspace: - self._ensure_doc_loaded(doc_id) - return get_document_structure(self.documents, doc_id) - - def get_page_content(self, doc_id: str, pages: str) -> str: - """Return page content for the given pages string (e.g. '5-7', '3,8', '12').""" - if self.workspace: - self._ensure_doc_loaded(doc_id) - return get_page_content(self.documents, doc_id, pages) + + key = litellm.get_api_key(llm_provider=provider, dynamic_api_key=None) + if not key: + import os + common_var = f"{provider.upper()}_API_KEY" + if not os.getenv(common_var): + from .errors import PageIndexError + raise PageIndexError( + f"API key not configured for provider '{provider}' (model: {model}). " + f"Set the {common_var} environment variable." + ) + + def collection(self, name: str = "default") -> Collection: + """Get or create a collection. Defaults to 'default'.""" + self._backend.get_or_create_collection(name) + return Collection(name=name, backend=self._backend) + + def list_collections(self) -> list[str]: + return self._backend.list_collections() + + def delete_collection(self, name: str) -> None: + self._backend.delete_collection(name) + + def register_parser(self, parser: DocumentParser) -> None: + """Register a custom document parser. Only available in local mode.""" + if not hasattr(self._backend, 'register_parser'): + from .errors import PageIndexError + raise PageIndexError("Custom parsers are not supported in cloud mode") + self._backend.register_parser(parser) + + +class LocalClient(PageIndexClient): + """Local mode — indexes and queries documents on your machine. + + Args: + model: LLM model for indexing (default: gpt-4o-2024-11-20) + retrieve_model: LLM model for agent QA (default: same as model) + storage_path: Directory for SQLite DB and files (default: ./.pageindex) + storage: Custom StorageEngine instance (default: SQLiteStorage) + index_config: Advanced indexing parameters. Pass an IndexConfig instance + or a dict. All fields have sensible defaults — most users don't need this. + + Example:: + + # Simple — defaults are fine + client = LocalClient(model="gpt-5.4") + + # Advanced — tune indexing parameters + from pageindex.config import IndexConfig + client = LocalClient( + model="gpt-5.4", + index_config=IndexConfig(toc_check_page_num=30), + ) + """ + + def __init__(self, model: str = None, retrieve_model: str = None, + storage_path: str = None, storage=None, + index_config: IndexConfig | dict = None): + self._init_local(model, retrieve_model, storage_path, storage, index_config) + + +class CloudClient(PageIndexClient): + """Cloud mode — fully managed by PageIndex cloud service. No LLM key needed.""" + + def __init__(self, api_key: str): + self._init_cloud(api_key) diff --git a/pageindex/collection.py b/pageindex/collection.py new file mode 100644 index 000000000..f963d2293 --- /dev/null +++ b/pageindex/collection.py @@ -0,0 +1,69 @@ +# pageindex/collection.py +from __future__ import annotations +from typing import AsyncIterator +from .events import QueryEvent +from .backend.protocol import Backend + + +class QueryStream: + """Wraps backend.query_stream() as an async iterable object.""" + + def __init__(self, backend: Backend, collection: str, question: str, + doc_ids: list[str] | None = None): + self._backend = backend + self._collection = collection + self._question = question + self._doc_ids = doc_ids + + async def stream_events(self) -> AsyncIterator[QueryEvent]: + async for event in self._backend.query_stream( + self._collection, self._question, self._doc_ids + ): + yield event + + def __aiter__(self): + return self.stream_events() + + +class Collection: + def __init__(self, name: str, backend: Backend): + self._name = name + self._backend = backend + + @property + def name(self) -> str: + return self._name + + def add(self, file_path: str) -> str: + return self._backend.add_document(self._name, file_path) + + def list_documents(self) -> list[dict]: + return self._backend.list_documents(self._name) + + def get_document(self, doc_id: str, include_text: bool = False) -> dict: + return self._backend.get_document(self._name, doc_id, include_text=include_text) + + def get_document_structure(self, doc_id: str) -> list: + return self._backend.get_document_structure(self._name, doc_id) + + def get_page_content(self, doc_id: str, pages: str) -> list: + return self._backend.get_page_content(self._name, doc_id, pages) + + def delete_document(self, doc_id: str) -> None: + self._backend.delete_document(self._name, doc_id) + + def query(self, question: str, doc_ids: list[str] | None = None, + stream: bool = False) -> str | QueryStream: + """Query documents in this collection. + + - stream=False: returns answer string (sync) + - stream=True: returns async iterable of QueryEvent + + Usage: + answer = col.query("question") + async for event in col.query("question", stream=True): + ... + """ + if stream: + return QueryStream(self._backend, self._name, question, doc_ids) + return self._backend.query(self._name, question, doc_ids) diff --git a/pageindex/config.py b/pageindex/config.py new file mode 100644 index 000000000..fd3b12fc5 --- /dev/null +++ b/pageindex/config.py @@ -0,0 +1,22 @@ +# pageindex/config.py +from __future__ import annotations +from pydantic import BaseModel + + +class IndexConfig(BaseModel): + """Configuration for the PageIndex indexing pipeline. + + All fields have sensible defaults. Advanced users can override + via LocalClient(index_config=IndexConfig(...)) or a dict. + """ + model_config = {"extra": "forbid"} + + model: str = "gpt-4o-2024-11-20" + retrieve_model: str | None = None + toc_check_page_num: int = 20 + max_page_num_each_node: int = 10 + max_token_num_each_node: int = 20000 + if_add_node_id: bool = True + if_add_node_summary: bool = True + if_add_doc_description: bool = True + if_add_node_text: bool = False diff --git a/pageindex/config.yaml b/pageindex/config.yaml deleted file mode 100644 index 591fe9331..000000000 --- a/pageindex/config.yaml +++ /dev/null @@ -1,10 +0,0 @@ -model: "gpt-4o-2024-11-20" -# model: "anthropic/claude-sonnet-4-6" -retrieve_model: "gpt-5.4" # defaults to `model` if not set -toc_check_page_num: 20 -max_page_num_each_node: 10 -max_token_num_each_node: 20000 -if_add_node_id: "yes" -if_add_node_summary: "yes" -if_add_doc_description: "no" -if_add_node_text: "no" \ No newline at end of file diff --git a/pageindex/errors.py b/pageindex/errors.py new file mode 100644 index 000000000..790b68ffd --- /dev/null +++ b/pageindex/errors.py @@ -0,0 +1,28 @@ +class PageIndexError(Exception): + """Base exception for all PageIndex SDK errors.""" + pass + + +class CollectionNotFoundError(PageIndexError): + """Collection does not exist.""" + pass + + +class DocumentNotFoundError(PageIndexError): + """Document ID not found.""" + pass + + +class IndexingError(PageIndexError): + """Indexing pipeline failure.""" + pass + + +class CloudAPIError(PageIndexError): + """Cloud API returned error.""" + pass + + +class FileTypeError(PageIndexError): + """Unsupported file type.""" + pass diff --git a/pageindex/events.py b/pageindex/events.py new file mode 100644 index 000000000..fc8f30497 --- /dev/null +++ b/pageindex/events.py @@ -0,0 +1,9 @@ +from dataclasses import dataclass +from typing import Literal, Any + + +@dataclass +class QueryEvent: + """Event emitted during streaming query.""" + type: Literal["reasoning", "tool_call", "tool_result", "answer_delta", "answer_done"] + data: Any diff --git a/pageindex/index/__init__.py b/pageindex/index/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pageindex/index/legacy_utils.py b/pageindex/index/legacy_utils.py new file mode 100644 index 000000000..1d6aab510 --- /dev/null +++ b/pageindex/index/legacy_utils.py @@ -0,0 +1,2 @@ +# Re-export from the original utils.py for backward compatibility +from ..utils import * diff --git a/pageindex/index/page_index.py b/pageindex/index/page_index.py new file mode 100644 index 000000000..291309066 --- /dev/null +++ b/pageindex/index/page_index.py @@ -0,0 +1,1155 @@ +import os +import json +import copy +import math +import random +import re +from .legacy_utils import * +import os +from concurrent.futures import ThreadPoolExecutor, as_completed + + +################### check title in page ######################################################### +async def check_title_appearance(item, page_list, start_index=1, model=None): + title=item['title'] + if 'physical_index' not in item or item['physical_index'] is None: + return {'list_index': item.get('list_index'), 'answer': 'no', 'title':title, 'page_number': None} + + + page_number = item['physical_index'] + page_text = page_list[page_number-start_index][0] + + + prompt = f""" + Your job is to check if the given section appears or starts in the given page_text. + + Note: do fuzzy matching, ignore any space inconsistency in the page_text. + + The given section title is {title}. + The given page_text is {page_text}. + + Reply format: + {{ + + "thinking": + "answer": "yes or no" (yes if the section appears or starts in the page_text, no otherwise) + }} + Directly return the final JSON structure. Do not output anything else.""" + + response = await llm_acompletion(model=model, prompt=prompt) + response = extract_json(response) + if 'answer' in response: + answer = response['answer'] + else: + answer = 'no' + return {'list_index': item['list_index'], 'answer': answer, 'title': title, 'page_number': page_number} + + +async def check_title_appearance_in_start(title, page_text, model=None, logger=None): + prompt = f""" + You will be given the current section title and the current page_text. + Your job is to check if the current section starts in the beginning of the given page_text. + If there are other contents before the current section title, then the current section does not start in the beginning of the given page_text. + If the current section title is the first content in the given page_text, then the current section starts in the beginning of the given page_text. + + Note: do fuzzy matching, ignore any space inconsistency in the page_text. + + The given section title is {title}. + The given page_text is {page_text}. + + reply format: + {{ + "thinking": + "start_begin": "yes or no" (yes if the section starts in the beginning of the page_text, no otherwise) + }} + Directly return the final JSON structure. Do not output anything else.""" + + response = await llm_acompletion(model=model, prompt=prompt) + response = extract_json(response) + if logger: + logger.info(f"Response: {response}") + return response.get("start_begin", "no") + + +async def check_title_appearance_in_start_concurrent(structure, page_list, model=None, logger=None): + if logger: + logger.info("Checking title appearance in start concurrently") + + # skip items without physical_index + for item in structure: + if item.get('physical_index') is None: + item['appear_start'] = 'no' + + # only for items with valid physical_index + tasks = [] + valid_items = [] + for item in structure: + if item.get('physical_index') is not None: + page_text = page_list[item['physical_index'] - 1][0] + tasks.append(check_title_appearance_in_start(item['title'], page_text, model=model, logger=logger)) + valid_items.append(item) + + results = await asyncio.gather(*tasks, return_exceptions=True) + for item, result in zip(valid_items, results): + if isinstance(result, Exception): + if logger: + logger.error(f"Error checking start for {item['title']}: {result}") + item['appear_start'] = 'no' + else: + item['appear_start'] = result + + return structure + + +def toc_detector_single_page(content, model=None): + prompt = f""" + Your job is to detect if there is a table of content provided in the given text. + + Given text: {content} + + return the following JSON format: + {{ + "thinking": + "toc_detected": "", + }} + + Directly return the final JSON structure. Do not output anything else. + Please note: abstract,summary, notation list, figure list, table list, etc. are not table of contents.""" + + response = llm_completion(model=model, prompt=prompt) + # print('response', response) + json_content = extract_json(response) + return json_content['toc_detected'] + + +def check_if_toc_extraction_is_complete(content, toc, model=None): + prompt = f""" + You are given a partial document and a table of contents. + Your job is to check if the table of contents is complete, which it contains all the main sections in the partial document. + + Reply format: + {{ + "thinking": + "completed": "yes" or "no" + }} + Directly return the final JSON structure. Do not output anything else.""" + + prompt = prompt + '\n Document:\n' + content + '\n Table of contents:\n' + toc + response = llm_completion(model=model, prompt=prompt) + json_content = extract_json(response) + return json_content['completed'] + + +def check_if_toc_transformation_is_complete(content, toc, model=None): + prompt = f""" + You are given a raw table of contents and a table of contents. + Your job is to check if the table of contents is complete. + + Reply format: + {{ + "thinking": + "completed": "yes" or "no" + }} + Directly return the final JSON structure. Do not output anything else.""" + + prompt = prompt + '\n Raw Table of contents:\n' + content + '\n Cleaned Table of contents:\n' + toc + response = llm_completion(model=model, prompt=prompt) + json_content = extract_json(response) + return json_content['completed'] + +def extract_toc_content(content, model=None): + prompt = f""" + Your job is to extract the full table of contents from the given text, replace ... with : + + Given text: {content} + + Directly return the full table of contents content. Do not output anything else.""" + + response, finish_reason = llm_completion(model=model, prompt=prompt, return_finish_reason=True) + + if_complete = check_if_toc_transformation_is_complete(content, response, model) + if if_complete == "yes" and finish_reason == "finished": + return response + + chat_history = [ + {"role": "user", "content": prompt}, + {"role": "assistant", "content": response}, + ] + prompt = f"""please continue the generation of table of contents , directly output the remaining part of the structure""" + new_response, finish_reason = llm_completion(model=model, prompt=prompt, chat_history=chat_history, return_finish_reason=True) + response = response + new_response + if_complete = check_if_toc_transformation_is_complete(content, response, model) + + attempt = 0 + max_attempts = 5 + + while not (if_complete == "yes" and finish_reason == "finished"): + attempt += 1 + if attempt > max_attempts: + raise Exception('Failed to complete table of contents after maximum retries') + + chat_history = [ + {"role": "user", "content": prompt}, + {"role": "assistant", "content": response}, + ] + prompt = f"""please continue the generation of table of contents , directly output the remaining part of the structure""" + new_response, finish_reason = llm_completion(model=model, prompt=prompt, chat_history=chat_history, return_finish_reason=True) + response = response + new_response + if_complete = check_if_toc_transformation_is_complete(content, response, model) + + return response + +def detect_page_index(toc_content, model=None): + print('start detect_page_index') + prompt = f""" + You will be given a table of contents. + + Your job is to detect if there are page numbers/indices given within the table of contents. + + Given text: {toc_content} + + Reply format: + {{ + "thinking": + "page_index_given_in_toc": "" + }} + Directly return the final JSON structure. Do not output anything else.""" + + response = llm_completion(model=model, prompt=prompt) + json_content = extract_json(response) + return json_content['page_index_given_in_toc'] + +def toc_extractor(page_list, toc_page_list, model): + def transform_dots_to_colon(text): + text = re.sub(r'\.{5,}', ': ', text) + # Handle dots separated by spaces + text = re.sub(r'(?:\. ){5,}\.?', ': ', text) + return text + + toc_content = "" + for page_index in toc_page_list: + toc_content += page_list[page_index][0] + toc_content = transform_dots_to_colon(toc_content) + has_page_index = detect_page_index(toc_content, model=model) + + return { + "toc_content": toc_content, + "page_index_given_in_toc": has_page_index + } + + + + +def toc_index_extractor(toc, content, model=None): + print('start toc_index_extractor') + toc_extractor_prompt = """ + You are given a table of contents in a json format and several pages of a document, your job is to add the physical_index to the table of contents in the json format. + + The provided pages contains tags like and to indicate the physical location of the page X. + + The structure variable is the numeric system which represents the index of the hierarchy section in the table of contents. For example, the first section has structure index 1, the first subsection has structure index 1.1, the second subsection has structure index 1.2, etc. + + The response should be in the following JSON format: + [ + { + "structure": (string), + "title": , + "physical_index": "<physical_index_X>" (keep the format) + }, + ... + ] + + Only add the physical_index to the sections that are in the provided pages. + If the section is not in the provided pages, do not add the physical_index to it. + Directly return the final JSON structure. Do not output anything else.""" + + prompt = toc_extractor_prompt + '\nTable of contents:\n' + str(toc) + '\nDocument pages:\n' + content + response = llm_completion(model=model, prompt=prompt) + json_content = extract_json(response) + return json_content + + + +def toc_transformer(toc_content, model=None): + print('start toc_transformer') + init_prompt = """ + You are given a table of contents, You job is to transform the whole table of content into a JSON format included table_of_contents. + + structure is the numeric system which represents the index of the hierarchy section in the table of contents. For example, the first section has structure index 1, the first subsection has structure index 1.1, the second subsection has structure index 1.2, etc. + + The response should be in the following JSON format: + { + table_of_contents: [ + { + "structure": <structure index, "x.x.x" or None> (string), + "title": <title of the section>, + "page": <page number or None>, + }, + ... + ], + } + You should transform the full table of contents in one go. + Directly return the final JSON structure, do not output anything else. """ + + prompt = init_prompt + '\n Given table of contents\n:' + toc_content + last_complete, finish_reason = llm_completion(model=model, prompt=prompt, return_finish_reason=True) + if_complete = check_if_toc_transformation_is_complete(toc_content, last_complete, model) + if if_complete == "yes" and finish_reason == "finished": + last_complete = extract_json(last_complete) + cleaned_response=convert_page_to_int(last_complete['table_of_contents']) + return cleaned_response + + last_complete = get_json_content(last_complete) + attempt = 0 + max_attempts = 5 + while not (if_complete == "yes" and finish_reason == "finished"): + attempt += 1 + if attempt > max_attempts: + raise Exception('Failed to complete toc transformation after maximum retries') + position = last_complete.rfind('}') + if position != -1: + last_complete = last_complete[:position+2] + prompt = f""" + Your task is to continue the table of contents json structure, directly output the remaining part of the json structure. + The response should be in the following JSON format: + + The raw table of contents json structure is: + {toc_content} + + The incomplete transformed table of contents json structure is: + {last_complete} + + Please continue the json structure, directly output the remaining part of the json structure.""" + + new_complete, finish_reason = llm_completion(model=model, prompt=prompt, return_finish_reason=True) + + if new_complete.startswith('```json'): + new_complete = get_json_content(new_complete) + last_complete = last_complete+new_complete + + if_complete = check_if_toc_transformation_is_complete(toc_content, last_complete, model) + + + last_complete = extract_json(last_complete) + + cleaned_response=convert_page_to_int(last_complete['table_of_contents']) + return cleaned_response + + + + +def find_toc_pages(start_page_index, page_list, opt, logger=None): + print('start find_toc_pages') + last_page_is_yes = False + toc_page_list = [] + i = start_page_index + + while i < len(page_list): + # Only check beyond max_pages if we're still finding TOC pages + if i >= opt.toc_check_page_num and not last_page_is_yes: + break + detected_result = toc_detector_single_page(page_list[i][0],model=opt.model) + if detected_result == 'yes': + if logger: + logger.info(f'Page {i} has toc') + toc_page_list.append(i) + last_page_is_yes = True + elif detected_result == 'no' and last_page_is_yes: + if logger: + logger.info(f'Found the last page with toc: {i-1}') + break + i += 1 + + if not toc_page_list and logger: + logger.info('No toc found') + + return toc_page_list + +def remove_page_number(data): + if isinstance(data, dict): + data.pop('page_number', None) + for key in list(data.keys()): + if 'nodes' in key: + remove_page_number(data[key]) + elif isinstance(data, list): + for item in data: + remove_page_number(item) + return data + +def extract_matching_page_pairs(toc_page, toc_physical_index, start_page_index): + pairs = [] + for phy_item in toc_physical_index: + for page_item in toc_page: + if phy_item.get('title') == page_item.get('title'): + physical_index = phy_item.get('physical_index') + if physical_index is not None and int(physical_index) >= start_page_index: + pairs.append({ + 'title': phy_item.get('title'), + 'page': page_item.get('page'), + 'physical_index': physical_index + }) + return pairs + + +def calculate_page_offset(pairs): + differences = [] + for pair in pairs: + try: + physical_index = pair['physical_index'] + page_number = pair['page'] + difference = physical_index - page_number + differences.append(difference) + except (KeyError, TypeError): + continue + + if not differences: + return None + + difference_counts = {} + for diff in differences: + difference_counts[diff] = difference_counts.get(diff, 0) + 1 + + most_common = max(difference_counts.items(), key=lambda x: x[1])[0] + + return most_common + +def add_page_offset_to_toc_json(data, offset): + for i in range(len(data)): + if data[i].get('page') is not None and isinstance(data[i]['page'], int): + data[i]['physical_index'] = data[i]['page'] + offset + del data[i]['page'] + + return data + + + +def page_list_to_group_text(page_contents, token_lengths, max_tokens=20000, overlap_page=1): + num_tokens = sum(token_lengths) + + if num_tokens <= max_tokens: + # merge all pages into one text + page_text = "".join(page_contents) + return [page_text] + + subsets = [] + current_subset = [] + current_token_count = 0 + + expected_parts_num = math.ceil(num_tokens / max_tokens) + average_tokens_per_part = math.ceil(((num_tokens / expected_parts_num) + max_tokens) / 2) + + for i, (page_content, page_tokens) in enumerate(zip(page_contents, token_lengths)): + if current_token_count + page_tokens > average_tokens_per_part: + + subsets.append(''.join(current_subset)) + # Start new subset from overlap if specified + overlap_start = max(i - overlap_page, 0) + current_subset = page_contents[overlap_start:i] + current_token_count = sum(token_lengths[overlap_start:i]) + + # Add current page to the subset + current_subset.append(page_content) + current_token_count += page_tokens + + # Add the last subset if it contains any pages + if current_subset: + subsets.append(''.join(current_subset)) + + print('divide page_list to groups', len(subsets)) + return subsets + +def add_page_number_to_toc(part, structure, model=None): + fill_prompt_seq = """ + You are given an JSON structure of a document and a partial part of the document. Your task is to check if the title that is described in the structure is started in the partial given document. + + The provided text contains tags like <physical_index_X> and <physical_index_X> to indicate the physical location of the page X. + + If the full target section starts in the partial given document, insert the given JSON structure with the "start": "yes", and "start_index": "<physical_index_X>". + + If the full target section does not start in the partial given document, insert "start": "no", "start_index": None. + + The response should be in the following format. + [ + { + "structure": <structure index, "x.x.x" or None> (string), + "title": <title of the section>, + "start": "<yes or no>", + "physical_index": "<physical_index_X> (keep the format)" or None + }, + ... + ] + The given structure contains the result of the previous part, you need to fill the result of the current part, do not change the previous result. + Directly return the final JSON structure. Do not output anything else.""" + + prompt = fill_prompt_seq + f"\n\nCurrent Partial Document:\n{part}\n\nGiven Structure\n{json.dumps(structure, indent=2)}\n" + current_json_raw = llm_completion(model=model, prompt=prompt) + json_result = extract_json(current_json_raw) + + for item in json_result: + if 'start' in item: + del item['start'] + return json_result + + +def remove_first_physical_index_section(text): + """ + Removes the first section between <physical_index_X> and <physical_index_X> tags, + and returns the remaining text. + """ + pattern = r'<physical_index_\d+>.*?<physical_index_\d+>' + match = re.search(pattern, text, re.DOTALL) + if match: + # Remove the first matched section + return text.replace(match.group(0), '', 1) + return text + +### add verify completeness +def generate_toc_continue(toc_content, part, model=None): + print('start generate_toc_continue') + prompt = """ + You are an expert in extracting hierarchical tree structure. + You are given a tree structure of the previous part and the text of the current part. + Your task is to continue the tree structure from the previous part to include the current part. + + The structure variable is the numeric system which represents the index of the hierarchy section in the table of contents. For example, the first section has structure index 1, the first subsection has structure index 1.1, the second subsection has structure index 1.2, etc. + + For the title, you need to extract the original title from the text, only fix the space inconsistency. + + The provided text contains tags like <physical_index_X> and <physical_index_X> to indicate the start and end of page X. \ + + For the physical_index, you need to extract the physical index of the start of the section from the text. Keep the <physical_index_X> format. + + The response should be in the following format. + [ + { + "structure": <structure index, "x.x.x"> (string), + "title": <title of the section, keep the original title>, + "physical_index": "<physical_index_X> (keep the format)" + }, + ... + ] + + Directly return the additional part of the final JSON structure. Do not output anything else.""" + + prompt = prompt + '\nGiven text\n:' + part + '\nPrevious tree structure\n:' + json.dumps(toc_content, indent=2) + response, finish_reason = llm_completion(model=model, prompt=prompt, return_finish_reason=True) + if finish_reason == 'finished': + return extract_json(response) + else: + raise Exception(f'finish reason: {finish_reason}') + +### add verify completeness +def generate_toc_init(part, model=None): + print('start generate_toc_init') + prompt = """ + You are an expert in extracting hierarchical tree structure, your task is to generate the tree structure of the document. + + The structure variable is the numeric system which represents the index of the hierarchy section in the table of contents. For example, the first section has structure index 1, the first subsection has structure index 1.1, the second subsection has structure index 1.2, etc. + + For the title, you need to extract the original title from the text, only fix the space inconsistency. + + The provided text contains tags like <physical_index_X> and <physical_index_X> to indicate the start and end of page X. + + For the physical_index, you need to extract the physical index of the start of the section from the text. Keep the <physical_index_X> format. + + The response should be in the following format. + [ + {{ + "structure": <structure index, "x.x.x"> (string), + "title": <title of the section, keep the original title>, + "physical_index": "<physical_index_X> (keep the format)" + }}, + + ], + + + Directly return the final JSON structure. Do not output anything else.""" + + prompt = prompt + '\nGiven text\n:' + part + response, finish_reason = llm_completion(model=model, prompt=prompt, return_finish_reason=True) + + if finish_reason == 'finished': + return extract_json(response) + else: + raise Exception(f'finish reason: {finish_reason}') + +def process_no_toc(page_list, start_index=1, model=None, logger=None): + page_contents=[] + token_lengths=[] + for page_index in range(start_index, start_index+len(page_list)): + page_text = f"<physical_index_{page_index}>\n{page_list[page_index-start_index][0]}\n<physical_index_{page_index}>\n\n" + page_contents.append(page_text) + token_lengths.append(count_tokens(page_text, model)) + group_texts = page_list_to_group_text(page_contents, token_lengths) + logger.info(f'len(group_texts): {len(group_texts)}') + + toc_with_page_number= generate_toc_init(group_texts[0], model) + for group_text in group_texts[1:]: + toc_with_page_number_additional = generate_toc_continue(toc_with_page_number, group_text, model) + toc_with_page_number.extend(toc_with_page_number_additional) + logger.info(f'generate_toc: {toc_with_page_number}') + + toc_with_page_number = convert_physical_index_to_int(toc_with_page_number) + logger.info(f'convert_physical_index_to_int: {toc_with_page_number}') + + return toc_with_page_number + +def process_toc_no_page_numbers(toc_content, toc_page_list, page_list, start_index=1, model=None, logger=None): + page_contents=[] + token_lengths=[] + toc_content = toc_transformer(toc_content, model) + logger.info(f'toc_transformer: {toc_content}') + for page_index in range(start_index, start_index+len(page_list)): + page_text = f"<physical_index_{page_index}>\n{page_list[page_index-start_index][0]}\n<physical_index_{page_index}>\n\n" + page_contents.append(page_text) + token_lengths.append(count_tokens(page_text, model)) + + group_texts = page_list_to_group_text(page_contents, token_lengths) + logger.info(f'len(group_texts): {len(group_texts)}') + + toc_with_page_number=copy.deepcopy(toc_content) + for group_text in group_texts: + toc_with_page_number = add_page_number_to_toc(group_text, toc_with_page_number, model) + logger.info(f'add_page_number_to_toc: {toc_with_page_number}') + + toc_with_page_number = convert_physical_index_to_int(toc_with_page_number) + logger.info(f'convert_physical_index_to_int: {toc_with_page_number}') + + return toc_with_page_number + + + +def process_toc_with_page_numbers(toc_content, toc_page_list, page_list, toc_check_page_num=None, model=None, logger=None): + toc_with_page_number = toc_transformer(toc_content, model) + logger.info(f'toc_with_page_number: {toc_with_page_number}') + + toc_no_page_number = remove_page_number(copy.deepcopy(toc_with_page_number)) + + start_page_index = toc_page_list[-1] + 1 + main_content = "" + for page_index in range(start_page_index, min(start_page_index + toc_check_page_num, len(page_list))): + main_content += f"<physical_index_{page_index+1}>\n{page_list[page_index][0]}\n<physical_index_{page_index+1}>\n\n" + + toc_with_physical_index = toc_index_extractor(toc_no_page_number, main_content, model) + logger.info(f'toc_with_physical_index: {toc_with_physical_index}') + + toc_with_physical_index = convert_physical_index_to_int(toc_with_physical_index) + logger.info(f'toc_with_physical_index: {toc_with_physical_index}') + + matching_pairs = extract_matching_page_pairs(toc_with_page_number, toc_with_physical_index, start_page_index) + logger.info(f'matching_pairs: {matching_pairs}') + + offset = calculate_page_offset(matching_pairs) + logger.info(f'offset: {offset}') + + toc_with_page_number = add_page_offset_to_toc_json(toc_with_page_number, offset) + logger.info(f'toc_with_page_number: {toc_with_page_number}') + + toc_with_page_number = process_none_page_numbers(toc_with_page_number, page_list, model=model) + logger.info(f'toc_with_page_number: {toc_with_page_number}') + + return toc_with_page_number + + + +##check if needed to process none page numbers +def process_none_page_numbers(toc_items, page_list, start_index=1, model=None): + for i, item in enumerate(toc_items): + if "physical_index" not in item: + # logger.info(f"fix item: {item}") + # Find previous physical_index + prev_physical_index = 0 # Default if no previous item exists + for j in range(i - 1, -1, -1): + if toc_items[j].get('physical_index') is not None: + prev_physical_index = toc_items[j]['physical_index'] + break + + # Find next physical_index + next_physical_index = -1 # Default if no next item exists + for j in range(i + 1, len(toc_items)): + if toc_items[j].get('physical_index') is not None: + next_physical_index = toc_items[j]['physical_index'] + break + + page_contents = [] + for page_index in range(prev_physical_index, next_physical_index+1): + # Add bounds checking to prevent IndexError + list_index = page_index - start_index + if list_index >= 0 and list_index < len(page_list): + page_text = f"<physical_index_{page_index}>\n{page_list[list_index][0]}\n<physical_index_{page_index}>\n\n" + page_contents.append(page_text) + else: + continue + + item_copy = copy.deepcopy(item) + del item_copy['page'] + result = add_page_number_to_toc(page_contents, item_copy, model) + if isinstance(result[0]['physical_index'], str) and result[0]['physical_index'].startswith('<physical_index'): + item['physical_index'] = int(result[0]['physical_index'].split('_')[-1].rstrip('>').strip()) + del item['page'] + + return toc_items + + + + +def check_toc(page_list, opt=None): + toc_page_list = find_toc_pages(start_page_index=0, page_list=page_list, opt=opt) + if len(toc_page_list) == 0: + print('no toc found') + return {'toc_content': None, 'toc_page_list': [], 'page_index_given_in_toc': 'no'} + else: + print('toc found') + toc_json = toc_extractor(page_list, toc_page_list, opt.model) + + if toc_json['page_index_given_in_toc'] == 'yes': + print('index found') + return {'toc_content': toc_json['toc_content'], 'toc_page_list': toc_page_list, 'page_index_given_in_toc': 'yes'} + else: + current_start_index = toc_page_list[-1] + 1 + + while (toc_json['page_index_given_in_toc'] == 'no' and + current_start_index < len(page_list) and + current_start_index < opt.toc_check_page_num): + + additional_toc_pages = find_toc_pages( + start_page_index=current_start_index, + page_list=page_list, + opt=opt + ) + + if len(additional_toc_pages) == 0: + break + + additional_toc_json = toc_extractor(page_list, additional_toc_pages, opt.model) + if additional_toc_json['page_index_given_in_toc'] == 'yes': + print('index found') + return {'toc_content': additional_toc_json['toc_content'], 'toc_page_list': additional_toc_pages, 'page_index_given_in_toc': 'yes'} + + else: + current_start_index = additional_toc_pages[-1] + 1 + print('index not found') + return {'toc_content': toc_json['toc_content'], 'toc_page_list': toc_page_list, 'page_index_given_in_toc': 'no'} + + + + + + +################### fix incorrect toc ######################################################### +async def single_toc_item_index_fixer(section_title, content, model=None): + toc_extractor_prompt = """ + You are given a section title and several pages of a document, your job is to find the physical index of the start page of the section in the partial document. + + The provided pages contains tags like <physical_index_X> and <physical_index_X> to indicate the physical location of the page X. + + Reply in a JSON format: + { + "thinking": <explain which page, started and closed by <physical_index_X>, contains the start of this section>, + "physical_index": "<physical_index_X>" (keep the format) + } + Directly return the final JSON structure. Do not output anything else.""" + + prompt = toc_extractor_prompt + '\nSection Title:\n' + str(section_title) + '\nDocument pages:\n' + content + response = await llm_acompletion(model=model, prompt=prompt) + json_content = extract_json(response) + return convert_physical_index_to_int(json_content['physical_index']) + + + +async def fix_incorrect_toc(toc_with_page_number, page_list, incorrect_results, start_index=1, model=None, logger=None): + print(f'start fix_incorrect_toc with {len(incorrect_results)} incorrect results') + incorrect_indices = {result['list_index'] for result in incorrect_results} + + end_index = len(page_list) + start_index - 1 + + incorrect_results_and_range_logs = [] + # Helper function to process and check a single incorrect item + async def process_and_check_item(incorrect_item): + list_index = incorrect_item['list_index'] + + # Check if list_index is valid + if list_index < 0 or list_index >= len(toc_with_page_number): + # Return an invalid result for out-of-bounds indices + return { + 'list_index': list_index, + 'title': incorrect_item['title'], + 'physical_index': incorrect_item.get('physical_index'), + 'is_valid': False + } + + # Find the previous correct item + prev_correct = None + for i in range(list_index-1, -1, -1): + if i not in incorrect_indices and i >= 0 and i < len(toc_with_page_number): + physical_index = toc_with_page_number[i].get('physical_index') + if physical_index is not None: + prev_correct = physical_index + break + # If no previous correct item found, use start_index + if prev_correct is None: + prev_correct = start_index - 1 + + # Find the next correct item + next_correct = None + for i in range(list_index+1, len(toc_with_page_number)): + if i not in incorrect_indices and i >= 0 and i < len(toc_with_page_number): + physical_index = toc_with_page_number[i].get('physical_index') + if physical_index is not None: + next_correct = physical_index + break + # If no next correct item found, use end_index + if next_correct is None: + next_correct = end_index + + incorrect_results_and_range_logs.append({ + 'list_index': list_index, + 'title': incorrect_item['title'], + 'prev_correct': prev_correct, + 'next_correct': next_correct + }) + + page_contents=[] + for page_index in range(prev_correct, next_correct+1): + # Add bounds checking to prevent IndexError + page_list_idx = page_index - start_index + if page_list_idx >= 0 and page_list_idx < len(page_list): + page_text = f"<physical_index_{page_index}>\n{page_list[page_list_idx][0]}\n<physical_index_{page_index}>\n\n" + page_contents.append(page_text) + else: + continue + content_range = ''.join(page_contents) + + physical_index_int = await single_toc_item_index_fixer(incorrect_item['title'], content_range, model) + + # Check if the result is correct + check_item = incorrect_item.copy() + check_item['physical_index'] = physical_index_int + check_result = await check_title_appearance(check_item, page_list, start_index, model) + + return { + 'list_index': list_index, + 'title': incorrect_item['title'], + 'physical_index': physical_index_int, + 'is_valid': check_result['answer'] == 'yes' + } + + # Process incorrect items concurrently + tasks = [ + process_and_check_item(item) + for item in incorrect_results + ] + results = await asyncio.gather(*tasks, return_exceptions=True) + for item, result in zip(incorrect_results, results): + if isinstance(result, Exception): + print(f"Processing item {item} generated an exception: {result}") + continue + results = [result for result in results if not isinstance(result, Exception)] + + # Update the toc_with_page_number with the fixed indices and check for any invalid results + invalid_results = [] + for result in results: + if result['is_valid']: + # Add bounds checking to prevent IndexError + list_idx = result['list_index'] + if 0 <= list_idx < len(toc_with_page_number): + toc_with_page_number[list_idx]['physical_index'] = result['physical_index'] + else: + # Index is out of bounds, treat as invalid + invalid_results.append({ + 'list_index': result['list_index'], + 'title': result['title'], + 'physical_index': result['physical_index'], + }) + else: + invalid_results.append({ + 'list_index': result['list_index'], + 'title': result['title'], + 'physical_index': result['physical_index'], + }) + + logger.info(f'incorrect_results_and_range_logs: {incorrect_results_and_range_logs}') + logger.info(f'invalid_results: {invalid_results}') + + return toc_with_page_number, invalid_results + + + +async def fix_incorrect_toc_with_retries(toc_with_page_number, page_list, incorrect_results, start_index=1, max_attempts=3, model=None, logger=None): + print('start fix_incorrect_toc') + fix_attempt = 0 + current_toc = toc_with_page_number + current_incorrect = incorrect_results + + while current_incorrect: + print(f"Fixing {len(current_incorrect)} incorrect results") + + current_toc, current_incorrect = await fix_incorrect_toc(current_toc, page_list, current_incorrect, start_index, model, logger) + + fix_attempt += 1 + if fix_attempt >= max_attempts: + logger.info("Maximum fix attempts reached") + break + + return current_toc, current_incorrect + + + + +################### verify toc ######################################################### +async def verify_toc(page_list, list_result, start_index=1, N=None, model=None): + print('start verify_toc') + # Find the last non-None physical_index + last_physical_index = None + for item in reversed(list_result): + if item.get('physical_index') is not None: + last_physical_index = item['physical_index'] + break + + # Early return if we don't have valid physical indices + if last_physical_index is None or last_physical_index < len(page_list)/2: + return 0, [] + + # Determine which items to check + if N is None: + print('check all items') + sample_indices = range(0, len(list_result)) + else: + N = min(N, len(list_result)) + print(f'check {N} items') + sample_indices = random.sample(range(0, len(list_result)), N) + + # Prepare items with their list indices + indexed_sample_list = [] + for idx in sample_indices: + item = list_result[idx] + # Skip items with None physical_index (these were invalidated by validate_and_truncate_physical_indices) + if item.get('physical_index') is not None: + item_with_index = item.copy() + item_with_index['list_index'] = idx # Add the original index in list_result + indexed_sample_list.append(item_with_index) + + # Run checks concurrently + tasks = [ + check_title_appearance(item, page_list, start_index, model) + for item in indexed_sample_list + ] + results = await asyncio.gather(*tasks) + + # Process results + correct_count = 0 + incorrect_results = [] + for result in results: + if result['answer'] == 'yes': + correct_count += 1 + else: + incorrect_results.append(result) + + # Calculate accuracy + checked_count = len(results) + accuracy = correct_count / checked_count if checked_count > 0 else 0 + print(f"accuracy: {accuracy*100:.2f}%") + return accuracy, incorrect_results + + + + + +################### main process ######################################################### +async def meta_processor(page_list, mode=None, toc_content=None, toc_page_list=None, start_index=1, opt=None, logger=None): + print(mode) + print(f'start_index: {start_index}') + + if mode == 'process_toc_with_page_numbers': + toc_with_page_number = process_toc_with_page_numbers(toc_content, toc_page_list, page_list, toc_check_page_num=opt.toc_check_page_num, model=opt.model, logger=logger) + elif mode == 'process_toc_no_page_numbers': + toc_with_page_number = process_toc_no_page_numbers(toc_content, toc_page_list, page_list, model=opt.model, logger=logger) + else: + toc_with_page_number = process_no_toc(page_list, start_index=start_index, model=opt.model, logger=logger) + + toc_with_page_number = [item for item in toc_with_page_number if item.get('physical_index') is not None] + + toc_with_page_number = validate_and_truncate_physical_indices( + toc_with_page_number, + len(page_list), + start_index=start_index, + logger=logger + ) + + accuracy, incorrect_results = await verify_toc(page_list, toc_with_page_number, start_index=start_index, model=opt.model) + + logger.info({ + 'mode': 'process_toc_with_page_numbers', + 'accuracy': accuracy, + 'incorrect_results': incorrect_results + }) + if accuracy == 1.0 and len(incorrect_results) == 0: + return toc_with_page_number + if accuracy > 0.6 and len(incorrect_results) > 0: + toc_with_page_number, incorrect_results = await fix_incorrect_toc_with_retries(toc_with_page_number, page_list, incorrect_results,start_index=start_index, max_attempts=3, model=opt.model, logger=logger) + return toc_with_page_number + else: + if mode == 'process_toc_with_page_numbers': + return await meta_processor(page_list, mode='process_toc_no_page_numbers', toc_content=toc_content, toc_page_list=toc_page_list, start_index=start_index, opt=opt, logger=logger) + elif mode == 'process_toc_no_page_numbers': + return await meta_processor(page_list, mode='process_no_toc', start_index=start_index, opt=opt, logger=logger) + else: + raise Exception('Processing failed') + + +async def process_large_node_recursively(node, page_list, opt=None, logger=None): + node_page_list = page_list[node['start_index']-1:node['end_index']] + token_num = sum([page[1] for page in node_page_list]) + + if node['end_index'] - node['start_index'] > opt.max_page_num_each_node and token_num >= opt.max_token_num_each_node: + print('large node:', node['title'], 'start_index:', node['start_index'], 'end_index:', node['end_index'], 'token_num:', token_num) + + node_toc_tree = await meta_processor(node_page_list, mode='process_no_toc', start_index=node['start_index'], opt=opt, logger=logger) + node_toc_tree = await check_title_appearance_in_start_concurrent(node_toc_tree, page_list, model=opt.model, logger=logger) + + # Filter out items with None physical_index before post_processing + valid_node_toc_items = [item for item in node_toc_tree if item.get('physical_index') is not None] + + if valid_node_toc_items and node['title'].strip() == valid_node_toc_items[0]['title'].strip(): + node['nodes'] = post_processing(valid_node_toc_items[1:], node['end_index']) + node['end_index'] = valid_node_toc_items[1]['start_index'] if len(valid_node_toc_items) > 1 else node['end_index'] + else: + node['nodes'] = post_processing(valid_node_toc_items, node['end_index']) + node['end_index'] = valid_node_toc_items[0]['start_index'] if valid_node_toc_items else node['end_index'] + + if 'nodes' in node and node['nodes']: + tasks = [ + process_large_node_recursively(child_node, page_list, opt, logger=logger) + for child_node in node['nodes'] + ] + await asyncio.gather(*tasks) + + return node + +async def tree_parser(page_list, opt, doc=None, logger=None): + check_toc_result = check_toc(page_list, opt) + logger.info(check_toc_result) + + if check_toc_result.get("toc_content") and check_toc_result["toc_content"].strip() and check_toc_result["page_index_given_in_toc"] == "yes": + toc_with_page_number = await meta_processor( + page_list, + mode='process_toc_with_page_numbers', + start_index=1, + toc_content=check_toc_result['toc_content'], + toc_page_list=check_toc_result['toc_page_list'], + opt=opt, + logger=logger) + else: + toc_with_page_number = await meta_processor( + page_list, + mode='process_no_toc', + start_index=1, + opt=opt, + logger=logger) + + toc_with_page_number = add_preface_if_needed(toc_with_page_number) + toc_with_page_number = await check_title_appearance_in_start_concurrent(toc_with_page_number, page_list, model=opt.model, logger=logger) + + # Filter out items with None physical_index before post_processings + valid_toc_items = [item for item in toc_with_page_number if item.get('physical_index') is not None] + + toc_tree = post_processing(valid_toc_items, len(page_list)) + tasks = [ + process_large_node_recursively(node, page_list, opt, logger=logger) + for node in toc_tree + ] + await asyncio.gather(*tasks) + + return toc_tree + + +def page_index_main(doc, opt=None): + logger = JsonLogger(doc) + + is_valid_pdf = ( + (isinstance(doc, str) and os.path.isfile(doc) and doc.lower().endswith(".pdf")) or + isinstance(doc, BytesIO) + ) + if not is_valid_pdf: + raise ValueError("Unsupported input type. Expected a PDF file path or BytesIO object.") + + print('Parsing PDF...') + page_list = get_page_tokens(doc, model=opt.model) + + logger.info({'total_page_number': len(page_list)}) + logger.info({'total_token': sum([page[1] for page in page_list])}) + + async def page_index_builder(): + structure = await tree_parser(page_list, opt, doc=doc, logger=logger) + if opt.if_add_node_id: + write_node_id(structure) + if opt.if_add_node_text: + add_node_text(structure, page_list) + if opt.if_add_node_summary: + if not opt.if_add_node_text: + add_node_text(structure, page_list) + await generate_summaries_for_structure(structure, model=opt.model) + if not opt.if_add_node_text: + remove_structure_text(structure) + if opt.if_add_doc_description: + # Create a clean structure without unnecessary fields for description generation + clean_structure = create_clean_structure_for_description(structure) + doc_description = generate_doc_description(clean_structure, model=opt.model) + structure = format_structure(structure, order=['title', 'node_id', 'start_index', 'end_index', 'summary', 'text', 'nodes']) + return { + 'doc_name': get_pdf_name(doc), + 'doc_description': doc_description, + 'structure': structure, + } + structure = format_structure(structure, order=['title', 'node_id', 'start_index', 'end_index', 'summary', 'text', 'nodes']) + return { + 'doc_name': get_pdf_name(doc), + 'structure': structure, + } + + return asyncio.run(page_index_builder()) + + +def page_index(doc, model=None, toc_check_page_num=None, max_page_num_each_node=None, max_token_num_each_node=None, + if_add_node_id=None, if_add_node_summary=None, if_add_doc_description=None, if_add_node_text=None): + + from ..config import IndexConfig + user_opt = { + arg: value for arg, value in locals().items() + if arg != "doc" and value is not None + } + opt = IndexConfig(**user_opt) + return page_index_main(doc, opt) + + +def validate_and_truncate_physical_indices(toc_with_page_number, page_list_length, start_index=1, logger=None): + """ + Validates and truncates physical indices that exceed the actual document length. + This prevents errors when TOC references pages that don't exist in the document (e.g. the file is broken or incomplete). + """ + if not toc_with_page_number: + return toc_with_page_number + + max_allowed_page = page_list_length + start_index - 1 + truncated_items = [] + + for i, item in enumerate(toc_with_page_number): + if item.get('physical_index') is not None: + original_index = item['physical_index'] + if original_index > max_allowed_page: + item['physical_index'] = None + truncated_items.append({ + 'title': item.get('title', 'Unknown'), + 'original_index': original_index + }) + if logger: + logger.info(f"Removed physical_index for '{item.get('title', 'Unknown')}' (was {original_index}, too far beyond document)") + + if truncated_items and logger: + logger.info(f"Total removed items: {len(truncated_items)}") + + print(f"Document validation: {page_list_length} pages, max allowed index: {max_allowed_page}") + if truncated_items: + print(f"Truncated {len(truncated_items)} TOC items that exceeded document length") + + return toc_with_page_number \ No newline at end of file diff --git a/pageindex/index/page_index_md.py b/pageindex/index/page_index_md.py new file mode 100644 index 000000000..e6078c26f --- /dev/null +++ b/pageindex/index/page_index_md.py @@ -0,0 +1,341 @@ +import asyncio +import json +import re +import os +try: + from .legacy_utils import * +except: + from legacy_utils import * + +async def get_node_summary(node, summary_token_threshold=200, model=None): + node_text = node.get('text') + num_tokens = count_tokens(node_text, model=model) + if num_tokens < summary_token_threshold: + return node_text + else: + return await generate_node_summary(node, model=model) + + +async def generate_summaries_for_structure_md(structure, summary_token_threshold, model=None): + nodes = structure_to_list(structure) + tasks = [get_node_summary(node, summary_token_threshold=summary_token_threshold, model=model) for node in nodes] + summaries = await asyncio.gather(*tasks) + + for node, summary in zip(nodes, summaries): + if not node.get('nodes'): + node['summary'] = summary + else: + node['prefix_summary'] = summary + return structure + + +def extract_nodes_from_markdown(markdown_content): + header_pattern = r'^(#{1,6})\s+(.+)$' + code_block_pattern = r'^```' + node_list = [] + + lines = markdown_content.split('\n') + in_code_block = False + + for line_num, line in enumerate(lines, 1): + stripped_line = line.strip() + + # Check for code block delimiters (triple backticks) + if re.match(code_block_pattern, stripped_line): + in_code_block = not in_code_block + continue + + # Skip empty lines + if not stripped_line: + continue + + # Only look for headers when not inside a code block + if not in_code_block: + match = re.match(header_pattern, stripped_line) + if match: + title = match.group(2).strip() + node_list.append({'node_title': title, 'line_num': line_num}) + + return node_list, lines + + +def extract_node_text_content(node_list, markdown_lines): + all_nodes = [] + for node in node_list: + line_content = markdown_lines[node['line_num'] - 1] + header_match = re.match(r'^(#{1,6})', line_content) + + if header_match is None: + print(f"Warning: Line {node['line_num']} does not contain a valid header: '{line_content}'") + continue + + processed_node = { + 'title': node['node_title'], + 'line_num': node['line_num'], + 'level': len(header_match.group(1)) + } + all_nodes.append(processed_node) + + for i, node in enumerate(all_nodes): + start_line = node['line_num'] - 1 + if i + 1 < len(all_nodes): + end_line = all_nodes[i + 1]['line_num'] - 1 + else: + end_line = len(markdown_lines) + + node['text'] = '\n'.join(markdown_lines[start_line:end_line]).strip() + return all_nodes + +def update_node_list_with_text_token_count(node_list, model=None): + + def find_all_children(parent_index, parent_level, node_list): + """Find all direct and indirect children of a parent node""" + children_indices = [] + + # Look for children after the parent + for i in range(parent_index + 1, len(node_list)): + current_level = node_list[i]['level'] + + # If we hit a node at same or higher level than parent, stop + if current_level <= parent_level: + break + + # This is a descendant + children_indices.append(i) + + return children_indices + + # Make a copy to avoid modifying the original + result_list = node_list.copy() + + # Process nodes from end to beginning to ensure children are processed before parents + for i in range(len(result_list) - 1, -1, -1): + current_node = result_list[i] + current_level = current_node['level'] + + # Get all children of this node + children_indices = find_all_children(i, current_level, result_list) + + # Start with the node's own text + node_text = current_node.get('text', '') + total_text = node_text + + # Add all children's text + for child_index in children_indices: + child_text = result_list[child_index].get('text', '') + if child_text: + total_text += '\n' + child_text + + # Calculate token count for combined text + result_list[i]['text_token_count'] = count_tokens(total_text, model=model) + + return result_list + + +def tree_thinning_for_index(node_list, min_node_token=None, model=None): + def find_all_children(parent_index, parent_level, node_list): + children_indices = [] + + for i in range(parent_index + 1, len(node_list)): + current_level = node_list[i]['level'] + + if current_level <= parent_level: + break + + children_indices.append(i) + + return children_indices + + result_list = node_list.copy() + nodes_to_remove = set() + + for i in range(len(result_list) - 1, -1, -1): + if i in nodes_to_remove: + continue + + current_node = result_list[i] + current_level = current_node['level'] + + total_tokens = current_node.get('text_token_count', 0) + + if total_tokens < min_node_token: + children_indices = find_all_children(i, current_level, result_list) + + children_texts = [] + for child_index in sorted(children_indices): + if child_index not in nodes_to_remove: + child_text = result_list[child_index].get('text', '') + if child_text.strip(): + children_texts.append(child_text) + nodes_to_remove.add(child_index) + + if children_texts: + parent_text = current_node.get('text', '') + merged_text = parent_text + for child_text in children_texts: + if merged_text and not merged_text.endswith('\n'): + merged_text += '\n\n' + merged_text += child_text + + result_list[i]['text'] = merged_text + + result_list[i]['text_token_count'] = count_tokens(merged_text, model=model) + + for index in sorted(nodes_to_remove, reverse=True): + result_list.pop(index) + + return result_list + + +def build_tree_from_nodes(node_list): + if not node_list: + return [] + + stack = [] + root_nodes = [] + node_counter = 1 + + for node in node_list: + current_level = node['level'] + + tree_node = { + 'title': node['title'], + 'node_id': str(node_counter).zfill(4), + 'text': node['text'], + 'line_num': node['line_num'], + 'nodes': [] + } + node_counter += 1 + + while stack and stack[-1][1] >= current_level: + stack.pop() + + if not stack: + root_nodes.append(tree_node) + else: + parent_node, parent_level = stack[-1] + parent_node['nodes'].append(tree_node) + + stack.append((tree_node, current_level)) + + return root_nodes + + +def clean_tree_for_output(tree_nodes): + cleaned_nodes = [] + + for node in tree_nodes: + cleaned_node = { + 'title': node['title'], + 'node_id': node['node_id'], + 'text': node['text'], + 'line_num': node['line_num'] + } + + if node['nodes']: + cleaned_node['nodes'] = clean_tree_for_output(node['nodes']) + + cleaned_nodes.append(cleaned_node) + + return cleaned_nodes + + +async def md_to_tree(md_path, if_thinning=False, min_token_threshold=None, if_add_node_summary=False, summary_token_threshold=None, model=None, if_add_doc_description=False, if_add_node_text=False, if_add_node_id=True): + with open(md_path, 'r', encoding='utf-8') as f: + markdown_content = f.read() + line_count = markdown_content.count('\n') + 1 + + print(f"Extracting nodes from markdown...") + node_list, markdown_lines = extract_nodes_from_markdown(markdown_content) + + print(f"Extracting text content from nodes...") + nodes_with_content = extract_node_text_content(node_list, markdown_lines) + + if if_thinning: + nodes_with_content = update_node_list_with_text_token_count(nodes_with_content, model=model) + print(f"Thinning nodes...") + nodes_with_content = tree_thinning_for_index(nodes_with_content, min_token_threshold, model=model) + + print(f"Building tree from nodes...") + tree_structure = build_tree_from_nodes(nodes_with_content) + + if if_add_node_id: + write_node_id(tree_structure) + + print(f"Formatting tree structure...") + + if if_add_node_summary: + # Always include text for summary generation + tree_structure = format_structure(tree_structure, order = ['title', 'node_id', 'line_num', 'summary', 'prefix_summary', 'text', 'nodes']) + + print(f"Generating summaries for each node...") + tree_structure = await generate_summaries_for_structure_md(tree_structure, summary_token_threshold=summary_token_threshold, model=model) + + if not if_add_node_text: + # Remove text after summary generation if not requested + tree_structure = format_structure(tree_structure, order = ['title', 'node_id', 'line_num', 'summary', 'prefix_summary', 'nodes']) + + if if_add_doc_description: + print(f"Generating document description...") + clean_structure = create_clean_structure_for_description(tree_structure) + doc_description = generate_doc_description(clean_structure, model=model) + return { + 'doc_name': os.path.splitext(os.path.basename(md_path))[0], + 'doc_description': doc_description, + 'line_count': line_count, + 'structure': tree_structure, + } + else: + # No summaries needed, format based on text preference + if if_add_node_text: + tree_structure = format_structure(tree_structure, order = ['title', 'node_id', 'line_num', 'summary', 'prefix_summary', 'text', 'nodes']) + else: + tree_structure = format_structure(tree_structure, order = ['title', 'node_id', 'line_num', 'summary', 'prefix_summary', 'nodes']) + + return { + 'doc_name': os.path.splitext(os.path.basename(md_path))[0], + 'line_count': line_count, + 'structure': tree_structure, + } + + +if __name__ == "__main__": + import os + import json + + # MD_NAME = 'Detect-Order-Construct' + MD_NAME = 'cognitive-load' + MD_PATH = os.path.join(os.path.dirname(__file__), '..', 'examples/documents/', f'{MD_NAME}.md') + + + MODEL="gpt-4.1" + IF_THINNING=False + THINNING_THRESHOLD=5000 + SUMMARY_TOKEN_THRESHOLD=200 + IF_SUMMARY=True + + tree_structure = asyncio.run(md_to_tree( + md_path=MD_PATH, + if_thinning=IF_THINNING, + min_token_threshold=THINNING_THRESHOLD, + if_add_node_summary='yes' if IF_SUMMARY else 'no', + summary_token_threshold=SUMMARY_TOKEN_THRESHOLD, + model=MODEL)) + + print('\n' + '='*60) + print('TREE STRUCTURE') + print('='*60) + print_json(tree_structure) + + print('\n' + '='*60) + print('TABLE OF CONTENTS') + print('='*60) + print_toc(tree_structure['structure']) + + output_path = os.path.join(os.path.dirname(__file__), '..', 'results', f'{MD_NAME}_structure.json') + os.makedirs(os.path.dirname(output_path), exist_ok=True) + + with open(output_path, 'w', encoding='utf-8') as f: + json.dump(tree_structure, f, indent=2, ensure_ascii=False) + + print(f"\nTree structure saved to: {output_path}") \ No newline at end of file diff --git a/pageindex/index/pipeline.py b/pageindex/index/pipeline.py new file mode 100644 index 000000000..70d8c2fe6 --- /dev/null +++ b/pageindex/index/pipeline.py @@ -0,0 +1,122 @@ +# pageindex/index/pipeline.py +from __future__ import annotations +from ..parser.protocol import ContentNode, ParsedDocument + + +def detect_strategy(nodes: list[ContentNode]) -> str: + """Determine which indexing strategy to use based on node data.""" + if any(n.level is not None for n in nodes): + return "level_based" + return "content_based" + + +def build_tree_from_levels(nodes: list[ContentNode]) -> list[dict]: + """Strategy 0: Build tree from explicit level information. + Adapted from pageindex/page_index_md.py:build_tree_from_nodes.""" + stack = [] + root_nodes = [] + + for node in nodes: + tree_node = { + "title": node.title or "", + "text": node.content, + "line_num": node.index, + "nodes": [], + } + current_level = node.level or 1 + + while stack and stack[-1][1] >= current_level: + stack.pop() + + if not stack: + root_nodes.append(tree_node) + else: + parent_node, _ = stack[-1] + parent_node["nodes"].append(tree_node) + + stack.append((tree_node, current_level)) + + return root_nodes + + +def _run_async(coro): + """Run an async coroutine, handling the case where an event loop is already running.""" + import asyncio + import concurrent.futures + try: + asyncio.get_running_loop() + # Already inside an event loop -- run in a separate thread + with concurrent.futures.ThreadPoolExecutor(max_workers=1) as pool: + return pool.submit(asyncio.run, coro).result() + except RuntimeError: + return asyncio.run(coro) + + +def build_index(parsed: ParsedDocument, model: str = None, opt=None) -> dict: + """Main entry point: ParsedDocument -> tree structure dict. + Routes to the appropriate strategy and runs enhancement.""" + from .utils import (write_node_id, add_node_text, remove_structure_text, + generate_summaries_for_structure, generate_doc_description, + create_clean_structure_for_description) + from ..config import IndexConfig + + if opt is None: + opt = IndexConfig(model=model) if model else IndexConfig() + + nodes = parsed.nodes + strategy = detect_strategy(nodes) + + if strategy == "level_based": + structure = build_tree_from_levels(nodes) + # For level-based, text is already in the tree nodes + else: + # Strategies 1-3: convert ContentNode list to page_list format for existing pipeline + page_list = [(n.content, n.tokens) for n in nodes] + structure = _run_async(_content_based_pipeline(page_list, opt)) + + # Unified enhancement + if opt.if_add_node_id: + write_node_id(structure) + + if strategy != "level_based": + if opt.if_add_node_text or opt.if_add_node_summary: + add_node_text(structure, page_list) + + if opt.if_add_node_summary: + _run_async(generate_summaries_for_structure(structure, model=opt.model)) + + if not opt.if_add_node_text and strategy != "level_based": + remove_structure_text(structure) + + result = { + "doc_name": parsed.doc_name, + "structure": structure, + } + + if opt.if_add_doc_description: + clean_structure = create_clean_structure_for_description(structure) + result["doc_description"] = generate_doc_description( + clean_structure, model=opt.model + ) + + return result + + +class _NullLogger: + """Minimal logger that satisfies the tree_parser interface without writing files.""" + def info(self, message, **kwargs): pass + def error(self, message, **kwargs): pass + def debug(self, message, **kwargs): pass + + +async def _content_based_pipeline(page_list, opt): + """Strategies 1-3: delegates to the existing PDF pipeline from pageindex/page_index.py. + + The page_list is already in the format expected by tree_parser: + [(page_text, token_count), ...] + """ + from .page_index import tree_parser + + logger = _NullLogger() + structure = await tree_parser(page_list, opt, doc=None, logger=logger) + return structure diff --git a/pageindex/index/utils.py b/pageindex/index/utils.py new file mode 100644 index 000000000..f416d6d3d --- /dev/null +++ b/pageindex/index/utils.py @@ -0,0 +1,431 @@ +import litellm +import logging +import time +import json +import copy +import re +import asyncio +import PyPDF2 + +logger = logging.getLogger(__name__) + + +def count_tokens(text, model=None): + if not text: + return 0 + return litellm.token_counter(model=model, text=text) + + +def llm_completion(model, prompt, chat_history=None, return_finish_reason=False): + if model: + model = model.removeprefix("litellm/") + max_retries = 10 + messages = list(chat_history) + [{"role": "user", "content": prompt}] if chat_history else [{"role": "user", "content": prompt}] + for i in range(max_retries): + try: + litellm.drop_params = True + response = litellm.completion( + model=model, + messages=messages, + temperature=0, + ) + content = response.choices[0].message.content + if return_finish_reason: + finish_reason = "max_output_reached" if response.choices[0].finish_reason == "length" else "finished" + return content, finish_reason + return content + except Exception as e: + logger.warning("Retrying LLM completion (%d/%d)", i + 1, max_retries) + logger.error(f"Error: {e}") + if i < max_retries - 1: + time.sleep(1) + else: + logger.error('Max retries reached for prompt: ' + prompt) + raise RuntimeError(f"LLM call failed after {max_retries} retries") from e + + + +async def llm_acompletion(model, prompt): + if model: + model = model.removeprefix("litellm/") + max_retries = 10 + messages = [{"role": "user", "content": prompt}] + for i in range(max_retries): + try: + litellm.drop_params = True + response = await litellm.acompletion( + model=model, + messages=messages, + temperature=0, + ) + return response.choices[0].message.content + except Exception as e: + logger.warning("Retrying async LLM completion (%d/%d)", i + 1, max_retries) + logger.error(f"Error: {e}") + if i < max_retries - 1: + await asyncio.sleep(1) + else: + logger.error('Max retries reached for prompt: ' + prompt) + raise RuntimeError(f"LLM call failed after {max_retries} retries") from e + + +def extract_json(content): + try: + # First, try to extract JSON enclosed within ```json and ``` + start_idx = content.find("```json") + if start_idx != -1: + start_idx += 7 # Adjust index to start after the delimiter + end_idx = content.rfind("```") + json_content = content[start_idx:end_idx].strip() + else: + # If no delimiters, assume entire content could be JSON + json_content = content.strip() + + # Clean up common issues that might cause parsing errors + json_content = json_content.replace('None', 'null') # Replace Python None with JSON null + json_content = json_content.replace('\n', ' ').replace('\r', ' ') # Remove newlines + json_content = ' '.join(json_content.split()) # Normalize whitespace + + # Attempt to parse and return the JSON object + return json.loads(json_content) + except json.JSONDecodeError as e: + logging.error(f"Failed to extract JSON: {e}") + # Try to clean up the content further if initial parsing fails + try: + # Remove any trailing commas before closing brackets/braces + json_content = json_content.replace(',]', ']').replace(',}', '}') + return json.loads(json_content) + except Exception: + logging.error("Failed to parse JSON even after cleanup") + return {} + except Exception as e: + logging.error(f"Unexpected error while extracting JSON: {e}") + return {} + + +def get_json_content(response): + start_idx = response.find("```json") + if start_idx != -1: + start_idx += 7 + response = response[start_idx:] + + end_idx = response.rfind("```") + if end_idx != -1: + response = response[:end_idx] + + json_content = response.strip() + return json_content + + +def write_node_id(data, node_id=0): + if isinstance(data, dict): + data['node_id'] = str(node_id).zfill(4) + node_id += 1 + for key in list(data.keys()): + if 'nodes' in key: + node_id = write_node_id(data[key], node_id) + elif isinstance(data, list): + for index in range(len(data)): + node_id = write_node_id(data[index], node_id) + return node_id + + +def remove_fields(data, fields=None): + fields = fields or ["text"] + if isinstance(data, dict): + return {k: remove_fields(v, fields) + for k, v in data.items() if k not in fields} + elif isinstance(data, list): + return [remove_fields(item, fields) for item in data] + return data + + +def structure_to_list(structure): + if isinstance(structure, dict): + nodes = [] + nodes.append(structure) + if 'nodes' in structure: + nodes.extend(structure_to_list(structure['nodes'])) + return nodes + elif isinstance(structure, list): + nodes = [] + for item in structure: + nodes.extend(structure_to_list(item)) + return nodes + + +def get_nodes(structure): + if isinstance(structure, dict): + structure_node = copy.deepcopy(structure) + structure_node.pop('nodes', None) + nodes = [structure_node] + for key in list(structure.keys()): + if 'nodes' in key: + nodes.extend(get_nodes(structure[key])) + return nodes + elif isinstance(structure, list): + nodes = [] + for item in structure: + nodes.extend(get_nodes(item)) + return nodes + + +def get_leaf_nodes(structure): + if isinstance(structure, dict): + if not structure['nodes']: + structure_node = copy.deepcopy(structure) + structure_node.pop('nodes', None) + return [structure_node] + else: + leaf_nodes = [] + for key in list(structure.keys()): + if 'nodes' in key: + leaf_nodes.extend(get_leaf_nodes(structure[key])) + return leaf_nodes + elif isinstance(structure, list): + leaf_nodes = [] + for item in structure: + leaf_nodes.extend(get_leaf_nodes(item)) + return leaf_nodes + + +async def generate_node_summary(node, model=None): + prompt = f"""You are given a part of a document, your task is to generate a description of the partial document about what are main points covered in the partial document. + + Partial Document Text: {node['text']} + + Directly return the description, do not include any other text. + """ + response = await llm_acompletion(model, prompt) + return response + + +async def generate_summaries_for_structure(structure, model=None): + nodes = structure_to_list(structure) + tasks = [generate_node_summary(node, model=model) for node in nodes] + summaries = await asyncio.gather(*tasks) + + for node, summary in zip(nodes, summaries): + node['summary'] = summary + return structure + + +def generate_doc_description(structure, model=None): + prompt = f"""Your are an expert in generating descriptions for a document. + You are given a structure of a document. Your task is to generate a one-sentence description for the document, which makes it easy to distinguish the document from other documents. + + Document Structure: {structure} + + Directly return the description, do not include any other text. + """ + response = llm_completion(model, prompt) + return response + + +def list_to_tree(data): + def get_parent_structure(structure): + """Helper function to get the parent structure code""" + if not structure: + return None + parts = str(structure).split('.') + return '.'.join(parts[:-1]) if len(parts) > 1 else None + + # First pass: Create nodes and track parent-child relationships + nodes = {} + root_nodes = [] + + for item in data: + structure = item.get('structure') + node = { + 'title': item.get('title'), + 'start_index': item.get('start_index'), + 'end_index': item.get('end_index'), + 'nodes': [] + } + + nodes[structure] = node + + # Find parent + parent_structure = get_parent_structure(structure) + + if parent_structure: + # Add as child to parent if parent exists + if parent_structure in nodes: + nodes[parent_structure]['nodes'].append(node) + else: + root_nodes.append(node) + else: + # No parent, this is a root node + root_nodes.append(node) + + # Helper function to clean empty children arrays + def clean_node(node): + if not node['nodes']: + del node['nodes'] + else: + for child in node['nodes']: + clean_node(child) + return node + + # Clean and return the tree + return [clean_node(node) for node in root_nodes] + + +def post_processing(structure, end_physical_index): + # First convert page_number to start_index in flat list + for i, item in enumerate(structure): + item['start_index'] = item.get('physical_index') + if i < len(structure) - 1: + if structure[i + 1].get('appear_start') == 'yes': + item['end_index'] = structure[i + 1]['physical_index']-1 + else: + item['end_index'] = structure[i + 1]['physical_index'] + else: + item['end_index'] = end_physical_index + tree = list_to_tree(structure) + if len(tree)!=0: + return tree + else: + ### remove appear_start + for node in structure: + node.pop('appear_start', None) + node.pop('physical_index', None) + return structure + + +def reorder_dict(data, key_order): + if not key_order: + return data + return {key: data[key] for key in key_order if key in data} + + +def format_structure(structure, order=None): + if not order: + return structure + if isinstance(structure, dict): + if 'nodes' in structure: + structure['nodes'] = format_structure(structure['nodes'], order) + if not structure.get('nodes'): + structure.pop('nodes', None) + structure = reorder_dict(structure, order) + elif isinstance(structure, list): + structure = [format_structure(item, order) for item in structure] + return structure + + +def create_clean_structure_for_description(structure): + """ + Create a clean structure for document description generation, + excluding unnecessary fields like 'text'. + """ + if isinstance(structure, dict): + clean_node = {} + # Only include essential fields for description + for key in ['title', 'node_id', 'summary', 'prefix_summary']: + if key in structure: + clean_node[key] = structure[key] + + # Recursively process child nodes + if 'nodes' in structure and structure['nodes']: + clean_node['nodes'] = create_clean_structure_for_description(structure['nodes']) + + return clean_node + elif isinstance(structure, list): + return [create_clean_structure_for_description(item) for item in structure] + else: + return structure + + +def _get_text_of_pages(page_list, start_page, end_page): + """Concatenate text from page_list for pages [start_page, end_page] (1-indexed).""" + text = "" + for page_num in range(start_page - 1, end_page): + text += page_list[page_num][0] + return text + + +def add_node_text(node, page_list): + """Recursively add 'text' field to each node from page_list content. + + Each node must have 'start_index' and 'end_index' (1-indexed page numbers). + page_list is [(page_text, token_count), ...]. + """ + if isinstance(node, dict): + start_page = node.get('start_index') + end_page = node.get('end_index') + if start_page is not None and end_page is not None: + node['text'] = _get_text_of_pages(page_list, start_page, end_page) + if 'nodes' in node: + add_node_text(node['nodes'], page_list) + elif isinstance(node, list): + for item in node: + add_node_text(item, page_list) + + +def remove_structure_text(data): + if isinstance(data, dict): + data.pop('text', None) + if 'nodes' in data: + remove_structure_text(data['nodes']) + elif isinstance(data, list): + for item in data: + remove_structure_text(item) + return data + + +# ── Functions migrated from retrieve.py ────────────────────────────────────── + +def parse_pages(pages: str) -> list[int]: + """Parse a pages string like '5-7', '3,8', or '12' into a sorted list of ints.""" + result = [] + for part in pages.split(','): + part = part.strip() + if '-' in part: + start, end = int(part.split('-', 1)[0].strip()), int(part.split('-', 1)[1].strip()) + if start > end: + raise ValueError(f"Invalid range '{part}': start must be <= end") + result.extend(range(start, end + 1)) + else: + result.append(int(part)) + result = [p for p in result if p >= 1] + result = sorted(set(result)) + if len(result) > 1000: + raise ValueError(f"Page range too large: {len(result)} pages (max 1000)") + return result + + +def get_pdf_page_content(file_path: str, page_nums: list[int]) -> list[dict]: + """Extract text for specific PDF pages (1-indexed), opening the PDF once.""" + with open(file_path, 'rb') as f: + pdf_reader = PyPDF2.PdfReader(f) + total = len(pdf_reader.pages) + valid_pages = [p for p in page_nums if 1 <= p <= total] + return [ + {'page': p, 'content': pdf_reader.pages[p - 1].extract_text() or ''} + for p in valid_pages + ] + + +def get_md_page_content(structure: list, page_nums: list[int]) -> list[dict]: + """ + For Markdown documents, 'pages' are line numbers. + Find nodes whose line_num falls within [min(page_nums), max(page_nums)] and return their text. + """ + if not page_nums: + return [] + min_line, max_line = min(page_nums), max(page_nums) + results = [] + seen = set() + + def _traverse(nodes): + for node in nodes: + ln = node.get('line_num') + if ln and min_line <= ln <= max_line and ln not in seen: + seen.add(ln) + results.append({'page': ln, 'content': node.get('text', '')}) + if node.get('nodes'): + _traverse(node['nodes']) + + _traverse(structure) + results.sort(key=lambda x: x['page']) + return results diff --git a/pageindex/page_index.py b/pageindex/page_index.py index 9004309fb..fab228345 100644 --- a/pageindex/page_index.py +++ b/pageindex/page_index.py @@ -1113,11 +1113,12 @@ async def page_index_builder(): def page_index(doc, model=None, toc_check_page_num=None, max_page_num_each_node=None, max_token_num_each_node=None, if_add_node_id=None, if_add_node_summary=None, if_add_doc_description=None, if_add_node_text=None): + from .config import IndexConfig user_opt = { arg: value for arg, value in locals().items() if arg != "doc" and value is not None } - opt = ConfigLoader().load(user_opt) + opt = IndexConfig(**user_opt) return page_index_main(doc, opt) diff --git a/pageindex/parser/__init__.py b/pageindex/parser/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pageindex/parser/markdown.py b/pageindex/parser/markdown.py new file mode 100644 index 000000000..f62013c4c --- /dev/null +++ b/pageindex/parser/markdown.py @@ -0,0 +1,59 @@ +import re +from pathlib import Path +from .protocol import ContentNode, ParsedDocument +from ..index.utils import count_tokens + + +class MarkdownParser: + def supported_extensions(self) -> list[str]: + return [".md", ".markdown"] + + def parse(self, file_path: str, **kwargs) -> ParsedDocument: + path = Path(file_path) + model = kwargs.get("model") + + with open(path, "r", encoding="utf-8") as f: + content = f.read() + + lines = content.split("\n") + headers = self._extract_headers(lines) + nodes = self._build_nodes(headers, lines, model) + + return ParsedDocument(doc_name=path.stem, nodes=nodes) + + def _extract_headers(self, lines: list[str]) -> list[dict]: + header_pattern = r"^(#{1,6})\s+(.+)$" + code_block_pattern = r"^```" + headers = [] + in_code_block = False + + for line_num, line in enumerate(lines, 1): + stripped = line.strip() + if re.match(code_block_pattern, stripped): + in_code_block = not in_code_block + continue + if not in_code_block and stripped: + match = re.match(header_pattern, stripped) + if match: + headers.append({ + "title": match.group(2).strip(), + "level": len(match.group(1)), + "line_num": line_num, + }) + return headers + + def _build_nodes(self, headers: list[dict], lines: list[str], model: str | None) -> list[ContentNode]: + nodes = [] + for i, header in enumerate(headers): + start = header["line_num"] - 1 + end = headers[i + 1]["line_num"] - 1 if i + 1 < len(headers) else len(lines) + text = "\n".join(lines[start:end]).strip() + tokens = count_tokens(text, model=model) + nodes.append(ContentNode( + content=text, + tokens=tokens, + title=header["title"], + index=header["line_num"], + level=header["level"], + )) + return nodes diff --git a/pageindex/parser/pdf.py b/pageindex/parser/pdf.py new file mode 100644 index 000000000..14e7f833e --- /dev/null +++ b/pageindex/parser/pdf.py @@ -0,0 +1,101 @@ +import pymupdf +from pathlib import Path +from .protocol import ContentNode, ParsedDocument +from ..index.utils import count_tokens + +# Minimum image dimension to keep (skip icons/artifacts) +_MIN_IMAGE_SIZE = 32 + + +class PdfParser: + def supported_extensions(self) -> list[str]: + return [".pdf"] + + def parse(self, file_path: str, **kwargs) -> ParsedDocument: + path = Path(file_path) + model = kwargs.get("model") + images_dir = kwargs.get("images_dir") + nodes = [] + + with pymupdf.open(str(path)) as doc: + for i, page in enumerate(doc): + page_num = i + 1 + if images_dir: + content, images = self._extract_page_with_images( + doc, page, page_num, images_dir) + else: + content = page.get_text() + images = None + + tokens = count_tokens(content, model=model) + nodes.append(ContentNode( + content=content or "", + tokens=tokens, + index=page_num, + images=images if images else None, + )) + + return ParsedDocument(doc_name=path.stem, nodes=nodes) + + @staticmethod + def _extract_page_with_images(doc, page, page_num: int, + images_dir: str) -> tuple[str, list[dict]]: + """Extract text and images from a page, preserving their relative order. + + Uses get_text("dict") to iterate blocks in reading order. + Text blocks become text; image blocks are saved to disk and replaced + with an inline placeholder: ![image](path) + """ + images_path = Path(images_dir) + images_path.mkdir(parents=True, exist_ok=True) + # Use path relative to cwd so downstream consumers can access directly + try: + rel_images_path = images_path.relative_to(Path.cwd()) + except ValueError: + rel_images_path = images_path + + parts: list[str] = [] + images: list[dict] = [] + img_idx = 0 + + for block in page.get_text("dict")["blocks"]: + if block["type"] == 0: # text block + lines = [] + for line in block["lines"]: + spans_text = "".join(span["text"] for span in line["spans"]) + lines.append(spans_text) + parts.append("\n".join(lines)) + + elif block["type"] == 1: # image block + width = block.get("width", 0) + height = block.get("height", 0) + if width < _MIN_IMAGE_SIZE or height < _MIN_IMAGE_SIZE: + continue + + image_bytes = block.get("image") + ext = block.get("ext", "png") + if not image_bytes: + continue + + try: + pix = pymupdf.Pixmap(image_bytes) + if pix.n > 4: + pix = pymupdf.Pixmap(pymupdf.csRGB, pix) + filename = f"p{page_num}_img{img_idx}.png" + save_path = images_path / filename + pix.save(str(save_path)) + pix = None + except Exception: + continue + + rel_path = str(rel_images_path / filename) + images.append({ + "path": rel_path, + "width": width, + "height": height, + }) + parts.append(f"![image]({rel_path})") + img_idx += 1 + + content = "\n".join(parts) + return content, images diff --git a/pageindex/parser/protocol.py b/pageindex/parser/protocol.py new file mode 100644 index 000000000..76d7b0a78 --- /dev/null +++ b/pageindex/parser/protocol.py @@ -0,0 +1,28 @@ +from __future__ import annotations +from dataclasses import dataclass +from typing import Protocol, runtime_checkable + + +@dataclass +class ContentNode: + """Universal content unit produced by parsers.""" + content: str + tokens: int + title: str | None = None + index: int | None = None + level: int | None = None + images: list[dict] | None = None # [{"path": str, "width": int, "height": int}, ...] + + +@dataclass +class ParsedDocument: + """Unified parser output. Always a flat list of ContentNode.""" + doc_name: str + nodes: list[ContentNode] + metadata: dict | None = None + + +@runtime_checkable +class DocumentParser(Protocol): + def supported_extensions(self) -> list[str]: ... + def parse(self, file_path: str, **kwargs) -> ParsedDocument: ... diff --git a/pageindex/storage/__init__.py b/pageindex/storage/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pageindex/storage/protocol.py b/pageindex/storage/protocol.py new file mode 100644 index 000000000..427021b2d --- /dev/null +++ b/pageindex/storage/protocol.py @@ -0,0 +1,18 @@ +from __future__ import annotations +from typing import Protocol, runtime_checkable + + +@runtime_checkable +class StorageEngine(Protocol): + def create_collection(self, name: str) -> None: ... + def get_or_create_collection(self, name: str) -> None: ... + def list_collections(self) -> list[str]: ... + def delete_collection(self, name: str) -> None: ... + def save_document(self, collection: str, doc_id: str, doc: dict) -> None: ... + def find_document_by_hash(self, collection: str, file_hash: str) -> str | None: ... + def get_document(self, collection: str, doc_id: str) -> dict: ... + def get_document_structure(self, collection: str, doc_id: str) -> list: ... + def get_pages(self, collection: str, doc_id: str) -> list | None: ... + def list_documents(self, collection: str) -> list[dict]: ... + def delete_document(self, collection: str, doc_id: str) -> None: ... + def close(self) -> None: ... diff --git a/pageindex/storage/sqlite.py b/pageindex/storage/sqlite.py new file mode 100644 index 000000000..86e71cc8a --- /dev/null +++ b/pageindex/storage/sqlite.py @@ -0,0 +1,164 @@ +import json +import sqlite3 +import threading +from pathlib import Path + + +class SQLiteStorage: + def __init__(self, db_path: str): + self._db_path = Path(db_path).expanduser() + self._db_path.parent.mkdir(parents=True, exist_ok=True) + self._local = threading.local() + self._connections: list[sqlite3.Connection] = [] + self._conn_lock = threading.Lock() + self._init_schema() + + def _get_conn(self) -> sqlite3.Connection: + """Return a thread-local SQLite connection.""" + if not hasattr(self._local, "conn"): + conn = sqlite3.connect(str(self._db_path)) + conn.execute("PRAGMA journal_mode=WAL") + conn.execute("PRAGMA foreign_keys=ON") + self._local.conn = conn + with self._conn_lock: + self._connections.append(conn) + return self._local.conn + + def _init_schema(self): + conn = self._get_conn() + conn.execute("PRAGMA user_version = 1") + conn.executescript(""" + CREATE TABLE IF NOT EXISTS collections ( + name TEXT PRIMARY KEY CHECK(length(name) <= 128 AND name GLOB '[a-zA-Z0-9_-]*'), + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP + ); + CREATE TABLE IF NOT EXISTS documents ( + doc_id TEXT PRIMARY KEY, + collection_name TEXT NOT NULL REFERENCES collections(name) ON DELETE CASCADE, + doc_name TEXT, + doc_description TEXT, + file_path TEXT, + file_hash TEXT, + doc_type TEXT NOT NULL, + structure JSON, + pages JSON, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP + ); + CREATE INDEX IF NOT EXISTS idx_docs_collection ON documents(collection_name); + CREATE INDEX IF NOT EXISTS idx_docs_hash ON documents(collection_name, file_hash); + """) + conn.commit() + + def create_collection(self, name: str) -> None: + conn = self._get_conn() + conn.execute("INSERT INTO collections (name) VALUES (?)", (name,)) + conn.commit() + + def get_or_create_collection(self, name: str) -> None: + conn = self._get_conn() + conn.execute("INSERT OR IGNORE INTO collections (name) VALUES (?)", (name,)) + conn.commit() + + def list_collections(self) -> list[str]: + conn = self._get_conn() + rows = conn.execute("SELECT name FROM collections ORDER BY name").fetchall() + return [r[0] for r in rows] + + def delete_collection(self, name: str) -> None: + conn = self._get_conn() + conn.execute("DELETE FROM collections WHERE name = ?", (name,)) + conn.commit() + + def save_document(self, collection: str, doc_id: str, doc: dict) -> None: + conn = self._get_conn() + conn.execute( + """INSERT OR REPLACE INTO documents + (doc_id, collection_name, doc_name, doc_description, file_path, file_hash, doc_type, structure, pages) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)""", + (doc_id, collection, doc.get("doc_name"), doc.get("doc_description"), + doc.get("file_path"), doc.get("file_hash"), doc["doc_type"], + json.dumps(doc.get("structure", [])), + json.dumps(doc.get("pages")) if doc.get("pages") else None), + ) + conn.commit() + + def find_document_by_hash(self, collection: str, file_hash: str) -> str | None: + conn = self._get_conn() + row = conn.execute( + "SELECT doc_id FROM documents WHERE collection_name = ? AND file_hash = ?", + (collection, file_hash), + ).fetchone() + return row[0] if row else None + + def get_document(self, collection: str, doc_id: str) -> dict: + conn = self._get_conn() + row = conn.execute( + "SELECT doc_id, doc_name, doc_description, file_path, doc_type FROM documents WHERE doc_id = ? AND collection_name = ?", + (doc_id, collection), + ).fetchone() + if not row: + return {} + return {"doc_id": row[0], "doc_name": row[1], "doc_description": row[2], + "file_path": row[3], "doc_type": row[4]} + + def get_document_structure(self, collection: str, doc_id: str) -> list: + conn = self._get_conn() + row = conn.execute( + "SELECT structure FROM documents WHERE doc_id = ? AND collection_name = ?", + (doc_id, collection), + ).fetchone() + if not row: + return [] + return json.loads(row[0]) + + def get_pages(self, collection: str, doc_id: str) -> list | None: + """Return cached page content, or None if not cached.""" + conn = self._get_conn() + row = conn.execute( + "SELECT pages FROM documents WHERE doc_id = ? AND collection_name = ?", + (doc_id, collection), + ).fetchone() + if not row or not row[0]: + return None + return json.loads(row[0]) + + def list_documents(self, collection: str) -> list[dict]: + conn = self._get_conn() + rows = conn.execute( + "SELECT doc_id, doc_name, doc_type FROM documents WHERE collection_name = ? ORDER BY created_at", + (collection,), + ).fetchall() + return [{"doc_id": r[0], "doc_name": r[1], "doc_type": r[2]} for r in rows] + + def delete_document(self, collection: str, doc_id: str) -> None: + conn = self._get_conn() + conn.execute( + "DELETE FROM documents WHERE doc_id = ? AND collection_name = ?", + (doc_id, collection), + ) + conn.commit() + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.close() + return False + + def close(self) -> None: + """Close all tracked SQLite connections across all threads.""" + with self._conn_lock: + for conn in self._connections: + try: + conn.close() + except Exception: + pass + self._connections.clear() + if hasattr(self._local, "conn"): + del self._local.conn + + def __del__(self): + try: + self.close() + except Exception: + pass diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 000000000..8927acc18 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,48 @@ +[build-system] +requires = ["setuptools>=68.0"] +build-backend = "setuptools.build_meta" + +[project] +name = "pageindex" +version = "0.3.0" +description = "Python SDK for PageIndex" +readme = "README.md" +license = {text = "MIT"} +requires-python = ">=3.10" +authors = [ + {name = "Ray", email = "ray@vectify.ai"}, +] +classifiers = [ + "Development Status :: 3 - Alpha", + "Intended Audience :: Developers", + "License :: OSI Approved :: MIT License", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", + "Topic :: Scientific/Engineering :: Artificial Intelligence", +] +keywords = ["rag", "document", "retrieval", "llm", "pageindex"] +dependencies = [ + "litellm>=1.82.0", + "pymupdf>=1.26.0", + "PyPDF2>=3.0.0", + "python-dotenv>=1.0.0", + "pyyaml>=6.0", + "openai-agents>=0.1.0", + "requests>=2.28.0", + "httpx[socks]>=0.28.1", +] + +[project.optional-dependencies] +dev = ["pytest>=8.0", "pytest-asyncio>=0.23"] + +[project.urls] +Homepage = "https://github.com/VectifyAI/PageIndex" +Documentation = "https://docs.pageindex.ai" +Repository = "https://github.com/VectifyAI/PageIndex" +Issues = "https://github.com/VectifyAI/PageIndex/issues" + +[tool.setuptools.packages.find] +include = ["pageindex*"] diff --git a/run_pageindex.py b/run_pageindex.py index 673439d89..a2d4c3185 100644 --- a/run_pageindex.py +++ b/run_pageindex.py @@ -1,9 +1,9 @@ import argparse import os import json -from pageindex import * -from pageindex.page_index_md import md_to_tree -from pageindex.utils import ConfigLoader +from pageindex.index.page_index import * +from pageindex.index.page_index_md import md_to_tree +from pageindex.config import IndexConfig if __name__ == "__main__": # Set up argument parser @@ -11,7 +11,7 @@ parser.add_argument('--pdf_path', type=str, help='Path to the PDF file') parser.add_argument('--md_path', type=str, help='Path to the Markdown file') - parser.add_argument('--model', type=str, default=None, help='Model to use (overrides config.yaml)') + parser.add_argument('--model', type=str, default=None, help='Model to use') parser.add_argument('--toc-check-pages', type=int, default=None, help='Number of pages to check for table of contents (PDF only)') @@ -20,15 +20,15 @@ parser.add_argument('--max-tokens-per-node', type=int, default=None, help='Maximum number of tokens per node (PDF only)') - parser.add_argument('--if-add-node-id', type=str, default=None, - help='Whether to add node id to the node') - parser.add_argument('--if-add-node-summary', type=str, default=None, - help='Whether to add summary to the node') - parser.add_argument('--if-add-doc-description', type=str, default=None, - help='Whether to add doc description to the doc') - parser.add_argument('--if-add-node-text', type=str, default=None, - help='Whether to add text to the node') - + parser.add_argument('--if-add-node-id', action='store_true', default=None, + help='Add node id to the node') + parser.add_argument('--if-add-node-summary', action='store_true', default=None, + help='Add summary to the node') + parser.add_argument('--if-add-doc-description', action='store_true', default=None, + help='Add doc description to the doc') + parser.add_argument('--if-add-node-text', action='store_true', default=None, + help='Add text to the node') + # Markdown specific arguments parser.add_argument('--if-thinning', type=str, default='no', help='Whether to apply tree thinning for markdown (markdown only)') @@ -37,77 +37,61 @@ parser.add_argument('--summary-token-threshold', type=int, default=200, help='Token threshold for generating summaries (markdown only)') args = parser.parse_args() - + # Validate that exactly one file type is specified if not args.pdf_path and not args.md_path: raise ValueError("Either --pdf_path or --md_path must be specified") if args.pdf_path and args.md_path: raise ValueError("Only one of --pdf_path or --md_path can be specified") - + + # Build IndexConfig from CLI args (None values use defaults) + config_overrides = { + k: v for k, v in { + "model": args.model, + "toc_check_page_num": args.toc_check_pages, + "max_page_num_each_node": args.max_pages_per_node, + "max_token_num_each_node": args.max_tokens_per_node, + "if_add_node_id": args.if_add_node_id, + "if_add_node_summary": args.if_add_node_summary, + "if_add_doc_description": args.if_add_doc_description, + "if_add_node_text": args.if_add_node_text, + }.items() if v is not None + } + opt = IndexConfig(**config_overrides) + if args.pdf_path: # Validate PDF file if not args.pdf_path.lower().endswith('.pdf'): raise ValueError("PDF file must have .pdf extension") if not os.path.isfile(args.pdf_path): raise ValueError(f"PDF file not found: {args.pdf_path}") - - # Process PDF file - user_opt = { - 'model': args.model, - 'toc_check_page_num': args.toc_check_pages, - 'max_page_num_each_node': args.max_pages_per_node, - 'max_token_num_each_node': args.max_tokens_per_node, - 'if_add_node_id': args.if_add_node_id, - 'if_add_node_summary': args.if_add_node_summary, - 'if_add_doc_description': args.if_add_doc_description, - 'if_add_node_text': args.if_add_node_text, - } - opt = ConfigLoader().load({k: v for k, v in user_opt.items() if v is not None}) # Process the PDF toc_with_page_number = page_index_main(args.pdf_path, opt) print('Parsing done, saving to file...') - + # Save results - pdf_name = os.path.splitext(os.path.basename(args.pdf_path))[0] + pdf_name = os.path.splitext(os.path.basename(args.pdf_path))[0] output_dir = './results' output_file = f'{output_dir}/{pdf_name}_structure.json' os.makedirs(output_dir, exist_ok=True) - + with open(output_file, 'w', encoding='utf-8') as f: json.dump(toc_with_page_number, f, indent=2) - + print(f'Tree structure saved to: {output_file}') - + elif args.md_path: # Validate Markdown file if not args.md_path.lower().endswith(('.md', '.markdown')): raise ValueError("Markdown file must have .md or .markdown extension") if not os.path.isfile(args.md_path): raise ValueError(f"Markdown file not found: {args.md_path}") - + # Process markdown file print('Processing markdown file...') - - # Process the markdown import asyncio - - # Use ConfigLoader to get consistent defaults (matching PDF behavior) - from pageindex.utils import ConfigLoader - config_loader = ConfigLoader() - - # Create options dict with user args - user_opt = { - 'model': args.model, - 'if_add_node_summary': args.if_add_node_summary, - 'if_add_doc_description': args.if_add_doc_description, - 'if_add_node_text': args.if_add_node_text, - 'if_add_node_id': args.if_add_node_id - } - - # Load config with defaults from config.yaml - opt = config_loader.load(user_opt) - + toc_with_page_number = asyncio.run(md_to_tree( md_path=args.md_path, if_thinning=args.if_thinning.lower() == 'yes', @@ -119,16 +103,16 @@ if_add_node_text=opt.if_add_node_text, if_add_node_id=opt.if_add_node_id )) - + print('Parsing done, saving to file...') - + # Save results - md_name = os.path.splitext(os.path.basename(args.md_path))[0] + md_name = os.path.splitext(os.path.basename(args.md_path))[0] output_dir = './results' output_file = f'{output_dir}/{md_name}_structure.json' os.makedirs(output_dir, exist_ok=True) - + with open(output_file, 'w', encoding='utf-8') as f: json.dump(toc_with_page_number, f, indent=2, ensure_ascii=False) - - print(f'Tree structure saved to: {output_file}') \ No newline at end of file + + print(f'Tree structure saved to: {output_file}') diff --git a/tests/test_agent.py b/tests/test_agent.py new file mode 100644 index 000000000..7d40b2b5c --- /dev/null +++ b/tests/test_agent.py @@ -0,0 +1,14 @@ +from pageindex.agent import AgentRunner, SYSTEM_PROMPT +from pageindex.backend.protocol import AgentTools + + +def test_agent_runner_init(): + tools = AgentTools(function_tools=["mock_tool"]) + runner = AgentRunner(tools=tools, model="gpt-4o") + assert runner._model == "gpt-4o" + + +def test_system_prompt_has_tool_instructions(): + assert "list_documents" in SYSTEM_PROMPT + assert "get_document_structure" in SYSTEM_PROMPT + assert "get_page_content" in SYSTEM_PROMPT diff --git a/tests/test_client.py b/tests/test_client.py new file mode 100644 index 000000000..2c78c92cc --- /dev/null +++ b/tests/test_client.py @@ -0,0 +1,51 @@ +# tests/sdk/test_client.py +import pytest +from pageindex.client import PageIndexClient, LocalClient, CloudClient + + +def test_local_client_is_pageindex_client(tmp_path): + client = LocalClient(model="gpt-4o", storage_path=str(tmp_path / "pi")) + assert isinstance(client, PageIndexClient) + + +def test_cloud_client_is_pageindex_client(): + client = CloudClient(api_key="pi-test") + assert isinstance(client, PageIndexClient) + + +def test_collection_default_name(tmp_path): + client = LocalClient(model="gpt-4o", storage_path=str(tmp_path / "pi")) + col = client.collection() + assert col.name == "default" + + +def test_collection_custom_name(tmp_path): + client = LocalClient(model="gpt-4o", storage_path=str(tmp_path / "pi")) + col = client.collection("papers") + assert col.name == "papers" + + +def test_list_collections_empty(tmp_path): + client = LocalClient(model="gpt-4o", storage_path=str(tmp_path / "pi")) + assert client.list_collections() == [] + + +def test_list_collections_after_create(tmp_path): + client = LocalClient(model="gpt-4o", storage_path=str(tmp_path / "pi")) + client.collection("papers") + assert "papers" in client.list_collections() + + +def test_delete_collection(tmp_path): + client = LocalClient(model="gpt-4o", storage_path=str(tmp_path / "pi")) + client.collection("papers") + client.delete_collection("papers") + assert "papers" not in client.list_collections() + + +def test_register_parser(tmp_path): + client = LocalClient(model="gpt-4o", storage_path=str(tmp_path / "pi")) + class FakeParser: + def supported_extensions(self): return [".txt"] + def parse(self, file_path, **kwargs): pass + client.register_parser(FakeParser()) diff --git a/tests/test_cloud_backend.py b/tests/test_cloud_backend.py new file mode 100644 index 000000000..8123c726f --- /dev/null +++ b/tests/test_cloud_backend.py @@ -0,0 +1,16 @@ +from pageindex.backend.cloud import CloudBackend, API_BASE + + +def test_cloud_backend_init(): + backend = CloudBackend(api_key="pi-test") + assert backend._api_key == "pi-test" + assert backend._headers["api_key"] == "pi-test" + + +def test_api_base_url(): + assert "pageindex.ai" in API_BASE + + +def test_get_retrieve_model_is_none(): + backend = CloudBackend(api_key="pi-test") + assert backend.get_agent_tools("col").function_tools == [] diff --git a/tests/test_collection.py b/tests/test_collection.py new file mode 100644 index 000000000..5ef483f09 --- /dev/null +++ b/tests/test_collection.py @@ -0,0 +1,41 @@ +# tests/sdk/test_collection.py +import pytest +from unittest.mock import MagicMock +from pageindex.collection import Collection + + +@pytest.fixture +def col(): + backend = MagicMock() + backend.list_documents.return_value = [ + {"doc_id": "d1", "doc_name": "paper.pdf", "doc_type": "pdf"} + ] + backend.get_document.return_value = {"doc_id": "d1", "doc_name": "paper.pdf"} + backend.add_document.return_value = "d1" + return Collection(name="papers", backend=backend) + + +def test_add(col): + doc_id = col.add("paper.pdf") + assert doc_id == "d1" + col._backend.add_document.assert_called_once_with("papers", "paper.pdf") + + +def test_list_documents(col): + docs = col.list_documents() + assert len(docs) == 1 + assert docs[0]["doc_id"] == "d1" + + +def test_get_document(col): + doc = col.get_document("d1") + assert doc["doc_name"] == "paper.pdf" + + +def test_delete_document(col): + col.delete_document("d1") + col._backend.delete_document.assert_called_once_with("papers", "d1") + + +def test_name_property(col): + assert col.name == "papers" diff --git a/tests/test_config.py b/tests/test_config.py new file mode 100644 index 000000000..be3b00310 --- /dev/null +++ b/tests/test_config.py @@ -0,0 +1,28 @@ +# tests/test_config.py +import pytest +from pageindex.config import IndexConfig + + +def test_defaults(): + config = IndexConfig() + assert config.model == "gpt-4o-2024-11-20" + assert config.retrieve_model is None + assert config.toc_check_page_num == 20 + + +def test_overrides(): + config = IndexConfig(model="gpt-5.4", retrieve_model="claude-sonnet") + assert config.model == "gpt-5.4" + assert config.retrieve_model == "claude-sonnet" + + +def test_unknown_key_raises(): + with pytest.raises(Exception): + IndexConfig(nonexistent_key="value") + + +def test_model_copy_with_update(): + config = IndexConfig(toc_check_page_num=30) + updated = config.model_copy(update={"model": "gpt-5.4"}) + assert updated.model == "gpt-5.4" + assert updated.toc_check_page_num == 30 diff --git a/tests/test_content_node.py b/tests/test_content_node.py new file mode 100644 index 000000000..409982193 --- /dev/null +++ b/tests/test_content_node.py @@ -0,0 +1,45 @@ +from pageindex.parser.protocol import ContentNode, ParsedDocument, DocumentParser + + +def test_content_node_required_fields(): + node = ContentNode(content="hello", tokens=5) + assert node.content == "hello" + assert node.tokens == 5 + assert node.title is None + assert node.index is None + assert node.level is None + + +def test_content_node_all_fields(): + node = ContentNode(content="# Intro", tokens=10, title="Intro", index=1, level=1) + assert node.title == "Intro" + assert node.index == 1 + assert node.level == 1 + + +def test_parsed_document(): + nodes = [ContentNode(content="page1", tokens=100, index=1)] + doc = ParsedDocument(doc_name="test.pdf", nodes=nodes) + assert doc.doc_name == "test.pdf" + assert len(doc.nodes) == 1 + assert doc.metadata is None + + +def test_parsed_document_with_metadata(): + nodes = [ContentNode(content="page1", tokens=100)] + doc = ParsedDocument(doc_name="test.pdf", nodes=nodes, metadata={"author": "John"}) + assert doc.metadata["author"] == "John" + + +def test_document_parser_protocol(): + """Verify a class implementing DocumentParser is structurally compatible.""" + class MyParser: + def supported_extensions(self) -> list[str]: + return [".txt"] + def parse(self, file_path: str, **kwargs) -> ParsedDocument: + return ParsedDocument(doc_name="test", nodes=[]) + + parser = MyParser() + assert parser.supported_extensions() == [".txt"] + result = parser.parse("test.txt") + assert isinstance(result, ParsedDocument) diff --git a/tests/test_errors.py b/tests/test_errors.py new file mode 100644 index 000000000..af55e7c57 --- /dev/null +++ b/tests/test_errors.py @@ -0,0 +1,27 @@ +from pageindex.errors import ( + PageIndexError, + CollectionNotFoundError, + DocumentNotFoundError, + IndexingError, + CloudAPIError, + FileTypeError, +) + + +def test_all_errors_inherit_from_base(): + for cls in [CollectionNotFoundError, DocumentNotFoundError, IndexingError, CloudAPIError, FileTypeError]: + assert issubclass(cls, PageIndexError) + assert issubclass(cls, Exception) + + +def test_error_message(): + err = FileTypeError("Unsupported: .docx") + assert str(err) == "Unsupported: .docx" + + +def test_catch_base_catches_all(): + for cls in [CollectionNotFoundError, DocumentNotFoundError, IndexingError, CloudAPIError, FileTypeError]: + try: + raise cls("test") + except PageIndexError: + pass # expected diff --git a/tests/test_events.py b/tests/test_events.py new file mode 100644 index 000000000..0046130e8 --- /dev/null +++ b/tests/test_events.py @@ -0,0 +1,26 @@ +from pageindex.events import QueryEvent +from pageindex.backend.protocol import AgentTools + + +def test_query_event(): + event = QueryEvent(type="answer_delta", data="hello") + assert event.type == "answer_delta" + assert event.data == "hello" + + +def test_query_event_types(): + for t in ["reasoning", "tool_call", "tool_result", "answer_delta", "answer_done"]: + event = QueryEvent(type=t, data="test") + assert event.type == t + + +def test_agent_tools_default_empty(): + tools = AgentTools() + assert tools.function_tools == [] + assert tools.mcp_servers == [] + + +def test_agent_tools_with_values(): + tools = AgentTools(function_tools=["tool1"], mcp_servers=["server1"]) + assert len(tools.function_tools) == 1 + assert len(tools.mcp_servers) == 1 diff --git a/tests/test_local_backend.py b/tests/test_local_backend.py new file mode 100644 index 000000000..7de9580fa --- /dev/null +++ b/tests/test_local_backend.py @@ -0,0 +1,50 @@ +# tests/sdk/test_local_backend.py +import pytest +from pathlib import Path +from pageindex.backend.local import LocalBackend +from pageindex.storage.sqlite import SQLiteStorage +from pageindex.errors import FileTypeError + + +@pytest.fixture +def backend(tmp_path): + storage = SQLiteStorage(str(tmp_path / "test.db")) + files_dir = tmp_path / "files" + return LocalBackend(storage=storage, files_dir=str(files_dir), model="gpt-4o") + + +def test_collection_lifecycle(backend): + backend.get_or_create_collection("papers") + assert "papers" in backend.list_collections() + backend.delete_collection("papers") + assert "papers" not in backend.list_collections() + + +def test_list_documents_empty(backend): + backend.get_or_create_collection("papers") + assert backend.list_documents("papers") == [] + + +def test_unsupported_file_type_raises(backend, tmp_path): + backend.get_or_create_collection("papers") + bad_file = tmp_path / "test.xyz" + bad_file.write_text("hello") + with pytest.raises(FileTypeError): + backend.add_document("papers", str(bad_file)) + + +def test_register_custom_parser(backend): + from pageindex.parser.protocol import ParsedDocument, ContentNode + + class TxtParser: + def supported_extensions(self): + return [".txt"] + def parse(self, file_path, **kwargs): + text = Path(file_path).read_text() + return ParsedDocument(doc_name="test", nodes=[ + ContentNode(content=text, tokens=len(text.split()), title="Content", index=1, level=1) + ]) + + backend.register_parser(TxtParser()) + # Now .txt should be supported (won't raise FileTypeError) + assert backend._resolve_parser("test.txt") is not None diff --git a/tests/test_markdown_parser.py b/tests/test_markdown_parser.py new file mode 100644 index 000000000..cbd06af99 --- /dev/null +++ b/tests/test_markdown_parser.py @@ -0,0 +1,55 @@ +import pytest +from pathlib import Path +from pageindex.parser.markdown import MarkdownParser +from pageindex.parser.protocol import ContentNode, ParsedDocument + +@pytest.fixture +def sample_md(tmp_path): + md = tmp_path / "test.md" + md.write_text("""# Chapter 1 +Some intro text. + +## Section 1.1 +Details here. + +## Section 1.2 +More details. + +# Chapter 2 +Another chapter. +""") + return str(md) + +def test_supported_extensions(): + parser = MarkdownParser() + exts = parser.supported_extensions() + assert ".md" in exts + assert ".markdown" in exts + +def test_parse_returns_parsed_document(sample_md): + parser = MarkdownParser() + result = parser.parse(sample_md) + assert isinstance(result, ParsedDocument) + assert result.doc_name == "test" + +def test_parse_nodes_have_level(sample_md): + parser = MarkdownParser() + result = parser.parse(sample_md) + assert len(result.nodes) == 4 + assert result.nodes[0].level == 1 + assert result.nodes[0].title == "Chapter 1" + assert result.nodes[1].level == 2 + assert result.nodes[1].title == "Section 1.1" + assert result.nodes[3].level == 1 + +def test_parse_nodes_have_content(sample_md): + parser = MarkdownParser() + result = parser.parse(sample_md) + assert "Some intro text" in result.nodes[0].content + assert "Details here" in result.nodes[1].content + +def test_parse_nodes_have_index(sample_md): + parser = MarkdownParser() + result = parser.parse(sample_md) + for node in result.nodes: + assert node.index is not None diff --git a/tests/test_pdf_parser.py b/tests/test_pdf_parser.py new file mode 100644 index 000000000..c6a8cabfc --- /dev/null +++ b/tests/test_pdf_parser.py @@ -0,0 +1,29 @@ +import pytest +from pathlib import Path +from pageindex.parser.pdf import PdfParser +from pageindex.parser.protocol import ContentNode, ParsedDocument + +TEST_PDF = Path("tests/pdfs/deepseek-r1.pdf") + +def test_supported_extensions(): + parser = PdfParser() + assert ".pdf" in parser.supported_extensions() + +@pytest.mark.skipif(not TEST_PDF.exists(), reason="Test PDF not available") +def test_parse_returns_parsed_document(): + parser = PdfParser() + result = parser.parse(str(TEST_PDF)) + assert isinstance(result, ParsedDocument) + assert len(result.nodes) > 0 + assert result.doc_name != "" + +@pytest.mark.skipif(not TEST_PDF.exists(), reason="Test PDF not available") +def test_parse_nodes_are_flat_without_level(): + parser = PdfParser() + result = parser.parse(str(TEST_PDF)) + for node in result.nodes: + assert isinstance(node, ContentNode) + assert node.content is not None + assert node.tokens >= 0 + assert node.index is not None + assert node.level is None diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py new file mode 100644 index 000000000..9e1e54e67 --- /dev/null +++ b/tests/test_pipeline.py @@ -0,0 +1,95 @@ +# tests/sdk/test_pipeline.py +import asyncio +from unittest.mock import patch, AsyncMock + +from pageindex.parser.protocol import ContentNode, ParsedDocument +from pageindex.index.pipeline import ( + detect_strategy, build_tree_from_levels, build_index, + _content_based_pipeline, _NullLogger, +) + + +def test_detect_strategy_with_level(): + nodes = [ + ContentNode(content="# Intro", tokens=10, title="Intro", index=1, level=1), + ContentNode(content="## Details", tokens=10, title="Details", index=5, level=2), + ] + assert detect_strategy(nodes) == "level_based" + + +def test_detect_strategy_without_level(): + nodes = [ + ContentNode(content="Page 1 text", tokens=100, index=1), + ContentNode(content="Page 2 text", tokens=100, index=2), + ] + assert detect_strategy(nodes) == "content_based" + + +def test_build_tree_from_levels(): + nodes = [ + ContentNode(content="ch1 text", tokens=10, title="Chapter 1", index=1, level=1), + ContentNode(content="s1.1 text", tokens=10, title="Section 1.1", index=5, level=2), + ContentNode(content="s1.2 text", tokens=10, title="Section 1.2", index=10, level=2), + ContentNode(content="ch2 text", tokens=10, title="Chapter 2", index=20, level=1), + ] + tree = build_tree_from_levels(nodes) + assert len(tree) == 2 # 2 root nodes (chapters) + assert tree[0]["title"] == "Chapter 1" + assert len(tree[0]["nodes"]) == 2 # 2 sections under chapter 1 + assert tree[0]["nodes"][0]["title"] == "Section 1.1" + assert tree[0]["nodes"][1]["title"] == "Section 1.2" + assert tree[1]["title"] == "Chapter 2" + assert len(tree[1]["nodes"]) == 0 + + +def test_build_tree_from_levels_single_level(): + nodes = [ + ContentNode(content="a", tokens=5, title="A", index=1, level=1), + ContentNode(content="b", tokens=5, title="B", index=2, level=1), + ] + tree = build_tree_from_levels(nodes) + assert len(tree) == 2 + assert tree[0]["title"] == "A" + assert tree[1]["title"] == "B" + + +def test_build_tree_from_levels_deep_nesting(): + nodes = [ + ContentNode(content="h1", tokens=5, title="H1", index=1, level=1), + ContentNode(content="h2", tokens=5, title="H2", index=2, level=2), + ContentNode(content="h3", tokens=5, title="H3", index=3, level=3), + ] + tree = build_tree_from_levels(nodes) + assert len(tree) == 1 + assert tree[0]["title"] == "H1" + assert len(tree[0]["nodes"]) == 1 + assert tree[0]["nodes"][0]["title"] == "H2" + assert len(tree[0]["nodes"][0]["nodes"]) == 1 + assert tree[0]["nodes"][0]["nodes"][0]["title"] == "H3" + + +def test_content_based_pipeline_does_not_raise(): + """_content_based_pipeline should delegate to tree_parser, not raise NotImplementedError.""" + fake_tree = [{"title": "Intro", "start_index": 1, "end_index": 2, "nodes": []}] + + async def fake_tree_parser(page_list, opt, doc=None, logger=None): + return fake_tree + + page_list = [("Page 1 text", 50), ("Page 2 text", 60)] + + from types import SimpleNamespace + opt = SimpleNamespace(model="test-model") + + with patch("pageindex.index.page_index.tree_parser", new=fake_tree_parser): + result = asyncio.run(_content_based_pipeline(page_list, opt)) + + assert result == fake_tree + + +def test_null_logger_methods(): + """NullLogger should have info/error/debug and not raise.""" + logger = _NullLogger() + logger.info("test message") + logger.error("test error") + logger.debug("test debug") + logger.info({"key": "value"}) diff --git a/tests/test_sqlite_storage.py b/tests/test_sqlite_storage.py new file mode 100644 index 000000000..3e8984554 --- /dev/null +++ b/tests/test_sqlite_storage.py @@ -0,0 +1,61 @@ +import pytest +from pageindex.storage.sqlite import SQLiteStorage + +@pytest.fixture +def storage(tmp_path): + return SQLiteStorage(str(tmp_path / "test.db")) + +def test_create_and_list_collections(storage): + storage.create_collection("papers") + assert "papers" in storage.list_collections() + +def test_get_or_create_collection_idempotent(storage): + storage.get_or_create_collection("papers") + storage.get_or_create_collection("papers") + assert storage.list_collections().count("papers") == 1 + +def test_delete_collection(storage): + storage.create_collection("papers") + storage.delete_collection("papers") + assert "papers" not in storage.list_collections() + +def test_save_and_get_document(storage): + storage.create_collection("papers") + doc = { + "doc_name": "test.pdf", "doc_description": "A test", + "file_path": "/tmp/test.pdf", "doc_type": "pdf", + "structure": [{"title": "Intro", "node_id": "0001"}], + } + storage.save_document("papers", "doc-1", doc) + result = storage.get_document("papers", "doc-1") + assert result["doc_name"] == "test.pdf" + assert result["doc_type"] == "pdf" + +def test_get_document_structure(storage): + storage.create_collection("papers") + structure = [{"title": "Ch1", "node_id": "0001", "nodes": []}] + storage.save_document("papers", "doc-1", { + "doc_name": "test.pdf", "doc_type": "pdf", + "file_path": "/tmp/test.pdf", "structure": structure, + }) + result = storage.get_document_structure("papers", "doc-1") + assert result[0]["title"] == "Ch1" + +def test_list_documents(storage): + storage.create_collection("papers") + storage.save_document("papers", "doc-1", {"doc_name": "p1.pdf", "doc_type": "pdf", "file_path": "/tmp/p1.pdf", "structure": []}) + storage.save_document("papers", "doc-2", {"doc_name": "p2.pdf", "doc_type": "pdf", "file_path": "/tmp/p2.pdf", "structure": []}) + docs = storage.list_documents("papers") + assert len(docs) == 2 + +def test_delete_document(storage): + storage.create_collection("papers") + storage.save_document("papers", "doc-1", {"doc_name": "test.pdf", "doc_type": "pdf", "file_path": "/tmp/test.pdf", "structure": []}) + storage.delete_document("papers", "doc-1") + assert len(storage.list_documents("papers")) == 0 + +def test_delete_collection_cascades_documents(storage): + storage.create_collection("papers") + storage.save_document("papers", "doc-1", {"doc_name": "test.pdf", "doc_type": "pdf", "file_path": "/tmp/test.pdf", "structure": []}) + storage.delete_collection("papers") + assert "papers" not in storage.list_collections() diff --git a/tests/test_storage_protocol.py b/tests/test_storage_protocol.py new file mode 100644 index 000000000..49392547d --- /dev/null +++ b/tests/test_storage_protocol.py @@ -0,0 +1,19 @@ +from pageindex.storage.protocol import StorageEngine + +def test_storage_engine_is_protocol(): + class FakeStorage: + def create_collection(self, name: str) -> None: pass + def get_or_create_collection(self, name: str) -> None: pass + def list_collections(self) -> list[str]: return [] + def delete_collection(self, name: str) -> None: pass + def save_document(self, collection: str, doc_id: str, doc: dict) -> None: pass + def find_document_by_hash(self, collection: str, file_hash: str) -> str | None: return None + def get_document(self, collection: str, doc_id: str) -> dict: return {} + def get_document_structure(self, collection: str, doc_id: str) -> dict: return {} + def get_pages(self, collection: str, doc_id: str) -> list | None: return None + def list_documents(self, collection: str) -> list[dict]: return [] + def delete_document(self, collection: str, doc_id: str) -> None: pass + def close(self) -> None: pass + + storage = FakeStorage() + assert isinstance(storage, StorageEngine)