Skip to content

Commit 0172879

Browse files
committed
fix(cache-middleware): align config schema, reject bool/non-finite thresholds
Per CodeRabbit on NVIDIA#1879: 1. cache_middleware_config.py was out of sync with the new constructor semantics — schema allowed similarity_threshold in [0, 1] and had no max_entries field at all, so config-driven instantiation could pass values the constructor now rejects. Re-export _MIN_FUZZY_THRESHOLD and _DEFAULT_MAX_CACHE_ENTRIES from the module and wire both constants into the schema: similarity_threshold: ge=_MIN_FUZZY_THRESHOLD, le=1.0 max_entries: ge=1, default=_DEFAULT_MAX_CACHE_ENTRIES Docstrings mirror the constructor's rationale (cache-poisoning for threshold, DoS for max_entries). 2. Constructor validation did not reject `bool` (since isinstance(True, int) is True in Python) or non-finite floats (NaN / +inf / -inf). Both are classic config-bug foot-guns: a boolean silently becoming 0 or 1.0, or a parser that hands back NaN from an upstream source, would slip past the range comparison. Add explicit: - bool rejection on similarity_threshold (before the number check) - math.isfinite() check on similarity_threshold - bool rejection on max_entries New tests: - test_threshold_bool_is_rejected (parametrized True / False) - test_threshold_non_finite_is_rejected (parametrized NaN / +inf / -inf) - test_bool_max_entries_is_rejected (parametrized True / False) Signed-off-by: ColinM-sys <cmcdonough@50words.com>
1 parent d8047c3 commit 0172879

3 files changed

Lines changed: 64 additions & 8 deletions

File tree

packages/nvidia_nat_core/src/nat/middleware/cache/cache_middleware.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030

