Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 102 additions & 2 deletions bot/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,15 @@
import asyncio
import logging
import os
import shlex
from pathlib import Path
from typing import Any

from agents import Agent
from agents import Runner
from agents import ShellTool
from agents import ShellToolLocalEnvironment
from agents import ShellToolLocalSkill
from agents import TResponseInputItem
from agents.mcp import MCPServerStdio
from agents.mcp import MCPServerStreamableHttp
Expand All @@ -24,6 +29,8 @@

MAX_TURNS = 10
MCP_SESSION_TIMEOUT_SECONDS = 30.0
SHELL_TIMEOUT = 30.0
SKILLS_DIR = Path(__file__).resolve().parent.parent / "skills"

set_tracing_disabled(True)

Expand All @@ -47,20 +54,106 @@ def _get_model() -> OpenAIResponsesModel | OpenAIChatCompletionsModel:
return OpenAIResponsesModel(model=model_name, openai_client=client)


def _parse_skill_description(content: str) -> str:
"""Return the description field from a SKILL.md YAML frontmatter, or ""."""
if not content.startswith("---"):
return ""
end = content.find("\n---", 3)
if end == -1:
return ""
for line in content[3:end].splitlines():
if line.startswith("description:"):
return line[len("description:") :].strip()
return ""


def _load_shell_skills() -> list[ShellToolLocalSkill]:
"""Discover local shell skills under SKILLS_DIR.

Each immediate subdirectory of SKILLS_DIR containing a SKILL.md is mounted
as a ShellToolLocalSkill. The skill name is the directory name; the
description is read from the SKILL.md YAML frontmatter.
"""
if not SKILLS_DIR.is_dir():
return []
skills: list[ShellToolLocalSkill] = []
for skill_dir in sorted(SKILLS_DIR.iterdir()):
skill_md = skill_dir / "SKILL.md"
if not skill_dir.is_dir() or not skill_md.is_file():
continue
skills.append(
ShellToolLocalSkill(
name=skill_dir.name,
description=_parse_skill_description(skill_md.read_text(encoding="utf-8")),
path=str(skill_dir),
)
)
return skills


async def _shell_executor(request: Any) -> str:
"""Run each shell command from the request and return combined output.

Two layers of defence keep the bot from running anything other than the
``obsidian`` CLI:

1. **Allowlist** — after ``shlex.split``, the first token of each command
must be exactly ``obsidian``. Anything else is rejected without
execution. This blocks attempts to invoke unrelated binaries.
2. **No shell** — commands are executed via ``create_subprocess_exec``
(not ``_shell``), so shell metacharacters like ``;``, ``&&``, ``|``,
``$()``, and backticks are passed as literal arguments to ``obsidian``
instead of being interpreted. This blocks command-chaining injection.

Honours ``action.timeout_ms`` when set, otherwise falls back to
``SHELL_TIMEOUT``. stderr is merged into stdout for simplicity.
"""
action = request.data.action
timeout = (action.timeout_ms / 1000.0) if action.timeout_ms else SHELL_TIMEOUT

outputs: list[str] = []
for command in action.commands:
try:
tokens = shlex.split(command)
except ValueError as e:
outputs.append(f"rejected: cannot parse command ({e}): {command}")
continue
if not tokens or tokens[0] != "obsidian":
outputs.append(f"rejected: only the 'obsidian' CLI is allowed: {command}")
continue

proc = await asyncio.create_subprocess_exec(
*tokens,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.STDOUT,
)
try:
stdout, _ = await asyncio.wait_for(proc.communicate(), timeout=timeout)
outputs.append(stdout.decode("utf-8", errors="replace"))
except TimeoutError:
proc.kill()
await proc.communicate()
outputs.append(f"Command timed out after {timeout}s: {command}")
break
return "\n".join(outputs)


class OpenAIAgent:
"""A wrapper for OpenAI Agent with MCP server support."""
"""A wrapper for OpenAI Agent with MCP server and local shell skill support."""

