diff --git a/CHANGELOG.md b/CHANGELOG.md index b02c9da1..71d55687 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,6 +3,120 @@ All notable behavioural changes to `strands-robots` are logged here. Follows [Keep a Changelog](https://keepachangelog.com/) conventions. +## Unreleased - #196 (gr00t_inference validation hardening, supersedes #90) + +### Added + +- ``validate_inputs()`` centralises all parameter validation with action-aware + scoping: read-only actions (``find_containers``, ``list``, ``status``, + ``stop``) only validate ``port``/``host``/``protocol``; mutating actions + (``start``, ``restart``, ``lifecycle``) validate the full parameter surface. +- ``host`` parameter now accepts both IP addresses and RFC-952 hostnames + (e.g. ``localhost``, ``host.docker.internal``). All-numeric strings that + fail ``ipaddress.ip_address()`` (e.g. ``127.0.01``, ``999.999.999.999``) + are rejected as obvious IP typos. +- ``protocol`` validation moved into ``validate_inputs()`` (previously + hand-rolled outside the helper, breaking the single-entry-point contract). +- ``pgrep`` patterns now match both ``--port N`` and ``--port=N`` forms, + preventing silent miss when the service is started with the ``=`` syntax. +- Integration tests that invoke ``gr00t_inference()`` end-to-end and assert + that invalid inputs are caught (pins the ``try/except ValueError`` wiring). +- End-to-end regression test for ``_stop_service`` cross-port-kill scenario: + verifies that a process on port 8000 is NOT killed when stopping port 80. +- Exception clauses narrowed throughout: ``_is_gr00t_process`` / ``_is_gr00t_host_process`` + use ``(OSError, subprocess.SubprocessError, UnicodeDecodeError)``; + ``_list_running_services``/``_is_service_running`` use ``OSError``; + ``_stop_service`` uses ``(OSError, subprocess.SubprocessError)``; + ``_start_service`` uses ``(OSError, RuntimeError)``. + Only ``_download_checkpoint`` retains ``except Exception`` (``# noqa: BLE001``) + because huggingface_hub raises varied, opaque exception types. +- ``action`` parameter validated against a complete allowlist of 10 valid + actions; unknown actions get a clear error with the valid set listed. +- ``image_name``, ``volumes``, and ``container_command`` parameters are now + validated (Docker image reference, path traversal, shell metacharacters). +- ``pgrep`` pattern factored into ``_PGREP_INFERENCE_PORT_FMT`` module-level + constant — single source of truth across all 4 usage sites. +- ``_PGREP_INFERENCE_PORT_FMT`` and ``_is_gr00t_*_process`` now match both + N1.5/N1.6 (``inference_service.py``) and N1.7 (``gr00t.eval.run_gr00t_server``) + entry-points. Closes the N1.7 stop/status identification gap. +- All exception types in process-probe helpers now log (WARNING for + PermissionError, DEBUG for other OSError/SubprocessError/UnicodeDecodeError). + +### Changed + +- ``validate_inputs()`` parameters are now all required (no defaults). + ``gr00t_inference()`` is the single source of truth for default values; + the validator no longer duplicates them (prevents silent drift). +- ``_DOCKER_IMAGE_RE`` extended to support private-registry references with + port numbers (e.g. ``localhost:5000/myorg/img:tag``). +- ``_is_gr00t_process`` / ``_is_gr00t_host_process`` now require ``--port`` + in the process cmdline to match — prevents false-killing editors or + log-tailers that happen to touch ``inference_service.py``. +- ``PermissionError`` in process probes now logs at WARNING level instead + of being silently swallowed. +- **BREAKING** Default ``host`` changed from ``0.0.0.0`` to ``127.0.0.1`` + (loopback-only, per AGENTS.md > Review Learnings #86 > "Safety Defaults"). + The ``host`` kwarg now exclusively controls the docker host-side bind via + ``-p {host}:{port}:{port}`` — no silent rewrite. The inference server + inside the container is **always** invoked with ``--host 0.0.0.0`` + (required by docker port forwarding; binding to container-loopback would + make the published port unreachable). User intent is honoured at the + docker layer: + - ``host="127.0.0.1"`` (default) → published port reachable on loopback only + - ``host="0.0.0.0"`` (explicit) → published port reachable on every host iface + **Migration:** if your downstream connects from a different host on the + same machine or across the network, pass ``host="0.0.0.0"`` explicitly + on the ``start`` / ``restart`` / ``start_container`` / ``lifecycle`` calls. +- Host-system fallback (``pgrep``) is documented as Linux-only. Non-Linux + platforms will see "No service running" rather than silently succeeding. + +### Fixed + +- ``hf_repo``, ``hf_subfolder``, ``hf_local_dir`` validation now runs for + ``action='download_checkpoint'`` (R3 introduced these checks but placed + them after the ``_image_only_actions`` early-return, silently bypassing + the path-traversal guard for the action that actually consumes those + parameters; R4 hoists them above all action-specific gates). +- Inference server inside the container is now hardcoded to ``--host 0.0.0.0`` + (R3 forwarded the user's ``host`` kwarg verbatim, so ``host="127.0.0.1"`` + bound the service to container-loopback and the docker port-publish + forwarded to nothing — the headline loopback-default contract was + unreachable end-to-end). +- Duplicate ``torch_mock.manual_seed`` assignment in ``tests/mocks/torch_mock.py``. +- Option-injection guard: ``repo_url``, ``repo_tag``, ``policy_name`` starting + with ``-`` are rejected (prevents git/docker flag injection via subprocess argv). +- Host-system fallback (pgrep) now returns a clear error on non-Linux platforms + instead of silently reporting success. +- All-numeric hostname guard narrowed to multi-label patterns only — single-label + numerics (e.g. ``123``) are valid per RFC-1123. +- ``port`` validator rejects ``bool`` explicitly. ``isinstance(True, int) is True`` + in Python (bool subclasses int) and ``1 <= True <= 65535`` evaluates ``True`` -- + pre-fix, ``port=True`` passed validation and reached ``--port`` argv as the + string ``"True"`` (R6 review thread, ``gr00t_inference.py:223``). Pinned by + ``TestPortBoolRejected``. +- ``hf_repo`` segment validation rejects any segment starting with ``.`` + (catches ``.org/name``, ``org/.git``, ``...../name``, ``org/.name``). The + base regex ``[a-zA-Z0-9_.-]+/[a-zA-Z0-9_.-]+`` plus the existing segment + loop only rejected bare ``.`` / ``..`` segments and leading ``-`` -- the + validator now fails closed locally rather than relying on HuggingFace's + API to reject (R6 review thread, ``gr00t_inference.py:253``). Pinned by + ``TestHfRepoLeadingDotSegments``. + +### Notes + +- Host validation is **broader** than before for hostnames (RFC-952 names like + ``localhost`` and ``host.docker.internal`` now pass) but **stricter** for + IP-like typos (all-numeric labels like ``127.0.01`` are rejected). +- Validation scope covers ``port``, ``host``, ``protocol``, ``data_config``, + ``embodiment_tag``, ``container_name``, TRT dtypes, ``checkpoint_path``, + ``trt_engine_path``, ``image_name``, ``volumes``, and ``container_command``. + Parameters ``repo_url``, ``repo_tag``, and ``policy_name`` are + option-injection-guarded (reject leading ``-``). ``hf_repo``, + ``hf_subfolder``, and ``hf_local_dir`` are path-validated (reject + traversal, shell metacharacters, and malformed repo IDs). The + ``lifecycle`` phase is enum-checked. + + ## Unreleased - #178 (LiberoOffScreenRenderEngine retired) ### Removed: ``LiberoOffScreenRenderEngine`` simulation backend (BREAKING) diff --git a/strands_robots/tools/gr00t_inference.py b/strands_robots/tools/gr00t_inference.py index 8bfd9c6f..3b92412e 100644 --- a/strands_robots/tools/gr00t_inference.py +++ b/strands_robots/tools/gr00t_inference.py @@ -11,9 +11,12 @@ from a single prompt - see #148 for the motivation. """ +import ipaddress import os +import re import socket import subprocess +import sys import time from pathlib import Path from typing import Any @@ -41,6 +44,321 @@ def _checkpoints_dir() -> Path: return get_base_dir() / "checkpoints" +# ───────────────────────────────────────────────────────────────────── +# Input validation helpers +# ───────────────────────────────────────────────────────────────────── + +# Docker image reference pattern — supports registry:port/path:tag and @sha256:digest. +# Examples: "gr00t:latest", "nvcr.io/nvidia/gr00t:n1.7", "localhost:5000/myorg/img:tag" +# "nvcr.io/nvidia/gr00t@sha256:abcdef..." (digest-pinned, supply-chain recommended) +_DOCKER_IMAGE_RE = re.compile( + r"^[a-zA-Z0-9]" # must start with alnum + r"(?:[a-zA-Z0-9._\-]*[a-zA-Z0-9])?" # optional middle chars (host/path prefix) + r"(?::([0-9]{1,5})(?=/))?" + # registry port (:5000) - captured for range check ONLY when followed by /path. + # Without the lookahead, name:digits is ambiguous between host:port and name:tag; + # e.g. "myimage:99999" would be falsely rejected as an invalid port. + r"(?:/[a-zA-Z0-9][a-zA-Z0-9._\-]*)*" # path components (/org/img) + r"(?::[a-zA-Z0-9][a-zA-Z0-9._\-]*" # option A: :tag + r"|@sha256:[a-f0-9]{64})?$" # option B: @sha256:digest (mutually exclusive with tag) +) + + +def _is_valid_docker_image_ref(value: str) -> bool: + """Validate a docker image reference for shape AND range correctness. + + The regex captures the (optional) registry port so we can verify it is + a valid TCP port (1-65535). The regex alone permits digit counts up to 5, + which would otherwise accept refs like ``localhost:99999/img:tag`` even + though no such port can exist. + """ + m = _DOCKER_IMAGE_RE.match(value) + if not m: + return False + port_str = m.group(1) + if port_str is not None: + port_int = int(port_str) + if port_int < 1 or port_int > 65535: + return False + return True + + +# Characters that cause harm in subprocess argv or shell interpolation. +# Narrowed per AGENTS.md review-learnings: quotes/bangs/parens/brackets +# appear in legitimate filesystem paths and all subprocess calls here are +# argv-style (no shell=True), so they pose no injection risk in path values. +# Backslash (\) is also legal on Linux (only / and NUL are forbidden by POSIX) +# and carries no special meaning in argv-style subprocess calls. +_SHELL_META = re.compile(r"[;&|`$<>\n\r\x00]") + +# Strict patterns for enumerable parameters. +_DATA_CONFIG_RE = re.compile(r"^[a-z][a-z0-9_]{0,63}$") +_EMBODIMENT_TAG_RE = re.compile(r"^[a-z][a-z0-9_]{0,31}$") +_CONTAINER_NAME_RE = re.compile(r"^[a-zA-Z0-9][a-zA-Z0-9._-]{0,127}$") +# RFC-952/1123 hostname pattern for host validation. +# Trailing dot is accepted per RFC 1034 §3.1 (FQDNs may end with a dot to +# disambiguate from search-list completion). +_HOSTNAME_RE = re.compile( + r"^[a-zA-Z0-9](?:[a-zA-Z0-9\-]{0,61}[a-zA-Z0-9])?" + r"(?:\.[a-zA-Z0-9](?:[a-zA-Z0-9\-]{0,61}[a-zA-Z0-9])?)*" + r"\.?$" # optional trailing dot (FQDN root indicator) +) +# Reject multi-label all-numeric strings — prevents typos like "127.0.01" +# which pass _HOSTNAME_RE but are clearly malformed IP attempts, not hostnames. +# Single-label numerics (e.g. "123") are valid per RFC-1123. +_ALL_NUMERIC_RE = re.compile(r"^[0-9]+(?:\.[0-9]+)+$") + +# Factored pgrep pattern — single source of truth for both docker-exec and +# host-fallback discovery paths. ERE syntax (procps-ng on Linux). +# NOTE: _PGREP_INFERENCE_PORT_FMT uses `( |$)` (ERE, space-only boundary) while +# _PYTHON_PORT_RE_FMT uses `(?:\s|$)` (Python re, any-whitespace boundary). +# This is intentional: pgrep is constrained to ERE, and cmdlines are space-separated +# in practice (procps-ng converts NUL → space when reading /proc/*/cmdline). +# Matches both N1.5/N1.6 (inference_service.py) and N1.7 (gr00t.eval.run_gr00t_server) +_PGREP_INFERENCE_PORT_FMT = r"(inference_service\.py|gr00t\.eval\.run_gr00t_server).*--port[= ]{port}( |$)" +# Python-side equivalent for re.search — uses (?:\s|$) instead of ( |$) +# because Python re is always ERE-ish and \s is more precise. +_PYTHON_PORT_RE_FMT = r"--port[= ]{port}(?:\s|$)" + +# Allowlists for TensorRT dtype parameters. +_VALID_VIT_DTYPES = {"fp16", "fp8"} +_VALID_LLM_DTYPES = {"fp16", "nvfp4", "fp8"} +_VALID_DIT_DTYPES = {"fp16", "fp8"} + +# Complete allowlist of valid actions for the tool. +_VALID_ACTIONS = frozenset( + { + "find_containers", + "list", + "status", + "stop", + "start", + "restart", + "build_image", + "download_checkpoint", + "start_container", + "lifecycle", + } +) + + +def _validate_path(value: str, label: str, *, reject_colon: bool = False) -> None: + """Reject paths containing shell metacharacters, null bytes, or traversal sequences. + + Args: + value: The path string to validate. + label: Human-readable label for error messages. + reject_colon: When True, reject ':' in the value. Required for Docker + volume mount paths where ':' would be re-interpreted as + host:container:options separator by docker -v. + """ + if "\x00" in value: + raise ValueError(f"{label} must not contain null bytes") + if value.startswith("-"): + raise ValueError(f"{label} must not start with '-' (got {value!r})") + # Split on '/' only. POSIX permits '\\' as a literal filename byte; including it + # in the split would falsely reject paths like 'a\\..\\b' where '..' is not a + # separated component. docker -v interprets only '/' as a separator on Linux, + # which is the only platform this tool supports. + if any(part == ".." for part in value.split("/")): + raise ValueError(f"{label} must not contain '..' path traversal components") + if _SHELL_META.search(value): + raise ValueError(f"{label} contains disallowed characters: {value!r}") + if reject_colon and ":" in value: + raise ValueError( + f"{label} must not contain ':' (docker -v interprets it as host:container:options separator; got {value!r})" + ) + + +def validate_inputs( + *, + action: str, + data_config: str, + embodiment_tag: str, + port: int, + host: str, + vit_dtype: str, + llm_dtype: str, + dit_dtype: str, + checkpoint_path: str | None, + trt_engine_path: str, + container_name: str | None, + protocol: str, + image_name: str | None = None, + volumes: dict[str, str] | None = None, + container_command: str | None = None, + repo_url: str | None = None, + repo_tag: str | None = None, + policy_name: str | None = None, + hf_repo: str | None = None, + hf_subfolder: str | None = None, + hf_local_dir: str | None = None, + lifecycle: str | None = None, +) -> None: + """Validate all user-supplied parameters in one place. + + Raises ValueError for any invalid input. Callers exposing this through + an AgentTool MUST wrap in try/except and convert to the structured error + dict (``{"status": "error", "message": str(e)}``). + + This centralises validation so that the main tool function stays focused + on orchestration and each check is independently testable via this + single entry-point. + + Validation is scoped to the action: actions whose only user-controlled + surface is port/host/protocol (find_containers, list, status, stop) + skip full parameter validation; mutating actions (start, restart, + lifecycle, build_image, download_checkpoint, start_container) validate + the full parameter surface. + """ + # Action allowlist — reject unknown actions early with a clear error + if action not in _VALID_ACTIONS: + raise ValueError(f"Unknown action {action!r}. Valid actions: {sorted(_VALID_ACTIONS)}") + + # Protocol — always validated regardless of action + valid_protocols = ("n1.5", "n1.6", "n1.7") + if protocol not in valid_protocols: + raise ValueError(f"Unknown protocol {protocol!r}. Valid: {list(valid_protocols)}") + # Port range — always validated. Type-check first so callers get ValueError, not TypeError. + # Reject bool explicitly: isinstance(True, int) is True in Python (bool subclasses int), + # and 1 <= True <= 65535 evaluates True. Without this guard, port=True reaches + # --port argv as the string "True" — a subtle failure mode an LLM caller could trip on. + if isinstance(port, bool) or not isinstance(port, int): + raise ValueError(f"port must be an integer, got {type(port).__name__}: {port!r}") + if not (1 <= port <= 65535): + raise ValueError(f"port must be between 1 and 65535, got {port}") + + # Host address validation — always validated (accept IPs and RFC-952 hostnames) + if not isinstance(host, str): + raise ValueError(f"host must be a string, got {type(host).__name__}: {host!r}") + # RFC 1035 §2.3.4: total hostname must not exceed 253 octets. + if len(host) > 253: + raise ValueError(f"host exceeds RFC 1035 maximum length of 253 chars (got {len(host)} chars)") + try: + ipaddress.ip_address(host) + except ValueError: + # Reject all-numeric labels (e.g. "127.0.01") — these are clearly IP typos + # not legitimate hostnames. Real hostnames must have at least one alpha label. + # Rejects all-numeric multi-label strings including Linux IPv4 short-forms + # like "127.1" — use canonical dotted-quad for clarity in agent-driven contexts. + if _ALL_NUMERIC_RE.match(host) or not _HOSTNAME_RE.match(host): + raise ValueError( + f"host must be a valid IP address or hostname (got {host!r}). " + f"Use '127.0.0.1' for loopback, '0.0.0.0' for all interfaces, " + f"or a valid hostname like 'localhost'." + ) from None + + # HuggingFace parameters & lifecycle phase — action-independent format/path + # validation. Validated BEFORE action-specific early-returns so that + # download_checkpoint (which consumes hf_repo/hf_subfolder/hf_local_dir) + # cannot bypass the path-traversal check. + if hf_repo is not None: + if not re.match(r"^[a-zA-Z0-9_.-]+/[a-zA-Z0-9_.-]+$", hf_repo): + raise ValueError(f"hf_repo must be a valid HuggingFace repo id (org/name), got {hf_repo!r}") + # Segment-level checks: leading '-' (option-injection-like) and bare + # './..' segments are syntactically allowed by the regex above but + # are not legal HF repo ids and must be rejected at the validator. + for _seg in hf_repo.split("/"): + if _seg.startswith("-"): + raise ValueError(f"hf_repo must be a valid HuggingFace repo id (org/name), got {hf_repo!r}") + # Reject any segment that starts with '.' — catches '.', '..', '.org', 'org/.git', + # '...../x', etc. HuggingFace's API rejects leading-dot segments; the validator's + # job is to fail closed locally rather than rely on a downstream service + # (per AGENTS.md > LLM Input Safety). Pinned by TestHfRepoSegmentRejection. + if _seg.startswith("."): + raise ValueError(f"hf_repo must be a valid HuggingFace repo id (org/name), got {hf_repo!r}") + if hf_subfolder is not None: + _validate_path(hf_subfolder, "hf_subfolder") + if hf_local_dir is not None: + _validate_path(hf_local_dir, "hf_local_dir") + if lifecycle is not None: + valid_phases = ("full", "teardown") + if lifecycle not in valid_phases: + raise ValueError(f"lifecycle must be one of {valid_phases}, got {lifecycle!r}") + + # Port-only actions (find_containers, list, status, stop) only need + # port/host/protocol validation — the other params are unused by dispatch. + _port_only_actions = ("find_containers", "list", "status", "stop") + if action in _port_only_actions: + return + + # Image/download actions only consume image_name, paths, and volumes — not + # inference-time params (data_config, embodiment_tag, dtypes). + _image_only_actions = ("build_image", "download_checkpoint", "start_container") + if action in _image_only_actions: + # Validate container_name (used by start_container, interpolated into docker run --name) + if container_name is not None and not _CONTAINER_NAME_RE.match(container_name): + raise ValueError(f"container_name must match Docker naming rules (got {container_name!r})") + # Validate image_name, volumes, container_command (relevant to these actions) + if image_name is not None and not _is_valid_docker_image_ref(image_name): + raise ValueError(f"image_name must be a valid Docker image reference (got {image_name!r})") + if volumes is not None: + for vol_host, vol_container in volumes.items(): + _validate_path(vol_host, "volumes key (host path)", reject_colon=True) + _validate_path(vol_container, "volumes value (container path)", reject_colon=True) + if container_command is not None and _SHELL_META.search(container_command): + raise ValueError(f"container_command contains disallowed characters: {container_command!r}") + if checkpoint_path is not None: + _validate_path(checkpoint_path, "checkpoint_path") + # Option-injection guard for params used by these actions + for param_name, param_value in [("repo_url", repo_url), ("repo_tag", repo_tag), ("policy_name", policy_name)]: + if param_value is not None and param_value.startswith("-"): + raise ValueError(f"{param_name} must not start with '-' (got {param_value!r})") + return + + # ── Full validation for inference-mutating actions (start, restart, lifecycle) ── + + # Enumerable string parameters + if not _DATA_CONFIG_RE.match(data_config): + raise ValueError( + f"data_config must be lowercase alphanumeric/underscore (got {data_config!r}). " + f"See the tool docstring for the full list of accepted configs." + ) + if not _EMBODIMENT_TAG_RE.match(embodiment_tag): + raise ValueError(f"embodiment_tag must be lowercase alphanumeric/underscore (got {embodiment_tag!r})") + + # Docker container name + if container_name is not None and not _CONTAINER_NAME_RE.match(container_name): + raise ValueError(f"container_name must match Docker naming rules (got {container_name!r})") + + # Filesystem paths — reject shell metacharacters and traversal + if checkpoint_path is not None: + _validate_path(checkpoint_path, "checkpoint_path") + _validate_path(trt_engine_path, "trt_engine_path") + + # TensorRT dtype allowlists + if vit_dtype not in _VALID_VIT_DTYPES: + raise ValueError(f"vit_dtype must be one of {_VALID_VIT_DTYPES}, got {vit_dtype!r}") + if llm_dtype not in _VALID_LLM_DTYPES: + raise ValueError(f"llm_dtype must be one of {_VALID_LLM_DTYPES}, got {llm_dtype!r}") + if dit_dtype not in _VALID_DIT_DTYPES: + raise ValueError(f"dit_dtype must be one of {_VALID_DIT_DTYPES}, got {dit_dtype!r}") + + # Docker image reference (if provided via kwargs) + if image_name is not None and not _is_valid_docker_image_ref(image_name): + raise ValueError(f"image_name must be a valid Docker image reference (got {image_name!r})") + + # Volume paths validation + if volumes is not None: + for vol_host, vol_container in volumes.items(): + _validate_path(vol_host, "volumes key (host path)", reject_colon=True) + _validate_path(vol_container, "volumes value (container path)", reject_colon=True) + + # Container command — reject shell metacharacters + if container_command is not None and _SHELL_META.search(container_command): + raise ValueError(f"container_command contains disallowed characters: {container_command!r}") + + # Option-injection guard: reject LLM-controlled values starting with '-' + # which could be parsed as flags by git/docker/pgrep in subprocess argv. + for param_name, param_value in [ + ("repo_url", repo_url), + ("repo_tag", repo_tag), + ("policy_name", policy_name), + ]: + if param_value is not None and param_value.startswith("-"): + raise ValueError(f"{param_name} must not start with '-' (got {param_value!r})") + + @tool def gr00t_inference( action: str, @@ -50,7 +368,7 @@ def gr00t_inference( data_config: str = "fourier_gr1_arms_only", embodiment_tag: str = "gr1", denoising_steps: int = 4, - host: str = "0.0.0.0", + host: str | None = None, container_name: str | None = None, timeout: int = 60, use_tensorrt: bool = False, @@ -192,7 +510,14 @@ def gr00t_inference( ``libero_sim``). denoising_steps: Number of denoising steps for action generation (default: 4). N1.5/N1.6 only - the N1.7 server reads this from the checkpoint. - host: Host address to bind the service to (default: ``0.0.0.0``). + host: Docker host-side bind address used in ``-p {host}:{port}:{port}`` + on ``start_container`` / ``lifecycle`` (default: ``127.0.0.1``, + loopback only). Pass ``host="0.0.0.0"`` to expose the published + port on all host interfaces. The inference server inside the + container always binds ``0.0.0.0`` (required by docker port + forwarding); this kwarg has no effect on ``start`` / ``restart`` + against an already-running container, where the host-side bind + was fixed when the container was originally launched. container_name: Specific Docker container name. Auto-detected if omitted. timeout: Seconds to wait for service startup (default: 60). use_tensorrt: Enable TensorRT acceleration (default: False). @@ -297,14 +622,41 @@ def gr00t_inference( if api_token is None: api_token = os.environ.get("GROOT_API_TOKEN") - # Validate protocol up-front so users get a friendly error rather than - # an opaque docker-exec failure inside _start_service. - valid_protocols = ("n1.5", "n1.6", "n1.7") - if protocol not in valid_protocols: - return { - "status": "error", - "message": f"Unknown protocol {protocol!r}. Valid: {list(valid_protocols)}", - } + # Sentinel default: None means "user did not pass host=". + # Default to 127.0.0.1 (loopback, per AGENTS.md § LLM Input Safety). + # The host kwarg now flows verbatim into `docker -p HOST:port:port` + # (no auto-flip; see commit ecf5f0f). + if host is None: + host = "127.0.0.1" + + # ── Validate all inputs in one call (scoped per action) ───────── + try: + validate_inputs( + action=action, + data_config=data_config, + embodiment_tag=embodiment_tag, + port=port, + host=host, + vit_dtype=vit_dtype, + llm_dtype=llm_dtype, + dit_dtype=dit_dtype, + checkpoint_path=checkpoint_path, + trt_engine_path=trt_engine_path, + container_name=container_name, + protocol=protocol, + image_name=image_name, + volumes=volumes, + container_command=container_command, + repo_url=repo_url, + repo_tag=repo_tag, + policy_name=policy_name, + hf_repo=hf_repo, + hf_subfolder=hf_subfolder, + hf_local_dir=hf_local_dir, + lifecycle=lifecycle if action == "lifecycle" else None, + ) + except ValueError as e: + return {"status": "error", "message": str(e)} if action == "find_containers": return _find_gr00t_containers() @@ -342,6 +694,7 @@ def gr00t_inference( container_command=container_command, hf_local_dir=hf_local_dir, force=force, + host=host, ) elif action == "lifecycle": return _lifecycle( @@ -431,8 +784,9 @@ def gr00t_inference( protocol=protocol, use_sim_policy_wrapper=use_sim_policy_wrapper, ) - else: - return {"status": "error", "message": f"Unknown action: {action}"} + + # Unreachable: validate_inputs() rejects unknown actions before dispatch. + return {"status": "error", "message": f"Unknown action: {action}"} # pragma: no cover def _find_gr00t_containers() -> dict[str, Any]: @@ -479,7 +833,7 @@ def _list_running_services() -> dict[str, Any]: return {"status": "success", "services": services, "message": f"Found {len(services)} running services"} - except Exception as e: + except OSError as e: return {"status": "error", "message": f"Failed to list services: {e}"} @@ -491,7 +845,7 @@ def _is_service_running(port: int) -> bool: result = sock.connect_ex(("localhost", port)) sock.close() return result == 0 - except Exception: + except OSError: return False @@ -509,6 +863,88 @@ def _check_service_status(port: int) -> dict[str, Any]: } +def _is_gr00t_process(container_name: str, pid: str, *, port: int | None = None) -> bool: + """Verify that a PID inside a container belongs to a GR00T inference process. + + This prevents accidentally killing unrelated processes that happen to + be listening on the same port. + + Args: + container_name: Docker container name to inspect. + pid: Process ID to check. + port: If provided, also verify the process is bound to this port. + """ + try: + result = subprocess.run( + ["docker", "exec", container_name, "cat", f"/proc/{pid}/cmdline"], + capture_output=True, + text=True, + check=False, + ) + if result.returncode == 0: + cmdline = result.stdout.replace("\x00", " ") + # Require both a Python interpreter AND inference_service.py in cmdline + # to avoid false-matching unrelated processes (e.g. vim editing a gr00t file) + # Match both N1.5/N1.6 (inference_service.py) and N1.7 (gr00t.eval.run_gr00t_server) + is_gr00t = ( + ("inference_service.py" in cmdline or "gr00t.eval.run_gr00t_server" in cmdline) + and ("python" in cmdline.lower() or "gr00t" in cmdline.lower()) + and "--port" in cmdline # Must have a --port flag to be a running service + ) + if is_gr00t and port is not None: + # Verify the process is serving on the requested port + # Use word-boundary regex to avoid partial matches (e.g. port 80 vs 8000) + return bool(re.search(_PYTHON_PORT_RE_FMT.format(port=port), cmdline)) + return is_gr00t + except (OSError, subprocess.SubprocessError, UnicodeDecodeError) as exc: + import logging + + _logger = logging.getLogger(__name__) + if isinstance(exc, PermissionError): + _logger.warning("Permission denied probing container process %s -- treating as non-GR00T", pid) + else: + _logger.debug("Failed to probe container process %s: %s", pid, exc) + return False + + +def _is_gr00t_host_process(pid: str, *, port: int | None = None) -> bool: + """Verify that a host PID belongs to a GR00T inference process. + + Reads /proc//cmdline directly (no Docker) to confirm the process + is a GR00T inference service, optionally bound to a specific port. + + Note: This function reads from /proc and is Linux-only. + + Args: + pid: Process ID to check. + port: If provided, also verify the process is bound to this port. + """ + try: + cmdline_path = Path(f"/proc/{pid}/cmdline") + if cmdline_path.exists(): + cmdline = cmdline_path.read_text().replace("\x00", " ") + # Require both a Python interpreter AND inference_service.py in cmdline + # to avoid false-matching unrelated processes (e.g. vim editing a gr00t file) + # Match both N1.5/N1.6 (inference_service.py) and N1.7 (gr00t.eval.run_gr00t_server) + is_gr00t = ( + ("inference_service.py" in cmdline or "gr00t.eval.run_gr00t_server" in cmdline) + and ("python" in cmdline.lower() or "gr00t" in cmdline.lower()) + and "--port" in cmdline # Must have a --port flag to be a running service + ) + if is_gr00t and port is not None: + return bool(re.search(_PYTHON_PORT_RE_FMT.format(port=port), cmdline)) + return is_gr00t + except (OSError, UnicodeDecodeError) as exc: + import logging + + _logger = logging.getLogger(__name__) + if isinstance(exc, PermissionError): + _logger.warning("Permission denied reading /proc/%s/cmdline -- treating as non-GR00T", pid) + else: + _logger.debug("Failed to probe host process %s: %s", pid, exc) + return False + + def _stop_service(port: int) -> dict[str, Any]: """Stop GR00T inference service running on specific port.""" try: @@ -520,7 +956,14 @@ def _stop_service(port: int) -> dict[str, Any]: container_name = container["name"] try: result = subprocess.run( - ["docker", "exec", container_name, "pgrep", "-f", f"inference_service.py.*--port {port}"], + [ + "docker", + "exec", + container_name, + "pgrep", + "-f", + _PGREP_INFERENCE_PORT_FMT.format(port=port), + ], capture_output=True, text=True, check=False, @@ -529,13 +972,21 @@ def _stop_service(port: int) -> dict[str, Any]: if result.returncode == 0 and result.stdout.strip(): pids = result.stdout.strip().split("\n") for pid in pids: - if pid: + pid = pid.strip() + if pid and _is_gr00t_process(container_name, pid, port=port): subprocess.run(["docker", "exec", container_name, "kill", "-TERM", pid], check=True) time.sleep(2) result = subprocess.run( - ["docker", "exec", container_name, "pgrep", "-f", f"inference_service.py.*--port {port}"], + [ + "docker", + "exec", + container_name, + "pgrep", + "-f", + _PGREP_INFERENCE_PORT_FMT.format(port=port), + ], capture_output=True, text=True, check=False, @@ -544,7 +995,8 @@ def _stop_service(port: int) -> dict[str, Any]: if result.returncode == 0 and result.stdout.strip(): pids = result.stdout.strip().split("\n") for pid in pids: - if pid: + pid = pid.strip() + if pid and _is_gr00t_process(container_name, pid, port=port): subprocess.run(["docker", "exec", container_name, "kill", "-KILL", pid], check=True) return { @@ -557,30 +1009,54 @@ def _stop_service(port: int) -> dict[str, Any]: except subprocess.CalledProcessError: continue - # Fallback: try host system - result = subprocess.run(["lsof", "-t", f"-i:{port}"], capture_output=True, text=True) + # Fallback: try host system — verify via /proc//cmdline + # This path is Linux-only (ERE pgrep + /proc filesystem). + if sys.platform != "linux": + return { + "status": "error", + "message": ( + "No GR00T containers found. Host-fallback stop requires Linux " + "(pgrep + /proc). Run inside a Docker container or use action='find_containers' first." + ), + } + + result = subprocess.run( + # NOTE: ( |$) is ERE syntax; pgrep on Linux (procps-ng) defaults to ERE. + # This pattern is Linux-only; BSD pgrep may not match correctly. + ["pgrep", "-f", _PGREP_INFERENCE_PORT_FMT.format(port=port)], + capture_output=True, + text=True, + ) if result.returncode == 0: pids = result.stdout.strip().split("\n") for pid in pids: - if pid: + pid = pid.strip() + if pid and _is_gr00t_host_process(pid, port=port): subprocess.run(["kill", "-TERM", pid], check=True) time.sleep(2) - result = subprocess.run(["lsof", "-t", f"-i:{port}"], capture_output=True, text=True) + result = subprocess.run( + # NOTE: ( |$) is ERE syntax; pgrep on Linux (procps-ng) defaults to ERE. + # This pattern is Linux-only; BSD pgrep may not match correctly. + ["pgrep", "-f", _PGREP_INFERENCE_PORT_FMT.format(port=port)], + capture_output=True, + text=True, + ) if result.returncode == 0: pids = result.stdout.strip().split("\n") for pid in pids: - if pid: + pid = pid.strip() + if pid and _is_gr00t_host_process(pid, port=port): subprocess.run(["kill", "-KILL", pid], check=True) return {"status": "success", "port": port, "message": f"Service on port {port} stopped"} else: return {"status": "success", "port": port, "message": f"No service running on port {port}"} - except Exception as e: + except (OSError, subprocess.SubprocessError) as e: return {"status": "error", "message": f"Failed to stop service: {e}"} @@ -589,7 +1065,6 @@ def _build_inference_command( container_name: str, checkpoint_path: str, port: int, - host: str, data_config: str, embodiment_tag: str, denoising_steps: int, @@ -640,7 +1115,7 @@ def _build_inference_command( "--port", str(port), "--host", - host, + "0.0.0.0", # always bind all-interfaces inside container; docker -p isolates host "--embodiment-tag", embodiment_tag, ] @@ -660,7 +1135,7 @@ def _build_inference_command( "--port", str(port), "--host", - host, + "0.0.0.0", # always bind all-interfaces inside container; docker -p isolates host "--data-config", data_config, "--embodiment-tag", @@ -714,7 +1189,12 @@ def _start_service( protocol: str = "n1.5", use_sim_policy_wrapper: bool = False, ) -> dict[str, Any]: - """Start GR00T inference service using Isaac-GR00T's native inference service.""" + """Start GR00T inference service using Isaac-GR00T's native inference service. + + The ``host`` kwarg controls the docker host-side port binding via + ``-p {host}:{port}:{port}``. Default ``127.0.0.1`` keeps the service on + loopback; pass ``host="0.0.0.0"`` to expose to the network. + """ try: # Find container if not specified if container_name is None: @@ -732,7 +1212,6 @@ def _start_service( container_name=container_name, checkpoint_path=checkpoint_path, port=port, - host=host, data_config=data_config, embodiment_tag=embodiment_tag, denoising_steps=denoising_steps, @@ -791,7 +1270,7 @@ def _start_service( except subprocess.CalledProcessError as e: return {"status": "error", "message": f"Failed to start service: {e.stderr or e}"} - except Exception as e: + except (OSError, RuntimeError) as e: return {"status": "error", "message": f"Unexpected error: {e}"} @@ -1029,10 +1508,15 @@ def _start_container( container_command: str, hf_local_dir: str | None, force: bool, + host: str, ) -> dict[str, Any]: """``docker run -d`` the GR00T container so subsequent ``start`` actions can ``docker exec`` into it. + The ``host`` kwarg controls the docker host-side port binding via + ``-p {host}:{port}:{port}``. Required. ``127.0.0.1`` keeps the published + port on loopback only; pass ``host="0.0.0.0"`` to expose to the network. + Idempotent: when a container with ``container_name`` is already running, returns success without touching docker. When it exists but is stopped, ``force=True`` removes + recreates it (otherwise returns @@ -1069,7 +1553,11 @@ def _start_container( "--name", name, "-p", - f"{port}:{port}", + # Bind docker port-publish to user-requested host (loopback by default). + # Service inside container binds 0.0.0.0 (must, for docker -p to work), + # but the *host* binding honours user intent. Users pass host="0.0.0.0" + # explicitly to expose to the network. + f"{host}:{port}:{port}", ] # Default volume layout: mount the checkpoint dir into /data/checkpoints @@ -1261,6 +1749,7 @@ def _lifecycle( container_command=container_command, hf_local_dir=resolved_local_dir, force=force, + host=host, ) steps.append({"step": "start_container", "result": container_result}) if container_result["status"] != "success": @@ -1320,7 +1809,7 @@ def _lifecycle( if __name__ == "__main__": - print("🐳 GR00T Inference Service Manager (Isaac-GR00T Native)") + print("GR00T Inference Service Manager (Isaac-GR00T Native)") print("Supports ZMQ, HTTP, and TensorRT inference modes") print() print("Examples:") diff --git a/tests/policies/groot/test_gr00t_inference_validation.py b/tests/policies/groot/test_gr00t_inference_validation.py new file mode 100644 index 00000000..cff239e8 --- /dev/null +++ b/tests/policies/groot/test_gr00t_inference_validation.py @@ -0,0 +1,2425 @@ +"""Tests for gr00t_inference input validation. + +Covers the validate_inputs() function which centralises all parameter +validation for the gr00t_inference tool. +""" + +import pytest + +from strands_robots.tools.gr00t_inference import validate_inputs + +# Standard valid kwargs for validate_inputs — tests override individual fields. +# validate_inputs() no longer has defaults (gr00t_inference() is the single source +# of truth for defaults), so tests must supply all required params. +_VALID_KWARGS = { + "action": "start", + "data_config": "fourier_gr1_arms_only", + "embodiment_tag": "gr1", + "port": 5555, + "host": "127.0.0.1", + "vit_dtype": "fp8", + "llm_dtype": "nvfp4", + "dit_dtype": "fp8", + "checkpoint_path": None, + "trt_engine_path": "gr00t_engine", + "container_name": None, + "protocol": "n1.5", +} + + +class TestValidateInputs: + """Tests for the validate_inputs() public function.""" + + def test_valid_defaults(self): + """Default values must pass validation.""" + validate_inputs(**_VALID_KWARGS) + + def test_valid_with_all_optional(self): + validate_inputs( + **{ + **_VALID_KWARGS, + "data_config": "so100_dualcam", + "embodiment_tag": "so100", + "port": 8000, + "vit_dtype": "fp16", + "llm_dtype": "fp8", + "dit_dtype": "fp16", + "checkpoint_path": "/data/checkpoints/model", + "trt_engine_path": "/engines/cache", + "container_name": "gr00t-n17", + } + ) + + def test_invalid_data_config_uppercase(self): + with pytest.raises(ValueError, match="data_config"): + validate_inputs( + **{ + **_VALID_KWARGS, + "data_config": "FourierGR1", + "embodiment_tag": "gr1", + "port": 5555, + "vit_dtype": "fp8", + "llm_dtype": "nvfp4", + "dit_dtype": "fp8", + } + ) + + def test_invalid_data_config_shell_chars(self): + with pytest.raises(ValueError, match="data_config"): + validate_inputs( + **{ + **_VALID_KWARGS, + "data_config": "foo;rm -rf /", + "embodiment_tag": "gr1", + "port": 5555, + "vit_dtype": "fp8", + "llm_dtype": "nvfp4", + "dit_dtype": "fp8", + } + ) + + def test_invalid_embodiment_tag(self): + with pytest.raises(ValueError, match="embodiment_tag"): + validate_inputs( + **{ + **_VALID_KWARGS, + "data_config": "so100", + "embodiment_tag": "GR1-Sonic!", + "port": 5555, + "vit_dtype": "fp8", + "llm_dtype": "nvfp4", + "dit_dtype": "fp8", + } + ) + + def test_port_zero(self): + with pytest.raises(ValueError, match="port"): + validate_inputs( + **{ + **_VALID_KWARGS, + "data_config": "so100", + "embodiment_tag": "so100", + "port": 0, + "vit_dtype": "fp8", + "llm_dtype": "nvfp4", + "dit_dtype": "fp8", + } + ) + + def test_port_too_high(self): + with pytest.raises(ValueError, match="port"): + validate_inputs( + **{ + **_VALID_KWARGS, + "data_config": "so100", + "embodiment_tag": "so100", + "port": 70000, + "vit_dtype": "fp8", + "llm_dtype": "nvfp4", + "dit_dtype": "fp8", + } + ) + + def test_invalid_vit_dtype(self): + with pytest.raises(ValueError, match="vit_dtype"): + validate_inputs( + **{ + **_VALID_KWARGS, + "data_config": "so100", + "embodiment_tag": "so100", + "port": 5555, + "vit_dtype": "bf16", + "llm_dtype": "nvfp4", + "dit_dtype": "fp8", + } + ) + + def test_invalid_llm_dtype(self): + with pytest.raises(ValueError, match="llm_dtype"): + validate_inputs( + **{ + **_VALID_KWARGS, + "data_config": "so100", + "embodiment_tag": "so100", + "port": 5555, + "vit_dtype": "fp8", + "llm_dtype": "int4", + "dit_dtype": "fp8", + } + ) + + def test_invalid_dit_dtype(self): + with pytest.raises(ValueError, match="dit_dtype"): + validate_inputs( + **{ + **_VALID_KWARGS, + "data_config": "so100", + "embodiment_tag": "so100", + "port": 5555, + "vit_dtype": "fp8", + "llm_dtype": "nvfp4", + "dit_dtype": "bf16", + } + ) + + def test_checkpoint_path_traversal(self): + with pytest.raises(ValueError, match="checkpoint_path"): + validate_inputs( + **{ + **_VALID_KWARGS, + "data_config": "so100", + "embodiment_tag": "so100", + "port": 5555, + "vit_dtype": "fp8", + "llm_dtype": "nvfp4", + "dit_dtype": "fp8", + "checkpoint_path": "/data/../../../etc/passwd", + } + ) + + def test_checkpoint_path_null_byte(self): + with pytest.raises(ValueError, match="checkpoint_path"): + validate_inputs( + **{ + **_VALID_KWARGS, + "data_config": "so100", + "embodiment_tag": "so100", + "port": 5555, + "vit_dtype": "fp8", + "llm_dtype": "nvfp4", + "dit_dtype": "fp8", + "checkpoint_path": "/data/model\x00.bin", + } + ) + + def test_trt_engine_path_shell_injection(self): + with pytest.raises(ValueError, match="trt_engine_path"): + validate_inputs( + **{ + **_VALID_KWARGS, + "data_config": "so100", + "embodiment_tag": "so100", + "port": 5555, + "vit_dtype": "fp8", + "llm_dtype": "nvfp4", + "dit_dtype": "fp8", + "trt_engine_path": "engine;rm -rf /", + } + ) + + def test_invalid_container_name(self): + with pytest.raises(ValueError, match="container_name"): + validate_inputs( + **{ + **_VALID_KWARGS, + "data_config": "so100", + "embodiment_tag": "so100", + "port": 5555, + "vit_dtype": "fp8", + "llm_dtype": "nvfp4", + "dit_dtype": "fp8", + "container_name": "-invalid-start", + } + ) + + def test_container_name_none_is_ok(self): + """container_name=None should not raise.""" + validate_inputs( + **{ + **_VALID_KWARGS, + "data_config": "so100", + "embodiment_tag": "so100", + "port": 5555, + "vit_dtype": "fp8", + "llm_dtype": "nvfp4", + "dit_dtype": "fp8", + "container_name": None, + } + ) + + +class TestIsGr00tProcess: + """Test the _is_gr00t_process helper verifies port binding.""" + + def test_rejects_wrong_port(self, monkeypatch): + """_is_gr00t_process should reject a GR00T process on a different port.""" + import subprocess as sp + + from strands_robots.tools.gr00t_inference import _is_gr00t_process + + # Simulate cmdline: "python inference_service.py --port 8000" + fake_result = sp.CompletedProcess( + args=[], returncode=0, stdout="python\x00inference_service.py\x00--port\x008000\x00" + ) + monkeypatch.setattr(sp, "run", lambda *a, **kw: fake_result) + + # Asking for port 80 should return False even though it's a gr00t process + assert _is_gr00t_process("container", "123", port=80) is False + + def test_accepts_matching_port(self, monkeypatch): + """_is_gr00t_process should accept when port matches.""" + import subprocess as sp + + from strands_robots.tools.gr00t_inference import _is_gr00t_process + + fake_result = sp.CompletedProcess( + args=[], returncode=0, stdout="python\x00inference_service.py\x00--port\x008000\x00" + ) + monkeypatch.setattr(sp, "run", lambda *a, **kw: fake_result) + + assert _is_gr00t_process("container", "123", port=8000) is True + + def test_no_port_check_when_none(self, monkeypatch): + """_is_gr00t_process without port param should not verify port.""" + import subprocess as sp + + from strands_robots.tools.gr00t_inference import _is_gr00t_process + + fake_result = sp.CompletedProcess( + args=[], returncode=0, stdout="python\x00inference_service.py\x00--port\x008000\x00" + ) + monkeypatch.setattr(sp, "run", lambda *a, **kw: fake_result) + + # Without port, just checks if it's a gr00t process + assert _is_gr00t_process("container", "123") is True + + def test_accepts_equals_style_port(self, monkeypatch): + """_is_gr00t_process should accept --port=N style.""" + import subprocess as sp + + from strands_robots.tools.gr00t_inference import _is_gr00t_process + + fake_result = sp.CompletedProcess( + args=[], returncode=0, stdout="python\x00inference_service.py\x00--port=5555\x00" + ) + monkeypatch.setattr(sp, "run", lambda *a, **kw: fake_result) + + assert _is_gr00t_process("container", "123", port=5555) is True + assert _is_gr00t_process("container", "123", port=6666) is False + + +class TestIsGr00tHostProcess: + """Test the _is_gr00t_host_process helper for host-system PID verification.""" + + def test_rejects_wrong_port(self, tmp_path, monkeypatch): + """_is_gr00t_host_process should reject a process on a different port.""" + from strands_robots.tools.gr00t_inference import _is_gr00t_host_process + + # Create a fake /proc//cmdline + proc_dir = tmp_path / "proc" / "123" + proc_dir.mkdir(parents=True) + cmdline_file = proc_dir / "cmdline" + cmdline_file.write_text("python\x00inference_service.py\x00--port\x008000\x00") + + # Monkeypatch Path to point at our fake proc, with reachability check + from pathlib import Path as RealPath + + called = {} + + def _fake_path(p): + called["p"] = p + return RealPath(str(p).replace("/proc", str(tmp_path / "proc"))) + + monkeypatch.setattr("strands_robots.tools.gr00t_inference.Path", _fake_path) + + assert _is_gr00t_host_process("123", port=80) is False + assert called.get("p") == "/proc/123/cmdline" # patch was reached + + def test_accepts_matching_port(self, tmp_path, monkeypatch): + """_is_gr00t_host_process should accept when port matches.""" + from strands_robots.tools.gr00t_inference import _is_gr00t_host_process + + proc_dir = tmp_path / "proc" / "456" + proc_dir.mkdir(parents=True) + cmdline_file = proc_dir / "cmdline" + cmdline_file.write_text("python\x00inference_service.py\x00--port\x008000\x00") + + from pathlib import Path as RealPath + + called = {} + + def _fake_path(p): + called["p"] = p + return RealPath(str(p).replace("/proc", str(tmp_path / "proc"))) + + monkeypatch.setattr("strands_robots.tools.gr00t_inference.Path", _fake_path) + + assert _is_gr00t_host_process("456", port=8000) is True + assert called.get("p") == "/proc/456/cmdline" # patch was reached + + def test_rejects_non_gr00t_process(self, tmp_path, monkeypatch): + """_is_gr00t_host_process should reject non-GR00T processes.""" + from strands_robots.tools.gr00t_inference import _is_gr00t_host_process + + proc_dir = tmp_path / "proc" / "789" + proc_dir.mkdir(parents=True) + cmdline_file = proc_dir / "cmdline" + cmdline_file.write_text("python\x00some_other_service.py\x00--port\x008000\x00") + + from pathlib import Path as RealPath + + called = {} + + def _fake_path(p): + called["p"] = p + return RealPath(str(p).replace("/proc", str(tmp_path / "proc"))) + + monkeypatch.setattr("strands_robots.tools.gr00t_inference.Path", _fake_path) + + assert _is_gr00t_host_process("789", port=8000) is False + assert called.get("p") == "/proc/789/cmdline" # patch was reached + + def test_no_port_check_when_none(self, tmp_path, monkeypatch): + """_is_gr00t_host_process without port checks only process identity.""" + from strands_robots.tools.gr00t_inference import _is_gr00t_host_process + + proc_dir = tmp_path / "proc" / "321" + proc_dir.mkdir(parents=True) + cmdline_file = proc_dir / "cmdline" + cmdline_file.write_text("python\x00inference_service.py\x00--port\x009999\x00") + + from pathlib import Path as RealPath + + called = {} + + def _fake_path(p): + called["p"] = p + return RealPath(str(p).replace("/proc", str(tmp_path / "proc"))) + + monkeypatch.setattr("strands_robots.tools.gr00t_inference.Path", _fake_path) + + # Without port kwarg, just checks identity + assert _is_gr00t_host_process("321") is True + assert called.get("p") == "/proc/321/cmdline" # patch was reached + + +class TestHostValidation: + """Tests for host address validation in validate_inputs().""" + + def test_valid_loopback(self): + """127.0.0.1 is valid.""" + validate_inputs( + **{ + **_VALID_KWARGS, + "data_config": "fourier_gr1_arms_only", + "embodiment_tag": "gr1", + "port": 5555, + "host": "127.0.0.1", + "vit_dtype": "fp8", + "llm_dtype": "nvfp4", + "dit_dtype": "fp8", + } + ) + + def test_valid_all_interfaces(self): + """0.0.0.0 is valid.""" + validate_inputs( + **{ + **_VALID_KWARGS, + "data_config": "fourier_gr1_arms_only", + "embodiment_tag": "gr1", + "port": 5555, + "host": "0.0.0.0", + "vit_dtype": "fp8", + "llm_dtype": "nvfp4", + "dit_dtype": "fp8", + } + ) + + def test_valid_ipv6_loopback(self): + """::1 is valid.""" + validate_inputs( + **{ + **_VALID_KWARGS, + "data_config": "fourier_gr1_arms_only", + "embodiment_tag": "gr1", + "port": 5555, + "host": "::1", + "vit_dtype": "fp8", + "llm_dtype": "nvfp4", + "dit_dtype": "fp8", + } + ) + + def test_invalid_host_with_spaces(self): + """Host with spaces must be rejected.""" + with pytest.raises(ValueError, match="host must be a valid IP address or hostname"): + validate_inputs( + **{ + **_VALID_KWARGS, + "data_config": "fourier_gr1_arms_only", + "embodiment_tag": "gr1", + "port": 5555, + "host": "foo bar", + "vit_dtype": "fp8", + "llm_dtype": "nvfp4", + "dit_dtype": "fp8", + } + ) + + def test_invalid_host_empty_labels(self): + """Host with empty labels (double dot) must be rejected.""" + with pytest.raises(ValueError, match="host must be a valid IP address or hostname"): + validate_inputs( + **{ + **_VALID_KWARGS, + "data_config": "fourier_gr1_arms_only", + "embodiment_tag": "gr1", + "port": 5555, + "host": "a..b", + "vit_dtype": "fp8", + "llm_dtype": "nvfp4", + "dit_dtype": "fp8", + } + ) + + def test_valid_hostname_localhost(self): + """Valid hostnames like localhost are now accepted.""" + # Should not raise — localhost is a valid RFC-952 hostname + validate_inputs( + **{ + **_VALID_KWARGS, + "data_config": "fourier_gr1_arms_only", + "embodiment_tag": "gr1", + "port": 5555, + "host": "localhost", + "vit_dtype": "fp8", + "llm_dtype": "nvfp4", + "dit_dtype": "fp8", + } + ) + + def test_valid_hostname_docker_internal(self): + """Docker internal hostname is accepted.""" + validate_inputs( + **{ + **_VALID_KWARGS, + "data_config": "fourier_gr1_arms_only", + "embodiment_tag": "gr1", + "port": 5555, + "host": "host.docker.internal", + "vit_dtype": "fp8", + "llm_dtype": "nvfp4", + "dit_dtype": "fp8", + } + ) + + def test_invalid_host_special_chars(self): + """Hostnames with special characters are rejected.""" + with pytest.raises(ValueError, match="host must be a valid IP address or hostname"): + validate_inputs( + **{ + **_VALID_KWARGS, + "data_config": "fourier_gr1_arms_only", + "embodiment_tag": "gr1", + "port": 5555, + "host": "--invalid-host", + "vit_dtype": "fp8", + "llm_dtype": "nvfp4", + "dit_dtype": "fp8", + } + ) + + +class TestGr00tInferenceToolIntegration: + """Integration tests verifying validate_inputs is wired into the tool entry point. + + These tests invoke gr00t_inference() directly and assert that invalid inputs + are caught and returned as error dicts, NOT silently passed through. + This pins the try/except ValueError wiring so a future refactor that drops + the validation call surfaces as a test failure. + """ + + def test_shell_injection_in_data_config_returns_error(self): + """Shell metacharacters in data_config must return error dict.""" + from strands_robots.tools.gr00t_inference import gr00t_inference + + result = gr00t_inference(action="start", data_config="foo;rm -rf /") + assert result["status"] == "error" + assert "data_config" in result["message"] + + def test_path_traversal_in_checkpoint_returns_error(self): + """Path traversal in checkpoint_path must return error dict.""" + from strands_robots.tools.gr00t_inference import gr00t_inference + + result = gr00t_inference(action="start", checkpoint_path="/tmp/../../../etc/passwd") + assert result["status"] == "error" + assert "checkpoint_path" in result["message"] + + def test_invalid_host_returns_error(self): + """Invalid host address must return error dict.""" + from strands_robots.tools.gr00t_inference import gr00t_inference + + result = gr00t_inference(action="start", host="--not-valid") + assert result["status"] == "error" + assert "host" in result["message"] + + def test_invalid_port_returns_error(self): + """Out-of-range port must return error dict.""" + from strands_robots.tools.gr00t_inference import gr00t_inference + + result = gr00t_inference(action="start", port=99999) + assert result["status"] == "error" + assert "port" in result["message"] + + +class TestStopServiceCrossPortKill: + """End-to-end regression test for the cross-port-kill bug. + + Verifies that _stop_service(port=80) does NOT kill a GR00T process + running on port 8000. This pins the _is_gr00t_process(port=...) guard + so a future refactor that removes it will surface as a test failure. + """ + + def test_stop_service_does_not_kill_wrong_port(self, monkeypatch): + """_stop_service(port=80) must NOT kill a process on port 8000.""" + from strands_robots.tools.gr00t_inference import _stop_service + + killed_pids = [] + call_log = [] + + def _fake_run(cmd, *args, **kwargs): + call_log.append(cmd) + + # Mock _find_gr00t_containers returning no containers (forces host fallback) + if "docker" in cmd and "ps" in cmd: + import subprocess + + result = subprocess.CompletedProcess(cmd, 0, stdout="", stderr="") + return result + + # Mock pgrep finding PID 999 on the host + if cmd[0] == "pgrep": + import subprocess + + return subprocess.CompletedProcess(cmd, 0, stdout="999\n", stderr="") + + # Mock kill — record it + if cmd[0] == "kill": + killed_pids.append(cmd[-1]) + import subprocess + + return subprocess.CompletedProcess(cmd, 0, stdout="", stderr="") + + import subprocess + + return subprocess.CompletedProcess(cmd, 1, stdout="", stderr="") + + def _fake_host_process(pid, *, port=None): + """Simulate a process running on port 8000, NOT port 80.""" + # The process is a GR00T process but on port 8000 + if port == 80: + return False # Not on port 80 + if port == 8000: + return True # Yes on port 8000 + return True # Generic check without port + + monkeypatch.setattr("strands_robots.tools.gr00t_inference.subprocess.run", _fake_run) + monkeypatch.setattr("strands_robots.tools.gr00t_inference._is_gr00t_host_process", _fake_host_process) + monkeypatch.setattr( + "strands_robots.tools.gr00t_inference._find_gr00t_containers", + lambda: {"status": "success", "containers": []}, + ) + + _stop_service(port=80) + + # No process should have been killed (the only process is on port 8000) + assert not killed_pids, ( + f"_stop_service(port=80) killed PIDs {killed_pids} but should not have " + f"(the only GR00T process is on port 8000)" + ) + + def test_stop_service_kills_correct_port(self, monkeypatch): + """_stop_service(port=8000) MUST kill a process on port 8000.""" + from strands_robots.tools.gr00t_inference import _stop_service + + killed_pids = [] + + def _fake_run(cmd, *args, **kwargs): + import subprocess + + if cmd[0] == "pgrep": + return subprocess.CompletedProcess(cmd, 0, stdout="999\n", stderr="") + if cmd[0] == "kill": + killed_pids.append(cmd[-1]) + return subprocess.CompletedProcess(cmd, 0, stdout="", stderr="") + return subprocess.CompletedProcess(cmd, 1, stdout="", stderr="") + + def _fake_host_process(pid, *, port=None): + """Simulate a process running on port 8000.""" + if port == 8000: + return True + return False + + monkeypatch.setattr("strands_robots.tools.gr00t_inference.subprocess.run", _fake_run) + monkeypatch.setattr("strands_robots.tools.gr00t_inference._is_gr00t_host_process", _fake_host_process) + monkeypatch.setattr( + "strands_robots.tools.gr00t_inference._find_gr00t_containers", + lambda: {"status": "success", "containers": []}, + ) + + _stop_service(port=8000) + + # Process should have been killed + assert "999" in killed_pids, f"_stop_service(port=8000) should have killed PID 999 but killed {killed_pids}" + + +class TestActionScopedValidation: + """Tests verifying that validate_inputs scopes checks per action. + + Read-only actions (find_containers, list, status, stop) should only + validate port/host/protocol, not the full parameter surface like + data_config, embodiment_tag, etc. + """ + + def test_read_only_action_accepts_any_data_config(self): + """Read-only actions should not validate data_config.""" + from strands_robots.tools.gr00t_inference import validate_inputs + + # This has hyphens/caps which WOULD fail for action="start" but passes for "list" + validate_inputs(**{**_VALID_KWARGS, "action": "list", "data_config": "Has-Hyphens-And-Caps"}) + + def test_read_only_action_still_validates_port(self): + """Read-only actions must still validate port.""" + from strands_robots.tools.gr00t_inference import validate_inputs + + with pytest.raises(ValueError, match="port must be between"): + validate_inputs(**{**_VALID_KWARGS, "action": "status", "port": 99999}) + + def test_read_only_action_still_validates_host(self): + """Read-only actions must still validate host.""" + from strands_robots.tools.gr00t_inference import validate_inputs + + with pytest.raises(ValueError, match="host must be a valid"): + validate_inputs(**{**_VALID_KWARGS, "action": "stop", "host": "--invalid"}) + + def test_read_only_action_still_validates_protocol(self): + """Read-only actions must still validate protocol.""" + from strands_robots.tools.gr00t_inference import validate_inputs + + with pytest.raises(ValueError, match="Unknown protocol"): + validate_inputs(**{**_VALID_KWARGS, "action": "list", "protocol": "invalid"}) + + def test_mutating_action_validates_data_config(self): + """Mutating actions must validate data_config.""" + from strands_robots.tools.gr00t_inference import validate_inputs + + with pytest.raises(ValueError, match="data_config"): + validate_inputs(**{**_VALID_KWARGS, "action": "start", "data_config": "foo;bar"}) + + def test_integration_read_only_action_skips_data_config_validation(self): + """gr00t_inference(action='list', data_config='invalid') must not error on data_config.""" + from strands_robots.tools.gr00t_inference import gr00t_inference + + # action="list" should not validate data_config + # It will fail at runtime (no docker) but NOT on validation + result = gr00t_inference(action="list", data_config="invalid;stuff") + # Should NOT be a validation error about data_config + if result.get("status") == "error": + assert "data_config" not in result.get("message", "") + + +class TestHostNumericTypoRejection: + """Regression tests for all-numeric hostname typos. + + Verifies that "127.0.01" (typo for 127.0.0.1) and "999.999.999.999" + are rejected by validate_inputs. These strings pass _HOSTNAME_RE but + are caught by the _ALL_NUMERIC_RE guard introduced in review round-4. + """ + + def test_invalid_host_typo_dotted_numeric(self): + """127.0.01 (typo for 127.0.0.1) must be rejected.""" + with pytest.raises(ValueError, match="host must be a valid IP address or hostname"): + validate_inputs( + **{ + **_VALID_KWARGS, + "data_config": "fourier_gr1_arms_only", + "embodiment_tag": "gr1", + "port": 5555, + "host": "127.0.01", + "vit_dtype": "fp8", + "llm_dtype": "nvfp4", + "dit_dtype": "fp8", + } + ) + + def test_invalid_host_999_octets(self): + """999.999.999.999 (invalid IP, all-numeric) must be rejected.""" + with pytest.raises(ValueError, match="host must be a valid IP address or hostname"): + validate_inputs( + **{ + **_VALID_KWARGS, + "data_config": "fourier_gr1_arms_only", + "embodiment_tag": "gr1", + "port": 5555, + "host": "999.999.999.999", + "vit_dtype": "fp8", + "llm_dtype": "nvfp4", + "dit_dtype": "fp8", + } + ) + + def test_single_numeric_label_is_valid_hostname(self): + """A bare number like '8080' is a valid single-label hostname (RFC-1123).""" + # Single-label numerics are valid hostnames; only multi-label patterns + # like '127.0.01' (IP typos) are rejected. + validate_inputs( + **{ + **_VALID_KWARGS, + "host": "8080", + } + ) + + +class TestActionAllowlistValidation: + """Tests for the action allowlist in validate_inputs. + + Verifies that unknown actions are rejected with a clear error that + lists the valid options, rather than falling through to validation + of unrelated parameters. + """ + + def test_unknown_action_rejected(self): + """Typo'd action gets a clear error listing valid actions.""" + with pytest.raises(ValueError, match="Unknown action.*Valid actions"): + validate_inputs(**{**_VALID_KWARGS, "action": "strat"}) # typo for "start" + + def test_unknown_action_integration(self): + """gr00t_inference(action='typo') returns error about unknown action.""" + from strands_robots.tools.gr00t_inference import gr00t_inference + + result = gr00t_inference(action="typo") + assert result["status"] == "error" + assert "Unknown action" in result["message"] + + def test_all_valid_actions_accepted(self): + """All 10 valid actions pass action validation (may fail later).""" + from strands_robots.tools.gr00t_inference import _VALID_ACTIONS + + for action in _VALID_ACTIONS: + # Should not raise ValueError about unknown action + # (may raise about other params, but that's fine) + try: + validate_inputs(**{**_VALID_KWARGS, "action": action}) + except ValueError as e: + assert "Unknown action" not in str(e), f"Action {action!r} wrongly rejected" + + +class TestExpandedParamValidation: + """Tests for image_name, volumes, and container_command validation.""" + + def test_invalid_image_name_rejected(self): + """Docker image with shell chars must be rejected.""" + with pytest.raises(ValueError, match="image_name must be a valid Docker"): + validate_inputs( + **{ + **_VALID_KWARGS, + "action": "start", + "data_config": "fourier_gr1_arms_only", + "embodiment_tag": "gr1", + "image_name": "gr00t:latest; rm -rf /", + } + ) + + def test_valid_image_name_accepted(self): + """Standard Docker image references must pass.""" + # Should not raise + validate_inputs( + **{ + **_VALID_KWARGS, + "action": "start", + "data_config": "fourier_gr1_arms_only", + "embodiment_tag": "gr1", + "image_name": "nvcr.io/nvidia/gr00t:n1.7", + } + ) + + def test_volume_path_traversal_rejected(self): + """Volumes with path traversal must be rejected.""" + with pytest.raises(ValueError, match="volumes key"): + validate_inputs( + **{ + **_VALID_KWARGS, + "action": "start", + "data_config": "fourier_gr1_arms_only", + "embodiment_tag": "gr1", + "volumes": {"/../etc/passwd": "/data"}, + } + ) + + def test_container_command_shell_meta_rejected(self): + """Container command with shell metacharacters must be rejected.""" + with pytest.raises(ValueError, match="container_command contains disallowed"): + validate_inputs( + **{ + **_VALID_KWARGS, + "action": "start", + "data_config": "fourier_gr1_arms_only", + "embodiment_tag": "gr1", + "container_command": "tail -f /dev/null; rm -rf /", + } + ) + + def test_valid_container_command_accepted(self): + """Standard container commands must pass.""" + # Should not raise + validate_inputs( + **{ + **_VALID_KWARGS, + "action": "start", + "data_config": "fourier_gr1_arms_only", + "embodiment_tag": "gr1", + "container_command": "tail -f /dev/null", + } + ) + + +class TestHappyPathIntegration: + """Happy-path integration test for gr00t_inference. + + Verifies that valid inputs pass validation and proceed to runtime + (which will fail due to missing Docker, but NOT on validation). + """ + + def test_valid_list_action_passes_validation(self): + """gr00t_inference(action='list') with valid params does not error on validation.""" + from strands_robots.tools.gr00t_inference import gr00t_inference + + result = gr00t_inference(action="list") + # The error should be about runtime (no docker), NOT validation + if result.get("status") == "error": + msg = result.get("message", "") + # Must not be a validation error + assert "must be" not in msg or "port" not in msg + assert "Unknown action" not in msg + assert "data_config" not in msg + + def test_valid_status_action_passes_validation(self): + """gr00t_inference(action='status') with valid params proceeds past validation.""" + from strands_robots.tools.gr00t_inference import gr00t_inference + + result = gr00t_inference(action="status", port=5555, host="127.0.0.1") + # Should not be a validation error + if result.get("status") == "error": + msg = result.get("message", "") + assert "Unknown action" not in msg + assert "host must be" not in msg + assert "port must be" not in msg + + +class TestDockerImageRegistryPort: + """Tests that _DOCKER_IMAGE_RE supports private registries with port numbers.""" + + def test_registry_with_port_accepted(self): + """localhost:5000/myorg/img:tag must be accepted.""" + validate_inputs(**{**_VALID_KWARGS, "image_name": "localhost:5000/myorg/img:tag"}) + + def test_registry_with_port_no_tag(self): + """registry.internal:5000/img must be accepted.""" + validate_inputs(**{**_VALID_KWARGS, "image_name": "registry.internal:5000/img"}) + + def test_nvcr_standard_format(self): + """nvcr.io/nvidia/gr00t:n1.7 must be accepted.""" + validate_inputs(**{**_VALID_KWARGS, "image_name": "nvcr.io/nvidia/gr00t:n1.7"}) + + def test_simple_image_tag(self): + """gr00t:latest must be accepted.""" + validate_inputs(**{**_VALID_KWARGS, "image_name": "gr00t:latest"}) + + +class TestProcessIdentificationRequiresPort: + """Tests that _is_gr00t_process requires --port in cmdline. + + Prevents false-matching unrelated processes like editors or log-tailers + that happen to have 'inference_service.py' and 'python' in their cmdline. + """ + + def test_process_without_port_flag_rejected(self, monkeypatch): + """A process with 'python inference_service.py' but no --port flag is not a match.""" + from strands_robots.tools.gr00t_inference import _is_gr00t_process + + # Mock docker exec to return a cmdline without --port + def fake_run(*args, **kwargs): + class Result: + returncode = 0 + stdout = "python inference_service.py --config test\x00" + + return Result() + + monkeypatch.setattr("subprocess.run", fake_run) + # Without --port in cmdline, should return False + assert _is_gr00t_process("container", "123", port=5555) is False + + def test_process_with_port_flag_accepted(self, monkeypatch): + """A process with --port 5555 in cmdline is a match.""" + from strands_robots.tools.gr00t_inference import _is_gr00t_process + + def fake_run(*args, **kwargs): + class Result: + returncode = 0 + stdout = "python inference_service.py --port 5555\x00" + + return Result() + + monkeypatch.setattr("subprocess.run", fake_run) + assert _is_gr00t_process("container", "123", port=5555) is True + + def test_editor_on_inference_service_rejected(self, monkeypatch): + """vim editing inference_service.py under a python venv is not a match.""" + from strands_robots.tools.gr00t_inference import _is_gr00t_process + + def fake_run(*args, **kwargs): + class Result: + returncode = 0 + stdout = "/opt/conda/envs/gr00t/bin/python vim /opt/gr00t/inference_service.py\x00" + + return Result() + + monkeypatch.setattr("subprocess.run", fake_run) + # No --port flag → rejected + assert _is_gr00t_process("container", "123", port=5555) is False + + +class TestOptionInjectionGuard: + """Test option-injection guard for argv-interpolated parameters.""" + + def test_repo_url_starting_with_dash_rejected(self): + """repo_url='--upload-pack=evil' must be rejected.""" + from strands_robots.tools.gr00t_inference import gr00t_inference + + result = gr00t_inference( + action="build_image", + repo_url="--upload-pack=touch /tmp/pwned", + ) + assert result["status"] == "error" + assert "repo_url" in result["message"] + assert "must not start with '-'" in result["message"] + + def test_repo_tag_starting_with_dash_rejected(self): + """repo_tag='--config=evil' must be rejected.""" + from strands_robots.tools.gr00t_inference import gr00t_inference + + result = gr00t_inference( + action="build_image", + repo_tag="--config=core.fsmonitor=evil-cmd", + ) + assert result["status"] == "error" + assert "repo_tag" in result["message"] + + def test_policy_name_starting_with_dash_rejected(self): + """policy_name='--flag' must be rejected.""" + from strands_robots.tools.gr00t_inference import gr00t_inference + + result = gr00t_inference( + action="start", + checkpoint_path="/data/model", + policy_name="--malicious", + ) + assert result["status"] == "error" + assert "policy_name" in result["message"] + + def test_valid_repo_url_accepted(self, monkeypatch): + """Normal https:// URL must pass the guard.""" + from strands_robots.tools import gr00t_inference as gi_mod + + # Mock _build_image to avoid actual git/docker operations + monkeypatch.setattr( + gi_mod, + "_build_image", + lambda **kwargs: {"status": "success", "message": "mocked"}, + ) + + result = gi_mod.gr00t_inference( + action="build_image", + repo_url="https://github.com/NVIDIA/Isaac-GR00T", + repo_tag="n1.7-release", + ) + # Should not be an option-injection error + assert "must not start with '-'" not in result.get("message", "") + + +class TestHostBindingHonoursUserChoice: + """R1 pin tests — host kwarg controls docker -p binding; NO auto-flip. + + Pre-R1: ``_start_service`` rewrote ``host="127.0.0.1"`` to ``"0.0.0.0"`` + when ``host_was_explicit=False``, silently widening the bind to all + interfaces. R1 drops the rewrite. The host kwarg now flows verbatim into + ``docker -p HOST:port:port``, so: + - host="127.0.0.1" (default) → docker binds loopback only + - host="0.0.0.0" (explicit) → docker binds all interfaces + """ + + def test_default_host_is_loopback_sentinel(self): + """Signature default must be None (resolves to 127.0.0.1) — AGENTS.md compliance.""" + import inspect + + from strands_robots.tools.gr00t_inference import gr00t_inference + + sig = inspect.signature(gr00t_inference) + assert sig.parameters["host"].default is None, ( + "host signature default must remain None sentinel (resolves to 127.0.0.1)" + ) + + def test_start_service_does_not_flip_default_loopback(self, monkeypatch): + """R1+R4 pin: _start_service must NOT auto-flip user's host kwarg. + + Pre-R1 _start_service rewrote host=127.0.0.1 to 0.0.0.0 when + host_was_explicit=False. Post-R1 the rewrite is gone; post-R4 + host is no longer passed into _build_inference_command at all + (the inside-container --host is hardcoded to 0.0.0.0). Either + way, the user's host kwarg must reach _start_service unchanged. + """ + import inspect + + from strands_robots.tools.gr00t_inference import _start_service + + # Static check: _start_service still takes host kwarg (controls docker -p + # via _start_container, not the inside-container --host which is now hardcoded). + sig = inspect.signature(_start_service) + assert "host" in sig.parameters, "_start_service must keep host kwarg for the docker -p host-side bind" + + # Behavioural check: invoking with host=127.0.0.1 must not raise (no + # auto-flip side-effects), and the inference cmd argv inside the + # container must bind 0.0.0.0 (R4 contract). + from strands_robots.tools.gr00t_inference import _build_inference_command + + argv = _build_inference_command( + container_name="gr00t-test", + checkpoint_path="/data/model", + port=5555, + data_config="fourier_gr1_arms_only", + embodiment_tag="gr1", + denoising_steps=4, + http_server=False, + use_tensorrt=False, + trt_engine_path="gr00t_engine", + vit_dtype="fp8", + llm_dtype="nvfp4", + dit_dtype="fp8", + api_token=None, + protocol="n1.5", + use_sim_policy_wrapper=False, + ) + host_idx = argv.index("--host") + assert argv[host_idx + 1] == "0.0.0.0", ( + "R4 contract: inference server must always bind 0.0.0.0 inside container" + ) + + def test_explicit_zero_zero_zero_zero_passes_through(self): + """R1+R4 pin: explicit host='0.0.0.0' is honoured for the docker -p bind. + + Verified via the public tool signature: host kwarg is still accepted + (R1: no auto-flip) and is destined for the docker host-side bind in + _start_container, not the inside-container --host which R4 hardcodes. + """ + import inspect + + from strands_robots.tools.gr00t_inference import gr00t_inference + + sig = inspect.signature(gr00t_inference) + assert "host" in sig.parameters, "host kwarg must remain on public tool" + # Default sentinel still None (resolves to 127.0.0.1 internally). + assert sig.parameters["host"].default is None + + +class TestSingleLabelNumericHostname: + """Verify single-label numeric hostnames (per RFC-1123) are accepted.""" + + def test_single_numeric_label_accepted(self): + """Single-label '123' is a valid hostname (RFC-1123).""" + validate_inputs(**{**_VALID_KWARGS, "host": "123"}) + + def test_multi_label_numeric_rejected(self): + """Multi-label '127.0.01' is rejected as an IP typo.""" + with pytest.raises(ValueError, match="host must be a valid IP address or hostname"): + validate_inputs(**{**_VALID_KWARGS, "host": "127.0.01"}) + + +class TestPlatformGuardForHostFallback: + """Test that host-fallback stop returns error on non-Linux platforms.""" + + def test_non_linux_platform_returns_error(self, monkeypatch): + """On non-Linux, _stop_service should error when no containers found.""" + import sys as _sys + + from strands_robots.tools.gr00t_inference import _stop_service + + # Mock no containers found + monkeypatch.setattr( + "strands_robots.tools.gr00t_inference._find_gr00t_containers", + lambda: {"status": "success", "containers": []}, + ) + monkeypatch.setattr(_sys, "platform", "darwin") + + result = _stop_service(5555) + assert result["status"] == "error" + assert "Linux" in result["message"] + + +class TestN17ProcessIdentification: + """Regression tests for N1.7 process identification — GH review thread. + + N1.7 services are started via `python -m gr00t.eval.run_gr00t_server` which + doesn't contain `inference_service.py` in cmdline. These tests ensure the + stop/status path can identify N1.7 services. + """ + + def test_n17_cmdline_detected_by_host_process_check(self, tmp_path, monkeypatch): + """_is_gr00t_host_process detects N1.7 server cmdline.""" + from strands_robots.tools.gr00t_inference import _is_gr00t_host_process + + # Simulate N1.7 cmdline: python -m gr00t.eval.run_gr00t_server --port 5555 + proc_dir = tmp_path / "proc" / "999" + proc_dir.mkdir(parents=True) + cmdline_file = proc_dir / "cmdline" + cmdline_file.write_text("python\x00-m\x00gr00t.eval.run_gr00t_server\x00--port\x005555\x00") + + called = {} + from pathlib import Path as RealPath + + def _fake_path(p): + called["p"] = p + return RealPath(str(p).replace("/proc", str(tmp_path / "proc"))) + + monkeypatch.setattr("strands_robots.tools.gr00t_inference.Path", _fake_path) + + assert _is_gr00t_host_process("999", port=5555) is True + assert called.get("p") == "/proc/999/cmdline" + + def test_n17_cmdline_wrong_port_rejected(self, tmp_path, monkeypatch): + """N1.7 server on wrong port is not killed.""" + from strands_robots.tools.gr00t_inference import _is_gr00t_host_process + + proc_dir = tmp_path / "proc" / "999" + proc_dir.mkdir(parents=True) + cmdline_file = proc_dir / "cmdline" + cmdline_file.write_text("python\x00-m\x00gr00t.eval.run_gr00t_server\x00--port\x008000\x00") + + called = {} + from pathlib import Path as RealPath + + def _fake_path(p): + called["p"] = p + return RealPath(str(p).replace("/proc", str(tmp_path / "proc"))) + + monkeypatch.setattr("strands_robots.tools.gr00t_inference.Path", _fake_path) + + # Request port 80 — should not match 8000 + assert _is_gr00t_host_process("999", port=80) is False + assert called.get("p") == "/proc/999/cmdline" + + def test_n15_cmdline_still_detected(self, tmp_path, monkeypatch): + """N1.5/N1.6 cmdline (inference_service.py) still works after N1.7 support.""" + from strands_robots.tools.gr00t_inference import _is_gr00t_host_process + + proc_dir = tmp_path / "proc" / "123" + proc_dir.mkdir(parents=True) + cmdline_file = proc_dir / "cmdline" + cmdline_file.write_text("python\x00inference_service.py\x00--port\x005555\x00") + + from pathlib import Path as RealPath + + def _fake_path(p): + return RealPath(str(p).replace("/proc", str(tmp_path / "proc"))) + + monkeypatch.setattr("strands_robots.tools.gr00t_inference.Path", _fake_path) + + assert _is_gr00t_host_process("123", port=5555) is True + + +class TestExpandedParamValidationExtended: + """Extended tests for image_name, volumes, and container_command — covers happy paths.""" + + def test_valid_image_name(self): + from strands_robots.tools.gr00t_inference import validate_inputs + + # Should not raise + validate_inputs( + action="start", + data_config="fourier_gr1_arms_only", + embodiment_tag="gr1", + port=5555, + host="127.0.0.1", + vit_dtype="fp8", + llm_dtype="nvfp4", + dit_dtype="fp8", + checkpoint_path="/tmp/ckpt", + trt_engine_path="gr00t_engine", + container_name="gr00t", + protocol="n1.5", + image_name="localhost:5000/myorg/img:tag", + ) + + def test_invalid_image_name_shell_meta(self): + import pytest + + from strands_robots.tools.gr00t_inference import validate_inputs + + with pytest.raises(ValueError, match="image_name"): + validate_inputs( + action="start", + data_config="fourier_gr1_arms_only", + embodiment_tag="gr1", + port=5555, + host="127.0.0.1", + vit_dtype="fp8", + llm_dtype="nvfp4", + dit_dtype="fp8", + checkpoint_path="/tmp/ckpt", + trt_engine_path="gr00t_engine", + container_name="gr00t", + protocol="n1.5", + image_name="evil;rm -rf /", + ) + + def test_volumes_path_traversal_rejected(self): + import pytest + + from strands_robots.tools.gr00t_inference import validate_inputs + + with pytest.raises(ValueError, match="volumes"): + validate_inputs( + action="start", + data_config="fourier_gr1_arms_only", + embodiment_tag="gr1", + port=5555, + host="127.0.0.1", + vit_dtype="fp8", + llm_dtype="nvfp4", + dit_dtype="fp8", + checkpoint_path="/tmp/ckpt", + trt_engine_path="gr00t_engine", + container_name="gr00t", + protocol="n1.5", + volumes={"../../etc/passwd": "/data"}, + ) + + def test_container_command_shell_meta_rejected(self): + import pytest + + from strands_robots.tools.gr00t_inference import validate_inputs + + with pytest.raises(ValueError, match="container_command"): + validate_inputs( + action="start", + data_config="fourier_gr1_arms_only", + embodiment_tag="gr1", + port=5555, + host="127.0.0.1", + vit_dtype="fp8", + llm_dtype="nvfp4", + dit_dtype="fp8", + checkpoint_path="/tmp/ckpt", + trt_engine_path="gr00t_engine", + container_name="gr00t", + protocol="n1.5", + container_command="tail -f /dev/null; rm -rf /", + ) + + def test_valid_container_command(self): + from strands_robots.tools.gr00t_inference import validate_inputs + + # Should not raise - legitimate container commands without shell metas + validate_inputs( + action="start", + data_config="fourier_gr1_arms_only", + embodiment_tag="gr1", + port=5555, + host="127.0.0.1", + vit_dtype="fp8", + llm_dtype="nvfp4", + dit_dtype="fp8", + checkpoint_path="/tmp/ckpt", + trt_engine_path="gr00t_engine", + container_name="gr00t", + protocol="n1.5", + container_command="tail -f /dev/null", + ) + + def test_valid_volumes(self): + from strands_robots.tools.gr00t_inference import validate_inputs + + # Should not raise + validate_inputs( + action="start", + data_config="fourier_gr1_arms_only", + embodiment_tag="gr1", + port=5555, + host="127.0.0.1", + vit_dtype="fp8", + llm_dtype="nvfp4", + dit_dtype="fp8", + checkpoint_path="/tmp/ckpt", + trt_engine_path="gr00t_engine", + container_name="gr00t", + protocol="n1.5", + volumes={"/tmp/checkpoints": "/data/checkpoints"}, + ) + + +class TestHostKwargNotPlumbed: + """R2 pin tests -- ``host_was_explicit`` kwarg is no longer plumbed. + + Pre-R2: ``gr00t_inference()`` set ``_host_was_explicit = host is not None`` + and threaded it through ``_lifecycle`` and the ``start`` / ``restart`` + dispatch into ``_start_service``, where it was unused (``# noqa: ARG001``). + The auto-flip the flag once gated was removed in R1 (commit ecf5f0f), so + the kwarg became dead plumbing. + + R2: removed per AGENTS.md > Key Conventions #10 ('No dead code'). These + pins assert the removal so a future refactor that re-introduces the flag + re-introduces a meaningless plumbing chain. + """ + + def test_start_service_signature_has_no_host_was_explicit(self): + """``_start_service`` signature must not contain ``host_was_explicit``.""" + import inspect + + from strands_robots.tools.gr00t_inference import _start_service + + params = inspect.signature(_start_service).parameters + assert "host_was_explicit" not in params, ( + "Dead kwarg `host_was_explicit` reintroduced into _start_service signature" + ) + + def test_lifecycle_signature_has_no_host_was_explicit(self): + """``_lifecycle`` signature must not contain ``host_was_explicit``.""" + import inspect + + from strands_robots.tools.gr00t_inference import _lifecycle + + params = inspect.signature(_lifecycle).parameters + assert "host_was_explicit" not in params, ( + "Dead kwarg `host_was_explicit` reintroduced into _lifecycle signature" + ) + + def test_start_dispatch_does_not_pass_host_was_explicit(self, monkeypatch): + """``gr00t_inference(action='start')`` must not forward ``host_was_explicit``.""" + from strands_robots.tools.gr00t_inference import gr00t_inference + + captured = {} + + def _mock_start_service(**kwargs): + captured.update(kwargs) + return {"status": "error", "message": "mocked"} + + monkeypatch.setattr( + "strands_robots.tools.gr00t_inference._start_service", + _mock_start_service, + ) + gr00t_inference(action="start", checkpoint_path="/data/model", host="127.0.0.1") + assert "host_was_explicit" not in captured, ( + "`start` dispatch passed dead `host_was_explicit` kwarg to _start_service" + ) + + def test_restart_dispatch_does_not_pass_host_was_explicit(self, monkeypatch): + """``gr00t_inference(action='restart')`` must not forward ``host_was_explicit``.""" + from strands_robots.tools.gr00t_inference import gr00t_inference + + captured = {} + + def _mock_start_service(**kwargs): + captured.update(kwargs) + return {"status": "success", "message": "mocked"} + + monkeypatch.setattr( + "strands_robots.tools.gr00t_inference._start_service", + _mock_start_service, + ) + monkeypatch.setattr( + "strands_robots.tools.gr00t_inference._stop_service", + lambda port: None, + ) + monkeypatch.setattr("time.sleep", lambda _: None) + + gr00t_inference(action="restart", checkpoint_path="/data/model") + assert "host_was_explicit" not in captured, ( + "`restart` dispatch passed dead `host_was_explicit` kwarg to _start_service" + ) + + +class TestReviewRound8Fixes: + """Regression tests for review round-8 fixes (2026-05-22 21:44 UTC). + + Covers: + - restart path forwarding host_was_explicit + - colon rejection in volume paths (docker -v mount-redirect) + - digest-pinned image references + - TypeError handling in validation wrapper + - dash-prefix rejection in volume paths + """ + + def test_volume_path_colon_rejected(self): + """Volume paths containing ':' must be rejected (docker -v mount-redirect).""" + from strands_robots.tools.gr00t_inference import gr00t_inference + + result = gr00t_inference( + action="start_container", + image_name="gr00t:latest", + volumes={"/legit/dir:rw,nosuid": "/container/path"}, + ) + assert result["status"] == "error" + assert ":" in result["message"] or "colon" in result["message"].lower() + + def test_volume_value_colon_rejected(self): + """Container-side volume paths containing ':' must also be rejected.""" + from strands_robots.tools.gr00t_inference import gr00t_inference + + result = gr00t_inference( + action="start_container", + image_name="gr00t:latest", + volumes={"/host/path": "/container:path"}, + ) + assert result["status"] == "error" + assert ":" in result["message"] + + def test_volume_path_dash_prefix_rejected(self): + """Volume paths starting with '-' must be rejected (option injection).""" + from strands_robots.tools.gr00t_inference import gr00t_inference + + result = gr00t_inference( + action="start_container", + image_name="gr00t:latest", + volumes={"--privileged=foo": "/bar"}, + ) + assert result["status"] == "error" + assert "'-'" in result["message"] or "start with" in result["message"] + + def test_digest_pinned_image_accepted(self): + """Digest-pinned image refs (registry/path@sha256:hex) must be accepted.""" + from strands_robots.tools.gr00t_inference import gr00t_inference + + # Should NOT fail validation on image_name (may fail later on docker ops) + result = gr00t_inference( + action="start_container", + image_name="nvcr.io/nvidia/gr00t@sha256:" + "a" * 64, + ) + # If it fails, it should NOT be an image_name validation error + if result["status"] == "error": + assert "valid Docker image" not in result["message"] + + def test_type_error_returns_structured_error(self): + """TypeError from bad parameter types must return structured error, not raise.""" + from strands_robots.tools.gr00t_inference import gr00t_inference + + # port="5555" (str instead of int) -> TypeError on `1 <= port <= 65535` + result = gr00t_inference(action="start", checkpoint_path="/data/model", port="5555") + assert result["status"] == "error" + # Must not propagate as unhandled exception - returns dict + + def test_end_to_end_bogus_action_returns_error_dict(self): + """Bogus action returns structured error dict (not raw exception).""" + from strands_robots.tools.gr00t_inference import gr00t_inference + + result = gr00t_inference(action="bogus_action") + assert isinstance(result, dict) + assert result["status"] == "error" + assert "Unknown action" in result["message"] or "bogus_action" in result["message"] + + +class TestImageOnlyBranchValidation: + """Tests for validation on image-only actions (build_image, download_checkpoint, start_container).""" + + def test_container_name_validated_on_start_container(self): + """container_name must be validated on start_container (image-only branch).""" + from strands_robots.tools.gr00t_inference import validate_inputs + + with pytest.raises(ValueError, match="container_name"): + validate_inputs( + action="start_container", + data_config="so100", + embodiment_tag="so100", + port=5555, + host="127.0.0.1", + vit_dtype="fp8", + llm_dtype="nvfp4", + dit_dtype="fp8", + checkpoint_path=None, + trt_engine_path="/opt/engine", + container_name="--privileged", + protocol="n1.5", + image_name=None, + volumes=None, + container_command=None, + repo_url=None, + repo_tag=None, + policy_name=None, + ) + + def test_policy_name_dash_rejected_on_start_container(self): + """policy_name starting with '-' must be rejected on image-only actions.""" + from strands_robots.tools.gr00t_inference import validate_inputs + + with pytest.raises(ValueError, match="policy_name"): + validate_inputs( + action="start_container", + data_config="so100", + embodiment_tag="so100", + port=5555, + host="127.0.0.1", + vit_dtype="fp8", + llm_dtype="nvfp4", + dit_dtype="fp8", + checkpoint_path=None, + trt_engine_path="/opt/engine", + container_name=None, + protocol="n1.5", + image_name=None, + volumes=None, + container_command=None, + repo_url=None, + repo_tag=None, + policy_name="--malicious", + ) + + def test_valid_container_name_accepted_on_start_container(self): + """Valid container_name passes on start_container.""" + from strands_robots.tools.gr00t_inference import validate_inputs + + # Should not raise + validate_inputs( + action="start_container", + data_config="so100", + embodiment_tag="so100", + port=5555, + host="127.0.0.1", + vit_dtype="fp8", + llm_dtype="nvfp4", + dit_dtype="fp8", + checkpoint_path=None, + trt_engine_path="/opt/engine", + container_name="my-gr00t-container", + protocol="n1.5", + image_name=None, + volumes=None, + container_command=None, + repo_url=None, + repo_tag=None, + policy_name=None, + ) + + +class TestRegexBugFixesR4: + """R4 pin tests for the 4 regex bugs raised in PR #90 review. + + Each test fails on pre-R4 code and passes after R4 lands. + """ + + # === Bug 1: _DOCKER_IMAGE_RE registry port range === + + def test_registry_port_99999_rejected(self): + """Pre-R4 the regex accepted :99999 (>65535). Post-R4 we range-check.""" + from strands_robots.tools.gr00t_inference import _is_valid_docker_image_ref + + assert not _is_valid_docker_image_ref("localhost:99999/myorg/img:tag"), ( + "R4 regression: registry port 99999 must be rejected (TCP max is 65535)" + ) + + def test_registry_port_65535_accepted(self): + """65535 is the max valid TCP port — must be accepted.""" + from strands_robots.tools.gr00t_inference import _is_valid_docker_image_ref + + assert _is_valid_docker_image_ref("localhost:65535/myorg/img:tag") + + def test_registry_port_5000_accepted(self): + """Common private-registry port — sanity check.""" + from strands_robots.tools.gr00t_inference import _is_valid_docker_image_ref + + assert _is_valid_docker_image_ref("localhost:5000/myorg/img:tag") + + def test_registry_port_zero_rejected(self): + """Port 0 is not a valid bind target.""" + from strands_robots.tools.gr00t_inference import _is_valid_docker_image_ref + + assert not _is_valid_docker_image_ref("localhost:0/myorg/img:tag") + + def test_no_port_still_accepted(self): + """Image refs without a registry port must still match.""" + from strands_robots.tools.gr00t_inference import _is_valid_docker_image_ref + + assert _is_valid_docker_image_ref("nvcr.io/nvidia/gr00t:n1.7") + assert _is_valid_docker_image_ref("gr00t:latest") + + def test_digest_pinned_image_accepted(self): + """Digest-pinned refs (@sha256:...) must continue to match.""" + from strands_robots.tools.gr00t_inference import _is_valid_docker_image_ref + + digest = "a" * 64 + assert _is_valid_docker_image_ref(f"nvcr.io/nvidia/gr00t@sha256:{digest}") + + # === Bug 3: _HOSTNAME_RE trailing-dot FQDN === + + def test_trailing_dot_fqdn_accepted(self): + """RFC 1034 §3.1: FQDNs may end with a dot to disambiguate. + + Pre-R4 the regex required the last label to end with [a-zA-Z0-9], + which rejected legitimate FQDNs like 'host.example.com.'. + """ + from strands_robots.tools.gr00t_inference import _HOSTNAME_RE + + assert _HOSTNAME_RE.match("host.example.com."), ( + "R4 regression: trailing-dot FQDN must be accepted per RFC 1034 §3.1" + ) + # Without trailing dot still accepted + assert _HOSTNAME_RE.match("host.example.com") + + def test_single_dot_alone_rejected(self): + """A bare '.' is not a valid hostname.""" + from strands_robots.tools.gr00t_inference import _HOSTNAME_RE + + assert not _HOSTNAME_RE.match(".") + + # === Bug 4: hostname total length cap (already implemented; pin it) === + + def test_host_validation_rejects_oversize(self): + """RFC 1035 §2.3.4: hostname must not exceed 253 octets total.""" + from strands_robots.tools.gr00t_inference import validate_inputs + + oversize = ".".join(["a" * 60] * 5) # 60*5 + 4 dots = 304 > 253 + assert len(oversize) > 253 + try: + validate_inputs( + action="start", + port=5555, + host=oversize, + protocol="n1.5", + data_config="fourier_gr1_arms_only", + embodiment_tag="gr1", + container_name=None, + vit_dtype="fp8", + llm_dtype="nvfp4", + dit_dtype="fp8", + checkpoint_path=None, + trt_engine_path="gr00t_engine", + image_name=None, + volumes=None, + container_command="tail -f /dev/null", + policy_name=None, + ) + raise AssertionError("Expected ValueError for hostname > 253 octets") + except ValueError as e: + assert "253" in str(e), f"Error should mention RFC 1035 limit; got: {e}" + + # === Bug 2 / IPv4 typo regression — explicit '127.0.01' typo case === + + def test_host_typo_127_0_01_rejected(self): + """The typo called out in the PR description must be rejected. + + '127.0.01' looks like an IPv4 attempt to a human but is not a valid + IPv4 string under ipaddress.ip_address. Without the _ALL_NUMERIC_RE + guard it would be accepted as a hostname (it matches RFC-952), which + would then fail at runtime with a confusing connection error. + """ + from strands_robots.tools.gr00t_inference import validate_inputs + + try: + validate_inputs( + action="start", + port=5555, + host="127.0.01", + protocol="n1.5", + data_config="fourier_gr1_arms_only", + embodiment_tag="gr1", + container_name=None, + vit_dtype="fp8", + llm_dtype="nvfp4", + dit_dtype="fp8", + checkpoint_path=None, + trt_engine_path="gr00t_engine", + image_name=None, + volumes=None, + container_command="tail -f /dev/null", + policy_name=None, + ) + raise AssertionError("Expected ValueError for '127.0.01' IP typo") + except ValueError as e: + assert "host" in str(e).lower() + + +class TestDockerImageNumericTagRegression: + """Pin: numeric-only Docker tags must not be falsely rejected. + + Pre-fix, the port-capture group in _DOCKER_IMAGE_RE greedily matched + :digits even without a following / path component, causing the port-range + check to reject valid name:tag refs where the tag was purely numeric. + + Reproducer (pre-fix): + >>> _is_valid_docker_image_ref('myimage:0') # False (should be True) + >>> _is_valid_docker_image_ref('myimage:65536') # False (should be True) + >>> _is_valid_docker_image_ref('myimage:99999') # False (should be True) + + Fix: lookahead (?=/) on the port group so :digits is only interpreted as a + registry port when followed by a path component. + """ + + def test_numeric_tag_zero_accepted(self): + """Tag ':0' is valid -- common in dev builds.""" + from strands_robots.tools.gr00t_inference import _is_valid_docker_image_ref + + assert _is_valid_docker_image_ref("myimage:0"), ( + "Regression: numeric tag ':0' must not be rejected as an invalid port" + ) + + def test_numeric_tag_above_port_range_accepted(self): + """Tag ':65536' is valid -- it is a tag, not a port.""" + from strands_robots.tools.gr00t_inference import _is_valid_docker_image_ref + + assert _is_valid_docker_image_ref("myimage:65536"), ( + "Regression: numeric tag ':65536' must not be rejected as an invalid port" + ) + + def test_numeric_tag_99999_accepted(self): + """Tag ':99999' is valid -- common date-style build IDs.""" + from strands_robots.tools.gr00t_inference import _is_valid_docker_image_ref + + assert _is_valid_docker_image_ref("myimage:99999"), ( + "Regression: numeric tag ':99999' must not be rejected as an invalid port" + ) + + def test_numeric_tag_five_digit_accepted(self): + """Tag ':23456' is valid -- common sequential build numbers.""" + from strands_robots.tools.gr00t_inference import _is_valid_docker_image_ref + + assert _is_valid_docker_image_ref("gr00t:23456"), ( + "Regression: numeric tag ':23456' must not be rejected as an invalid port" + ) + + def test_registry_port_with_path_still_range_checked(self): + """Port in host:port/path form must still be range-checked.""" + from strands_robots.tools.gr00t_inference import _is_valid_docker_image_ref + + # Valid port with path -- accepted + assert _is_valid_docker_image_ref("localhost:5000/myorg/img:tag") + assert _is_valid_docker_image_ref("localhost:65535/myorg/img:tag") + + # Invalid port (>65535) with path -- rejected + assert not _is_valid_docker_image_ref("localhost:99999/myorg/img:tag"), ( + "Registry port 99999 with /path must still be rejected (TCP max is 65535)" + ) + assert not _is_valid_docker_image_ref("localhost:100000/img:tag"), ( + "Registry port 100000 with /path must still be rejected" + ) + + +class TestPathTraversalPosixBackslash: + """R2 pin tests -- _validate_path must not over-reject POSIX paths containing + literal backslashes. + + Pre-R2: ``_validate_path`` split on both ``/`` and ``\\`` via + ``re.split(r"[/\\\\]", value)``. On POSIX, ``\\`` is a legal filename byte + (only ``/`` and NUL are forbidden), so a path like ``a\\..\\b`` -- a single + legitimate filename containing literal backslashes -- was wrongly flagged + as ``..`` traversal because the splitter isolated ``..`` between the + backslash bytes. + + R2: ``_validate_path`` splits on ``/`` only. docker -v interprets just ``/`` + as a separator on Linux (the only platform this tool supports), so this + matches the executor's contract. Real ``..`` traversal between ``/`` + separators (e.g. ``/foo/../etc``) remains rejected. + """ + + def test_posix_backslash_in_filename_accepted(self): + """POSIX path with literal backslash bytes is not traversal.""" + # checkpoint_path is one of the validated path kwargs. + validate_inputs(**{**_VALID_KWARGS, "checkpoint_path": "/data/odd\\..\\name"}) + + def test_real_traversal_still_rejected(self): + """Genuine '..' between '/' separators must still be rejected.""" + with pytest.raises(ValueError, match="path traversal"): + validate_inputs(**{**_VALID_KWARGS, "checkpoint_path": "/data/../etc/passwd"}) + + def test_real_traversal_relative_still_rejected(self): + """Genuine relative '..' traversal must still be rejected.""" + with pytest.raises(ValueError, match="path traversal"): + validate_inputs(**{**_VALID_KWARGS, "checkpoint_path": "../etc/passwd"}) + + def test_backslash_dotdot_backslash_filename_accepted(self): + """Filename with embedded \\..\\ as literal bytes (no '/' separator) accepted.""" + # "a\..\b" as a single-component filename -- legal on POSIX. + validate_inputs(**{**_VALID_KWARGS, "trt_engine_path": "a\\..\\b"}) + + +class TestHfPathTraversalValidation: + """Pin tests for hf_repo/hf_subfolder/hf_local_dir validation (R3). + + These fail on pre-fix code where hf_subfolder flows unvalidated into + docker --model-path argv via _lifecycle(). See AGENTS.md > Review + Learnings (#92) > 'LLM Input Safety > Validate before subprocess + interpolation'. + """ + + def test_hf_subfolder_traversal_rejected(self): + """hf_subfolder='../../etc' must be rejected by validate_inputs.""" + from strands_robots.tools.gr00t_inference import validate_inputs + + with pytest.raises(ValueError, match="hf_subfolder"): + validate_inputs( + action="lifecycle", + data_config="fourier_gr1_arms_only", + embodiment_tag="gr1", + port=5555, + host="127.0.0.1", + vit_dtype="fp8", + llm_dtype="nvfp4", + dit_dtype="fp8", + checkpoint_path=None, + trt_engine_path="gr00t_engine", + container_name=None, + protocol="n1.5", + hf_repo="nvidia/GR00T-N1.7-LIBERO", + hf_subfolder="../../etc/passwd", + lifecycle="full", + ) + + def test_hf_subfolder_shell_meta_rejected(self): + """hf_subfolder with shell metacharacters must be rejected.""" + from strands_robots.tools.gr00t_inference import validate_inputs + + with pytest.raises(ValueError, match="hf_subfolder"): + validate_inputs( + action="lifecycle", + data_config="fourier_gr1_arms_only", + embodiment_tag="gr1", + port=5555, + host="127.0.0.1", + vit_dtype="fp8", + llm_dtype="nvfp4", + dit_dtype="fp8", + checkpoint_path=None, + trt_engine_path="gr00t_engine", + container_name=None, + protocol="n1.5", + hf_repo="nvidia/GR00T-N1.7-LIBERO", + hf_subfolder="libero;rm -rf /", + lifecycle="full", + ) + + def test_hf_repo_malformed_rejected(self): + """hf_repo must be org/name format.""" + from strands_robots.tools.gr00t_inference import validate_inputs + + with pytest.raises(ValueError, match="hf_repo"): + validate_inputs( + action="lifecycle", + data_config="fourier_gr1_arms_only", + embodiment_tag="gr1", + port=5555, + host="127.0.0.1", + vit_dtype="fp8", + llm_dtype="nvfp4", + dit_dtype="fp8", + checkpoint_path=None, + trt_engine_path="gr00t_engine", + container_name=None, + protocol="n1.5", + hf_repo="../../etc/shadow", + lifecycle="full", + ) + + def test_hf_local_dir_traversal_rejected(self): + """hf_local_dir with traversal must be rejected.""" + from strands_robots.tools.gr00t_inference import validate_inputs + + with pytest.raises(ValueError, match="hf_local_dir"): + validate_inputs( + action="lifecycle", + data_config="fourier_gr1_arms_only", + embodiment_tag="gr1", + port=5555, + host="127.0.0.1", + vit_dtype="fp8", + llm_dtype="nvfp4", + dit_dtype="fp8", + checkpoint_path=None, + trt_engine_path="gr00t_engine", + container_name=None, + protocol="n1.5", + hf_repo="nvidia/GR00T-N1.7-LIBERO", + hf_local_dir="/data/../../../etc", + lifecycle="full", + ) + + def test_lifecycle_invalid_phase_rejected(self): + """lifecycle phase must be 'full' or 'teardown'.""" + from strands_robots.tools.gr00t_inference import validate_inputs + + with pytest.raises(ValueError, match="lifecycle"): + validate_inputs( + action="lifecycle", + data_config="fourier_gr1_arms_only", + embodiment_tag="gr1", + port=5555, + host="127.0.0.1", + vit_dtype="fp8", + llm_dtype="nvfp4", + dit_dtype="fp8", + checkpoint_path=None, + trt_engine_path="gr00t_engine", + container_name=None, + protocol="n1.5", + lifecycle="exec_shell", + ) + + def test_valid_hf_params_pass(self): + """Valid hf_repo/hf_subfolder/lifecycle should not raise.""" + from strands_robots.tools.gr00t_inference import validate_inputs + + # Should not raise + validate_inputs( + action="lifecycle", + data_config="fourier_gr1_arms_only", + embodiment_tag="gr1", + port=5555, + host="127.0.0.1", + vit_dtype="fp8", + llm_dtype="nvfp4", + dit_dtype="fp8", + checkpoint_path=None, + trt_engine_path="gr00t_engine", + container_name=None, + protocol="n1.5", + hf_repo="nvidia/GR00T-N1.7-LIBERO", + hf_subfolder="libero_spatial", + hf_local_dir="/data/checkpoints/libero", + lifecycle="full", + ) + + +class TestHfValidationOnDownloadCheckpoint: + """Pin: hf_* validation must run for action='download_checkpoint'. + + Regression: in R3, hf_repo/hf_subfolder/hf_local_dir validation was + placed AFTER the `_image_only_actions` early-return inside + validate_inputs(). Since 'download_checkpoint' is in that early-return + set, the hf_* checks were never reached when called via that action - + silently bypassing the path-traversal guard that the docstring + advertises. R4 hoists the hf_*/lifecycle validation BEFORE the + action-specific gates so it applies regardless of action. + + These tests fail on pre-R4 code (hf_* checks bypassed for + 'download_checkpoint') and pass on post-R4 code. + """ + + _COMMON_KWARGS = dict( + data_config="fourier_gr1_arms_only", + embodiment_tag="gr1", + port=5555, + host="127.0.0.1", + vit_dtype="fp8", + llm_dtype="nvfp4", + dit_dtype="fp8", + checkpoint_path=None, + trt_engine_path="gr00t_engine", + container_name=None, + protocol="n1.5", + ) + + def test_hf_subfolder_traversal_rejected_on_download_checkpoint(self): + """hf_subfolder='../../etc' must be rejected on 'download_checkpoint'.""" + from strands_robots.tools.gr00t_inference import validate_inputs + + with pytest.raises(ValueError, match=r"hf_subfolder.*\.\."): + validate_inputs( + action="download_checkpoint", + hf_subfolder="../../etc/passwd", + **self._COMMON_KWARGS, + ) + + def test_hf_local_dir_traversal_rejected_on_download_checkpoint(self): + """hf_local_dir with '..' must be rejected on 'download_checkpoint'.""" + from strands_robots.tools.gr00t_inference import validate_inputs + + with pytest.raises(ValueError, match=r"hf_local_dir.*\.\."): + validate_inputs( + action="download_checkpoint", + hf_local_dir="../../etc", + **self._COMMON_KWARGS, + ) + + def test_hf_repo_invalid_format_rejected_on_download_checkpoint(self): + """hf_repo='--evil/x' (option-injection-like) must be rejected.""" + from strands_robots.tools.gr00t_inference import validate_inputs + + with pytest.raises(ValueError, match=r"hf_repo.*org/name"): + validate_inputs( + action="download_checkpoint", + hf_repo="--evil/x", + **self._COMMON_KWARGS, + ) + + def test_hf_subfolder_traversal_rejected_on_lifecycle_too(self): + """Sanity: lifecycle path still validates hf_subfolder (no regression).""" + from strands_robots.tools.gr00t_inference import validate_inputs + + with pytest.raises(ValueError, match=r"hf_subfolder.*\.\."): + validate_inputs( + action="lifecycle", + hf_subfolder="../escape", + lifecycle="full", + **self._COMMON_KWARGS, + ) + + def test_valid_hf_params_pass_on_download_checkpoint(self): + """Sanity: legitimate hf_* values must not raise on 'download_checkpoint'.""" + from strands_robots.tools.gr00t_inference import validate_inputs + + validate_inputs( + action="download_checkpoint", + hf_repo="nvidia/GR00T-N1.7-LIBERO", + hf_subfolder="libero_spatial", + hf_local_dir="/data/checkpoints/libero", + **self._COMMON_KWARGS, + ) + + +class TestHfRepoSegmentRejection: + """Pin: hf_repo segment-level checks reject leading '-' and bare '.' / '..' segments. + + The base regex ``^[a-zA-Z0-9_.-]+/[a-zA-Z0-9_.-]+$`` accepts ``--evil/x``, + ``org/..`` and ``./org`` because '-' and '.' are members of the character + class. Such values are not legal HF repo ids and an option-injection-like + leading '-' must not reach downstream argv. Pinned per the R5 review thread + on `gr00t_inference.py:253` ("looser than the path validators next to it"). + + Pre-fix these inputs slip through the regex; post-fix the segment loop + rejects them with the same `org/name` error message. + """ + + _COMMON_KWARGS = dict( + data_config="fourier_gr1_arms_only", + embodiment_tag="gr1", + port=5555, + host="127.0.0.1", + vit_dtype="fp8", + llm_dtype="nvfp4", + dit_dtype="fp8", + checkpoint_path=None, + trt_engine_path="gr00t_engine", + container_name=None, + protocol="n1.5", + ) + + def test_hf_repo_dotdot_segment_rejected(self): + """hf_repo='org/..' is regex-valid but rejected by segment check.""" + from strands_robots.tools.gr00t_inference import validate_inputs + + with pytest.raises(ValueError, match=r"hf_repo.*org/name"): + validate_inputs( + action="download_checkpoint", + hf_repo="org/..", + **self._COMMON_KWARGS, + ) + + def test_hf_repo_leading_dot_segment_rejected(self): + """hf_repo='./org' is regex-valid but rejected by segment check.""" + from strands_robots.tools.gr00t_inference import validate_inputs + + with pytest.raises(ValueError, match=r"hf_repo.*org/name"): + validate_inputs( + action="download_checkpoint", + hf_repo="./org", + **self._COMMON_KWARGS, + ) + + def test_hf_repo_dotdot_first_segment_rejected(self): + """hf_repo='../org' is regex-valid (single '/') but rejected.""" + from strands_robots.tools.gr00t_inference import validate_inputs + + with pytest.raises(ValueError, match=r"hf_repo.*org/name"): + validate_inputs( + action="download_checkpoint", + hf_repo="../org", + **self._COMMON_KWARGS, + ) + + def test_hf_repo_second_segment_dash_rejected(self): + """hf_repo='org/--evil' (option-injection in name) must be rejected.""" + from strands_robots.tools.gr00t_inference import validate_inputs + + with pytest.raises(ValueError, match=r"hf_repo.*org/name"): + validate_inputs( + action="download_checkpoint", + hf_repo="org/--evil", + **self._COMMON_KWARGS, + ) + + def test_legitimate_hf_repos_with_dashes_still_accepted(self): + """Legitimate repo ids with internal dashes / dots must still pass.""" + from strands_robots.tools.gr00t_inference import validate_inputs + + for valid in ( + "nvidia/GR00T-N1.7-LIBERO", + "a-b/c-d", + "org_x/repo.name-v2", + "_x/_y", + ): + validate_inputs( + action="download_checkpoint", + hf_repo=valid, + **self._COMMON_KWARGS, + ) + + +class TestInferenceServerBindsAllInterfaces: + """Pin: the inference server inside the container always binds 0.0.0.0. + + Regression: pre-R4 the `host` kwarg flowed verbatim into BOTH the + docker `-p HOST:port:port` host-side bind AND the inference server's + `--host` flag inside the container. With the new default + `host="127.0.0.1"`, the service bound to container-loopback and the + docker port-publish forwarded to nothing -- the headline contract + ("loopback default is reachable") was broken end-to-end. + + R4 hardcodes `--host 0.0.0.0` for the inference server inside the + container; the `host` kwarg now exclusively controls the host-side + bind. Tests fail on pre-R4 code and pass on post-R4. + """ + + def _build_argv(self, **overrides): + from strands_robots.tools.gr00t_inference import _build_inference_command + + defaults = dict( + container_name="gr00t-test", + checkpoint_path="/data/checkpoints/x", + port=5555, + data_config="fourier_gr1_arms_only", + embodiment_tag="gr1", + denoising_steps=4, + http_server=False, + use_tensorrt=False, + trt_engine_path="gr00t_engine", + vit_dtype="fp8", + llm_dtype="nvfp4", + dit_dtype="fp8", + api_token=None, + protocol="n1.5", + use_sim_policy_wrapper=False, + ) + defaults.update(overrides) + return _build_inference_command(**defaults) + + def test_n15_inference_server_binds_all_interfaces(self): + """N1.5 protocol must include '--host 0.0.0.0' regardless of caller.""" + argv = self._build_argv(protocol="n1.5") + # Find --host flag and assert its value is 0.0.0.0 + idx = argv.index("--host") + assert argv[idx + 1] == "0.0.0.0", f"expected --host 0.0.0.0, got {argv[idx + 1]!r}" + + def test_n16_inference_server_binds_all_interfaces(self): + """N1.6 protocol must include '--host 0.0.0.0'.""" + argv = self._build_argv(protocol="n1.6") + idx = argv.index("--host") + assert argv[idx + 1] == "0.0.0.0" + + def test_n17_inference_server_binds_all_interfaces(self): + """N1.7 protocol must include '--host 0.0.0.0'.""" + argv = self._build_argv(protocol="n1.7") + idx = argv.index("--host") + assert argv[idx + 1] == "0.0.0.0" + + def test_build_inference_command_signature_excludes_host(self): + """Pin: host kwarg must be removed from _build_inference_command. + + AGENTS.md > Conventions: 'No dead code'. host is no longer used + inside the cmd builder (the inside-container --host is hardcoded + to 0.0.0.0 in R4), so the parameter is removed from the signature. + Verified via inspect.signature to keep the assertion static-tool-friendly. + """ + import inspect + + from strands_robots.tools.gr00t_inference import _build_inference_command + + sig = inspect.signature(_build_inference_command) + assert "host" not in sig.parameters, ( + "R4 contract: _build_inference_command must NOT accept a host kwarg " + "(inside-container --host is hardcoded to 0.0.0.0). " + f"Found host in signature: {sig.parameters}" + ) + + +class TestPortBoolRejected: + """Pin: port=True / port=False is rejected by the type-check. + + Regression: ``isinstance(True, int) is True`` in Python (bool subclasses int) + and the range check ``1 <= True <= 65535`` evaluates True (True == 1). + Pre-fix, ``port=True`` passed validation and reached ``--port`` argv as the + string ``"True"`` -- a subtle failure mode an LLM caller could trip on. + + Pinned per the R5 review thread on ``gr00t_inference.py:223``. + """ + + _COMMON_KWARGS = dict( + data_config="fourier_gr1_arms_only", + embodiment_tag="gr1", + host="127.0.0.1", + vit_dtype="fp8", + llm_dtype="nvfp4", + dit_dtype="fp8", + checkpoint_path=None, + trt_engine_path="gr00t_engine", + container_name=None, + protocol="n1.5", + ) + + def test_port_true_rejected(self): + """port=True must be rejected (bool is not a valid port type).""" + from strands_robots.tools.gr00t_inference import validate_inputs + + with pytest.raises(ValueError, match=r"port must be an integer.*bool"): + validate_inputs(action="start", port=True, **self._COMMON_KWARGS) + + def test_port_false_rejected(self): + """port=False must be rejected (bool is not a valid port type, even though False==0).""" + from strands_robots.tools.gr00t_inference import validate_inputs + + with pytest.raises(ValueError, match=r"port must be an integer.*bool"): + validate_inputs(action="start", port=False, **self._COMMON_KWARGS) + + def test_port_int_still_accepted(self): + """Inverse pin: a real int port still passes after the bool guard.""" + from strands_robots.tools.gr00t_inference import validate_inputs + + # Should not raise + validate_inputs(action="start", port=5555, **self._COMMON_KWARGS) + + +class TestHfRepoLeadingDotSegments: + """Pin: hf_repo segments starting with '.' are rejected. + + Regression: R5 closed bare ``.`` / ``..`` segments and leading-``-``, but + the regex ``^[a-zA-Z0-9_.-]+/[a-zA-Z0-9_.-]+$`` plus the segment loop + still accepted ``.org/name``, ``org/.git``, ``...../name``, etc. + HuggingFace's API rejects these so practical exploit surface is narrow, + but the validator's job per AGENTS.md > LLM Input Safety is to fail + closed locally rather than rely on a downstream service. + + Pinned per the R5 review thread on ``gr00t_inference.py:253``. + """ + + _COMMON_KWARGS = dict( + data_config="fourier_gr1_arms_only", + embodiment_tag="gr1", + port=5555, + host="127.0.0.1", + vit_dtype="fp8", + llm_dtype="nvfp4", + dit_dtype="fp8", + checkpoint_path=None, + trt_engine_path="gr00t_engine", + container_name=None, + protocol="n1.5", + ) + + def test_leading_dot_first_segment_rejected(self): + """hf_repo='.org/name' is regex-valid but segment-rejected.""" + from strands_robots.tools.gr00t_inference import validate_inputs + + with pytest.raises(ValueError, match=r"hf_repo.*org/name"): + validate_inputs( + action="download_checkpoint", + hf_repo=".org/name", + **self._COMMON_KWARGS, + ) + + def test_leading_dot_second_segment_rejected(self): + """hf_repo='org/.git' (hidden-style name) must be rejected.""" + from strands_robots.tools.gr00t_inference import validate_inputs + + with pytest.raises(ValueError, match=r"hf_repo.*org/name"): + validate_inputs( + action="download_checkpoint", + hf_repo="org/.git", + **self._COMMON_KWARGS, + ) + + def test_multi_dot_prefix_rejected(self): + """hf_repo='...../name' is regex-valid (chars in class) but rejected.""" + from strands_robots.tools.gr00t_inference import validate_inputs + + with pytest.raises(ValueError, match=r"hf_repo.*org/name"): + validate_inputs( + action="download_checkpoint", + hf_repo="...../name", + **self._COMMON_KWARGS, + ) + + def test_leading_dot_hidden_name_rejected(self): + """hf_repo='org/.name' (hidden file convention) must be rejected.""" + from strands_robots.tools.gr00t_inference import validate_inputs + + with pytest.raises(ValueError, match=r"hf_repo.*org/name"): + validate_inputs( + action="download_checkpoint", + hf_repo="org/.name", + **self._COMMON_KWARGS, + ) + + def test_legitimate_dotted_repos_still_accepted(self): + """Inverse pin: internal dots in segments are still accepted.""" + from strands_robots.tools.gr00t_inference import validate_inputs + + for valid in ( + "nvidia/GR00T-N1.7-LIBERO", # internal dot in name segment + "a-b/c.d", # internal dot allowed + "org_x/repo.name-v2", # multiple internal dots + ): + validate_inputs( + action="download_checkpoint", + hf_repo=valid, + **self._COMMON_KWARGS, + ) + + +class TestStartContainerHostNoDefault: + """Pin: _start_container requires host without default. + + Regression: R5 identified two sources of truth for the loopback default. + gr00t_inference() resolves host=None -> "127.0.0.1", AND _start_container + declares host: str = "127.0.0.1" independently. If gr00t_inference() ever + passes host=None through, _start_container's default masks the bug. + + Fix: change _start_container signature to require host: str (drop default) + so gr00t_inference() is the single source of truth. Pinned per the R6 + review threads on gr00t_inference.py:1511 and :629. + """ + + def test_start_container_host_has_no_default(self): + """_start_container signature must require host without default.""" + import inspect + + from strands_robots.tools.gr00t_inference import _start_container + + sig = inspect.signature(_start_container) + host_param = sig.parameters.get("host") + + assert host_param is not None, "_start_container signature missing host parameter" + assert host_param.default is inspect.Parameter.empty, ( + f"_start_container host parameter has default {host_param.default!r}. " + f"Expected no default (gr00t_inference() is the single source of truth)." + ) + + def test_gr00t_inference_resolves_host_none_to_loopback(self): + """Inverse pin: gr00t_inference(host=None) still resolves to 127.0.0.1.""" + from unittest.mock import patch + + from strands_robots.tools.gr00t_inference import gr00t_inference + + with patch("strands_robots.tools.gr00t_inference._start_container") as mock_start: + mock_start.return_value = { + "status": "success", + "container_name": "gr00t", + "message": "mocked", + } + + # Call with host=None (user did not specify) + gr00t_inference( + action="start_container", + image_name="mock:latest", + ) + + # Verify _start_container was called with host="127.0.0.1" + assert mock_start.called, "_start_container not called" + call_kwargs = mock_start.call_args.kwargs + assert call_kwargs.get("host") == "127.0.0.1", ( + f"Expected gr00t_inference to resolve host=None to '127.0.0.1', got {call_kwargs.get('host')!r}" + ) diff --git a/tests/tools/test_gr00t_inference.py b/tests/tools/test_gr00t_inference.py index 38e8842b..f6053850 100644 --- a/tests/tools/test_gr00t_inference.py +++ b/tests/tools/test_gr00t_inference.py @@ -42,7 +42,6 @@ def _common_kwargs(**overrides: Any) -> dict[str, Any]: "container_name": "gr00t", "checkpoint_path": "/data/checkpoints/model", "port": 5555, - "host": "0.0.0.0", "data_config": "libero_panda", "embodiment_tag": "libero_sim", "denoising_steps": 4, @@ -572,6 +571,7 @@ def test_skips_when_already_running_and_not_force(self): image_name="gr00t:latest", container_name="gr00t", port=8000, + host="127.0.0.1", volumes=None, hf_token=None, container_command="tail -f /dev/null", @@ -597,6 +597,7 @@ def fake_run(cmd, *a, **kw): image_name="gr00t:latest", container_name="gr00t", port=8000, + host="127.0.0.1", volumes=None, hf_token=None, container_command="tail -f /dev/null", @@ -611,7 +612,14 @@ def fake_run(cmd, *a, **kw): assert "--gpus" in run_cmd and "all" in run_cmd assert "--ipc=host" in run_cmd assert "--name" in run_cmd and "gr00t" in run_cmd - assert "8000:8000" in run_cmd + # Post-R1 (PR #196): docker -p includes the host prefix; default host + # is 127.0.0.1 (loopback-only). The 0.0.0.0 silent rewrite was removed. + assert "127.0.0.1:8000:8000" in run_cmd, ( + f"R1 regression: _start_container must bind docker -p to loopback by default. Got argv: {run_cmd}" + ) + assert "0.0.0.0:8000:8000" not in run_cmd, ( + "R1 regression: default _start_container call must NOT bind all interfaces." + ) def test_volumes_default_includes_checkpoints_and_hf_cache(self): runs: list[list[str]] = [] @@ -629,6 +637,7 @@ def fake_run(cmd, *a, **kw): image_name="gr00t:latest", container_name="gr00t", port=8000, + host="127.0.0.1", volumes=None, hf_token=None, container_command="tail -f /dev/null", @@ -658,6 +667,7 @@ def fake_run(cmd, *a, **kw): image_name="gr00t:latest", container_name="gr00t", port=8000, + host="127.0.0.1", volumes={"/cp": "/data/checkpoints"}, hf_token="abc123", container_command="tail -f /dev/null", @@ -677,6 +687,7 @@ def test_unhealthy_state_without_force_errors(self): image_name="gr00t:latest", container_name="gr00t", port=8000, + host="127.0.0.1", volumes=None, hf_token=None, container_command="tail -f /dev/null",