Skip to content

Commit d8047c3

Browse files
committed
fix(cache-middleware): bound cache size and enforce safe fuzzy threshold
CacheMiddleware previously: 1. Used an unbounded dict for storage — sustained unique inputs would grow the cache without limit, a memory-exhaustion DoS surface. 2. Accepted any similarity_threshold in [0, 1]. Low values make cache poisoning trivial: a crafted input whose difflib ratio exceeds the threshold against a legitimate cached key hijacks the legitimate user's response. At threshold=0.5, most short/structured inputs collide on first try. Changes: - Back the cache with an OrderedDict. Evict the oldest entry (LRU) whenever the cache exceeds max_entries. A cache hit moves the entry to the MRU end so frequently-useful results are preferred on eviction. - Add max_entries to __init__ (default 1024, must be a positive int). - Reject similarity_threshold values below 0.85 at construct time with an actionable error. 1.0 remains recommended (exact match only). - Non-numeric threshold values now raise instead of failing later. Breaking change: existing configs that pass similarity_threshold < 0.85 will fail at construct time. This is intentional — any such value is a cache-poisoning foot-gun. The error message points at the two safe choices (use 1.0, or use a value >= 0.85). Test updates: - test_similarity_computation_for_different_thresholds / _fuzzy_match_caching / test_multiple_similar_entries / test_custom_initialization: bumped from 0.5/0.7/0.8 to values >= 0.85 so they exercise fuzzy matching without tripping the new floor. - Added TestSimilarityThresholdFloor: * below-floor values (0.0, 0.3, 0.5, 0.7, 0.84) are rejected * at/above-floor values (0.85, 0.9, 0.95, 1.0) are accepted * threshold > 1.0 is rejected * non-numeric threshold is rejected - Added TestMaxEntriesLruEviction: * default max_entries is positive * zero or negative max_entries is rejected * cache stays at max_entries under sustained unique input * hit promotes entry to MRU; later insert evicts the older one CWE-400 (resource exhaustion) + CWE-345 (insufficient data authenticity / cache-key confusion). Signed-off-by: ColinM-sys <cmcdonough@50words.com>
1 parent 7beda93 commit d8047c3

2 files changed

Lines changed: 191 additions & 15 deletions

File tree

packages/nvidia_nat_core/src/nat/middleware/cache/cache_middleware.py

Lines changed: 66 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030

3131
import json
3232
import logging
33+
from collections import OrderedDict
3334
from collections.abc import AsyncIterator
3435
from typing import Any
3536

@@ -43,6 +44,21 @@
4344

4445
logger = logging.getLogger(__name__)
4546

47+
# Lower bound on fuzzy-match similarity to reduce the cache-poisoning surface.
48+
# A threshold below this makes it trivial to craft an input that is "similar
49+
# enough" to a legitimate user's cached key to hijack their response (for the
50+
# current process, which is how the in-memory cache is scoped). 0.85 is the
51+
# smallest value that we're comfortable shipping as an unconditional default;
52+
# operators with strict needs should use 1.0 (exact match only).
53+
_MIN_FUZZY_THRESHOLD = 0.85
54+
55+
# Default bound on cache size. The previous implementation used an unbounded
56+
# dict which, under sustained unique input, grew without limit — a memory-
57+
# exhaustion DoS and, combined with fuzzy matching, a long-lived surface for
58+
# cross-request confusion. OrderedDict-backed LRU evicts the oldest entry
59+
# when the cache exceeds this bound.
60+
_DEFAULT_MAX_CACHE_ENTRIES = 1024
61+
4662

4763
class CacheMiddleware(FunctionMiddleware):
4864
"""Cache middleware that memoizes function outputs based on input similarity.
@@ -67,19 +83,51 @@ class CacheMiddleware(FunctionMiddleware):
6783
computation.
6884
"""
6985