def __init__(
self,
name: str,
mcp_servers: list | None = None,
tools: list | None = None,
instructions: str = DEFAULT_INSTRUCTIONS,
) -> None:
self.agent = Agent(
name=name,
instructions=instructions,
model=_get_model(),
mcp_servers=(mcp_servers if mcp_servers is not None else []),
tools=(tools if tools is not None else []),
)
self.name = name
self._conversations: dict[int, list[TResponseInputItem]] = {}
Expand Down Expand Up @@ -116,8 +209,15 @@ def from_dict(cls, name: str, config: dict[str, Any]) -> OpenAIAgent:
},
)
)

tools: list[Any] = []
skills = _load_shell_skills()
if skills:
environment = ShellToolLocalEnvironment(type="local", skills=skills)
tools.append(ShellTool(executor=_shell_executor, environment=environment))

instructions = config.get("instructions", DEFAULT_INSTRUCTIONS)
return cls(name, mcp_servers, instructions=instructions)
return cls(name, mcp_servers, tools=tools, instructions=instructions)

async def connect(self) -> None:
for mcp_server in self.agent.mcp_servers:
Expand Down
4 changes: 4 additions & 0 deletions skills/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# Users drop their own shell skills into this directory; the bot auto-loads
# any subdirectory containing a SKILL.md. Skill content is not tracked.
*
!.gitignore
141 changes: 139 additions & 2 deletions tests/test_agents.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from __future__ import annotations

from types import SimpleNamespace
from unittest.mock import MagicMock
from unittest.mock import create_autospec
from unittest.mock import patch

import pytest
from agents import ShellTool
from agents.mcp import MCPServerStdio
from agents.mcp import MCPServerStreamableHttp
from agents.models.interface import Model
Expand All @@ -15,12 +17,16 @@
from bot.agents import MAX_TURNS
from bot.agents import OpenAIAgent
from bot.agents import _get_model
from bot.agents import _shell_executor


@pytest.fixture(autouse=True)
def _mock_model(monkeypatch):
"""Prevent tests from constructing a real OpenAI client."""
def _mock_model(monkeypatch, tmp_path_factory):
"""Prevent tests from constructing a real OpenAI client and isolate skills."""
monkeypatch.setattr("bot.agents._get_model", lambda: create_autospec(Model))
# Point SKILLS_DIR at an empty directory so tests do not auto-load the
# real ./skills/ on disk. Tests that need skills can override this.
monkeypatch.setattr("bot.agents.SKILLS_DIR", tmp_path_factory.mktemp("empty_skills"))


class TestGetModel:
Expand Down Expand Up @@ -233,3 +239,134 @@ def test_no_truncation_when_under_limit(self):
msgs = agent.get_messages(chat_id=100)
user_msgs = [m for m in msgs if m["role"] == "user"]
assert len(user_msgs) == 3


class TestLoadShellSkills:
def test_no_shell_tool_when_skills_dir_missing(self, tmp_path, monkeypatch):
monkeypatch.setattr("bot.agents.SKILLS_DIR", tmp_path / "nonexistent")
agent = OpenAIAgent.from_dict("test", {"mcpServers": {}})
shell_tools = [t for t in agent.agent.tools if isinstance(t, ShellTool)]
assert len(shell_tools) == 0

def test_shell_tool_added_when_skill_found(self, tmp_path, monkeypatch):
skill_dir = tmp_path / "my-skill"
skill_dir.mkdir()
(skill_dir / "SKILL.md").write_text("---\nname: my-skill\ndescription: A test skill\n---\n")
monkeypatch.setattr("bot.agents.SKILLS_DIR", tmp_path)

agent = OpenAIAgent.from_dict("test", {"mcpServers": {}})
shell_tool = next(t for t in agent.agent.tools if isinstance(t, ShellTool))
skill = shell_tool.environment["skills"][0]
assert skill["name"] == "my-skill"
assert skill["description"] == "A test skill"
assert skill["path"] == str(skill_dir)

def test_multiple_skills_all_mounted(self, tmp_path, monkeypatch):
for name in ["skill-a", "skill-b"]:
d = tmp_path / name
d.mkdir()
(d / "SKILL.md").write_text(f"---\nname: {name}\ndescription: desc {name}\n---\n")
monkeypatch.setattr("bot.agents.SKILLS_DIR", tmp_path)

