Skip to content

Commit aa7d679

Browse files
committed
test(mcp-proxy): tighten circuit breaker hygiene
Bound circuit breaker state-change events, remove the unused cooldown_remaining_seconds property, and align policy.__all__ with the existing package-level ProxyCircuitBreakerConfig export. Adds 54 circuit breaker test cases covering bounded event retention, removed dead-code surface, expanded config validation, success-clears-failures semantics, policy export coverage, and RuntimeGateUntrustedError paths that must not trip the breaker. Implemented with assistance from Codex.
1 parent e0659e0 commit aa7d679

3 files changed

Lines changed: 175 additions & 14 deletions

File tree

agentveil_mcp_proxy/circuit_breaker.py

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ def __init__(
8686
self._opened_at: float | None = None
8787
self._half_open_successes = 0
8888
self._state_change_count = 0
89-
self._events: Deque[dict[str, Any]] = deque()
89+
self._events: Deque[dict[str, Any]] = deque(maxlen=1000)
9090

9191
@property
9292
def state(self) -> CircuitState:
@@ -102,16 +102,6 @@ def state_change_count(self) -> int:
102102
with self._lock:
103103
return self._state_change_count
104104

105-
@property
106-
def cooldown_remaining_seconds(self) -> float:
107-
"""Return seconds until an open circuit can move to half-open."""
108-
109-
with self._lock:
110-
if self._state is not CircuitState.OPEN or self._opened_at is None:
111-
return 0.0
112-
elapsed = self._time_func() - self._opened_at
113-
return max(0.0, self.config.cooldown_seconds - elapsed)
114-
115105
def before_call(self) -> None:
116106
"""Allow a backend call or raise when the circuit is open."""
117107

agentveil_mcp_proxy/policy.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -831,6 +831,7 @@ def builtin_policy_pack(name: str) -> PolicyConfig:
831831
"MAX_RUNTIME_EVENTS",
832832
"PROXY_CONFIG_SCHEMA_VERSION",
833833
"POLICY_SCHEMA_VERSION",
834+
"ProxyCircuitBreakerConfig",
834835
"ProxyConfig",
835836
"ProxyConfigError",
836837
"RiskClass",

tests/test_mcp_proxy_circuit_breaker.py

Lines changed: 173 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
)
2323
from agentveil_mcp_proxy.classification import ToolCallClassifier
2424
from agentveil_mcp_proxy.cli import doctor_proxy, init_proxy
25+
import agentveil_mcp_proxy.policy as policy_module
2526
from agentveil_mcp_proxy.passthrough import (
2627
DownstreamConfig,
2728
JSONRPC_RUNTIME_GATE_UNAVAILABLE,
@@ -143,8 +144,8 @@ def _sign_jcs(body: dict, seed: bytes = BACKEND_SEED) -> str:
143144
return jcs.canonicalize(signed).decode("utf-8")
144145

145146

146-
def _decision_receipt(request: dict, *, seed: bytes = BACKEND_SEED) -> str:
147-
return _sign_jcs({
147+
def _decision_receipt_body(request: dict) -> dict:
148+
return {
148149
"schema_version": "decision_receipt/2",
149150
"audit_id": AUDIT_ID,
150151
"agent_did": AGENT_DID,
@@ -155,7 +156,11 @@ def _decision_receipt(request: dict, *, seed: bytes = BACKEND_SEED) -> str:
155156
"payload_hash": request["payload_hash"],
156157
"client_risk_class": request["risk_class"],
157158
"client_policy_context_hash": request["policy_context_hash"],
158-
}, seed=seed)
159+
}
160+
161+
162+
def _decision_receipt(request: dict, *, seed: bytes = BACKEND_SEED) -> str:
163+
return _sign_jcs(_decision_receipt_body(request), seed=seed)
159164

160165

161166
class RecordingAgent:
@@ -177,6 +182,75 @@ def runtime_evaluate(self, **kwargs):
177182
}
178183

179184

