Skip to content

Commit d09cdc6

Browse files
charliegilletclauderyan-t-christensen
committed
fix(nodes): add input validation/sanitization to LLM chat drivers (#559)
* feat(vscode): improve stop button feedback in Pipeline Observability screen Handle TASK_STATE.STOPPING in the control button to show "Stopping..." with a disabled state and distinct orange styling, preventing duplicate clicks and giving immediate visual feedback during pipeline shutdown. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * fix(nodes): add input validation and sanitization to LLM chat drivers Add a shared validation module and integrate it into ChatBase and all LLM nodes that bypass the base chat path, preventing control character injection, empty prompts, malformed model names, and unsafe token limits from reaching provider APIs. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * fix: remove unrelated PageStatus "Stopping..." changes from LLM validation PR The PageStatus changes belong in a separate PR (#549) and were accidentally included here. * fix: address CodeRabbit review findings on input validation - Re-check prompt emptiness after sanitization to catch control-only prompts - Validate total_tokens param before clamping output tokens - Move min 1024 output tokens check after validate_max_tokens clamping - Use validate_prompt() in Mistral/Perplexity chat() overrides to match shared ChatBase validation behavior Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * fix(nodes): address kwit75's review feedback on LLM input validation - Remove double/triple validation: keep validate_prompt() only in chat_string() (the main entry point), remove redundant sanitize_prompt() from _chat() and validate_prompt() from Mistral/Perplexity chat() overrides - Fix validate_model_name(None) breaking existing setups: return None gracefully when model is not yet configured instead of raising ValueError - Add @ to model name regex to support org@model provider formats - IBM Watson formatting changes kept as-is since they are ruff format enforced style, not hand-made cosmetic changes Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * fix(nodes): remove redundant sanitize_prompt from gemini and ibm_watson drivers Validation is now centralized in ChatBase.chat_string() — individual drivers should not duplicate sanitization. Removes the redundant calls and imports per kwit75's review feedback. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * fix(nodes): add type guards for model name and token validation Address CodeRabbit suggestions: - Guard validate_model_name against non-string inputs - Guard validate_max_tokens against bool values (isinstance(True, int) is True in Python) * add mistral and perplexity validation --------- Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Co-authored-by: ryan-t-christensen <ryan.christensen@rocketride.ai>
1 parent 715d459 commit d09cdc6

5 files changed

Lines changed: 182 additions & 23 deletions

File tree

nodes/src/nodes/llm_ibm_watson/ibm_watson.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,20 @@
3131

3232

3333
# Known IBM Cloud regions for Watson services
34-
_VALID_LOCATIONS = frozenset({
35-
'us-south', 'us-east', 'eu-gb', 'eu-de', 'eu-es',
36-
'jp-tok', 'jp-osa', 'au-syd', 'ca-tor', 'br-sao',
37-
})
34+
_VALID_LOCATIONS = frozenset(
35+
{
36+
'us-south',
37+
'us-east',
38+
'eu-gb',
39+
'eu-de',
40+
'eu-es',
41+
'jp-tok',
42+
'jp-osa',
43+
'au-syd',
44+
'ca-tor',
45+
'br-sao',
46+
}
47+
)
3848

3949
_LOCATION_RE = re.compile(r'^[a-z0-9]([a-z0-9-]*[a-z0-9])?$')
4050

@@ -58,10 +68,7 @@ def _validate_location(location):
5868
if not _LOCATION_RE.match(location):
5969
raise ValueError(f'Invalid location format: {location!r}')
6070
if location not in _VALID_LOCATIONS:
61-
raise ValueError(
62-
f'Unknown IBM Cloud location: {location!r}. '
63-
f'Valid locations: {", ".join(sorted(_VALID_LOCATIONS))}'
64-
)
71+
raise ValueError(f'Unknown IBM Cloud location: {location!r}. Valid locations: {", ".join(sorted(_VALID_LOCATIONS))}')
6572
return f'https://{location}.ml.cloud.ibm.com'
6673

6774

@@ -131,9 +138,6 @@ def _chat(self, prompt: str) -> str:
131138
Returns:
132139
str: The generated text response from the model
133140
"""
134-
if not prompt:
135-
raise ValueError('Prompt is empty.')
136-
137141
messages = [{'role': 'user', 'content': prompt}]
138142

139143
response = self._llm.chat(messages=messages)

nodes/src/nodes/llm_mistral/mistral.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
from ai.common.schema import Answer, Question
4242
from ai.common.chat import ChatBase
4343
from ai.common.config import Config
44+
from ai.common.validation import validate_prompt
4445
from mistralai.client import Mistral
4546

4647

@@ -237,14 +238,14 @@ def _getRetryConfig(self, model: str) -> Tuple[int, float]:
237238

238239
def chat(self, question: Question) -> Answer:
239240
"""Send a chat message to Mistral AI and get the response."""
240-
# Get retry configuration for this model
241+
prompt = validate_prompt(question.getPrompt(), self._modelTotalTokens, self.getTokens)
241242
max_retries, base_delay = self._getRetryConfig(self._model)
242243
last_error = None
243244

244245
for attempt in range(max_retries + 1):
245246
try:
246247
# Create the chat message
247-
messages = [{'role': 'user', 'content': question.getPrompt()}]
248+
messages = [{'role': 'user', 'content': prompt}]
248249

249250
# Make the API call
250251
chat_response = self._client.chat.complete(

nodes/src/nodes/llm_perplexity/perplexity.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
from ai.common.schema import Answer, Question
4141
from ai.common.chat import ChatBase
4242
from ai.common.config import Config
43+
from ai.common.validation import validate_prompt
4344
from langchain_openai import ChatOpenAI
4445

4546

@@ -179,14 +180,14 @@ def _getRetryConfig(self, model: str) -> tuple[int, float]:
179180

180181
def chat(self, question: Question) -> Answer:
181182
"""Process a question and return an answer with retry logic."""
182-
# Get retry configuration for this model
183+
prompt = validate_prompt(question.getPrompt(), self._modelTotalTokens, self.getTokens)
183184
max_retries, base_delay = self._getRetryConfig(self._model)
184185
last_error = None
185186

186187
for attempt in range(max_retries + 1): # +1 for initial attempt
187188
try:
188189
# Ask the model
189-
results = self._llm.invoke(question.getPrompt())
190+
results = self._llm.invoke(prompt)
190191

191192
# Create and return the answer
192193
answer = Answer(expectJson=question.expectJson)

packages/ai/src/ai/common/chat.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from ai.common.schema import Answer, Question
1919
from ai.common.config import Config
2020
from ai.common.util import parseJson
21+
from ai.common.validation import validate_model_name, validate_max_tokens, validate_prompt
2122

2223

2324
class ChatBase:
@@ -27,6 +28,7 @@ class ChatBase:
2728
This class provides the foundation for AI chat implementations by handling:
2829
- Token counting and management
2930
- Configuration loading and validation
31+
- Input validation and sanitization
3032
- Consistent interface for chat operations
3133
- Warning systems for token limits
3234
@@ -61,18 +63,17 @@ def __init__(self, provider: str, connConfig: Dict[str, Any], bag: Dict[str, Any
6163

6264
# Extract model configuration - these are the core settings that control
6365
# how the chat driver behaves with respect to token limits
64-
self._model = config.get('model') # Model identifier (e.g., 'gpt-4', 'claude-3')
66+
self._model = validate_model_name(config.get('model'))
6567
self._modelTotalTokens = config.get('modelTotalTokens', 16384) # Default to 16K if not specified
6668
self._modelOutputTokens = config.get('modelOutputTokens', 4096) # Default to 4K if not specified
6769

70+
# Validate and clamp output tokens against known safe maximums
71+
self._modelOutputTokens = validate_max_tokens(self._modelOutputTokens, self._modelTotalTokens)
72+
6873
# We really can't work with a model that has a very small output window
6974
if self._modelOutputTokens < 1024:
7075
raise ValueError(f'Model output tokens ({self._modelOutputTokens}) must be at least 1024')
7176

72-
# If the output tokens exceed the total tokens, adjust accordingly
73-
if self._modelOutputTokens > self._modelTotalTokens:
74-
self._modelOutputTokens = self._modelTotalTokens
75-
7677
# Log the configuration for debugging and monitoring purposes
7778
# This helps track which model and limits are being used in production
7879
debug(f' Model : {self._model}')
@@ -291,7 +292,7 @@ def _chat_with_retries(self, prompt: str) -> str:
291292
errors occur
292293
"""
293294
from ai.constants import CONST_CHAT_MAX_RETRIES, CONST_CHAT_BASE_DELAY, CONST_CHAT_MAX_DELAY
294-
295+
295296
max_network_retries = CONST_CHAT_MAX_RETRIES
296297
base_delay = CONST_CHAT_BASE_DELAY
297298
max_delay = CONST_CHAT_MAX_DELAY
@@ -353,6 +354,9 @@ def chat_string(self, prompt: str) -> str:
353354
Exception: If network/API retries are exhausted or non-retryable
354355
errors occur
355356
"""
357+
# Validate and sanitize the prompt before processing
358+
prompt = validate_prompt(prompt, self._modelTotalTokens, self.getTokens)
359+
356360
# Count tokens in the input prompt to check against limits
357361
# This is important for preventing API errors and ensuring quality responses
358362
prompt_tokens = self.getTokens(prompt)
@@ -415,7 +419,6 @@ def chat(self, question: Question) -> Answer:
415419
if question.expectJson:
416420
max_retries = 3
417421

418-
419422
for retry_count in range(max_retries):
420423
try:
421424
# Parse (and strip any markdown fences) — reuse the result below
@@ -439,7 +442,7 @@ def chat(self, question: Question) -> Answer:
439442
error_msg = f'Failed to get valid JSON response after {max_retries + 1} attempts. Last response: {response[:200]}...'
440443
debug(f'Error: {error_msg}')
441444
raise ValueError(error_msg)
442-
445+
443446
else:
444447
# Create the answer and assign the text
445448
answer = Answer(expectJson=False)
Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
"""
2+
Input validation and sanitization utilities for LLM chat drivers.
3+
4+
This module provides functions to validate and sanitize user input before
5+
it is sent to LLM provider APIs. It guards against:
6+
7+
- Control characters that cause API errors or undefined behavior
8+
- Prompts that exceed provider context windows
9+
- Empty or whitespace-only prompts
10+
- Model name strings that contain unexpected characters
11+
- Output token values that exceed known safe maximums
12+
"""
13+
14+
import re
15+
from typing import Optional
16+
17+
from rocketlib import debug
18+
19+
# Matches C0/C1 control characters EXCEPT common whitespace (\t \n \r)
20+
_CONTROL_CHAR_RE = re.compile(r'[\x00-\x08\x0b\x0c\x0e-\x1f\x7f-\x9f]')
21+
22+
# Model names should be alphanumeric with hyphens, dots, slashes, colons, at-signs, and underscores
23+
# e.g. "gpt-4", "claude-3-opus-20240229", "us.anthropic.claude-3", "meta-llama/Llama-3", "org@model"
24+
_MODEL_NAME_RE = re.compile(r'^[a-zA-Z0-9][a-zA-Z0-9._:/@-]*$')
25+
26+
# Absolute upper bound for output tokens across all known providers (as of 2026)
27+
MAX_OUTPUT_TOKENS = 1_000_000
28+
29+
30+
def sanitize_prompt(prompt: str) -> str:
31+
"""Strip control characters from a prompt string.
32+
33+
Removes C0/C1 control characters that are known to cause errors or
34+
undefined behavior in LLM APIs while preserving normal whitespace
35+
(tabs, newlines, carriage returns).
36+
37+
Args:
38+
prompt: The raw prompt string.
39+
40+
Returns:
41+
The sanitized prompt with control characters removed.
42+
"""
43+
sanitized = _CONTROL_CHAR_RE.sub('', prompt)
44+
if sanitized != prompt:
45+
removed_count = len(prompt) - len(sanitized)
46+
debug(f'Sanitized {removed_count} control character(s) from prompt')
47+
return sanitized
48+
49+
50+
def validate_prompt(prompt: str, max_tokens: int, token_counter) -> str:
51+
"""Validate and sanitize a prompt before sending to an LLM API.
52+
53+
Performs the following checks in order:
54+
1. Rejects empty / whitespace-only prompts
55+
2. Strips dangerous control characters
56+
3. Warns if the prompt likely exceeds the model's context window
57+
58+
Args:
59+
prompt: The raw prompt string.
60+
max_tokens: The model's total token limit (context window).
61+
token_counter: A callable that estimates token count for a string.
62+
63+
Returns:
64+
The sanitized prompt string, ready for the API call.
65+
66+
Raises:
67+
ValueError: If the prompt is empty or whitespace-only.
68+
"""
69+
if not prompt or not prompt.strip():
70+
raise ValueError('Prompt is empty or contains only whitespace.')
71+
72+
# Sanitize control characters
73+
prompt = sanitize_prompt(prompt)
74+
75+
# Re-check after sanitization to catch control-only prompts
76+
if not prompt.strip():
77+
raise ValueError('Prompt is empty after sanitization.')
78+
79+
# Check token count - warn but don't block (ChatBase.chat_string already
80+
# has a softer check; this catches the truly egregious cases early)
81+
try:
82+
token_count = token_counter(prompt)
83+
if token_count > max_tokens:
84+
debug(f'Warning: Prompt ({token_count} tokens) exceeds model context window ({max_tokens} tokens). The request will likely be rejected by the provider.')
85+
except Exception:
86+
# Token counting failures should not block the request
87+
pass
88+
89+
return prompt
90+
91+
92+
def validate_model_name(model: Optional[str]) -> Optional[str]:
93+
"""Validate that a model name is well-formed.
94+
95+
Args:
96+
model: The model identifier string, or None if not yet configured.
97+
98+
Returns:
99+
The validated model name (stripped of leading/trailing whitespace),
100+
or None if model was None (not yet configured).
101+
102+
Raises:
103+
ValueError: If the model name is non-None but empty or contains
104+
invalid characters.
105+
"""
106+
if model is None:
107+
return None
108+
109+
if not isinstance(model, str):
110+
raise ValueError(f'Model name must be a string, got {type(model).__name__}.')
111+
112+
if not model.strip():
113+
raise ValueError('Model name was provided but is empty.')
114+
115+
model = model.strip()
116+
117+
if not _MODEL_NAME_RE.match(model):
118+
raise ValueError(f'Invalid model name: {model!r}. Model names must start with an alphanumeric character and contain only letters, digits, hyphens, dots, underscores, colons, at-signs, or slashes.')
119+
120+
return model
121+
122+
123+
def validate_max_tokens(output_tokens: int, total_tokens: int) -> int:
124+
"""Validate that the output token limit is within reasonable bounds.
125+
126+
Args:
127+
output_tokens: The configured max output tokens.
128+
total_tokens: The model's total context window.
129+
130+
Returns:
131+
The validated output token value (clamped if necessary).
132+
133+
Raises:
134+
ValueError: If output_tokens is not a positive integer.
135+
"""
136+
if not isinstance(output_tokens, int) or isinstance(output_tokens, bool) or output_tokens < 1:
137+
raise ValueError(f'Output tokens must be a positive integer, got {output_tokens!r}.')
138+
139+
if not isinstance(total_tokens, int) or isinstance(total_tokens, bool) or total_tokens < 1:
140+
raise ValueError(f'Total tokens must be a positive integer, got {total_tokens!r}.')
141+
142+
if output_tokens > MAX_OUTPUT_TOKENS:
143+
debug(f'Warning: Output tokens ({output_tokens}) exceeds maximum known limit ({MAX_OUTPUT_TOKENS}). Clamping to {MAX_OUTPUT_TOKENS}.')
144+
output_tokens = MAX_OUTPUT_TOKENS
145+
146+
if output_tokens > total_tokens:
147+
debug(f'Warning: Output tokens ({output_tokens}) exceeds total tokens ({total_tokens}). Clamping to total tokens.')
148+
output_tokens = total_tokens
149+
150+
return output_tokens

0 commit comments

Comments
 (0)