Skip to content

Commit 0e6583c

Browse files
committed
fix(mcp-proxy): close P10.5-security defense-in-depth gaps
Enforce DecisionReceipt schema and binding semantics during offline evidence bundle verification, including orphan signed-receipt warnings. Add durable fsync handling for secure CLI JSON writes. Bound client JSON-RPC input lines and downstream response bookkeeping for malicious or noisy downstream behavior. Implemented with assistance from Codex.
1 parent dca780e commit 0e6583c

6 files changed

Lines changed: 490 additions & 22 deletions

File tree

agentveil_mcp_proxy/cli.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
parse_utc_timestamp,
4444
verify_evidence_bundle_file,
4545
)
46+
from agentveil_mcp_proxy.evidence.proof import _fsync_parent_directory
4647
from agentveil_mcp_proxy.identity import (
4748
IdentityDecryptError,
4849
IdentityError,
@@ -203,18 +204,31 @@ def _secure_write_json(path: Path, data: dict[str, Any], *, force: bool = False)
203204
if force:
204205
tmp_path = path.with_name(f".{path.name}.tmp")
205206
flags = os.O_WRONLY | os.O_CREAT | os.O_TRUNC
206-
with os.fdopen(os.open(tmp_path, flags, 0o600), "w", encoding="utf-8") as fh:
207-
json.dump(data, fh, indent=2, sort_keys=True)
208-
fh.write("\n")
209-
os.replace(tmp_path, path)
210-
os.chmod(path, 0o600)
207+
try:
208+
with os.fdopen(os.open(tmp_path, flags, 0o600), "w", encoding="utf-8") as fh:
209+
json.dump(data, fh, indent=2, sort_keys=True)
210+
fh.write("\n")
211+
fh.flush()
212+
os.fsync(fh.fileno())
213+
os.replace(tmp_path, path)
214+
os.chmod(path, 0o600)
215+
_fsync_parent_directory(path)
216+
except Exception:
217+
try:
218+
tmp_path.unlink()
219+
except OSError:
220+
pass
221+
raise
211222
return
212223

213224
flags = os.O_WRONLY | os.O_CREAT | os.O_EXCL
214225
with os.fdopen(os.open(path, flags, 0o600), "w", encoding="utf-8") as fh:
215226
json.dump(data, fh, indent=2, sort_keys=True)
216227
fh.write("\n")
228+
fh.flush()
229+
os.fsync(fh.fileno())
217230
os.chmod(path, 0o600)
231+
_fsync_parent_directory(path)
218232

219233

220234
def _read_json(path: Path, label: str) -> dict[str, Any]:

agentveil_mcp_proxy/evidence/proof.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222

2323
EVIDENCE_EXPORT_SCHEMA_VERSION = 1
24+
_DECISION_RECEIPT_SCHEMAS = frozenset({"decision_receipt/1", "decision_receipt/2"})
2425
_RECEIPT_RECORD_CROSS_CHECK_FIELDS = (
2526
("payload_hash", "payload_hash"),
2627
("risk_class", "client_risk_class"),
@@ -219,11 +220,14 @@ def verify_evidence_bundle(
219220
actual_digest = hashlib.sha256(receipt_jcs.encode("utf-8")).hexdigest()
220221
if digest != actual_digest:
221222
raise EvidenceVerificationError("signed receipt digest mismatch")
222-
verified_bodies[digest] = _verify_receipt_with_pinned_signers(
223+
receipt_body = _verify_receipt_with_pinned_signers(
223224
receipt_jcs,
224225
pinned_signers,
225226
)
227+
_verify_decision_receipt_semantics(receipt_body)
228+
verified_bodies[digest] = receipt_body
226229

230+
referenced_receipt_digests: set[str] = set()
227231
for record in records:
228232
if not isinstance(record, dict):
229233
continue
@@ -232,16 +236,23 @@ def verify_evidence_bundle(
232236
continue
233237
if receipt_digest not in verified_bodies:
234238
continue
239+
referenced_receipt_digests.add(receipt_digest)
235240
receipt_body = verified_bodies[receipt_digest]
236241
for record_field, receipt_field in _RECEIPT_RECORD_CROSS_CHECK_FIELDS:
237242
record_value = record.get(record_field)
238243
receipt_value = receipt_body.get(receipt_field)
239-
if record_value is None or receipt_value is None:
244+
if record_value is None:
240245
continue
246+
if receipt_value is None:
247+
raise EvidenceVerificationError(
248+
f"DecisionReceipt {receipt_field} missing for referenced record"
249+
)
241250
if receipt_value != record_value:
242251
raise EvidenceVerificationError(
243252
f"DecisionReceipt {receipt_field} mismatch with record {record_field}"
244253
)
254+
for digest in sorted(set(verified_bodies) - referenced_receipt_digests):
255+
warnings.append(f"signed receipt {digest[:16]}... not referenced by any record")
245256

246257
return EvidenceVerificationResult(
247258
valid=True,
@@ -313,6 +324,13 @@ def _verify_receipt_with_pinned_signers(
313324
raise EvidenceVerificationError("signed receipt signer is not trusted") from last_error
314325

315326

327+
def _verify_decision_receipt_semantics(body: Mapping[str, Any]) -> None:
328+
if body.get("schema_version") not in _DECISION_RECEIPT_SCHEMAS:
329+
raise EvidenceVerificationError("signed receipt schema unsupported")
330+
if not isinstance(body.get("audit_id"), str) or not body.get("audit_id"):
331+
raise EvidenceVerificationError("signed receipt audit_id missing")
332+
333+
316334
def _compute_unverified_receipt_count(
317335
records: Iterable[Mapping[str, Any]],
318336
receipts: Mapping[str, Any],

agentveil_mcp_proxy/passthrough.py

Lines changed: 142 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,9 @@
6060
JSONRPC_DOWNSTREAM_TIMEOUT = -32014
6161
DEFAULT_DOWNSTREAM_RESPONSE_TIMEOUT_SECONDS = 30.0
6262
MAX_DOWNSTREAM_MESSAGE_BYTES = 1 * 1024 * 1024
63+
MAX_CLIENT_MESSAGE_BYTES = 1 * 1024 * 1024
64+
MAX_PENDING_RESPONSES = 1000
65+
DEFAULT_TIMED_OUT_ID_RETENTION_SECONDS = 600.0
6366
SAFE_ENV_KEYS = (
6467
"PATH",
6568
"HOME",
@@ -287,6 +290,46 @@ def jsonrpc_error(
287290
}
288291

289292

293+
def _read_bounded_line(client_in: TextIO, max_bytes: int) -> tuple[str | None, bool]:
294+
read = getattr(client_in, "read", None)
295+
if not callable(read):
296+
try:
297+
raw_line = next(client_in) # type: ignore[arg-type]
298+
except StopIteration:
299+
return None, False
300+
raw_bytes = raw_line.encode("utf-8", errors="replace")
301+
if not raw_line.endswith("\n") or len(raw_bytes.rstrip(b"\n")) > max_bytes:
302+
return "", True
303+
return raw_line, False
304+
305+
chunks: list[str] = []
306+
byte_count = 0
307+
while True:
308+
char = read(1)
309+
if char == "":
310+
if chunks:
311+
return "", True
312+
return None, False
313+
if char == "\n":
314+
return "".join(chunks) + "\n", False
315+
char_size = len(char.encode("utf-8", errors="replace"))
316+
if byte_count + char_size > max_bytes:
317+
_discard_line_remainder(client_in)
318+
return "", True
319+
chunks.append(char)
320+
byte_count += char_size
321+
322+
323+
def _discard_line_remainder(client_in: TextIO) -> None:
324+
read = getattr(client_in, "read", None)
325+
if not callable(read):
326+
return
327+
while True:
328+
char = read(1)
329+
if char in {"", "\n"}:
330+
return
331+
332+
290333
def _blocked_error(
291334
request_id: Any,
292335
message: str,
@@ -356,8 +399,11 @@ def __init__(
356399
self._runtime_gate_startup_error: Exception | None = None
357400
self._runtime_gate_errors = 0
358401
self._downstream_timeouts = 0
402+
self._client_oversized_messages = 0
403+
self._unsolicited_downstream_responses = 0
359404
self._security_events: Deque[Mapping[str, Any]] = deque(maxlen=1000)
360-
self._timed_out_response_ids: set[str] = set()
405+
self._inflight_ids: set[str] = set()
406+
self._timed_out_response_ids: dict[str, float] = {}
361407
self._windows_job: _WindowsJobObject | None = None
362408

363409
@property
@@ -384,6 +430,18 @@ def downstream_timeouts(self) -> int:
384430

385431
return self._downstream_timeouts
386432

433+
@property
434+
def client_oversized_messages(self) -> int:
435+
"""Number of oversized or unterminated client messages rejected."""
436+
437+
return self._client_oversized_messages
438+
439+
@property
440+
def unsolicited_downstream_responses(self) -> int:
441+
"""Number of downstream responses dropped for unknown client request IDs."""
442+
443+
return self._unsolicited_downstream_responses
444+
387445
@property
388446
def security_events(self) -> tuple[Mapping[str, Any], ...]:
389447
"""Sanitized in-memory security events for P5 failure handling."""
@@ -496,7 +554,21 @@ def run_stdio(self, client_in: TextIO, client_out: TextIO) -> int:
496554
self._notification_writer = lambda message: self._write_client(client_out, message)
497555
self.start()
498556
try:
499-
for raw_line in client_in:
557+
while True:
558+
raw_line, rejected = _read_bounded_line(client_in, MAX_CLIENT_MESSAGE_BYTES)
559+
if rejected:
560+
self._increment_client_oversized_messages()
561+
self._write_client(
562+
client_out,
563+
jsonrpc_error(
564+
None,
565+
JSONRPC_INVALID_REQUEST,
566+
"client request exceeds maximum size",
567+
),
568+
)
569+
continue
570+
if raw_line is None:
571+
break
500572
if not raw_line.strip():
501573
continue
502574
responses = self.handle_client_line(raw_line)
@@ -528,12 +600,17 @@ def handle_client_line(self, raw_line: str) -> list[dict[str, Any]]:
528600
policy_error, approval_outcome = self._policy_error_response(classification, request_id)
529601
if policy_error is not None:
530602
return [policy_error] if has_id else []
531-
self._send_downstream(message)
532-
if not has_id:
533-
return []
534-
response = self._wait_downstream_response(request_id)
535-
self._record_approval_result(approval_outcome, response)
536-
return [response]
603+
response_key = self._register_inflight_id(request_id) if has_id else None
604+
try:
605+
self._send_downstream(message)
606+
if not has_id:
607+
return []
608+
response = self._wait_downstream_response(request_id)
609+
self._record_approval_result(approval_outcome, response)
610+
return [response]
611+
finally:
612+
if response_key is not None:
613+
self._unregister_inflight_id(response_key)
537614
except DownstreamTimeoutError:
538615
self._increment_downstream_timeouts()
539616
self._record_approval_error(approval_outcome, "downstream_response_timeout")
@@ -801,11 +878,31 @@ def _increment_downstream_timeouts(self) -> None:
801878
with self._counters_lock:
802879
self._downstream_timeouts += 1
803880

881+
def _increment_client_oversized_messages(self) -> None:
882+
with self._counters_lock:
883+
self._client_oversized_messages += 1
884+
885+
def _increment_unsolicited_downstream_responses(self) -> None:
886+
with self._counters_lock:
887+
self._unsolicited_downstream_responses += 1
888+
889+
def _register_inflight_id(self, request_id: Any) -> str:
890+
response_key = self._id_key(request_id)
891+
with self._stdout_condition:
892+
self._inflight_ids.add(response_key)
893+
return response_key
894+
895+
def _unregister_inflight_id(self, response_key: str) -> None:
896+
with self._stdout_condition:
897+
self._inflight_ids.discard(response_key)
898+
self._prune_pending_responses_locked()
899+
804900
def _wait_downstream_response(self, expected_id: Any) -> dict[str, Any]:
805901
response_key = self._id_key(expected_id)
806902
deadline = time.monotonic() + self.downstream.response_timeout_seconds
807903
with self._stdout_condition:
808904
while True:
905+
self._prune_timed_out_ids_locked()
809906
queued = self._responses.get(response_key)
810907
if queued:
811908
response = queued.pop(0)
@@ -816,7 +913,9 @@ def _wait_downstream_response(self, expected_id: Any) -> dict[str, Any]:
816913
raise self._downstream_error
817914
remaining = deadline - time.monotonic()
818915
if remaining <= 0:
819-
self._timed_out_response_ids.add(response_key)
916+
self._timed_out_response_ids[
917+
response_key
918+
] = time.monotonic() + DEFAULT_TIMED_OUT_ID_RETENTION_SECONDS
820919
raise DownstreamTimeoutError("downstream response timed out")
821920
self._stdout_condition.wait(timeout=remaining)
822921

@@ -911,13 +1010,46 @@ def _handle_downstream_message(self, response: Any) -> None:
9111010
return
9121011
if "id" in response:
9131012
with self._stdout_condition:
1013+
self._prune_timed_out_ids_locked()
9141014
response_key = self._id_key(response.get("id"))
9151015
if response_key in self._timed_out_response_ids:
916-
self._timed_out_response_ids.remove(response_key)
1016+
self._timed_out_response_ids.pop(response_key, None)
1017+
return
1018+
if response_key not in self._inflight_ids:
1019+
self._increment_unsolicited_downstream_responses()
9171020
return
9181021
self._responses.setdefault(response_key, []).append(response)
1022+
self._prune_pending_responses_locked()
9191023
self._stdout_condition.notify_all()
9201024

1025+
def _prune_timed_out_ids_locked(self, now: float | None = None) -> None:
1026+
now = time.monotonic() if now is None else now
1027+
expired = [
1028+
response_key
1029+
for response_key, expires_at in self._timed_out_response_ids.items()
1030+
if expires_at <= now
1031+
]
1032+
for response_key in expired:
1033+
self._timed_out_response_ids.pop(response_key, None)
1034+
1035+
def _prune_pending_responses_locked(self) -> None:
1036+
pending_count = sum(len(responses) for responses in self._responses.values())
1037+
while pending_count > MAX_PENDING_RESPONSES:
1038+
dropped = False
1039+
for response_key, responses in list(self._responses.items()):
1040+
if response_key in self._inflight_ids:
1041+
continue
1042+
if responses:
1043+
responses.pop(0)
1044+
pending_count -= 1
1045+
dropped = True
1046+
if not responses:
1047+
self._responses.pop(response_key, None)
1048+
if pending_count <= MAX_PENDING_RESPONSES:
1049+
return
1050+
if not dropped:
1051+
return
1052+
9211053
def _downstream_buffer_too_large(self, buffer: str) -> bool:
9221054
return len(buffer.encode("utf-8", errors="replace")) > MAX_DOWNSTREAM_MESSAGE_BYTES
9231055

tests/test_mcp_proxy_cli.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,11 @@
77
import json
88
import os
99
from pathlib import Path
10+
import stat
1011

1112
import pytest
1213

14+
import agentveil_mcp_proxy.cli as proxy_cli
1315
from agentveil.delegation import verify_delegation
1416
from agentveil_mcp_proxy.cli import (
1517
AGENTVEIL_DEV_SIGNER_DIDS,
@@ -97,6 +99,38 @@ def _replace_grant(
9799
return grant
98100

99101

102+
def test_secure_write_json_fsyncs_before_close(tmp_path, monkeypatch):
103+
calls: list[int] = []
104+
105+
def fake_fsync(fd: int) -> None:
106+
os.fstat(fd)
107+
calls.append(fd)
108+
109+
monkeypatch.setattr(proxy_cli.os, "fsync", fake_fsync)
110+
111+
proxy_cli._secure_write_json(tmp_path / "config.json", {"ok": True})
112+
113+
assert calls
114+
115+
116+
def test_secure_write_json_force_fsyncs_parent_directory_on_posix(tmp_path, monkeypatch):
117+
if os.name == "nt":
118+
pytest.skip("directory fsync is POSIX-specific")
119+
calls: list[bool] = []
120+
121+
def fake_fsync(fd: int) -> None:
122+
calls.append(stat.S_ISDIR(os.fstat(fd).st_mode))
123+
124+
monkeypatch.setattr(proxy_cli.os, "fsync", fake_fsync)
125+
path = tmp_path / "config.json"
126+
proxy_cli._secure_write_json(path, {"old": True})
127+
128+
calls.clear()
129+
proxy_cli._secure_write_json(path, {"new": True}, force=True)
130+
131+
assert calls[-1] is True
132+
133+
100134
def test_init_creates_identity_config_and_control_grant_with_0600(tmp_path):
101135
home = tmp_path / "avp-home"
102136
result = init_proxy(

0 commit comments

Comments
 (0)