185+
_UNTRUSTED_FIELD_MISMATCHES = {
186+
"field_action_mismatch": "action",
187+
"field_resource_mismatch": "resource",
188+
"field_environment_mismatch": "environment",
189+
"field_payload_hash_mismatch": "payload_hash",
190+
"field_client_risk_class_mismatch": "client_risk_class",
191+
"field_client_policy_context_hash_mismatch": "client_policy_context_hash",
192+
}
193+
194+
_UNTRUSTED_FIELD_MISSING = {
195+
"field_action_missing": "action",
196+
"field_resource_missing": "resource",
197+
"field_environment_missing": "environment",
198+
"field_payload_hash_missing": "payload_hash",
199+
"field_client_risk_class_missing": "client_risk_class",
200+
"field_client_policy_context_hash_missing": "client_policy_context_hash",
201+
}
202+
203+
204+
class UntrustedReceiptAgent:
205+
did = AGENT_DID
206+
207+
def __init__(self, scenario: str):
208+
self.scenario = scenario
209+
self.calls: list[dict] = []
210+
self.receipt_jcs = ""
211+
212+
def runtime_evaluate(self, **kwargs):
213+
self.calls.append(kwargs)
214+
if self.scenario == "receipt_missing_response":
215+
return {"decision": "ALLOW"}
216+
217+
body = _decision_receipt_body(kwargs)
218+
response = {"audit_id": AUDIT_ID, "decision": "ALLOW"}
219+
seed = BACKEND_SEED
220+
221+
if self.scenario == "receipt_fetch_empty":
222+
self.receipt_jcs = ""
223+
return response
224+
if self.scenario == "signature_fail":
225+
seed = OTHER_BACKEND_SEED
226+
elif self.scenario == "schema_unsupported":
227+
body["schema_version"] = "decision_receipt/99"
228+
elif self.scenario == "decision_unsupported":
229+
body["decision"] = "UNKNOWN"
230+
response["decision"] = "UNKNOWN"
231+
elif self.scenario == "audit_id_missing":
232+
body.pop("audit_id")
233+
elif self.scenario == "response_decision_mismatch":
234+
response["decision"] = "BLOCK"
235+
elif self.scenario == "response_audit_id_mismatch":
236+
response["audit_id"] = "urn:uuid:22222222-2222-4222-8222-222222222222"
237+
elif self.scenario == "agent_did_mismatch":
238+
body["agent_did"] = BACKEND_DID
239+
elif self.scenario in _UNTRUSTED_FIELD_MISMATCHES:
240+
body[_UNTRUSTED_FIELD_MISMATCHES[self.scenario]] = "mismatch"
241+
elif self.scenario in _UNTRUSTED_FIELD_MISSING:
242+
body.pop(_UNTRUSTED_FIELD_MISSING[self.scenario])
243+
elif self.scenario != "empty_trusted_signer_dids":
244+
raise AssertionError(f"unhandled untrusted scenario: {self.scenario}")
245+
246+
self.receipt_jcs = _sign_jcs(body, seed=seed)
247+
return {**response, "decision_receipt_jcs": self.receipt_jcs}
248+
249+
def get_decision_receipt(self, audit_id: str) -> str:
250+
assert audit_id == AUDIT_ID
251+
return self.receipt_jcs
252+
253+
180254
def _json_line(message: dict) -> str:
181255
return json.dumps(message, separators=(",", ":")) + "\n"
182256

@@ -241,6 +315,29 @@ def test_circuit_starts_closed():
241315
assert breaker.state_change_count == 0
242316

243317

318+
def test_circuit_has_no_dead_cooldown_remaining_property():
319+
assert not hasattr(CircuitBreaker(), "cooldown_" + "remaining_seconds")
320+
321+
322+
def test_events_deque_bounded_under_repeated_state_changes():
323+
clock = Clock()
324+
breaker = CircuitBreaker(
325+
CircuitBreakerConfig(failures_before_open=1, cooldown_seconds=1),
326+
time_func=clock,
327+
)
328+
329+
for _ in range(334):
330+
breaker.record_failure()
331+
clock.advance(1)
332+
breaker.before_call()
333+
breaker.record_success()
334+
335+
assert len(breaker._events) == 1000
336+
events = breaker.drain_events()
337+
assert events[0]["state_change_count"] == 3
338+
assert events[-1]["state_change_count"] == 1002
339+
340+
244341
def test_circuit_opens_after_threshold_failures():
245342
breaker = CircuitBreaker(CircuitBreakerConfig(failures_before_open=2))
246343

