Skip to content
Merged
26 changes: 15 additions & 11 deletions nodes/src/nodes/llm_ibm_watson/ibm_watson.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,20 @@


# Known IBM Cloud regions for Watson services
_VALID_LOCATIONS = frozenset({
'us-south', 'us-east', 'eu-gb', 'eu-de', 'eu-es',
'jp-tok', 'jp-osa', 'au-syd', 'ca-tor', 'br-sao',
})
_VALID_LOCATIONS = frozenset(
{
'us-south',
'us-east',
'eu-gb',
'eu-de',
'eu-es',
'jp-tok',
'jp-osa',
'au-syd',
'ca-tor',
'br-sao',
}
)

_LOCATION_RE = re.compile(r'^[a-z0-9]([a-z0-9-]*[a-z0-9])?$')

Expand All @@ -58,10 +68,7 @@ def _validate_location(location):
if not _LOCATION_RE.match(location):
raise ValueError(f'Invalid location format: {location!r}')
if location not in _VALID_LOCATIONS:
raise ValueError(
f'Unknown IBM Cloud location: {location!r}. '
f'Valid locations: {", ".join(sorted(_VALID_LOCATIONS))}'
)
raise ValueError(f'Unknown IBM Cloud location: {location!r}. Valid locations: {", ".join(sorted(_VALID_LOCATIONS))}')
Comment thread
ryan-t-christensen marked this conversation as resolved.
return f'https://{location}.ml.cloud.ibm.com'


Expand Down Expand Up @@ -131,9 +138,6 @@ def _chat(self, prompt: str) -> str:
Returns:
str: The generated text response from the model
"""
if not prompt:
raise ValueError('Prompt is empty.')

messages = [{'role': 'user', 'content': prompt}]

response = self._llm.chat(messages=messages)
Expand Down
5 changes: 3 additions & 2 deletions nodes/src/nodes/llm_mistral/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from ai.common.schema import Answer, Question
from ai.common.chat import ChatBase
from ai.common.config import Config
from ai.common.validation import validate_prompt
from mistralai.client import Mistral


Expand Down Expand Up @@ -237,14 +238,14 @@ def _getRetryConfig(self, model: str) -> Tuple[int, float]:

def chat(self, question: Question) -> Answer:
"""Send a chat message to Mistral AI and get the response."""
# Get retry configuration for this model
prompt = validate_prompt(question.getPrompt(), self._modelTotalTokens, self.getTokens)
max_retries, base_delay = self._getRetryConfig(self._model)
last_error = None

for attempt in range(max_retries + 1):
try:
# Create the chat message
messages = [{'role': 'user', 'content': question.getPrompt()}]
messages = [{'role': 'user', 'content': prompt}]

# Make the API call
chat_response = self._client.chat.complete(
Expand Down
5 changes: 3 additions & 2 deletions nodes/src/nodes/llm_perplexity/perplexity.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from ai.common.schema import Answer, Question
from ai.common.chat import ChatBase
from ai.common.config import Config
from ai.common.validation import validate_prompt
from langchain_openai import ChatOpenAI


Expand Down Expand Up @@ -179,14 +180,14 @@ def _getRetryConfig(self, model: str) -> tuple[int, float]:

def chat(self, question: Question) -> Answer:
"""Process a question and return an answer with retry logic."""
# Get retry configuration for this model
prompt = validate_prompt(question.getPrompt(), self._modelTotalTokens, self.getTokens)
max_retries, base_delay = self._getRetryConfig(self._model)
last_error = None

for attempt in range(max_retries + 1): # +1 for initial attempt
try:
# Ask the model
results = self._llm.invoke(question.getPrompt())
results = self._llm.invoke(prompt)

# Create and return the answer
answer = Answer(expectJson=question.expectJson)
Expand Down
19 changes: 11 additions & 8 deletions packages/ai/src/ai/common/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from ai.common.schema import Answer, Question
from ai.common.config import Config
from ai.common.util import parseJson
from ai.common.validation import validate_model_name, validate_max_tokens, validate_prompt


class ChatBase:
Expand All @@ -27,6 +28,7 @@ class ChatBase:
This class provides the foundation for AI chat implementations by handling:
- Token counting and management
- Configuration loading and validation
- Input validation and sanitization
- Consistent interface for chat operations
- Warning systems for token limits

Expand Down Expand Up @@ -61,18 +63,17 @@ def __init__(self, provider: str, connConfig: Dict[str, Any], bag: Dict[str, Any

# Extract model configuration - these are the core settings that control
# how the chat driver behaves with respect to token limits
self._model = config.get('model') # Model identifier (e.g., 'gpt-4', 'claude-3')
self._model = validate_model_name(config.get('model'))
self._modelTotalTokens = config.get('modelTotalTokens', 16384) # Default to 16K if not specified
self._modelOutputTokens = config.get('modelOutputTokens', 4096) # Default to 4K if not specified

# Validate and clamp output tokens against known safe maximums
self._modelOutputTokens = validate_max_tokens(self._modelOutputTokens, self._modelTotalTokens)

# We really can't work with a model that has a very small output window
if self._modelOutputTokens < 1024:
raise ValueError(f'Model output tokens ({self._modelOutputTokens}) must be at least 1024')

# If the output tokens exceed the total tokens, adjust accordingly
if self._modelOutputTokens > self._modelTotalTokens:
self._modelOutputTokens = self._modelTotalTokens

# Log the configuration for debugging and monitoring purposes
# This helps track which model and limits are being used in production
debug(f' Model : {self._model}')
Expand Down Expand Up @@ -291,7 +292,7 @@ def _chat_with_retries(self, prompt: str) -> str:
errors occur
"""
from ai.constants import CONST_CHAT_MAX_RETRIES, CONST_CHAT_BASE_DELAY, CONST_CHAT_MAX_DELAY

