Skip to content

Commit ec97055

Browse files
committed
refactor(cache-middleware): address dagardner-nv review feedback
Per design discussion in #1888: - Drop similarity_threshold floor from 0.85 to 0 and remove manual type/range validation from the constructor; Pydantic enforces bounds at the config layer (CacheMiddlewareConfig) and anyone constructing CacheMiddleware directly is on their own - Update CacheMiddlewareConfig.similarity_threshold to ge=0 with updated description documenting the performance cost of low values and the cache-collision risk near 0 - Replace the manual difflib.SequenceMatcher loop in _find_similar_key with difflib.get_close_matches for cleaner, more idiomatic code - Remove constructor-level validation tests that no longer apply Signed-off-by: ColinM-sys <cmcdonough@50words.com>
1 parent 79e3eff commit ec97055

3 files changed

Lines changed: 27 additions & 140 deletions

File tree

packages/nvidia_nat_core/src/nat/middleware/cache/cache_middleware.py

Lines changed: 13 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@
3030

3131
import json
3232
import logging
33-
import math
3433
from collections import OrderedDict
3534
from collections.abc import AsyncIterator
3635
from typing import Any
@@ -45,13 +44,6 @@
4544

4645
logger = logging.getLogger(__name__)
4746

48-
# Lower bound on fuzzy-match similarity to reduce the cache-poisoning surface.
49-
# A threshold below this makes it trivial to craft an input that is "similar
50-
# enough" to a legitimate user's cached key to hijack their response (for the
51-
# current process, which is how the in-memory cache is scoped). 0.85 is the
52-
# smallest value that we're comfortable shipping as an unconditional default;
53-
# operators with strict needs should use 1.0 (exact match only).
54-
_MIN_FUZZY_THRESHOLD = 0.85
5547

