Skip to content

Commit 8ae042f

Browse files
author
amabito
committed
fix(evaluators): budget evaluator R1 -- security hardening + 6 adversarial tests
3-body review findings: Security: - Sanitize pipe/equals in scope key metadata values (injection prevention) - Add max_buckets=100K to InMemoryBudgetStore (OOM prevention, fail-closed) - Block dunder attribute access in _extract_by_path - Add math.isfinite guard on extracted cost values - Skip per-user rules when per field missing from metadata (was collapsing per-user budgets into global bucket) Correctness: - Changed exceeded check from > to >= (utilization=100% now triggers exceeded) - Removed unused BudgetSnapshot import from evaluator.py Tests (6 adversarial): - Exact limit boundary (USD and tokens) - Scope key injection via pipe character - max_buckets OOM prevention - per-field missing skips rule - dunder path rejection 54 budget tests, 284 total evaluator tests passing.
1 parent adaa614 commit 8ae042f

3 files changed

Lines changed: 90 additions & 10 deletions

File tree

evaluators/builtin/src/agent_control_evaluators/budget/evaluator.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,14 @@
88
from __future__ import annotations
99

1010
import logging
11+
import math
1112
from datetime import datetime, timezone
1213
from typing import Any
1314

1415
from agent_control_evaluators._base import Evaluator, EvaluatorMetadata
1516
from agent_control_evaluators._registry import register_evaluator
1617
from agent_control_evaluators.budget.config import BudgetEvaluatorConfig, BudgetLimitRule
17-
from agent_control_evaluators.budget.store import BudgetSnapshot, InMemoryBudgetStore
18+
from agent_control_evaluators.budget.store import InMemoryBudgetStore
1819
from agent_control_models import EvaluatorResult
1920

2021
logger = logging.getLogger(__name__)
@@ -40,6 +41,11 @@ def _derive_period_key(window: str | None) -> str:
4041
return ""
4142

4243

