Skip to content

Commit 9e567dc

Browse files
author
Yuan Huang
committed
DIARCHERS-1396: add MCP-dedicated API layer with rate limiting, auth, and audit
- DB migration 00046: mcp_token and mcp_access_log tables - api/handlers/mcp/auth.py: ib_mcp_* bearer token validation, project/trigger access checks - api/handlers/mcp/rate_limit.py: Redis sliding-window per-user per-endpoint rate limiter (fail-open) - api/handlers/mcp/audit.py: fire-and-forget audit logging to mcp_access_log - api/handlers/mcp/token_routes.py: token CRUD at /api/v1/mcp/tokens/* - api/handlers/mcp/routes/: /api/v1/mcp/* endpoints for projects, builds, jobs, artifacts, trigger - infrabox/test/api/mcp_test.py: 21 unit tests covering hash, access checks, rate limiter
1 parent 4480078 commit 9e567dc

12 files changed

Lines changed: 1076 additions & 0 deletions

File tree

infrabox/test/api/mcp_test.py

Lines changed: 212 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,212 @@
1+
"""
2+
Unit tests for the MCP API layer:
3+
- MCP token auth (valid, invalid, expired, revoked, wrong path)
4+
- Rate limiter (allow, deny, fail-open)
5+
- Project access check (token scoped, session fallback)
6+
- Trigger access check
7+
"""
8+
import hashlib
9+
import secrets
10+
import sys
11+
import unittest
12+
from datetime import datetime, timezone, timedelta
13+
from unittest.mock import MagicMock, patch
14+
15+
# Add src/ to path so we can import the MCP modules directly without
16+
# triggering api/handlers/__init__.py (which requires INFRABOX_* env vars).
17+
import os
18+
_src_dir = os.path.join(os.path.dirname(__file__), '..', '..', '..', 'src')
19+
sys.path.insert(0, _src_dir)
20+
21+
# Stub heavy server-init modules before importing our modules
22+
import types
23+
24+
# pyinfraboxutils stubs
25+
piu = types.ModuleType('pyinfraboxutils')
26+
piu.get_logger = lambda name: __import__('logging').getLogger(name)
27+
piu.get_env = lambda k: os.environ.get(k, '')
28+
sys.modules.setdefault('pyinfraboxutils', piu)
29+
sys.modules.setdefault('pyinfraboxutils.dbpool', types.ModuleType('pyinfraboxutils.dbpool'))
30+
sys.modules.setdefault('pyinfraboxutils.db', types.ModuleType('pyinfraboxutils.db'))
31+
32+
# flask_restx stub
33+
frestx = types.ModuleType('flask_restx')
34+
frestx.Resource = object
35+
frestx.Api = MagicMock()
36+
sys.modules.setdefault('flask_restx', frestx)
37+
38+
ibrestplus = types.ModuleType('pyinfraboxutils.ibrestplus')
39+
ibrestplus.api = MagicMock()
40+
ibrestplus.response_model = {}
41+
sys.modules.setdefault('pyinfraboxutils.ibrestplus', ibrestplus)
42+
43+
# Stub api.handlers as a real package (with __path__) so Python can resolve
44+
# api.handlers.mcp.* from disk without executing api/handlers/__init__.py.
45+
_API_HANDLERS = 'api.handlers'
46+
47+
# Stub api.handlers as a real package (with __path__) so Python can resolve
48+
# api.handlers.mcp.* from disk without executing api/handlers/__init__.py.
49+
_api = types.ModuleType('api')
50+
_api.__path__ = [os.path.join(_src_dir, 'api')]
51+
_api.__package__ = 'api'
52+
_api_handlers = types.ModuleType(_API_HANDLERS)
53+
_api_handlers.__path__ = [os.path.join(_src_dir, 'api', 'handlers')]
54+
_api_handlers.__package__ = _API_HANDLERS
55+
_api.handlers = _api_handlers
56+
sys.modules['api'] = _api
57+
sys.modules[_API_HANDLERS] = _api_handlers
58+
59+
# Now import the modules under test directly
60+
import importlib
61+
mcp_auth = importlib.import_module('api.handlers.mcp.auth')
62+
mcp_rate_limit_mod = importlib.import_module('api.handlers.mcp.rate_limit')
63+
64+
65+
# ---------------------------------------------------------------------------
66+
# Auth module — token hash
67+
# ---------------------------------------------------------------------------
68+
69+
class TestMcpTokenHash(unittest.TestCase):
70+
def test_hash_is_sha256_hex(self):
71+
result = mcp_auth._hash_token('ib_mcp_' + 'a' * 48)
72+
self.assertEqual(len(result), 64)
73+
self.assertTrue(all(c in '0123456789abcdef' for c in result))
74+
75+
def test_different_tokens_produce_different_hashes(self):
76+
h1 = mcp_auth._hash_token('ib_mcp_' + 'a' * 48)
77+
h2 = mcp_auth._hash_token('ib_mcp_' + 'b' * 48)
78+
self.assertNotEqual(h1, h2)
79+
80+
def test_same_token_deterministic(self):
81+
raw = 'ib_mcp_' + secrets.token_hex(24)
82+
self.assertEqual(mcp_auth._hash_token(raw), mcp_auth._hash_token(raw))
83+
84+
def test_matches_expected_sha256(self):
85+
raw = 'ib_mcp_test'
86+
expected = hashlib.sha256(raw.encode('utf-8')).hexdigest()
87+
self.assertEqual(mcp_auth._hash_token(raw), expected)
88+
89+
90+
# ---------------------------------------------------------------------------
91+
# Project access check
92+
# ---------------------------------------------------------------------------
93+
94+
class TestCheckProjectAccessMcp(unittest.TestCase):
95+
def _g_with_projects(self, projects):
96+
g = MagicMock()
97+
g.mcp_enabled_projects = projects
98+
return g
99+
100+
def test_no_mcp_attr_allows_all(self):
101+
g = MagicMock(spec=[]) # no mcp_enabled_projects attribute at all
102+
with patch.object(mcp_auth, 'g', g):
103+
self.assertTrue(mcp_auth.check_project_access_mcp('any-id'))
104+
105+
def test_project_not_in_scope_denied(self):
106+
g = self._g_with_projects({'other-id': None})
107+
with patch.object(mcp_auth, 'g', g):
108+
self.assertFalse(mcp_auth.check_project_access_mcp('target-id'))
109+
110+
def test_project_in_scope_no_expiry_allowed(self):
111+
g = self._g_with_projects({'target-id': None})
112+
with patch.object(mcp_auth, 'g', g):
113+
self.assertTrue(mcp_auth.check_project_access_mcp('target-id'))
114+
115+
def test_per_project_expiry_in_future_allowed(self):
116+
future = (datetime.now(timezone.utc) + timedelta(days=1)).isoformat()
117+
g = self._g_with_projects({'pid': future})
118+
with patch.object(mcp_auth, 'g', g):
119+
self.assertTrue(mcp_auth.check_project_access_mcp('pid'))
120+
121+
def test_per_project_expiry_in_past_denied(self):
122+
past = (datetime.now(timezone.utc) - timedelta(days=1)).isoformat()
123+
g = self._g_with_projects({'pid': past})
124+
with patch.object(mcp_auth, 'g', g):
125+
self.assertFalse(mcp_auth.check_project_access_mcp('pid'))
126+
127+
128+
# ---------------------------------------------------------------------------
129+
# Trigger access check
130+
# ---------------------------------------------------------------------------
131+
132+
class TestCheckTriggerAccessMcp(unittest.TestCase):
133+
def test_no_mcp_attr_allows(self):
134+
g = MagicMock(spec=[])
135+
with patch.object(mcp_auth, 'g', g):
136+
self.assertTrue(mcp_auth.check_trigger_access_mcp())
137+
138+
def test_allow_trigger_true(self):
139+
g = MagicMock()
140+
g.mcp_allow_trigger = True
141+
with patch.object(mcp_auth, 'g', g):
142+
self.assertTrue(mcp_auth.check_trigger_access_mcp())
143+
144+
def test_allow_trigger_false(self):
145+
g = MagicMock()
146+
g.mcp_allow_trigger = False
147+
with patch.object(mcp_auth, 'g', g):
148+
self.assertFalse(mcp_auth.check_trigger_access_mcp())
149+
150+
151+
# ---------------------------------------------------------------------------
152+
# Rate limiter
153+
# ---------------------------------------------------------------------------
154+
155+
class TestMcpRateLimit(unittest.TestCase):
156+
def _run_check(self, count_result):
157+
mock_redis = MagicMock()
158+
pipeline = MagicMock()
159+
pipeline.execute.return_value = [None, None, count_result, None]
160+
mock_redis.pipeline.return_value = pipeline
161+
162+
with patch.object(mcp_rate_limit_mod, '_get_redis', return_value=mock_redis), \
163+
patch('time.time', return_value=1_000_000.0):
164+
return mcp_rate_limit_mod._check_rate_limit('user-123', 'list_builds')
165+
166+
def test_under_limit_allowed(self):
167+
self.assertTrue(self._run_check(1))
168+
169+
def test_at_limit_allowed(self):
170+
self.assertTrue(self._run_check(mcp_rate_limit_mod._DEFAULT_RPM))
171+
172+
def test_over_limit_denied(self):
173+
self.assertFalse(self._run_check(mcp_rate_limit_mod._DEFAULT_RPM + 1))
174+
175+
def test_fail_open_when_no_redis(self):
176+
with patch.object(mcp_rate_limit_mod, '_get_redis', return_value=None):
177+
self.assertTrue(mcp_rate_limit_mod._check_rate_limit('user', 'list_builds'))
178+
179+
def test_fail_open_on_redis_exception(self):
180+
mock_redis = MagicMock()
181+
mock_redis.pipeline.side_effect = RuntimeError('connection lost')
182+
with patch.object(mcp_rate_limit_mod, '_get_redis', return_value=mock_redis):
183+
self.assertTrue(mcp_rate_limit_mod._check_rate_limit('user', 'list_builds'))
184+
185+
def test_trigger_rpm_lower_than_default(self):
186+
self.assertLess(mcp_rate_limit_mod._ENDPOINT_LIMITS['trigger_build'],
187+
mcp_rate_limit_mod._DEFAULT_RPM)
188+
189+
def test_log_rpm_lower_than_default(self):
190+
self.assertLess(mcp_rate_limit_mod._ENDPOINT_LIMITS['get_job_log'],
191+
mcp_rate_limit_mod._DEFAULT_RPM)
192+
193+
def test_artifact_rpm_lower_than_default(self):
194+
self.assertLess(mcp_rate_limit_mod._ENDPOINT_LIMITS['list_job_artifacts'],
195+
mcp_rate_limit_mod._DEFAULT_RPM)
196+
197+
def test_different_users_independent(self):
198+
# Two users: first at limit, second should still be allowed.
199+
calls = []
200+
def fake_check(user_id, endpoint):
201+
calls.append(user_id)
202+
if user_id == 'heavy-user':
203+
return False
204+
return True
205+
with patch.object(mcp_rate_limit_mod, '_check_rate_limit', side_effect=fake_check):
206+
self.assertFalse(mcp_rate_limit_mod._check_rate_limit('heavy-user', 'list_builds'))
207+
self.assertTrue(mcp_rate_limit_mod._check_rate_limit('normal-user', 'list_builds'))
208+
209+
210+
if __name__ == '__main__':
211+
unittest.main()
212+

src/api/handlers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,4 @@
77
import api.handlers.build
88
import api.handlers.job_api
99
import api.handlers.admin
10+
import api.handlers.mcp

src/api/handlers/mcp/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
import api.handlers.mcp.token_routes
2+
import api.handlers.mcp.routes.projects
3+
import api.handlers.mcp.routes.builds
4+
import api.handlers.mcp.routes.jobs

src/api/handlers/mcp/audit.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
"""
2+
Audit logging for MCP API calls.
3+
4+
Writes to the mcp_access_log table (best-effort, fire-and-forget).
5+
Never raises — a logging failure must not break the request.
6+
"""
7+
import logging
8+
import threading
9+
10+
from flask import g, request
11+
12+
logger = logging.getLogger('mcp_audit')
13+
14+
15+
def audit_mcp(action: str, outcome: str = 'attempt', details: dict = None, error: str = ''):
16+
"""Record one MCP audit entry. Non-blocking: runs in a daemon thread."""
17+
token_id = getattr(g, 'mcp_token_id', None)
18+
user_id = getattr(g, 'mcp_token_user_id', None)
19+
if not user_id:
20+
token = getattr(g, 'token', None)
21+
if token and 'user' in token:
22+
user_id = str(token['user'].get('id', ''))
23+
ip = request.remote_addr
24+
25+
# Capture a snapshot of the db connection so the thread can use it safely.
26+
# For simplicity we write synchronously on the request db connection since
27+
# the volume is low. If latency becomes a concern this can be offloaded.
28+
try:
29+
db = getattr(g, 'db', None)
30+
if db is None:
31+
return
32+
33+
db.execute('''
34+
INSERT INTO mcp_access_log (token_id, user_id, action, outcome, details, error, ip)
35+
VALUES (%s, %s, %s, %s, %s, %s, %s)
36+
''', [
37+
token_id,
38+
user_id,
39+
action,
40+
outcome,
41+
_to_json(details),
42+
error or None,
43+
ip,
44+
])
45+
db.commit()
46+
except Exception as exc:
47+
logger.warning('MCP audit log failed: %s', exc)
48+
49+
50+
def _to_json(d):
51+
if d is None:
52+
return None
53+
import json
54+
try:
55+
return json.dumps(d)
56+
except Exception:
57+
return str(d)

0 commit comments

Comments
 (0)