Skip to content

Commit 92cc94c

Browse files
committed
Improve MCP relationship guidance and tool logging
1 parent bc30023 commit 92cc94c

7 files changed

Lines changed: 207 additions & 19 deletions

src/codealive_mcp_server.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,9 @@
8080
- Use specific function/class names or file path scopes when looking for particular implementations
8181
- Treat `semantic_search` and `grep_search` as the default discovery tools
8282
- Prefer `semantic_search` over the deprecated `codebase_search` legacy alias
83+
- Use `get_artifact_relationships` only with exact artifact identifiers from prior search/fetch results.
84+
It expands a known artifact's relationship graph; it does not search by path, class name, or guessed symbol.
85+
For exact source code, call `fetch_artifacts` on identifiers returned by search or relationships.
8386
- Remember that context from previous messages is maintained in the same conversation
8487
8588
Flexible data source usage:

src/middleware/observability_middleware.py

Lines changed: 44 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
execution carries the correlation ID.
1212
"""
1313

14-
from typing import TYPE_CHECKING
14+
from typing import TYPE_CHECKING, Any
1515

1616
from loguru import logger
1717
from opentelemetry import trace
@@ -25,11 +25,43 @@
2525
_tracer = trace.get_tracer("codealive-mcp.tools")
2626

2727

28+
def _extract_tool_arguments(context: "MiddlewareContext") -> dict[str, Any]:
29+
"""Best-effort extraction of raw MCP tool arguments from FastMCP middleware context."""
30+
message = getattr(context, "message", None)
31+
args = getattr(message, "arguments", None)
32+
if isinstance(args, dict):
33+
return dict(args)
34+
35+
params = getattr(message, "params", None)
36+
if isinstance(params, dict):
37+
args = params.get("arguments")
38+
if isinstance(args, dict):
39+
return dict(args)
40+
else:
41+
args = getattr(params, "arguments", None)
42+
if isinstance(args, dict):
43+
return dict(args)
44+
45+
if isinstance(message, dict):
46+
args = message.get("arguments")
47+
if isinstance(args, dict):
48+
return dict(args)
49+
50+
params = message.get("params")
51+
if isinstance(params, dict):
52+
args = params.get("arguments")
53+
if isinstance(args, dict):
54+
return dict(args)
55+
56+
return {}
57+
58+
2859
class ObservabilityMiddleware(Middleware):
2960
"""Wrap each ``tools/call`` in an OTel span and log its outcome."""
3061

3162
async def on_call_tool(self, context: "MiddlewareContext", call_next: "CallNext"):
3263
tool_name = getattr(context.message, "name", "unknown")
64+
tool_arguments = _extract_tool_arguments(context)
3365

3466
with _tracer.start_as_current_span(
3567
f"tool {tool_name}",
@@ -44,21 +76,25 @@ async def on_call_tool(self, context: "MiddlewareContext", call_next: "CallNext"
4476
span_ctx = span.get_span_context()
4577
trace_id = format(span_ctx.trace_id, "032x") if span_ctx.trace_id else ""
4678

47-
with logger.contextualize(trace_id=trace_id, tool=tool_name):
48-
logger.info("Tool call started: {tool}", tool=tool_name)
79+
with logger.contextualize(
80+
trace_id=trace_id,
81+
tool=tool_name,
82+
tool_arguments=tool_arguments,
83+
):
84+
logger.debug("Tool call started: {tool}", tool=tool_name)
4985

5086
try:
5187
result = await call_next(context)
5288
except Exception as exc:
5389
span.set_status(StatusCode.ERROR, str(exc))
5490
span.record_exception(exc)
55-
logger.error(
56-
"Tool call failed: {tool} — {error}",
57-
tool=tool_name,
91+
logger.bind(
92+
error_type=type(exc).__name__,
5893
error=str(exc),
59-
)
94+
tool_arguments=tool_arguments,
95+
).opt(exception=exc).warning("Tool call failed: {tool}", tool=tool_name)
6096
raise
6197

6298
span.set_status(StatusCode.OK)
63-
logger.info("Tool call completed: {tool}", tool=tool_name)
99+
logger.debug("Tool call completed: {tool}", tool=tool_name)
64100
return result

src/tests/test_artifact_relationships.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,7 @@ async def test_default_profile_sends_calls_only(self, mock_get_api_key):
208208
mock_get_api_key.return_value = "test_key"
209209

210210
ctx = MagicMock(spec=Context)
211-
ctx.info = AsyncMock()
211+
ctx.debug = AsyncMock()
212212
ctx.error = AsyncMock()
213213

214214
mock_response = MagicMock()
@@ -244,7 +244,7 @@ async def test_explicit_profile_maps_correctly(self, mock_get_api_key):
244244
mock_get_api_key.return_value = "test_key"
245245

246246
ctx = MagicMock(spec=Context)
247-
ctx.info = AsyncMock()
247+
ctx.debug = AsyncMock()
248248
ctx.error = AsyncMock()
249249

250250
mock_response = MagicMock()
@@ -295,7 +295,7 @@ async def test_api_error_returns_error_json(self, mock_get_api_key):
295295
mock_get_api_key.return_value = "test_key"
296296

297297
ctx = MagicMock(spec=Context)
298-
ctx.info = AsyncMock()
298+
ctx.debug = AsyncMock()
299299
ctx.error = AsyncMock()
300300

301301
mock_response = MagicMock()
@@ -322,7 +322,7 @@ async def test_not_found_response_renders_correctly(self, mock_get_api_key):
322322
mock_get_api_key.return_value = "test_key"
323323

324324
ctx = MagicMock(spec=Context)
325-
ctx.info = AsyncMock()
325+
ctx.debug = AsyncMock()
326326
ctx.error = AsyncMock()
327327

328328
mock_response = MagicMock()

src/tests/test_e2e_tools.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,12 @@
1515
import httpx
1616
import pytest
1717
from fastmcp import Client, FastMCP
18+
from loguru import logger
1819

1920
sys.path.insert(0, str(Path(__file__).parent.parent))
2021

2122
from core import CodeAliveContext
23+
from middleware.observability_middleware import ObservabilityMiddleware
2224
from tools import (
2325
chat,
2426
codebase_consultant,
@@ -1204,6 +1206,38 @@ async def test_invalid_profile_returns_error(self):
12041206
assert "callsOnly" in text
12051207
assert "literal_error" in text or "Input should be" in text
12061208

1209+
@pytest.mark.asyncio
1210+
async def test_invalid_profile_is_logged_with_arguments_by_middleware(self):
1211+
"""FastMCP validation fails before the tool body, so middleware must capture args."""
1212+
mcp = _server({})
1213+
mcp.add_middleware(ObservabilityMiddleware())
1214+
records = []
1215+
handler_id = logger.add(lambda message: records.append(message.record), level="DEBUG")
1216+
1217+
try:
1218+
async with Client(mcp) as client:
1219+
result = await client.call_tool(
1220+
"get_artifact_relationships",
1221+
{"identifier": "org/repo::x", "profile": "bogus"},
1222+
raise_on_error=False,
1223+
)
1224+
finally:
1225+
logger.remove(handler_id)
1226+
1227+
assert result.is_error
1228+
failures = [
1229+
record for record in records
1230+
if record["message"] == "Tool call failed: get_artifact_relationships"
1231+
]
1232+
assert len(failures) == 1
1233+
failure = failures[0]
1234+
assert failure["level"].name == "WARNING"
1235+
assert failure["extra"]["tool_arguments"] == {
1236+
"identifier": "org/repo::x",
1237+
"profile": "bogus",
1238+
}
1239+
assert failure["extra"]["error_type"] == "ValidationError"
1240+
12071241
@pytest.mark.asyncio
12081242
async def test_empty_identifier_returns_error(self):
12091243
mcp = _server({})

src/tests/test_observability_middleware.py

Lines changed: 70 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,14 @@
55
from unittest.mock import AsyncMock, MagicMock, patch
66

77
import pytest
8+
from loguru import logger
89
from opentelemetry import trace
910
from opentelemetry.sdk.trace import TracerProvider, ReadableSpan
1011
from opentelemetry.sdk.trace.export import SimpleSpanProcessor, SpanExporter, SpanExportResult
1112

1213
sys.path.insert(0, str(__import__("pathlib").Path(__file__).parent.parent))
1314

14-
from middleware.observability_middleware import ObservabilityMiddleware
15+
from middleware.observability_middleware import ObservabilityMiddleware, _extract_tool_arguments
1516

1617

1718
class _CollectingExporter(SpanExporter):
@@ -49,9 +50,10 @@ def otel_setup():
4950
provider.shutdown()
5051

5152

52-
def _make_context(tool_name: str = "codebase_search"):
53+
def _make_context(tool_name: str = "codebase_search", arguments: dict | None = None):
5354
ctx = MagicMock()
5455
ctx.message.name = tool_name
56+
ctx.message.arguments = arguments or {}
5557
return ctx
5658

5759

@@ -113,6 +115,28 @@ async def test_handles_missing_tool_name(self, otel_setup):
113115
assert span.name == "tool unknown"
114116
assert span.attributes["mcp.tool.name"] == "unknown"
115117

118+
@pytest.mark.asyncio
119+
async def test_lifecycle_logs_are_debug_with_tool_arguments(self, otel_setup):
120+
middleware = ObservabilityMiddleware()
121+
tool_arguments = {"identifier": "org/repo::src/svc.py::run", "profile": "callsOnly"}
122+
context = _make_context("get_artifact_relationships", tool_arguments)
123+
call_next = AsyncMock(return_value="ok")
124+
records = []
125+
handler_id = logger.add(lambda message: records.append(message.record), level="DEBUG")
126+
127+
try:
128+
await middleware.on_call_tool(context, call_next)
129+
finally:
130+
logger.remove(handler_id)
131+
132+
lifecycle = [
133+
record for record in records
134+
if record["message"].startswith("Tool call ")
135+
]
136+
assert [record["level"].name for record in lifecycle] == ["DEBUG", "DEBUG"]
137+
assert lifecycle[0]["extra"]["tool_arguments"] == tool_arguments
138+
assert lifecycle[1]["extra"]["tool_arguments"] == tool_arguments
139+
116140

117141
# ---------------------------------------------------------------------------
118142
# Failed tool call
@@ -158,3 +182,47 @@ async def test_span_records_exception_event(self, otel_setup):
158182
messages = [e.attributes["exception.message"] for e in exception_events]
159183
assert "RuntimeError" in types
160184
assert "boom" in messages
185+
186+
@pytest.mark.asyncio
187+
async def test_failure_logs_warning_with_full_tool_arguments(self, otel_setup):
188+
middleware = ObservabilityMiddleware()
189+
tool_arguments = {
190+
"identifier": "org/repo::src/svc.py::run",
191+
"profile": "bogus",
192+
"max_count_per_type": 50,
193+
}
194+
context = _make_context("get_artifact_relationships", tool_arguments)
195+
call_next = AsyncMock(side_effect=ValueError("bad profile"))
196+
records = []
197+
handler_id = logger.add(lambda message: records.append(message.record), level="DEBUG")
198+
199+
try:
200+
with pytest.raises(ValueError, match="bad profile"):
201+
await middleware.on_call_tool(context, call_next)
202+
finally:
203+
logger.remove(handler_id)
204+
205+
failures = [record for record in records if record["message"] == "Tool call failed: get_artifact_relationships"]
206+
assert len(failures) == 1
207+
failure = failures[0]
208+
assert failure["level"].name == "WARNING"
209+
assert failure["extra"]["tool"] == "get_artifact_relationships"
210+
assert failure["extra"]["tool_arguments"] == tool_arguments
211+
assert failure["extra"]["error_type"] == "ValueError"
212+
assert failure["extra"]["error"] == "bad profile"
213+
214+
215+
class TestExtractToolArguments:
216+
def test_extracts_fastmcp_arguments(self):
217+
context = _make_context("tool", {"name": "value"})
218+
assert _extract_tool_arguments(context) == {"name": "value"}
219+
220+
def test_extracts_json_rpc_params_arguments(self):
221+
context = MagicMock()
222+
context.message = {"params": {"arguments": {"identifier": "id"}}}
223+
assert _extract_tool_arguments(context) == {"identifier": "id"}
224+
225+
def test_returns_empty_dict_when_unavailable(self):
226+
context = MagicMock()
227+
context.message = object()
228+
assert _extract_tool_arguments(context) == {}

src/tests/test_tool_metadata.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,3 +35,9 @@ async def test_all_tools_are_marked_read_only_with_titles():
3535
assert tool.title == title
3636
assert tool.annotations is not None
3737
assert tool.annotations.readOnlyHint is True
38+
39+
relationships_description = actual["get_artifact_relationships"].description
40+
assert relationships_description is not None
41+
assert "exact artifact identifier" in relationships_description
42+
assert "not a search tool" in relationships_description
43+
assert "fetch_artifacts" in relationships_description

src/tools/artifact_relationships.py

Lines changed: 46 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import httpx
77
from fastmcp import Context
88
from fastmcp.exceptions import ToolError
9+
from loguru import logger
910

1011
from core import CodeAliveContext, get_api_key_from_context, log_api_request, log_api_response
1112
from utils import handle_api_error
@@ -40,10 +41,29 @@ async def get_artifact_relationships(
4041
"""
4142
Retrieve relationship groups for a single artifact by profile.
4243
43-
Use this tool to explore an artifact's call graph, inheritance hierarchy,
44-
or references. This is a drill-down tool — use it AFTER `semantic_search`,
45-
`grep_search`, legacy `codebase_search`, or `fetch_artifacts` when you need
46-
to understand how an artifact relates to others in the codebase.
44+
Use this tool to expand the relationship graph around one known artifact:
45+
call graph edges, inheritance hierarchy, or references.
46+
47+
Important usage rules:
48+
- This is a graph expansion tool, not a search tool. The `identifier`
49+
must be an exact artifact identifier returned by `semantic_search`,
50+
`grep_search`, legacy `codebase_search`, or `fetch_artifacts`.
51+
- Do not pass a repository name, file path, class name, method name, or
52+
guessed symbol name unless it is the full identifier from a prior
53+
tool result.
54+
- If `found=false` or the backend returns a not-found/inaccessible
55+
error, get a fresh identifier with `semantic_search`, `grep_search`,
56+
`codebase_search`, or `fetch_artifacts` before retrying. Repeating
57+
the same guessed identifier usually repeats the same failure.
58+
- Relationships are primarily available for symbol artifacts such as
59+
functions, methods, classes, and interfaces. Plain files and prose
60+
documents can legitimately have no relationship graph.
61+
- The response contains relationship metadata and short summaries, not
62+
full source code. Use `fetch_artifacts` on returned identifiers when
63+
exact source content is needed.
64+
- If any relationship group has `truncated=true`, increase
65+
`max_count_per_type` up to 1000 or narrow the investigation with a
66+
more specific `profile`.
4767
4868
Args:
4969
identifier: Fully qualified artifact identifier from search or fetch results.
@@ -68,17 +88,32 @@ async def get_artifact_relationships(
6888
When the artifact is not found or inaccessible:
6989
{"sourceIdentifier":"...","profile":"callsOnly","found":false}
7090
"""
91+
tool_arguments = {
92+
"identifier": identifier,
93+
"profile": profile,
94+
"max_count_per_type": max_count_per_type,
95+
}
96+
7197
if not identifier:
98+
logger.bind(tool=_TOOL_NAME, tool_arguments=tool_arguments).warning(
99+
"Tool validation failed: artifact identifier is required"
100+
)
72101
raise ToolError(f"[{_TOOL_NAME}] Artifact identifier is required.")
73102

74103
if not (1 <= max_count_per_type <= 1000):
104+
logger.bind(tool=_TOOL_NAME, tool_arguments=tool_arguments).warning(
105+
"Tool validation failed: max_count_per_type is out of range"
106+
)
75107
raise ToolError(f"[{_TOOL_NAME}] max_count_per_type must be between 1 and 1000.")
76108

77109
# Literal type handles most validation via Pydantic, but direct callers
78110
# (e.g. unit tests) can still pass invalid values — keep as fallback.
79111
api_profile = PROFILE_MAP.get(profile)
80112
if api_profile is None:
81113
supported = ", ".join(PROFILE_MAP.keys())
114+
logger.bind(tool=_TOOL_NAME, tool_arguments=tool_arguments).warning(
115+
"Tool validation failed: unsupported relationship profile"
116+
)
82117
raise ToolError(f'[{_TOOL_NAME}] Unsupported profile "{profile}". Use one of: {supported}')
83118

84119
context: CodeAliveContext = ctx.request_context.lifespan_context
@@ -98,7 +133,7 @@ async def get_artifact_relationships(
98133
"maxCountPerType": max_count_per_type,
99134
}
100135

101-
await ctx.info(f"Fetching {profile} relationships for artifact")
136+
await ctx.debug(f"Fetching {profile} relationships for artifact")
102137

103138
full_url = urljoin(context.base_url, "/api/search/artifact-relationships")
104139
request_id = log_api_request("POST", full_url, headers, body=body)
@@ -113,6 +148,12 @@ async def get_artifact_relationships(
113148
return _build_relationships_dict(response.json())
114149

115150
except (httpx.HTTPStatusError, Exception) as e:
151+
logger.bind(
152+
tool=_TOOL_NAME,
153+
tool_arguments=tool_arguments,
154+
error_type=type(e).__name__,
155+
error=str(e),
156+
).warning("Tool call failed while fetching artifact relationships")
116157
await handle_api_error(
117158
ctx, e, "get artifact relationships", method=_TOOL_NAME,
118159
recovery_hints={

0 commit comments

Comments
 (0)