max_network_retries = CONST_CHAT_MAX_RETRIES
base_delay = CONST_CHAT_BASE_DELAY
max_delay = CONST_CHAT_MAX_DELAY
Expand Down Expand Up @@ -353,6 +354,9 @@ def chat_string(self, prompt: str) -> str:
Exception: If network/API retries are exhausted or non-retryable
errors occur
"""
# Validate and sanitize the prompt before processing
prompt = validate_prompt(prompt, self._modelTotalTokens, self.getTokens)

# Count tokens in the input prompt to check against limits
# This is important for preventing API errors and ensuring quality responses
prompt_tokens = self.getTokens(prompt)
Expand Down Expand Up @@ -415,7 +419,6 @@ def chat(self, question: Question) -> Answer:
if question.expectJson:
max_retries = 3


for retry_count in range(max_retries):
try:
# Parse (and strip any markdown fences) — reuse the result below
Expand All @@ -439,7 +442,7 @@ def chat(self, question: Question) -> Answer:
error_msg = f'Failed to get valid JSON response after {max_retries + 1} attempts. Last response: {response[:200]}...'
debug(f'Error: {error_msg}')
raise ValueError(error_msg)

else:
# Create the answer and assign the text
answer = Answer(expectJson=False)
Expand Down
150 changes: 150 additions & 0 deletions packages/ai/src/ai/common/validation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
"""
Input validation and sanitization utilities for LLM chat drivers.

This module provides functions to validate and sanitize user input before
it is sent to LLM provider APIs. It guards against:

- Control characters that cause API errors or undefined behavior
- Prompts that exceed provider context windows
- Empty or whitespace-only prompts
- Model name strings that contain unexpected characters
- Output token values that exceed known safe maximums
"""

import re
from typing import Optional

from rocketlib import debug

# Matches C0/C1 control characters EXCEPT common whitespace (\t \n \r)
_CONTROL_CHAR_RE = re.compile(r'[\x00-\x08\x0b\x0c\x0e-\x1f\x7f-\x9f]')

# Model names should be alphanumeric with hyphens, dots, slashes, colons, at-signs, and underscores
# e.g. "gpt-4", "claude-3-opus-20240229", "us.anthropic.claude-3", "meta-llama/Llama-3", "org@model"
_MODEL_NAME_RE = re.compile(r'^[a-zA-Z0-9][a-zA-Z0-9._:/@-]*$')

# Absolute upper bound for output tokens across all known providers (as of 2026)
MAX_OUTPUT_TOKENS = 1_000_000


def sanitize_prompt(prompt: str) -> str:
"""Strip control characters from a prompt string.

