Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 54 additions & 20 deletions agentic_security/mcp/main.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
from typing import Any

import httpx
from mcp.server.fastmcp import FastMCP

from agentic_security.logutils import logger

# Initialize MCP server
mcp = FastMCP(
name="Agentic Security MCP Server",
Expand All @@ -11,6 +15,51 @@
AGENTIC_SECURITY = "http://0.0.0.0:8718"


def _api_error(error_type: str, message: str, **details: Any) -> dict[str, Any]:
error = {"type": error_type, "message": message}
error.update(details)
return {"error": error}


def _response_body(response: httpx.Response) -> Any:
try:
return response.json()
except ValueError:
return response.text


async def _request_api(
method: str,
path: str,
*,
json: dict[str, Any] | None = None,
) -> dict[str, Any] | list[Any]:
url = f"{AGENTIC_SECURITY}{path}"
try:
async with httpx.AsyncClient() as client:
request_kwargs = {}
if json is not None:
request_kwargs["json"] = json
response = await client.request(method, url, **request_kwargs)
response.raise_for_status()
return response.json()
except httpx.HTTPStatusError as exc:
response = exc.response
logger.error("MCP backend returned an error: %s", exc)
return _api_error(
"http_status",
str(exc),
status_code=response.status_code,
response=_response_body(response),
)
except httpx.RequestError as exc:
logger.error("MCP backend request failed: %s", exc)
return _api_error("request", str(exc))
except ValueError as exc:
logger.error("MCP backend returned invalid JSON: %s", exc)
return _api_error("invalid_json", str(exc))


@mcp.tool()
async def verify_llm(spec: str) -> dict:
"""
Expand All @@ -22,10 +71,7 @@ async def verify_llm(spec: str) -> dict:
Args: spect(str): The specification of the LLM model to verify.

"""
url = f"{AGENTIC_SECURITY}/verify"
async with httpx.AsyncClient() as client:
response = await client.post(url, json={"spec": spec})
return response.json()
return await _request_api("POST", "/verify", json={"spec": spec})


@mcp.tool()
Expand All @@ -47,7 +93,6 @@ async def start_scan(
enableMultiStepAttack (bool, optional): Whether to enable multi-step attack

"""
url = f"{AGENTIC_SECURITY}/scan"
payload = {
"llmSpec": llmSpec,
"maxBudget": maxBudget,
Expand All @@ -57,9 +102,7 @@ async def start_scan(
"probe_datasets": [],
"secrets": {},
}
async with httpx.AsyncClient() as client:
response = await client.post(url, json=payload)
return response.json()
return await _request_api("POST", "/scan", json=payload)


@mcp.tool()
Expand All @@ -69,10 +112,7 @@ async def stop_scan() -> dict:
Returns:
dict: The confirmation from the FastAPI server that the scan has been stopped.
"""
url = f"{AGENTIC_SECURITY}/stop"
async with httpx.AsyncClient() as client:
response = await client.post(url)
return response.json()
return await _request_api("POST", "/stop")


@mcp.tool()
Expand All @@ -83,10 +123,7 @@ async def get_data_config() -> list:
Returns:
list: The response from the FastAPI server, confirming the scan has been stopped.
"""
url = f"{AGENTIC_SECURITY}/v1/data-config"
async with httpx.AsyncClient() as client:
response = await client.get(url)
return response.json()
return await _request_api("GET", "/v1/data-config")


@mcp.tool()
Expand All @@ -97,10 +134,7 @@ async def get_spec_templates() -> list:
Returns:
list: The LLM specification templates from the FastAPI server.
"""
url = f"{AGENTIC_SECURITY}/v1/llm-specs"
async with httpx.AsyncClient() as client:
response = await client.get(url)
return response.json()
return await _request_api("GET", "/v1/llm-specs")


# Run the MCP server
Expand Down
102 changes: 102 additions & 0 deletions tests/unit/test_mcp_main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
import httpx
import pytest

from agentic_security.mcp import main as mcp_main


class MockResponse:
def __init__(self, status_code=200, payload=None, text=""):
self.status_code = status_code
self.payload = payload
self.text = text
self.request = httpx.Request("GET", "http://testserver")

def json(self):
if isinstance(self.payload, Exception):
raise self.payload
return self.payload

def raise_for_status(self):
if self.status_code >= 400:
raise httpx.HTTPStatusError(
f"{self.status_code} error",
request=self.request,
response=self,
)


class MockAsyncClient:
def __init__(self, response=None, exc=None):
self.response = response
self.exc = exc
self.calls = []

async def __aenter__(self):
return self

async def __aexit__(self, exc_type, exc, traceback):
return False

async def request(self, method, url, json=None):
self.calls.append((method, url, json))
if self.exc:
raise self.exc
return self.response


def use_client(monkeypatch, client):
monkeypatch.setattr(mcp_main.httpx, "AsyncClient", lambda: client)
return client


@pytest.mark.asyncio
async def test_verify_llm_returns_backend_json(monkeypatch):
client = use_client(
monkeypatch, MockAsyncClient(MockResponse(payload={"ok": True}))
)

result = await mcp_main.verify_llm("openai:gpt-4")

assert result == {"ok": True}
assert client.calls == [
(
"POST",
"http://0.0.0.0:8718/verify",
{"spec": "openai:gpt-4"},
)
]


@pytest.mark.asyncio
async def test_start_scan_returns_http_status_error(monkeypatch):
response = MockResponse(status_code=503, payload={"detail": "backend unavailable"})
use_client(monkeypatch, MockAsyncClient(response))

result = await mcp_main.start_scan("openai:gpt-4", 10)

assert result["error"]["type"] == "http_status"
assert result["error"]["status_code"] == 503
assert result["error"]["response"] == {"detail": "backend unavailable"}


@pytest.mark.asyncio
async def test_get_data_config_returns_request_error(monkeypatch):
request = httpx.Request("GET", "http://0.0.0.0:8718/v1/data-config")
error = httpx.ConnectError("connection refused", request=request)
use_client(monkeypatch, MockAsyncClient(exc=error))

result = await mcp_main.get_data_config()

assert result["error"]["type"] == "request"
assert "connection refused" in result["error"]["message"]


@pytest.mark.asyncio
async def test_get_spec_templates_returns_invalid_json_error(monkeypatch):
response = MockResponse(payload=ValueError("bad json"))
use_client(monkeypatch, MockAsyncClient(response))

result = await mcp_main.get_spec_templates()

assert result["error"]["type"] == "invalid_json"
assert "bad json" in result["error"]["message"]