diff --git a/nodes/src/nodes/library/ssrf_protection.py b/nodes/src/nodes/library/ssrf_protection.py new file mode 100644 index 000000000..14d26717a --- /dev/null +++ b/nodes/src/nodes/library/ssrf_protection.py @@ -0,0 +1,282 @@ +# ============================================================================= +# MIT License +# Copyright (c) 2024 RocketRide Inc. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +# ============================================================================= + +""" +SSRF (Server-Side Request Forgery) protection utilities. + +Validates URLs and resolved IP addresses to prevent requests to private, +loopback, link-local, and reserved IP ranges. Supports a configurable +allowlist so self-hosted operators can permit specific internal services. + +DNS resolution is performed before the IP check to prevent DNS rebinding +attacks where a hostname initially resolves to a public IP but later +resolves to an internal one. + +Usage:: + + from library.ssrf_protection import validate_url, SSRFError + + # Block all private IPs (default) + validate_url('http://192.168.1.1/api') # raises SSRFError + + # Allow specific private ranges + validate_url( + 'http://192.168.1.100/api', + allowed_private=['192.168.1.0/24'], + ) +""" + +from __future__ import annotations + +import ipaddress +import os +import socket +from typing import List, Optional, Sequence +from urllib.parse import urlparse + +# --------------------------------------------------------------------------- +# Blocked networks (RFC 1918, loopback, link-local, metadata, etc.) +# --------------------------------------------------------------------------- + +_BLOCKED_IPV4 = [ + ipaddress.IPv4Network('0.0.0.0/8'), # "This host" (RFC 1122) + ipaddress.IPv4Network('10.0.0.0/8'), # Private (RFC 1918) + ipaddress.IPv4Network('100.64.0.0/10'), # Shared address (RFC 6598) + ipaddress.IPv4Network('127.0.0.0/8'), # Loopback (RFC 1122) + ipaddress.IPv4Network('169.254.0.0/16'), # Link-local (RFC 3927) + cloud metadata + ipaddress.IPv4Network('172.16.0.0/12'), # Private (RFC 1918) + ipaddress.IPv4Network('192.0.0.0/24'), # IETF protocol assignments (RFC 6890) + ipaddress.IPv4Network('192.0.2.0/24'), # Documentation (RFC 5737) + ipaddress.IPv4Network('192.168.0.0/16'), # Private (RFC 1918) + ipaddress.IPv4Network('198.18.0.0/15'), # Benchmarking (RFC 2544) + ipaddress.IPv4Network('198.51.100.0/24'), # Documentation (RFC 5737) + ipaddress.IPv4Network('203.0.113.0/24'), # Documentation (RFC 5737) + ipaddress.IPv4Network('224.0.0.0/4'), # Multicast (RFC 5771) + ipaddress.IPv4Network('240.0.0.0/4'), # Reserved (RFC 1112) + ipaddress.IPv4Network('255.255.255.255/32'), # Broadcast +] + +_BLOCKED_IPV6 = [ + ipaddress.IPv6Network('::1/128'), # Loopback + ipaddress.IPv6Network('::/128'), # Unspecified + ipaddress.IPv6Network('::ffff:0:0/96'), # IPv4-mapped (checked via mapped v4) + ipaddress.IPv6Network('64:ff9b::/96'), # NAT64 (RFC 6052) + ipaddress.IPv6Network('100::/64'), # Discard (RFC 6666) + ipaddress.IPv6Network('2001:db8::/32'), # Documentation (RFC 3849) + ipaddress.IPv6Network('fc00::/7'), # Unique local (RFC 4193) + ipaddress.IPv6Network('fe80::/10'), # Link-local (RFC 4291) + ipaddress.IPv6Network('ff00::/8'), # Multicast (RFC 4291) +] + +# Hostnames that are always blocked regardless of IP resolution. +_BLOCKED_HOSTNAMES = frozenset( + { + 'localhost', + 'metadata.google.internal', + } +) + +# Environment variable for the global allowlist (comma-separated CIDRs). +SSRF_ALLOWLIST_ENV = 'ROCKETRIDE_SSRF_ALLOWLIST' + +# Only allow http and https schemes. +_ALLOWED_SCHEMES = frozenset({'http', 'https'}) + + +# --------------------------------------------------------------------------- +# Exceptions +# --------------------------------------------------------------------------- + + +class SSRFError(ValueError): + """Raised when a URL targets a blocked (private/reserved) IP address.""" + + +# --------------------------------------------------------------------------- +# Public API +# --------------------------------------------------------------------------- + + +def validate_url( + url: str, + *, + allowed_private: Optional[Sequence[str]] = None, +) -> str: + """Validate *url* against SSRF rules and return the resolved URL. + + Parameters + ---------- + url: + The URL to validate (must use ``http`` or ``https`` scheme). + allowed_private: + An optional list of CIDR strings (e.g. ``['192.168.1.0/24']``) that + should be permitted even though they fall within blocked ranges. + This is merged with the global allowlist from the + ``ROCKETRIDE_SSRF_ALLOWLIST`` environment variable. + + Returns + ------- + str + The original *url* unchanged, if validation passes. + + Raises + ------ + SSRFError + If the URL targets a blocked IP, uses a disallowed scheme, or + cannot be resolved. + """ + parsed = urlparse(url) + + # -- Scheme check ------------------------------------------------------- + scheme = (parsed.scheme or '').lower() + if scheme not in _ALLOWED_SCHEMES: + raise SSRFError(f'SSRF protection: scheme {scheme!r} is not allowed. Only {sorted(_ALLOWED_SCHEMES)} are permitted.') + + # -- Extract hostname --------------------------------------------------- + hostname = (parsed.hostname or '').lower().strip('.') + if not hostname: + raise SSRFError('SSRF protection: URL has no hostname.') + + # -- Blocked hostname check --------------------------------------------- + if hostname in _BLOCKED_HOSTNAMES: + raise SSRFError(f'SSRF protection: hostname {hostname!r} is blocked.') + + # -- Build combined allowlist ------------------------------------------- + allow_nets = _build_allowlist(allowed_private) + + # -- DNS resolution + IP check ------------------------------------------ + port = parsed.port or (443 if scheme == 'https' else 80) + _resolve_and_check(hostname, port, allow_nets) + + return url + + +def resolve_and_validate( + hostname: str, + port: int = 80, + *, + allowed_private: Optional[Sequence[str]] = None, +) -> List[str]: + """Resolve *hostname* and validate all resulting IPs. + + Returns the list of resolved IP address strings. Raises ``SSRFError`` + if any resolved address is blocked. + """ + allow_nets = _build_allowlist(allowed_private) + return _resolve_and_check(hostname, port, allow_nets) + + +# --------------------------------------------------------------------------- +# Internals +# --------------------------------------------------------------------------- + + +def _build_allowlist( + extra: Optional[Sequence[str]] = None, +) -> list[ipaddress.IPv4Network | ipaddress.IPv6Network]: + """Merge per-call allowlist with the global env-var allowlist.""" + nets: list[ipaddress.IPv4Network | ipaddress.IPv6Network] = [] + + # Global allowlist from environment + env_val = os.environ.get(SSRF_ALLOWLIST_ENV, '').strip() + if env_val: + for cidr in env_val.split(','): + cidr = cidr.strip() + if cidr: + try: + nets.append(ipaddress.ip_network(cidr, strict=False)) + except ValueError: + pass # silently skip malformed entries + + # Per-call allowlist + for cidr in extra or []: + cidr_s = str(cidr).strip() + if cidr_s: + try: + nets.append(ipaddress.ip_network(cidr_s, strict=False)) + except ValueError: + pass + + return nets + + +def _resolve_and_check( + hostname: str, + port: int, + allow_nets: list[ipaddress.IPv4Network | ipaddress.IPv6Network], +) -> List[str]: + """Resolve hostname via DNS and check every resulting IP.""" + # If hostname is already an IP literal, skip DNS. + try: + addr = ipaddress.ip_address(hostname) + _check_ip(addr, hostname, allow_nets) + return [str(addr)] + except ValueError: + pass # not an IP literal — resolve via DNS + + try: + addrinfos = socket.getaddrinfo(hostname, port, proto=socket.IPPROTO_TCP) + except socket.gaierror as exc: + raise SSRFError(f'SSRF protection: cannot resolve hostname {hostname!r}: {exc}') from exc + + if not addrinfos: + raise SSRFError(f'SSRF protection: hostname {hostname!r} resolved to no addresses.') + + resolved_ips: List[str] = [] + for family, _type, _proto, _canonname, sockaddr in addrinfos: + ip_str = sockaddr[0] + addr = ipaddress.ip_address(ip_str) + _check_ip(addr, hostname, allow_nets) + if ip_str not in resolved_ips: + resolved_ips.append(ip_str) + + return resolved_ips + + +def _check_ip( + addr: ipaddress.IPv4Address | ipaddress.IPv6Address, + hostname: str, + allow_nets: list[ipaddress.IPv4Network | ipaddress.IPv6Network], +) -> None: + """Raise ``SSRFError`` if *addr* falls within a blocked range.""" + # For IPv6-mapped IPv4 addresses, also check the embedded v4 address. + check_addrs = [addr] + if isinstance(addr, ipaddress.IPv6Address) and addr.ipv4_mapped: + check_addrs.append(addr.ipv4_mapped) + + for check_addr in check_addrs: + if not _is_blocked(check_addr): + continue + + # Check if the address is in the allowlist + if any(check_addr in net for net in allow_nets): + continue + + raise SSRFError(f'SSRF protection: request to {hostname!r} blocked — resolved IP {check_addr} is in a private/reserved range. If this is intentional, add the IP or CIDR to the ROCKETRIDE_SSRF_ALLOWLIST environment variable or the node-level allowlist.') + + +def _is_blocked(addr: ipaddress.IPv4Address | ipaddress.IPv6Address) -> bool: + """Return True if *addr* is in any blocked range.""" + if isinstance(addr, ipaddress.IPv4Address): + return any(addr in net for net in _BLOCKED_IPV4) + return any(addr in net for net in _BLOCKED_IPV6) diff --git a/nodes/src/nodes/tool_http_request/IGlobal.py b/nodes/src/nodes/tool_http_request/IGlobal.py index 7a7bc0e92..74bc20651 100644 --- a/nodes/src/nodes/tool_http_request/IGlobal.py +++ b/nodes/src/nodes/tool_http_request/IGlobal.py @@ -32,6 +32,7 @@ from __future__ import annotations +import json as _json import re from typing import List, Set @@ -65,12 +66,14 @@ def beginGlobal(self) -> None: server_name = str((cfg.get('serverName') or 'http')).strip() enabled_methods, url_patterns = self._build_guardrails(cfg) + ssrf_allowed_private = self._build_ssrf_allowlist(cfg) try: self.driver = HttpDriver( server_name=server_name, enabled_methods=enabled_methods, url_patterns=url_patterns, + ssrf_allowed_private=ssrf_allowed_private, ) except Exception as e: warning(str(e)) @@ -86,10 +89,9 @@ def _build_guardrails(cfg: dict) -> tuple[Set[str], List[re.Pattern]]: raw_whitelist = cfg.get('urlWhitelist') or [] if not isinstance(raw_whitelist, list): - import json try: - raw_whitelist = json.loads(str(raw_whitelist)) - except (json.JSONDecodeError, TypeError, ValueError): + raw_whitelist = _json.loads(str(raw_whitelist)) + except (_json.JSONDecodeError, TypeError, ValueError): raw_whitelist = [] patterns: List[re.Pattern] = [] for row in raw_whitelist: @@ -104,6 +106,26 @@ def _build_guardrails(cfg: dict) -> tuple[Set[str], List[re.Pattern]]: return enabled, patterns + @staticmethod + def _build_ssrf_allowlist(cfg: dict) -> List[str]: + """Read the SSRF private-IP allowlist from the node config. + + Expects ``cfg['ssrfAllowlist']`` to be a JSON array of strings + (CIDR notation), e.g. ``["192.168.1.0/24", "10.0.0.5/32"]``. + """ + raw = cfg.get('ssrfAllowlist') or [] + if not isinstance(raw, list): + try: + raw = _json.loads(str(raw)) + except (_json.JSONDecodeError, TypeError, ValueError): + raw = [] + result: List[str] = [] + for entry in raw: + val = str(entry).strip() if entry else '' + if val: + result.append(val) + return result + def validateConfig(self) -> None: try: cfg = Config.getNodeConfig(self.glb.logicalType, self.glb.connConfig) diff --git a/nodes/src/nodes/tool_http_request/http_client.py b/nodes/src/nodes/tool_http_request/http_client.py index 03da20ed9..0f8e47538 100644 --- a/nodes/src/nodes/tool_http_request/http_client.py +++ b/nodes/src/nodes/tool_http_request/http_client.py @@ -32,11 +32,13 @@ import re import time -from typing import Any, Dict, Optional +from typing import Any, Dict, List, Optional import requests from requests.auth import HTTPBasicAuth +from library.ssrf_protection import validate_url + DEFAULT_TIMEOUT_SECONDS = 30 MAX_TIMEOUT_SECONDS = 300 @@ -51,14 +53,24 @@ def execute_request( auth: Optional[Dict[str, Any]] = None, body: Optional[Dict[str, Any]] = None, timeout: Optional[float] = None, + ssrf_allowed_private: Optional[List[str]] = None, ) -> Dict[str, Any]: """Execute an HTTP request and return a structured response. - Raises ``requests.RequestException`` on transport-level failures. - """ + Parameters + ---------- + ssrf_allowed_private: + Optional list of CIDR strings that should be permitted even though + they fall within normally-blocked private/reserved IP ranges. + Raises ``requests.RequestException`` on transport-level failures and + ``SSRFError`` if the URL targets a blocked IP range. + """ resolved_url = _resolve_path_params(url, path_params) + # --- SSRF protection: validate the resolved URL before connecting --- + validate_url(resolved_url, allowed_private=ssrf_allowed_private) + req_headers = dict(headers or {}) req_auth = None extra_params: Dict[str, str] = {} @@ -95,6 +107,7 @@ def execute_request( # Internal helpers # --------------------------------------------------------------------------- + def _resolve_path_params(url: str, path_params: Optional[Dict[str, str]]) -> str: """Replace ``:name`` placeholders in the URL with values from *path_params*.""" if not path_params: diff --git a/nodes/src/nodes/tool_http_request/http_driver.py b/nodes/src/nodes/tool_http_request/http_driver.py index a456176ae..bc7f04fa4 100644 --- a/nodes/src/nodes/tool_http_request/http_driver.py +++ b/nodes/src/nodes/tool_http_request/http_driver.py @@ -170,12 +170,14 @@ def __init__( server_name: str, enabled_methods: Set[str], url_patterns: List[re.Pattern], + ssrf_allowed_private: List[str] | None = None, ): self._server_name = (server_name or '').strip() or 'http' self._tool_name = 'http_request' self._namespaced = f'{self._server_name}.{self._tool_name}' self._enabled_methods = enabled_methods self._url_patterns = url_patterns + self._ssrf_allowed_private = ssrf_allowed_private or [] # ------------------------------------------------------------------ # ToolsBase hooks @@ -250,19 +252,14 @@ def _tool_validate(self, *, tool_name: str, input_obj: Any) -> None: # noqa: AN if method.upper() not in VALID_METHODS: raise ValueError(f'method must be one of {sorted(VALID_METHODS)}; got {method!r}') if method.upper() not in self._enabled_methods: - raise ValueError( - f'HTTP method "{method.upper()}" is not allowed. ' - f'Enabled methods: {", ".join(sorted(self._enabled_methods))}' - ) + raise ValueError(f'HTTP method "{method.upper()}" is not allowed. Enabled methods: {", ".join(sorted(self._enabled_methods))}') # --- Guardrail: URL whitelist (empty list = allow all) --- url = input_obj.get('url') if not url or not isinstance(url, str): raise ValueError('url is required and must be a non-empty string') if self._url_patterns and not any(p.search(url) for p in self._url_patterns): - raise ValueError( - f'URL "{url}" does not match any allowed URL pattern.' - ) + raise ValueError(f'URL "{url}" does not match any allowed URL pattern.') # --- Standard field validation --- auth = input_obj.get('auth') @@ -280,9 +277,7 @@ def _tool_validate(self, *, tool_name: str, input_obj: Any) -> None: # noqa: AN raw = body.get('raw') or {} ct = (raw.get('content_type') or 'application/json').strip().lower() if ct not in VALID_RAW_CONTENT_TYPES: - raise ValueError( - f'body.raw.content_type must be one of {sorted(VALID_RAW_CONTENT_TYPES)}; got {ct!r}' - ) + raise ValueError(f'body.raw.content_type must be one of {sorted(VALID_RAW_CONTENT_TYPES)}; got {ct!r}') def _tool_invoke(self, *, tool_name: str, input_obj: Any) -> Any: # noqa: ANN401 if not isinstance(input_obj, dict): @@ -300,4 +295,5 @@ def _tool_invoke(self, *, tool_name: str, input_obj: Any) -> Any: # noqa: ANN40 auth=input_obj.get('auth'), body=input_obj.get('body'), timeout=input_obj.get('timeout'), + ssrf_allowed_private=self._ssrf_allowed_private or None, ) diff --git a/nodes/src/nodes/tts_openai/IGlobal.py b/nodes/src/nodes/tts_openai/IGlobal.py new file mode 100644 index 000000000..66a7a512b --- /dev/null +++ b/nodes/src/nodes/tts_openai/IGlobal.py @@ -0,0 +1,184 @@ +# ============================================================================= +# MIT License +# Copyright (c) 2026 Aparavi Software AG +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +# ============================================================================= + +import os +import re + +from rocketlib import IGlobalBase, OPEN_MODE, debug, warning +from ai.common.config import Config + + +# Valid values for TTS configuration +VALID_VOICES = {'alloy', 'echo', 'fable', 'onyx', 'nova', 'shimmer'} +VALID_MODELS = {'tts-1', 'tts-1-hd'} +VALID_FORMATS = {'mp3', 'opus', 'aac', 'flac', 'wav', 'pcm'} + +# MIME type mapping for response formats +FORMAT_MIME_TYPES = { + 'mp3': 'audio/mpeg', + 'opus': 'audio/opus', + 'aac': 'audio/aac', + 'flac': 'audio/flac', + 'wav': 'audio/wav', + 'pcm': 'audio/pcm', +} + + +class IGlobal(IGlobalBase): + """ + Global configuration and setup for the OpenAI text-to-speech node. + + Handles API client initialization, configuration validation, and + provides the shared OpenAI client for all instances. + """ + + def validateConfig(self): + """Validate the configuration for the OpenAI TTS node.""" + try: + # Load dependencies + from depends import depends + + requirements = os.path.dirname(os.path.realpath(__file__)) + '/requirements.txt' + depends(requirements) + + from openai import OpenAI, APIStatusError, AuthenticationError, RateLimitError, APIConnectionError, OpenAIError + + # Get config + config = Config.getNodeConfig(self.glb.logicalType, self.glb.connConfig) + apikey = config.get('apikey') + model = config.get('model', 'tts-1') + voice = config.get('voice', 'alloy') + + # Validate model + if model not in VALID_MODELS: + warning(f'Invalid TTS model: {model}. Must be one of: {", ".join(sorted(VALID_MODELS))}') + return + + # Validate voice + if voice not in VALID_VOICES: + warning(f'Invalid voice: {voice}. Must be one of: {", ".join(sorted(VALID_VOICES))}') + return + + # Validate API key with a lightweight call (no billable audio generation) + try: + client = OpenAI(api_key=apikey) + # List models to verify the API key is valid without generating audio + client.models.list() + except APIStatusError as e: + status = getattr(e, 'status_code', None) or getattr(e, 'status', None) + message = str(e) + try: + resp = getattr(e, 'response', None) + data = resp.json() if resp is not None else None + if isinstance(data, dict): + err = data.get('error') + etype = err.get('type') if isinstance(err, dict) else None + emsg = (err.get('message') if isinstance(err, dict) else None) or data.get('message') + parts = [] + if status: + parts.append(f'Error {status}:') + if etype: + parts.append(etype) + if emsg: + if etype: + parts.append('-') + parts.append(emsg) + if parts: + message = ' '.join(parts) + except Exception: + pass + message = re.sub(r'\s+', ' ', message).strip() + if len(message) > 500: + message = message[:500].rstrip() + '\u2026' + warning(message) + return + except (AuthenticationError, RateLimitError, APIConnectionError, OpenAIError) as e: + message = re.sub(r'\s+', ' ', str(e)).strip() + if len(message) > 500: + message = message[:500].rstrip() + '\u2026' + warning(message) + return + + except Exception as e: + warning(str(e)) + return + + def beginGlobal(self): + """ + Initialize the global state. + + Reads configuration values, validates parameters, and creates + the shared OpenAI client for TTS operations. + """ + # Initialize instance state (avoid class-level mutable defaults) + self._client = None + self._model = 'tts-1' + self._voice = 'alloy' + self._speed = 1.0 + self._response_format = 'mp3' + + # Are we in config mode? + if self.IEndpoint.endpoint.openMode == OPEN_MODE.CONFIG: + return + + # Load dependencies + from depends import depends + + requirements = os.path.dirname(os.path.realpath(__file__)) + '/requirements.txt' + depends(requirements) + + from openai import OpenAI + + # Get the passed configuration + config = Config.getNodeConfig(self.glb.logicalType, self.glb.connConfig) + + # Read and validate configuration + self._model = config.get('model', 'tts-1') + self._voice = config.get('voice', 'alloy') + self._speed = config.get('speed', 1.0) + self._response_format = config.get('response_format', 'mp3') + apikey = config.get('apikey') + + # Clamp speed to valid range + self._speed = max(0.25, min(4.0, float(self._speed))) + + # Validate voice + if self._voice not in VALID_VOICES: + self._voice = 'alloy' + + # Validate model + if self._model not in VALID_MODELS: + self._model = 'tts-1' + + # Validate response format + if self._response_format not in VALID_FORMATS: + self._response_format = 'mp3' + + # Create the OpenAI client + self._client = OpenAI(api_key=apikey) + + debug(f' TTS OpenAI: model={self._model}, voice={self._voice}, speed={self._speed}, format={self._response_format}') + + def endGlobal(self): + """Clean up global state.""" + self._client = None diff --git a/nodes/src/nodes/tts_openai/IInstance.py b/nodes/src/nodes/tts_openai/IInstance.py new file mode 100644 index 000000000..713792736 --- /dev/null +++ b/nodes/src/nodes/tts_openai/IInstance.py @@ -0,0 +1,99 @@ +# ============================================================================= +# MIT License +# Copyright (c) 2026 Aparavi Software AG +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +# ============================================================================= + +from rocketlib import AVI_ACTION, IInstanceBase, debug, warning +from .IGlobal import IGlobal, FORMAT_MIME_TYPES + + +class IInstance(IInstanceBase): + """ + Instance class responsible for handling text input and producing audio output. + + Receives text via the text lane, sends it to the OpenAI TTS API, and + writes the resulting audio data to the audio output lane. + """ + + IGlobal: IGlobal + + def writeText(self, text: str): + """ + Receive text input and convert it to speech audio. + + Args: + text: The text content to synthesize into speech. + """ + if not text or not text.strip(): + debug('TTS: skipping empty text input') + return + + text = text.strip() + + # Check that the client is available + if not self.IGlobal._client: + warning('TTS: OpenAI client not initialized') + return + + try: + # Call the OpenAI TTS API + response = self.IGlobal._client.audio.speech.create( + model=self.IGlobal._model, + voice=self.IGlobal._voice, + input=text, + speed=self.IGlobal._speed, + response_format=self.IGlobal._response_format, + ) + + # Read the audio content from the response + audio_data = response.content + + if not audio_data: + warning('TTS: received empty audio response from OpenAI') + return + + # Determine the MIME type for the audio output + mime_type = FORMAT_MIME_TYPES.get(self.IGlobal._response_format, 'audio/mpeg') + + # Write audio output using the AVI action pattern + self.instance.writeAudio(AVI_ACTION.BEGIN, mime_type) + self.instance.writeAudio(AVI_ACTION.WRITE, mime_type, audio_data) + self.instance.writeAudio(AVI_ACTION.END, mime_type) + + debug(f'TTS: generated {len(audio_data)} bytes of {self.IGlobal._response_format} audio') + + except Exception as e: + # Import OpenAI error types for granular handling + try: + from openai import RateLimitError, APIConnectionError + except ImportError: + RateLimitError = None + APIConnectionError = None + + # Rate-limit and connection errors are transient; warn but don't re-raise + if RateLimitError is not None and isinstance(e, RateLimitError): + warning(f'TTS: rate limited by OpenAI, skipping this request: {e}') + elif APIConnectionError is not None and isinstance(e, APIConnectionError): + warning(f'TTS: connection error (transient), skipping this request: {e}') + else: + # Non-transient errors must propagate so the pipeline knows it failed + warning(f'TTS: failed to generate speech: {e}') + raise diff --git a/nodes/src/nodes/tts_openai/__init__.py b/nodes/src/nodes/tts_openai/__init__.py new file mode 100644 index 000000000..1fff03125 --- /dev/null +++ b/nodes/src/nodes/tts_openai/__init__.py @@ -0,0 +1,27 @@ +# ============================================================================= +# MIT License +# Copyright (c) 2026 Aparavi Software AG +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +# ============================================================================= + +from .IGlobal import IGlobal +from .IInstance import IInstance + +__all__ = ['IGlobal', 'IInstance'] diff --git a/nodes/src/nodes/tts_openai/requirements.txt b/nodes/src/nodes/tts_openai/requirements.txt new file mode 100644 index 000000000..ec838c5a8 --- /dev/null +++ b/nodes/src/nodes/tts_openai/requirements.txt @@ -0,0 +1 @@ +openai diff --git a/nodes/src/nodes/tts_openai/services.json b/nodes/src/nodes/tts_openai/services.json new file mode 100644 index 000000000..06bfcfca9 --- /dev/null +++ b/nodes/src/nodes/tts_openai/services.json @@ -0,0 +1,229 @@ +{ + // + // Required: + // The displayable name of this node + // + "title": "Text to Speech", + // + // Required: + // The protocol is the endpoint protocol + // + "protocol": "tts_openai://", + // + // Required: + // Class type of the node - what it does + // + "classType": ["audio"], + // + // Required: + // Capabilities are flags that change the behavior of the underlying + // engine + // + "capabilities": [], + // + // Optional: + // Register is either filter, endpoint or ignored if not specified. If the + // type is specified, a factory is registered of that given type + // + "register": "filter", + // + // Optional: + // The node is the actual pyhsical node to instantiate - if + // not specified, the protocol will be used + // + "node": "python", + // + // Optional: + // The path is the executable/script code - it is node dependent + // and is optional for most node + // + "path": "nodes.tts_openai", + // + // Required: + // The prefix map when added/removed when convertting URLs <=> paths + // + "prefix": "tts_openai", + // + // Optional: + // Description to of this driver + // + "description": [ + "A component that converts text into natural-sounding speech using OpenAI's ", + "text-to-speech API. It supports multiple models (tts-1, tts-1-hd) and voices ", + "(alloy, echo, fable, onyx, nova, shimmer) with configurable speed and output ", + "format. The component accepts text input and produces audio output suitable for ", + "downstream audio processing or playback nodes." + ], + // + // Optional: + // The icon is the icon to display in the UI for this node + // + "icon": "tts-openai.svg", + // + // Optional: + // Rendering hints to the UI which indicate which fields of + // the configuration should be used to display information + // + "tile": [ + "Voice: ${parameters.tts_openai.voice}" + ], + // + // Optional: + // As a pipe component, define what this pipe component takes + // and what it produces + // + "lanes": { + "text": ["audio"] + }, + "input": [ + { + "lane": "text", + "output": [ + { + "lane": "audio" + } + ] + } + ], + // + // Optional: + // Profile section are configuration optoins used by the driver + // itself + // + "preconfig": { + // Define the values that will be merged into any profile configuration + // specified, unless the profile is 'absolute' + "default": "tts-1", + // Defines profiles used with the "profile": key + "profiles": { + "tts-1": { + "title": "TTS-1", + "model": "tts-1", + "voice": "alloy", + "speed": 1.0, + "response_format": "mp3", + "apikey": "" + }, + "tts-1-hd": { + "title": "TTS-1 HD", + "model": "tts-1-hd", + "voice": "alloy", + "speed": 1.0, + "response_format": "mp3", + "apikey": "" + } + } + }, + // + // Optional: + // Local fields defintions - these define fields only for the + // current service. You may specify them here, or directly + // in the shape + // + "fields": { + "tts_openai.voice": { + "type": "string", + "title": "Voice", + "description": "The voice to use for speech synthesis", + "enum": [ + ["alloy", "Alloy - Neutral and balanced"], + ["echo", "Echo - Warm and rounded"], + ["fable", "Fable - Expressive and dramatic"], + ["onyx", "Onyx - Deep and authoritative"], + ["nova", "Nova - Friendly and upbeat"], + ["shimmer", "Shimmer - Clear and bright"] + ], + "default": "alloy" + }, + "tts_openai.speed": { + "type": "number", + "title": "Speed", + "description": "The speed of the generated audio (0.25 to 4.0)", + "default": 1.0 + }, + "tts_openai.response_format": { + "type": "string", + "title": "Output Format", + "description": "The audio format for the generated speech", + "enum": [ + ["mp3", "MP3 - Compressed audio"], + ["opus", "Opus - Low latency streaming"], + ["aac", "AAC - Digital audio compression"], + ["flac", "FLAC - Lossless audio"], + ["wav", "WAV - Uncompressed audio"], + ["pcm", "PCM - Raw audio samples"] + ], + "default": "mp3" + }, + "tts_openai.tts-1": { + "object": "tts-1", + "properties": [ + "llm.cloud.apikey", + "tts_openai.voice", + "tts_openai.speed", + "tts_openai.response_format" + ] + }, + "tts_openai.tts-1-hd": { + "object": "tts-1-hd", + "properties": [ + "llm.cloud.apikey", + "tts_openai.voice", + "tts_openai.speed", + "tts_openai.response_format" + ] + }, + "tts_openai.profile": { + "title": "Model", + "description": "TTS model", + "type": "string", + "default": "tts-1", + "enum": [ + "*>preconfig.profiles.*.title" + ], + "conditional": [ + { + "value": "tts-1", + "properties": [ + "tts_openai.tts-1" + ] + }, + { + "value": "tts-1-hd", + "properties": [ + "tts_openai.tts-1-hd" + ] + } + ] + } + }, + // + // Required: + // Defines the fields (shape) of the service. Either source or target + // map be specified, or both, but at least one is required + // + "shape": [ + { + "section": "Pipe", + "title": "Text to Speech", + "properties": [ + "tts_openai.profile" + ] + } + ], + "test": { + "profiles": ["tts-1"], + "outputs": ["audio"], + "cases": [ + { + "name": "TTS generates audio from text", + "text": "Hello, this is a test of text to speech.", + "expect": { + "audio": { + "notEmpty": true + } + } + } + ] + } +} diff --git a/nodes/test/test_ssrf_protection.py b/nodes/test/test_ssrf_protection.py new file mode 100644 index 000000000..238acb4f8 --- /dev/null +++ b/nodes/test/test_ssrf_protection.py @@ -0,0 +1,304 @@ +# ============================================================================= +# MIT License +# Copyright (c) 2024 RocketRide Inc. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +# ============================================================================= + +"""Tests for the SSRF protection module.""" + +from __future__ import annotations + +import importlib.util +import ipaddress +import os +import socket +import sys +from typing import List, Tuple +from unittest.mock import patch + +import pytest + +# --------------------------------------------------------------------------- +# Load ssrf_protection directly from the file to avoid pulling in the +# ``library`` package's __init__.py which depends on heavy runtime modules. +# --------------------------------------------------------------------------- + +_MOD_PATH = os.path.join(os.path.dirname(__file__), '..', 'src', 'nodes', 'library', 'ssrf_protection.py') + +_spec = importlib.util.spec_from_file_location('ssrf_protection', _MOD_PATH) +_mod = importlib.util.module_from_spec(_spec) +_spec.loader.exec_module(_mod) +sys.modules['ssrf_protection'] = _mod + +SSRFError = _mod.SSRFError +_build_allowlist = _mod._build_allowlist +_is_blocked = _mod._is_blocked +validate_url = _mod.validate_url + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _fake_getaddrinfo(ip: str): + """Return a patched getaddrinfo that always resolves to *ip*.""" + + def _patched(host, port, **_kw): + family = socket.AF_INET6 if ':' in ip else socket.AF_INET + return [(family, socket.SOCK_STREAM, socket.IPPROTO_TCP, '', (ip, port))] + + return _patched + + +def _fake_getaddrinfo_multi(ips: List[Tuple[str, int]]): + """Return a patched getaddrinfo that resolves to multiple addresses.""" + + def _patched(host, port, **_kw): + results = [] + for ip, family in ips: + results.append((family, socket.SOCK_STREAM, socket.IPPROTO_TCP, '', (ip, port))) + return results + + return _patched + + +# --------------------------------------------------------------------------- +# Tests: blocked IP ranges +# --------------------------------------------------------------------------- + + +class TestBlockedIPv4: + """Private / reserved IPv4 addresses must be blocked.""" + + @pytest.mark.parametrize( + 'ip', + [ + '127.0.0.1', # loopback + '127.0.0.2', # loopback range + '10.0.0.1', # RFC 1918 + '10.255.255.255', # RFC 1918 top + '172.16.0.1', # RFC 1918 + '172.31.255.255', # RFC 1918 top + '192.168.0.1', # RFC 1918 + '192.168.255.255', # RFC 1918 top + '169.254.169.254', # cloud metadata endpoint + '169.254.0.1', # link-local + '0.0.0.0', # this host + ], + ) + def test_blocked_ipv4(self, ip): + with patch('ssrf_protection.socket.getaddrinfo', _fake_getaddrinfo(ip)): + with pytest.raises(SSRFError, match='private/reserved range'): + validate_url(f'http://example.com/{ip}') + + @pytest.mark.parametrize( + 'ip', + [ + '127.0.0.1', + '10.0.0.1', + '172.16.0.1', + '192.168.0.1', + '169.254.169.254', + '0.0.0.0', + ], + ) + def test_blocked_ip_literal(self, ip): + """Direct IP literals in the URL are also blocked.""" + with pytest.raises(SSRFError, match='private/reserved range'): + validate_url(f'http://{ip}/path') + + +class TestBlockedIPv6: + """Private / reserved IPv6 addresses must be blocked.""" + + @pytest.mark.parametrize( + 'ip', + [ + '::1', # loopback + 'fc00::1', # unique local + 'fd12:3456::1', # unique local + 'fe80::1', # link-local + ], + ) + def test_blocked_ipv6(self, ip): + with patch('ssrf_protection.socket.getaddrinfo', _fake_getaddrinfo(ip)): + with pytest.raises(SSRFError, match='private/reserved range'): + validate_url('http://example.com/') + + +# --------------------------------------------------------------------------- +# Tests: allowed public IPs +# --------------------------------------------------------------------------- + + +class TestAllowedPublic: + """Public IP addresses must pass validation.""" + + @pytest.mark.parametrize( + 'ip', + [ + '8.8.8.8', # Google DNS + '1.1.1.1', # Cloudflare DNS + '93.184.216.34', # example.com + '151.101.1.140', # a CDN address + ], + ) + def test_public_ip_allowed(self, ip): + with patch('ssrf_protection.socket.getaddrinfo', _fake_getaddrinfo(ip)): + result = validate_url('http://example.com/') + assert result == 'http://example.com/' + + +# --------------------------------------------------------------------------- +# Tests: scheme validation +# --------------------------------------------------------------------------- + + +class TestSchemeValidation: + """Only http and https schemes are allowed.""" + + def test_http_allowed(self): + with patch('ssrf_protection.socket.getaddrinfo', _fake_getaddrinfo('8.8.8.8')): + validate_url('http://example.com/') + + def test_https_allowed(self): + with patch('ssrf_protection.socket.getaddrinfo', _fake_getaddrinfo('8.8.8.8')): + validate_url('https://example.com/') + + @pytest.mark.parametrize( + 'url', + [ + 'ftp://example.com/', + 'file:///etc/passwd', + 'gopher://example.com/', + 'dict://example.com/', + ], + ) + def test_disallowed_scheme(self, url): + with pytest.raises(SSRFError, match='scheme.*is not allowed'): + validate_url(url) + + +# --------------------------------------------------------------------------- +# Tests: hostname validation +# --------------------------------------------------------------------------- + + +class TestHostnameValidation: + """Blocked hostnames must be rejected.""" + + def test_localhost_blocked(self): + with pytest.raises(SSRFError, match='hostname.*blocked'): + validate_url('http://localhost/path') + + def test_metadata_google_blocked(self): + with pytest.raises(SSRFError, match='hostname.*blocked'): + validate_url('http://metadata.google.internal/computeMetadata/v1/') + + def test_empty_hostname(self): + with pytest.raises(SSRFError, match='no hostname'): + validate_url('http:///path') + + +# --------------------------------------------------------------------------- +# Tests: allowlist +# --------------------------------------------------------------------------- + + +class TestAllowlist: + """The allowlist should permit specific private ranges.""" + + def test_allowlist_permits_specific_ip(self): + with patch('ssrf_protection.socket.getaddrinfo', _fake_getaddrinfo('192.168.1.100')): + result = validate_url( + 'http://internal-api.local/', + allowed_private=['192.168.1.0/24'], + ) + assert result == 'http://internal-api.local/' + + def test_allowlist_does_not_permit_other_range(self): + with patch('ssrf_protection.socket.getaddrinfo', _fake_getaddrinfo('10.0.0.1')): + with pytest.raises(SSRFError, match='private/reserved range'): + validate_url( + 'http://internal-api.local/', + allowed_private=['192.168.1.0/24'], + ) + + def test_env_allowlist(self): + with patch.dict(os.environ, {'ROCKETRIDE_SSRF_ALLOWLIST': '10.0.0.0/8'}): + with patch('ssrf_protection.socket.getaddrinfo', _fake_getaddrinfo('10.0.0.5')): + result = validate_url('http://internal.corp/') + assert result == 'http://internal.corp/' + + def test_env_allowlist_multiple(self): + with patch.dict(os.environ, {'ROCKETRIDE_SSRF_ALLOWLIST': '10.0.0.0/8, 172.16.0.0/12'}): + with patch('ssrf_protection.socket.getaddrinfo', _fake_getaddrinfo('172.16.5.1')): + result = validate_url('http://internal.corp/') + assert result == 'http://internal.corp/' + + +# --------------------------------------------------------------------------- +# Tests: DNS rebinding prevention +# --------------------------------------------------------------------------- + + +class TestDNSRebinding: + """DNS resolution must happen before connecting.""" + + def test_hostname_resolving_to_private_ip_blocked(self): + """A public-looking hostname that resolves to a private IP is blocked.""" + with patch('ssrf_protection.socket.getaddrinfo', _fake_getaddrinfo('169.254.169.254')): + with pytest.raises(SSRFError, match='private/reserved range'): + validate_url('http://attacker-dns-rebind.evil.com/') + + def test_unresolvable_hostname(self): + """Hostnames that fail DNS resolution must raise SSRFError.""" + with patch( + 'ssrf_protection.socket.getaddrinfo', + side_effect=socket.gaierror('Name or service not known'), + ): + with pytest.raises(SSRFError, match='cannot resolve hostname'): + validate_url('http://nonexistent.invalid/') + + +# --------------------------------------------------------------------------- +# Tests: internal helpers +# --------------------------------------------------------------------------- + + +class TestInternalHelpers: + """Coverage for internal helper functions.""" + + def test_is_blocked_public(self): + assert _is_blocked(ipaddress.ip_address('8.8.8.8')) is False + + def test_is_blocked_private(self): + assert _is_blocked(ipaddress.ip_address('10.0.0.1')) is True + + def test_build_allowlist_empty(self): + with patch.dict(os.environ, {}, clear=True): + nets = _build_allowlist(None) + assert nets == [] + + def test_build_allowlist_malformed_ignored(self): + nets = _build_allowlist(['not-a-cidr', '10.0.0.0/8']) + assert len(nets) == 1 + assert str(nets[0]) == '10.0.0.0/8' diff --git a/packages/shared-ui/src/assets/nodes/tts-openai.svg b/packages/shared-ui/src/assets/nodes/tts-openai.svg new file mode 100644 index 000000000..97bbcc135 --- /dev/null +++ b/packages/shared-ui/src/assets/nodes/tts-openai.svg @@ -0,0 +1 @@ +