Removes C0/C1 control characters that are known to cause errors or
undefined behavior in LLM APIs while preserving normal whitespace
(tabs, newlines, carriage returns).

Args:
prompt: The raw prompt string.

Returns:
The sanitized prompt with control characters removed.
"""
sanitized = _CONTROL_CHAR_RE.sub('', prompt)
if sanitized != prompt:
removed_count = len(prompt) - len(sanitized)
debug(f'Sanitized {removed_count} control character(s) from prompt')
return sanitized


def validate_prompt(prompt: str, max_tokens: int, token_counter) -> str:
"""Validate and sanitize a prompt before sending to an LLM API.

Performs the following checks in order:
1. Rejects empty / whitespace-only prompts
2. Strips dangerous control characters
3. Warns if the prompt likely exceeds the model's context window

Args:
prompt: The raw prompt string.
max_tokens: The model's total token limit (context window).
token_counter: A callable that estimates token count for a string.

Returns:
The sanitized prompt string, ready for the API call.

Raises:
ValueError: If the prompt is empty or whitespace-only.
"""
if not prompt or not prompt.strip():
raise ValueError('Prompt is empty or contains only whitespace.')
Comment thread
ryan-t-christensen marked this conversation as resolved.

# Sanitize control characters
prompt = sanitize_prompt(prompt)

Comment thread
coderabbitai[bot] marked this conversation as resolved.
# Re-check after sanitization to catch control-only prompts
if not prompt.strip():
raise ValueError('Prompt is empty after sanitization.')

# Check token count - warn but don't block (ChatBase.chat_string already
# has a softer check; this catches the truly egregious cases early)
try:
token_count = token_counter(prompt)
if token_count > max_tokens:
debug(f'Warning: Prompt ({token_count} tokens) exceeds model context window ({max_tokens} tokens). The request will likely be rejected by the provider.')
except Exception:
# Token counting failures should not block the request
pass
Comment thread
ryan-t-christensen marked this conversation as resolved.

return prompt


def validate_model_name(model: Optional[str]) -> Optional[str]:
"""Validate that a model name is well-formed.

Args:
model: The model identifier string, or None if not yet configured.

Returns:
The validated model name (stripped of leading/trailing whitespace),
or None if model was None (not yet configured).

Raises:
ValueError: If the model name is non-None but empty or contains
invalid characters.
"""
if model is None:
return None

if not isinstance(model, str):
raise ValueError(f'Model name must be a string, got {type(model).__name__}.')

if not model.strip():
raise ValueError('Model name was provided but is empty.')

model = model.strip()

if not _MODEL_NAME_RE.match(model):
raise ValueError(f'Invalid model name: {model!r}. Model names must start with an alphanumeric character and contain only letters, digits, hyphens, dots, underscores, colons, at-signs, or slashes.')

return model


def validate_max_tokens(output_tokens: int, total_tokens: int) -> int:
"""Validate that the output token limit is within reasonable bounds.

Args:
output_tokens: The configured max output tokens.
total_tokens: The model's total context window.

Returns:
The validated output token value (clamped if necessary).

Raises:
ValueError: If output_tokens is not a positive integer.
"""
if not isinstance(output_tokens, int) or isinstance(output_tokens, bool) or output_tokens < 1:
raise ValueError(f'Output tokens must be a positive integer, got {output_tokens!r}.')

if not isinstance(total_tokens, int) or isinstance(total_tokens, bool) or total_tokens < 1:
raise ValueError(f'Total tokens must be a positive integer, got {total_tokens!r}.')

if output_tokens > MAX_OUTPUT_TOKENS:
debug(f'Warning: Output tokens ({output_tokens}) exceeds maximum known limit ({MAX_OUTPUT_TOKENS}). Clamping to {MAX_OUTPUT_TOKENS}.')
output_tokens = MAX_OUTPUT_TOKENS

if output_tokens > total_tokens:
debug(f'Warning: Output tokens ({output_tokens}) exceeds total tokens ({total_tokens}). Clamping to total tokens.')
output_tokens = total_tokens

return output_tokens
Loading