Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
c817550
feat: implement prompt injection detection module (Issue #1979)
nac7 Jun 6, 2026
f24c247
feat: implement context length validation (Issue #1983)
nac7 Jun 6, 2026
8ef91f0
feat: redact sensitive data from logs to prevent data leaks (Issue #1…
nac7 Jun 6, 2026
884c214
fix: address Greptile and CodeRabbit review comments on PR #1998
nac7 Jun 6, 2026
daa9c98
Fix 4 critical issues on PR #2000 (fix/redact-sensitive-logs branch)
nac7 Jun 6, 2026
7909252
Fix type error in redact_list - accept list or tuple
nac7 Jun 6, 2026
cc9ab91
Add full Apache 2.0 license headers to files modified for PR #2000
nac7 Jun 6, 2026
358168f
Fix model context window partial match to prioritize longer keys
nac7 Jun 6, 2026
2191c3d
fix: support ChatMessage dataclass in estimate_message_tokens
nac7 Jun 7, 2026
117e34f
fix: resolve lint and test CI failures on PR #2000
nac7 Jun 7, 2026
cdee1f8
fix: remaining 2 CI test failures in test_sensitive_redaction
nac7 Jun 7, 2026
e001a54
fix: redact dicts in list values and resolve Codecov GPG failure
nac7 Jun 7, 2026
d97865a
tests: add coverage for uncovered lines across 6 files
nac7 Jun 7, 2026
c398fbf
tests: replace named lambda with def to satisfy ruff E731
nac7 Jun 7, 2026
344ba8f
style: apply ruff-format to test files
nac7 Jun 7, 2026
0bbabc9
tests: cover stream_async injection (274-276), redact_dict non-dict (…
nac7 Jun 7, 2026
9faf41c
fix: add gpt-3.5-turbo 16k variants and gpt-4-32k to MODEL_CONTEXT_WI…
nac7 Jun 7, 2026
8697f6b
fix(redactor): replace substring matching with segment-based keyword …
nac7 Jun 7, 2026
c728abd
fix: anchor system_override to line start and correct nested_comment …
nac7 Jun 8, 2026
cd3a6ad
fix: extend bearer/token regex separator to cover space-separated Aut…
nac7 Jun 8, 2026
33f9905
fix: widen token separator class to [:=\s] to cover tab and all white…
nac7 Jun 8, 2026
b575121
fix: pre-format log records before redacting to prevent TypeError on …
nac7 Jun 8, 2026
2ec1795
fix(token_counter): return None for unknown models to skip context-le…
nac7 Jun 8, 2026
cb36275
fix(ci+injections): remove orphaned GPG step and strip PII from injec…
nac7 Jun 8, 2026
e9997db
Merge branch 'develop' into fix/redact-sensitive-logs
nac7 Jun 8, 2026
f10ac76
fix(sensitive_filter): attach filter to handlers so child-logger reco…
nac7 Jun 8, 2026
e9c51a5
style: apply ruff formatting to test_sensitive_redaction.py
nac7 Jun 8, 2026
a0cb871
fix(sensitive_filter): guard logging.lastResort None to satisfy pyright
nac7 Jun 8, 2026
e22852b
test(sensitive_filter): cover named-logger own-handler branch in setu…
nac7 Jun 8, 2026
e146ab0
fix(test): filter by message instead of total count in test_no_op_tra…
nac7 Jun 8, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions nemoguardrails/actions/llm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
tool_calls_var,
)
from nemoguardrails.exceptions import LLMCallException
from nemoguardrails.llm.token_counter import ContextLengthExceededError, validate_context_length
from nemoguardrails.logging.explain import LLMCallInfo
from nemoguardrails.logging.llm_tracker import track_llm_call
from nemoguardrails.types import ChatMessage, LLMModel, LLMResponse, LLMResponseChunk, UsageInfo
Expand Down Expand Up @@ -74,6 +75,13 @@ async def llm_call(
_log_prompt(prompt)
chat_prompt = _ensure_chat_messages(prompt)

# Validate context length before sending to LLM
try:
validate_context_length(prompt, model_name=model_name or model.model_name)
except ContextLengthExceededError as e:
logger.error(f"Context length validation failed: {e}")
raise LLMCallException(e)
Comment thread
coderabbitai[bot] marked this conversation as resolved.

if streaming_handler:
return await _stream_llm_call(model, chat_prompt, streaming_handler, stop, llm_params)

Expand Down
27 changes: 27 additions & 0 deletions nemoguardrails/guardrails/guardrails.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,9 @@
from nemoguardrails.guardrails.guardrails_types import LLMMessages
from nemoguardrails.guardrails.iorails import IORails
from nemoguardrails.logging.explain import ExplainInfo
from nemoguardrails.logging.sensitive_filter import setup_sensitive_data_filter
from nemoguardrails.rails.llm.config import RailsConfig
from nemoguardrails.rails.llm.injections import PromptInjectionDetectedError, validate_prompt_safety
from nemoguardrails.rails.llm.llmrails import LLMRails
from nemoguardrails.rails.llm.options import GenerationResponse, RailsResult, RailType
from nemoguardrails.types import LLMModel
Expand Down Expand Up @@ -78,6 +80,12 @@ def __init__(
else:
configure_logging(logging.INFO)

# Setup sensitive data redaction in logs to prevent data leaks
try:
setup_sensitive_data_filter(logging.getLogger())
except Exception as e:
log.warning(f"Failed to setup sensitive data filter: {e}")
Comment thread
nac7 marked this conversation as resolved.

if use_iorails:
fallback_reason = IORails.unsupported_reason(config, llm)
if fallback_reason is None:
Expand Down Expand Up @@ -210,6 +218,12 @@ def generate(
"""Generate an LLM response synchronously with guardrails applied.
Supported in both IORails and LLMRails
"""
# Validate input for prompt injection attempts
try:
validate_prompt_safety(prompt=prompt, messages=messages)
except PromptInjectionDetectedError as e:
log.warning(f"Prompt injection attempt blocked: {e}")
raise

generate_messages = self._convert_to_messages(prompt, messages)
return self.rails_engine.generate(messages=generate_messages, **kwargs)
Expand Down Expand Up @@ -238,6 +252,13 @@ async def generate_async(
"""Generate an LLM response asynchronously with guardrails applied.
Supported by both LLMRails and IORails
"""
# Validate input for prompt injection attempts
try:
validate_prompt_safety(prompt=prompt, messages=messages)
except PromptInjectionDetectedError as e:
log.warning(f"Prompt injection attempt blocked: {e}")
raise

await self._ensure_started()

generate_messages = self._convert_to_messages(prompt, messages)
Expand All @@ -247,6 +268,12 @@ def stream_async(
self, prompt: str | None = None, messages: LLMMessages | None = None, **kwargs
) -> AsyncIterator[str | dict]:
"""Generate an LLM response asynchronously with streaming support."""
# Validate input for prompt injection attempts
try:
validate_prompt_safety(prompt=prompt, messages=messages)
except PromptInjectionDetectedError as e:
log.warning(f"Prompt injection attempt blocked: {e}")
raise

stream_messages = self._convert_to_messages(prompt, messages)

Expand Down
250 changes: 250 additions & 0 deletions nemoguardrails/llm/token_counter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,250 @@
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# SPDX-FileCopyrightText: Copyright (c) 2023-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

"""Token counting and context length validation utilities.

Provides methods to estimate token counts for prompts and validate
that prompts don't exceed model context windows.
"""

import dataclasses
import logging
from typing import Any, List, Optional, Union

log = logging.getLogger(__name__)


class ContextLengthExceededError(ValueError):
"""Raised when prompt exceeds model context length."""

def __init__(
self,
message: str,
prompt_tokens: int,
max_tokens: int,
model_name: Optional[str] = None,
):
self.prompt_tokens = prompt_tokens
self.max_tokens = max_tokens
self.model_name = model_name
super().__init__(message)


class TokenCounter:
"""Estimates token counts for various model types."""

# Approximate tokens per character ratios for different model families
# These are conservative estimates; actual counts depend on tokenizer
TOKENS_PER_CHAR = {
"gpt": 0.25, # OpenAI models: ~4 chars per token
"claude": 0.27, # Anthropic: ~3.7 chars per token
"llama": 0.28, # Meta: ~3.6 chars per token
"mistral": 0.28,
"gemini": 0.26,
"default": 0.27,
}

# Model context window limits (in tokens)
MODEL_CONTEXT_WINDOWS = {
# OpenAI
"gpt-4o": 128000,
"gpt-4-turbo": 128000,
"gpt-4-32k": 32768,
"gpt-4": 8192,
# gpt-3.5-turbo-* variants must precede the generic key so the partial-match
# loop (sorted longest-first) finds the specific 16k entry before "gpt-3.5-turbo"
"gpt-3.5-turbo-16k": 16384,
"gpt-3.5-turbo-0125": 16384,
"gpt-3.5-turbo-1106": 16384,
"gpt-3.5-turbo": 4096,
Comment thread
nac7 marked this conversation as resolved.
# Anthropic
"claude-3-opus": 200000,
"claude-3-sonnet": 200000,
"claude-3-haiku": 200000,
"claude-3": 200000,
"claude-2.1": 100000,
"claude-2": 100000,
# Meta Llama
"llama-2": 4096,
"llama-2-70b": 4096,
"llama-3": 8192,
"llama-3-70b": 8192,
# Mistral
"mistral-7b": 32768,
"mistral-large": 32768,
# Google
"gemini-pro": 32768,
"gemini-2.0-flash": 1000000,
# Default fallback
"default": 4096,
}

@staticmethod
def estimate_tokens(text: str) -> int:
"""Estimate token count for text.

Args:
text: The text to estimate tokens for

Returns:
Approximate token count
"""
if not text:
return 0
# Conservative estimate: average ~3.7 characters per token
return max(1, len(text) // 4)

@staticmethod
def estimate_message_tokens(messages: List[Any]) -> int:
"""Estimate total token count for message list.

Accounts for message structure overhead. Accepts both plain dicts and
dataclass instances (e.g. ChatMessage) that expose a ``content`` attribute.

Args:
messages: List of message dicts or dataclass instances with a 'content' field

Returns:
Approximate total token count including formatting
"""
if not messages:
return 0

total_tokens = 0
# Account for message structure overhead (~4 tokens per message)
total_tokens += len(messages) * 4

for msg in messages:
if isinstance(msg, dict):
content = msg.get("content", "")
elif dataclasses.is_dataclass(msg) and not isinstance(msg, type):
content = getattr(msg, "content", None) or ""
else:
continue

if isinstance(content, str):
total_tokens += TokenCounter.estimate_tokens(content)
elif isinstance(content, list):
# For multimodal content
for item in content:
if isinstance(item, dict):
if item.get("type") == "text":
total_tokens += TokenCounter.estimate_tokens(item.get("text", ""))
elif item.get("type") in ("image_url", "image"):
# Image tokens vary; rough estimate
total_tokens += 85

return total_tokens

@staticmethod
def get_model_context_window(model_name: Optional[str]) -> Optional[int]:
"""Get context window size for a model.

Args:
model_name: Name of the model

Returns:
Context window in tokens, or None if the model is not recognised
"""
if not model_name:
return None

model_name_lower = model_name.lower()

# Exact match
if model_name_lower in TokenCounter.MODEL_CONTEXT_WINDOWS:
return TokenCounter.MODEL_CONTEXT_WINDOWS[model_name_lower]

# Partial match (longest/most-specific key first, skip 'default')
for key in sorted(TokenCounter.MODEL_CONTEXT_WINDOWS.keys(), key=len, reverse=True):
if key == "default":
continue
if key in model_name_lower:
return TokenCounter.MODEL_CONTEXT_WINDOWS[key]

# Unknown model — return None so callers can skip validation
return None

@staticmethod
def validate_context_length(
prompt: Union[str, List[Any]],
model_name: Optional[str] = None,
max_tokens: Optional[int] = None,
) -> None:
"""Validate that prompt fits within model context window.

Args:
prompt: The prompt (string, list of dicts, or list of ChatMessage dataclasses) to validate
model_name: Name of the model (for context window lookup)
max_tokens: Override context window size

Raises:
ContextLengthExceededError: If prompt exceeds context window
"""
if isinstance(prompt, str):
prompt_tokens = TokenCounter.estimate_tokens(prompt)
elif isinstance(prompt, list):
prompt_tokens = TokenCounter.estimate_message_tokens(prompt)
else:
return # Can't validate unknown type

Comment thread
coderabbitai[bot] marked this conversation as resolved.
# Determine context window
if max_tokens is None:
max_tokens = TokenCounter.get_model_context_window(model_name)
if max_tokens is None:
log.debug("Skipping context-length check: unrecognised model '%s'", model_name)
return

# Validate (reserve 10% for safety margin and output tokens)
safety_threshold = int(max_tokens * 0.9)

if prompt_tokens > safety_threshold:
raise ContextLengthExceededError(
f"Prompt exceeds model context length. "
f"Prompt tokens: {prompt_tokens}, "
f"Model context window: {max_tokens} "
f"(using 90% threshold: {safety_threshold} tokens). "
f"Context length exceeded by {prompt_tokens - safety_threshold} tokens. "
f"Please reduce prompt length or use a model with larger context window.",
prompt_tokens=prompt_tokens,
max_tokens=max_tokens,
model_name=model_name,
)

log.debug(
f"Prompt token validation passed: {prompt_tokens}/{safety_threshold} tokens "
f"(model: {model_name or 'unknown'})"
)


def validate_context_length(
prompt: Union[str, List[Any]],
model_name: Optional[str] = None,
max_tokens: Optional[int] = None,
) -> None:
"""Convenience function to validate context length.

Args:
prompt: The prompt to validate
model_name: Name of the model
max_tokens: Override context window size

Raises:
ContextLengthExceededError: If prompt exceeds context window
"""
TokenCounter.validate_context_length(prompt, model_name, max_tokens)
Loading