5648
# Default bound on cache size. The previous implementation used an unbounded
5749
# dict which, under sustained unique input, grew without limit — a memory-
@@ -96,43 +88,17 @@ def __init__(
9688
Args:
9789
enabled_mode: Either "always" or "eval". If "eval", only caches
9890
when Context.is_evaluating is True.
99-
similarity_threshold: Similarity threshold in [_MIN_FUZZY_THRESHOLD, 1.0].
100-
If 1.0, performs exact matching. Lower values enable difflib-
101-
based fuzzy matching. Values below _MIN_FUZZY_THRESHOLD are
102-
rejected to prevent cache-poisoning where a crafted input
103-
collides with a legitimate user's cached key.
91+
similarity_threshold: Similarity threshold in [0, 1.0]. If 1.0,
92+
performs exact matching. Lower values enable difflib-based
93+
fuzzy matching; note that difflib is quadratic in the worst
94+
case, so large caches with low thresholds may have a
95+
performance cost. Values near 0 increase the risk of cache
96+
collisions where different inputs return the same cached
97+
response.
10498
max_entries: Maximum number of cache entries. When exceeded, the
10599
oldest entry is evicted (LRU). Defaults to
106-
_DEFAULT_MAX_CACHE_ENTRIES. Must be >= 1.
107-
108-
Raises:
109-
ValueError: If similarity_threshold is outside [_MIN_FUZZY_THRESHOLD, 1.0]
110-
or max_entries is not a positive integer.
100+
_DEFAULT_MAX_CACHE_ENTRIES.
111101
"""
112-
# Reject bool explicitly — `isinstance(True, int)` is True in Python,
113-
# and `True`/`False` silently sneaking through as numeric is a classic
114-
# config bug (user passes the wrong key, gets no error). Check bool
115-
# FIRST so the "must be a number" message doesn't lie.
116-
if isinstance(similarity_threshold, bool):
117-
raise ValueError(
118-
f"similarity_threshold must be a number, got bool ({similarity_threshold!r})")
119-
if not isinstance(similarity_threshold, (int, float)):
120-
raise ValueError(
121-
f"similarity_threshold must be a number, got {type(similarity_threshold).__name__}")
122-
if not math.isfinite(similarity_threshold):
123-
raise ValueError(
124-
f"similarity_threshold must be finite, got {similarity_threshold!r}")
125-
if similarity_threshold < _MIN_FUZZY_THRESHOLD or similarity_threshold > 1.0:
126-
raise ValueError(
127-
f"similarity_threshold={similarity_threshold} is outside the safe range "
128-
f"[{_MIN_FUZZY_THRESHOLD}, 1.0]. Lower values make cache-poisoning trivial — "
129-
"a crafted input can collide with a legitimate user's cached key. Use 1.0 "
130-
"for exact matching (recommended), or a value >= "
131-
f"{_MIN_FUZZY_THRESHOLD} for fuzzy matching.")
132-
# Same bool-as-int foot-gun applies to max_entries.
133-
if isinstance(max_entries, bool) or not isinstance(max_entries, int) or max_entries < 1:
134-
raise ValueError(f"max_entries must be a positive integer, got {max_entries!r}")
135-
136102
super().__init__(is_final=True)
137103
self._enabled_mode = enabled_mode
138104
self._similarity_threshold = similarity_threshold
@@ -202,22 +168,13 @@ def _find_similar_key(self, input_str: str) -> str | None:
202168
# Exact matching - fast path
203169
return input_str if input_str in self._cache else None
204170

205-
# Fuzzy matching using difflib
206171
import difflib
207172

208-
best_match = None
209-
best_ratio = 0.0
210-
211-
for cached_key in self._cache:
212-
# Use SequenceMatcher for similarity computation
213-
matcher = difflib.SequenceMatcher(None, input_str, cached_key)
214-
ratio = matcher.ratio()
215-
216-
if ratio >= self._similarity_threshold and ratio > best_ratio:
217-
best_ratio = ratio
218-
best_match = cached_key
219-
220-
return best_match
173+
best_matches = difflib.get_close_matches(
174+
input_str, self._cache.keys(), n=1, cutoff=self._similarity_threshold)
175+
if best_matches:
176+
return best_matches[0]
177+
return None
221178

222179
async def function_middleware_invoke(self,
223180
*args: Any,

packages/nvidia_nat_core/src/nat/middleware/cache/cache_middleware_config.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020

2121
from nat.data_models.middleware import FunctionMiddlewareBaseConfig
2222
from nat.middleware.cache.cache_middleware import _DEFAULT_MAX_CACHE_ENTRIES
23-
from nat.middleware.cache.cache_middleware import _MIN_FUZZY_THRESHOLD
2423

2524

2625
class CacheMiddlewareConfig(FunctionMiddlewareBaseConfig, name="cache"):
@@ -33,13 +32,13 @@ class CacheMiddlewareConfig(FunctionMiddlewareBaseConfig, name="cache"):
3332
enabled_mode: Controls when caching is active:
3433
- "always": Cache is always enabled
3534
- "eval": Cache only active when Context.is_evaluating is True
36-
similarity_threshold: Float in [_MIN_FUZZY_THRESHOLD, 1.0] for input
37-
matching:
35+
similarity_threshold: Float in [0, 1.0] for input matching:
3836
- 1.0: Exact string matching (fastest, recommended)
39-
- >= _MIN_FUZZY_THRESHOLD: Fuzzy matching via difflib. Values
40-
below this bound are rejected as a cache-poisoning risk —
41-
crafted inputs at lower thresholds can collide with a
42-
legitimate user's cached key.
37+
- < 1.0: Fuzzy matching via difflib. Note that difflib is
38+
quadratic in the worst case, so large caches with low
39+
thresholds may have a performance cost. Values near 0
40+
increase the risk of cache collisions where different
41+
inputs return the same cached response.
4342
max_entries: Upper bound on cached entries. When exceeded, the
4443
least-recently-used entry is evicted. Must be a positive int;
4544
defaults to _DEFAULT_MAX_CACHE_ENTRIES.
@@ -49,12 +48,14 @@ class CacheMiddlewareConfig(FunctionMiddlewareBaseConfig, name="cache"):
4948
default="eval", description="When caching is enabled: 'always' or 'eval' (only during evaluation)")
5049
similarity_threshold: float = Field(
5150
default=1.0,
52-
ge=_MIN_FUZZY_THRESHOLD,
51+
ge=0,
5352
le=1.0,
5453
description=(
55-
f"Similarity threshold in [{_MIN_FUZZY_THRESHOLD}, 1.0]. Use 1.0 for exact matching "
56-
"(recommended). Lower values enable fuzzy matching but are bounded below to prevent "
57-
"cache-poisoning collisions with legitimate cached keys."),
54+
"Similarity threshold in [0, 1.0]. Use 1.0 for exact matching (recommended). "
55+
"Lower values enable fuzzy matching via difflib; note that difflib is quadratic "
56+
"in the worst case, so large caches with low thresholds may have a performance "
57+
"cost. Values near 0 increase the risk of cache collisions where different "
58+
"inputs return the same cached response."),
5859
)
5960
max_entries: int = Field(
6061
default=_DEFAULT_MAX_CACHE_ENTRIES,

packages/nvidia_nat_core/tests/nat/middleware/test_cache_middleware.py

Lines changed: 2 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,6 @@ def test_default_initialization(self):
6262

6363
def test_custom_initialization(self):
6464
"""Test custom initialization."""
65-
# Use 0.9 (above the enforced minimum) to exercise non-default fuzzy mode.
6665
middleware = CacheMiddleware(enabled_mode="always", similarity_threshold=0.9)
6766
# Check attributes are set
6867
assert hasattr(middleware, '_enabled_mode')
@@ -109,11 +108,7 @@ async def mock_next_call(*args, **kwargs):
109108
assert result3.result == "Result for test"
110109

111110
async def test_fuzzy_match_caching(self, middleware_context):
112-
"""Test fuzzy matching with similarity_threshold < 1.0.
113-
114-
Uses 0.9 (above the enforced minimum) — 0.8 is no longer a valid
115-
threshold after the cache-poisoning hardening.
116-
"""
111+
"""Test fuzzy matching with similarity_threshold < 1.0."""
117112
middleware = CacheMiddleware(enabled_mode="always", similarity_threshold=0.9)
118113

119114
call_count = 0
@@ -272,9 +267,6 @@ async def mock_next_call(*args, **kwargs):
272267

273268
def test_similarity_computation_for_different_thresholds(self):
274269
"""Test similarity computation for different thresholds."""
275-
# This is more of a unit test for the similarity logic.
276-
# Uses 0.9 (above the enforced minimum) to exercise fuzzy matching
277-
# without enabling cache-poisoning-prone low thresholds.
278270
middleware = CacheMiddleware(enabled_mode="always", similarity_threshold=0.9)
279271

280272
# Directly test internal methods
@@ -291,11 +283,7 @@ def test_similarity_computation_for_different_thresholds(self):
291283
assert middleware._find_similar_key("xyz123abc") is None # noqa
292284

293285
async def test_multiple_similar_entries(self, middleware_context):
294-
"""Test behavior with multiple similar cached entries.
295-
296-
Uses 0.85 (the enforced minimum) instead of the original 0.7 —
297-
below 0.85 is now rejected as a cache-poisoning risk.
298-
"""
286+
"""Test behavior with multiple similar cached entries."""
299287
middleware = CacheMiddleware(enabled_mode="always", similarity_threshold=0.85)
300288

301289
# Pre-populate cache with similar entries
@@ -319,47 +307,6 @@ async def mock_next_call(*args, **kwargs):
319307
# The exact behavior depends on which cached key is most similar
320308

321309

322-
class TestSimilarityThresholdFloor:
323-
"""The constructor must reject similarity thresholds below the safe floor.
324-
325-
Below ~0.85, crafting an input whose difflib ratio exceeds the threshold
326-
against a legitimate cached key is trivial (small edits, common prefixes,
327-
shared structural tokens). Accepting those values silently produces a
328-
cache where one caller can hijack another caller's response.
329-
"""
330-
331-
@pytest.mark.parametrize("threshold", [0.0, 0.3, 0.5, 0.7, 0.84])
332-
def test_below_floor_is_rejected(self, threshold):
333-
with pytest.raises(ValueError, match="outside the safe range"):
334-
CacheMiddleware(enabled_mode="always", similarity_threshold=threshold)
335-
336-
@pytest.mark.parametrize("threshold", [0.85, 0.9, 0.95, 1.0])
337-
def test_at_or_above_floor_is_allowed(self, threshold):
338-
mw = CacheMiddleware(enabled_mode="always", similarity_threshold=threshold)
339-
assert mw._similarity_threshold == threshold # noqa: SLF001
340-
341-
def test_threshold_above_one_is_rejected(self):
342-
with pytest.raises(ValueError, match="outside the safe range"):
343-
CacheMiddleware(enabled_mode="always", similarity_threshold=1.5)
344-
345-
def test_threshold_non_numeric_is_rejected(self):
346-
with pytest.raises(ValueError, match="must be a number"):
347-
CacheMiddleware(enabled_mode="always", similarity_threshold="high") # type: ignore[arg-type]
348-
349-
@pytest.mark.parametrize("bad_bool", [True, False])
350-
def test_threshold_bool_is_rejected(self, bad_bool):
351-
"""`isinstance(True, int)` is True in Python — reject bools explicitly
352-
so a config with the wrong key type doesn't silently become 1.0 or 0.0."""
353-
with pytest.raises(ValueError, match="got bool"):
354-
CacheMiddleware(enabled_mode="always", similarity_threshold=bad_bool) # type: ignore[arg-type]
355-
356-
@pytest.mark.parametrize("bad_value", [float("nan"), float("inf"), float("-inf")])
357-
def test_threshold_non_finite_is_rejected(self, bad_value):
358-
"""NaN, +inf, -inf must be rejected before the range comparison."""
359-
with pytest.raises(ValueError, match="must be finite"):
360-
CacheMiddleware(enabled_mode="always", similarity_threshold=bad_value)
361-
362-
363310
class TestMaxEntriesLruEviction:
364311
"""The cache must bound its size to prevent memory-exhaustion DoS.
365312
@@ -372,24 +319,6 @@ async def test_default_max_entries_is_positive(self):
372319
mw = CacheMiddleware(enabled_mode="always", similarity_threshold=1.0)
373320
assert mw._max_entries > 0 # noqa: SLF001
374321

375-
def test_zero_max_entries_is_rejected(self):
376-
with pytest.raises(ValueError, match="positive integer"):
377-
CacheMiddleware(enabled_mode="always", similarity_threshold=1.0, max_entries=0)
378-
379-
def test_negative_max_entries_is_rejected(self):
380-
with pytest.raises(ValueError, match="positive integer"):
381-
CacheMiddleware(enabled_mode="always", similarity_threshold=1.0, max_entries=-5)
382-
383-
@pytest.mark.parametrize("bad_bool", [True, False])
384-
def test_bool_max_entries_is_rejected(self, bad_bool):
385-
"""Same bool-as-int foot-gun protection as similarity_threshold."""
386-
with pytest.raises(ValueError, match="positive integer"):
387-
CacheMiddleware(
388-
enabled_mode="always",
389-
similarity_threshold=1.0,
390-
max_entries=bad_bool, # type: ignore[arg-type]
391-
)
392-
393322
async def test_cache_evicts_oldest_when_exceeding_max_entries(self, middleware_context):
394323
"""Insert more unique entries than max_entries; verify size stays bounded."""
395324
mw = CacheMiddleware(

0 commit comments

Comments
 (0)