diff --git a/.gitignore b/.gitignore index f93bf8d..d5992cc 100644 --- a/.gitignore +++ b/.gitignore @@ -208,3 +208,4 @@ __marimo__/ DEVELOPMENT_WORKFLOW.md docs/ +toolcalling.txt diff --git a/README.md b/README.md index 34feed4..9eee0ba 100644 --- a/README.md +++ b/README.md @@ -105,6 +105,7 @@ - **Seamless model switching** — capabilities update instantly when you switch models mid-conversation - **Chain-of-thought reasoning** for models that support it (e.g. `qwen3`, `deepseek-r1`, `deepseek-v3.1`, `gpt-oss`) - **Tool calling** and a full agent loop for multi-step model actions +- **Custom coding tools** (`read`, `grep`, `glob`, `ls`, `write`, `edit`, `multiedit`, `apply_patch`, `bash`, `batch`, planning/todo/task tools, and more) - **Web search** via Ollama's built-in tools (requires an Ollama API key) - **Vision / image attachments** for vision-capable models (e.g. `gemma3`, `llava`) - **Context window alignment** — `max_context_tokens` is forwarded to Ollama as `options.num_ctx` so the server-side context window always matches the client-side trim budget @@ -278,6 +279,20 @@ enabled = false directory = "~/.local/state/ollamaterm/conversations" metadata_path = "~/.local/state/ollamaterm/conversations/index.json" +[tools] +# Enable schema-first custom coding tools +enabled = true +# Base root for file/search/edit tools +workspace_root = "." +# Allow temporary external roots via external-directory tool +allow_external_directories = false +command_timeout_seconds = 30 +max_output_lines = 200 +max_output_bytes = 50000 +max_read_bytes = 200000 +max_search_results = 200 +default_external_directories = [] + [capabilities] # Show the model's reasoning trace inside the assistant bubble. # Thinking support itself is auto-detected — this controls only the UI display. @@ -346,6 +361,99 @@ The agent loop allows the model to invoke tools multiple times before producing a final answer. Control the upper bound with `max_tool_iterations` in `[capabilities]`. +In addition to Ollama web tools, OllamaTerm now ships a schema-first local +coding toolset designed for agentic workflows: + +- File and search tools: `read`, `ls`, `glob`, `grep`, `codesearch` +- Editing tools: `write`, `edit`, `multiedit`, `apply_patch` +- Runtime tools: `bash`, `batch`, `external-directory` +- Planning/state tools: `plan-enter`, `plan-exit`, `plan`, `todo`, `todoread`, `todowrite`, `task`, `question` +- Introspection tools: `registry`, `tool`, `truncation`, `invalid` + +These tools are controlled by the `[tools]` config section and are constrained +by workspace-root path checks, command timeouts, and output truncation limits. + +#### Function tools with Ollama (alpha/experimental) + +OllamaTerm passes tools to the Ollama Python SDK in two forms: + +- JSON function tools generated from the schema-first tool specs (the majority of tools below) +- Python callables for built-in Ollama integrations when enabled (e.g. `web_search`, `web_fetch`) + +The model emits `tool_calls`, the app executes them, appends a `tool` role message with the result, and continues the loop until the assistant returns a final answer. + +> Warning: This tool suite is experimental. Most tools are untested and may be buggy or missing edge-case handling. Use with caution and review changes carefully, especially file edits. Outputs may be truncated according to configured limits. + +##### Available tools (names and key parameters) + +- Files & search + - `list` (built-in) — List files and directories. + - `path?: string` (default: workspace root) + - `ls` (custom) — Alternate directory listing with tree-style output. + - `path?: string`, `ignore?: string[]` + - `read` — Read a file window. + - `path: string`, `offset?: int`, `limit?: int` + - `glob` — Find files by glob. + - `pattern: string`, `path?: string`, `max_results?: int` + - `grep` / `codesearch` — Search file contents. + - `query: string`, `path?: string`, `case_sensitive?: bool`, `fixed_strings?: bool`, `max_results?: int` + +- Editing + - `write` — Atomic full-file write. + - `path: string`, `content: string`, `overwrite?: bool`, `create_dirs?: bool` + - `edit` — Single snippet replace. + - `path: string`, `old_text: string`, `new_text: string`, `replace_all?: bool` + - `multiedit` — Multiple snippet edits atomically. + - `path: string`, `edits: { old_text, new_text, replace_all? }[]` + - `apply_patch` — Apply structured patch hunks. + - `path: string`, `hunks: { old_text, new_text, replace_all? }[]` + +- Runtime + - `bash` — Run a shell command (capped by time/output limits). + - `command: string`, `cwd?: string` + - `batch` — Run a sequence of tool calls. + - `calls: { name: string, arguments: object }[]`, `continue_on_error?: bool` + - `external-directory` — Manage temporary external directory allowlist for this session. + - `action: string`, `path?: string` + +- Planning & state + - `plan-enter` | `plan-exit` | `plan` + - `plan-enter: { goal?: string }` + - `plan: { action?: string, content?: string }` + - `todo` | `todoread` | `todowrite` | `task` + - `todo: { item: string }` + - `todowrite: { items: string[], mode?: "append"|"replace" }` + - `task: { action?: string, name?: string, status?: string }` + - `question` — Emit a structured clarification question. + - `prompt: string`, `context?: string` + +- Introspection & utility + - `registry` — List available tools. + - `tool` — Inspect a tool definition. + - `truncation` — Show output truncation limits. + - `invalid` — Always fails (for error-path testing). + +- Web (requires tool-capable model; `web_search_enabled = true` and an API key) + - `websearch` — Perform a web search via Ollama integration. + - `query: string`, `max_results?: int` + - `webfetch` — Fetch a URL via Ollama integration. + - `url: string` + +Notes: + +- Directory listing may appear as `list` (built-in) or `ls` (custom) depending on which tool set is active. Both list files; prefer `list` when available. +- File and command tools will prompt for permission. Paths are restricted to the configured workspace by default. +- Large outputs are truncated. Use `offset`/`limit` (for `read`) and `max_results` (for `grep`/`glob`) to scope results. + +##### Quick examples + +```text +List files here → Call tool: list { "path": "." } +Search for a string → Call tool: grep { "query": "TODO", "path": "." } +Read a file window → Call tool: read { "path": "src/main.py", "offset": 1, "limit": 120 } +Make an edit → Call tool: edit { "path": "README.md", "old_text": "foo", "new_text": "bar", "replace_all": true } +``` + ### Web search Set `web_search_enabled = true` in `[capabilities]` and provide an Ollama API diff --git a/config.example.toml b/config.example.toml index be1740d..0302798 100644 --- a/config.example.toml +++ b/config.example.toml @@ -51,6 +51,22 @@ enabled = false directory = "~/.local/state/ollamaterm/conversations" metadata_path = "~/.local/state/ollamaterm/conversations/index.json" +[tools] +# Enable schema-first custom coding tools (read/write/search/edit/bash/plan/todo/etc.) +enabled = true +# Base root for relative paths in file/search/edit tools. +workspace_root = "." +# Allow adding temporary external roots via external-directory tool. +allow_external_directories = false +# Safety/runtime limits. +command_timeout_seconds = 30 +max_output_lines = 200 +max_output_bytes = 50000 +max_read_bytes = 200000 +max_search_results = 200 +# Optional always-allowed external roots. +default_external_directories = [] + [capabilities] # Whether to render the model's reasoning trace in the assistant bubble. # The trace is shown only when the active model supports thinking (auto-detected). diff --git a/src/ollama_chat/__main__.py b/src/ollama_chat/__main__.py index 6fbde7f..d2f3aa5 100644 --- a/src/ollama_chat/__main__.py +++ b/src/ollama_chat/__main__.py @@ -3,8 +3,8 @@ from __future__ import annotations import argparse +from collections.abc import Sequence from importlib import metadata -from typing import Sequence from .app import OllamaChatApp from .config import ensure_config_dir diff --git a/src/ollama_chat/app.py b/src/ollama_chat/app.py index 88d369b..a1dadc4 100644 --- a/src/ollama_chat/app.py +++ b/src/ollama_chat/app.py @@ -3,16 +3,16 @@ from __future__ import annotations import asyncio +from collections.abc import Awaitable, Callable from datetime import datetime import inspect import logging import os -import sys +from pathlib import Path import random import shutil import subprocess -from pathlib import Path -from collections.abc import Awaitable, Callable +import sys from typing import Any from urllib.parse import urlparse @@ -24,7 +24,7 @@ from textual.widgets import Button, Footer, Header, Input, OptionList, Static from .capabilities import AttachmentState, CapabilityContext, SearchState -from .chat import CapabilityReport, OllamaChat, ChatSendOptions +from .chat import CapabilityReport, ChatSendOptions, OllamaChat from .commands import parse_inline_directives from .config import load_config from .exceptions import ( @@ -46,11 +46,11 @@ from .state import ConnectionState, ConversationState, StateManager from .stream_handler import StreamHandler from .task_manager import TaskManager -from .tools import ToolRegistry, build_registry, ToolRegistryOptions +from .tooling import ToolRegistry, ToolRegistryOptions, ToolRuntimeOptions, build_registry +from .widgets.activity_bar import ActivityBar from .widgets.conversation import ConversationView from .widgets.input_box import InputBox from .widgets.message import MessageBubble -from .widgets.activity_bar import ActivityBar from .widgets.status_bar import StatusBar LOGGER = logging.getLogger(__name__) @@ -116,7 +116,7 @@ async def _open_native_file_dialog( cleaned = token.strip("',()><[]") if cleaned.startswith("file://"): return urllib.parse.unquote(cleaned[len("file://") :]) - except (asyncio.TimeoutError, OSError): + except (TimeoutError, OSError): pass # --- zenity --- @@ -137,7 +137,7 @@ async def _open_native_file_dialog( path = stdout.decode().strip() if path: return path - except (asyncio.TimeoutError, OSError): + except (TimeoutError, OSError): pass # --- kdialog --- @@ -158,7 +158,7 @@ async def _open_native_file_dialog( path = stdout.decode().strip() if path: return path - except (asyncio.TimeoutError, OSError): + except (TimeoutError, OSError): pass return None @@ -520,13 +520,34 @@ def __init__(self) -> None: # used is gated at call time by _effective_caps.tools_enabled. This # ensures the registry is ready when the first tool-capable model loads. try: + tools_cfg = self.config.get("tools", {}) options = ( ToolRegistryOptions( web_search_api_key=( self.capabilities.web_search_api_key if self.capabilities.web_search_enabled else None - ) + ), + enable_custom_tools=bool(tools_cfg.get("enabled", True)), + runtime_options=ToolRuntimeOptions( + enabled=bool(tools_cfg.get("enabled", True)), + workspace_root=str(tools_cfg.get("workspace_root", ".")), + allow_external_directories=bool( + tools_cfg.get("allow_external_directories", False) + ), + command_timeout_seconds=int( + tools_cfg.get("command_timeout_seconds", 30) + ), + max_output_lines=int(tools_cfg.get("max_output_lines", 200)), + max_output_bytes=int(tools_cfg.get("max_output_bytes", 50_000)), + max_read_bytes=int(tools_cfg.get("max_read_bytes", 200_000)), + max_search_results=int(tools_cfg.get("max_search_results", 200)), + default_external_directories=tuple( + str(item) + for item in tools_cfg.get("default_external_directories", []) + if str(item).strip() + ), + ), ) ) self._tool_registry: ToolRegistry | None = build_registry(options) diff --git a/src/ollama_chat/chat.py b/src/ollama_chat/chat.py index 41a099a..6195fc3 100644 --- a/src/ollama_chat/chat.py +++ b/src/ollama_chat/chat.py @@ -5,8 +5,10 @@ import asyncio from collections.abc import AsyncGenerator from dataclasses import dataclass, field -import logging import inspect +import json +import logging +import re from typing import Any, Literal from .exceptions import ( @@ -201,7 +203,7 @@ async def list_models(self) -> list[str]: names: list[str] = [] models: Any = None if hasattr(response, "models"): - models = getattr(response, "models") + models = response.models elif isinstance(response, dict): models = response.get("models") elif hasattr(response, "model_dump"): @@ -308,7 +310,7 @@ async def show_model_capabilities( # SDK object path. if hasattr(response, "capabilities"): - caps_raw = getattr(response, "capabilities") + caps_raw = response.capabilities # Treat explicit None as unknown; anything else counts as present. if caps_raw is not None: caps_known = True @@ -419,6 +421,59 @@ def _extract_chunk_tool_calls(cls, chunk: Any) -> list[Any]: value = cls._extract_from_chunk(chunk, "tool_calls") return value if isinstance(value, list) else [] + @staticmethod + def _parse_inline_tool_call_from_content( + content: str, allowed_names: set[str] + ) -> list[dict[str, Any]]: + """Parse a tool call embedded as JSON in content code blocks. + + Some models emit a JSON object like {"name": "ls", "arguments": {}} + instead of structured tool_calls. Convert it to a minimal tool_call + dict only when the name is in allowed_names. + """ + text = (content or "").strip() + if not text: + return [] + + # Prefer ```json code blocks if present. + match = re.search(r"```json\s*(\{[\s\S]*?\})\s*```", text, re.IGNORECASE) + candidate = match.group(1) if match else text + + try: + parsed = json.loads(candidate) + except Exception: + return [] + + def _as_call(obj: dict[str, Any]) -> list[dict[str, Any]]: + if not isinstance(obj, dict): + return [] + if isinstance(obj.get("function"), dict): + fn = obj["function"] + name = str(fn.get("name", "")) + if name and name in allowed_names: + args = fn.get("arguments", {}) + if not isinstance(args, dict): + args = {} + return [{"function": {"name": name, "arguments": args}}] + return [] + name = str(obj.get("name", "")) + if name and name in allowed_names: + args = obj.get("arguments", {}) + if not isinstance(args, dict): + args = {} + return [{"function": {"name": name, "arguments": args}}] + return [] + + if isinstance(parsed, list): + for item in parsed: + calls = _as_call(item) + if calls: + return calls + return [] + if isinstance(parsed, dict): + return _as_call(parsed) + return [] + def _map_exception(self, exc: Exception) -> OllamaChatError: if isinstance(exc, OllamaChatError): return exc @@ -504,6 +559,20 @@ async def _stream_once_with_capabilities( yield ChatChunk(kind="content", text=content_text) chunk_tool_calls = self._extract_chunk_tool_calls(chunk) + + # If the model printed a JSON tool call in content (no structured field), + # parse it and treat it as a tool_call so the agent loop can proceed. + if not chunk_tool_calls and tools and content_text: + allowed: set[str] = set() + for t in tools: + if isinstance(t, dict): + fn = t.get("function", {}) + if isinstance(fn, dict): + n = fn.get("name") + if isinstance(n, str) and n: + allowed.add(n) + for tc in self._parse_inline_tool_call_from_content(content_text, allowed): + chunk_tool_calls.append(tc) for tc in chunk_tool_calls: name, args, index = self._parse_tool_call(tc) if name: diff --git a/src/ollama_chat/commands.py b/src/ollama_chat/commands.py index 2a1a642..4c9aa49 100644 --- a/src/ollama_chat/commands.py +++ b/src/ollama_chat/commands.py @@ -5,6 +5,7 @@ from dataclasses import dataclass import os import re + from .capabilities import CapabilityContext _IMAGE_PREFIX_RE = re.compile(r"(?:^|\s)/image\s+(\S+)") diff --git a/src/ollama_chat/config.py b/src/ollama_chat/config.py index ff1c02f..caf24e9 100644 --- a/src/ollama_chat/config.py +++ b/src/ollama_chat/config.py @@ -7,6 +7,7 @@ import os from pathlib import Path import re +import tomllib # stdlib since Python 3.11 (project requires >=3.11) from typing import Any from urllib.parse import urlparse @@ -21,8 +22,6 @@ from .exceptions import ConfigValidationError -import tomllib # stdlib since Python 3.11 (project requires >=3.11) - LOGGER = logging.getLogger(__name__) CONFIG_DIR = Path.home() / ".config" / "ollamaterm" @@ -276,6 +275,46 @@ def _validate_path_string(cls, value: Any) -> str: return normalized +class ToolsConfig(BaseModel): + """Runtime policy for schema-based coding tools.""" + + enabled: bool = True + workspace_root: str = "." + allow_external_directories: bool = False + command_timeout_seconds: int = Field(default=30, ge=1, le=600) + max_output_lines: int = Field(default=200, ge=1, le=10_000) + max_output_bytes: int = Field(default=50_000, ge=256, le=5_000_000) + max_read_bytes: int = Field(default=200_000, ge=256, le=20_000_000) + max_search_results: int = Field(default=200, ge=1, le=10_000) + default_external_directories: list[str] = Field(default_factory=list) + + @field_validator("workspace_root", mode="before") + @classmethod + def _validate_workspace_root(cls, value: Any) -> str: + if not isinstance(value, str): + raise ValueError("workspace_root must be a string.") + normalized = value.strip() + if not normalized: + raise ValueError("workspace_root must not be empty.") + return normalized + + @field_validator("default_external_directories", mode="before") + @classmethod + def _validate_external_directories(cls, value: Any) -> list[str]: + if value is None: + return [] + if not isinstance(value, list): + raise ValueError("default_external_directories must be a list.") + normalized: list[str] = [] + for item in value: + if not isinstance(item, str): + raise ValueError("default_external_directories entries must be strings.") + candidate = item.strip() + if candidate: + normalized.append(candidate) + return normalized + + class CapabilitiesConfig(BaseModel): """User-facing feature preferences for thinking, tools, web search, and vision. @@ -313,6 +352,7 @@ class Config(BaseModel): security: SecurityConfig = SecurityConfig() logging: LoggingConfig = LoggingConfig() persistence: PersistenceConfig = PersistenceConfig() + tools: ToolsConfig = ToolsConfig() capabilities: CapabilitiesConfig = CapabilitiesConfig() @model_validator(mode="after") diff --git a/src/ollama_chat/custom_tools.py b/src/ollama_chat/custom_tools.py new file mode 100644 index 0000000..1bd5651 --- /dev/null +++ b/src/ollama_chat/custom_tools.py @@ -0,0 +1,1217 @@ +"""Schema-first custom coding tools for Ollama function calling.""" + +from __future__ import annotations + +from collections.abc import Callable +from dataclasses import dataclass, field +import fnmatch +import json +import os +from pathlib import Path +import re +import subprocess +import tempfile +from typing import Any + +from .exceptions import OllamaToolError + +_SEARCH_SKIP_DIR_NAMES = { + ".git", + ".venv", + "__pycache__", + ".mypy_cache", + ".ruff_cache", + "node_modules", +} + + +@dataclass(frozen=True) +class ToolRuntimeOptions: + """Runtime limits and safety controls for local tools.""" + + enabled: bool = True + workspace_root: str = "." + allow_external_directories: bool = False + command_timeout_seconds: int = 30 + max_output_lines: int = 200 + max_output_bytes: int = 50_000 + max_read_bytes: int = 200_000 + max_search_results: int = 200 + default_external_directories: tuple[str, ...] = () + + +@dataclass(frozen=True) +class ToolSpec: + """JSON-schema function tool definition + handler.""" + + name: str + description: str + parameters_schema: dict[str, Any] + handler: Callable[[dict[str, Any]], str] + safety_level: str = "safe" + category: str = "meta" + + def as_ollama_tool(self) -> dict[str, Any]: + """Render the tool in Ollama's function schema format.""" + return { + "type": "function", + "function": { + "name": self.name, + "description": self.description, + "parameters": self.parameters_schema, + }, + } + + +@dataclass +class _ToolState: + """Session-local mutable state for plan/todo/task tools.""" + + allowed_external_roots: set[Path] = field(default_factory=set) + plan_mode: bool = False + plan_content: str = "" + todos: list[str] = field(default_factory=list) + tasks: dict[str, str] = field(default_factory=dict) + + +class CustomToolSuite: + """Factory for schema-first custom coding tools.""" + + def __init__( + self, + runtime_options: ToolRuntimeOptions, + web_search_fn: Callable[[str, int], str] | None = None, + web_fetch_fn: Callable[[str], str] | None = None, + ) -> None: + self._runtime_options = runtime_options + root = (runtime_options.workspace_root or ".").strip() or "." + self._workspace_root = Path(os.path.expanduser(root)).resolve() + self._state = _ToolState() + self._default_external_roots: set[Path] = set() + for entry in runtime_options.default_external_directories: + try: + self._default_external_roots.add( + Path(os.path.expanduser(entry)).resolve() + ) + except Exception: + continue + + self._web_search_fn = web_search_fn + self._web_fetch_fn = web_fetch_fn + self._executor: Callable[[str, dict[str, Any]], str] | None = None + self._specs: dict[str, ToolSpec] = {} + self._build_specs() + + def bind_executor(self, executor: Callable[[str, dict[str, Any]], str]) -> None: + """Provide an executor callback used by the batch tool.""" + self._executor = executor + + def specs(self) -> list[ToolSpec]: + """Return all custom tool specs.""" + return list(self._specs.values()) + + def get_spec(self, name: str) -> ToolSpec | None: + """Return one spec by name.""" + return self._specs.get(name) + + @staticmethod + def _object_schema( + properties: dict[str, Any], + required: list[str] | None = None, + *, + additional_properties: bool = False, + ) -> dict[str, Any]: + return { + "type": "object", + "properties": properties, + "required": required or [], + "additionalProperties": additional_properties, + } + + def _register(self, spec: ToolSpec) -> None: + self._specs[spec.name] = spec + + def _build_specs(self) -> None: + self._register( + ToolSpec( + name="read", + description=( + "Read file contents from workspace with optional line window." + ), + parameters_schema=self._object_schema( + { + "path": { + "type": "string", + "description": "Absolute or workspace-relative path.", + }, + "offset": { + "type": "integer", + "description": "1-indexed starting line number.", + }, + "limit": { + "type": "integer", + "description": "Max number of lines to return.", + }, + }, + required=["path"], + ), + handler=self._handle_read, + category="fs", + ) + ) + + self._register( + ToolSpec( + name="ls", + description="List files and directories.", + parameters_schema=self._object_schema( + { + "path": { + "type": "string", + "description": "Directory path (default workspace root).", + }, + "max_entries": { + "type": "integer", + "description": "Maximum entries to return.", + }, + } + ), + handler=self._handle_ls, + category="fs", + ) + ) + + self._register( + ToolSpec( + name="glob", + description="Find files by glob pattern.", + parameters_schema=self._object_schema( + { + "pattern": { + "type": "string", + "description": "Glob pattern, e.g. **/*.py", + }, + "path": { + "type": "string", + "description": "Base path (default workspace root).", + }, + "max_results": { + "type": "integer", + "description": "Maximum number of matches.", + }, + }, + required=["pattern"], + ), + handler=self._handle_glob, + category="search", + ) + ) + + grep_schema = self._object_schema( + { + "query": { + "type": "string", + "description": "Regex or literal text query.", + }, + "path": { + "type": "string", + "description": "File or directory path.", + }, + "case_sensitive": { + "type": "boolean", + "description": "Case-sensitive search.", + }, + "fixed_strings": { + "type": "boolean", + "description": "Treat query as literal text.", + }, + "max_results": { + "type": "integer", + "description": "Maximum matching lines.", + }, + }, + required=["query"], + ) + self._register( + ToolSpec( + name="grep", + description="Search file content and return matching lines.", + parameters_schema=grep_schema, + handler=self._handle_grep, + category="search", + ) + ) + self._register( + ToolSpec( + name="codesearch", + description="Code search alias for grep.", + parameters_schema=grep_schema, + handler=self._handle_grep, + category="search", + ) + ) + + self._register( + ToolSpec( + name="write", + description="Write full file content atomically.", + parameters_schema=self._object_schema( + { + "path": { + "type": "string", + "description": "Target file path.", + }, + "content": { + "type": "string", + "description": "Complete file content.", + }, + "overwrite": { + "type": "boolean", + "description": "Allow overwrite of existing file.", + }, + "create_dirs": { + "type": "boolean", + "description": "Create parent directories if missing.", + }, + }, + required=["path", "content"], + ), + handler=self._handle_write, + safety_level="confirm", + category="edit", + ) + ) + + edit_schema = self._object_schema( + { + "path": { + "type": "string", + "description": "Target file path.", + }, + "old_text": { + "type": "string", + "description": "Text to replace.", + }, + "new_text": { + "type": "string", + "description": "Replacement text.", + }, + "replace_all": { + "type": "boolean", + "description": "Replace all occurrences.", + }, + }, + required=["path", "old_text", "new_text"], + ) + self._register( + ToolSpec( + name="edit", + description="Replace a snippet in a file.", + parameters_schema=edit_schema, + handler=self._handle_edit, + safety_level="confirm", + category="edit", + ) + ) + + array_edit_item = self._object_schema( + { + "old_text": { + "type": "string", + "description": "Text to replace.", + }, + "new_text": { + "type": "string", + "description": "Replacement text.", + }, + "replace_all": { + "type": "boolean", + "description": "Replace all occurrences.", + }, + }, + required=["old_text", "new_text"], + ) + self._register( + ToolSpec( + name="multiedit", + description="Apply multiple snippet edits atomically.", + parameters_schema=self._object_schema( + { + "path": { + "type": "string", + "description": "Target file path.", + }, + "edits": { + "type": "array", + "description": "Ordered edit operations.", + "items": array_edit_item, + }, + }, + required=["path", "edits"], + ), + handler=self._handle_multiedit, + safety_level="confirm", + category="edit", + ) + ) + + self._register( + ToolSpec( + name="apply_patch", + description=( + "Apply structured patch hunks (old_text/new_text) to a file." + ), + parameters_schema=self._object_schema( + { + "path": { + "type": "string", + "description": "Target file path.", + }, + "hunks": { + "type": "array", + "description": "Patch hunks.", + "items": array_edit_item, + }, + }, + required=["path", "hunks"], + ), + handler=self._handle_apply_patch, + safety_level="confirm", + category="edit", + ) + ) + + self._register( + ToolSpec( + name="bash", + description="Run a shell command with timeout and output caps.", + parameters_schema=self._object_schema( + { + "command": { + "type": "string", + "description": "Shell command.", + }, + "cwd": { + "type": "string", + "description": "Optional working directory.", + }, + }, + required=["command"], + ), + handler=self._handle_bash, + safety_level="confirm", + category="shell", + ) + ) + + self._register( + ToolSpec( + name="batch", + description="Run a sequence of tool calls in one invocation.", + parameters_schema=self._object_schema( + { + "calls": { + "type": "array", + "description": "Array of tool call objects.", + "items": self._object_schema( + { + "name": { + "type": "string", + "description": "Tool name.", + }, + "arguments": { + "type": "object", + "description": "Arguments object.", + "additionalProperties": True, + }, + }, + required=["name", "arguments"], + additional_properties=False, + ), + }, + "continue_on_error": { + "type": "boolean", + "description": "Continue after errors.", + }, + }, + required=["calls"], + ), + handler=self._handle_batch, + category="meta", + ) + ) + + self._register( + ToolSpec( + name="external-directory", + description="Manage external directory allowlist for this session.", + parameters_schema=self._object_schema( + { + "action": { + "type": "string", + "description": "add, remove, or list.", + }, + "path": { + "type": "string", + "description": "Directory path for add/remove.", + }, + } + ), + handler=self._handle_external_directory, + safety_level="confirm", + category="fs", + ) + ) + + self._register( + ToolSpec( + name="registry", + description="List available tools and metadata.", + parameters_schema=self._object_schema({}), + handler=self._handle_registry, + category="meta", + ) + ) + self._register( + ToolSpec( + name="tool", + description="Inspect a tool definition by name.", + parameters_schema=self._object_schema( + { + "name": { + "type": "string", + "description": "Tool name.", + } + }, + required=["name"], + ), + handler=self._handle_tool, + category="meta", + ) + ) + self._register( + ToolSpec( + name="invalid", + description="Intentionally fail for debugging tool error handling.", + parameters_schema=self._object_schema({}), + handler=self._handle_invalid, + category="meta", + ) + ) + self._register( + ToolSpec( + name="truncation", + description="Show current output truncation limits.", + parameters_schema=self._object_schema({}), + handler=self._handle_truncation, + category="meta", + ) + ) + + self._register( + ToolSpec( + name="plan-enter", + description="Enter planning mode and optionally set a plan goal.", + parameters_schema=self._object_schema( + { + "goal": { + "type": "string", + "description": "Optional initial plan text.", + } + } + ), + handler=self._handle_plan_enter, + category="planning", + ) + ) + self._register( + ToolSpec( + name="plan-exit", + description="Exit planning mode.", + parameters_schema=self._object_schema({}), + handler=self._handle_plan_exit, + category="planning", + ) + ) + self._register( + ToolSpec( + name="plan", + description="Get/set/append/clear the current plan content.", + parameters_schema=self._object_schema( + { + "action": { + "type": "string", + "description": "get, set, append, or clear.", + }, + "content": { + "type": "string", + "description": "Content for set/append.", + }, + } + ), + handler=self._handle_plan, + category="planning", + ) + ) + self._register( + ToolSpec( + name="question", + description="Emit a structured clarification question.", + parameters_schema=self._object_schema( + { + "prompt": { + "type": "string", + "description": "Question text.", + }, + "context": { + "type": "string", + "description": "Optional context.", + }, + }, + required=["prompt"], + ), + handler=self._handle_question, + category="planning", + ) + ) + + self._register( + ToolSpec( + name="todo", + description="Add one todo item.", + parameters_schema=self._object_schema( + { + "item": { + "type": "string", + "description": "Todo item text.", + } + }, + required=["item"], + ), + handler=self._handle_todo, + category="task", + ) + ) + self._register( + ToolSpec( + name="todoread", + description="Read the current todo list.", + parameters_schema=self._object_schema({}), + handler=self._handle_todoread, + category="task", + ) + ) + self._register( + ToolSpec( + name="todowrite", + description="Replace or append todo items.", + parameters_schema=self._object_schema( + { + "items": { + "type": "array", + "description": "Todo item strings.", + "items": { + "type": "string", + "description": "Todo item.", + }, + }, + "mode": { + "type": "string", + "description": "replace or append.", + }, + }, + required=["items"], + ), + handler=self._handle_todowrite, + category="task", + ) + ) + self._register( + ToolSpec( + name="task", + description="Set/get/list named task statuses.", + parameters_schema=self._object_schema( + { + "action": { + "type": "string", + "description": "set, get, or list.", + }, + "name": { + "type": "string", + "description": "Task name.", + }, + "status": { + "type": "string", + "description": "Task status for set action.", + }, + } + ), + handler=self._handle_task, + category="task", + ) + ) + + self._register( + ToolSpec( + name="lsp", + description="Language-server tool stub (not configured).", + parameters_schema=self._object_schema( + { + "action": { + "type": "string", + "description": "Requested LSP action.", + }, + "path": { + "type": "string", + "description": "Target path.", + }, + "symbol": { + "type": "string", + "description": "Optional symbol.", + }, + } + ), + handler=self._handle_lsp, + category="meta", + ) + ) + self._register( + ToolSpec( + name="skill", + description="Skill invocation tool stub.", + parameters_schema=self._object_schema( + { + "name": { + "type": "string", + "description": "Skill name.", + }, + "input": { + "type": "string", + "description": "Skill input.", + }, + } + ), + handler=self._handle_skill, + category="meta", + ) + ) + self._register( + ToolSpec( + name="websearch", + description="Search the web using Ollama web_search integration.", + parameters_schema=self._object_schema( + { + "query": { + "type": "string", + "description": "Search query.", + }, + "max_results": { + "type": "integer", + "description": "Maximum number of results.", + }, + }, + required=["query"], + ), + handler=self._handle_websearch, + category="web", + ) + ) + self._register( + ToolSpec( + name="webfetch", + description="Fetch a URL using Ollama web_fetch integration.", + parameters_schema=self._object_schema( + { + "url": { + "type": "string", + "description": "URL to fetch.", + } + }, + required=["url"], + ), + handler=self._handle_webfetch, + category="web", + ) + ) + + @staticmethod + def _to_json(value: Any) -> str: + return json.dumps(value, indent=2, ensure_ascii=False, sort_keys=True) + + def _resolve_any_path(self, path_text: str) -> Path: + expanded = os.path.expanduser(path_text) + candidate = Path(expanded) + if not candidate.is_absolute(): + candidate = self._workspace_root / candidate + return candidate.resolve() + + def _allowed_roots(self) -> set[Path]: + roots = {self._workspace_root} + roots.update(self._default_external_roots) + if self._runtime_options.allow_external_directories: + roots.update(self._state.allowed_external_roots) + return roots + + def _is_path_allowed(self, path: Path) -> bool: + for root in self._allowed_roots(): + if path == root or root in path.parents: + return True + return False + + def _resolve_path(self, path_text: str, *, must_exist: bool = False) -> Path: + resolved = self._resolve_any_path(path_text) + if not self._is_path_allowed(resolved): + raise OllamaToolError( + f"Path {str(resolved)!r} is outside allowed workspace roots." + ) + if must_exist and not resolved.exists(): + raise OllamaToolError(f"Path does not exist: {str(resolved)!r}") + return resolved + + def _atomic_write(self, path: Path, content: str) -> None: + parent = path.parent + parent.mkdir(parents=True, exist_ok=True) + with tempfile.NamedTemporaryFile( + "w", + encoding="utf-8", + dir=str(parent), + delete=False, + ) as temp_file: + temp_file.write(content) + temp_name = temp_file.name + Path(temp_name).replace(path) + + def _handle_read(self, args: dict[str, Any]) -> str: + path = self._resolve_path(str(args["path"]), must_exist=True) + if path.is_dir(): + raise OllamaToolError("read expects a file path.") + + raw = path.read_bytes() + if len(raw) > self._runtime_options.max_read_bytes: + raw = raw[: self._runtime_options.max_read_bytes] + text = raw.decode("utf-8", errors="ignore") + + offset = int(args.get("offset", 1)) + limit = int(args.get("limit", 200)) + offset = max(1, offset) + limit = max(1, limit) + + lines = text.splitlines() + start = offset - 1 + end = start + limit + output: list[str] = [] + for idx, line in enumerate(lines[start:end], start=offset): + output.append(f"{idx:>6}\t{line}") + return "\n".join(output) + + def _handle_ls(self, args: dict[str, Any]) -> str: + target = self._resolve_path(str(args.get("path", ".")), must_exist=True) + if not target.is_dir(): + raise OllamaToolError("ls expects a directory path.") + + max_entries = max(1, int(args.get("max_entries", 200))) + entries = sorted(target.iterdir(), key=lambda p: p.name.lower()) + lines: list[str] = [] + for entry in entries[:max_entries]: + if entry.is_dir(): + lines.append(f"[dir ] {entry.name}/") + else: + lines.append(f"[file] {entry.name} ({entry.stat().st_size} bytes)") + if len(entries) > max_entries: + lines.append(f"... {len(entries) - max_entries} more entries") + return "\n".join(lines) + + def _handle_glob(self, args: dict[str, Any]) -> str: + base = self._resolve_path(str(args.get("path", ".")), must_exist=True) + if not base.is_dir(): + raise OllamaToolError("glob path must be a directory.") + pattern = str(args["pattern"]) + max_results = max( + 1, + int(args.get("max_results", self._runtime_options.max_search_results)), + ) + + found: list[str] = [] + for root, dir_names, file_names in os.walk(base): + dir_names[:] = [d for d in dir_names if d not in _SEARCH_SKIP_DIR_NAMES] + root_path = Path(root) + for name in sorted(file_names + dir_names): + full = root_path / name + rel = str(full.relative_to(base)) + if fnmatch.fnmatch(rel, pattern): + prefix = "dir" if full.is_dir() else "file" + found.append(f"{prefix}: {rel}") + if len(found) >= max_results: + return "\n".join(found) + if not found: + return "No matches found." + return "\n".join(found) + + def _iter_search_files(self, target: Path) -> list[Path]: + if target.is_file(): + return [target] + + files: list[Path] = [] + for root, dir_names, file_names in os.walk(target): + dir_names[:] = [d for d in dir_names if d not in _SEARCH_SKIP_DIR_NAMES] + root_path = Path(root) + for file_name in file_names: + files.append(root_path / file_name) + return files + + def _handle_grep(self, args: dict[str, Any]) -> str: + query = str(args["query"]) + case_sensitive = bool(args.get("case_sensitive", False)) + fixed_strings = bool(args.get("fixed_strings", False)) + max_results = max( + 1, + int(args.get("max_results", self._runtime_options.max_search_results)), + ) + + target = self._resolve_path(str(args.get("path", ".")), must_exist=True) + + flags = 0 if case_sensitive else re.IGNORECASE + if fixed_strings: + pattern = re.compile(re.escape(query), flags) + else: + try: + pattern = re.compile(query, flags) + except re.error as exc: + raise OllamaToolError(f"Invalid regex: {exc}") from exc + + matches: list[str] = [] + for file_path in self._iter_search_files(target): + if len(matches) >= max_results: + break + try: + if file_path.stat().st_size > self._runtime_options.max_read_bytes: + continue + content = file_path.read_text(encoding="utf-8", errors="ignore") + except Exception: + continue + + try: + rel = file_path.relative_to(self._workspace_root) + except ValueError: + rel = file_path + for line_no, line in enumerate(content.splitlines(), start=1): + if pattern.search(line): + matches.append(f"{str(rel)}:{line_no}:{line}") + if len(matches) >= max_results: + break + + if not matches: + return "No matches found." + return "\n".join(matches) + + def _apply_edits(self, content: str, edits: list[dict[str, Any]]) -> str: + updated = content + for edit in edits: + old_text = str(edit["old_text"]) + new_text = str(edit["new_text"]) + replace_all = bool(edit.get("replace_all", False)) + if old_text not in updated: + raise OllamaToolError("Edit failed: old_text not found.") + if replace_all: + updated = updated.replace(old_text, new_text) + else: + updated = updated.replace(old_text, new_text, 1) + return updated + + def _handle_write(self, args: dict[str, Any]) -> str: + path = self._resolve_path(str(args["path"])) + overwrite = bool(args.get("overwrite", True)) + create_dirs = bool(args.get("create_dirs", True)) + + if path.exists() and not overwrite: + raise OllamaToolError(f"Refusing to overwrite existing file: {str(path)!r}") + if not path.parent.exists() and not create_dirs: + raise OllamaToolError("Parent directory missing and create_dirs is false.") + + content = str(args["content"]) + self._atomic_write(path, content) + return f"Wrote {len(content.encode('utf-8'))} bytes to {str(path)}" + + def _handle_edit(self, args: dict[str, Any]) -> str: + path = self._resolve_path(str(args["path"]), must_exist=True) + content = path.read_text(encoding="utf-8", errors="ignore") + updated = self._apply_edits( + content, + [ + { + "old_text": str(args["old_text"]), + "new_text": str(args["new_text"]), + "replace_all": bool(args.get("replace_all", False)), + } + ], + ) + self._atomic_write(path, updated) + return f"Edited file: {str(path)}" + + def _handle_multiedit(self, args: dict[str, Any]) -> str: + path = self._resolve_path(str(args["path"]), must_exist=True) + edits = args.get("edits", []) + if not isinstance(edits, list) or not edits: + raise OllamaToolError("multiedit requires a non-empty edits list.") + + content = path.read_text(encoding="utf-8", errors="ignore") + updated = self._apply_edits(content, edits) + self._atomic_write(path, updated) + return f"Applied {len(edits)} edit(s) to {str(path)}" + + def _handle_apply_patch(self, args: dict[str, Any]) -> str: + path = self._resolve_path(str(args["path"]), must_exist=True) + hunks = args.get("hunks", []) + if not isinstance(hunks, list) or not hunks: + raise OllamaToolError("apply_patch requires a non-empty hunks list.") + + content = path.read_text(encoding="utf-8", errors="ignore") + updated = self._apply_edits(content, hunks) + self._atomic_write(path, updated) + return f"Applied {len(hunks)} hunk(s) to {str(path)}" + + def _handle_bash(self, args: dict[str, Any]) -> str: + command = str(args["command"]).strip() + if not command: + raise OllamaToolError("bash command must not be empty.") + + cwd = str(args.get("cwd", ".")).strip() or "." + cwd_path = self._resolve_path(cwd, must_exist=True) + if not cwd_path.is_dir(): + raise OllamaToolError("bash cwd must be a directory.") + + try: + completed = subprocess.run( + command, + shell=True, + capture_output=True, + text=True, + cwd=str(cwd_path), + timeout=max(1, int(self._runtime_options.command_timeout_seconds)), + ) + except subprocess.TimeoutExpired as exc: + raise OllamaToolError(f"bash timed out after {exc.timeout}s") from exc + + output = (completed.stdout or "") + (completed.stderr or "") + header = f"exit_code={completed.returncode} cwd={str(cwd_path)}" + if output.strip(): + return f"{header}\n{output}" + return header + + def _handle_batch(self, args: dict[str, Any]) -> str: + calls = args.get("calls", []) + if not isinstance(calls, list) or not calls: + raise OllamaToolError("batch requires a non-empty calls list.") + if self._executor is None: + raise OllamaToolError("batch executor is not configured.") + + continue_on_error = bool(args.get("continue_on_error", True)) + rows: list[dict[str, Any]] = [] + for index, call in enumerate(calls): + if not isinstance(call, dict): + raise OllamaToolError(f"batch call index {index} must be an object.") + name = str(call.get("name", "")).strip() + call_args = call.get("arguments", {}) + if not name: + raise OllamaToolError(f"batch call index {index} missing tool name.") + if not isinstance(call_args, dict): + raise OllamaToolError( + f"batch call index {index} arguments must be object." + ) + if name == "batch": + raise OllamaToolError("batch cannot call itself recursively.") + + try: + result = self._executor(name, call_args) + rows.append({"index": index, "name": name, "ok": True, "result": result}) + except OllamaToolError as exc: + rows.append( + { + "index": index, + "name": name, + "ok": False, + "error": str(exc), + } + ) + if not continue_on_error: + break + return self._to_json(rows) + + def _handle_external_directory(self, args: dict[str, Any]) -> str: + if not self._runtime_options.allow_external_directories: + raise OllamaToolError( + "external-directory is disabled by policy. " + "Enable tools.allow_external_directories." + ) + + action = str(args.get("action", "list")).strip().lower() or "list" + if action == "list": + roots = sorted(str(path) for path in self._state.allowed_external_roots) + if not roots: + return "No external directories configured." + return "\n".join(roots) + + path_text = str(args.get("path", "")).strip() + if not path_text: + raise OllamaToolError("external-directory add/remove requires a path.") + path = self._resolve_any_path(path_text) + + if action == "add": + if not path.exists() or not path.is_dir(): + raise OllamaToolError("Path must exist and be a directory.") + self._state.allowed_external_roots.add(path) + return f"Added external directory: {str(path)}" + + if action == "remove": + self._state.allowed_external_roots.discard(path) + return f"Removed external directory: {str(path)}" + + raise OllamaToolError("external-directory action must be add, remove, or list.") + + def _handle_registry(self, _args: dict[str, Any]) -> str: + rows = [] + for spec in self._specs.values(): + rows.append( + { + "name": spec.name, + "category": spec.category, + "safety_level": spec.safety_level, + "description": spec.description, + } + ) + rows.sort(key=lambda item: str(item["name"])) + return self._to_json(rows) + + def _handle_tool(self, args: dict[str, Any]) -> str: + name = str(args["name"]) + spec = self._specs.get(name) + if spec is None: + raise OllamaToolError(f"Unknown tool {name!r}") + return self._to_json( + { + "name": spec.name, + "description": spec.description, + "category": spec.category, + "safety_level": spec.safety_level, + "parameters": spec.parameters_schema, + } + ) + + def _handle_invalid(self, _args: dict[str, Any]) -> str: + raise OllamaToolError("invalid tool invoked intentionally") + + def _handle_truncation(self, _args: dict[str, Any]) -> str: + return self._to_json( + { + "max_output_lines": self._runtime_options.max_output_lines, + "max_output_bytes": self._runtime_options.max_output_bytes, + "max_read_bytes": self._runtime_options.max_read_bytes, + "max_search_results": self._runtime_options.max_search_results, + } + ) + + def _handle_plan_enter(self, args: dict[str, Any]) -> str: + self._state.plan_mode = True + goal = str(args.get("goal", "")).strip() + if goal: + self._state.plan_content = goal + return "plan mode enabled" + + def _handle_plan_exit(self, _args: dict[str, Any]) -> str: + self._state.plan_mode = False + return "plan mode disabled" + + def _handle_plan(self, args: dict[str, Any]) -> str: + action = str(args.get("action", "get")).strip().lower() or "get" + if action == "get": + return self._state.plan_content or "" + if action == "set": + self._state.plan_content = str(args.get("content", "")) + return "plan updated" + if action == "append": + addition = str(args.get("content", "")) + if self._state.plan_content: + self._state.plan_content = f"{self._state.plan_content}\n{addition}" + else: + self._state.plan_content = addition + return "plan appended" + if action == "clear": + self._state.plan_content = "" + return "plan cleared" + raise OllamaToolError("plan action must be get, set, append, or clear.") + + def _handle_question(self, args: dict[str, Any]) -> str: + return self._to_json( + { + "type": "question", + "prompt": str(args["prompt"]), + "context": str(args.get("context", "")), + } + ) + + def _handle_todo(self, args: dict[str, Any]) -> str: + item = str(args["item"]).strip() + if not item: + raise OllamaToolError("todo item must not be empty.") + self._state.todos.append(item) + return f"todo added ({len(self._state.todos)} item(s) total)" + + def _handle_todoread(self, _args: dict[str, Any]) -> str: + if not self._state.todos: + return "[]" + return self._to_json(self._state.todos) + + def _handle_todowrite(self, args: dict[str, Any]) -> str: + items = args.get("items", []) + if not isinstance(items, list): + raise OllamaToolError("todowrite items must be an array.") + normalized = [str(item).strip() for item in items if str(item).strip()] + mode = str(args.get("mode", "replace")).strip().lower() or "replace" + if mode == "replace": + self._state.todos = normalized + elif mode == "append": + self._state.todos.extend(normalized) + else: + raise OllamaToolError("todowrite mode must be replace or append.") + return f"todo list now has {len(self._state.todos)} item(s)" + + def _handle_task(self, args: dict[str, Any]) -> str: + action = str(args.get("action", "list")).strip().lower() or "list" + if action == "list": + return self._to_json(self._state.tasks) + if action == "get": + name = str(args.get("name", "")).strip() + if not name: + raise OllamaToolError("task get requires a task name.") + return self._state.tasks.get(name, "") + if action == "set": + name = str(args.get("name", "")).strip() + status = str(args.get("status", "")).strip() + if not name or not status: + raise OllamaToolError("task set requires name and status.") + self._state.tasks[name] = status + return f"task {name!r} set to {status!r}" + raise OllamaToolError("task action must be list, get, or set.") + + @staticmethod + def _not_configured(name: str) -> str: + return f"{name} is not configured in this runtime." + + def _handle_lsp(self, _args: dict[str, Any]) -> str: + return self._not_configured("lsp") + + def _handle_skill(self, _args: dict[str, Any]) -> str: + return self._not_configured("skill") + + def _handle_websearch(self, args: dict[str, Any]) -> str: + if self._web_search_fn is None: + return self._not_configured("websearch") + query = str(args["query"]) + max_results = int(args.get("max_results", 5)) + return str(self._web_search_fn(query, max_results)) + + def _handle_webfetch(self, args: dict[str, Any]) -> str: + if self._web_fetch_fn is None: + return self._not_configured("webfetch") + url = str(args["url"]) + return str(self._web_fetch_fn(url)) diff --git a/src/ollama_chat/logging_utils.py b/src/ollama_chat/logging_utils.py index 183ddf2..aa153eb 100644 --- a/src/ollama_chat/logging_utils.py +++ b/src/ollama_chat/logging_utils.py @@ -2,8 +2,6 @@ from __future__ import annotations -from datetime import datetime, timedelta, timezone -import json import logging import os from pathlib import Path @@ -12,7 +10,6 @@ import structlog - def _best_effort_private_permissions(path: Path) -> None: if os.name != "posix": return diff --git a/src/ollama_chat/message_store.py b/src/ollama_chat/message_store.py index dfd9e57..6dc64a9 100644 --- a/src/ollama_chat/message_store.py +++ b/src/ollama_chat/message_store.py @@ -2,8 +2,9 @@ from __future__ import annotations +from collections.abc import Iterable import json -from typing import Any, Iterable +from typing import Any # Public message type (no internal fields). Message = dict[str, Any] diff --git a/src/ollama_chat/persistence.py b/src/ollama_chat/persistence.py index 6f5ce78..f179d81 100644 --- a/src/ollama_chat/persistence.py +++ b/src/ollama_chat/persistence.py @@ -2,7 +2,7 @@ from __future__ import annotations -from datetime import datetime, timezone +from datetime import UTC, datetime import json import os from pathlib import Path @@ -112,7 +112,7 @@ def save_conversation( raise PersistenceDisabledError("Persistence is disabled in configuration.") self._ensure_paths() - now = datetime.now(timezone.utc) + now = datetime.now(UTC) created_at = now.isoformat() filename = f"{now.strftime('%Y%m%d-%H%M%S')}-{uuid4().hex[:8]}.json" target = self.directory / filename @@ -164,7 +164,7 @@ def export_markdown(self, messages: list[dict[str, str]], model: str) -> Path: raise PersistenceDisabledError("Persistence is disabled in configuration.") self._ensure_paths() - filename = f"{datetime.now(timezone.utc).strftime('%Y%m%d-%H%M%S')}-export.md" + filename = f"{datetime.now(UTC).strftime('%Y%m%d-%H%M%S')}-export.md" target = self.directory / filename lines = [f"# Conversation Export ({model})", ""] for message in messages: diff --git a/src/ollama_chat/support/__init__.py b/src/ollama_chat/support/__init__.py new file mode 100644 index 0000000..2374cfb --- /dev/null +++ b/src/ollama_chat/support/__init__.py @@ -0,0 +1,10 @@ +from __future__ import annotations + +__all__ = [ + "bus", + "file_time", + "ripgrep", + "lsp_client", + "question_service", + "permission", +] diff --git a/src/ollama_chat/support/bus.py b/src/ollama_chat/support/bus.py new file mode 100644 index 0000000..a63bb13 --- /dev/null +++ b/src/ollama_chat/support/bus.py @@ -0,0 +1,50 @@ +from __future__ import annotations + +import asyncio +from collections import defaultdict +from collections.abc import Callable +from dataclasses import dataclass +from typing import Any + + +@dataclass +class _Subscriber: + callback: Callable[[str, dict[str, Any]], None] + + +class Bus: + """Very small in-process pub/sub bus. + + This is intentionally minimal; it is enough for tools to notify the UI or + observers during tests without introducing external dependencies. + """ + + def __init__(self) -> None: + self._subscribers: defaultdict[str, list[_Subscriber]] = defaultdict(list) + self._lock = asyncio.Lock() + + async def publish(self, event: str, payload: dict[str, Any]) -> None: + async with self._lock: + for sub in list(self._subscribers.get(event, [])): + try: + sub.callback(event, payload) + except Exception: + continue + + def publish_nowait(self, event: str, payload: dict[str, Any]) -> None: + for sub in list(self._subscribers.get(event, [])): + try: + sub.callback(event, payload) + except Exception: + continue + + def subscribe(self, event: str, callback: Callable[[str, dict[str, Any]], None]) -> None: + self._subscribers[event].append(_Subscriber(callback)) + + def unsubscribe(self, event: str, callback: Callable[[str, dict[str, Any]], None]) -> None: + items = self._subscribers.get(event, []) + self._subscribers[event] = [s for s in items if s.callback is not callback] + + +# Global bus instance convenient for modules +bus = Bus() diff --git a/src/ollama_chat/support/file_time.py b/src/ollama_chat/support/file_time.py new file mode 100644 index 0000000..fa7afc6 --- /dev/null +++ b/src/ollama_chat/support/file_time.py @@ -0,0 +1,55 @@ +from __future__ import annotations + +import asyncio +from collections.abc import Callable +from datetime import datetime +import os +from pathlib import Path +from typing import Any + +_state: dict[str, dict[str, datetime]] = {} +_locks: dict[str, asyncio.Lock] = {} + + +def record_read(session_id: str, filepath: str | os.PathLike[str]) -> None: + """Record that session_id has read filepath at the current time.""" + path = str(Path(filepath).resolve()) + _state.setdefault(session_id, {})[path] = datetime.utcnow() + + +def get_read_time(session_id: str, filepath: str | os.PathLike[str]) -> datetime | None: + path = str(Path(filepath).resolve()) + return _state.get(session_id, {}).get(path) + + +async def assert_read(session_id: str, filepath: str | os.PathLike[str]) -> None: + """ + Guard against overwriting files that have not been read in this session. + + If OLLAMATERM_DISABLE_FILETIME_CHECK is set, the check is skipped. + """ + if os.environ.get("OLLAMATERM_DISABLE_FILETIME_CHECK"): + return + + path = str(Path(filepath).resolve()) + read_time = get_read_time(session_id, path) + if read_time is None: + raise RuntimeError(f"You must read file {path!r} before overwriting it.") + + try: + mtime = os.stat(path).st_mtime_ns / 1e9 + except FileNotFoundError: + # New file, allow creation. + return + + # Allow small tolerance for filesystem timestamp granularity. + if mtime > read_time.timestamp() + 0.05: + raise RuntimeError(f"File {path!r} has been modified since last read.") + + +async def with_lock(filepath: str | os.PathLike[str], fn: Callable[[], Any]): + """Serialize concurrent writes to the same file.""" + key = str(Path(filepath).resolve()) + lock = _locks.setdefault(key, asyncio.Lock()) + async with lock: + return await fn() diff --git a/src/ollama_chat/support/lsp_client.py b/src/ollama_chat/support/lsp_client.py new file mode 100644 index 0000000..f43d5c3 --- /dev/null +++ b/src/ollama_chat/support/lsp_client.py @@ -0,0 +1,36 @@ +from __future__ import annotations + +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any + + +@dataclass +class _State: + touched: set[str] = field(default_factory=set) + diagnostics: dict[str, list[dict[str, Any]]] = field(default_factory=dict) + + +_state = _State() + + +def touch_file(path: str | Path, *, notify: bool) -> None: # noqa: ARG001 - notify kept for API compatibility + p = str(Path(path).resolve()) + _state.touched.add(p) + # In a full implementation, this would send didOpen/didChange/didSave notifications + # to a running LSP server and refresh diagnostics. Here we keep a simple stub. + + +def get_diagnostics() -> dict[str, list[dict[str, Any]]]: + # Return a copy to avoid external mutation + return {k: list(v) for k, v in _state.diagnostics.items()} + + +def set_diagnostics(path: str | Path, messages: list[dict[str, Any]]) -> None: + p = str(Path(path).resolve()) + _state.diagnostics[p] = list(messages) + + +def has_clients_for(path: str | Path) -> bool: + # Stubbed: no language servers are wired by default in this standalone package. + return False diff --git a/src/ollama_chat/support/permission.py b/src/ollama_chat/support/permission.py new file mode 100644 index 0000000..f0e2a34 --- /dev/null +++ b/src/ollama_chat/support/permission.py @@ -0,0 +1,21 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any + + +@dataclass +class PermissionRequest: + permission: str + patterns: list[str] + always: list[str] + metadata: dict[str, Any] + + +def evaluate(_req: PermissionRequest) -> bool: # pragma: no cover - stub + """Placeholder permission evaluation. + + A real implementation would consult a ruleset. Here we always return True + and rely on the caller to use ctx.ask() to request interactive approval. + """ + return True diff --git a/src/ollama_chat/support/question_service.py b/src/ollama_chat/support/question_service.py new file mode 100644 index 0000000..8caf090 --- /dev/null +++ b/src/ollama_chat/support/question_service.py @@ -0,0 +1,56 @@ +from __future__ import annotations + +import asyncio +from dataclasses import dataclass +import time +from typing import Any + +from .bus import bus + + +@dataclass +class _Pending: + future: asyncio.Future + session_id: str + questions: list[dict[str, Any]] + + +_pending: dict[str, _Pending] = {} + + +def _new_id(prefix: str = "q") -> str: + return f"{prefix}_{int(time.time() * 1000)}" + + +async def ask( + *, + session_id: str, + questions: list[dict[str, Any]], + tool: dict[str, Any] | None = None, +) -> list[list[str]]: + """Publish a question event and await an answer. + + In this standalone implementation, if no reply is provided within a short + timeout, an empty answer is returned to avoid deadlock. + """ + qid = _new_id("question") + fut: asyncio.Future = asyncio.get_running_loop().create_future() + _pending[qid] = _Pending(future=fut, session_id=session_id, questions=questions) + try: + await bus.publish( + "question.asked", + {"id": qid, "session_id": session_id, "questions": questions, "tool": tool or {}}, + ) + try: + return await asyncio.wait_for(fut, timeout=0.5) + except TimeoutError: + return [[] for _ in questions] + finally: + _pending.pop(qid, None) + + +def reply(question_id: str, answers: list[list[str]]) -> None: + """Programmatically reply to a pending question (tests / UI).""" + item = _pending.get(question_id) + if item and not item.future.done(): + item.future.set_result(answers) diff --git a/src/ollama_chat/support/ripgrep.py b/src/ollama_chat/support/ripgrep.py new file mode 100644 index 0000000..b9836ea --- /dev/null +++ b/src/ollama_chat/support/ripgrep.py @@ -0,0 +1,75 @@ +from __future__ import annotations + +import asyncio +from collections.abc import AsyncIterator +import os +from pathlib import Path +import shutil + +RG_BIN = Path.home() / ".local" / "share" / "ollamaterm" / "bin" / "rg" + + +async def filepath() -> str: + """Return a path to a ripgrep binary if available. + + Preference order: + 1) RG_BIN path if present + 2) "rg" on PATH + 3) "ripgrep" on PATH + + Falls back to returning "rg" which may or may not exist at runtime. + """ + if RG_BIN.exists(): + return str(RG_BIN) + for name in ("rg", "ripgrep"): + found = shutil.which(name) + if found: + return found + return "rg" + + +async def files( + cwd: str, + glob: list[str] | None = None, + follow: bool = True, + hidden: bool = False, + signal: asyncio.Event | None = None, +) -> AsyncIterator[str]: + """Yield file paths using ripgrep if available, else Python fallback. + + Args roughly map to: rg --files [--follow] [--hidden] [--glob pat]... {cwd} + """ + rg = await filepath() + if shutil.which(Path(rg).name): + args = [rg, "--files"] + if follow: + args.append("--follow") + if hidden: + args.append("--hidden") + for pat in glob or []: + args += ["--glob", pat] + args.append(str(cwd)) + try: + proc = await asyncio.create_subprocess_exec( + *args, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.DEVNULL, + ) + except Exception: + proc = None + if proc is not None and proc.stdout is not None: + while True: + if signal and signal.is_set(): + break + chunk = await proc.stdout.readline() + if not chunk: + break + yield chunk.decode().rstrip("\n") + return + + # Fallback: walk the directory. + for root, _dirnames, filenames in os.walk(cwd): + for name in filenames: + if signal and signal.is_set(): + return + yield str(Path(root) / name) diff --git a/src/ollama_chat/tooling.py b/src/ollama_chat/tooling.py new file mode 100644 index 0000000..4e5ba85 --- /dev/null +++ b/src/ollama_chat/tooling.py @@ -0,0 +1,461 @@ +"""Tool registry for Ollama agent-loop tool calling.""" + +from __future__ import annotations + +from collections.abc import Callable +import asyncio +from dataclasses import dataclass, field +import logging +import os +import sys +import threading +import time +from typing import Any + +from .custom_tools import CustomToolSuite, ToolRuntimeOptions, ToolSpec +from .exceptions import OllamaToolError +from .tools.base import ToolContext + +try: + from ollama import web_fetch as _ollama_web_fetch + from ollama import web_search as _ollama_web_search +except ModuleNotFoundError: # pragma: no cover - optional dependency. + _ollama_web_search = None # type: ignore[assignment] + _ollama_web_fetch = None # type: ignore[assignment] + +LOGGER = logging.getLogger(__name__) + +# Serialise all temporary env-var mutations so that concurrent threads +# (e.g. when tool execution is offloaded via asyncio.to_thread) cannot +# observe each other's transient OLLAMA_API_KEY value. +_env_lock = threading.Lock() + + +def _with_temp_env(key: str, value: str, fn: Callable[[], str]) -> str: + """Temporarily set an environment variable, call fn(), then restore. + + Protected by a module-level lock so that concurrent threads do not + observe each other's transient environment changes. + """ + with _env_lock: + old_value = os.environ.get(key) + os.environ[key] = value + try: + return fn() + finally: + if old_value is None: + os.environ.pop(key, None) + else: + os.environ[key] = old_value + + +def _truncate_output(text: str, max_lines: int, max_bytes: int) -> tuple[str, bool]: + """Apply deterministic truncation by byte and line limits.""" + truncated = False + result = text + + if max_bytes > 0: + encoded = result.encode("utf-8", errors="ignore") + if len(encoded) > max_bytes: + truncated = True + clipped = encoded[:max_bytes] + result = clipped.decode("utf-8", errors="ignore") + result += "\n... [truncated by byte limit]" + + if max_lines > 0: + lines = result.splitlines() + if len(lines) > max_lines: + truncated = True + result = "\n".join(lines[:max_lines] + ["... [truncated by line limit]"]) + + return result, truncated + + +class ToolsPackageAdapter: + def __init__( + self, + runtime_options: ToolRuntimeOptions, + ask_cb: Callable[..., Any] | None = None, + metadata_cb: Callable[[dict], None] | None = None, + ) -> None: + self._runtime = runtime_options + self._ask_cb = ask_cb + self._metadata_cb = metadata_cb + + def to_specs(self) -> list[ToolSpec]: + specs: list[ToolSpec] = [] + # Ensure modules that import `from support import ...` can resolve the local + # package alias (ollama_chat.support) without modifying their import lines. + try: + import ollama_chat.support as _support_pkg # type: ignore[import-not-found] + sys.modules.setdefault("support", _support_pkg) + except Exception: + pass + try: + from .tools.registry import get_registry # local/lazy import + tools = get_registry().tools_for_model() + except Exception: + tools = [] + for tool in tools: + name = getattr(tool, "id", "") + if not name: + continue + try: + schema = tool.params_schema.model_json_schema() + except Exception: + schema = {"type": "object", "properties": {}, "required": []} + if schema.get("type") != "object": + schema = { + "type": "object", + "properties": schema.get("properties", {}), + "required": schema.get("required", []), + "additionalProperties": True, + } + else: + schema.setdefault("additionalProperties", True) + + def make_handler(t=tool) -> Callable[[dict[str, Any]], str]: + def handler(args: dict[str, Any]) -> str: + async def _run() -> str: + ctx = ToolContext( + session_id="default", + message_id=str(time.time_ns()), + agent="ollama", + abort=asyncio.Event(), + extra={ + "project_dir": self._runtime.workspace_root, + "bypassCwdCheck": self._runtime.allow_external_directories, + }, + ) + if self._metadata_cb is not None: + ctx._metadata_cb = self._metadata_cb # type: ignore[attr-defined] + if self._ask_cb is not None: + ctx._ask_cb = self._ask_cb # type: ignore[attr-defined] + result = await t.run(args, ctx) + return str(result.output) + + return asyncio.run(_run()) + + return handler + + specs.append( + ToolSpec( + name=name, + description=getattr(tool, "description", name), + parameters_schema=schema, + handler=make_handler(), + safety_level="safe", + category="builtin", + ) + ) + return specs + + +class ToolRegistry: + """Registry of callable tools available to the model during an agent loop.""" + + def __init__(self, runtime_options: ToolRuntimeOptions | None = None) -> None: + self._tools: dict[str, Callable[..., Any]] = {} + self._specs: dict[str, ToolSpec] = {} + self._runtime_options = runtime_options or ToolRuntimeOptions() + + def register(self, fn: Callable[..., Any]) -> None: + """Register a callable as a named tool. + + The function name is used as the tool name. + """ + self._tools[fn.__name__] = fn + LOGGER.debug( + "tools.registered", + extra={"event": "tools.registered", "tool": fn.__name__}, + ) + + def register_spec(self, spec: ToolSpec) -> None: + """Register a schema-first tool specification.""" + self._specs[spec.name] = spec + LOGGER.debug( + "tools.spec.registered", + extra={"event": "tools.spec.registered", "tool": spec.name}, + ) + + def list_tool_names(self) -> list[str]: + """Return all callable and schema tool names.""" + names = set(self._tools.keys()) | set(self._specs.keys()) + return sorted(names) + + def build_tools_list(self) -> list[Any]: + """Return callable and schema tools for passing to the Ollama SDK.""" + callables = list(self._tools.values()) + schema_tools = [spec.as_ollama_tool() for spec in self._specs.values()] + return callables + schema_tools + + def _validate_value(self, name: str, value: Any, schema: dict[str, Any]) -> None: + expected = schema.get("type") + if expected is None: + return + + if expected == "string": + if not isinstance(value, str): + raise OllamaToolError(f"Argument {name!r} must be a string.") + return + + if expected == "integer": + if not isinstance(value, int) or isinstance(value, bool): + raise OllamaToolError(f"Argument {name!r} must be an integer.") + return + + if expected == "number": + if not isinstance(value, (int, float)) or isinstance(value, bool): + raise OllamaToolError(f"Argument {name!r} must be numeric.") + return + + if expected == "boolean": + if not isinstance(value, bool): + raise OllamaToolError(f"Argument {name!r} must be a boolean.") + return + + if expected == "object": + if not isinstance(value, dict): + raise OllamaToolError(f"Argument {name!r} must be an object.") + return + + if expected == "array": + if not isinstance(value, list): + raise OllamaToolError(f"Argument {name!r} must be an array.") + item_schema = schema.get("items") + if isinstance(item_schema, dict): + for idx, item in enumerate(value): + self._validate_value(f"{name}[{idx}]", item, item_schema) + return + + def _validate_schema_arguments( + self, + schema: dict[str, Any], + arguments: dict[str, Any], + ) -> dict[str, Any]: + """Validate arguments against a constrained JSON schema subset.""" + if not isinstance(arguments, dict): + raise OllamaToolError("Tool arguments must be a JSON object.") + + if schema.get("type") != "object": + raise OllamaToolError("Tool schema root must be an object.") + + properties = schema.get("properties", {}) + if not isinstance(properties, dict): + raise OllamaToolError("Tool schema properties must be an object.") + required = schema.get("required", []) + if not isinstance(required, list): + raise OllamaToolError("Tool schema required must be a list.") + additional_allowed = bool(schema.get("additionalProperties", False)) + + missing = [name for name in required if name not in arguments] + if missing: + raise OllamaToolError(f"Missing required argument(s): {', '.join(missing)}") + + for arg_name, arg_value in arguments.items(): + prop_schema = properties.get(arg_name) + if prop_schema is None: + if additional_allowed: + continue + raise OllamaToolError(f"Unknown argument: {arg_name!r}") + self._validate_value(arg_name, arg_value, prop_schema) + + return arguments + + def execute(self, name: str, arguments: dict[str, Any]) -> str: + """Execute a named tool and return its string result. + + Raises OllamaToolError if the tool is unknown or raises an exception. + This method is synchronous; callers in an async context should use + ``asyncio.to_thread(registry.execute, name, args)`` to avoid blocking + the event loop. + """ + spec = self._specs.get(name) + if spec is not None: + try: + validated = self._validate_schema_arguments( + spec.parameters_schema, + arguments, + ) + result = str(spec.handler(validated)) + truncated, _ = _truncate_output( + result, + max_lines=self._runtime_options.max_output_lines, + max_bytes=self._runtime_options.max_output_bytes, + ) + return truncated + except OllamaToolError: + raise + except Exception as exc: # noqa: BLE001 + raise OllamaToolError(f"Tool {name!r} raised an error: {exc}") from exc + + fn = self._tools.get(name) + if fn is None: + raise OllamaToolError(f"Unknown tool requested by model: {name!r}") + try: + result = str(fn(**arguments)) + truncated, _ = _truncate_output( + result, + max_lines=self._runtime_options.max_output_lines, + max_bytes=self._runtime_options.max_output_bytes, + ) + return truncated + except OllamaToolError: + raise + except Exception as exc: # noqa: BLE001 - tool functions can fail arbitrarily. + raise OllamaToolError(f"Tool {name!r} raised an error: {exc}") from exc + + @property + def is_empty(self) -> bool: + """Return True when no tools are registered.""" + return not bool(self._tools or self._specs) + + +@dataclass(frozen=True) +class ToolRegistryOptions: + """Options used to build a ToolRegistry without boolean flags. + + If ``web_search_api_key`` is a non-empty string, web_search and web_fetch + tools are registered with the provided key. If it is empty or None, no web + tools are added. + """ + + web_search_api_key: str | None = None + enable_custom_tools: bool = False + enable_builtin_tools: bool = True + runtime_options: ToolRuntimeOptions = field(default_factory=ToolRuntimeOptions) + + +def build_registry(options: ToolRegistryOptions | None = None) -> ToolRegistry: + """Build a ToolRegistry based on provided options. + + - Optional callable-based web_search/web_fetch registration remains for + backward compatibility. + - Optional schema-based custom coding tools are registered when + ``enable_custom_tools`` is true. + """ + runtime = options.runtime_options if options is not None else ToolRuntimeOptions() + registry = ToolRegistry(runtime_options=runtime) + if options is None: + return registry + + api_key = (options.web_search_api_key or "").strip() + web_search_fn: Callable[[str, int], str] | None = None + web_fetch_fn: Callable[[str], str] | None = None + + if api_key: + web_search_fn = _make_web_search_tool(api_key) + web_fetch_fn = _make_web_fetch_tool(api_key) + # Backwards-compatible callable registrations. + registry.register(web_search_fn) + registry.register(web_fetch_fn) + LOGGER.info( + "tools.web_search.enabled", + extra={"event": "tools.web_search.enabled"}, + ) + + # Register built-in class-based tools first so that custom tools may override + # duplicate names when both systems are enabled. + builtin_names: set[str] = set() + if options.enable_builtin_tools: + adapter = ToolsPackageAdapter(options.runtime_options) + builtin_specs = adapter.to_specs() + for spec in builtin_specs: + registry.register_spec(spec) + builtin_names.add(spec.name) + + if options.enable_custom_tools: + suite = CustomToolSuite( + runtime_options=options.runtime_options, + web_search_fn=web_search_fn, + web_fetch_fn=web_fetch_fn, + ) + suite.bind_executor(registry.execute) + for spec in suite.specs(): + # Prefer built-in implementations for overlapping names in the initial + # allowlist (read, edit, grep, codesearch, list). Skip duplicates. + if spec.name in builtin_names: + continue + registry.register_spec(spec) + return registry + + +def _make_web_search_tool(api_key: str) -> Callable[..., str]: + """Return a web_search callable with the API key bound at creation time.""" + + if _ollama_web_search is None: + raise OllamaToolError( + "web_search is unavailable: the ollama package is not installed." + ) + + def _web_search_tool(query: str, max_results: int = 5) -> str: + """Search the web for a query and return relevant results. + + Args: + query: The search query string. + max_results: Maximum number of results to return (1-10). + + Returns: + Formatted search results as a string. + """ + try: + return _with_temp_env( + "OLLAMA_API_KEY", + api_key, + lambda: str(_ollama_web_search(query, max_results=max_results)), + ) + except Exception as exc: # noqa: BLE001 + raise OllamaToolError(f"web_search failed: {exc}") from exc + + return _web_search_tool + + +def _make_web_fetch_tool(api_key: str) -> Callable[..., str]: + """Return a web_fetch callable with the API key bound at creation time.""" + + if _ollama_web_fetch is None: + raise OllamaToolError( + "web_fetch is unavailable: the ollama package is not installed." + ) + + def _web_fetch_tool(url: str) -> str: + """Fetch the content of a web page by URL. + + Args: + url: The URL to fetch. + + Returns: + The page title and content as a string. + """ + try: + return _with_temp_env( + "OLLAMA_API_KEY", + api_key, + lambda: str(_ollama_web_fetch(url)), + ) + except Exception as exc: # noqa: BLE001 + raise OllamaToolError(f"web_fetch failed: {exc}") from exc + + return _web_fetch_tool + + +def build_default_registry( + web_search_enabled: bool = False, + web_search_api_key: str = "", +) -> ToolRegistry: + """Compatibility wrapper for legacy API with a boolean flag. + + Prefer ``build_registry(ToolRegistryOptions(web_search_api_key=...))``. + """ + if not web_search_enabled: + return ToolRegistry() + + # Resolve API key: explicit config value takes precedence over env var. + api_key = web_search_api_key or os.environ.get("OLLAMA_API_KEY", "").strip() + if not api_key: + raise OllamaToolError( + "web_search_enabled is True but no OLLAMA_API_KEY was found. " + "Set web_search_api_key in [capabilities] or export OLLAMA_API_KEY." + ) + + return build_registry(ToolRegistryOptions(web_search_api_key=api_key)) diff --git a/src/ollama_chat/tools.py b/src/ollama_chat/tools.py deleted file mode 100644 index 3e5afce..0000000 --- a/src/ollama_chat/tools.py +++ /dev/null @@ -1,208 +0,0 @@ -"""Tool registry for Ollama agent-loop tool calling.""" - -from __future__ import annotations - -import logging -import os -import threading -from dataclasses import dataclass -from collections.abc import Callable -from typing import Any - -from .exceptions import OllamaToolError - -try: - from ollama import web_fetch as _ollama_web_fetch - from ollama import web_search as _ollama_web_search -except ModuleNotFoundError: # pragma: no cover - optional dependency. - _ollama_web_search = None # type: ignore[assignment] - _ollama_web_fetch = None # type: ignore[assignment] - -LOGGER = logging.getLogger(__name__) - -# Serialise all temporary env-var mutations so that concurrent threads -# (e.g. when tool execution is offloaded via asyncio.to_thread) cannot -# observe each other's transient OLLAMA_API_KEY value. -_env_lock = threading.Lock() - - -def _with_temp_env(key: str, value: str, fn: Callable[[], str]) -> str: - """Temporarily set an environment variable, call fn(), then restore. - - Protected by a module-level lock so that concurrent threads do not - observe each other's transient environment changes. - """ - with _env_lock: - old_value = os.environ.get(key) - os.environ[key] = value - try: - return fn() - finally: - if old_value is None: - os.environ.pop(key, None) - else: - os.environ[key] = old_value - - -class ToolRegistry: - """Registry of callable tools available to the model during an agent loop.""" - - def __init__(self) -> None: - self._tools: dict[str, Callable[..., Any]] = {} - - def register(self, fn: Callable[..., Any]) -> None: - """Register a callable as a named tool. - - The function name is used as the tool name. - """ - self._tools[fn.__name__] = fn - LOGGER.debug( - "tools.registered", - extra={"event": "tools.registered", "tool": fn.__name__}, - ) - - def build_tools_list(self) -> list[Callable[..., Any]]: - """Return the list of tool callables for passing to the Ollama SDK.""" - return list(self._tools.values()) - - def execute(self, name: str, arguments: dict[str, Any]) -> str: - """Execute a named tool and return its string result. - - Raises OllamaToolError if the tool is unknown or raises an exception. - This method is synchronous; callers in an async context should use - ``asyncio.to_thread(registry.execute, name, args)`` to avoid blocking - the event loop. - """ - fn = self._tools.get(name) - if fn is None: - raise OllamaToolError(f"Unknown tool requested by model: {name!r}") - try: - result = fn(**arguments) - return str(result) - except OllamaToolError: - raise - except Exception as exc: # noqa: BLE001 - tool functions can fail arbitrarily. - raise OllamaToolError(f"Tool {name!r} raised an error: {exc}") from exc - - @property - def is_empty(self) -> bool: - """Return True when no tools are registered.""" - return not bool(self._tools) - - -@dataclass(frozen=True) -class ToolRegistryOptions: - """Options used to build a ToolRegistry without boolean flags. - - If ``web_search_api_key`` is a non-empty string, web_search and web_fetch - tools are registered with the provided key. If it is empty or None, no web - tools are added. - """ - - web_search_api_key: str | None = None - - -def build_registry(options: ToolRegistryOptions | None = None) -> ToolRegistry: - """Build a ToolRegistry based on provided options. - - - When ``options.web_search_api_key`` is a non-empty string, register - web_search and web_fetch tools with that key. - - Otherwise return an empty registry. - """ - registry = ToolRegistry() - if options is None: - return registry - - api_key = (options.web_search_api_key or "").strip() - if not api_key: - return registry - - # Validate and register tools with the provided key - registry.register(_make_web_search_tool(api_key)) - registry.register(_make_web_fetch_tool(api_key)) - LOGGER.info( - "tools.web_search.enabled", - extra={"event": "tools.web_search.enabled"}, - ) - return registry - - -def _make_web_search_tool(api_key: str) -> Callable[..., str]: - """Return a web_search callable with the API key bound at creation time.""" - - if _ollama_web_search is None: - raise OllamaToolError( - "web_search is unavailable: the ollama package is not installed." - ) - - def _web_search_tool(query: str, max_results: int = 5) -> str: - """Search the web for a query and return relevant results. - - Args: - query: The search query string. - max_results: Maximum number of results to return (1-10). - - Returns: - Formatted search results as a string. - """ - try: - return _with_temp_env( - "OLLAMA_API_KEY", - api_key, - lambda: str(_ollama_web_search(query, max_results=max_results)), - ) - except Exception as exc: # noqa: BLE001 - raise OllamaToolError(f"web_search failed: {exc}") from exc - - return _web_search_tool - - -def _make_web_fetch_tool(api_key: str) -> Callable[..., str]: - """Return a web_fetch callable with the API key bound at creation time.""" - - if _ollama_web_fetch is None: - raise OllamaToolError( - "web_fetch is unavailable: the ollama package is not installed." - ) - - def _web_fetch_tool(url: str) -> str: - """Fetch the content of a web page by URL. - - Args: - url: The URL to fetch. - - Returns: - The page title and content as a string. - """ - try: - return _with_temp_env( - "OLLAMA_API_KEY", - api_key, - lambda: str(_ollama_web_fetch(url)), - ) - except Exception as exc: # noqa: BLE001 - raise OllamaToolError(f"web_fetch failed: {exc}") from exc - - return _web_fetch_tool - - -def build_default_registry( - web_search_enabled: bool = False, - web_search_api_key: str = "", -) -> ToolRegistry: - """Compatibility wrapper for legacy API with a boolean flag. - - Prefer ``build_registry(ToolRegistryOptions(web_search_api_key=...))``. - """ - if not web_search_enabled: - return ToolRegistry() - - # Resolve API key: explicit config value takes precedence over env var. - api_key = web_search_api_key or os.environ.get("OLLAMA_API_KEY", "").strip() - if not api_key: - raise OllamaToolError( - "web_search_enabled is True but no OLLAMA_API_KEY was found. " - "Set web_search_api_key in [capabilities] or export OLLAMA_API_KEY." - ) - - return build_registry(ToolRegistryOptions(web_search_api_key=api_key)) diff --git a/src/ollama_chat/tools/__init__.py b/src/ollama_chat/tools/__init__.py new file mode 100644 index 0000000..981efd6 --- /dev/null +++ b/src/ollama_chat/tools/__init__.py @@ -0,0 +1,32 @@ +from __future__ import annotations + +# Public package for schema-first tool implementations described in +# PYTHON_TOOLS_PROMPT.md. These modules are intentionally isolated from the +# existing ollama_chat runtime and can be imported independently. + +__all__ = [ + "base", + "truncation", + "external_directory", + "registry", + "bash_tool", + "read_tool", + "edit_tool", + "write_tool", + "glob_tool", + "grep_tool", + "webfetch_tool", + "websearch_tool", + "codesearch_tool", + "task_tool", + "batch_tool", + "lsp_tool", + "plan_tool", + "question_tool", + "todo_tool", + "skill_tool", + "apply_patch_tool", + "multiedit_tool", + "ls_tool", + "invalid_tool", +] diff --git a/src/ollama_chat/tools/apply_patch_tool.py b/src/ollama_chat/tools/apply_patch_tool.py new file mode 100644 index 0000000..464188a --- /dev/null +++ b/src/ollama_chat/tools/apply_patch_tool.py @@ -0,0 +1,274 @@ +from __future__ import annotations + +from dataclasses import dataclass +from difflib import unified_diff +from pathlib import Path +from typing import Any + +from support import bus, lsp_client + +from .base import ParamsSchema, Tool, ToolContext, ToolResult +from .external_directory import assert_external_directory + + +class ApplyPatchParams(ParamsSchema): + patch_text: str + + +@dataclass +class _AddHunk: + path: str + content: str + + +@dataclass +class _DeleteHunk: + path: str + + +@dataclass +class _UpdateHunk: + path: str + chunks: list[tuple[str, str]] # (old_text, new_text) + move_to: str | None = None + + +_HUNK_MARKER = "*** " +_BEGIN = "*** Begin Patch" +_END = "*** End Patch" +_ADD = "*** Add File:" +_UPDATE = "*** Update File:" +_DELETE = "*** Delete File:" +_MOVE_TO = "*** Move to:" + + +def _extract_between(text: str, start: str, end: str) -> str: + try: + s = text.index(start) + e = text.rindex(end) + return text[s + len(start) : e] + except ValueError: + return text + + +def _parse_patch(text: str) -> list[Any]: + body = _extract_between(text, _BEGIN, _END) + lines = body.splitlines() + i = 0 + hunks: list[Any] = [] + + def _collect_until_header(start: int) -> tuple[list[str], int]: + out: list[str] = [] + j = start + while j < len(lines): + if lines[j].startswith(_HUNK_MARKER) and not lines[j].startswith("@@"): + break + out.append(lines[j]) + j += 1 + return out, j + + while i < len(lines): + line = lines[i].strip() + if not line: + i += 1 + continue + if line.startswith(_ADD): + path = line[len(_ADD) :].strip() + content_lines, i = _collect_until_header(i + 1) + content: list[str] = [] + for ln in content_lines: + if ln.startswith("+"): + content.append(ln[1:]) + else: + # Allow raw lines as well (defensive) + content.append(ln) + hunks.append(_AddHunk(path=path, content="\n".join(content) + ("\n" if content_lines else ""))) + continue + if line.startswith(_DELETE): + path = line[len(_DELETE) :].strip() + hunks.append(_DeleteHunk(path=path)) + i += 1 + continue + if line.startswith(_UPDATE): + path = line[len(_UPDATE) :].strip() + move_to: str | None = None + j = i + 1 + # Optional move to next line + if j < len(lines) and lines[j].startswith(_MOVE_TO): + move_to = lines[j][len(_MOVE_TO) :].strip() + j += 1 + chunks: list[tuple[str, str]] = [] + cur_old: list[str] = [] + cur_new: list[str] = [] + in_chunk = False + while j < len(lines): + s = lines[j] + if s.startswith(_HUNK_MARKER) and not s.startswith("@@"): + break + if s.startswith("@@"): + # start new chunk + if in_chunk and (cur_old or cur_new): + chunks.append(("\n".join(cur_old), "\n".join(cur_new))) + cur_old, cur_new = [], [] + in_chunk = True + j += 1 + continue + if not in_chunk: + j += 1 + continue + if s.startswith(" "): + cur_old.append(s[1:]) + cur_new.append(s[1:]) + elif s.startswith("-"): + cur_old.append(s[1:]) + elif s.startswith("+"): + cur_new.append(s[1:]) + else: + # treat as context + cur_old.append(s) + cur_new.append(s) + j += 1 + if in_chunk and (cur_old or cur_new): + chunks.append(("\n".join(cur_old), "\n".join(cur_new))) + hunks.append(_UpdateHunk(path=path, chunks=chunks, move_to=move_to)) + i = j + continue + # Unknown line: skip + i += 1 + + return hunks + + +def _apply_update_chunks(old: str, chunks: list[tuple[str, str]]) -> str: + updated = old + for old_text, new_text in chunks: + if not old_text and not new_text: + continue + # Try exact match first + if old_text and old_text in updated: + updated = updated.replace(old_text, new_text, 1) + continue + # Try without trailing newline + if old_text.endswith("\n") and old_text[:-1] in updated: + updated = updated.replace(old_text[:-1], new_text, 1) + continue + # As a last resort, raise error to surface mismatch + raise RuntimeError("apply_patch verification failed: cannot locate chunk in target file") + return updated + + +class ApplyPatchTool(Tool): + id = "apply_patch" + params_schema = ApplyPatchParams + + async def execute(self, params: ApplyPatchParams, ctx: ToolContext) -> ToolResult: + hunks = _parse_patch(params.patch_text) + if not hunks: + return ToolResult( + title="apply_patch", + output="apply_patch verification failed: no hunks found", + metadata={"ok": False}, + ) + + # Resolve paths and build previews/diffs + diffs: list[str] = [] + file_list: list[str] = [] + actions: list[tuple[str, Any]] = [] # (action, data) + + for h in hunks: + if isinstance(h, _AddHunk): + path = Path(h.path).expanduser().resolve() + await assert_external_directory(ctx, str(path)) + old_content = "" + new_content = h.content + diff = "\n".join( + unified_diff(old_content.splitlines(), new_content.splitlines(), fromfile=str(path), tofile=str(path), lineterm="") + ) + diffs.append(diff) + file_list.append(str(path)) + actions.append(("add", (path, new_content))) + elif isinstance(h, _DeleteHunk): + path = Path(h.path).expanduser().resolve() + await assert_external_directory(ctx, str(path)) + old_content = path.read_text(encoding="utf-8", errors="ignore") if path.exists() else "" + new_content = "" + diff = "\n".join( + unified_diff(old_content.splitlines(), new_content.splitlines(), fromfile=str(path), tofile=str(path), lineterm="") + ) + diffs.append(diff) + file_list.append(str(path)) + actions.append(("delete", (path,))) + elif isinstance(h, _UpdateHunk): + src = Path(h.path).expanduser().resolve() + await assert_external_directory(ctx, str(src)) + dst = Path(h.move_to).expanduser().resolve() if h.move_to else None + if dst is not None: + await assert_external_directory(ctx, str(dst)) + old_content = src.read_text(encoding="utf-8", errors="ignore") if src.exists() else "" + new_content = _apply_update_chunks(old_content, h.chunks) + diff = "\n".join( + unified_diff(old_content.splitlines(), new_content.splitlines(), fromfile=str(src), tofile=str(dst or src), lineterm="") + ) + diffs.append(diff) + file_list.append(str(dst or src)) + actions.append(("update", (src, dst, new_content))) + + diff_text = "\n".join(diffs) + await ctx.ask( + permission="edit", + patterns=file_list, + always=["*"], + metadata={"diff": diff_text, "files": file_list}, + ) + + changed: list[str] = [] + # Apply changes + for kind, data in actions: + if kind == "add": + path, content = data + path.parent.mkdir(parents=True, exist_ok=True) + path.write_text(content, encoding="utf-8") + changed.append(f"A {path}") + try: + await bus.bus.publish("file.edited", {"file": str(path)}) + await bus.bus.publish("file.watcher.updated", {"file": str(path), "event": "add"}) + except Exception: + pass + try: + lsp_client.touch_file(str(path), notify=True) + except Exception: + pass + elif kind == "delete": + (path,) = data + try: + path.unlink(missing_ok=True) + changed.append(f"D {path}") + try: + await bus.bus.publish("file.watcher.updated", {"file": str(path), "event": "unlink"}) + except Exception: + pass + except Exception: + changed.append(f"D {path} (failed)") + elif kind == "update": + src, dst, content = data + target = dst or src + target.parent.mkdir(parents=True, exist_ok=True) + target.write_text(content, encoding="utf-8") + if dst and dst != src: + try: + src.unlink(missing_ok=True) + except Exception: + pass + changed.append(f"M {target}") + try: + await bus.bus.publish("file.edited", {"file": str(target)}) + await bus.bus.publish("file.watcher.updated", {"file": str(target), "event": "change"}) + except Exception: + pass + try: + lsp_client.touch_file(str(target), notify=True) + except Exception: + pass + + output = "\n".join(changed) if changed else "No changes applied." + return ToolResult(title="apply_patch", output=output, metadata={"changed": len(changed)}) diff --git a/src/ollama_chat/tools/base.py b/src/ollama_chat/tools/base.py new file mode 100644 index 0000000..748d225 --- /dev/null +++ b/src/ollama_chat/tools/base.py @@ -0,0 +1,135 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +import asyncio +from collections.abc import Awaitable, Callable +from dataclasses import dataclass, field, replace +from typing import Any + +from pydantic import BaseModel, ValidationError + +from .truncation import truncate_output + + +@dataclass +class Attachment: + type: str # "file" + mime: str + url: str # data:;base64, or https://... + + +@dataclass +class ToolResult: + title: str + output: str + metadata: dict[str, Any] + attachments: list[Attachment] = field(default_factory=list) + + +@dataclass +class ToolContext: + session_id: str + message_id: str + agent: str + abort: asyncio.Event # set when the user aborts + call_id: str | None = None + extra: dict[str, Any] = field(default_factory=dict) + messages: list[Any] = field(default_factory=list) + + # Injected by the runtime: + _metadata_cb: Callable[[dict], None] = field(default=lambda _: None, repr=False) + _ask_cb: Callable[..., Awaitable[None]] | None = field(default=None, repr=False) + + def metadata(self, title: str | None = None, metadata: dict | None = None) -> None: + """Update live streaming metadata visible in the UI.""" + self._metadata_cb({"title": title, "metadata": metadata or {}}) + + async def ask( + self, + permission: str, + patterns: list[str], + always: list[str], + metadata: dict, + ) -> None: + """ + Request user approval. + Raises PermissionDeniedError / PermissionRejectedError on denial. + """ + if self._ask_cb: + await self._ask_cb( + permission=permission, + patterns=patterns, + always=always, + metadata=metadata, + session_id=self.session_id, + ) + + def with_call_id(self, call_id: str) -> ToolContext: + """Return a shallow-copied context with a different call_id.""" + return replace(self, call_id=call_id) + + +class ParamsSchema(BaseModel): + """Base class for all tool parameter schemas.""" + + model_config = {"extra": "ignore"} + + +class Tool(ABC): + """ + Abstract base class for all tools. + + Subclasses set: + id – unique tool name (matches permission key) + description – shown to the LLM + params_schema – a ParamsSchema subclass + + The base run() helper performs: + 1. Validate params via the Pydantic schema + 2. Call the concrete execute() + 3. Apply output truncation (truncate_output()) unless the result + already has metadata["truncated"] set + """ + + id: str + description: str = "" + params_schema: type[ParamsSchema] + + @abstractmethod + async def execute(self, params: ParamsSchema, ctx: ToolContext) -> ToolResult: # pragma: no cover - interface only + ... + + def schema(self) -> dict: + """Return the OpenAI function-calling schema for this tool.""" + return { + "name": self.id, + "description": self.description, + "parameters": self.params_schema.model_json_schema(), + } + + def format_validation_error(self, error: Exception) -> str | None: + """Override to provide custom error messages for schema validation failures.""" + return None + + async def run(self, raw_params: dict[str, Any], ctx: ToolContext) -> ToolResult: + """Validate, execute, and apply truncation to the tool output.""" + try: + params = self.params_schema.model_validate(raw_params) + except ValidationError as exc: # pragma: no cover - defensive + msg = self.format_validation_error(exc) or str(exc) + return ToolResult( + title=f"{self.id}: invalid parameters", + output=msg, + metadata={"ok": False, "validation_error": True}, + ) + + result = await self.execute(params, ctx) + # Respect explicit truncation metadata from the tool implementation. + if not result.metadata.get("truncated"): + trunc = await truncate_output(result.output, agent=ctx.agent) + result.output = trunc.content + # Merge truncation metadata non-destructively. + result.metadata = {**result.metadata, "truncated": trunc.truncated} + if trunc.output_path: + result.metadata.setdefault("output_path", trunc.output_path) + return result diff --git a/src/ollama_chat/tools/bash_tool.py b/src/ollama_chat/tools/bash_tool.py new file mode 100644 index 0000000..35cbb37 --- /dev/null +++ b/src/ollama_chat/tools/bash_tool.py @@ -0,0 +1,198 @@ +from __future__ import annotations + +import asyncio +import os +from pathlib import Path +import shlex +import signal +import subprocess + +from .base import ParamsSchema, Tool, ToolContext, ToolResult + +DEFAULT_TIMEOUT_MS = 120_000 +MAX_METADATA_LENGTH = 30_000 + +# Minimal arity table used to build stable permission patterns +ARITY: dict[str, int] = { + "cat": 1, + "cd": 1, + "chmod": 1, + "chown": 1, + "cp": 1, + "echo": 1, + "grep": 1, + "kill": 1, + "ls": 1, + "mkdir": 1, + "mv": 1, + "rm": 1, + "touch": 1, + "which": 1, + "git": 2, + "npm": 2, + "bun": 2, + "docker": 2, + "python": 2, + "pip": 2, + "cargo": 2, + "go": 2, + "make": 2, + "yarn": 2, + "npm run": 3, + "bun run": 3, + "git config": 3, + "docker compose": 3, +} + + +class BashParams(ParamsSchema): + command: str + description: str # short summary for live metadata + timeout: int | None = None # milliseconds; default DEFAULT_TIMEOUT_MS + workdir: str | None = None # defaults to project directory + + +class BashTool(Tool): + id = "bash" + params_schema = BashParams + description = ( + "Executes a bash command with timeout and output caps. Use workdir instead of cd." + ) + + async def _kill_process_tree(self, proc: asyncio.subprocess.Process) -> None: + try: + if os.name == "nt": # Windows + subprocess.run(["taskkill", "/F", "/T", "/PID", str(proc.pid)]) + else: + os.killpg(os.getpgid(proc.pid), signal.SIGTERM) + except Exception: + try: + proc.terminate() + except Exception: + pass + + def _extract_tokens(self, command: str) -> list[str]: + try: + return shlex.split(command) + except Exception: + return command.strip().split() + + def _arity_prefix(self, tokens: list[str]) -> list[str]: + for length in range(len(tokens), 0, -1): + key = " ".join(tokens[:length]) + if key in ARITY: + return tokens[: ARITY[key]] + return tokens[:1] + + async def _check_external_dirs(self, tokens: list[str], cwd: Path, ctx: ToolContext) -> None: + pathlike_cmds = {"cd", "rm", "cp", "mv", "mkdir", "touch", "chmod", "chown", "cat"} + if not tokens: + return + if tokens[0] not in pathlike_cmds: + return + dirs: set[Path] = set() + for tok in tokens[1:]: + if tok.startswith("-"): + continue + p = Path(tok) + if not p.is_absolute(): + p = (cwd / p).resolve() + try: + if p.is_dir(): + target = p + else: + target = p.parent + dirs.add(target) + except Exception: + continue + if not dirs: + return + globs = [str(d / "*") for d in dirs] + await ctx.ask( + permission="external_directory", + patterns=globs, + always=globs, + metadata={}, + ) + + async def execute(self, params: BashParams, ctx: ToolContext) -> ToolResult: + command = params.command.strip() + if not command: + return ToolResult(title="bash", output="Empty command.", metadata={"ok": False}) + + project_dir = Path(str(ctx.extra.get("project_dir", "."))).expanduser().resolve() + cwd = Path(params.workdir or project_dir).expanduser().resolve() + if not cwd.exists() or not cwd.is_dir(): + return ToolResult(title="bash", output="Invalid working directory.", metadata={"ok": False}) + + tokens = self._extract_tokens(command) + await self._check_external_dirs(tokens, cwd, ctx) + + # Ask permission with patterns and arity prefix + prefix = " ".join(self._arity_prefix(tokens)) + await ctx.ask( + permission="bash", + patterns=[command], + always=[prefix + " *"], + metadata={}, + ) + + extra_env: dict[str, str] = {} + timeout_ms = int(params.timeout or DEFAULT_TIMEOUT_MS) + timeout_sec = max(1.0, timeout_ms / 1000.0) + + # Spawn process in its own session/process group for reliable termination + proc = await asyncio.create_subprocess_shell( + command, + cwd=str(cwd), + env={**os.environ, **extra_env}, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + start_new_session=(os.name != "nt"), + ) + + output_chunks: list[str] = [] + aborted = False + timed_out = False + + async def _read_stream(stream: asyncio.StreamReader) -> None: + while True: + if ctx.abort.is_set(): + break + chunk = await stream.readline() + if not chunk: + break + output_chunks.append(chunk.decode(errors="ignore")) + tail = ("".join(output_chunks))[-MAX_METADATA_LENGTH:] + ctx.metadata(title="bash", metadata={"output": tail, "description": params.description}) + + tasks = [] + if proc.stdout is not None: + tasks.append(asyncio.create_task(_read_stream(proc.stdout))) + if proc.stderr is not None: + tasks.append(asyncio.create_task(_read_stream(proc.stderr))) + + try: + await asyncio.wait_for(proc.wait(), timeout=timeout_sec) + except TimeoutError: + timed_out = True + await self._kill_process_tree(proc) + if ctx.abort.is_set(): + aborted = True + await self._kill_process_tree(proc) + + # Ensure readers finish + await asyncio.gather(*tasks, return_exceptions=True) + + exit_code = proc.returncode if proc.returncode is not None else -1 + output = "".join(output_chunks) + + meta_note = [] + if timed_out: + meta_note.append("timed out") + if aborted: + meta_note.append("aborted") + status_part = f" status={','.join(meta_note)}" if meta_note else "" + output += f"\n exit_code={exit_code}{status_part}" + + return ToolResult(title=f"bash: {command}", output=output, metadata={"exit_code": exit_code, "timed_out": timed_out, "aborted": aborted}) diff --git a/src/ollama_chat/tools/batch_tool.py b/src/ollama_chat/tools/batch_tool.py new file mode 100644 index 0000000..8942f2f --- /dev/null +++ b/src/ollama_chat/tools/batch_tool.py @@ -0,0 +1,69 @@ +from __future__ import annotations + +import asyncio +from typing import Any + +from pydantic import BaseModel + +from .base import ParamsSchema, Tool, ToolContext, ToolResult + + +class ToolCallSpec(BaseModel): + tool: str + parameters: dict[str, Any] + + +class BatchParams(ParamsSchema): + tool_calls: list[ToolCallSpec] + + +DISALLOWED = {"batch"} +FILTERED_FROM_SUGGESTIONS = {"invalid", "apply_patch"} | DISALLOWED + + +class BatchTool(Tool): + id = "batch" + params_schema = BatchParams + + async def execute(self, params: BatchParams, ctx: ToolContext) -> ToolResult: + calls = list(params.tool_calls or []) + over = 0 + if len(calls) > 25: + over = len(calls) - 25 + calls = calls[:25] + + # Lazy import to avoid circular dependency at module import time. + from .registry import get_registry # noqa: WPS433 + registry = get_registry() + + async def run_one(index: int, call: ToolCallSpec) -> tuple[int, bool, str | None, ToolResult | None]: + if call.tool in DISALLOWED: + return index, False, "Tool not allowed in batch", None + tool = registry.get(call.tool) + if not tool: + return index, False, f"Tool not in registry: {call.tool}", None + part_id = f"part_{index}" + try: + result = await tool.run(call.parameters, ctx.with_call_id(part_id)) + return index, True, None, result + except Exception as exc: # noqa: BLE001 + return index, False, str(exc), None + + tasks = [run_one(i, call) for i, call in enumerate(calls)] + results = await asyncio.gather(*tasks, return_exceptions=False) + + ok_count = sum(1 for _, ok, _, _ in results if ok) + fail_count = len(results) - ok_count + attachments = [] + for _, ok, _, res in results: + if ok and res is not None: + attachments.extend(res.attachments) + + if fail_count == 0: + summary = f"All {ok_count} tools executed successfully." + else: + summary = f"Executed {ok_count}/{len(results)} tools. {fail_count} failed." + if over: + summary += f" Skipped {over} additional call(s)." + + return ToolResult(title="batch", output=summary, metadata={"ok": fail_count == 0}, attachments=attachments) diff --git a/src/ollama_chat/tools/codesearch_tool.py b/src/ollama_chat/tools/codesearch_tool.py new file mode 100644 index 0000000..351c574 --- /dev/null +++ b/src/ollama_chat/tools/codesearch_tool.py @@ -0,0 +1,75 @@ +from __future__ import annotations + +import json + +from .base import ParamsSchema, Tool, ToolContext, ToolResult + +EXA_BASE_URL = "https://mcp.exa.ai" +EXA_ENDPOINT = "/mcp" + + +class CodeSearchParams(ParamsSchema): + query: str + tokens_num: int = 5000 # 1000–50000 + + +class CodeSearchTool(Tool): + id = "codesearch" + params_schema = CodeSearchParams + + async def execute(self, params: CodeSearchParams, ctx: ToolContext) -> ToolResult: + await ctx.ask( + permission="codesearch", + patterns=[params.query], + always=["*"], + metadata={"query": params.query, "tokens_num": params.tokens_num}, + ) + + try: + import httpx # noqa: WPS433 + except Exception as exc: # pragma: no cover - optional dep + return ToolResult(title="codesearch", output=f"Missing dependency: {exc}", metadata={"ok": False}) + + body = { + "jsonrpc": "2.0", + "id": 1, + "method": "tools/call", + "params": { + "name": "get_code_context_exa", + "arguments": { + "query": params.query, + "tokensNum": int(params.tokens_num), + }, + }, + } + headers = { + "accept": "application/json, text/event-stream", + "content-type": "application/json", + } + url = EXA_BASE_URL + EXA_ENDPOINT + try: + async with httpx.AsyncClient(timeout=30.0) as client: + resp = await client.post(url, headers=headers, content=json.dumps(body)) + text = resp.text + except Exception as exc: # pragma: no cover - network + return ToolResult(title="codesearch", output=f"Code search failed: {exc}", metadata={"ok": False}) + + result_text = None + for line in text.splitlines(): + line = line.strip() + if not line.startswith("data: "): + continue + try: + data = json.loads(line[6:]) + result_text = ( + data.get("result", {}) + .get("content", [{"text": ""}])[0] + .get("text", "") + ) + if result_text: + break + except Exception: + continue + if not result_text: + result_text = "No code snippets or documentation found. Try a different query." + return ToolResult(title="codesearch", output=result_text, metadata={}) diff --git a/src/ollama_chat/tools/edit_tool.py b/src/ollama_chat/tools/edit_tool.py new file mode 100644 index 0000000..1029f95 --- /dev/null +++ b/src/ollama_chat/tools/edit_tool.py @@ -0,0 +1,139 @@ +from __future__ import annotations + +from difflib import unified_diff +from pathlib import Path + +from support import bus, lsp_client +from support import file_time as file_time_service + +from .base import ParamsSchema, Tool, ToolContext, ToolResult +from .external_directory import assert_external_directory + + +class EditParams(ParamsSchema): + file_path: str + old_string: str + new_string: str + replace_all: bool = False + + +class EditTool(Tool): + id = "edit" + params_schema = EditParams + + async def execute(self, params: EditParams, ctx: ToolContext) -> ToolResult: + if params.old_string == params.new_string: + return ToolResult(title=params.file_path, output="No changes to apply.", metadata={"ok": False}) + + file_path = Path(params.file_path).expanduser().resolve() + await assert_external_directory(ctx, str(file_path)) + + # Special case: create new file when old_string is empty + if params.old_string == "": + diff_lines = list( + unified_diff([], params.new_string.splitlines(), fromfile=str(file_path), tofile=str(file_path), lineterm="") + ) + diff_str = "\n".join(diff_lines) + await ctx.ask( + permission="edit", + patterns=[str(file_path)], + always=["*"], + metadata={"filepath": str(file_path), "diff": diff_str}, + ) + def _write() -> None: + file_path.parent.mkdir(parents=True, exist_ok=True) + file_path.write_text(params.new_string, encoding="utf-8") + await file_time_service.with_lock(str(file_path), _write) + try: + await bus.bus.publish("file.edited", {"file": str(file_path)}) + await bus.bus.publish("file.watcher.updated", {"file": str(file_path), "event": "add"}) + except Exception: + pass + try: + file_time_service.record_read(ctx.session_id, str(file_path)) + except Exception: + pass + try: + lsp_client.touch_file(str(file_path), notify=True) + except Exception: + pass + return ToolResult(title=str(file_path), output="File created.", metadata={"created": True}) + + # Otherwise, require prior read + try: + await file_time_service.assert_read(ctx.session_id, str(file_path)) + except Exception as exc: + return ToolResult(title=str(file_path), output=str(exc), metadata={"ok": False}) + + try: + content = file_path.read_text(encoding="utf-8", errors="replace") + except Exception as exc: + return ToolResult(title=str(file_path), output=str(exc), metadata={"ok": False}) + + # Simple replacement strategy: exact text, optionally first occurrence only + occurrences = content.count(params.old_string) + if occurrences == 0: + return ToolResult( + title=str(file_path), + output="Could not find oldString in the file. Make sure to read the file and pass the exact text.", + metadata={"ok": False}, + ) + if not params.replace_all and occurrences > 1: + return ToolResult( + title=str(file_path), + output="Found multiple matches for oldString; set replace_all=true to replace them all.", + metadata={"ok": False}, + ) + + new_content = ( + content.replace(params.old_string, params.new_string) + if params.replace_all + else content.replace(params.old_string, params.new_string, 1) + ) + + diff_lines = list( + unified_diff( + content.splitlines(), + new_content.splitlines(), + fromfile=str(file_path), + tofile=str(file_path), + lineterm="", + ) + ) + diff_str = "\n".join(diff_lines) + await ctx.ask( + permission="edit", + patterns=[str(file_path)], + always=["*"], + metadata={"filepath": str(file_path), "diff": diff_str}, + ) + + def _write() -> None: + file_path.write_text(new_content, encoding="utf-8") + await file_time_service.with_lock(str(file_path), _write) + + try: + await bus.bus.publish("file.edited", {"file": str(file_path)}) + await bus.bus.publish("file.watcher.updated", {"file": str(file_path), "event": "change"}) + except Exception: + pass + try: + file_time_service.record_read(ctx.session_id, str(file_path)) + except Exception: + pass + try: + lsp_client.touch_file(str(file_path), notify=True) + except Exception: + pass + + # LSP diagnostics + try: + diagnostics = lsp_client.get_diagnostics() + errors = [d for d in diagnostics.get(str(file_path), []) if d.get("severity") == 1] + output = f"Applied {'all' if params.replace_all else 'one'} replacement successfully." + if errors: + output += "\n\n" + "\n".join(d.get("message", "") for d in errors[:20]) + "\n" + except Exception: + output = f"Applied {'all' if params.replace_all else 'one'} replacement successfully." + + return ToolResult(title=str(file_path), output=output, metadata={"changed": True}) diff --git a/src/ollama_chat/tools/external_directory.py b/src/ollama_chat/tools/external_directory.py new file mode 100644 index 0000000..e0af69b --- /dev/null +++ b/src/ollama_chat/tools/external_directory.py @@ -0,0 +1,47 @@ +from __future__ import annotations + +from pathlib import Path + + +async def assert_external_directory( + ctx, + target: str | None, + bypass: bool = False, + kind: str = "file", # "file" | "directory" +) -> None: + """ + Ask the user for approval when operating outside the project/worktree roots. + The project directory is derived from ctx.extra.get("project_dir") or CWD. + """ + if target is None or bypass: + return + + try: + target_path = Path(target).resolve() + except Exception: + # If the path cannot be resolved, fall back to asking for the parent dir. + target_path = Path(str(target)).expanduser() + + project_dir_text = str(ctx.extra.get("project_dir", ".")) + worktree_text = str(ctx.extra.get("worktree", project_dir_text)) + project_dir = Path(project_dir_text).expanduser().resolve() + worktree = Path(worktree_text).expanduser().resolve() + + # If target is inside project_dir or worktree, no approval required. + def _inside(root: Path, child: Path) -> bool: + try: + return root == child or root in child.parents + except Exception: + return False + + if _inside(project_dir, target_path) or _inside(worktree, target_path): + return + + parent_dir = target_path if kind == "directory" else target_path.parent + glob = str(parent_dir / "*") + await ctx.ask( + permission="external_directory", + patterns=[glob], + always=[glob], + metadata={"filepath": str(target_path), "parentDir": str(parent_dir)}, + ) diff --git a/src/ollama_chat/tools/glob_tool.py b/src/ollama_chat/tools/glob_tool.py new file mode 100644 index 0000000..83a20b8 --- /dev/null +++ b/src/ollama_chat/tools/glob_tool.py @@ -0,0 +1,78 @@ +from __future__ import annotations + +import asyncio +import os +from pathlib import Path + +from support import ripgrep + +from .base import ParamsSchema, Tool, ToolContext, ToolResult +from .external_directory import assert_external_directory + +MAX_RESULTS = 100 + + +class GlobParams(ParamsSchema): + pattern: str + path: str | None = None # default: project directory + + +class GlobTool(Tool): + id = "glob" + params_schema = GlobParams + + async def execute(self, params: GlobParams, ctx: ToolContext) -> ToolResult: + pattern = params.pattern + search_root = Path(str(params.path or ctx.extra.get("project_dir", "."))) + search_root = search_root.expanduser().resolve() + + await ctx.ask( + permission="glob", + patterns=[pattern], + always=["*"], + metadata={"pattern": pattern, "path": str(search_root)}, + ) + await assert_external_directory(ctx, str(search_root), kind="directory") + + files: list[str] = [] + # Try ripgrep first + try: + rg = await ripgrep.filepath() + proc = await asyncio.create_subprocess_exec( + rg, + "--files", + "--glob", + pattern, + str(search_root), + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.DEVNULL, + ) + if proc.stdout is not None: + raw = await proc.stdout.read() + files = [f for f in raw.decode().split("\n") if f] + except Exception: + files = [] + + if not files: + # Fallback to Python globbing + try: + for p in search_root.rglob(pattern): + files.append(str(p)) + except Exception: + pass + + truncated = False + if len(files) > MAX_RESULTS: + files = files[:MAX_RESULTS] + truncated = True + + # Sort by mtime desc + try: + files.sort(key=lambda f: os.stat(f).st_mtime, reverse=True) + except Exception: + pass + + output = "\n".join(files) + if truncated: + output += "\n... results truncated; refine your search pattern." + return ToolResult(title=f"glob: {pattern}", output=output, metadata={"truncated": truncated}) diff --git a/src/ollama_chat/tools/grep_tool.py b/src/ollama_chat/tools/grep_tool.py new file mode 100644 index 0000000..e370229 --- /dev/null +++ b/src/ollama_chat/tools/grep_tool.py @@ -0,0 +1,121 @@ +from __future__ import annotations + +import asyncio +import os +from pathlib import Path +import re + +from .base import ParamsSchema, Tool, ToolContext, ToolResult +from .external_directory import assert_external_directory + +MAX_LINE_LENGTH = 2000 +MAX_RESULTS = 100 + + +class GrepParams(ParamsSchema): + pattern: str + path: str | None = None + include: str | None = None # file glob filter, e.g. "*.py" + + +class GrepTool(Tool): + id = "grep" + params_schema = GrepParams + + async def execute(self, params: GrepParams, ctx: ToolContext) -> ToolResult: + pattern = params.pattern + search_root = Path(str(params.path or ctx.extra.get("project_dir", "."))).expanduser().resolve() + + await ctx.ask( + permission="grep", + patterns=[pattern], + always=["*"], + metadata={"pattern": pattern, "path": str(search_root)}, + ) + await assert_external_directory(ctx, str(search_root), kind="directory") + + rg = "rg" + try: + # Prefer ripgrep if available + args = [ + rg, + "-nH", + "--hidden", + "--no-messages", + "--field-match-separator=|", + "--regexp", + pattern, + ] + if params.include: + args += ["--glob", str(params.include)] + args.append(str(search_root)) + proc = await asyncio.create_subprocess_exec( + *args, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + ) + stdout, stderr = await proc.communicate() + code = proc.returncode if proc.returncode is not None else 0 + text = stdout.decode() + if not text and code not in (0, 1, 2): + return ToolResult(title="grep", output="No files found.", metadata={}) + lines = [ln for ln in text.splitlines() if ln.strip()] + entries: list[tuple[str, int, str]] = [] + for ln in lines: + try: + filepath, linenum, linetext = ln.split("|", 3)[:3] + entries.append((filepath, int(linenum), linetext)) + except Exception: + continue + # Sort by mtime desc + try: + entries.sort(key=lambda e: os.stat(e[0]).st_mtime, reverse=True) + except Exception: + pass + + truncated = False + if len(entries) > MAX_RESULTS: + entries = entries[:MAX_RESULTS] + truncated = True + + out_lines: list[str] = [] + for fp, n, text in entries: + snippet = text if len(text) <= MAX_LINE_LENGTH else text[:MAX_LINE_LENGTH] + if len(text) > MAX_LINE_LENGTH: + snippet += "..." + out_lines.append(f"{fp}:\n Line {n}: {snippet}") + output = f"Found {len(entries)} matches\n" + "\n".join(out_lines) + if truncated: + output += "\n... results truncated; refine your query." + return ToolResult(title="grep", output=output, metadata={"truncated": truncated}) + except FileNotFoundError: + # Fallback: Python regex across files + try: + regex = re.compile(pattern) + except re.error as exc: + return ToolResult(title="grep", output=f"Invalid regex: {exc}", metadata={"ok": False}) + + files: list[Path] = [] + target = search_root + if target.is_file(): + files = [target] + else: + for root, _, filenames in os.walk(target): + for name in filenames: + files.append(Path(root) / name) + + matches: list[tuple[str, int, str]] = [] + for fp in files: + try: + with open(fp, encoding="utf-8", errors="ignore") as f: + for i, line in enumerate(f, start=1): + if regex.search(line): + matches.append((str(fp), i, line.rstrip("\n"))) + if len(matches) >= MAX_RESULTS: + break + except Exception: + continue + if not matches: + return ToolResult(title="grep", output="No files found.", metadata={}) + out_lines = [f"{fp}:\n Line {n}: {txt}" for fp, n, txt in matches] + return ToolResult(title="grep", output=f"Found {len(matches)} matches\n" + "\n".join(out_lines), metadata={}) diff --git a/src/ollama_chat/tools/invalid_tool.py b/src/ollama_chat/tools/invalid_tool.py new file mode 100644 index 0000000..ff45de3 --- /dev/null +++ b/src/ollama_chat/tools/invalid_tool.py @@ -0,0 +1,24 @@ +from __future__ import annotations + +from .base import ParamsSchema, Tool, ToolContext, ToolResult + + +class InvalidParams(ParamsSchema): + tool: str + error: str + + +class InvalidTool(Tool): + id = "invalid" + description = "Catches malformed tool calls. Never explicitly invoked." + params_schema = InvalidParams + + async def execute(self, params: InvalidParams, ctx: ToolContext) -> ToolResult: # noqa: D401 - simple + return ToolResult( + title="Invalid Tool", + output=( + f"The arguments provided to the tool {params.tool!r} are invalid: " + f"{params.error}" + ), + metadata={"ok": False}, + ) diff --git a/src/ollama_chat/tools/ls_tool.py b/src/ollama_chat/tools/ls_tool.py new file mode 100644 index 0000000..b909c0d --- /dev/null +++ b/src/ollama_chat/tools/ls_tool.py @@ -0,0 +1,100 @@ +from __future__ import annotations + +from collections import defaultdict +import os +from pathlib import Path + +from .base import ParamsSchema, Tool, ToolContext, ToolResult +from .external_directory import assert_external_directory + +IGNORE_PATTERNS = [ + "node_modules/", + "__pycache__/", + ".git/", + "dist/", + "build/", + "target/", + "vendor/", + "bin/", + "obj/", + ".idea/", + ".vscode/", + ".zig-cache/", + "zig-out", + ".coverage", + "coverage/", + "tmp/", + "temp/", + ".cache/", + "cache/", + "logs/", + ".venv/", + "venv/", + "env/", +] +LIMIT = 100 + + +class ListParams(ParamsSchema): + path: str | None = None + ignore: list[str] | None = None + + +class ListTool(Tool): + id = "list" + params_schema = ListParams + + async def execute(self, params: ListParams, ctx: ToolContext) -> ToolResult: + search = Path(str(params.path or ctx.extra.get("project_dir", "."))).expanduser().resolve() + await assert_external_directory(ctx, str(search), kind="directory") + await ctx.ask( + permission="list", + patterns=[str(search)], + always=["*"], + metadata={"path": str(search)}, + ) + + ignore = set(IGNORE_PATTERNS) + for pat in params.ignore or []: + if pat: + ignore.add(pat) + + # Gather files under the search root + files: list[Path] = [] + for root, dirnames, filenames in os.walk(search): + # Skip ignored directories by prefix match + dirnames[:] = [ + d for d in dirnames if not any((Path(root) / d).as_posix().endswith(p.rstrip("/")) for p in ignore) + ] + for name in filenames: + files.append(Path(root) / name) + if len(files) >= LIMIT: + break + if len(files) >= LIMIT: + break + + # Build tree structures + dirs: set[Path] = set([search]) + files_by_dir: defaultdict[Path, list[str]] = defaultdict(list) + for fp in files: + dirs.add(fp.parent) + files_by_dir[fp.parent].append(fp.name) + + # Ensure parents are included + for d in list(dirs): + for p in d.parents: + if search in p.parents or p == search: + dirs.add(p) + + def render_dir(dir_path: Path, depth: int) -> str: + indent = " " * depth + out = f"{indent}{dir_path.name}/\n" if depth > 0 else "" + children = sorted({d for d in dirs if d.parent == dir_path and d != dir_path}, key=lambda p: p.name.lower()) + for child in children: + out += render_dir(child, depth + 1) + for fname in sorted(files_by_dir.get(dir_path, [])): + out += f"{' ' * (depth + 1)}{fname}\n" + return out + + output = f"{str(search)}/\n" + render_dir(search, 0) + return ToolResult(title=f"list: {str(search)}", output=output.rstrip("\n"), metadata={"count": len(files)}) diff --git a/src/ollama_chat/tools/lsp_tool.py b/src/ollama_chat/tools/lsp_tool.py new file mode 100644 index 0000000..1830cd2 --- /dev/null +++ b/src/ollama_chat/tools/lsp_tool.py @@ -0,0 +1,61 @@ +from __future__ import annotations + +from pathlib import Path + +from support import lsp_client + +from .base import ParamsSchema, Tool, ToolContext, ToolResult +from .external_directory import assert_external_directory + +OPERATIONS = [ + "goToDefinition", + "findReferences", + "hover", + "documentSymbol", + "workspaceSymbol", + "goToImplementation", + "prepareCallHierarchy", + "incomingCalls", + "outgoingCalls", +] + + +class LspParams(ParamsSchema): + operation: str + file_path: str + line: int + character: int + + +class LspTool(Tool): + id = "lsp" + params_schema = LspParams + + async def execute(self, params: LspParams, ctx: ToolContext) -> ToolResult: + file_path = str(Path(params.file_path).expanduser().resolve()) + await assert_external_directory(ctx, file_path) + await ctx.ask(permission="lsp", patterns=["*"], always=["*"], metadata={}) + + if params.operation not in OPERATIONS: + return ToolResult( + title="lsp", + output=f"Unsupported operation: {params.operation}", + metadata={"ok": False}, + ) + + uri = Path(file_path).as_uri() + position = {"line": max(0, params.line - 1), "character": max(0, params.character - 1)} + + if not lsp_client.has_clients_for(file_path): + return ToolResult( + title="lsp", + output="No LSP server available for this file type.", + metadata={"ok": False}, + ) + + # This standalone implementation does not wire a real LSP; return a stub response. + return ToolResult( + title=f"lsp: {params.operation}", + output=f"No results found for {params.operation}", + metadata={"uri": uri, "position": position}, + ) diff --git a/src/ollama_chat/tools/multiedit_tool.py b/src/ollama_chat/tools/multiedit_tool.py new file mode 100644 index 0000000..c4f6c9e --- /dev/null +++ b/src/ollama_chat/tools/multiedit_tool.py @@ -0,0 +1,51 @@ +from __future__ import annotations + +from pydantic import BaseModel + +from .base import ParamsSchema, Tool, ToolContext, ToolResult +from .edit_tool import EditParams, EditTool + + +class EditOp(BaseModel): + file_path: str + old_string: str + new_string: str + replace_all: bool = False + + +class MultiEditParams(ParamsSchema): + file_path: str + edits: list[EditOp] + + +class MultiEditTool(Tool): + id = "multiedit" + params_schema = MultiEditParams + + async def execute(self, params: MultiEditParams, ctx: ToolContext) -> ToolResult: + if not params.edits: + return ToolResult( + title=params.file_path, + output="No edits provided.", + metadata={"ok": False}, + ) + editor = EditTool() + results = [] + last_output = "" + for op in params.edits: + result = await editor.execute( + EditParams( + file_path=params.file_path, + old_string=op.old_string, + new_string=op.new_string, + replace_all=op.replace_all, + ), + ctx, + ) + results.append(result) + last_output = result.output + return ToolResult( + title=params.file_path, + output=last_output, + metadata={"results": [r.metadata for r in results]}, + ) diff --git a/src/ollama_chat/tools/plan_tool.py b/src/ollama_chat/tools/plan_tool.py new file mode 100644 index 0000000..58aafc5 --- /dev/null +++ b/src/ollama_chat/tools/plan_tool.py @@ -0,0 +1,40 @@ +from __future__ import annotations + +from support import question_service + +from .base import ParamsSchema, Tool, ToolContext, ToolResult + + +class PlanExitTool(Tool): + id = "plan_exit" + params_schema = ParamsSchema # no parameters + + async def execute(self, params: ParamsSchema, ctx: ToolContext) -> ToolResult: # noqa: ARG002 - unused params + # Ask user to confirm switching to build agent. + questions = [ + { + "question": "Plan is complete. Switch to build agent?", + "header": "Build Agent", + "options": [ + {"label": "Yes", "description": "Switch to build agent"}, + {"label": "No", "description": "Stay with plan agent"}, + ], + "custom": False, + } + ] + answers = await question_service.ask(session_id=ctx.session_id, questions=questions) + if answers and answers[0] and answers[0][0] == "No": + return ToolResult( + title="Plan Mode", + output="User chose to stay in plan mode.", + metadata={"switched": False}, + ) + + # In a full runtime, we'd emit a synthetic user message to trigger build agent. + return ToolResult( + title="Switching to build agent", + output=( + "User approved switching to build agent. Wait for further instructions." + ), + metadata={"switched": True}, + ) diff --git a/src/ollama_chat/tools/question_tool.py b/src/ollama_chat/tools/question_tool.py new file mode 100644 index 0000000..3c623a9 --- /dev/null +++ b/src/ollama_chat/tools/question_tool.py @@ -0,0 +1,50 @@ +from __future__ import annotations + +from pydantic import BaseModel + +from support import question_service + +from .base import ParamsSchema, Tool, ToolContext, ToolResult + + +class QuestionOption(BaseModel): + label: str + description: str + + +class QuestionInfo(BaseModel): + question: str + header: str + options: list[QuestionOption] + multiple: bool = False + custom: bool = True + + +class QuestionParams(ParamsSchema): + questions: list[QuestionInfo] + + +class QuestionTool(Tool): + id = "question" + params_schema = QuestionParams + description = "Suspends execution to ask the user structured questions." + + async def execute(self, params: QuestionParams, ctx: ToolContext) -> ToolResult: + answers = await question_service.ask( + session_id=ctx.session_id, + questions=[q.model_dump() for q in params.questions], + tool={"message_id": ctx.message_id, "call_id": ctx.call_id}, + ) + pairs: list[str] = [] + for q, ans in zip(params.questions, answers): + pairs.append(f'"{q.question}"="{", ".join(ans)}"') + output = ( + "User has answered your questions: " + + "; ".join(pairs) + + ". You can now continue..." + ) + return ToolResult( + title="Question Answered", + output=output, + metadata={"answers": answers}, + ) diff --git a/src/ollama_chat/tools/read_tool.py b/src/ollama_chat/tools/read_tool.py new file mode 100644 index 0000000..b0484f1 --- /dev/null +++ b/src/ollama_chat/tools/read_tool.py @@ -0,0 +1,155 @@ +from __future__ import annotations + +import base64 +import mimetypes +import os +from pathlib import Path + +from support import file_time as file_time_service +from support import lsp_client + +from .base import Attachment, ParamsSchema, Tool, ToolContext, ToolResult +from .external_directory import assert_external_directory + +DEFAULT_READ_LIMIT = 2000 +MAX_LINE_LENGTH = 2000 +MAX_BYTES = 50 * 1024 # 50 KB + + +class ReadParams(ParamsSchema): + file_path: str + offset: int | None = None # 1-indexed; default 1 + limit: int | None = None # default DEFAULT_READ_LIMIT + + +class ReadTool(Tool): + id = "read" + params_schema = ReadParams + + async def execute(self, params: ReadParams, ctx: ToolContext) -> ToolResult: + # Resolve absolute path + project_dir = Path(str(ctx.extra.get("project_dir", "."))).expanduser().resolve() + raw_path = Path(params.file_path) + file_path = raw_path if raw_path.is_absolute() else (project_dir / raw_path) + file_path = file_path.resolve() + + await assert_external_directory(ctx, str(file_path), bypass=bool(ctx.extra.get("bypassCwdCheck"))) + await ctx.ask( + permission="read", + patterns=[str(file_path)], + always=["*"], + metadata={}, + ) + + if not file_path.exists(): + parent = file_path.parent + name = file_path.name.lower() + suggestions: list[str] = [] + try: + for entry in parent.iterdir(): + ename = entry.name.lower() + if name in ename or ename in name: + suggestions.append(str(entry)) + if len(suggestions) >= 3: + break + except Exception: + pass + raise FileNotFoundError( + f"File not found: {str(file_path)}. Suggestions: {', '.join(suggestions) if suggestions else 'none'}" + ) + + # If directory: list entries with offset/limit + if file_path.is_dir(): + entries = sorted(os.listdir(file_path)) + off = max(0, (params.offset or 1) - 1) + lim = params.limit or DEFAULT_READ_LIMIT + shown = entries[off : off + lim] + lines = [] + for name in shown: + p = file_path / name + suffix = "/" if p.is_dir() else "" + lines.append(name + suffix) + content = ( + f"{str(file_path)}\ndirectory\n\n" + + "\n".join(lines) + + "\n" + ) + return ToolResult(title=str(file_path), output=content, metadata={}) + + # MIME and attachment handling for images and PDFs + mime, _ = mimetypes.guess_type(str(file_path)) + try: + if mime and (mime.startswith("image/") or mime == "application/pdf") and not mime.endswith("svg+xml") and "vnd.fastbidsheet" not in mime: + data = file_path.read_bytes() + b64 = base64.b64encode(data).decode() + attachment = Attachment(type="file", mime=mime, url=f"data:{mime};base64,{b64}") + return ToolResult( + title=str(file_path), + output="Binary attachment returned.", + metadata={"attachment": True}, + attachments=[attachment], + ) + except Exception: + # Fallback to text flow + pass + + # Binary detection + try: + with open(file_path, "rb") as bf: + sample = bf.read(4096) + if b"\x00" in sample: + raise RuntimeError("Cannot read binary file") + non_printable = sum(1 for b in sample if b < 9 or (13 < b < 32)) + if len(sample) > 0 and non_printable / len(sample) > 0.3: + raise RuntimeError("Cannot read binary file") + except Exception as exc: + return ToolResult(title=str(file_path), output=str(exc), metadata={"ok": False}) + + # Read text with offset/limit and byte cap + offset = max(1, int(params.offset or 1)) + limit = max(1, int(params.limit or DEFAULT_READ_LIMIT)) + out_lines: list[str] = [] + total_bytes = 0 + total_lines = 0 + start_line = offset + end_line = offset - 1 + try: + with open(file_path, encoding="utf-8", errors="replace") as f: + for i, raw in enumerate(f, start=1): + total_lines = i + if i < offset: + continue + if len(out_lines) >= limit or total_bytes > MAX_BYTES: + break + line = raw.rstrip("\n") + if len(line) > MAX_LINE_LENGTH: + line = line[:MAX_LINE_LENGTH] + "... (line truncated)" + out_lines.append(f"{i}: {line}") + total_bytes += len(line.encode("utf-8", errors="ignore")) + end_line = i + except Exception as exc: + return ToolResult(title=str(file_path), output=str(exc), metadata={"ok": False}) + + summary: str + if end_line < total_lines: + summary = f"(Showing lines {start_line}-{end_line} of total {total_lines})" + else: + summary = f"(End of file - total {total_lines} lines)" + + content = ( + f"{str(file_path)}\nfile\n\n" + + "\n".join(out_lines) + + f"\n{summary}\n" + ) + + # Post-read hooks + try: + file_time_service.record_read(ctx.session_id, str(file_path)) + except Exception: + pass + try: + lsp_client.touch_file(str(file_path), notify=False) + except Exception: + pass + + return ToolResult(title=str(file_path), output=content, metadata={}) diff --git a/src/ollama_chat/tools/registry.py b/src/ollama_chat/tools/registry.py new file mode 100644 index 0000000..319ac1d --- /dev/null +++ b/src/ollama_chat/tools/registry.py @@ -0,0 +1,87 @@ +from __future__ import annotations + +from dataclasses import dataclass + +from .apply_patch_tool import ApplyPatchTool +from .base import Tool +from .bash_tool import BashTool +from .batch_tool import BatchTool +from .codesearch_tool import CodeSearchTool +from .edit_tool import EditTool +from .glob_tool import GlobTool +from .grep_tool import GrepTool +from .invalid_tool import InvalidTool +from .ls_tool import ListTool +from .lsp_tool import LspTool +from .multiedit_tool import MultiEditTool +from .plan_tool import PlanExitTool +from .question_tool import QuestionTool +from .read_tool import ReadTool +from .skill_tool import SkillTool +from .task_tool import TaskTool +from .todo_tool import TodoReadTool, TodoWriteTool +from .webfetch_tool import WebFetchTool +from .websearch_tool import WebSearchTool +from .write_tool import WriteTool + + +@dataclass +class ToolDefinition: + name: str + factory: type[Tool] + + +class ToolRegistry: + def __init__(self) -> None: + self._tools: dict[str, Tool] = {} + + def register(self, tool: Tool) -> None: + self._tools[tool.id] = tool + + def get(self, name: str) -> Tool | None: + return self._tools.get(name) + + def all(self) -> list[Tool]: + return list(self._tools.values()) + + def tools_for_model(self, provider_id: str | None = None, model_id: str | None = None, agent: str | None = None) -> list[Tool]: # noqa: D401 - simple + # Minimal filtering: return all built-ins. Environment gating is not enforced here. + return self.all() + + @classmethod + def build_default(cls) -> ToolRegistry: + reg = cls() + # Built-in tool order roughly as specified + reg.register(InvalidTool()) + reg.register(QuestionTool()) + reg.register(BashTool()) + reg.register(ReadTool()) + reg.register(GlobTool()) + reg.register(GrepTool()) + reg.register(EditTool()) + reg.register(WriteTool()) + reg.register(TaskTool()) + reg.register(WebFetchTool()) + reg.register(TodoWriteTool()) + reg.register(TodoReadTool()) + reg.register(WebSearchTool()) + reg.register(CodeSearchTool()) + reg.register(SkillTool()) + reg.register(ApplyPatchTool()) + reg.register(MultiEditTool()) + reg.register(ListTool()) + reg.register(LspTool()) + reg.register(BatchTool()) + reg.register(PlanExitTool()) + return reg + + +# Provide a simple singleton accessor +_default_registry: ToolRegistry | None = None + + +def get_registry() -> ToolRegistry: + global _default_registry + if _default_registry is None: + _default_registry = ToolRegistry.build_default() + return _default_registry diff --git a/src/ollama_chat/tools/skill_tool.py b/src/ollama_chat/tools/skill_tool.py new file mode 100644 index 0000000..66268d6 --- /dev/null +++ b/src/ollama_chat/tools/skill_tool.py @@ -0,0 +1,102 @@ +from __future__ import annotations + +import os +from pathlib import Path + +from .base import ParamsSchema, Tool, ToolContext, ToolResult + + +class SkillParams(ParamsSchema): + name: str + + +class SkillTool(Tool): + id = "skill" + params_schema = SkillParams + + def _skill_dirs(self, ctx: ToolContext) -> list[Path]: + dirs: list[Path] = [] + dirs.append(Path.home() / ".config" / "opencode" / "skills") + project = Path(str(ctx.extra.get("project_dir", "."))).expanduser().resolve() + dirs.append(project / ".opencode" / "skills") + for extra in ctx.extra.get("skill_dirs", []) or []: + try: + p = Path(str(extra)).expanduser().resolve() + dirs.append(p) + except Exception: + continue + return dirs + + def _find_skill_file(self, name: str, dirs: list[Path]) -> Path | None: + for base in dirs: + candidate = base / name / "SKILL.md" + if candidate.exists() and candidate.is_file(): + return candidate + return None + + def _parse_skill_md(self, path: Path) -> tuple[str, str]: + text = path.read_text(encoding="utf-8", errors="replace") + desc = "" + body = text + if text.startswith("---\n"): + try: + end = text.index("\n---\n", 4) + front = text[4:end] + body = text[end + 5 :] + for line in front.splitlines(): + if line.lower().startswith("description:"): + desc = line.split(":", 1)[1].strip() + break + except ValueError: + pass + return desc, body + + async def execute(self, params: SkillParams, ctx: ToolContext) -> ToolResult: + dirs = self._skill_dirs(ctx) + name = params.name.strip() + await ctx.ask(permission="skill", patterns=[name], always=[name], metadata={}) + path = self._find_skill_file(name, dirs) + if not path: + available: list[str] = [] + for base in dirs: + if base.exists(): + for child in base.iterdir(): + if (child / "SKILL.md").exists(): + available.append(child.name) + available.sort() + return ToolResult( + title="skill", + output=f"Skill '{name}' not found. Available: {', '.join(available) if available else 'none'}", + metadata={"ok": False}, + ) + + desc, content = self._parse_skill_md(path) + skill_dir = path.parent + # List up to 10 files (excluding SKILL.md) + files: list[str] = [] + for root, _dirs, fnames in os.walk(skill_dir): + for fname in fnames: + p = Path(root) / fname + try: + if p.samefile(path): + continue + except Exception: + if str(p) == str(path): + continue + files.append(str(p.relative_to(skill_dir))) + if len(files) >= 10: + break + if len(files) >= 10: + break + + base_url = skill_dir.as_uri() + body = ( + f"\n" + f"# Skill: {name}\n\n{content}\n\n" + f"Base directory: {base_url}\n" + "Relative paths are relative to this base directory.\n\n" + "\n" + + "\n".join(f"{f}" for f in files) + + "\n\n" + ) + return ToolResult(title=f"skill: {name}", output=body, metadata={"description": desc}) diff --git a/src/ollama_chat/tools/task_tool.py b/src/ollama_chat/tools/task_tool.py new file mode 100644 index 0000000..ce15994 --- /dev/null +++ b/src/ollama_chat/tools/task_tool.py @@ -0,0 +1,37 @@ +from __future__ import annotations + +import uuid + +from .base import ParamsSchema, Tool, ToolContext, ToolResult + + +class TaskParams(ParamsSchema): + description: str + prompt: str + subagent_type: str + task_id: str | None = None + command: str | None = None + + +class TaskTool(Tool): + id = "task" + params_schema = TaskParams + + async def execute(self, params: TaskParams, ctx: ToolContext) -> ToolResult: + if not ctx.extra.get("bypassAgentCheck"): + await ctx.ask( + permission="task", + patterns=[params.subagent_type], + always=["*"], + metadata={"description": params.description, "command": params.command or ""}, + ) + + # Minimal standalone behavior: generate or reuse a task_id and echo the prompt + task_id = params.task_id or str(uuid.uuid4()) + ctx.metadata( + title=params.description, + metadata={"session_id": task_id, "agent": params.subagent_type}, + ) + text = params.prompt.strip() + output = f"task_id: {task_id}\n\n\n{text}\n" + return ToolResult(title=params.description, output=output, metadata={"task_id": task_id}) diff --git a/src/ollama_chat/tools/todo_tool.py b/src/ollama_chat/tools/todo_tool.py new file mode 100644 index 0000000..49e700e --- /dev/null +++ b/src/ollama_chat/tools/todo_tool.py @@ -0,0 +1,113 @@ +from __future__ import annotations + +import json +from pathlib import Path + +from pydantic import BaseModel + +from support import bus + +from .base import ParamsSchema, Tool, ToolContext, ToolResult + +try: # Optional dependency + import aiosqlite # type: ignore +except Exception: # pragma: no cover - optional + aiosqlite = None # type: ignore[assignment] + +DB_PATH = Path.home() / ".local" / "share" / "ollamaterm" / "todo.sqlite3" + + +async def _ensure_schema(): + if aiosqlite is None: # pragma: no cover - optional + return + async with aiosqlite.connect(DB_PATH) as db: # type: ignore[attribute-defined-outside-init] + await db.execute( + """ + CREATE TABLE IF NOT EXISTS todos ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + session_id TEXT NOT NULL, + content TEXT NOT NULL, + status TEXT NOT NULL, + priority TEXT NOT NULL, + position INTEGER NOT NULL + ) + """ + ) + await db.commit() + + +class TodoItem(BaseModel): + content: str + status: str # pending | in_progress | completed | cancelled + priority: str # high | medium | low + + +class TodoWriteParams(ParamsSchema): + todos: list[TodoItem] + + +class TodoReadParams(ParamsSchema): + pass + + +class TodoWriteTool(Tool): + id = "todowrite" + params_schema = TodoWriteParams + + async def execute(self, params: TodoWriteParams, ctx: ToolContext) -> ToolResult: + await ctx.ask(permission="todowrite", patterns=["*"], always=["*"], metadata={}) + if aiosqlite is None: # pragma: no cover - optional + return ToolResult(title="todowrite", output="aiosqlite is not installed.", metadata={"ok": False}) + await _ensure_schema() + async with aiosqlite.connect(DB_PATH) as db: # type: ignore[attribute-defined-outside-init] + await db.execute("DELETE FROM todos WHERE session_id=?", [ctx.session_id]) + await db.executemany( + "INSERT INTO todos(session_id,content,status,priority,position) VALUES(?,?,?,?,?)", + [ + (ctx.session_id, t.content, t.status, t.priority, i) + for i, t in enumerate(params.todos) + ], + ) + await db.commit() + try: + await bus.bus.publish( + "todo.updated", + { + "session_id": ctx.session_id, + "todos": [t.model_dump() for t in params.todos], + }, + ) + except Exception: + pass + remaining = len([t for t in params.todos if t.status != "completed"]) + return ToolResult( + title=f"{remaining} todos", + output=json.dumps([t.model_dump() for t in params.todos], indent=2), + metadata={"count": len(params.todos)}, + ) + + +class TodoReadTool(Tool): + id = "todoread" + params_schema = TodoReadParams + + async def execute(self, params: TodoReadParams, ctx: ToolContext) -> ToolResult: # noqa: ARG002 - unused params + await ctx.ask(permission="todoread", patterns=["*"], always=["*"], metadata={}) + if aiosqlite is None: # pragma: no cover - optional + return ToolResult(title="todoread", output="aiosqlite is not installed.", metadata={"ok": False}) + await _ensure_schema() + async with aiosqlite.connect(DB_PATH) as db: # type: ignore[attribute-defined-outside-init] + cursor = await db.execute( + "SELECT content,status,priority FROM todos WHERE session_id=? ORDER BY position", + [ctx.session_id], + ) + rows = await cursor.fetchall() + todos = [ + {"content": r[0], "status": r[1], "priority": r[2]} # type: ignore[index] + for r in rows + ] + return ToolResult( + title="todos", + output=json.dumps(todos, indent=2), + metadata={"count": len(todos)}, + ) diff --git a/src/ollama_chat/tools/truncation.py b/src/ollama_chat/tools/truncation.py new file mode 100644 index 0000000..2993876 --- /dev/null +++ b/src/ollama_chat/tools/truncation.py @@ -0,0 +1,127 @@ +from __future__ import annotations + +from dataclasses import dataclass +from pathlib import Path +import time + +try: # Lazy import only when writing files + import aiofiles # type: ignore +except Exception: # pragma: no cover - optional at runtime + aiofiles = None # type: ignore[assignment] + +MAX_LINES = 2000 +MAX_BYTES = 50 * 1024 # 50 KB +OUTPUT_DIR = Path.home() / ".local" / "share" / "ollamaterm" / "tool-output" +RETENTION_SECONDS = 7 * 24 * 60 * 60 # 7 days + + +@dataclass +class TruncateResult: + content: str + truncated: bool + output_path: str | None = None + + +def _agent_has_task_tool(agent: str | None) -> bool: + # Without full permission plumbing available in this standalone package, + # conservatively return False. Callers can still follow the generic hint. + return False + + +async def _write_full_output(full_text: str) -> str | None: + """Persist full text to OUTPUT_DIR; return file path or None on failure.""" + try: + OUTPUT_DIR.mkdir(parents=True, exist_ok=True) + name = f"tool_{int(time.time() * 1000)}.txt" + path = OUTPUT_DIR / name + if aiofiles is None: + path.write_text(full_text) + else: # pragma: no cover - requires aiofiles + async with aiofiles.open(path, "w", encoding="utf-8") as f: + await f.write(full_text) + return str(path) + except Exception: + return None + + +async def truncate_output( + text: str, + direction: str = "head", # "head" | "tail" + agent: str | None = None, + max_lines: int = MAX_LINES, + max_bytes: int = MAX_BYTES, +) -> TruncateResult: + """ + Apply line and byte caps. If truncated, persist the full text to disk and + append a helpful hint to the preview. + """ + if not text: + return TruncateResult(content="", truncated=False, output_path=None) + + encoded = text.encode("utf-8", errors="ignore") + lines = text.splitlines() + + if len(lines) <= max_lines and len(encoded) <= max_bytes: + return TruncateResult(content=text, truncated=False, output_path=None) + + truncated = [] + total_bytes = 0 + + iterable = enumerate(lines) if direction == "head" else enumerate(reversed(lines)) + for _, line in iterable: + candidate = ("\n".join(truncated + [line])).encode("utf-8", errors="ignore") + if len(truncated) + 1 > max_lines or len(candidate) > max_bytes: + break + truncated.append(line) + total_bytes = len(candidate) + + if direction != "head": + truncated.reverse() + + preview = "\n".join(truncated) + hidden_lines = max(0, len(lines) - len(truncated)) + hidden_bytes = max(0, len(encoded) - total_bytes) + + out_path = await _write_full_output(text) + + has_task = _agent_has_task_tool(agent) + if has_task: + hint = ( + "Use the explore agent with Read/Grep to inspect the full output, " + "or open the saved file if available." + ) + else: + hint = ( + "Refine your query or use read with offset/limit to page through the file." + ) + + suffix = ( + f"\n... {hidden_lines} lines / {hidden_bytes} bytes truncated ...\n\n{hint}" + ) + return TruncateResult(content=preview + suffix, truncated=True, output_path=out_path) + + +async def cleanup_old_outputs() -> None: + """ + Delete files in OUTPUT_DIR older than RETENTION_SECONDS. Filename format: + tool_{epoch_ms}. Unknown filenames are ignored. + """ + if not OUTPUT_DIR.exists(): + return + now = time.time() + for entry in OUTPUT_DIR.iterdir(): + try: + if not entry.is_file(): + continue + stem = entry.stem # e.g. tool_1700000000000 + if not stem.startswith("tool_"): + continue + ts_ms = int(stem.split("_", 1)[1]) + age = now - (ts_ms / 1000.0) + if age > RETENTION_SECONDS: + try: + entry.unlink(missing_ok=True) + except Exception: + continue + except Exception: + continue diff --git a/src/ollama_chat/tools/webfetch_tool.py b/src/ollama_chat/tools/webfetch_tool.py new file mode 100644 index 0000000..0667ead --- /dev/null +++ b/src/ollama_chat/tools/webfetch_tool.py @@ -0,0 +1,93 @@ +from __future__ import annotations + +import base64 + +from .base import Attachment, ParamsSchema, Tool, ToolContext, ToolResult + +MAX_RESPONSE_BYTES = 5 * 1024 * 1024 + + +class WebFetchParams(ParamsSchema): + url: str + format: str = "markdown" # "markdown" | "text" | "html" + timeout: float | None = None # seconds; max 120 + + +class WebFetchTool(Tool): + id = "webfetch" + params_schema = WebFetchParams + + async def execute(self, params: WebFetchParams, ctx: ToolContext) -> ToolResult: + url = params.url.strip() + if not (url.startswith("http://") or url.startswith("https://")): + return ToolResult(title="webfetch", output="Invalid URL scheme.", metadata={"ok": False}) + + await ctx.ask( + permission="webfetch", + patterns=[url], + always=["*"], + metadata={"url": url}, + ) + + timeout_sec = min(max(1.0, float(params.timeout or 30.0)), 120.0) + try: + from bs4 import BeautifulSoup # type: ignore # noqa: WPS433 + import httpx # noqa: WPS433 + from markdownify import markdownify # type: ignore # noqa: WPS433 + except Exception as exc: # pragma: no cover - optional deps + return ToolResult(title="webfetch", output=f"Missing dependency: {exc}", metadata={"ok": False}) + + fmt = (params.format or "markdown").strip().lower() + if fmt not in {"markdown", "text", "html"}: + fmt = "markdown" + + headers_map = { + "markdown": "text/markdown;q=1.0, text/x-markdown;q=0.9, text/plain;q=0.8, text/html;q=0.7, */*;q=0.1", + "text": "text/plain;q=1.0, text/markdown;q=0.9, text/html;q=0.8, */*;q=0.1", + "html": "text/html;q=1.0, application/xhtml+xml;q=0.9, text/plain;q=0.8, */*;q=0.1", + } + headers = { + "Accept": headers_map[fmt], + "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36", + } + + async with httpx.AsyncClient(timeout=timeout_sec, follow_redirects=True) as client: + response = await client.get(url, headers=headers) + if response.status_code == 403 and response.headers.get("cf-mitigated") == "challenge": + response = await client.get(url, headers={**headers, "User-Agent": "ollamterm"}) + + size_header = response.headers.get("content-length") + if size_header and int(size_header) > MAX_RESPONSE_BYTES: + return ToolResult(title="webfetch", output="Response too large.", metadata={"ok": False}) + + data = await response.aread() + if len(data) > MAX_RESPONSE_BYTES: + return ToolResult(title="webfetch", output="Response too large.", metadata={"ok": False}) + + content_type = (response.headers.get("content-type", "").split(";")[0].strip().lower()) + if content_type.startswith("image/") and not content_type.endswith("svg+xml") and "vnd.fastbidsheet" not in content_type: + b64 = base64.b64encode(data).decode() + att = Attachment(type="file", mime=content_type, url=f"data:{content_type};base64,{b64}") + return ToolResult( + title="webfetch", + output="Image fetched successfully.", + metadata={"attachment": True}, + attachments=[att], + ) + + text = data.decode("utf-8", errors="replace") + if fmt == "html" or not content_type: + body = text + elif fmt == "markdown" and content_type in {"text/html", "application/xhtml+xml"}: + body = markdownify( + text, + heading_style="atx", + bullets="-", + strip=["script", "style", "meta", "link"], + ) + elif fmt == "text" and content_type in {"text/html", "application/xhtml+xml"}: + body = BeautifulSoup(text, "html.parser").get_text(separator="\n") + else: + body = text + + return ToolResult(title="webfetch", output=body, metadata={}) diff --git a/src/ollama_chat/tools/websearch_tool.py b/src/ollama_chat/tools/websearch_tool.py new file mode 100644 index 0000000..83bba76 --- /dev/null +++ b/src/ollama_chat/tools/websearch_tool.py @@ -0,0 +1,81 @@ +from __future__ import annotations + +import json + +from .base import ParamsSchema, Tool, ToolContext, ToolResult + +EXA_BASE_URL = "https://mcp.exa.ai" +EXA_ENDPOINT = "/mcp" + + +class WebSearchParams(ParamsSchema): + query: str + num_results: int | None = 8 + livecrawl: str | None = None # "fallback" | "preferred" + type: str | None = "auto" # "auto" | "fast" | "deep" + context_max_characters: int | None = None + + +class WebSearchTool(Tool): + id = "websearch" + params_schema = WebSearchParams + + async def execute(self, params: WebSearchParams, ctx: ToolContext) -> ToolResult: + await ctx.ask( + permission="websearch", + patterns=[params.query], + always=["*"], + metadata={"query": params.query}, + ) + + try: + import httpx # noqa: WPS433 + except Exception as exc: # pragma: no cover - optional dep + return ToolResult(title="websearch", output=f"Missing dependency: {exc}", metadata={"ok": False}) + + body = { + "jsonrpc": "2.0", + "id": 1, + "method": "tools/call", + "params": { + "name": "web_search_exa", + "arguments": { + "query": params.query, + "type": (params.type or "auto"), + "numResults": int(params.num_results or 8), + "livecrawl": (params.livecrawl or "fallback"), + "contextMaxCharacters": params.context_max_characters, + }, + }, + } + headers = { + "accept": "application/json, text/event-stream", + "content-type": "application/json", + } + url = EXA_BASE_URL + EXA_ENDPOINT + try: + async with httpx.AsyncClient(timeout=25.0) as client: + resp = await client.post(url, headers=headers, content=json.dumps(body)) + text = resp.text + except Exception as exc: # pragma: no cover - network + return ToolResult(title="websearch", output=f"Search failed: {exc}", metadata={"ok": False}) + + result_text = None + for line in text.splitlines(): + line = line.strip() + if not line.startswith("data: "): + continue + try: + data = json.loads(line[6:]) + result_text = ( + data.get("result", {}) + .get("content", [{"text": ""}])[0] + .get("text", "") + ) + if result_text: + break + except Exception: + continue + if not result_text: + result_text = "No search results found." + return ToolResult(title="websearch", output=result_text, metadata={}) diff --git a/src/ollama_chat/tools/write_tool.py b/src/ollama_chat/tools/write_tool.py new file mode 100644 index 0000000..66f3018 --- /dev/null +++ b/src/ollama_chat/tools/write_tool.py @@ -0,0 +1,103 @@ +from __future__ import annotations + +from difflib import unified_diff +from pathlib import Path + +from support import bus, lsp_client +from support import file_time as file_time_service + +from .base import ParamsSchema, Tool, ToolContext, ToolResult +from .external_directory import assert_external_directory + + +class WriteParams(ParamsSchema): + file_path: str + content: str + + +class WriteTool(Tool): + id = "write" + params_schema = WriteParams + + async def execute(self, params: WriteParams, ctx: ToolContext) -> ToolResult: + file_path = Path(params.file_path).expanduser().resolve() + await assert_external_directory(ctx, str(file_path)) + + exists = file_path.exists() + old_content = "" + if exists: + old_content = file_path.read_text(encoding="utf-8", errors="replace") + try: + await file_time_service.assert_read(ctx.session_id, str(file_path)) + except Exception as exc: + return ToolResult(title=str(file_path), output=str(exc), metadata={"ok": False}) + + diff_lines = list( + unified_diff( + old_content.splitlines(), + params.content.splitlines(), + fromfile=str(file_path), + tofile=str(file_path), + lineterm="", + ) + ) + diff_str = "\n".join(diff_lines) + await ctx.ask( + permission="edit", + patterns=[str(file_path)], + always=["*"], + metadata={"filepath": str(file_path), "diff": diff_str}, + ) + + def _write() -> None: + file_path.parent.mkdir(parents=True, exist_ok=True) + file_path.write_text(params.content, encoding="utf-8") + + # Serialize concurrent writes + await file_time_service.with_lock(str(file_path), _write) + + # Events and bookkeeping + try: + await bus.bus.publish( + "file.edited", {"file": str(file_path), "event": "change" if exists else "add"} + ) + await bus.bus.publish( + "file.watcher.updated", + {"file": str(file_path), "event": "change" if exists else "add"}, + ) + except Exception: + pass + + try: + file_time_service.record_read(ctx.session_id, str(file_path)) + except Exception: + pass + + try: + lsp_client.touch_file(str(file_path), notify=True) + except Exception: + pass + + # LSP diagnostics + try: + diagnostics = lsp_client.get_diagnostics() + errors = [d for d in diagnostics.get(str(file_path), []) if d.get("severity") == 1] + other_files = [ + (p, [d for d in ds if d.get("severity") == 1]) + for p, ds in diagnostics.items() + if p != str(file_path) + ] + other_files = [(p, es) for p, es in other_files if es][:5] + output = "Wrote file successfully." + if errors: + output += "\n\n" + "\n".join(d.get("message", "") for d in errors[:20]) + "\n" + for p, es in other_files: + output += ( + f"\n\n" + + "\n".join(d.get("message", "") for d in es[:20]) + + "\n" + ) + except Exception: + output = "Wrote file successfully." + + return ToolResult(title=str(file_path), output=output, metadata={"changed": True}) diff --git a/src/ollama_chat/widgets/activity_bar.py b/src/ollama_chat/widgets/activity_bar.py index be785d6..b1bb474 100644 --- a/src/ollama_chat/widgets/activity_bar.py +++ b/src/ollama_chat/widgets/activity_bar.py @@ -2,15 +2,12 @@ from __future__ import annotations -import asyncio import logging from typing import Any from textual.app import ComposeResult -from textual.containers import Horizontal, Vertical -from textual.message import Message from textual.timer import Timer -from textual.widgets import Button, Label, Static +from textual.widgets import Label, Static LOGGER = logging.getLogger(__name__) diff --git a/tests/test_app_actions.py b/tests/test_app_actions.py index 5c0e638..a9261d1 100644 --- a/tests/test_app_actions.py +++ b/tests/test_app_actions.py @@ -174,8 +174,8 @@ async def _no_prompt(self) -> str: # noqa: ANN001 self.app.chat = _FakeChat() self.app.persistence = _FakePersistence() self.app.state = StateManager() - from ollama_chat.state import ConnectionState from ollama_chat.capabilities import AttachmentState, SearchState + from ollama_chat.state import ConnectionState from ollama_chat.task_manager import TaskManager self.app._task_manager = TaskManager() diff --git a/tests/test_chat.py b/tests/test_chat.py index 04395bf..24a850c 100644 --- a/tests/test_chat.py +++ b/tests/test_chat.py @@ -7,7 +7,7 @@ from ollama_chat.chat import ChatChunk, OllamaChat from ollama_chat.exceptions import OllamaModelNotFoundError, OllamaStreamingError -from ollama_chat.tools import ToolRegistry +from ollama_chat.tooling import ToolRegistry async def _chunk_stream( diff --git a/tests/test_config.py b/tests/test_config.py index 4826274..596cbb2 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -2,8 +2,8 @@ from __future__ import annotations -import tempfile from pathlib import Path +import tempfile import unittest from ollama_chat.config import DEFAULT_CONFIG, load_config @@ -38,6 +38,11 @@ def test_missing_config_uses_defaults(self) -> None: self.assertEqual( config["logging"]["level"], DEFAULT_CONFIG["logging"]["level"] ) + self.assertEqual(config["tools"]["enabled"], DEFAULT_CONFIG["tools"]["enabled"]) + self.assertEqual( + config["tools"]["workspace_root"], + DEFAULT_CONFIG["tools"]["workspace_root"], + ) def test_partial_config_overrides_selected_values(self) -> None: with tempfile.TemporaryDirectory() as temp_dir: @@ -58,11 +63,44 @@ def test_partial_config_overrides_selected_values(self) -> None: self.assertEqual(config["ollama"]["models"], ["qwen2.5", "llama3.2"]) self.assertFalse(config["ui"]["show_timestamps"]) self.assertEqual(config["app"]["title"], DEFAULT_CONFIG["app"]["title"]) + self.assertEqual(config["tools"]["workspace_root"], DEFAULT_CONFIG["tools"]["workspace_root"]) self.assertEqual( config["security"]["allow_remote_hosts"], DEFAULT_CONFIG["security"]["allow_remote_hosts"], ) + def test_tools_section_overrides_selected_values(self) -> None: + with tempfile.TemporaryDirectory() as temp_dir: + config_path = Path(temp_dir) / "config.toml" + config_path.write_text( + """ +[tools] +enabled = true +workspace_root = "~/workspace" +allow_external_directories = true +command_timeout_seconds = 45 +max_output_lines = 300 +max_output_bytes = 123456 +max_read_bytes = 654321 +max_search_results = 25 +default_external_directories = ["/tmp", "~/sandbox"] + """.strip(), + encoding="utf-8", + ) + config = load_config(config_path=config_path) + self.assertTrue(config["tools"]["enabled"]) + self.assertEqual(config["tools"]["workspace_root"], "~/workspace") + self.assertTrue(config["tools"]["allow_external_directories"]) + self.assertEqual(config["tools"]["command_timeout_seconds"], 45) + self.assertEqual(config["tools"]["max_output_lines"], 300) + self.assertEqual(config["tools"]["max_output_bytes"], 123456) + self.assertEqual(config["tools"]["max_read_bytes"], 654321) + self.assertEqual(config["tools"]["max_search_results"], 25) + self.assertEqual( + config["tools"]["default_external_directories"], + ["/tmp", "~/sandbox"], + ) + def test_models_fallback_to_single_model_when_models_missing(self) -> None: with tempfile.TemporaryDirectory() as temp_dir: config_path = Path(temp_dir) / "config.toml" diff --git a/tests/test_custom_tools.py b/tests/test_custom_tools.py new file mode 100644 index 0000000..12dd064 --- /dev/null +++ b/tests/test_custom_tools.py @@ -0,0 +1,149 @@ +"""Tests for schema-first custom coding tools.""" + +from __future__ import annotations + +from pathlib import Path +import tempfile +import unittest + +from ollama_chat.exceptions import OllamaToolError +from ollama_chat.tooling import ToolRegistryOptions, ToolRuntimeOptions, build_registry + + +class CustomToolsTests(unittest.TestCase): + """Validate registration and core custom-tool behaviors.""" + + def test_custom_tool_names_registered(self) -> None: + registry = build_registry( + ToolRegistryOptions( + enable_custom_tools=True, + # Built-in tools are enabled by default; list of names can include + # overlap from the built-in adapter. We only require that the + # schema-first set is a subset of the registry names. + runtime_options=ToolRuntimeOptions(), + ) + ) + names = set(registry.list_tool_names()) + expected_subset = { + "apply_patch", + "bash", + "batch", + "glob", + "invalid", + "ls", + "multiedit", + "plan-enter", + "plan-exit", + "plan", + "question", + "registry", + "task", + "todo", + "todoread", + "todowrite", + "tool", + "truncation", + "webfetch", + "websearch", + "write", + } + self.assertTrue(expected_subset.issubset(names)) + + def test_builtin_adapter_precedence_for_allowlisted_names(self) -> None: + registry = build_registry( + ToolRegistryOptions( + enable_custom_tools=True, + runtime_options=ToolRuntimeOptions(), + ) + ) + names = set(registry.list_tool_names()) + # Ensure allowlisted built-ins are present + for name in {"codesearch", "edit", "grep", "list", "read"}: + self.assertIn(name, names) + + def test_write_read_edit_round_trip(self) -> None: + with tempfile.TemporaryDirectory() as tmp: + root = Path(tmp) + registry = build_registry( + ToolRegistryOptions( + enable_custom_tools=True, + runtime_options=ToolRuntimeOptions(workspace_root=str(root)), + ) + ) + + target = root / "hello.txt" + result = registry.execute( + "write", + { + "path": str(target), + "content": "hello world\n", + "overwrite": True, + }, + ) + self.assertIn("Wrote", result) + + read_text = registry.execute("read", {"path": str(target)}) + self.assertIn("hello world", read_text) + + registry.execute( + "edit", + { + "path": str(target), + "old_text": "world", + "new_text": "tooling", + }, + ) + updated = target.read_text(encoding="utf-8") + self.assertIn("hello tooling", updated) + + def test_batch_and_todo_workflow(self) -> None: + with tempfile.TemporaryDirectory() as tmp: + registry = build_registry( + ToolRegistryOptions( + enable_custom_tools=True, + runtime_options=ToolRuntimeOptions(workspace_root=tmp), + ) + ) + response = registry.execute( + "batch", + { + "calls": [ + {"name": "todo", "arguments": {"item": "first"}}, + { + "name": "todowrite", + "arguments": {"items": ["second"], "mode": "append"}, + }, + {"name": "todoread", "arguments": {}}, + ] + }, + ) + self.assertIn("\"ok\": true", response) + self.assertIn("second", response) + + def test_external_directory_policy_disabled_by_default(self) -> None: + with tempfile.TemporaryDirectory() as tmp: + registry = build_registry( + ToolRegistryOptions( + enable_custom_tools=True, + runtime_options=ToolRuntimeOptions(workspace_root=tmp), + ) + ) + with self.assertRaises(OllamaToolError): + registry.execute( + "external-directory", + {"action": "add", "path": "/tmp"}, + ) + + def test_invalid_tool_raises(self) -> None: + registry = build_registry( + ToolRegistryOptions( + enable_custom_tools=True, + runtime_options=ToolRuntimeOptions(), + ) + ) + with self.assertRaises(OllamaToolError): + registry.execute("invalid", {}) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_logging.py b/tests/test_logging.py index e87a0b6..e08728c 100644 --- a/tests/test_logging.py +++ b/tests/test_logging.py @@ -5,13 +5,14 @@ from collections.abc import AsyncGenerator import json import logging -import tempfile from pathlib import Path +import tempfile import unittest +import structlog + from ollama_chat.chat import OllamaChat from ollama_chat.logging_utils import configure_logging -import structlog class RetryClient: diff --git a/tests/test_package_exports.py b/tests/test_package_exports.py index a8e1d32..f832a2e 100644 --- a/tests/test_package_exports.py +++ b/tests/test_package_exports.py @@ -26,7 +26,7 @@ def test_lazy_exports_resolve_known_symbols(self) -> None: def test_unknown_symbol_raises_attribute_error(self) -> None: with self.assertRaises(AttributeError): - getattr(ollama_chat, "THIS_DOES_NOT_EXIST") + ollama_chat.THIS_DOES_NOT_EXIST if __name__ == "__main__": diff --git a/tests/test_persistence.py b/tests/test_persistence.py index f07a531..6359c36 100644 --- a/tests/test_persistence.py +++ b/tests/test_persistence.py @@ -3,9 +3,9 @@ from __future__ import annotations import json +from pathlib import Path import tempfile import time -from pathlib import Path import unittest from ollama_chat.persistence import ( diff --git a/tests/test_slash_commands.py b/tests/test_slash_commands.py index 8684935..4898062 100644 --- a/tests/test_slash_commands.py +++ b/tests/test_slash_commands.py @@ -10,8 +10,8 @@ except Exception: # pragma: no cover OllamaChatApp = None # type: ignore[assignment,misc] -from ollama_chat.state import ConnectionState, StateManager from ollama_chat.capabilities import AttachmentState, SearchState +from ollama_chat.state import ConnectionState, StateManager from ollama_chat.task_manager import TaskManager diff --git a/tests/test_stream_handler.py b/tests/test_stream_handler.py index abe4159..68462a9 100644 --- a/tests/test_stream_handler.py +++ b/tests/test_stream_handler.py @@ -2,8 +2,8 @@ from __future__ import annotations -import unittest from typing import Any +import unittest from ollama_chat.stream_handler import StreamHandler diff --git a/tests/test_tools.py b/tests/test_tools.py index fecae37..e10e960 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -5,7 +5,13 @@ import unittest from ollama_chat.exceptions import OllamaToolError -from ollama_chat.tools import ToolRegistry, build_default_registry +from ollama_chat.tooling import ( + ToolRegistry, + ToolRegistryOptions, + ToolRuntimeOptions, + build_default_registry, + build_registry, +) def _add(a: int, b: int) -> int: @@ -86,11 +92,11 @@ def test_build_default_registry_registers_web_tools_when_enabled(self) -> None: self.assertTrue(any("web_fetch" in n for n in tool_names)) def test_build_default_registry_raises_without_api_key(self) -> None: - from ollama_chat.exceptions import OllamaToolError - # Ensure OLLAMA_API_KEY is not set for this test. import os + from ollama_chat.exceptions import OllamaToolError + old = os.environ.pop("OLLAMA_API_KEY", None) try: with self.assertRaises(OllamaToolError): @@ -106,6 +112,79 @@ def test_multiple_registrations_do_not_duplicate(self) -> None: # Second registration overwrites the first (same name key). self.assertEqual(len(registry.build_tools_list()), 1) + def test_schema_tools_are_exported_when_custom_tools_enabled(self) -> None: + registry = build_registry(ToolRegistryOptions(enable_custom_tools=True)) + tools = registry.build_tools_list() + schema_tools = [item for item in tools if isinstance(item, dict)] + self.assertTrue( + any(item["function"]["name"] == "read" for item in schema_tools) + ) + self.assertTrue( + any(item["function"]["name"] == "bash" for item in schema_tools) + ) + + def test_builtin_adapter_allowlist_and_execution(self) -> None: + # Built-in adapter is enabled by default; custom tools disabled. + from pathlib import Path + import tempfile + + with tempfile.TemporaryDirectory() as tmp: + root = Path(tmp) + registry = build_registry( + ToolRegistryOptions( + enable_custom_tools=False, + runtime_options=ToolRuntimeOptions(workspace_root=str(root)), + ) + ) + names = set(registry.list_tool_names()) + # Exactly the allowlist should be present from built-ins + for name in {"codesearch", "edit", "grep", "list", "read"}: + self.assertIn(name, names) + # Sanity: non-allowlisted names from custom suite should not be present + self.assertNotIn("write", names) + + # Verify read executes by creating a file and reading it + target = root / "foo.txt" + target.write_text("hello\nworld\n", encoding="utf-8") + read_out = registry.execute("read", {"file_path": str(target), "limit": 1}) + self.assertIn("hello", read_out) + + # grep may use ripgrep if available; still should not crash on a simple pattern + grep_out = registry.execute("grep", {"pattern": "hello", "path": str(root)}) + self.assertIn("Found", grep_out) + + def test_schema_tool_validation_rejects_missing_required_argument(self) -> None: + registry = build_registry(ToolRegistryOptions(enable_custom_tools=True)) + with self.assertRaises(OllamaToolError): + registry.execute("read", {}) + + def test_truncation_applies_to_schema_tool_outputs(self) -> None: + registry = build_registry( + ToolRegistryOptions( + enable_custom_tools=True, + runtime_options=registry_runtime_options( + max_output_lines=2, + max_output_bytes=5000, + ), + ) + ) + registry.execute("todo", {"item": "line-1"}) + registry.execute("todo", {"item": "line-2"}) + registry.execute("todo", {"item": "line-3"}) + rendered = registry.execute("todoread", {}) + self.assertIn("truncated", rendered) + + +def registry_runtime_options( + *, + max_output_lines: int, + max_output_bytes: int, +) -> ToolRuntimeOptions: + return ToolRuntimeOptions( + max_output_lines=max_output_lines, + max_output_bytes=max_output_bytes, + ) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_widgets.py b/tests/test_widgets.py index 8c9395b..94bdbf4 100644 --- a/tests/test_widgets.py +++ b/tests/test_widgets.py @@ -5,7 +5,7 @@ import unittest try: - from textual.widgets import Input, Button, Label + from textual.widgets import Button, Input, Label from ollama_chat.widgets.activity_bar import ActivityBar from ollama_chat.widgets.code_block import CodeBlock, split_message