44+
def _sanitize_scope_value(val: str) -> str:
45+
"""Remove pipe and equals from scope values to prevent key injection."""
46+
return val.replace("|", "_").replace("=", "_")
47+
48+
4349
def _build_scope_key(
4450
scope: dict[str, str],
4551
per: str | None,
@@ -48,19 +54,22 @@ def _build_scope_key(
4854
"""Build a composite scope key from static dimensions and per-field.
4955
5056
Format: "channel=slack|user_id=u1" or "__global__" if empty.
57+
Values are sanitized to prevent injection via pipe/equals characters.
5158
"""
5259
parts: list[str] = []
5360
for k, v in sorted(scope.items()):
54-
parts.append(f"{k}={v}")
61+
parts.append(f"{k}={_sanitize_scope_value(v)}")
5562
if per and per in metadata:
56-
parts.append(f"{per}={metadata[per]}")
63+
parts.append(f"{per}={_sanitize_scope_value(metadata[per])}")
5764
return "|".join(parts) if parts else "__global__"
5865

5966

6067
def _extract_by_path(data: Any, path: str) -> Any:
6168
"""Extract a value from nested data using dot-notation path."""
6269
current = data
6370
for part in path.split("."):
71+
if part.startswith("_"):
72+
return None
6473
if isinstance(current, dict):
6574
current = current.get(part)
6675
elif hasattr(current, part):
@@ -112,7 +121,7 @@ def _extract_cost(data: Any, cost_path: str | None) -> float | None:
112121
if data is None or cost_path is None:
113122
return None
114123
val = _extract_by_path(data, cost_path)
115-
if isinstance(val, (int, float)) and val >= 0:
124+
if isinstance(val, (int, float)) and math.isfinite(val) and val >= 0:
116125
return float(val)
117126
return None
118127

@@ -286,11 +295,16 @@ async def evaluate(self, data: Any) -> EvaluatorResult:
286295

287296

288297
def _scope_matches(rule: BudgetLimitRule, metadata: dict[str, str]) -> bool:
289-
"""Check if rule's static scope dimensions match step metadata.
298+
"""Check if rule's scope dimensions match step metadata.
290299
291300
An empty scope dict matches everything (global rule).
301+
If ``per`` is set but the field is missing from metadata, the rule
302+
is skipped to prevent per-user budgets from collapsing into a
303+
single global bucket.
292304
"""
293305
for key, expected in rule.scope.items():
294306
if metadata.get(key) != expected:
295307
return False
308+
if rule.per and rule.per not in metadata:
309+
return False
296310
return True

evaluators/builtin/src/agent_control_evaluators/budget/store.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -103,9 +103,12 @@ class InMemoryBudgetStore:
103103
Redis or Postgres-backed store (separate package).
104104
"""
105105

106-
def __init__(self) -> None:
106+
_DEFAULT_MAX_BUCKETS = 100_000
107+
108+
def __init__(self, *, max_buckets: int = _DEFAULT_MAX_BUCKETS) -> None:
107109
self._lock = threading.Lock()
108110
self._buckets: dict[tuple[str, str], _Bucket] = {}
111+
self._max_buckets = max_buckets
109112

110113
def record_and_check(
111114
self,
@@ -135,6 +138,16 @@ def record_and_check(
135138
with self._lock:
136139
bucket = self._buckets.get(key)
137140
if bucket is None:
141+
if len(self._buckets) >= self._max_buckets:
142+
# Fail-closed: treat as exceeded to prevent OOM
143+
return BudgetSnapshot(
144+
spent_usd=cost_usd,
145+
spent_tokens=input_tokens + output_tokens,
146+
limit_usd=limit_usd,
147+
limit_tokens=limit_tokens,
148+
utilization=1.0,
149+
exceeded=True,
150+
)
138151
bucket = _Bucket()
139152
self._buckets[key] = bucket
140153
bucket.spent_usd += cost_usd
@@ -146,9 +159,9 @@ def record_and_check(
146159
bucket.spent_usd, total_tokens, limit_usd, limit_tokens
147160
)
148161
exceeded = False
149-
if limit_usd is not None and bucket.spent_usd > limit_usd:
162+
if limit_usd is not None and bucket.spent_usd >= limit_usd:
150163
exceeded = True
151-
if limit_tokens is not None and total_tokens > limit_tokens:
164+
if limit_tokens is not None and total_tokens >= limit_tokens:
152165
exceeded = True
153166

154167
return BudgetSnapshot(
@@ -185,9 +198,9 @@ def get_snapshot(
185198
bucket.spent_usd, total_tokens, limit_usd, limit_tokens
186199
)
187200
exceeded = False
188-
if limit_usd is not None and bucket.spent_usd > limit_usd:
201+
if limit_usd is not None and bucket.spent_usd >= limit_usd:
189202
exceeded = True
190-
if limit_tokens is not None and total_tokens > limit_tokens:
203+
if limit_tokens is not None and total_tokens >= limit_tokens:
191204
exceeded = True
192205
return BudgetSnapshot(
193206
spent_usd=bucket.spent_usd,

evaluators/builtin/tests/budget/test_budget.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -387,3 +387,56 @@ def test_from_dict(self) -> None:
387387
ev = BudgetEvaluator.from_dict({"limits": [{"limit_usd": 5.0}]})
388388
assert isinstance(ev, BudgetEvaluator)
389389
assert ev.config.limits[0].limit_usd == 5.0
390+
391+
392+
# ---------------------------------------------------------------------------
393+
# Security / adversarial tests (R1 findings)
394+
# ---------------------------------------------------------------------------
395+
396+
397+
class TestBudgetAdversarial:
398+
"""Adversarial tests for budget evaluator security."""
399+
400+
def test_exceeded_at_exact_limit_usd(self) -> None:
401+
"""Spending exactly the limit must trigger exceeded (>= not >)."""
402+
store = InMemoryBudgetStore()
403+
snap = store.record_and_check("s", "p", 0, 0, 1.0, limit_usd=1.0)
404+
assert snap.exceeded is True
405+
assert snap.utilization == pytest.approx(1.0)
406+
407+
def test_exceeded_at_exact_limit_tokens(self) -> None:
408+
store = InMemoryBudgetStore()
409+
snap = store.record_and_check("s", "p", 500, 500, 0.0, limit_tokens=1000)
410+
assert snap.exceeded is True
411+
412+
def test_scope_key_injection_pipe(self) -> None:
413+
"""Pipe in metadata value must be sanitized, not create new scope dimension."""
414+
key = _build_scope_key({"ch": "slack"}, "uid", {"ch": "slack", "uid": "u1|ch=admin"})
415+
assert "|ch=admin" not in key.split("|")[-1] # injected dimension not present
416+
assert "u1_ch_admin" in key # sanitized
417+
418+
def test_max_buckets_prevents_oom(self) -> None:
419+
store = InMemoryBudgetStore(max_buckets=5)
420+
for i in range(10):
421+
snap = store.record_and_check(f"scope-{i}", "p", 1, 1, 0.01, limit_usd=100.0)
422+
# After 5, new buckets are rejected with exceeded=True
423+
assert len(store._buckets) == 5
424+
425+
@pytest.mark.asyncio
426+
async def test_per_without_metadata_skips_rule(self) -> None:
427+
"""per='user_id' but user_id missing -> rule skipped, not global."""
428+
from agent_control_evaluators.budget.evaluator import BudgetEvaluator
429+
config = BudgetEvaluatorConfig(
430+
limits=[{"scope": {}, "per": "user_id", "limit_usd": 1.0}],
431+
cost_path="cost",
432+
metadata_paths={"user_id": "user_id"},
433+
)
434+
ev = BudgetEvaluator(config)
435+
# No user_id in data -> rule skipped -> not matched (no applicable rules)
436+
result = await ev.evaluate({"cost": 999.0})
437+
assert result.matched is False
438+
439+
def test_extract_by_path_rejects_dunder(self) -> None:
440+
from agent_control_evaluators.budget.evaluator import _extract_by_path
441+
assert _extract_by_path({"a": 1}, "__class__") is None
442+
assert _extract_by_path({"a": {"__init__": 1}}, "a.__init__") is None

0 commit comments

Comments
 (0)