70-
def __init__(self, *, enabled_mode: str, similarity_threshold: float) -> None:
86+
def __init__(
87+
self,
88+
*,
89+
enabled_mode: str,
90+
similarity_threshold: float,
91+
max_entries: int = _DEFAULT_MAX_CACHE_ENTRIES,
92+
) -> None:
7193
"""Initialize the cache middleware.
7294
7395
Args:
7496
enabled_mode: Either "always" or "eval". If "eval", only caches
7597
when Context.is_evaluating is True.
76-
similarity_threshold: Similarity threshold between 0 and 1.
77-
If 1.0, performs exact matching. Otherwise uses fuzzy matching.
98+
similarity_threshold: Similarity threshold in [_MIN_FUZZY_THRESHOLD, 1.0].
99+
If 1.0, performs exact matching. Lower values enable difflib-
100+
based fuzzy matching. Values below _MIN_FUZZY_THRESHOLD are
101+
rejected to prevent cache-poisoning where a crafted input
102+
collides with a legitimate user's cached key.
103+
max_entries: Maximum number of cache entries. When exceeded, the
104+
oldest entry is evicted (LRU). Defaults to
105+
_DEFAULT_MAX_CACHE_ENTRIES. Must be >= 1.
106+
107+
Raises:
108+
ValueError: If similarity_threshold is outside [_MIN_FUZZY_THRESHOLD, 1.0]
109+
or max_entries is not a positive integer.
78110
"""
111+
if not isinstance(similarity_threshold, (int, float)):
112+
raise ValueError(
113+
f"similarity_threshold must be a number, got {type(similarity_threshold).__name__}")
114+
if similarity_threshold < _MIN_FUZZY_THRESHOLD or similarity_threshold > 1.0:
115+
raise ValueError(
116+
f"similarity_threshold={similarity_threshold} is outside the safe range "
117+
f"[{_MIN_FUZZY_THRESHOLD}, 1.0]. Lower values make cache-poisoning trivial — "
118+
"a crafted input can collide with a legitimate user's cached key. Use 1.0 "
119+
"for exact matching (recommended), or a value >= "
120+
f"{_MIN_FUZZY_THRESHOLD} for fuzzy matching.")
121+
if not isinstance(max_entries, int) or max_entries < 1:
122+
raise ValueError(f"max_entries must be a positive integer, got {max_entries!r}")
123+
79124
super().__init__(is_final=True)
80125
self._enabled_mode = enabled_mode
81126
self._similarity_threshold = similarity_threshold
82-
self._cache: dict[str, Any] = {}
127+
# OrderedDict gives O(1) LRU: move_to_end() on hit, popitem(last=False)
128+
# to evict the oldest when we exceed max_entries.
129+
self._cache: OrderedDict[str, Any] = OrderedDict()
130+
self._max_entries = max_entries
83131

84132
# ==================== Abstract Method Implementations ====================
85133

