Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/scout_apm/fastmcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@ async def on_call_tool(self, context, call_next):
# Add rich metadata from tool object via context
try:
tool = await context.fastmcp_context.fastmcp.get_tool(tool_name)
self._tag_tool_metadata(tracked_request, tool)
if tool is not None:
self._tag_tool_metadata(tracked_request, tool)
except Exception as exc:
# Tool not found or other error - continue without metadata
logger.warning(f"Unable to fetch tool metadata for {tool_name}: {exc}")
Expand Down
17 changes: 16 additions & 1 deletion src/scout_apm/rq.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,21 @@ def ensure_job_instrumented():
job_instrumented = True
Job.perform = wrap_perform(Job.perform)

try:
import pickle

from rq.serializers import DefaultSerializer

if getattr(DefaultSerializer, "dumps", None) is pickle.dumps:
logger.warning(
"RQ is using the default pickle serializer, which is vulnerable to "
"Remote Code Execution (RCE) via Redis (CWE-502). Consider switching "
"to a safer serializer like rq.serializers.JSONSerializer. "
"See https://github.com/rq/rq/issues/2389 for details."
)
except Exception:
pass


@wrapt.decorator
def wrap_perform(wrapped, instance, args, kwargs):
Expand All @@ -66,7 +81,7 @@ def wrap_perform(wrapped, instance, args, kwargs):

tracked_request = TrackedRequest.instance()
tracked_request.is_real_request = True
tracked_request.tag("task_id", instance.get_id())
tracked_request.tag("task_id", instance.id)
tracked_request.tag("queue", instance.origin)
# rq strips tzinfo from enqueued_at during serde in at least some cases
# internally everything uses UTC naive datetimes, so we operate on that
Expand Down
69 changes: 47 additions & 22 deletions tests/integration/test_fastmcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,18 +21,39 @@ def parse_version(v):

from scout_apm.fastmcp import ScoutMiddleware
except (ImportError, TypeError):
# fastmcp has compatibility issues with version <2.13.0
# This is due to us using internal methods to test the middleware hooks
# These internal methods were renamed in 2.13.0
fastmcp_version = "0.0.0"
pass

_fastmcp_version = parse_version(fastmcp_version)

pytestmark = pytest.mark.skipif(
parse_version(fastmcp_version) < (2, 13, 0) or sys.version_info < (3, 10),
_fastmcp_version < (2, 13, 0) or sys.version_info < (3, 10),
reason="These tests require fastMCP 2.13.0+ and Python 3.10+",
)


async def _call_tool(mcp, name, arguments):
"""
Call a tool on a FastMCP server, compatible with both 2.x and 3.x.

Returns (content_blocks, metadata) for uniform access.
"""
if _fastmcp_version >= (3,):
result = await mcp.call_tool(name, arguments)
return result.content, result.meta
else:
return await mcp._call_tool_mcp(name, arguments)


async def _list_tools(mcp):
"""
List tools on a FastMCP server, compatible with both 2.x and 3.x.
"""
if _fastmcp_version >= (3,):
return await mcp.list_tools()
else:
return await mcp._list_tools_mcp()


@contextmanager
def server_with_scout(scout_config=None):
"""
Expand Down Expand Up @@ -70,14 +91,14 @@ def add_numbers(a: int, b: int) -> int:
return a + b

# Verify tool is registered
tools_list = await mcp._list_tools_mcp()
tools_list = await _list_tools(mcp)
assert len(tools_list) == 1
assert tools_list[0].name == "add_numbers"

# Simulate tool execution using the MCP protocol method
result = await mcp._call_tool_mcp("add_numbers", {"a": 5, "b": 3})
# result is a tuple: (content_blocks, metadata)
content_blocks, metadata = result
content_blocks, metadata = await _call_tool(
mcp, "add_numbers", {"a": 5, "b": 3}
)
assert len(content_blocks) == 1
assert content_blocks[0].text == "8"

Expand All @@ -100,8 +121,10 @@ async def async_multiply(a: int, b: int) -> int:
return a * b

# Simulate tool execution
result, metadata = await mcp._call_tool_mcp("async_multiply", {"a": 4, "b": 7})
assert result[0].text == "28"
content_blocks, metadata = await _call_tool(
mcp, "async_multiply", {"a": 4, "b": 7}
)
assert content_blocks[0].text == "28"

# Verify tracking
assert len(tracked_requests) == 1
Expand Down Expand Up @@ -130,8 +153,8 @@ def search_database(query: str) -> list:
return [{"id": 1, "name": "result"}]

# Execute tool
result, metadata = await mcp._call_tool_mcp("search_db", {"query": "test"})
assert len(result) == 1
content_blocks, metadata = await _call_tool(mcp, "search_db", {"query": "test"})
assert len(content_blocks) == 1

# Verify metadata tags
assert len(tracked_requests) == 1
Expand All @@ -158,13 +181,15 @@ def process_data(data: str, password: str, count: int) -> dict:
return {"processed": True, "length": len(data)}

# Execute tool with sensitive parameter
result, metadata = await mcp._call_tool_mcp(
"process_data", {"data": "test data", "password": "secret123", "count": 5}
content_blocks, metadata = await _call_tool(
mcp,
"process_data",
{"data": "test data", "password": "secret123", "count": 5},
)
# FastMCP returns list of ContentBlock, need to parse the JSON
import json

result_data = json.loads(result[0].text)
result_data = json.loads(content_blocks[0].text)
assert result_data["processed"] is True

# Verify arguments are tagged
Expand Down Expand Up @@ -195,7 +220,7 @@ def divide_numbers(a: float, b: float) -> float:

# Execute tool that raises an error
with pytest.raises(ToolError, match="Division by zero"):
await mcp._call_tool_mcp("divide_numbers", {"a": 10, "b": 0})
await _call_tool(mcp, "divide_numbers", {"a": 10, "b": 0})

# Verify error tracking
assert len(tracked_requests) == 1
Expand All @@ -214,9 +239,9 @@ def echo(message: str) -> str:
return message

# Execute multiple times
await mcp._call_tool_mcp("echo", {"message": "first"})
await mcp._call_tool_mcp("echo", {"message": "second"})
await mcp._call_tool_mcp("echo", {"message": "third"})
await _call_tool(mcp, "echo", {"message": "first"})
await _call_tool(mcp, "echo", {"message": "second"})
await _call_tool(mcp, "echo", {"message": "third"})

# Should have 3 separate tracked requests
assert len(tracked_requests) == 3
Expand All @@ -234,8 +259,8 @@ def monitored_tool() -> str:
"""This should not be tracked."""
return "result"

result, metadata = await mcp._call_tool_mcp("monitored_tool", {})
assert result[0].text == "result"
content_blocks, metadata = await _call_tool(mcp, "monitored_tool", {})
assert content_blocks[0].text == "result"

# Should not track when monitor is disabled
assert len(tracked_requests) == 0
Loading