3131
import json
3232
import logging
33+
import math
3334
from collections import OrderedDict
3435
from collections.abc import AsyncIterator
3536
from typing import Any
@@ -108,17 +109,28 @@ def __init__(
108109
ValueError: If similarity_threshold is outside [_MIN_FUZZY_THRESHOLD, 1.0]
109110
or max_entries is not a positive integer.
110111
"""
112+
# Reject bool explicitly — `isinstance(True, int)` is True in Python,
113+
# and `True`/`False` silently sneaking through as numeric is a classic
114+
# config bug (user passes the wrong key, gets no error). Check bool
115+
# FIRST so the "must be a number" message doesn't lie.
116+
if isinstance(similarity_threshold, bool):
117+
raise ValueError(
118+
f"similarity_threshold must be a number, got bool ({similarity_threshold!r})")
111119
if not isinstance(similarity_threshold, (int, float)):
112120
raise ValueError(
113121
f"similarity_threshold must be a number, got {type(similarity_threshold).__name__}")
122+
if not math.isfinite(similarity_threshold):
123+
raise ValueError(
124+
f"similarity_threshold must be finite, got {similarity_threshold!r}")
114125
if similarity_threshold < _MIN_FUZZY_THRESHOLD or similarity_threshold > 1.0:
115126
raise ValueError(
116127
f"similarity_threshold={similarity_threshold} is outside the safe range "
117128
f"[{_MIN_FUZZY_THRESHOLD}, 1.0]. Lower values make cache-poisoning trivial — "
118129
"a crafted input can collide with a legitimate user's cached key. Use 1.0 "
119130
"for exact matching (recommended), or a value >= "
120131
f"{_MIN_FUZZY_THRESHOLD} for fuzzy matching.")
121-
if not isinstance(max_entries, int) or max_entries < 1:
132+
# Same bool-as-int foot-gun applies to max_entries.
133+
if isinstance(max_entries, bool) or not isinstance(max_entries, int) or max_entries < 1:
122134
raise ValueError(f"max_entries must be a positive integer, got {max_entries!r}")
123135

124136
super().__init__(is_final=True)

packages/nvidia_nat_core/src/nat/middleware/cache/cache_middleware_config.py

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
from pydantic import Field
2020

2121
from nat.data_models.middleware import FunctionMiddlewareBaseConfig
22+
from nat.middleware.cache.cache_middleware import _DEFAULT_MAX_CACHE_ENTRIES
23+
from nat.middleware.cache.cache_middleware import _MIN_FUZZY_THRESHOLD
2224

2325

2426
class CacheMiddlewareConfig(FunctionMiddlewareBaseConfig, name="cache"):
@@ -31,14 +33,33 @@ class CacheMiddlewareConfig(FunctionMiddlewareBaseConfig, name="cache"):
3133
enabled_mode: Controls when caching is active:
3234
- "always": Cache is always enabled
3335
- "eval": Cache only active when Context.is_evaluating is True
34-
similarity_threshold: Float between 0 and 1 for input matching:
35-
- 1.0: Exact string matching (fastest)
36-
- < 1.0: Fuzzy matching using difflib similarity
36+
similarity_threshold: Float in [_MIN_FUZZY_THRESHOLD, 1.0] for input
37+
matching:
38+
- 1.0: Exact string matching (fastest, recommended)
39+
- >= _MIN_FUZZY_THRESHOLD: Fuzzy matching via difflib. Values
40+
below this bound are rejected as a cache-poisoning risk —
41+
crafted inputs at lower thresholds can collide with a
42+
legitimate user's cached key.
43+
max_entries: Upper bound on cached entries. When exceeded, the
44+
least-recently-used entry is evicted. Must be a positive int;
45+
defaults to _DEFAULT_MAX_CACHE_ENTRIES.
3746
"""
3847

3948
enabled_mode: Literal["always", "eval"] = Field(
4049
default="eval", description="When caching is enabled: 'always' or 'eval' (only during evaluation)")
41-
similarity_threshold: float = Field(default=1.0,
42-
ge=0.0,
43-
le=1.0,
44-
description="Similarity threshold between 0 and 1. Use 1.0 for exact matching")
50+
similarity_threshold: float = Field(
51+
default=1.0,
52+
ge=_MIN_FUZZY_THRESHOLD,
53+
le=1.0,
54+
description=(
55+
f"Similarity threshold in [{_MIN_FUZZY_THRESHOLD}, 1.0]. Use 1.0 for exact matching "
56+
"(recommended). Lower values enable fuzzy matching but are bounded below to prevent "
57+
"cache-poisoning collisions with legitimate cached keys."),
58+
)
59+
max_entries: int = Field(
60+
default=_DEFAULT_MAX_CACHE_ENTRIES,
61+
ge=1,
62+
description=("Maximum number of cache entries before LRU eviction. Must be >= 1. "
63+
"Prevents memory-exhaustion DoS from unbounded cache growth under "
64+
"sustained unique inputs."),
65+
)

packages/nvidia_nat_core/tests/nat/middleware/test_cache_middleware.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -346,6 +346,19 @@ def test_threshold_non_numeric_is_rejected(self):
346346
with pytest.raises(ValueError, match="must be a number"):
347347
CacheMiddleware(enabled_mode="always", similarity_threshold="high") # type: ignore[arg-type]
348348

349+
@pytest.mark.parametrize("bad_bool", [True, False])
350+
def test_threshold_bool_is_rejected(self, bad_bool):
351+
"""`isinstance(True, int)` is True in Python — reject bools explicitly
352+
so a config with the wrong key type doesn't silently become 1.0 or 0.0."""
353+
with pytest.raises(ValueError, match="got bool"):
354+
CacheMiddleware(enabled_mode="always", similarity_threshold=bad_bool) # type: ignore[arg-type]
355+
356+
@pytest.mark.parametrize("bad_value", [float("nan"), float("inf"), float("-inf")])
357+
def test_threshold_non_finite_is_rejected(self, bad_value):
358+
"""NaN, +inf, -inf must be rejected before the range comparison."""
359+
with pytest.raises(ValueError, match="must be finite"):
360+
CacheMiddleware(enabled_mode="always", similarity_threshold=bad_value)
361+
349362

350363
class TestMaxEntriesLruEviction:
351364
"""The cache must bound its size to prevent memory-exhaustion DoS.
@@ -367,6 +380,16 @@ def test_negative_max_entries_is_rejected(self):
367380
with pytest.raises(ValueError, match="positive integer"):
368381
CacheMiddleware(enabled_mode="always", similarity_threshold=1.0, max_entries=-5)
369382

383+
@pytest.mark.parametrize("bad_bool", [True, False])
384+
def test_bool_max_entries_is_rejected(self, bad_bool):
385+
"""Same bool-as-int foot-gun protection as similarity_threshold."""
386+
with pytest.raises(ValueError, match="positive integer"):
387+
CacheMiddleware(
388+
enabled_mode="always",
389+
similarity_threshold=1.0,
390+
max_entries=bad_bool, # type: ignore[arg-type]
391+
)
392+
370393
async def test_cache_evicts_oldest_when_exceeding_max_entries(self, middleware_context):
371394
"""Insert more unique entries than max_entries; verify size stays bounded."""
372395
mw = CacheMiddleware(

0 commit comments

Comments
 (0)