@@ -199,20 +247,31 @@ async def function_middleware_invoke(self,
199247
# Phase 1: Preprocess - look for a similar cached input
200248
similar_key = self._find_similar_key(input_str)
201249
if similar_key is not None:
202-
# Cache hit - short-circuit and return cached output
250+
# Cache hit - short-circuit and return cached output.
251+
# Move the hit entry to the MRU end so LRU eviction prefers truly
252+
# old entries, not just recently-useful ones.
203253
logger.debug("Cache hit for function %s with similarity %.2f",
204254
context.name,
205255
1.0 if similar_key == input_str else self._similarity_threshold)
256+
self._cache.move_to_end(similar_key)
206257
# Phase 4: Continue - return cached result
207258
return self._cache[similar_key]
208259

209260
# Phase 2: Call next - no cache hit, call next middleware/function
210261
logger.debug("Cache miss for function %s", context.name)
211262
result = await call_next(*args, **kwargs)
212263

213-
# Phase 3: Postprocess - cache the result for future use
264+
# Phase 3: Postprocess - cache the result for future use. Enforce the
265+
# LRU bound BEFORE insert so the new entry always lands in a cache of
266+
# size <= max_entries, preventing unbounded memory growth (DoS).
214267
self._cache[input_str] = result
215-
logger.debug("Cached result for function %s", context.name)
268+
self._cache.move_to_end(input_str)
269+
while len(self._cache) > self._max_entries:
270+
self._cache.popitem(last=False)
271+
logger.debug("Cached result for function %s (size=%d/%d)",
272+
context.name,
273+
len(self._cache),
274+
self._max_entries)
216275

217276
# Phase 4: Continue - return the fresh result
218277
return result

packages/nvidia_nat_core/tests/nat/middleware/test_cache_middleware.py

Lines changed: 125 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,8 @@ def test_default_initialization(self):
6262

6363
def test_custom_initialization(self):
6464
"""Test custom initialization."""
65-
middleware = CacheMiddleware(enabled_mode="always", similarity_threshold=0.8)
65+
# Use 0.9 (above the enforced minimum) to exercise non-default fuzzy mode.
66+
middleware = CacheMiddleware(enabled_mode="always", similarity_threshold=0.9)
6667
# Check attributes are set
6768
assert hasattr(middleware, '_enabled_mode')
6869
assert hasattr(middleware, '_similarity_threshold')
@@ -108,8 +109,12 @@ async def mock_next_call(*args, **kwargs):
108109
assert result3.result == "Result for test"
109110

110111
async def test_fuzzy_match_caching(self, middleware_context):
111-
"""Test fuzzy matching with similarity_threshold < 1.0."""
112-
middleware = CacheMiddleware(enabled_mode="always", similarity_threshold=0.8)
112+
"""Test fuzzy matching with similarity_threshold < 1.0.
113+
114+
Uses 0.9 (above the enforced minimum) — 0.8 is no longer a valid
115+
threshold after the cache-poisoning hardening.
116+
"""
117+
middleware = CacheMiddleware(enabled_mode="always", similarity_threshold=0.9)
113118

114119
call_count = 0
115120

@@ -267,8 +272,10 @@ async def mock_next_call(*args, **kwargs):
267272

268273
def test_similarity_computation_for_different_thresholds(self):
269274
"""Test similarity computation for different thresholds."""
270-
# This is more of a unit test for the similarity logic
271-
middleware = CacheMiddleware(enabled_mode="always", similarity_threshold=0.5)
275+
# This is more of a unit test for the similarity logic.
276+
# Uses 0.9 (above the enforced minimum) to exercise fuzzy matching
277+
# without enabling cache-poisoning-prone low thresholds.
278+
middleware = CacheMiddleware(enabled_mode="always", similarity_threshold=0.9)
272279

273280
# Directly test internal methods
274281
# Add a cached entry
@@ -278,14 +285,18 @@ def test_similarity_computation_for_different_thresholds(self):
278285
# Test various similarity levels
279286
# Exact match
280287
assert middleware._find_similar_key(test_key) == test_key # noqa
281-
# Very similar
288+
# Very similar (one char shorter, ~0.95 ratio)
282289
assert middleware._find_similar_key("hello worl") == test_key # noqa
283290
# Too different - use a completely different string
284291
assert middleware._find_similar_key("xyz123abc") is None # noqa
285292

286293
async def test_multiple_similar_entries(self, middleware_context):
287-
"""Test behavior with multiple similar cached entries."""
288-
middleware = CacheMiddleware(enabled_mode="always", similarity_threshold=0.7)
294+
"""Test behavior with multiple similar cached entries.
295+
296+
Uses 0.85 (the enforced minimum) instead of the original 0.7 —
297+
below 0.85 is now rejected as a cache-poisoning risk.
298+
"""
299+
middleware = CacheMiddleware(enabled_mode="always", similarity_threshold=0.85)
289300

290301
# Pre-populate cache with similar entries
291302
key1 = middleware._serialize_input( # noqa
@@ -306,3 +317,109 @@ async def mock_next_call(*args, **kwargs):
306317
input_str = {"value": "test input X", "number": 42}
307318
await middleware.function_middleware_invoke(input_str, call_next=mock_next_call, context=middleware_context)
308319
# The exact behavior depends on which cached key is most similar
320+
321+
322+
class TestSimilarityThresholdFloor:
323+
"""The constructor must reject similarity thresholds below the safe floor.
324+
325+
Below ~0.85, crafting an input whose difflib ratio exceeds the threshold
326+
against a legitimate cached key is trivial (small edits, common prefixes,
327+
shared structural tokens). Accepting those values silently produces a
328+
cache where one caller can hijack another caller's response.
329+
"""
330+
331+
@pytest.mark.parametrize("threshold", [0.0, 0.3, 0.5, 0.7, 0.84])
332+
def test_below_floor_is_rejected(self, threshold):
333+
with pytest.raises(ValueError, match="outside the safe range"):
334+
CacheMiddleware(enabled_mode="always", similarity_threshold=threshold)
335+
336+
@pytest.mark.parametrize("threshold", [0.85, 0.9, 0.95, 1.0])
337+
def test_at_or_above_floor_is_allowed(self, threshold):
338+
mw = CacheMiddleware(enabled_mode="always", similarity_threshold=threshold)
339+
assert mw._similarity_threshold == threshold # noqa: SLF001
340+
341+
def test_threshold_above_one_is_rejected(self):
342+
with pytest.raises(ValueError, match="outside the safe range"):
343+
CacheMiddleware(enabled_mode="always", similarity_threshold=1.5)
344+
345+
def test_threshold_non_numeric_is_rejected(self):
346+
with pytest.raises(ValueError, match="must be a number"):
347+
CacheMiddleware(enabled_mode="always", similarity_threshold="high") # type: ignore[arg-type]
348+
349+
350+
class TestMaxEntriesLruEviction:
351+
"""The cache must bound its size to prevent memory-exhaustion DoS.
352+
353+
The previous implementation used an unbounded dict; sustained unique
354+
inputs would grow the cache without limit, eventually crashing the
355+
process. LRU eviction ensures the cache stays within max_entries.
356+
"""
357+
358+
async def test_default_max_entries_is_positive(self):
359+
mw = CacheMiddleware(enabled_mode="always", similarity_threshold=1.0)
360+
assert mw._max_entries > 0 # noqa: SLF001
361+
362+
def test_zero_max_entries_is_rejected(self):
363+
with pytest.raises(ValueError, match="positive integer"):
364+
CacheMiddleware(enabled_mode="always", similarity_threshold=1.0, max_entries=0)
365+
366+
def test_negative_max_entries_is_rejected(self):
367+
with pytest.raises(ValueError, match="positive integer"):
368+
CacheMiddleware(enabled_mode="always", similarity_threshold=1.0, max_entries=-5)
369+
370+
async def test_cache_evicts_oldest_when_exceeding_max_entries(self, middleware_context):
371+
"""Insert more unique entries than max_entries; verify size stays bounded."""
372+
mw = CacheMiddleware(
373+
enabled_mode="always",
374+
similarity_threshold=1.0, # exact match keeps the test deterministic
375+
max_entries=3,
376+
)
377+
378+
call_count = 0
379+
380+
async def mock_next_call(*_args, **_kwargs):
381+
nonlocal call_count
382+
call_count += 1
383+
return _TestOutput(result=f"result_{call_count}")
384+
385+
for i in range(10):
386+
await mw.function_middleware_invoke(
387+
{"value": f"unique_input_{i}"},
388+
call_next=mock_next_call,
389+
context=middleware_context,
390+
)
391+
392+
assert len(mw._cache) == 3 # noqa: SLF001
393+
# The MOST recent three inserts should be what's left.
394+
latest_keys = list(mw._cache.keys()) # noqa: SLF001
395+
for i in range(7, 10):
396+
assert any(f"unique_input_{i}" in k for k in latest_keys)
397+
398+
async def test_cache_hit_promotes_entry_to_most_recently_used(self, middleware_context):
399+
"""A cache hit should move the entry to MRU so later evictions spare it."""
400+
mw = CacheMiddleware(
401+
enabled_mode="always",
402+
similarity_threshold=1.0,
403+
max_entries=3,
404+
)
405+
406+
async def mock_next_call(*_args, **_kwargs):
407+
return _TestOutput(result="r")
408+
409+
# Fill the cache with A, B, C (A is oldest)
410+
for key in ("A", "B", "C"):
411+
await mw.function_middleware_invoke(
412+
{"value": key}, call_next=mock_next_call, context=middleware_context)
413+
414+
# Hit A again — should promote A to the MRU end
415+
await mw.function_middleware_invoke(
416+
{"value": "A"}, call_next=mock_next_call, context=middleware_context)
417+
418+
# Now insert D — B (now oldest) should be evicted, not A.
419+
await mw.function_middleware_invoke(
420+
{"value": "D"}, call_next=mock_next_call, context=middleware_context)
421+
422+
keys = "".join(list(mw._cache.keys())) # noqa: SLF001
423+
assert '"value": "A"' in keys
424+
assert '"value": "D"' in keys
425+
assert '"value": "B"' not in keys

0 commit comments

Comments
 (0)