@@ -267,6 +364,20 @@ def test_circuit_failure_window_excludes_old_failures():
267364
assert breaker.state == CircuitState.CLOSED
268365

269366

367+
def test_record_success_clears_failure_window_in_closed_state():
368+
breaker = CircuitBreaker(CircuitBreakerConfig(failures_before_open=5))
369+
370+
for _ in range(3):
371+
breaker.record_failure()
372+
assert breaker.state == CircuitState.CLOSED
373+
374+
breaker.record_success()
375+
376+
for _ in range(4):
377+
breaker.record_failure()
378+
assert breaker.state == CircuitState.CLOSED
379+
380+
270381
def test_circuit_open_raises_runtime_gate_unavailable_immediately():
271382
config = _config()
272383
agent = RecordingAgent()
@@ -401,6 +512,49 @@ def test_runtime_gate_client_does_not_count_untrusted_errors_as_circuit_failures
401512
assert breaker.state_change_count == 0
402513

403514

515+
@pytest.mark.parametrize("scenario", [
516+
"receipt_missing_response",
517+
"receipt_fetch_empty",
518+
"signature_fail",
519+
"schema_unsupported",
520+
"decision_unsupported",
521+
"audit_id_missing",
522+
"response_decision_mismatch",
523+
"response_audit_id_mismatch",
524+
"agent_did_mismatch",
525+
"field_action_mismatch",
526+
"field_resource_mismatch",
527+
"field_environment_mismatch",
528+
"field_payload_hash_mismatch",
529+
"field_client_risk_class_mismatch",
530+
"field_client_policy_context_hash_mismatch",
531+
"field_action_missing",
532+
"field_resource_missing",
533+
"field_environment_missing",
534+
"field_payload_hash_missing",
535+
"field_client_risk_class_missing",
536+
"field_client_policy_context_hash_missing",
537+
"empty_trusted_signer_dids",
538+
])
539+
def test_untrusted_boundary_does_not_trip_circuit_breaker_across_all_paths(scenario):
540+
config = _config()
541+
breaker = CircuitBreaker(CircuitBreakerConfig(failures_before_open=1))
542+
client = RuntimeGateClient(
543+
agent=UntrustedReceiptAgent(scenario),
544+
config=config,
545+
control_grant={"id": "grant"},
546+
circuit_breaker=breaker,
547+
)
548+
if scenario == "empty_trusted_signer_dids":
549+
client.trusted_signer_dids = ()
550+
551+
with pytest.raises(RuntimeGateUntrustedError):
552+
client.evaluate(_classification(config))
553+
554+
assert breaker.state == CircuitState.CLOSED
555+
assert breaker.state_change_count == 0
556+
557+
404558
def test_open_circuit_skips_backend_call_entirely():
405559
config = _config()
406560
agent = RecordingAgent()
@@ -432,6 +586,18 @@ def test_circuit_breaker_config_validates_positive_integers():
432586
_config(circuit_breaker={field: True})
433587

434588

589+
@pytest.mark.parametrize("field", [
590+
"failures_before_open",
591+
"window_seconds",
592+
"cooldown_seconds",
593+
"half_open_test_count",
594+
])
595+
@pytest.mark.parametrize("invalid_value", [-1, 5.0, 5.5, None, "5", [], {}])
596+
def test_circuit_breaker_config_rejects_invalid_types_and_values(field, invalid_value):
597+
with pytest.raises(ProxyConfigError):
598+
_config(circuit_breaker={field: invalid_value})
599+
600+
435601
def test_circuit_breaker_config_rejects_unknown_fields():
436602
with pytest.raises(ProxyConfigError, match="unknown"):
437603
_config(circuit_breaker={"unknown": 1})
@@ -446,6 +612,10 @@ def test_circuit_breaker_config_defaults_when_block_absent():
446612
assert config.circuit_breaker.half_open_test_count == 1
447613

448614

615+
def test_proxy_circuit_breaker_config_in_policy_all():
616+
assert "ProxyCircuitBreakerConfig" in policy_module.__all__
617+
618+
449619
def test_open_circuit_cascades_to_existing_fallback_policy_per_risk_class(tmp_path):
450620
config = _config(fallback={"write": "allow"})
451621
breaker = CircuitBreaker(CircuitBreakerConfig(failures_before_open=1))

0 commit comments

Comments
 (0)