Skip to content

Commit b189e85

Browse files
author
amabito
committed
feat(budget): address R3 review -- async ABC, TTL prune, defensive guards
Respond to lan17's R3 review on PR #144 with the mechanical items that do not depend on pending config-layer decisions (limit model, budget_id, unknown_model_behavior). Changes: - Migrate BudgetStore from Protocol to async ABC with __init_subclass__ guard that walks the MRO to reject sync overrides at class creation - InMemoryBudgetStore: async wrapper around sync helper, threading.Lock retained for CPU-bound critical section - TTL prune for stale period buckets on rollover, runs before max_buckets capacity check so rollover at capacity reclaims space - Monotonic prune watermark (rejects backwards clock) - _compute_utilization low-side clamp to [0.0, 1.0] (refund semantic) - Defensive guards: NaN/Inf cost and clock coerced to 0.0, negative token counts clamped to 0 - Revert root pyproject.toml (remove unrelated [dependency-groups], restore version 7.3.1) - Remove clear_budget_stores from __all__ (testing utility) - Document token attribution intent (single int -> output-only) Tests: 67 -> 91 (24 new: async migration, TTL prune coverage, adversarial guards, ABC contract enforcement)
1 parent b4d4d24 commit b189e85

File tree

6 files changed

+755
-73
lines changed

6 files changed

+755
-73
lines changed
Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,18 @@
11
"""Budget evaluator for per-agent LLM cost and token tracking."""
22

33
from agent_control_evaluator_budget.budget.config import BudgetEvaluatorConfig
4-
from agent_control_evaluator_budget.budget.evaluator import (
5-
BudgetEvaluator,
6-
clear_budget_stores,
7-
)
4+
from agent_control_evaluator_budget.budget.evaluator import BudgetEvaluator
85
from agent_control_evaluator_budget.budget.memory_store import InMemoryBudgetStore
96
from agent_control_evaluator_budget.budget.store import BudgetSnapshot, BudgetStore
107

8+
# Note: clear_budget_stores is a testing utility and is intentionally not
9+
# re-exported here. Import it directly from the evaluator submodule in tests:
10+
# from agent_control_evaluator_budget.budget.evaluator import clear_budget_stores
11+
1112
__all__ = [
1213
"BudgetEvaluator",
1314
"BudgetEvaluatorConfig",
1415
"BudgetSnapshot",
1516
"BudgetStore",
1617
"InMemoryBudgetStore",
17-
"clear_budget_stores",
1818
]

evaluators/contrib/budget/src/agent_control_evaluator_budget/budget/evaluator.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,10 @@ def _extract_tokens(data: Any, token_path: str | None) -> tuple[int, int]:
109109
if token_path:
110110
val = _extract_by_path(data, token_path)
111111
if isinstance(val, int) and not isinstance(val, bool) and val >= 0:
112+
# When token_path resolves to a single int we cannot distinguish
113+
# input vs output. Attribute the whole count to output because
114+
# output rates are typically higher than input rates in pricing
115+
# tables, so this over-estimates cost rather than under-estimates.
112116
return 0, val
113117
if isinstance(val, dict):
114118
data = val
@@ -211,7 +215,7 @@ async def evaluate(self, data: Any) -> EvaluatorResult:
211215
step_metadata = _extract_metadata(data, self.config.metadata_paths)
212216