agent = OpenAIAgent.from_dict("test", {"mcpServers": {}})
shell_tool = next(t for t in agent.agent.tools if isinstance(t, ShellTool))
assert len(shell_tool.environment["skills"]) == 2

def test_directory_without_skill_md_is_skipped(self, tmp_path, monkeypatch):
(tmp_path / "not-a-skill").mkdir()
good = tmp_path / "real-skill"
good.mkdir()
(good / "SKILL.md").write_text("---\nname: real-skill\ndescription: d\n---\n")
monkeypatch.setattr("bot.agents.SKILLS_DIR", tmp_path)

agent = OpenAIAgent.from_dict("test", {"mcpServers": {}})
shell_tool = next(t for t in agent.agent.tools if isinstance(t, ShellTool))
skills = shell_tool.environment["skills"]
assert len(skills) == 1
assert skills[0]["name"] == "real-skill"

def test_mcp_servers_and_shell_skills_coexist(self, tmp_path, monkeypatch):
skill_dir = tmp_path / "s"
skill_dir.mkdir()
(skill_dir / "SKILL.md").write_text("---\nname: s\ndescription: d\n---\n")
monkeypatch.setattr("bot.agents.SKILLS_DIR", tmp_path)

config = {"mcpServers": {"my-mcp": {"command": "uvx", "args": ["something"]}}}
agent = OpenAIAgent.from_dict("test", config)
assert len(agent.agent.mcp_servers) == 1
shell_tools = [t for t in agent.agent.tools if isinstance(t, ShellTool)]
assert len(shell_tools) == 1


def _shell_request(*commands: str) -> SimpleNamespace:
"""Build a real (not mocked) ShellCommandRequest-shaped object."""
return SimpleNamespace(
data=SimpleNamespace(
action=SimpleNamespace(commands=list(commands), timeout_ms=None),
),
)


class TestShellExecutorAllowlist:
@pytest.mark.anyio
async def test_rejects_non_obsidian_binary(self):
result = await _shell_executor(_shell_request("echo hello"))
assert "rejected" in result.lower()
assert "obsidian" in result.lower()

@pytest.mark.anyio
async def test_rejects_command_chained_with_semicolon(self, tmp_path):
# If the second command were ever interpreted by a shell, this file
# would be created. The allowlist + exec defence must prevent that.
sentinel = tmp_path / "should_not_exist"
result = await _shell_executor(
_shell_request(f"rm -rf /tmp/foo; touch {sentinel}"),
)
assert "rejected" in result.lower()
assert not sentinel.exists()

@pytest.mark.anyio
async def test_rejects_second_command_in_list(self, tmp_path):
# action.commands is a list — every entry must pass the allowlist.
sentinel = tmp_path / "should_not_exist"
result = await _shell_executor(
_shell_request("obsidian help", f"touch {sentinel}"),
)
assert "rejected" in result.lower()
assert not sentinel.exists()

@pytest.mark.anyio
async def test_rejects_malformed_quoting(self):
result = await _shell_executor(
_shell_request('obsidian read file="unclosed'),
)
assert "rejected" in result.lower()

@pytest.mark.anyio
async def test_metacharacters_are_not_shell_interpreted(self, tmp_path):
# Even if a command starts with `obsidian` and passes the allowlist,
# `&&` must not chain a second command via the shell. Use a fake
# `obsidian` script on PATH to confirm exec receives literal argv.
fake_bin = tmp_path / "bin"
fake_bin.mkdir()
sentinel = tmp_path / "should_not_exist"
fake_obsidian = fake_bin / "obsidian"
fake_obsidian.write_text("#!/bin/sh\necho fake-obsidian got: $*\n")
fake_obsidian.chmod(0o755)

import os

old_path = os.environ["PATH"]
os.environ["PATH"] = f"{fake_bin}{os.pathsep}{old_path}"
try:
result = await _shell_executor(
_shell_request(f"obsidian read && touch {sentinel}"),
)
finally:
os.environ["PATH"] = old_path

assert not sentinel.exists()
# The fake script echoed everything it received as argv; the `&&` and
# `touch` tokens should appear in its output as literal arguments.
assert "&&" in result
assert "touch" in result