213217
store = get_or_create_store(self.config)
214-
snapshots = store.record_and_check(
218+
snapshots = await store.record_and_check(
215219
scope=step_metadata,
216220
input_tokens=input_tokens,
217221
output_tokens=output_tokens,

evaluators/contrib/budget/src/agent_control_evaluator_budget/budget/memory_store.py

Lines changed: 102 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,14 @@
66

77
from __future__ import annotations
88

9+
import math
910
import threading
1011
import time
1112
from collections.abc import Callable
1213
from dataclasses import dataclass
1314

1415
from .config import BudgetLimitRule
15-
from .store import BudgetSnapshot, round_spent
16+
from .store import BudgetSnapshot, BudgetStore, round_spent
1617

1718

1819
def _sanitize_scope_value(val: str) -> str:
@@ -34,6 +35,20 @@ def _build_scope_key(
3435
return "|".join(parts) if parts else "__global__"
3536

3637

38+
def _parse_period_key(key: str) -> tuple[int, int] | None:
39+
"""Parse 'P{window}:{index}' into (window_seconds, bucket_index).
40+
41+
Returns None for empty/cumulative keys.
42+
"""
43+
if not key or not key.startswith("P"):
44+
return None
45+
try:
46+
window_part, index_part = key[1:].split(":", 1)
47+
return int(window_part), int(index_part)
48+
except (ValueError, IndexError):
49+
return None
50+
51+
3752
def _derive_period_key(window_seconds: int | None, now: float) -> str:
3853
"""Derive a period key from window_seconds and a timestamp.
3954
@@ -63,12 +78,17 @@ def _compute_utilization(
6378
limit: int | None,
6479
limit_tokens: int | None,
6580
) -> float:
66-
"""Return max(spend_ratio, token_ratio) clamped to [0.0, 1.0]."""
81+
"""Return max(spend_ratio, token_ratio) clamped to [0.0, 1.0].
82+
83+
The low-side clamp is load-bearing: under refund semantics the internal
84+
`spent` accumulator may go negative, which would otherwise produce a
85+
negative ratio and violate the BudgetSnapshot.utilization contract.
86+
"""
6787
ratios: list[float] = []
6888
if limit is not None and limit > 0:
69-
ratios.append(min(spent / limit, 1.0))
89+
ratios.append(max(0.0, min(spent / limit, 1.0)))
7090
if limit_tokens is not None and limit_tokens > 0:
71-
ratios.append(min(spent_tokens / limit_tokens, 1.0))
91+
ratios.append(max(0.0, min(spent_tokens / limit_tokens, 1.0)))
7292
return max(ratios) if ratios else 0.0
7393

7494

@@ -85,14 +105,21 @@ def total_tokens(self) -> int:
85105
return self.input_tokens + self.output_tokens
86106

87107

88-
class InMemoryBudgetStore:
108+
class InMemoryBudgetStore(BudgetStore):
89109
"""Thread-safe in-memory budget store.
90110
91111
Initialized with a list of BudgetLimitRule. Derives period keys
92112
internally from window_seconds + injected clock.
93113
94114
Cost is accumulated as float for precision. Integer rounding
95115
happens only at snapshot time for display/reporting.
116+
117+
TTL prune: on new period rollover per window, buckets older than
118+
`current - 1` for that window are dropped. This keeps memory bounded
119+
for long-running deployments with windowed rules.
120+
121+
`max_buckets` remains as a backstop for high-cardinality group_by
122+
explosions that TTL cannot protect against.
96123
"""
97124

98125
_DEFAULT_MAX_BUCKETS = 100_000
@@ -109,16 +136,41 @@ def __init__(
109136
self._lock = threading.Lock()
110137
self._buckets: dict[tuple[str, str], _Bucket] = {}
111138
self._max_buckets = max_buckets
139+
self._last_pruned_period: dict[int, int] = {}
112140

113-
def record_and_check(
141+
async def record_and_check(
114142
self,
115143
scope: dict[str, str],
116144
input_tokens: int,
117145
output_tokens: int,
118146
cost: float,
119147
) -> list[BudgetSnapshot]:
120148
"""Atomically record usage and return snapshots for all matching rules."""
149+
return self._record_and_check_sync(scope, input_tokens, output_tokens, cost)
150+
151+
def _record_and_check_sync(
152+
self,
153+
scope: dict[str, str],
154+
input_tokens: int,
155+
output_tokens: int,
156+
cost: float,
157+
) -> list[BudgetSnapshot]:
158+
"""Sync implementation of record_and_check.
159+
160+
NaN/Inf cost is coerced to 0.0 defensively. Once NaN enters a
161+
bucket's float accumulator, all subsequent additions produce NaN
162+
and `nan >= limit` is always False (IEEE 754), permanently
163+
disabling budget enforcement for that bucket.
164+
"""
165+
if not math.isfinite(cost):
166+
cost = 0.0
167+
# Token counts have no refund semantics; clamp to non-negative
168+
# to prevent negative injection from resetting the accumulator.
169+
input_tokens = max(0, input_tokens)
170+
output_tokens = max(0, output_tokens)
121171
now = self._clock()
172+
if not math.isfinite(now):
173+
now = 0.0
122174
snapshots: list[BudgetSnapshot] = []
123175
recorded_pairs: set[tuple[str, str]] = set()
124176

@@ -152,8 +204,14 @@ def record_and_check(
152204
recorded_pairs.add(pair)
153205
else:
154206
bucket = self._buckets.get(pair)
155-
if bucket is None:
156-
continue
207+
# Defensive: this branch is unreachable under current
208+
# invariants (recorded_pairs only contains pairs whose
209+
# bucket was successfully created, and self._lock prevents
210+
# concurrent deletion). If a future refactor violates
211+
# this, the assertion surfaces it.
212+
assert bucket is not None, (
213+
f"bucket for {pair!r} was in recorded_pairs but missing from _buckets"
214+
)
157215

158216
total_tokens = bucket.total_tokens
159217
utilization = _compute_utilization(
@@ -219,6 +277,7 @@ def reset(self, scope_key: str | None = None, period_key: str | None = None) ->
219277
with self._lock:
220278
if scope_key is None and period_key is None:
221279
self._buckets.clear()
280+
self._last_pruned_period.clear()
222281
return
223282
keys_to_remove = [
224283
k
@@ -230,10 +289,44 @@ def reset(self, scope_key: str | None = None, period_key: str | None = None) ->
230289
del self._buckets[k]
231290

232291
def _get_or_create_bucket(self, key: tuple[str, str]) -> _Bucket | None:
233-
"""Get or create a bucket. Returns None if max_buckets reached."""
292+
"""Get or create a bucket. Returns None if max_buckets reached.
293+
294+
On period rollover (new windowed bucket with a forward period index),
295+
stale buckets for the same window (bucket_index < current - 1) are
296+
pruned BEFORE the max_buckets capacity check, so that a rollover at
297+
capacity can free space rather than fail closed. Cross-scope pruning
298+
is intentional: all stale same-window buckets are dropped regardless
299+
of scope key, since the period has expired globally.
300+
301+
The watermark `_last_pruned_period[window]` only advances forward;
302+
a backwards clock does not trigger spurious prune work.
303+
304+
Caller must hold self._lock.
305+
"""
234306
bucket = self._buckets.get(key)
235307
if bucket is not None:
236308
return bucket
309+
310+
# TTL prune runs BEFORE the max_buckets check so that rollover at
311+
# capacity can reclaim space rather than fail closed permanently.
312+
parsed = _parse_period_key(key[1])
313+
if parsed is not None:
314+
window, index = parsed
315+
last_pruned = self._last_pruned_period.get(window)
316+
# Only advance on forward progress. Backwards clock is a no-op;
317+
# the previously established watermark still protects us.
318+
if last_pruned is None or index > last_pruned:
319+
stale_keys = [
320+
k
321+
for k in self._buckets
322+
if (kp := _parse_period_key(k[1])) is not None
323+
and kp[0] == window
324+
and kp[1] < index - 1
325+
]
326+
for k in stale_keys:
327+
del self._buckets[k]
328+
self._last_pruned_period[window] = index
329+
237330
if len(self._buckets) >= self._max_buckets:
238331
return None
239332
bucket = _Bucket()

evaluators/contrib/budget/src/agent_control_evaluator_budget/budget/store.py

Lines changed: 47 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
"""BudgetStore protocol -- interface for budget storage backends.
1+
"""BudgetStore abstract base class -- interface for budget storage backends.
22
33
Implementations must provide atomic record-and-check: a single call
44
that records usage and returns the current totals. This prevents
@@ -10,9 +10,11 @@
1010

1111
from __future__ import annotations
1212

13+
import inspect
1314
import math
15+
from abc import ABC, abstractmethod
1416
from dataclasses import dataclass
15-
from typing import Protocol, runtime_checkable
17+
from typing import Any
1618

1719

1820
@dataclass(frozen=True)
@@ -49,17 +51,56 @@ def round_spent(value: float) -> int:
4951
return int(value)
5052

5153

52-
@runtime_checkable
53-
class BudgetStore(Protocol):
54-
"""Protocol for budget storage backends.
54+
class BudgetStore(ABC):
55+
"""Abstract base class for budget storage backends.
5556
5657
The store is initialized with a list of BudgetLimitRule and derives
5758
period keys internally from window_seconds + current time.
5859
5960
Callers pass only usage data: scope dict, input_tokens, output_tokens, cost.
61+
62+
Negative `cost` values are permitted and reduce accumulated spend (refund
63+
semantics). `round_spent()` floors the displayed snapshot spend to 0 for
64+
negative accumulators, but the internal float accumulator may go negative
65+
so that a subsequent positive charge cancels correctly. Validation of
66+
cost >= 0 is NOT performed at the store boundary; it is the caller's
67+
responsibility if strict positive accounting is required.
68+
69+
Implementations should be safe to call from async contexts.
70+
InMemoryBudgetStore wraps a sync critical section under threading.Lock
71+
because the work is CPU-bound and brief; distributed backends
72+
(Redis/Postgres) should use native async I/O.
73+
74+
Subclasses must override `record_and_check` with a coroutine function
75+
(`async def`). A sync override is rejected at class creation time rather
76+
than failing silently at the first `await` site in production.
6077
"""
6178

62-
def record_and_check(
79+
def __init_subclass__(cls, **kwargs: Any) -> None:
80+
super().__init_subclass__(**kwargs)
81+
# Walk the MRO to find the nearest override of record_and_check.
82+
# Checking only cls.__dict__ misses mixin-inherited sync overrides
83+
# that satisfy ABC's abstractmethod check but silently break at the
84+
# first `await` call site.
85+
method = None
86+
for base in cls.__mro__:
87+
if base is BudgetStore:
88+
break
89+
if "record_and_check" in base.__dict__:
90+
raw = base.__dict__["record_and_check"]
91+
# Unwrap staticmethod/classmethod descriptors so that
92+
# inspect.iscoroutinefunction sees the underlying function.
93+
method = getattr(raw, "__func__", raw)
94+
break
95+
if method is not None and not inspect.iscoroutinefunction(method):
96+
raise TypeError(
97+
f"{cls.__name__}.record_and_check must be an async def "
98+
"(coroutine function); got a sync function. BudgetStore is "
99+
"an async ABC."
100+
)
101+
102+
@abstractmethod
103+
async def record_and_check(
63104
self,
64105
scope: dict[str, str],
65106
input_tokens: int,
@@ -77,4 +118,3 @@ def record_and_check(
77118
Returns:
78119
List of BudgetSnapshot, one per matching rule.
79120
"""
80-
...

0 commit comments

Comments
 (0)