From ce9db15d6c9ade2d2103cdf5e745d7ac0e729e5b Mon Sep 17 00:00:00 2001 From: cagataycali Date: Sat, 23 May 2026 01:04:26 -0400 Subject: [PATCH 01/18] =?UTF-8?q?improve(gr00t=5Finference):=20R0=20?= =?UTF-8?q?=E2=80=94=20rebase=20+=20baseline=20snapshot=20from=20PR=20#90?= =?UTF-8?q?=20head?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Squashed snapshot of strands-labs/robots#90's improve/groot-input-validation branch HEAD (commit 8dd443ec), rebased onto current upstream/main (which now includes #192). The original branch was based on a pre-#192 ancestor and showed 802 lines of unrelated tests/simulation/ deletions in its diff — rebasing removes that noise. Diff scope is now exactly 3 files: - strands_robots/tools/gr00t_inference.py (+471/-) - tests/policies/groot/test_gr00t_inference_validation.py (+1625/-) - CHANGELOG.md (+83/-) Subsequent commits on this branch address the 48 still-open review threads from #90, organised by theme: R1 — drop silent host=127.0.0.1 -> 0.0.0.0 auto-flip (security default) R2 — make validate_inputs() actually centralised R3 — reconcile CHANGELOG <-> PR description on host default R4 — fix the 4 regex bugs (port>65535, IPv4 short-forms, length cap, trailing dots) R5 — explicit 3-branch dispatch in validate_inputs (lifecycle, image-only, full) R6 — replace fragile Path monkeypatches; add 4 missing pin tests R7 — portable pgrep fallback (Linux-only ERE -> argv-list match + fail-safe) R8 — misc polish (stale comments, narrow except, host_was_explicit on restart) Refs: strands-labs/robots#90 (closing in favour of this consolidation). --- CHANGELOG.md | 83 + strands_robots/tools/gr00t_inference.py | 471 ++++- .../groot/test_gr00t_inference_validation.py | 1625 +++++++++++++++++ 3 files changed, 2153 insertions(+), 26 deletions(-) create mode 100644 tests/policies/groot/test_gr00t_inference_validation.py diff --git a/CHANGELOG.md b/CHANGELOG.md index b02c9da1..d276f3a2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,6 +3,89 @@ All notable behavioural changes to `strands-robots` are logged here. Follows [Keep a Changelog](https://keepachangelog.com/) conventions. +## Unreleased - #90 (gr00t_inference validation hardening) + +### Added + +- ``validate_inputs()`` centralises all parameter validation with action-aware + scoping: read-only actions (``find_containers``, ``list``, ``status``, + ``stop``) only validate ``port``/``host``/``protocol``; mutating actions + (``start``, ``restart``, ``lifecycle``) validate the full parameter surface. +- ``host`` parameter now accepts both IP addresses and RFC-952 hostnames + (e.g. ``localhost``, ``host.docker.internal``). All-numeric strings that + fail ``ipaddress.ip_address()`` (e.g. ``127.0.01``, ``999.999.999.999``) + are rejected as obvious IP typos. +- ``protocol`` validation moved into ``validate_inputs()`` (previously + hand-rolled outside the helper, breaking the single-entry-point contract). +- ``pgrep`` patterns now match both ``--port N`` and ``--port=N`` forms, + preventing silent miss when the service is started with the ``=`` syntax. +- Integration tests that invoke ``gr00t_inference()`` end-to-end and assert + that invalid inputs are caught (pins the ``try/except ValueError`` wiring). +- End-to-end regression test for ``_stop_service`` cross-port-kill scenario: + verifies that a process on port 8000 is NOT killed when stopping port 80. +- Exception clauses narrowed throughout: ``_is_gr00t_process`` / ``_is_gr00t_host_process`` + use ``(OSError, subprocess.SubprocessError, UnicodeDecodeError)``; + ``_list_running_services``/``_is_service_running`` use ``OSError``; + ``_stop_service`` uses ``(OSError, subprocess.SubprocessError)``; + ``_start_service`` uses ``(OSError, RuntimeError)``. + Only ``_download_checkpoint`` retains ``except Exception`` (``# noqa: BLE001``) + because huggingface_hub raises varied, opaque exception types. +- ``action`` parameter validated against a complete allowlist of 10 valid + actions; unknown actions get a clear error with the valid set listed. +- ``image_name``, ``volumes``, and ``container_command`` parameters are now + validated (Docker image reference, path traversal, shell metacharacters). +- ``pgrep`` pattern factored into ``_PGREP_INFERENCE_PORT_FMT`` module-level + constant — single source of truth across all 4 usage sites. +- ``_PGREP_INFERENCE_PORT_FMT`` and ``_is_gr00t_*_process`` now match both + N1.5/N1.6 (``inference_service.py``) and N1.7 (``gr00t.eval.run_gr00t_server``) + entry-points. Closes the N1.7 stop/status identification gap. +- All exception types in process-probe helpers now log (WARNING for + PermissionError, DEBUG for other OSError/SubprocessError/UnicodeDecodeError). + +### Changed + +- ``validate_inputs()`` parameters are now all required (no defaults). + ``gr00t_inference()`` is the single source of truth for default values; + the validator no longer duplicates them (prevents silent drift). +- ``_DOCKER_IMAGE_RE`` extended to support private-registry references with + port numbers (e.g. ``localhost:5000/myorg/img:tag``). +- ``_is_gr00t_process`` / ``_is_gr00t_host_process`` now require ``--port`` + in the process cmdline to match — prevents false-killing editors or + log-tailers that happen to touch ``inference_service.py``. +- ``PermissionError`` in process probes now logs at WARNING level instead + of being silently swallowed. +- **BREAKING** Default ``host`` changed from ``0.0.0.0`` to ``127.0.0.1`` + (loopback-only, per AGENTS.md). Container actions (``start``/``restart``/ + ``lifecycle``) auto-flip to ``0.0.0.0`` internally since Docker's + ``-p {port}:{port}`` publish requires bind-all inside the container. + **Migration:** if your downstream connects from another host, pass + ``host="0.0.0.0"`` explicitly. +- Host-system fallback (``pgrep``) is documented as Linux-only. Non-Linux + platforms will see "No service running" rather than silently succeeding. + +### Fixed + +- Duplicate ``torch_mock.manual_seed`` assignment in ``tests/mocks/torch_mock.py``. +- Option-injection guard: ``repo_url``, ``repo_tag``, ``policy_name`` starting + with ``-`` are rejected (prevents git/docker flag injection via subprocess argv). +- Host-system fallback (pgrep) now returns a clear error on non-Linux platforms + instead of silently reporting success. +- All-numeric hostname guard narrowed to multi-label patterns only — single-label + numerics (e.g. ``123``) are valid per RFC-1123. + +### Notes + +- Host validation is **broader** than before for hostnames (RFC-952 names like + ``localhost`` and ``host.docker.internal`` now pass) but **stricter** for + IP-like typos (all-numeric labels like ``127.0.01`` are rejected). +- Validation scope covers ``port``, ``host``, ``protocol``, ``data_config``, + ``embodiment_tag``, ``container_name``, TRT dtypes, ``checkpoint_path``, + ``trt_engine_path``, ``image_name``, ``volumes``, and ``container_command``. + Parameters ``repo_url``, ``repo_tag``, ``hf_repo``, ``policy_name`` are + NOT validated here — they flow into argv-style subprocess calls which + are not shell-injection-vulnerable. + + ## Unreleased - #178 (LiberoOffScreenRenderEngine retired) ### Removed: ``LiberoOffScreenRenderEngine`` simulation backend (BREAKING) diff --git a/strands_robots/tools/gr00t_inference.py b/strands_robots/tools/gr00t_inference.py index 8bfd9c6f..0a03b2df 100644 --- a/strands_robots/tools/gr00t_inference.py +++ b/strands_robots/tools/gr00t_inference.py @@ -11,9 +11,12 @@ from a single prompt - see #148 for the motivation. """ +import ipaddress import os +import re import socket import subprocess +import sys import time from pathlib import Path from typing import Any @@ -41,6 +44,256 @@ def _checkpoints_dir() -> Path: return get_base_dir() / "checkpoints" +# ───────────────────────────────────────────────────────────────────── +# Input validation helpers +# ───────────────────────────────────────────────────────────────────── + +# Docker image reference pattern — supports registry:port/path:tag and @sha256:digest. +# Examples: "gr00t:latest", "nvcr.io/nvidia/gr00t:n1.7", "localhost:5000/myorg/img:tag" +# "nvcr.io/nvidia/gr00t@sha256:abcdef..." (digest-pinned, supply-chain recommended) +_DOCKER_IMAGE_RE = re.compile( + r"^[a-zA-Z0-9]" # must start with alnum + r"(?:[a-zA-Z0-9._\-]*[a-zA-Z0-9])?" # optional middle chars (host/path prefix) + r"(?::[0-9]{1,5})?" # optional registry port (:5000) + r"(?:/[a-zA-Z0-9][a-zA-Z0-9._\-]*)*" # path components (/org/img) + r"(?::[a-zA-Z0-9][a-zA-Z0-9._\-]*" # option A: :tag + r"|@sha256:[a-f0-9]{64})?$" # option B: @sha256:digest (mutually exclusive with tag) +) + +# Characters that cause harm in subprocess argv or shell interpolation. +# Narrowed per AGENTS.md review-learnings: quotes/bangs/parens/brackets +# appear in legitimate filesystem paths and all subprocess calls here are +# argv-style (no shell=True), so they pose no injection risk in path values. +# Backslash (\) is also legal on Linux (only / and NUL are forbidden by POSIX) +# and carries no special meaning in argv-style subprocess calls. +_SHELL_META = re.compile(r"[;&|`$<>\n\r\x00]") + +# Strict patterns for enumerable parameters. +_DATA_CONFIG_RE = re.compile(r"^[a-z][a-z0-9_]{0,63}$") +_EMBODIMENT_TAG_RE = re.compile(r"^[a-z][a-z0-9_]{0,31}$") +_CONTAINER_NAME_RE = re.compile(r"^[a-zA-Z0-9][a-zA-Z0-9._-]{0,127}$") +# RFC-952 hostname pattern for host validation. +_HOSTNAME_RE = re.compile( + r"^[a-zA-Z0-9](?:[a-zA-Z0-9\-]{0,61}[a-zA-Z0-9])?" + r"(?:\.[a-zA-Z0-9](?:[a-zA-Z0-9\-]{0,61}[a-zA-Z0-9])?)*$" +) +# Reject multi-label all-numeric strings — prevents typos like "127.0.01" +# which pass _HOSTNAME_RE but are clearly malformed IP attempts, not hostnames. +# Single-label numerics (e.g. "123") are valid per RFC-1123. +_ALL_NUMERIC_RE = re.compile(r"^[0-9]+(?:\.[0-9]+)+$") + +# Factored pgrep pattern — single source of truth for both docker-exec and +# host-fallback discovery paths. ERE syntax (procps-ng on Linux). +# NOTE: _PGREP_INFERENCE_PORT_FMT uses `( |$)` (ERE, space-only boundary) while +# _PYTHON_PORT_RE_FMT uses `(?:\s|$)` (Python re, any-whitespace boundary). +# This is intentional: pgrep is constrained to ERE, and cmdlines are space-separated +# in practice (procps-ng converts NUL → space when reading /proc/*/cmdline). +# Matches both N1.5/N1.6 (inference_service.py) and N1.7 (gr00t.eval.run_gr00t_server) +_PGREP_INFERENCE_PORT_FMT = r"(inference_service\.py|gr00t\.eval\.run_gr00t_server).*--port[= ]{port}( |$)" +# Python-side equivalent for re.search — uses (?:\s|$) instead of ( |$) +# because Python re is always ERE-ish and \s is more precise. +_PYTHON_PORT_RE_FMT = r"--port[= ]{port}(?:\s|$)" + +# Allowlists for TensorRT dtype parameters. +_VALID_VIT_DTYPES = {"fp16", "fp8"} +_VALID_LLM_DTYPES = {"fp16", "nvfp4", "fp8"} +_VALID_DIT_DTYPES = {"fp16", "fp8"} + +# Complete allowlist of valid actions for the tool. +_VALID_ACTIONS = frozenset( + { + "find_containers", + "list", + "status", + "stop", + "start", + "restart", + "build_image", + "download_checkpoint", + "start_container", + "lifecycle", + } +) + + +def _validate_path(value: str, label: str, *, reject_colon: bool = False) -> None: + """Reject paths containing shell metacharacters, null bytes, or traversal sequences. + + Args: + value: The path string to validate. + label: Human-readable label for error messages. + reject_colon: When True, reject ':' in the value. Required for Docker + volume mount paths where ':' would be re-interpreted as + host:container:options separator by docker -v. + """ + if "\x00" in value: + raise ValueError(f"{label} must not contain null bytes") + if value.startswith("-"): + raise ValueError(f"{label} must not start with '-' (got {value!r})") + if any(part == ".." for part in re.split(r"[/\\]", value)): + raise ValueError(f"{label} must not contain '..' path traversal components") + if _SHELL_META.search(value): + raise ValueError(f"{label} contains disallowed characters: {value!r}") + if reject_colon and ":" in value: + raise ValueError( + f"{label} must not contain ':' (docker -v interprets it as host:container:options separator; got {value!r})" + ) + + +def validate_inputs( + *, + action: str, + data_config: str, + embodiment_tag: str, + port: int, + host: str, + vit_dtype: str, + llm_dtype: str, + dit_dtype: str, + checkpoint_path: str | None, + trt_engine_path: str, + container_name: str | None, + protocol: str, + image_name: str | None = None, + volumes: dict[str, str] | None = None, + container_command: str | None = None, + repo_url: str | None = None, + repo_tag: str | None = None, + policy_name: str | None = None, +) -> None: + """Validate all user-supplied parameters in one place. + + Raises ValueError for any invalid input. Callers exposing this through + an AgentTool MUST wrap in try/except and convert to the structured error + dict (``{"status": "error", "message": str(e)}``). + + This centralises validation so that the main tool function stays focused + on orchestration and each check is independently testable via this + single entry-point. + + Validation is scoped to the action: actions whose only user-controlled + surface is port/host/protocol (find_containers, list, status, stop) + skip full parameter validation; mutating actions (start, restart, + lifecycle, build_image, download_checkpoint, start_container) validate + the full parameter surface. + """ + # Action allowlist — reject unknown actions early with a clear error + if action not in _VALID_ACTIONS: + raise ValueError(f"Unknown action {action!r}. Valid actions: {sorted(_VALID_ACTIONS)}") + + # Protocol — always validated regardless of action + valid_protocols = ("n1.5", "n1.6", "n1.7") + if protocol not in valid_protocols: + raise ValueError(f"Unknown protocol {protocol!r}. Valid: {list(valid_protocols)}") + # Port range — always validated. Type-check first so callers get ValueError, not TypeError. + if not isinstance(port, int): + raise ValueError(f"port must be an integer, got {type(port).__name__}: {port!r}") + if not (1 <= port <= 65535): + raise ValueError(f"port must be between 1 and 65535, got {port}") + + # Host address validation — always validated (accept IPs and RFC-952 hostnames) + if not isinstance(host, str): + raise ValueError(f"host must be a string, got {type(host).__name__}: {host!r}") + # RFC 1035 §2.3.4: total hostname must not exceed 253 octets. + if len(host) > 253: + raise ValueError(f"host exceeds RFC 1035 maximum length of 253 chars (got {len(host)} chars)") + try: + ipaddress.ip_address(host) + except ValueError: + # Reject all-numeric labels (e.g. "127.0.01") — these are clearly IP typos + # not legitimate hostnames. Real hostnames must have at least one alpha label. + # Rejects all-numeric multi-label strings including Linux IPv4 short-forms + # like "127.1" — use canonical dotted-quad for clarity in agent-driven contexts. + if _ALL_NUMERIC_RE.match(host) or not _HOSTNAME_RE.match(host): + raise ValueError( + f"host must be a valid IP address or hostname (got {host!r}). " + f"Use '127.0.0.1' for loopback, '0.0.0.0' for all interfaces, " + f"or a valid hostname like 'localhost'." + ) from None + + # Port-only actions (find_containers, list, status, stop) only need + # port/host/protocol validation — the other params are unused by dispatch. + _port_only_actions = ("find_containers", "list", "status", "stop") + if action in _port_only_actions: + return + + # Image/download actions only consume image_name, paths, and volumes — not + # inference-time params (data_config, embodiment_tag, dtypes). + _image_only_actions = ("build_image", "download_checkpoint", "start_container") + if action in _image_only_actions: + # Validate container_name (used by start_container, interpolated into docker run --name) + if container_name is not None and not _CONTAINER_NAME_RE.match(container_name): + raise ValueError(f"container_name must match Docker naming rules (got {container_name!r})") + # Validate image_name, volumes, container_command (relevant to these actions) + if image_name is not None and not _DOCKER_IMAGE_RE.match(image_name): + raise ValueError(f"image_name must be a valid Docker image reference (got {image_name!r})") + if volumes is not None: + for vol_host, vol_container in volumes.items(): + _validate_path(vol_host, "volumes key (host path)", reject_colon=True) + _validate_path(vol_container, "volumes value (container path)", reject_colon=True) + if container_command is not None and _SHELL_META.search(container_command): + raise ValueError(f"container_command contains disallowed characters: {container_command!r}") + if checkpoint_path is not None: + _validate_path(checkpoint_path, "checkpoint_path") + # Option-injection guard for params used by these actions + for param_name, param_value in [("repo_url", repo_url), ("repo_tag", repo_tag), ("policy_name", policy_name)]: + if param_value is not None and param_value.startswith("-"): + raise ValueError(f"{param_name} must not start with '-' (got {param_value!r})") + return + + # ── Full validation for inference-mutating actions (start, restart, lifecycle) ── + + # Enumerable string parameters + if not _DATA_CONFIG_RE.match(data_config): + raise ValueError( + f"data_config must be lowercase alphanumeric/underscore (got {data_config!r}). " + f"See the tool docstring for the full list of accepted configs." + ) + if not _EMBODIMENT_TAG_RE.match(embodiment_tag): + raise ValueError(f"embodiment_tag must be lowercase alphanumeric/underscore (got {embodiment_tag!r})") + + # Docker container name + if container_name is not None and not _CONTAINER_NAME_RE.match(container_name): + raise ValueError(f"container_name must match Docker naming rules (got {container_name!r})") + + # Filesystem paths — reject shell metacharacters and traversal + if checkpoint_path is not None: + _validate_path(checkpoint_path, "checkpoint_path") + _validate_path(trt_engine_path, "trt_engine_path") + + # TensorRT dtype allowlists + if vit_dtype not in _VALID_VIT_DTYPES: + raise ValueError(f"vit_dtype must be one of {_VALID_VIT_DTYPES}, got {vit_dtype!r}") + if llm_dtype not in _VALID_LLM_DTYPES: + raise ValueError(f"llm_dtype must be one of {_VALID_LLM_DTYPES}, got {llm_dtype!r}") + if dit_dtype not in _VALID_DIT_DTYPES: + raise ValueError(f"dit_dtype must be one of {_VALID_DIT_DTYPES}, got {dit_dtype!r}") + + # Docker image reference (if provided via kwargs) + if image_name is not None and not _DOCKER_IMAGE_RE.match(image_name): + raise ValueError(f"image_name must be a valid Docker image reference (got {image_name!r})") + + # Volume paths validation + if volumes is not None: + for vol_host, vol_container in volumes.items(): + _validate_path(vol_host, "volumes key (host path)", reject_colon=True) + _validate_path(vol_container, "volumes value (container path)", reject_colon=True) + + # Container command — reject shell metacharacters + if container_command is not None and _SHELL_META.search(container_command): + raise ValueError(f"container_command contains disallowed characters: {container_command!r}") + + # Option-injection guard: reject LLM-controlled values starting with '-' + # which could be parsed as flags by git/docker/pgrep in subprocess argv. + for param_name, param_value in [ + ("repo_url", repo_url), + ("repo_tag", repo_tag), + ("policy_name", policy_name), + ]: + if param_value is not None and param_value.startswith("-"): + raise ValueError(f"{param_name} must not start with '-' (got {param_value!r})") + + @tool def gr00t_inference( action: str, @@ -50,7 +303,7 @@ def gr00t_inference( data_config: str = "fourier_gr1_arms_only", embodiment_tag: str = "gr1", denoising_steps: int = 4, - host: str = "0.0.0.0", + host: str | None = None, container_name: str | None = None, timeout: int = 60, use_tensorrt: bool = False, @@ -192,7 +445,9 @@ def gr00t_inference( ``libero_sim``). denoising_steps: Number of denoising steps for action generation (default: 4). N1.5/N1.6 only - the N1.7 server reads this from the checkpoint. - host: Host address to bind the service to (default: ``0.0.0.0``). + host: Host address to bind the service to (default: ``127.0.0.1`` + loopback only). Container actions auto-flip to ``0.0.0.0`` internally + since Docker -p port-publish requires bind-all inside the container. container_name: Specific Docker container name. Auto-detected if omitted. timeout: Seconds to wait for service startup (default: 60). use_tensorrt: Enable TensorRT acceleration (default: False). @@ -297,14 +552,37 @@ def gr00t_inference( if api_token is None: api_token = os.environ.get("GROOT_API_TOKEN") - # Validate protocol up-front so users get a friendly error rather than - # an opaque docker-exec failure inside _start_service. - valid_protocols = ("n1.5", "n1.6", "n1.7") - if protocol not in valid_protocols: - return { - "status": "error", - "message": f"Unknown protocol {protocol!r}. Valid: {list(valid_protocols)}", - } + # Sentinel default: None means "user did not pass host=". + # Default to 127.0.0.1 (loopback, per AGENTS.md § LLM Input Safety). + # _start_service auto-flips to 0.0.0.0 ONLY when host was not explicitly set. + _host_was_explicit = host is not None + if host is None: + host = "127.0.0.1" + + # ── Validate all inputs in one call (scoped per action) ───────── + try: + validate_inputs( + action=action, + data_config=data_config, + embodiment_tag=embodiment_tag, + port=port, + host=host, + vit_dtype=vit_dtype, + llm_dtype=llm_dtype, + dit_dtype=dit_dtype, + checkpoint_path=checkpoint_path, + trt_engine_path=trt_engine_path, + container_name=container_name, + protocol=protocol, + image_name=image_name, + volumes=volumes, + container_command=container_command, + repo_url=repo_url, + repo_tag=repo_tag, + policy_name=policy_name, + ) + except ValueError as e: + return {"status": "error", "message": str(e)} if action == "find_containers": return _find_gr00t_containers() @@ -378,6 +656,7 @@ def gr00t_inference( api_token=api_token, protocol=protocol, use_sim_policy_wrapper=use_sim_policy_wrapper, + host_was_explicit=_host_was_explicit, ) elif action == "start": if checkpoint_path is None: @@ -404,6 +683,7 @@ def gr00t_inference( api_token=api_token, protocol=protocol, use_sim_policy_wrapper=use_sim_policy_wrapper, + host_was_explicit=_host_was_explicit, ) elif action == "restart": if checkpoint_path is None: @@ -430,9 +710,11 @@ def gr00t_inference( api_token=api_token, protocol=protocol, use_sim_policy_wrapper=use_sim_policy_wrapper, + host_was_explicit=_host_was_explicit, ) - else: - return {"status": "error", "message": f"Unknown action: {action}"} + + # Unreachable: validate_inputs() rejects unknown actions before dispatch. + return {"status": "error", "message": f"Unknown action: {action}"} # pragma: no cover def _find_gr00t_containers() -> dict[str, Any]: @@ -479,7 +761,7 @@ def _list_running_services() -> dict[str, Any]: return {"status": "success", "services": services, "message": f"Found {len(services)} running services"} - except Exception as e: + except OSError as e: return {"status": "error", "message": f"Failed to list services: {e}"} @@ -491,7 +773,7 @@ def _is_service_running(port: int) -> bool: result = sock.connect_ex(("localhost", port)) sock.close() return result == 0 - except Exception: + except OSError: return False @@ -509,6 +791,88 @@ def _check_service_status(port: int) -> dict[str, Any]: } +def _is_gr00t_process(container_name: str, pid: str, *, port: int | None = None) -> bool: + """Verify that a PID inside a container belongs to a GR00T inference process. + + This prevents accidentally killing unrelated processes that happen to + be listening on the same port. + + Args: + container_name: Docker container name to inspect. + pid: Process ID to check. + port: If provided, also verify the process is bound to this port. + """ + try: + result = subprocess.run( + ["docker", "exec", container_name, "cat", f"/proc/{pid}/cmdline"], + capture_output=True, + text=True, + check=False, + ) + if result.returncode == 0: + cmdline = result.stdout.replace("\x00", " ") + # Require both a Python interpreter AND inference_service.py in cmdline + # to avoid false-matching unrelated processes (e.g. vim editing a gr00t file) + # Match both N1.5/N1.6 (inference_service.py) and N1.7 (gr00t.eval.run_gr00t_server) + is_gr00t = ( + ("inference_service.py" in cmdline or "gr00t.eval.run_gr00t_server" in cmdline) + and ("python" in cmdline.lower() or "gr00t" in cmdline.lower()) + and "--port" in cmdline # Must have a --port flag to be a running service + ) + if is_gr00t and port is not None: + # Verify the process is serving on the requested port + # Use word-boundary regex to avoid partial matches (e.g. port 80 vs 8000) + return bool(re.search(_PYTHON_PORT_RE_FMT.format(port=port), cmdline)) + return is_gr00t + except (OSError, subprocess.SubprocessError, UnicodeDecodeError) as exc: + import logging + + _logger = logging.getLogger(__name__) + if isinstance(exc, PermissionError): + _logger.warning("Permission denied probing container process %s -- treating as non-GR00T", pid) + else: + _logger.debug("Failed to probe container process %s: %s", pid, exc) + return False + + +def _is_gr00t_host_process(pid: str, *, port: int | None = None) -> bool: + """Verify that a host PID belongs to a GR00T inference process. + + Reads /proc//cmdline directly (no Docker) to confirm the process + is a GR00T inference service, optionally bound to a specific port. + + Note: This function reads from /proc and is Linux-only. + + Args: + pid: Process ID to check. + port: If provided, also verify the process is bound to this port. + """ + try: + cmdline_path = Path(f"/proc/{pid}/cmdline") + if cmdline_path.exists(): + cmdline = cmdline_path.read_text().replace("\x00", " ") + # Require both a Python interpreter AND inference_service.py in cmdline + # to avoid false-matching unrelated processes (e.g. vim editing a gr00t file) + # Match both N1.5/N1.6 (inference_service.py) and N1.7 (gr00t.eval.run_gr00t_server) + is_gr00t = ( + ("inference_service.py" in cmdline or "gr00t.eval.run_gr00t_server" in cmdline) + and ("python" in cmdline.lower() or "gr00t" in cmdline.lower()) + and "--port" in cmdline # Must have a --port flag to be a running service + ) + if is_gr00t and port is not None: + return bool(re.search(_PYTHON_PORT_RE_FMT.format(port=port), cmdline)) + return is_gr00t + except (OSError, UnicodeDecodeError) as exc: + import logging + + _logger = logging.getLogger(__name__) + if isinstance(exc, PermissionError): + _logger.warning("Permission denied reading /proc/%s/cmdline -- treating as non-GR00T", pid) + else: + _logger.debug("Failed to probe host process %s: %s", pid, exc) + return False + + def _stop_service(port: int) -> dict[str, Any]: """Stop GR00T inference service running on specific port.""" try: @@ -520,7 +884,14 @@ def _stop_service(port: int) -> dict[str, Any]: container_name = container["name"] try: result = subprocess.run( - ["docker", "exec", container_name, "pgrep", "-f", f"inference_service.py.*--port {port}"], + [ + "docker", + "exec", + container_name, + "pgrep", + "-f", + _PGREP_INFERENCE_PORT_FMT.format(port=port), + ], capture_output=True, text=True, check=False, @@ -529,13 +900,21 @@ def _stop_service(port: int) -> dict[str, Any]: if result.returncode == 0 and result.stdout.strip(): pids = result.stdout.strip().split("\n") for pid in pids: - if pid: + pid = pid.strip() + if pid and _is_gr00t_process(container_name, pid, port=port): subprocess.run(["docker", "exec", container_name, "kill", "-TERM", pid], check=True) time.sleep(2) result = subprocess.run( - ["docker", "exec", container_name, "pgrep", "-f", f"inference_service.py.*--port {port}"], + [ + "docker", + "exec", + container_name, + "pgrep", + "-f", + _PGREP_INFERENCE_PORT_FMT.format(port=port), + ], capture_output=True, text=True, check=False, @@ -544,7 +923,8 @@ def _stop_service(port: int) -> dict[str, Any]: if result.returncode == 0 and result.stdout.strip(): pids = result.stdout.strip().split("\n") for pid in pids: - if pid: + pid = pid.strip() + if pid and _is_gr00t_process(container_name, pid, port=port): subprocess.run(["docker", "exec", container_name, "kill", "-KILL", pid], check=True) return { @@ -557,30 +937,54 @@ def _stop_service(port: int) -> dict[str, Any]: except subprocess.CalledProcessError: continue - # Fallback: try host system - result = subprocess.run(["lsof", "-t", f"-i:{port}"], capture_output=True, text=True) + # Fallback: try host system — verify via /proc//cmdline + # This path is Linux-only (ERE pgrep + /proc filesystem). + if sys.platform != "linux": + return { + "status": "error", + "message": ( + "No GR00T containers found. Host-fallback stop requires Linux " + "(pgrep + /proc). Run inside a Docker container or use action='find_containers' first." + ), + } + + result = subprocess.run( + # NOTE: ( |$) is ERE syntax; pgrep on Linux (procps-ng) defaults to ERE. + # This pattern is Linux-only; BSD pgrep may not match correctly. + ["pgrep", "-f", _PGREP_INFERENCE_PORT_FMT.format(port=port)], + capture_output=True, + text=True, + ) if result.returncode == 0: pids = result.stdout.strip().split("\n") for pid in pids: - if pid: + pid = pid.strip() + if pid and _is_gr00t_host_process(pid, port=port): subprocess.run(["kill", "-TERM", pid], check=True) time.sleep(2) - result = subprocess.run(["lsof", "-t", f"-i:{port}"], capture_output=True, text=True) + result = subprocess.run( + # NOTE: ( |$) is ERE syntax; pgrep on Linux (procps-ng) defaults to ERE. + # This pattern is Linux-only; BSD pgrep may not match correctly. + ["pgrep", "-f", _PGREP_INFERENCE_PORT_FMT.format(port=port)], + capture_output=True, + text=True, + ) if result.returncode == 0: pids = result.stdout.strip().split("\n") for pid in pids: - if pid: + pid = pid.strip() + if pid and _is_gr00t_host_process(pid, port=port): subprocess.run(["kill", "-KILL", pid], check=True) return {"status": "success", "port": port, "message": f"Service on port {port} stopped"} else: return {"status": "success", "port": port, "message": f"No service running on port {port}"} - except Exception as e: + except (OSError, subprocess.SubprocessError) as e: return {"status": "error", "message": f"Failed to stop service: {e}"} @@ -713,9 +1117,22 @@ def _start_service( api_token: str | None, protocol: str = "n1.5", use_sim_policy_wrapper: bool = False, + host_was_explicit: bool = False, ) -> dict[str, Any]: """Start GR00T inference service using Isaac-GR00T's native inference service.""" try: + # Auto-flip host for container actions: Docker's -p port-publish requires the + # service to bind all interfaces inside the container. Only auto-flip if the + # user accepted the default (sentinel was None → resolved to 127.0.0.1). + # Users who explicitly pass host="127.0.0.1" get it honoured (e.g. --network=host). + if host == "127.0.0.1" and not host_was_explicit: + import logging as _logging + + _logging.getLogger(__name__).warning( + "Auto-flipping host from 127.0.0.1 to 0.0.0.0 for container " + "port-publish (-p). Pass host='127.0.0.1' explicitly to keep loopback." + ) + host = "0.0.0.0" # Find container if not specified if container_name is None: containers = _find_gr00t_containers() @@ -791,7 +1208,7 @@ def _start_service( except subprocess.CalledProcessError as e: return {"status": "error", "message": f"Failed to start service: {e.stderr or e}"} - except Exception as e: + except (OSError, RuntimeError) as e: return {"status": "error", "message": f"Unexpected error: {e}"} @@ -1187,6 +1604,7 @@ def _lifecycle( api_token: str | None, protocol: str, use_sim_policy_wrapper: bool, + host_was_explicit: bool = False, ) -> dict[str, Any]: """Orchestrate the four-step setup or tear down a previously-started container. @@ -1308,6 +1726,7 @@ def _lifecycle( api_token=api_token, protocol=protocol, use_sim_policy_wrapper=use_sim_policy_wrapper, + host_was_explicit=host_was_explicit, ) steps.append({"step": "start", "result": start_result}) @@ -1320,7 +1739,7 @@ def _lifecycle( if __name__ == "__main__": - print("🐳 GR00T Inference Service Manager (Isaac-GR00T Native)") + print("GR00T Inference Service Manager (Isaac-GR00T Native)") print("Supports ZMQ, HTTP, and TensorRT inference modes") print() print("Examples:") diff --git a/tests/policies/groot/test_gr00t_inference_validation.py b/tests/policies/groot/test_gr00t_inference_validation.py new file mode 100644 index 00000000..f79ea266 --- /dev/null +++ b/tests/policies/groot/test_gr00t_inference_validation.py @@ -0,0 +1,1625 @@ +"""Tests for gr00t_inference input validation. + +Covers the validate_inputs() function which centralises all parameter +validation for the gr00t_inference tool. +""" + +import pytest + +from strands_robots.tools.gr00t_inference import validate_inputs + +# Standard valid kwargs for validate_inputs — tests override individual fields. +# validate_inputs() no longer has defaults (gr00t_inference() is the single source +# of truth for defaults), so tests must supply all required params. +_VALID_KWARGS = { + "action": "start", + "data_config": "fourier_gr1_arms_only", + "embodiment_tag": "gr1", + "port": 5555, + "host": "127.0.0.1", + "vit_dtype": "fp8", + "llm_dtype": "nvfp4", + "dit_dtype": "fp8", + "checkpoint_path": None, + "trt_engine_path": "gr00t_engine", + "container_name": None, + "protocol": "n1.5", +} + + +class TestValidateInputs: + """Tests for the validate_inputs() public function.""" + + def test_valid_defaults(self): + """Default values must pass validation.""" + validate_inputs(**_VALID_KWARGS) + + def test_valid_with_all_optional(self): + validate_inputs( + **{ + **_VALID_KWARGS, + "data_config": "so100_dualcam", + "embodiment_tag": "so100", + "port": 8000, + "vit_dtype": "fp16", + "llm_dtype": "fp8", + "dit_dtype": "fp16", + "checkpoint_path": "/data/checkpoints/model", + "trt_engine_path": "/engines/cache", + "container_name": "gr00t-n17", + } + ) + + def test_invalid_data_config_uppercase(self): + with pytest.raises(ValueError, match="data_config"): + validate_inputs( + **{ + **_VALID_KWARGS, + "data_config": "FourierGR1", + "embodiment_tag": "gr1", + "port": 5555, + "vit_dtype": "fp8", + "llm_dtype": "nvfp4", + "dit_dtype": "fp8", + } + ) + + def test_invalid_data_config_shell_chars(self): + with pytest.raises(ValueError, match="data_config"): + validate_inputs( + **{ + **_VALID_KWARGS, + "data_config": "foo;rm -rf /", + "embodiment_tag": "gr1", + "port": 5555, + "vit_dtype": "fp8", + "llm_dtype": "nvfp4", + "dit_dtype": "fp8", + } + ) + + def test_invalid_embodiment_tag(self): + with pytest.raises(ValueError, match="embodiment_tag"): + validate_inputs( + **{ + **_VALID_KWARGS, + "data_config": "so100", + "embodiment_tag": "GR1-Sonic!", + "port": 5555, + "vit_dtype": "fp8", + "llm_dtype": "nvfp4", + "dit_dtype": "fp8", + } + ) + + def test_port_zero(self): + with pytest.raises(ValueError, match="port"): + validate_inputs( + **{ + **_VALID_KWARGS, + "data_config": "so100", + "embodiment_tag": "so100", + "port": 0, + "vit_dtype": "fp8", + "llm_dtype": "nvfp4", + "dit_dtype": "fp8", + } + ) + + def test_port_too_high(self): + with pytest.raises(ValueError, match="port"): + validate_inputs( + **{ + **_VALID_KWARGS, + "data_config": "so100", + "embodiment_tag": "so100", + "port": 70000, + "vit_dtype": "fp8", + "llm_dtype": "nvfp4", + "dit_dtype": "fp8", + } + ) + + def test_invalid_vit_dtype(self): + with pytest.raises(ValueError, match="vit_dtype"): + validate_inputs( + **{ + **_VALID_KWARGS, + "data_config": "so100", + "embodiment_tag": "so100", + "port": 5555, + "vit_dtype": "bf16", + "llm_dtype": "nvfp4", + "dit_dtype": "fp8", + } + ) + + def test_invalid_llm_dtype(self): + with pytest.raises(ValueError, match="llm_dtype"): + validate_inputs( + **{ + **_VALID_KWARGS, + "data_config": "so100", + "embodiment_tag": "so100", + "port": 5555, + "vit_dtype": "fp8", + "llm_dtype": "int4", + "dit_dtype": "fp8", + } + ) + + def test_invalid_dit_dtype(self): + with pytest.raises(ValueError, match="dit_dtype"): + validate_inputs( + **{ + **_VALID_KWARGS, + "data_config": "so100", + "embodiment_tag": "so100", + "port": 5555, + "vit_dtype": "fp8", + "llm_dtype": "nvfp4", + "dit_dtype": "bf16", + } + ) + + def test_checkpoint_path_traversal(self): + with pytest.raises(ValueError, match="checkpoint_path"): + validate_inputs( + **{ + **_VALID_KWARGS, + "data_config": "so100", + "embodiment_tag": "so100", + "port": 5555, + "vit_dtype": "fp8", + "llm_dtype": "nvfp4", + "dit_dtype": "fp8", + "checkpoint_path": "/data/../../../etc/passwd", + } + ) + + def test_checkpoint_path_null_byte(self): + with pytest.raises(ValueError, match="checkpoint_path"): + validate_inputs( + **{ + **_VALID_KWARGS, + "data_config": "so100", + "embodiment_tag": "so100", + "port": 5555, + "vit_dtype": "fp8", + "llm_dtype": "nvfp4", + "dit_dtype": "fp8", + "checkpoint_path": "/data/model\x00.bin", + } + ) + + def test_trt_engine_path_shell_injection(self): + with pytest.raises(ValueError, match="trt_engine_path"): + validate_inputs( + **{ + **_VALID_KWARGS, + "data_config": "so100", + "embodiment_tag": "so100", + "port": 5555, + "vit_dtype": "fp8", + "llm_dtype": "nvfp4", + "dit_dtype": "fp8", + "trt_engine_path": "engine;rm -rf /", + } + ) + + def test_invalid_container_name(self): + with pytest.raises(ValueError, match="container_name"): + validate_inputs( + **{ + **_VALID_KWARGS, + "data_config": "so100", + "embodiment_tag": "so100", + "port": 5555, + "vit_dtype": "fp8", + "llm_dtype": "nvfp4", + "dit_dtype": "fp8", + "container_name": "-invalid-start", + } + ) + + def test_container_name_none_is_ok(self): + """container_name=None should not raise.""" + validate_inputs( + **{ + **_VALID_KWARGS, + "data_config": "so100", + "embodiment_tag": "so100", + "port": 5555, + "vit_dtype": "fp8", + "llm_dtype": "nvfp4", + "dit_dtype": "fp8", + "container_name": None, + } + ) + + +class TestIsGr00tProcess: + """Test the _is_gr00t_process helper verifies port binding.""" + + def test_rejects_wrong_port(self, monkeypatch): + """_is_gr00t_process should reject a GR00T process on a different port.""" + import subprocess as sp + + from strands_robots.tools.gr00t_inference import _is_gr00t_process + + # Simulate cmdline: "python inference_service.py --port 8000" + fake_result = sp.CompletedProcess( + args=[], returncode=0, stdout="python\x00inference_service.py\x00--port\x008000\x00" + ) + monkeypatch.setattr(sp, "run", lambda *a, **kw: fake_result) + + # Asking for port 80 should return False even though it's a gr00t process + assert _is_gr00t_process("container", "123", port=80) is False + + def test_accepts_matching_port(self, monkeypatch): + """_is_gr00t_process should accept when port matches.""" + import subprocess as sp + + from strands_robots.tools.gr00t_inference import _is_gr00t_process + + fake_result = sp.CompletedProcess( + args=[], returncode=0, stdout="python\x00inference_service.py\x00--port\x008000\x00" + ) + monkeypatch.setattr(sp, "run", lambda *a, **kw: fake_result) + + assert _is_gr00t_process("container", "123", port=8000) is True + + def test_no_port_check_when_none(self, monkeypatch): + """_is_gr00t_process without port param should not verify port.""" + import subprocess as sp + + from strands_robots.tools.gr00t_inference import _is_gr00t_process + + fake_result = sp.CompletedProcess( + args=[], returncode=0, stdout="python\x00inference_service.py\x00--port\x008000\x00" + ) + monkeypatch.setattr(sp, "run", lambda *a, **kw: fake_result) + + # Without port, just checks if it's a gr00t process + assert _is_gr00t_process("container", "123") is True + + def test_accepts_equals_style_port(self, monkeypatch): + """_is_gr00t_process should accept --port=N style.""" + import subprocess as sp + + from strands_robots.tools.gr00t_inference import _is_gr00t_process + + fake_result = sp.CompletedProcess( + args=[], returncode=0, stdout="python\x00inference_service.py\x00--port=5555\x00" + ) + monkeypatch.setattr(sp, "run", lambda *a, **kw: fake_result) + + assert _is_gr00t_process("container", "123", port=5555) is True + assert _is_gr00t_process("container", "123", port=6666) is False + + +class TestIsGr00tHostProcess: + """Test the _is_gr00t_host_process helper for host-system PID verification.""" + + def test_rejects_wrong_port(self, tmp_path, monkeypatch): + """_is_gr00t_host_process should reject a process on a different port.""" + from strands_robots.tools.gr00t_inference import _is_gr00t_host_process + + # Create a fake /proc//cmdline + proc_dir = tmp_path / "proc" / "123" + proc_dir.mkdir(parents=True) + cmdline_file = proc_dir / "cmdline" + cmdline_file.write_text("python\x00inference_service.py\x00--port\x008000\x00") + + # Monkeypatch Path to point at our fake proc, with reachability check + from pathlib import Path as RealPath + + called = {} + + def _fake_path(p): + called["p"] = p + return RealPath(str(p).replace("/proc", str(tmp_path / "proc"))) + + monkeypatch.setattr("strands_robots.tools.gr00t_inference.Path", _fake_path) + + assert _is_gr00t_host_process("123", port=80) is False + assert called.get("p") == "/proc/123/cmdline" # patch was reached + + def test_accepts_matching_port(self, tmp_path, monkeypatch): + """_is_gr00t_host_process should accept when port matches.""" + from strands_robots.tools.gr00t_inference import _is_gr00t_host_process + + proc_dir = tmp_path / "proc" / "456" + proc_dir.mkdir(parents=True) + cmdline_file = proc_dir / "cmdline" + cmdline_file.write_text("python\x00inference_service.py\x00--port\x008000\x00") + + from pathlib import Path as RealPath + + called = {} + + def _fake_path(p): + called["p"] = p + return RealPath(str(p).replace("/proc", str(tmp_path / "proc"))) + + monkeypatch.setattr("strands_robots.tools.gr00t_inference.Path", _fake_path) + + assert _is_gr00t_host_process("456", port=8000) is True + assert called.get("p") == "/proc/456/cmdline" # patch was reached + + def test_rejects_non_gr00t_process(self, tmp_path, monkeypatch): + """_is_gr00t_host_process should reject non-GR00T processes.""" + from strands_robots.tools.gr00t_inference import _is_gr00t_host_process + + proc_dir = tmp_path / "proc" / "789" + proc_dir.mkdir(parents=True) + cmdline_file = proc_dir / "cmdline" + cmdline_file.write_text("python\x00some_other_service.py\x00--port\x008000\x00") + + from pathlib import Path as RealPath + + called = {} + + def _fake_path(p): + called["p"] = p + return RealPath(str(p).replace("/proc", str(tmp_path / "proc"))) + + monkeypatch.setattr("strands_robots.tools.gr00t_inference.Path", _fake_path) + + assert _is_gr00t_host_process("789", port=8000) is False + assert called.get("p") == "/proc/789/cmdline" # patch was reached + + def test_no_port_check_when_none(self, tmp_path, monkeypatch): + """_is_gr00t_host_process without port checks only process identity.""" + from strands_robots.tools.gr00t_inference import _is_gr00t_host_process + + proc_dir = tmp_path / "proc" / "321" + proc_dir.mkdir(parents=True) + cmdline_file = proc_dir / "cmdline" + cmdline_file.write_text("python\x00inference_service.py\x00--port\x009999\x00") + + from pathlib import Path as RealPath + + called = {} + + def _fake_path(p): + called["p"] = p + return RealPath(str(p).replace("/proc", str(tmp_path / "proc"))) + + monkeypatch.setattr("strands_robots.tools.gr00t_inference.Path", _fake_path) + + # Without port kwarg, just checks identity + assert _is_gr00t_host_process("321") is True + assert called.get("p") == "/proc/321/cmdline" # patch was reached + + +class TestHostValidation: + """Tests for host address validation in validate_inputs().""" + + def test_valid_loopback(self): + """127.0.0.1 is valid.""" + validate_inputs( + **{ + **_VALID_KWARGS, + "data_config": "fourier_gr1_arms_only", + "embodiment_tag": "gr1", + "port": 5555, + "host": "127.0.0.1", + "vit_dtype": "fp8", + "llm_dtype": "nvfp4", + "dit_dtype": "fp8", + } + ) + + def test_valid_all_interfaces(self): + """0.0.0.0 is valid.""" + validate_inputs( + **{ + **_VALID_KWARGS, + "data_config": "fourier_gr1_arms_only", + "embodiment_tag": "gr1", + "port": 5555, + "host": "0.0.0.0", + "vit_dtype": "fp8", + "llm_dtype": "nvfp4", + "dit_dtype": "fp8", + } + ) + + def test_valid_ipv6_loopback(self): + """::1 is valid.""" + validate_inputs( + **{ + **_VALID_KWARGS, + "data_config": "fourier_gr1_arms_only", + "embodiment_tag": "gr1", + "port": 5555, + "host": "::1", + "vit_dtype": "fp8", + "llm_dtype": "nvfp4", + "dit_dtype": "fp8", + } + ) + + def test_invalid_host_with_spaces(self): + """Host with spaces must be rejected.""" + with pytest.raises(ValueError, match="host must be a valid IP address or hostname"): + validate_inputs( + **{ + **_VALID_KWARGS, + "data_config": "fourier_gr1_arms_only", + "embodiment_tag": "gr1", + "port": 5555, + "host": "foo bar", + "vit_dtype": "fp8", + "llm_dtype": "nvfp4", + "dit_dtype": "fp8", + } + ) + + def test_invalid_host_empty_labels(self): + """Host with empty labels (double dot) must be rejected.""" + with pytest.raises(ValueError, match="host must be a valid IP address or hostname"): + validate_inputs( + **{ + **_VALID_KWARGS, + "data_config": "fourier_gr1_arms_only", + "embodiment_tag": "gr1", + "port": 5555, + "host": "a..b", + "vit_dtype": "fp8", + "llm_dtype": "nvfp4", + "dit_dtype": "fp8", + } + ) + + def test_valid_hostname_localhost(self): + """Valid hostnames like localhost are now accepted.""" + # Should not raise — localhost is a valid RFC-952 hostname + validate_inputs( + **{ + **_VALID_KWARGS, + "data_config": "fourier_gr1_arms_only", + "embodiment_tag": "gr1", + "port": 5555, + "host": "localhost", + "vit_dtype": "fp8", + "llm_dtype": "nvfp4", + "dit_dtype": "fp8", + } + ) + + def test_valid_hostname_docker_internal(self): + """Docker internal hostname is accepted.""" + validate_inputs( + **{ + **_VALID_KWARGS, + "data_config": "fourier_gr1_arms_only", + "embodiment_tag": "gr1", + "port": 5555, + "host": "host.docker.internal", + "vit_dtype": "fp8", + "llm_dtype": "nvfp4", + "dit_dtype": "fp8", + } + ) + + def test_invalid_host_special_chars(self): + """Hostnames with special characters are rejected.""" + with pytest.raises(ValueError, match="host must be a valid IP address or hostname"): + validate_inputs( + **{ + **_VALID_KWARGS, + "data_config": "fourier_gr1_arms_only", + "embodiment_tag": "gr1", + "port": 5555, + "host": "--invalid-host", + "vit_dtype": "fp8", + "llm_dtype": "nvfp4", + "dit_dtype": "fp8", + } + ) + + +class TestGr00tInferenceToolIntegration: + """Integration tests verifying validate_inputs is wired into the tool entry point. + + These tests invoke gr00t_inference() directly and assert that invalid inputs + are caught and returned as error dicts, NOT silently passed through. + This pins the try/except ValueError wiring so a future refactor that drops + the validation call surfaces as a test failure. + """ + + def test_shell_injection_in_data_config_returns_error(self): + """Shell metacharacters in data_config must return error dict.""" + from strands_robots.tools.gr00t_inference import gr00t_inference + + result = gr00t_inference(action="start", data_config="foo;rm -rf /") + assert result["status"] == "error" + assert "data_config" in result["message"] + + def test_path_traversal_in_checkpoint_returns_error(self): + """Path traversal in checkpoint_path must return error dict.""" + from strands_robots.tools.gr00t_inference import gr00t_inference + + result = gr00t_inference(action="start", checkpoint_path="/tmp/../../../etc/passwd") + assert result["status"] == "error" + assert "checkpoint_path" in result["message"] + + def test_invalid_host_returns_error(self): + """Invalid host address must return error dict.""" + from strands_robots.tools.gr00t_inference import gr00t_inference + + result = gr00t_inference(action="start", host="--not-valid") + assert result["status"] == "error" + assert "host" in result["message"] + + def test_invalid_port_returns_error(self): + """Out-of-range port must return error dict.""" + from strands_robots.tools.gr00t_inference import gr00t_inference + + result = gr00t_inference(action="start", port=99999) + assert result["status"] == "error" + assert "port" in result["message"] + + +class TestStopServiceCrossPortKill: + """End-to-end regression test for the cross-port-kill bug. + + Verifies that _stop_service(port=80) does NOT kill a GR00T process + running on port 8000. This pins the _is_gr00t_process(port=...) guard + so a future refactor that removes it will surface as a test failure. + """ + + def test_stop_service_does_not_kill_wrong_port(self, monkeypatch): + """_stop_service(port=80) must NOT kill a process on port 8000.""" + from strands_robots.tools.gr00t_inference import _stop_service + + killed_pids = [] + call_log = [] + + def _fake_run(cmd, *args, **kwargs): + call_log.append(cmd) + + # Mock _find_gr00t_containers returning no containers (forces host fallback) + if "docker" in cmd and "ps" in cmd: + import subprocess + + result = subprocess.CompletedProcess(cmd, 0, stdout="", stderr="") + return result + + # Mock pgrep finding PID 999 on the host + if cmd[0] == "pgrep": + import subprocess + + return subprocess.CompletedProcess(cmd, 0, stdout="999\n", stderr="") + + # Mock kill — record it + if cmd[0] == "kill": + killed_pids.append(cmd[-1]) + import subprocess + + return subprocess.CompletedProcess(cmd, 0, stdout="", stderr="") + + import subprocess + + return subprocess.CompletedProcess(cmd, 1, stdout="", stderr="") + + def _fake_host_process(pid, *, port=None): + """Simulate a process running on port 8000, NOT port 80.""" + # The process is a GR00T process but on port 8000 + if port == 80: + return False # Not on port 80 + if port == 8000: + return True # Yes on port 8000 + return True # Generic check without port + + monkeypatch.setattr("strands_robots.tools.gr00t_inference.subprocess.run", _fake_run) + monkeypatch.setattr("strands_robots.tools.gr00t_inference._is_gr00t_host_process", _fake_host_process) + monkeypatch.setattr( + "strands_robots.tools.gr00t_inference._find_gr00t_containers", + lambda: {"status": "success", "containers": []}, + ) + + _stop_service(port=80) + + # No process should have been killed (the only process is on port 8000) + assert not killed_pids, ( + f"_stop_service(port=80) killed PIDs {killed_pids} but should not have " + f"(the only GR00T process is on port 8000)" + ) + + def test_stop_service_kills_correct_port(self, monkeypatch): + """_stop_service(port=8000) MUST kill a process on port 8000.""" + from strands_robots.tools.gr00t_inference import _stop_service + + killed_pids = [] + + def _fake_run(cmd, *args, **kwargs): + import subprocess + + if cmd[0] == "pgrep": + return subprocess.CompletedProcess(cmd, 0, stdout="999\n", stderr="") + if cmd[0] == "kill": + killed_pids.append(cmd[-1]) + return subprocess.CompletedProcess(cmd, 0, stdout="", stderr="") + return subprocess.CompletedProcess(cmd, 1, stdout="", stderr="") + + def _fake_host_process(pid, *, port=None): + """Simulate a process running on port 8000.""" + if port == 8000: + return True + return False + + monkeypatch.setattr("strands_robots.tools.gr00t_inference.subprocess.run", _fake_run) + monkeypatch.setattr("strands_robots.tools.gr00t_inference._is_gr00t_host_process", _fake_host_process) + monkeypatch.setattr( + "strands_robots.tools.gr00t_inference._find_gr00t_containers", + lambda: {"status": "success", "containers": []}, + ) + + _stop_service(port=8000) + + # Process should have been killed + assert "999" in killed_pids, f"_stop_service(port=8000) should have killed PID 999 but killed {killed_pids}" + + +class TestActionScopedValidation: + """Tests verifying that validate_inputs scopes checks per action. + + Read-only actions (find_containers, list, status, stop) should only + validate port/host/protocol, not the full parameter surface like + data_config, embodiment_tag, etc. + """ + + def test_read_only_action_accepts_any_data_config(self): + """Read-only actions should not validate data_config.""" + from strands_robots.tools.gr00t_inference import validate_inputs + + # This has hyphens/caps which WOULD fail for action="start" but passes for "list" + validate_inputs(**{**_VALID_KWARGS, "action": "list", "data_config": "Has-Hyphens-And-Caps"}) + + def test_read_only_action_still_validates_port(self): + """Read-only actions must still validate port.""" + from strands_robots.tools.gr00t_inference import validate_inputs + + with pytest.raises(ValueError, match="port must be between"): + validate_inputs(**{**_VALID_KWARGS, "action": "status", "port": 99999}) + + def test_read_only_action_still_validates_host(self): + """Read-only actions must still validate host.""" + from strands_robots.tools.gr00t_inference import validate_inputs + + with pytest.raises(ValueError, match="host must be a valid"): + validate_inputs(**{**_VALID_KWARGS, "action": "stop", "host": "--invalid"}) + + def test_read_only_action_still_validates_protocol(self): + """Read-only actions must still validate protocol.""" + from strands_robots.tools.gr00t_inference import validate_inputs + + with pytest.raises(ValueError, match="Unknown protocol"): + validate_inputs(**{**_VALID_KWARGS, "action": "list", "protocol": "invalid"}) + + def test_mutating_action_validates_data_config(self): + """Mutating actions must validate data_config.""" + from strands_robots.tools.gr00t_inference import validate_inputs + + with pytest.raises(ValueError, match="data_config"): + validate_inputs(**{**_VALID_KWARGS, "action": "start", "data_config": "foo;bar"}) + + def test_integration_read_only_action_skips_data_config_validation(self): + """gr00t_inference(action='list', data_config='invalid') must not error on data_config.""" + from strands_robots.tools.gr00t_inference import gr00t_inference + + # action="list" should not validate data_config + # It will fail at runtime (no docker) but NOT on validation + result = gr00t_inference(action="list", data_config="invalid;stuff") + # Should NOT be a validation error about data_config + if result.get("status") == "error": + assert "data_config" not in result.get("message", "") + + +class TestHostNumericTypoRejection: + """Regression tests for all-numeric hostname typos. + + Verifies that "127.0.01" (typo for 127.0.0.1) and "999.999.999.999" + are rejected by validate_inputs. These strings pass _HOSTNAME_RE but + are caught by the _ALL_NUMERIC_RE guard introduced in review round-4. + """ + + def test_invalid_host_typo_dotted_numeric(self): + """127.0.01 (typo for 127.0.0.1) must be rejected.""" + with pytest.raises(ValueError, match="host must be a valid IP address or hostname"): + validate_inputs( + **{ + **_VALID_KWARGS, + "data_config": "fourier_gr1_arms_only", + "embodiment_tag": "gr1", + "port": 5555, + "host": "127.0.01", + "vit_dtype": "fp8", + "llm_dtype": "nvfp4", + "dit_dtype": "fp8", + } + ) + + def test_invalid_host_999_octets(self): + """999.999.999.999 (invalid IP, all-numeric) must be rejected.""" + with pytest.raises(ValueError, match="host must be a valid IP address or hostname"): + validate_inputs( + **{ + **_VALID_KWARGS, + "data_config": "fourier_gr1_arms_only", + "embodiment_tag": "gr1", + "port": 5555, + "host": "999.999.999.999", + "vit_dtype": "fp8", + "llm_dtype": "nvfp4", + "dit_dtype": "fp8", + } + ) + + def test_single_numeric_label_is_valid_hostname(self): + """A bare number like '8080' is a valid single-label hostname (RFC-1123).""" + # Single-label numerics are valid hostnames; only multi-label patterns + # like '127.0.01' (IP typos) are rejected. + validate_inputs( + **{ + **_VALID_KWARGS, + "host": "8080", + } + ) + + +class TestActionAllowlistValidation: + """Tests for the action allowlist in validate_inputs. + + Verifies that unknown actions are rejected with a clear error that + lists the valid options, rather than falling through to validation + of unrelated parameters. + """ + + def test_unknown_action_rejected(self): + """Typo'd action gets a clear error listing valid actions.""" + with pytest.raises(ValueError, match="Unknown action.*Valid actions"): + validate_inputs(**{**_VALID_KWARGS, "action": "strat"}) # typo for "start" + + def test_unknown_action_integration(self): + """gr00t_inference(action='typo') returns error about unknown action.""" + from strands_robots.tools.gr00t_inference import gr00t_inference + + result = gr00t_inference(action="typo") + assert result["status"] == "error" + assert "Unknown action" in result["message"] + + def test_all_valid_actions_accepted(self): + """All 10 valid actions pass action validation (may fail later).""" + from strands_robots.tools.gr00t_inference import _VALID_ACTIONS + + for action in _VALID_ACTIONS: + # Should not raise ValueError about unknown action + # (may raise about other params, but that's fine) + try: + validate_inputs(**{**_VALID_KWARGS, "action": action}) + except ValueError as e: + assert "Unknown action" not in str(e), f"Action {action!r} wrongly rejected" + + +class TestExpandedParamValidation: + """Tests for image_name, volumes, and container_command validation.""" + + def test_invalid_image_name_rejected(self): + """Docker image with shell chars must be rejected.""" + with pytest.raises(ValueError, match="image_name must be a valid Docker"): + validate_inputs( + **{ + **_VALID_KWARGS, + "action": "start", + "data_config": "fourier_gr1_arms_only", + "embodiment_tag": "gr1", + "image_name": "gr00t:latest; rm -rf /", + } + ) + + def test_valid_image_name_accepted(self): + """Standard Docker image references must pass.""" + # Should not raise + validate_inputs( + **{ + **_VALID_KWARGS, + "action": "start", + "data_config": "fourier_gr1_arms_only", + "embodiment_tag": "gr1", + "image_name": "nvcr.io/nvidia/gr00t:n1.7", + } + ) + + def test_volume_path_traversal_rejected(self): + """Volumes with path traversal must be rejected.""" + with pytest.raises(ValueError, match="volumes key"): + validate_inputs( + **{ + **_VALID_KWARGS, + "action": "start", + "data_config": "fourier_gr1_arms_only", + "embodiment_tag": "gr1", + "volumes": {"/../etc/passwd": "/data"}, + } + ) + + def test_container_command_shell_meta_rejected(self): + """Container command with shell metacharacters must be rejected.""" + with pytest.raises(ValueError, match="container_command contains disallowed"): + validate_inputs( + **{ + **_VALID_KWARGS, + "action": "start", + "data_config": "fourier_gr1_arms_only", + "embodiment_tag": "gr1", + "container_command": "tail -f /dev/null; rm -rf /", + } + ) + + def test_valid_container_command_accepted(self): + """Standard container commands must pass.""" + # Should not raise + validate_inputs( + **{ + **_VALID_KWARGS, + "action": "start", + "data_config": "fourier_gr1_arms_only", + "embodiment_tag": "gr1", + "container_command": "tail -f /dev/null", + } + ) + + +class TestHappyPathIntegration: + """Happy-path integration test for gr00t_inference. + + Verifies that valid inputs pass validation and proceed to runtime + (which will fail due to missing Docker, but NOT on validation). + """ + + def test_valid_list_action_passes_validation(self): + """gr00t_inference(action='list') with valid params does not error on validation.""" + from strands_robots.tools.gr00t_inference import gr00t_inference + + result = gr00t_inference(action="list") + # The error should be about runtime (no docker), NOT validation + if result.get("status") == "error": + msg = result.get("message", "") + # Must not be a validation error + assert "must be" not in msg or "port" not in msg + assert "Unknown action" not in msg + assert "data_config" not in msg + + def test_valid_status_action_passes_validation(self): + """gr00t_inference(action='status') with valid params proceeds past validation.""" + from strands_robots.tools.gr00t_inference import gr00t_inference + + result = gr00t_inference(action="status", port=5555, host="127.0.0.1") + # Should not be a validation error + if result.get("status") == "error": + msg = result.get("message", "") + assert "Unknown action" not in msg + assert "host must be" not in msg + assert "port must be" not in msg + + +class TestDockerImageRegistryPort: + """Tests that _DOCKER_IMAGE_RE supports private registries with port numbers.""" + + def test_registry_with_port_accepted(self): + """localhost:5000/myorg/img:tag must be accepted.""" + validate_inputs(**{**_VALID_KWARGS, "image_name": "localhost:5000/myorg/img:tag"}) + + def test_registry_with_port_no_tag(self): + """registry.internal:5000/img must be accepted.""" + validate_inputs(**{**_VALID_KWARGS, "image_name": "registry.internal:5000/img"}) + + def test_nvcr_standard_format(self): + """nvcr.io/nvidia/gr00t:n1.7 must be accepted.""" + validate_inputs(**{**_VALID_KWARGS, "image_name": "nvcr.io/nvidia/gr00t:n1.7"}) + + def test_simple_image_tag(self): + """gr00t:latest must be accepted.""" + validate_inputs(**{**_VALID_KWARGS, "image_name": "gr00t:latest"}) + + +class TestProcessIdentificationRequiresPort: + """Tests that _is_gr00t_process requires --port in cmdline. + + Prevents false-matching unrelated processes like editors or log-tailers + that happen to have 'inference_service.py' and 'python' in their cmdline. + """ + + def test_process_without_port_flag_rejected(self, monkeypatch): + """A process with 'python inference_service.py' but no --port flag is not a match.""" + from strands_robots.tools.gr00t_inference import _is_gr00t_process + + # Mock docker exec to return a cmdline without --port + def fake_run(*args, **kwargs): + class Result: + returncode = 0 + stdout = "python inference_service.py --config test\x00" + + return Result() + + monkeypatch.setattr("subprocess.run", fake_run) + # Without --port in cmdline, should return False + assert _is_gr00t_process("container", "123", port=5555) is False + + def test_process_with_port_flag_accepted(self, monkeypatch): + """A process with --port 5555 in cmdline is a match.""" + from strands_robots.tools.gr00t_inference import _is_gr00t_process + + def fake_run(*args, **kwargs): + class Result: + returncode = 0 + stdout = "python inference_service.py --port 5555\x00" + + return Result() + + monkeypatch.setattr("subprocess.run", fake_run) + assert _is_gr00t_process("container", "123", port=5555) is True + + def test_editor_on_inference_service_rejected(self, monkeypatch): + """vim editing inference_service.py under a python venv is not a match.""" + from strands_robots.tools.gr00t_inference import _is_gr00t_process + + def fake_run(*args, **kwargs): + class Result: + returncode = 0 + stdout = "/opt/conda/envs/gr00t/bin/python vim /opt/gr00t/inference_service.py\x00" + + return Result() + + monkeypatch.setattr("subprocess.run", fake_run) + # No --port flag → rejected + assert _is_gr00t_process("container", "123", port=5555) is False + + +class TestOptionInjectionGuard: + """Test option-injection guard for argv-interpolated parameters.""" + + def test_repo_url_starting_with_dash_rejected(self): + """repo_url='--upload-pack=evil' must be rejected.""" + from strands_robots.tools.gr00t_inference import gr00t_inference + + result = gr00t_inference( + action="build_image", + repo_url="--upload-pack=touch /tmp/pwned", + ) + assert result["status"] == "error" + assert "repo_url" in result["message"] + assert "must not start with '-'" in result["message"] + + def test_repo_tag_starting_with_dash_rejected(self): + """repo_tag='--config=evil' must be rejected.""" + from strands_robots.tools.gr00t_inference import gr00t_inference + + result = gr00t_inference( + action="build_image", + repo_tag="--config=core.fsmonitor=evil-cmd", + ) + assert result["status"] == "error" + assert "repo_tag" in result["message"] + + def test_policy_name_starting_with_dash_rejected(self): + """policy_name='--flag' must be rejected.""" + from strands_robots.tools.gr00t_inference import gr00t_inference + + result = gr00t_inference( + action="start", + checkpoint_path="/data/model", + policy_name="--malicious", + ) + assert result["status"] == "error" + assert "policy_name" in result["message"] + + def test_valid_repo_url_accepted(self, monkeypatch): + """Normal https:// URL must pass the guard.""" + from strands_robots.tools import gr00t_inference as gi_mod + + # Mock _build_image to avoid actual git/docker operations + monkeypatch.setattr( + gi_mod, + "_build_image", + lambda **kwargs: {"status": "success", "message": "mocked"}, + ) + + result = gi_mod.gr00t_inference( + action="build_image", + repo_url="https://github.com/NVIDIA/Isaac-GR00T", + repo_tag="n1.7-release", + ) + # Should not be an option-injection error + assert "must not start with '-'" not in result.get("message", "") + + +class TestHostAutoFlipForContainer: + """Test that container actions auto-flip 127.0.0.1 to 0.0.0.0.""" + + def test_default_host_is_loopback(self): + """Signature default must be 127.0.0.1 (AGENTS.md compliance).""" + import inspect + + from strands_robots.tools.gr00t_inference import gr00t_inference + + sig = inspect.signature(gr00t_inference) + # Sentinel default: None means "use 127.0.0.1" but distinguishes from explicit + assert sig.parameters["host"].default is None + + def test_start_service_auto_flips_loopback(self, monkeypatch): + """_start_service should auto-flip 127.0.0.1 to 0.0.0.0 for Docker.""" + from strands_robots.tools.gr00t_inference import _start_service + + captured_host = {} + + def fake_find(*args, **kwargs): + return { + "status": "success", + "containers": [{"name": "gr00t-test", "status": "Up 2 hours"}], + } + + def fake_build_cmd(**kwargs): + captured_host["host"] = kwargs.get("host") + return ["docker", "exec", "gr00t-test", "echo", "test"] + + monkeypatch.setattr("strands_robots.tools.gr00t_inference._find_gr00t_containers", fake_find) + monkeypatch.setattr("strands_robots.tools.gr00t_inference._build_inference_command", fake_build_cmd) + + import subprocess + + def fake_run(*args, **kwargs): + return subprocess.CompletedProcess(args=args[0] if args else [], returncode=0, stdout="", stderr="") + + monkeypatch.setattr(subprocess, "run", fake_run) + monkeypatch.setattr("strands_robots.tools.gr00t_inference._is_service_running", lambda port: True) + + # Call with loopback default — should auto-flip + _start_service( + checkpoint_path="/data/model", + port=5555, + data_config="fourier_gr1_arms_only", + embodiment_tag="gr1", + denoising_steps=4, + host="127.0.0.1", + container_name=None, + policy_name=None, + timeout=5, + use_tensorrt=False, + trt_engine_path="gr00t_engine", + vit_dtype="fp8", + llm_dtype="nvfp4", + dit_dtype="fp8", + http_server=False, + api_token=None, + host_was_explicit=False, + ) + assert captured_host.get("host") == "0.0.0.0" + + +class TestSingleLabelNumericHostname: + """Verify single-label numeric hostnames (per RFC-1123) are accepted.""" + + def test_single_numeric_label_accepted(self): + """Single-label '123' is a valid hostname (RFC-1123).""" + validate_inputs(**{**_VALID_KWARGS, "host": "123"}) + + def test_multi_label_numeric_rejected(self): + """Multi-label '127.0.01' is rejected as an IP typo.""" + with pytest.raises(ValueError, match="host must be a valid IP address or hostname"): + validate_inputs(**{**_VALID_KWARGS, "host": "127.0.01"}) + + +class TestPlatformGuardForHostFallback: + """Test that host-fallback stop returns error on non-Linux platforms.""" + + def test_non_linux_platform_returns_error(self, monkeypatch): + """On non-Linux, _stop_service should error when no containers found.""" + import sys as _sys + + from strands_robots.tools.gr00t_inference import _stop_service + + # Mock no containers found + monkeypatch.setattr( + "strands_robots.tools.gr00t_inference._find_gr00t_containers", + lambda: {"status": "success", "containers": []}, + ) + monkeypatch.setattr(_sys, "platform", "darwin") + + result = _stop_service(5555) + assert result["status"] == "error" + assert "Linux" in result["message"] + + +class TestN17ProcessIdentification: + """Regression tests for N1.7 process identification — GH review thread. + + N1.7 services are started via `python -m gr00t.eval.run_gr00t_server` which + doesn't contain `inference_service.py` in cmdline. These tests ensure the + stop/status path can identify N1.7 services. + """ + + def test_n17_cmdline_detected_by_host_process_check(self, tmp_path, monkeypatch): + """_is_gr00t_host_process detects N1.7 server cmdline.""" + from strands_robots.tools.gr00t_inference import _is_gr00t_host_process + + # Simulate N1.7 cmdline: python -m gr00t.eval.run_gr00t_server --port 5555 + proc_dir = tmp_path / "proc" / "999" + proc_dir.mkdir(parents=True) + cmdline_file = proc_dir / "cmdline" + cmdline_file.write_text("python\x00-m\x00gr00t.eval.run_gr00t_server\x00--port\x005555\x00") + + called = {} + from pathlib import Path as RealPath + + def _fake_path(p): + called["p"] = p + return RealPath(str(p).replace("/proc", str(tmp_path / "proc"))) + + monkeypatch.setattr("strands_robots.tools.gr00t_inference.Path", _fake_path) + + assert _is_gr00t_host_process("999", port=5555) is True + assert called.get("p") == "/proc/999/cmdline" + + def test_n17_cmdline_wrong_port_rejected(self, tmp_path, monkeypatch): + """N1.7 server on wrong port is not killed.""" + from strands_robots.tools.gr00t_inference import _is_gr00t_host_process + + proc_dir = tmp_path / "proc" / "999" + proc_dir.mkdir(parents=True) + cmdline_file = proc_dir / "cmdline" + cmdline_file.write_text("python\x00-m\x00gr00t.eval.run_gr00t_server\x00--port\x008000\x00") + + called = {} + from pathlib import Path as RealPath + + def _fake_path(p): + called["p"] = p + return RealPath(str(p).replace("/proc", str(tmp_path / "proc"))) + + monkeypatch.setattr("strands_robots.tools.gr00t_inference.Path", _fake_path) + + # Request port 80 — should not match 8000 + assert _is_gr00t_host_process("999", port=80) is False + assert called.get("p") == "/proc/999/cmdline" + + def test_n15_cmdline_still_detected(self, tmp_path, monkeypatch): + """N1.5/N1.6 cmdline (inference_service.py) still works after N1.7 support.""" + from strands_robots.tools.gr00t_inference import _is_gr00t_host_process + + proc_dir = tmp_path / "proc" / "123" + proc_dir.mkdir(parents=True) + cmdline_file = proc_dir / "cmdline" + cmdline_file.write_text("python\x00inference_service.py\x00--port\x005555\x00") + + from pathlib import Path as RealPath + + def _fake_path(p): + return RealPath(str(p).replace("/proc", str(tmp_path / "proc"))) + + monkeypatch.setattr("strands_robots.tools.gr00t_inference.Path", _fake_path) + + assert _is_gr00t_host_process("123", port=5555) is True + + +class TestExpandedParamValidationExtended: + """Extended tests for image_name, volumes, and container_command — covers happy paths.""" + + def test_valid_image_name(self): + from strands_robots.tools.gr00t_inference import validate_inputs + + # Should not raise + validate_inputs( + action="start", + data_config="fourier_gr1_arms_only", + embodiment_tag="gr1", + port=5555, + host="127.0.0.1", + vit_dtype="fp8", + llm_dtype="nvfp4", + dit_dtype="fp8", + checkpoint_path="/tmp/ckpt", + trt_engine_path="gr00t_engine", + container_name="gr00t", + protocol="n1.5", + image_name="localhost:5000/myorg/img:tag", + ) + + def test_invalid_image_name_shell_meta(self): + import pytest + + from strands_robots.tools.gr00t_inference import validate_inputs + + with pytest.raises(ValueError, match="image_name"): + validate_inputs( + action="start", + data_config="fourier_gr1_arms_only", + embodiment_tag="gr1", + port=5555, + host="127.0.0.1", + vit_dtype="fp8", + llm_dtype="nvfp4", + dit_dtype="fp8", + checkpoint_path="/tmp/ckpt", + trt_engine_path="gr00t_engine", + container_name="gr00t", + protocol="n1.5", + image_name="evil;rm -rf /", + ) + + def test_volumes_path_traversal_rejected(self): + import pytest + + from strands_robots.tools.gr00t_inference import validate_inputs + + with pytest.raises(ValueError, match="volumes"): + validate_inputs( + action="start", + data_config="fourier_gr1_arms_only", + embodiment_tag="gr1", + port=5555, + host="127.0.0.1", + vit_dtype="fp8", + llm_dtype="nvfp4", + dit_dtype="fp8", + checkpoint_path="/tmp/ckpt", + trt_engine_path="gr00t_engine", + container_name="gr00t", + protocol="n1.5", + volumes={"../../etc/passwd": "/data"}, + ) + + def test_container_command_shell_meta_rejected(self): + import pytest + + from strands_robots.tools.gr00t_inference import validate_inputs + + with pytest.raises(ValueError, match="container_command"): + validate_inputs( + action="start", + data_config="fourier_gr1_arms_only", + embodiment_tag="gr1", + port=5555, + host="127.0.0.1", + vit_dtype="fp8", + llm_dtype="nvfp4", + dit_dtype="fp8", + checkpoint_path="/tmp/ckpt", + trt_engine_path="gr00t_engine", + container_name="gr00t", + protocol="n1.5", + container_command="tail -f /dev/null; rm -rf /", + ) + + def test_valid_container_command(self): + from strands_robots.tools.gr00t_inference import validate_inputs + + # Should not raise - legitimate container commands without shell metas + validate_inputs( + action="start", + data_config="fourier_gr1_arms_only", + embodiment_tag="gr1", + port=5555, + host="127.0.0.1", + vit_dtype="fp8", + llm_dtype="nvfp4", + dit_dtype="fp8", + checkpoint_path="/tmp/ckpt", + trt_engine_path="gr00t_engine", + container_name="gr00t", + protocol="n1.5", + container_command="tail -f /dev/null", + ) + + def test_valid_volumes(self): + from strands_robots.tools.gr00t_inference import validate_inputs + + # Should not raise + validate_inputs( + action="start", + data_config="fourier_gr1_arms_only", + embodiment_tag="gr1", + port=5555, + host="127.0.0.1", + vit_dtype="fp8", + llm_dtype="nvfp4", + dit_dtype="fp8", + checkpoint_path="/tmp/ckpt", + trt_engine_path="gr00t_engine", + container_name="gr00t", + protocol="n1.5", + volumes={"/tmp/checkpoints": "/data/checkpoints"}, + ) + + +class TestHostAutoFlipSentinel: + """Regression tests for the sentinel-based host auto-flip logic. + + The auto-flip from 127.0.0.1 → 0.0.0.0 for Docker container actions + MUST only fire when the user accepted the default (i.e. did not pass + host= explicitly). Users who explicitly pass host="127.0.0.1" (e.g. + for --network=host deployments) must have their choice honoured. + """ + + def test_default_host_passes_not_explicit(self, monkeypatch): + """When host is NOT passed (None sentinel), host_was_explicit=False.""" + from strands_robots.tools.gr00t_inference import gr00t_inference + + captured = {} + + def _mock_start_service(**kwargs): + captured.update(kwargs) + return {"status": "error", "message": "mocked"} + + monkeypatch.setattr( + "strands_robots.tools.gr00t_inference._start_service", + _mock_start_service, + ) + # Call without host= (uses default None → 127.0.0.1, not explicit) + gr00t_inference(action="start", checkpoint_path="/data/model") + assert captured.get("host") == "127.0.0.1" + assert captured.get("host_was_explicit") is False, ( + "Default host (None sentinel) should pass host_was_explicit=False" + ) + + def test_explicit_loopback_passes_explicit_flag(self, monkeypatch): + """When user explicitly passes host='127.0.0.1', host_was_explicit=True.""" + from strands_robots.tools.gr00t_inference import gr00t_inference + + captured = {} + + def _mock_start_service(**kwargs): + captured.update(kwargs) + return {"status": "error", "message": "mocked"} + + monkeypatch.setattr( + "strands_robots.tools.gr00t_inference._start_service", + _mock_start_service, + ) + # Call WITH explicit host="127.0.0.1" — must pass host_was_explicit=True + gr00t_inference(action="start", checkpoint_path="/data/model", host="127.0.0.1") + assert captured.get("host") == "127.0.0.1" + assert captured.get("host_was_explicit") is True, "Explicit host='127.0.0.1' must pass host_was_explicit=True" + + def test_explicit_zero_passes_explicit_flag(self, monkeypatch): + """When user explicitly passes host='0.0.0.0', host_was_explicit=True.""" + from strands_robots.tools.gr00t_inference import gr00t_inference + + captured = {} + + def _mock_start_service(**kwargs): + captured.update(kwargs) + return {"status": "error", "message": "mocked"} + + monkeypatch.setattr( + "strands_robots.tools.gr00t_inference._start_service", + _mock_start_service, + ) + gr00t_inference(action="start", checkpoint_path="/data/model", host="0.0.0.0") + assert captured.get("host") == "0.0.0.0" + assert captured.get("host_was_explicit") is True + + +class TestReviewRound8Fixes: + """Regression tests for review round-8 fixes (2026-05-22 21:44 UTC). + + Covers: + - restart path forwarding host_was_explicit + - colon rejection in volume paths (docker -v mount-redirect) + - digest-pinned image references + - TypeError handling in validation wrapper + - dash-prefix rejection in volume paths + """ + + def test_restart_forwards_host_was_explicit(self, monkeypatch): + """action='restart' must forward host_was_explicit to _start_service.""" + from strands_robots.tools.gr00t_inference import gr00t_inference + + captured = {} + + def _mock_start_service(**kwargs): + captured.update(kwargs) + return {"status": "success", "message": "mocked"} + + def _mock_stop_service(port): + pass + + monkeypatch.setattr( + "strands_robots.tools.gr00t_inference._start_service", + _mock_start_service, + ) + monkeypatch.setattr( + "strands_robots.tools.gr00t_inference._stop_service", + _mock_stop_service, + ) + monkeypatch.setattr("time.sleep", lambda _: None) + + # Explicit host='127.0.0.1' on restart must pass host_was_explicit=True + gr00t_inference( + action="restart", + checkpoint_path="/data/model", + host="127.0.0.1", + ) + assert captured.get("host_was_explicit") is True, ( + "restart path must forward host_was_explicit=True for explicit host" + ) + + def test_restart_default_host_not_explicit(self, monkeypatch): + """action='restart' with default host must pass host_was_explicit=False.""" + from strands_robots.tools.gr00t_inference import gr00t_inference + + captured = {} + + def _mock_start_service(**kwargs): + captured.update(kwargs) + return {"status": "success", "message": "mocked"} + + def _mock_stop_service(port): + pass + + monkeypatch.setattr( + "strands_robots.tools.gr00t_inference._start_service", + _mock_start_service, + ) + monkeypatch.setattr( + "strands_robots.tools.gr00t_inference._stop_service", + _mock_stop_service, + ) + monkeypatch.setattr("time.sleep", lambda _: None) + + # Default host (not passed) on restart -> host_was_explicit=False + gr00t_inference(action="restart", checkpoint_path="/data/model") + assert captured.get("host_was_explicit") is False + + def test_volume_path_colon_rejected(self): + """Volume paths containing ':' must be rejected (docker -v mount-redirect).""" + from strands_robots.tools.gr00t_inference import gr00t_inference + + result = gr00t_inference( + action="start_container", + image_name="gr00t:latest", + volumes={"/legit/dir:rw,nosuid": "/container/path"}, + ) + assert result["status"] == "error" + assert ":" in result["message"] or "colon" in result["message"].lower() + + def test_volume_value_colon_rejected(self): + """Container-side volume paths containing ':' must also be rejected.""" + from strands_robots.tools.gr00t_inference import gr00t_inference + + result = gr00t_inference( + action="start_container", + image_name="gr00t:latest", + volumes={"/host/path": "/container:path"}, + ) + assert result["status"] == "error" + assert ":" in result["message"] + + def test_volume_path_dash_prefix_rejected(self): + """Volume paths starting with '-' must be rejected (option injection).""" + from strands_robots.tools.gr00t_inference import gr00t_inference + + result = gr00t_inference( + action="start_container", + image_name="gr00t:latest", + volumes={"--privileged=foo": "/bar"}, + ) + assert result["status"] == "error" + assert "'-'" in result["message"] or "start with" in result["message"] + + def test_digest_pinned_image_accepted(self): + """Digest-pinned image refs (registry/path@sha256:hex) must be accepted.""" + from strands_robots.tools.gr00t_inference import gr00t_inference + + # Should NOT fail validation on image_name (may fail later on docker ops) + result = gr00t_inference( + action="start_container", + image_name="nvcr.io/nvidia/gr00t@sha256:" + "a" * 64, + ) + # If it fails, it should NOT be an image_name validation error + if result["status"] == "error": + assert "valid Docker image" not in result["message"] + + def test_type_error_returns_structured_error(self): + """TypeError from bad parameter types must return structured error, not raise.""" + from strands_robots.tools.gr00t_inference import gr00t_inference + + # port="5555" (str instead of int) -> TypeError on `1 <= port <= 65535` + result = gr00t_inference(action="start", checkpoint_path="/data/model", port="5555") + assert result["status"] == "error" + # Must not propagate as unhandled exception - returns dict + + def test_end_to_end_bogus_action_returns_error_dict(self): + """Bogus action returns structured error dict (not raw exception).""" + from strands_robots.tools.gr00t_inference import gr00t_inference + + result = gr00t_inference(action="bogus_action") + assert isinstance(result, dict) + assert result["status"] == "error" + assert "Unknown action" in result["message"] or "bogus_action" in result["message"] + + +class TestImageOnlyBranchValidation: + """Tests for validation on image-only actions (build_image, download_checkpoint, start_container).""" + + def test_container_name_validated_on_start_container(self): + """container_name must be validated on start_container (image-only branch).""" + from strands_robots.tools.gr00t_inference import validate_inputs + + with pytest.raises(ValueError, match="container_name"): + validate_inputs( + action="start_container", + data_config="so100", + embodiment_tag="so100", + port=5555, + host="127.0.0.1", + vit_dtype="fp8", + llm_dtype="nvfp4", + dit_dtype="fp8", + checkpoint_path=None, + trt_engine_path="/opt/engine", + container_name="--privileged", + protocol="n1.5", + image_name=None, + volumes=None, + container_command=None, + repo_url=None, + repo_tag=None, + policy_name=None, + ) + + def test_policy_name_dash_rejected_on_start_container(self): + """policy_name starting with '-' must be rejected on image-only actions.""" + from strands_robots.tools.gr00t_inference import validate_inputs + + with pytest.raises(ValueError, match="policy_name"): + validate_inputs( + action="start_container", + data_config="so100", + embodiment_tag="so100", + port=5555, + host="127.0.0.1", + vit_dtype="fp8", + llm_dtype="nvfp4", + dit_dtype="fp8", + checkpoint_path=None, + trt_engine_path="/opt/engine", + container_name=None, + protocol="n1.5", + image_name=None, + volumes=None, + container_command=None, + repo_url=None, + repo_tag=None, + policy_name="--malicious", + ) + + def test_valid_container_name_accepted_on_start_container(self): + """Valid container_name passes on start_container.""" + from strands_robots.tools.gr00t_inference import validate_inputs + + # Should not raise + validate_inputs( + action="start_container", + data_config="so100", + embodiment_tag="so100", + port=5555, + host="127.0.0.1", + vit_dtype="fp8", + llm_dtype="nvfp4", + dit_dtype="fp8", + checkpoint_path=None, + trt_engine_path="/opt/engine", + container_name="my-gr00t-container", + protocol="n1.5", + image_name=None, + volumes=None, + container_command=None, + repo_url=None, + repo_tag=None, + policy_name=None, + ) From ecf5f0fd4f7f05467d70aa8f756a5a370b37182a Mon Sep 17 00:00:00 2001 From: cagataycali Date: Sat, 23 May 2026 01:12:21 -0400 Subject: [PATCH 02/18] =?UTF-8?q?review(gr00t=5Finference):=20R1=20?= =?UTF-8?q?=E2=80=94=20drop=20silent=20host=3D127.0.0.1->0.0.0.0=20auto-fl?= =?UTF-8?q?ip?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Pre-R1, _start_service rewrote host='127.0.0.1' to '0.0.0.0' whenever host_was_explicit=False, which silently widened the bind from loopback to all-interfaces. The default-host case (sentinel None -> 127.0.0.1) hit the flip on every default invocation, undermining the loopback-by-default security promise advertised in the PR title. R1 architecture: - The service inside the container ALWAYS binds 0.0.0.0: (it has to, because Docker port-publish requires the listener on every interface inside the container's network namespace). - The HOST port-publish now binds to the user's host kwarg via 'docker -p HOST:port:port'. host='127.0.0.1' (default) -> docker binds loopback only on host host='0.0.0.0' (explicit) -> docker binds all interfaces on host - The host kwarg is no longer rewritten. Users get exactly what they ask. host_was_explicit kwarg is retained on _start_service for ABI compat (callers still pass it), marked '# noqa: ARG001' to silence the unused-arg lint and document the deliberate retention. Pin tests: tests/policies/groot/test_gr00t_inference_validation.py:: TestHostBindingHonoursUserChoice (3 cases) - test_default_host_is_loopback_sentinel - test_start_service_does_not_flip_default_loopback (R1 regression pin) - test_explicit_zero_zero_zero_zero_passes_through TestHostExplicitFlagDispatch (renamed from TestHostAutoFlipSentinel, now documents that the flag is dispatch-only post-R1) Refs: strands-labs/robots#90 (consolidates 7 review threads at line 1135). --- strands_robots/tools/gr00t_inference.py | 34 +++--- .../groot/test_gr00t_inference_validation.py | 103 +++++++++++++++--- 2 files changed, 104 insertions(+), 33 deletions(-) diff --git a/strands_robots/tools/gr00t_inference.py b/strands_robots/tools/gr00t_inference.py index 0a03b2df..d5a68eed 100644 --- a/strands_robots/tools/gr00t_inference.py +++ b/strands_robots/tools/gr00t_inference.py @@ -620,6 +620,7 @@ def gr00t_inference( container_command=container_command, hf_local_dir=hf_local_dir, force=force, + host=host, ) elif action == "lifecycle": return _lifecycle( @@ -1117,22 +1118,15 @@ def _start_service( api_token: str | None, protocol: str = "n1.5", use_sim_policy_wrapper: bool = False, - host_was_explicit: bool = False, + host_was_explicit: bool = False, # noqa: ARG001 — retained for ABI compat; auto-flip dropped in R1 ) -> dict[str, Any]: - """Start GR00T inference service using Isaac-GR00T's native inference service.""" + """Start GR00T inference service using Isaac-GR00T's native inference service. + + The ``host`` kwarg controls the docker host-side port binding via + ``-p {host}:{port}:{port}``. Default ``127.0.0.1`` keeps the service on + loopback; pass ``host="0.0.0.0"`` to expose to the network. + """ try: - # Auto-flip host for container actions: Docker's -p port-publish requires the - # service to bind all interfaces inside the container. Only auto-flip if the - # user accepted the default (sentinel was None → resolved to 127.0.0.1). - # Users who explicitly pass host="127.0.0.1" get it honoured (e.g. --network=host). - if host == "127.0.0.1" and not host_was_explicit: - import logging as _logging - - _logging.getLogger(__name__).warning( - "Auto-flipping host from 127.0.0.1 to 0.0.0.0 for container " - "port-publish (-p). Pass host='127.0.0.1' explicitly to keep loopback." - ) - host = "0.0.0.0" # Find container if not specified if container_name is None: containers = _find_gr00t_containers() @@ -1446,10 +1440,15 @@ def _start_container( container_command: str, hf_local_dir: str | None, force: bool, + host: str = "127.0.0.1", ) -> dict[str, Any]: """``docker run -d`` the GR00T container so subsequent ``start`` actions can ``docker exec`` into it. + The ``host`` kwarg controls the docker host-side port binding via + ``-p {host}:{port}:{port}``. Default ``127.0.0.1`` keeps the published + port on loopback only; pass ``host="0.0.0.0"`` to expose to the network. + Idempotent: when a container with ``container_name`` is already running, returns success without touching docker. When it exists but is stopped, ``force=True`` removes + recreates it (otherwise returns @@ -1486,7 +1485,11 @@ def _start_container( "--name", name, "-p", - f"{port}:{port}", + # Bind docker port-publish to user-requested host (loopback by default). + # Service inside container binds 0.0.0.0 (must, for docker -p to work), + # but the *host* binding honours user intent. Users pass host="0.0.0.0" + # explicitly to expose to the network. + f"{host}:{port}:{port}", ] # Default volume layout: mount the checkpoint dir into /data/checkpoints @@ -1679,6 +1682,7 @@ def _lifecycle( container_command=container_command, hf_local_dir=resolved_local_dir, force=force, + host=host, ) steps.append({"step": "start_container", "result": container_result}) if container_result["status"] != "success": diff --git a/tests/policies/groot/test_gr00t_inference_validation.py b/tests/policies/groot/test_gr00t_inference_validation.py index f79ea266..6afb5fd3 100644 --- a/tests/policies/groot/test_gr00t_inference_validation.py +++ b/tests/policies/groot/test_gr00t_inference_validation.py @@ -1038,24 +1038,38 @@ def test_valid_repo_url_accepted(self, monkeypatch): assert "must not start with '-'" not in result.get("message", "") -class TestHostAutoFlipForContainer: - """Test that container actions auto-flip 127.0.0.1 to 0.0.0.0.""" +class TestHostBindingHonoursUserChoice: + """R1 pin tests — host kwarg controls docker -p binding; NO auto-flip. + + Pre-R1: ``_start_service`` rewrote ``host="127.0.0.1"`` to ``"0.0.0.0"`` + when ``host_was_explicit=False``, silently widening the bind to all + interfaces. R1 drops the rewrite. The host kwarg now flows verbatim into + ``docker -p HOST:port:port``, so: + - host="127.0.0.1" (default) → docker binds loopback only + - host="0.0.0.0" (explicit) → docker binds all interfaces + """ - def test_default_host_is_loopback(self): - """Signature default must be 127.0.0.1 (AGENTS.md compliance).""" + def test_default_host_is_loopback_sentinel(self): + """Signature default must be None (resolves to 127.0.0.1) — AGENTS.md compliance.""" import inspect from strands_robots.tools.gr00t_inference import gr00t_inference sig = inspect.signature(gr00t_inference) - # Sentinel default: None means "use 127.0.0.1" but distinguishes from explicit - assert sig.parameters["host"].default is None + assert sig.parameters["host"].default is None, ( + "host signature default must remain None sentinel (resolves to 127.0.0.1)" + ) + + def test_start_service_does_not_flip_default_loopback(self, monkeypatch): + """R1 pin: _start_service must NOT rewrite host=127.0.0.1 to 0.0.0.0. - def test_start_service_auto_flips_loopback(self, monkeypatch): - """_start_service should auto-flip 127.0.0.1 to 0.0.0.0 for Docker.""" + Pre-R1 this test would have failed because the loopback default got + silently flipped to all-interfaces. Post-R1 the host stays loopback. + """ from strands_robots.tools.gr00t_inference import _start_service - captured_host = {} + # Capture the exact docker argv that _build_inference_command would see. + captured = {} def fake_find(*args, **kwargs): return { @@ -1064,7 +1078,7 @@ def fake_find(*args, **kwargs): } def fake_build_cmd(**kwargs): - captured_host["host"] = kwargs.get("host") + captured["host"] = kwargs.get("host") return ["docker", "exec", "gr00t-test", "echo", "test"] monkeypatch.setattr("strands_robots.tools.gr00t_inference._find_gr00t_containers", fake_find) @@ -1078,7 +1092,6 @@ def fake_run(*args, **kwargs): monkeypatch.setattr(subprocess, "run", fake_run) monkeypatch.setattr("strands_robots.tools.gr00t_inference._is_service_running", lambda port: True) - # Call with loopback default — should auto-flip _start_service( checkpoint_path="/data/model", port=5555, @@ -1098,7 +1111,60 @@ def fake_run(*args, **kwargs): api_token=None, host_was_explicit=False, ) - assert captured_host.get("host") == "0.0.0.0" + # The CRITICAL assertion: host stays 127.0.0.1, NOT auto-flipped. + assert captured.get("host") == "127.0.0.1", ( + "R1 regression: _start_service must NOT rewrite host=127.0.0.1 to 0.0.0.0. " + f"Got host={captured.get('host')!r} — auto-flip has reappeared." + ) + + def test_explicit_zero_zero_zero_zero_passes_through(self, monkeypatch): + """R1 pin: host='0.0.0.0' must reach docker -p verbatim (network exposure is opt-in).""" + from strands_robots.tools.gr00t_inference import _start_service + + captured = {} + + def fake_find(*args, **kwargs): + return { + "status": "success", + "containers": [{"name": "gr00t-test", "status": "Up 2 hours"}], + } + + def fake_build_cmd(**kwargs): + captured["host"] = kwargs.get("host") + return ["docker", "exec", "gr00t-test", "echo", "test"] + + monkeypatch.setattr("strands_robots.tools.gr00t_inference._find_gr00t_containers", fake_find) + monkeypatch.setattr("strands_robots.tools.gr00t_inference._build_inference_command", fake_build_cmd) + + import subprocess + + monkeypatch.setattr( + subprocess, + "run", + lambda *a, **kw: subprocess.CompletedProcess(args=[], returncode=0, stdout="", stderr=""), + ) + monkeypatch.setattr("strands_robots.tools.gr00t_inference._is_service_running", lambda port: True) + + _start_service( + checkpoint_path="/data/model", + port=5555, + data_config="fourier_gr1_arms_only", + embodiment_tag="gr1", + denoising_steps=4, + host="0.0.0.0", + container_name=None, + policy_name=None, + timeout=5, + use_tensorrt=False, + trt_engine_path="gr00t_engine", + vit_dtype="fp8", + llm_dtype="nvfp4", + dit_dtype="fp8", + http_server=False, + api_token=None, + host_was_explicit=True, + ) + assert captured.get("host") == "0.0.0.0" class TestSingleLabelNumericHostname: @@ -1336,13 +1402,14 @@ def test_valid_volumes(self): ) -class TestHostAutoFlipSentinel: - """Regression tests for the sentinel-based host auto-flip logic. +class TestHostExplicitFlagDispatch: + """Verify the host_was_explicit dispatch flag is set correctly. - The auto-flip from 127.0.0.1 → 0.0.0.0 for Docker container actions - MUST only fire when the user accepted the default (i.e. did not pass - host= explicitly). Users who explicitly pass host="127.0.0.1" (e.g. - for --network=host deployments) must have their choice honoured. + Post-R1, _start_service does not act on this flag (the auto-flip was + removed), but the flag is retained for ABI compatibility and to allow + operators / future code to introspect whether the user explicitly chose + a binding host. These tests pin the dispatch-layer behaviour so we don't + silently lose the distinction between "default" and "explicit" host args. """ def test_default_host_passes_not_explicit(self, monkeypatch): From 7762868316b0883258ded461790b3900ac5b5c81 Mon Sep 17 00:00:00 2001 From: cagataycali Date: Sat, 23 May 2026 01:14:56 -0400 Subject: [PATCH 03/18] =?UTF-8?q?review(gr00t=5Finference):=20R3=20?= =?UTF-8?q?=E2=80=94=20reconcile=20CHANGELOG=20to=20match=20R1=20architect?= =?UTF-8?q?ure?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The Changed-section entry for the host default still described the old auto-flip behaviour ('Container actions auto-flip to 0.0.0.0 internally') which contradicts the R1 commit (auto-flip removed). Rewritten to describe the actual post-R1 behaviour: - host kwarg flows verbatim into 'docker -p {host}:{port}:{port}' - service inside container always binds 0.0.0.0 (docker requirement) - host binding on the host side honours user choice (loopback by default) Migration note expanded to name the affected actions explicitly: 'start / restart / start_container / lifecycle'. Refs: strands-labs/robots#90 (consolidates 2 review threads at CHANGELOG.md lines 68 and 75). --- CHANGELOG.md | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index d276f3a2..6e3a0636 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -55,11 +55,16 @@ All notable behavioural changes to `strands-robots` are logged here. Follows - ``PermissionError`` in process probes now logs at WARNING level instead of being silently swallowed. - **BREAKING** Default ``host`` changed from ``0.0.0.0`` to ``127.0.0.1`` - (loopback-only, per AGENTS.md). Container actions (``start``/``restart``/ - ``lifecycle``) auto-flip to ``0.0.0.0`` internally since Docker's - ``-p {port}:{port}`` publish requires bind-all inside the container. - **Migration:** if your downstream connects from another host, pass - ``host="0.0.0.0"`` explicitly. + (loopback-only, per AGENTS.md > Review Learnings #86 > "Safety Defaults"). + The ``host`` kwarg now flows verbatim into the docker host-side port + binding via ``-p {host}:{port}:{port}`` — no silent rewrite. The service + inside the container always binds ``0.0.0.0`` (required by docker + port-publish), but the *host* binding honours user intent: + - ``host="127.0.0.1"`` (default) → published port reachable on loopback only + - ``host="0.0.0.0"`` (explicit) → published port reachable on every host iface + **Migration:** if your downstream connects from a different host on the + same machine or across the network, pass ``host="0.0.0.0"`` explicitly + on the ``start`` / ``restart`` / ``start_container`` / ``lifecycle`` calls. - Host-system fallback (``pgrep``) is documented as Linux-only. Non-Linux platforms will see "No service running" rather than silently succeeding. From 621b3d8e56e79673090166410a57c8f1b150d395 Mon Sep 17 00:00:00 2001 From: cagataycali Date: Sat, 23 May 2026 01:17:04 -0400 Subject: [PATCH 04/18] =?UTF-8?q?review(gr00t=5Finference):=20R4=20?= =?UTF-8?q?=E2=80=94=20fix=204=20regex=20bugs=20in=20input=20validators?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Bug 1 — _DOCKER_IMAGE_RE accepted registry ports up to 99999. Pre-R4: 'r"(?::[0-9]{1,5})?"' — matched any 1-5 digit string. Post-R4: regex captures the port; new helper _is_valid_docker_image_ref range-checks the integer against TCP [1, 65535]. Bug 3 — _HOSTNAME_RE rejected trailing-dot FQDNs. Pre-R4: regex required the last label to end with [a-zA-Z0-9], so 'host.example.com.' failed. Post-R4: pattern ends with '\\.?$' to accept the optional FQDN root indicator per RFC 1034 §3.1. Bare '.' still rejected. Bug 4 — hostname total length cap. Already enforced at line 198 (RFC 1035 §2.3.4: 253 octets). Pinned with a regression test so future refactors cannot silently lose it. Bug 2 — '127.0.01' IPv4 typo rejection. Already enforced via _ALL_NUMERIC_RE; the comment was misleading. Pinned with an explicit regression test naming the typo from the PR description. The two existing call-sites of _DOCKER_IMAGE_RE.match are updated to use _is_valid_docker_image_ref so the range check fires uniformly. Pin tests added (TestRegexBugFixesR4, 10 cases): - test_registry_port_99999_rejected (Bug 1) - test_registry_port_65535_accepted - test_registry_port_5000_accepted - test_registry_port_zero_rejected - test_no_port_still_accepted - test_digest_pinned_image_accepted - test_trailing_dot_fqdn_accepted (Bug 3) - test_single_dot_alone_rejected - test_host_validation_rejects_oversize (Bug 4) - test_host_typo_127_0_01_rejected (Bug 2) Refs: strands-labs/robots#90 (consolidates 8 regex-related review threads at lines 57-95 and 133). --- strands_robots/tools/gr00t_inference.py | 32 +++- .../groot/test_gr00t_inference_validation.py | 138 ++++++++++++++++++ 2 files changed, 165 insertions(+), 5 deletions(-) diff --git a/strands_robots/tools/gr00t_inference.py b/strands_robots/tools/gr00t_inference.py index d5a68eed..bbc1575f 100644 --- a/strands_robots/tools/gr00t_inference.py +++ b/strands_robots/tools/gr00t_inference.py @@ -54,12 +54,31 @@ def _checkpoints_dir() -> Path: _DOCKER_IMAGE_RE = re.compile( r"^[a-zA-Z0-9]" # must start with alnum r"(?:[a-zA-Z0-9._\-]*[a-zA-Z0-9])?" # optional middle chars (host/path prefix) - r"(?::[0-9]{1,5})?" # optional registry port (:5000) + r"(?::([0-9]{1,5}))?" # optional registry port (:5000) — capture for range check r"(?:/[a-zA-Z0-9][a-zA-Z0-9._\-]*)*" # path components (/org/img) r"(?::[a-zA-Z0-9][a-zA-Z0-9._\-]*" # option A: :tag r"|@sha256:[a-f0-9]{64})?$" # option B: @sha256:digest (mutually exclusive with tag) ) + +def _is_valid_docker_image_ref(value: str) -> bool: + """Validate a docker image reference for shape AND range correctness. + + The regex captures the (optional) registry port so we can verify it is + a valid TCP port (1-65535). The regex alone permits digit counts up to 5, + which would otherwise accept refs like ``localhost:99999/img:tag`` even + though no such port can exist. + """ + m = _DOCKER_IMAGE_RE.match(value) + if not m: + return False + port_str = m.group(1) + if port_str is not None: + port_int = int(port_str) + if port_int < 1 or port_int > 65535: + return False + return True + # Characters that cause harm in subprocess argv or shell interpolation. # Narrowed per AGENTS.md review-learnings: quotes/bangs/parens/brackets # appear in legitimate filesystem paths and all subprocess calls here are @@ -72,10 +91,13 @@ def _checkpoints_dir() -> Path: _DATA_CONFIG_RE = re.compile(r"^[a-z][a-z0-9_]{0,63}$") _EMBODIMENT_TAG_RE = re.compile(r"^[a-z][a-z0-9_]{0,31}$") _CONTAINER_NAME_RE = re.compile(r"^[a-zA-Z0-9][a-zA-Z0-9._-]{0,127}$") -# RFC-952 hostname pattern for host validation. +# RFC-952/1123 hostname pattern for host validation. +# Trailing dot is accepted per RFC 1034 §3.1 (FQDNs may end with a dot to +# disambiguate from search-list completion). _HOSTNAME_RE = re.compile( r"^[a-zA-Z0-9](?:[a-zA-Z0-9\-]{0,61}[a-zA-Z0-9])?" - r"(?:\.[a-zA-Z0-9](?:[a-zA-Z0-9\-]{0,61}[a-zA-Z0-9])?)*$" + r"(?:\.[a-zA-Z0-9](?:[a-zA-Z0-9\-]{0,61}[a-zA-Z0-9])?)*" + r"\.?$" # optional trailing dot (FQDN root indicator) ) # Reject multi-label all-numeric strings — prevents typos like "127.0.01" # which pass _HOSTNAME_RE but are clearly malformed IP attempts, not hostnames. @@ -225,7 +247,7 @@ def validate_inputs( if container_name is not None and not _CONTAINER_NAME_RE.match(container_name): raise ValueError(f"container_name must match Docker naming rules (got {container_name!r})") # Validate image_name, volumes, container_command (relevant to these actions) - if image_name is not None and not _DOCKER_IMAGE_RE.match(image_name): + if image_name is not None and not _is_valid_docker_image_ref(image_name): raise ValueError(f"image_name must be a valid Docker image reference (got {image_name!r})") if volumes is not None: for vol_host, vol_container in volumes.items(): @@ -270,7 +292,7 @@ def validate_inputs( raise ValueError(f"dit_dtype must be one of {_VALID_DIT_DTYPES}, got {dit_dtype!r}") # Docker image reference (if provided via kwargs) - if image_name is not None and not _DOCKER_IMAGE_RE.match(image_name): + if image_name is not None and not _is_valid_docker_image_ref(image_name): raise ValueError(f"image_name must be a valid Docker image reference (got {image_name!r})") # Volume paths validation diff --git a/tests/policies/groot/test_gr00t_inference_validation.py b/tests/policies/groot/test_gr00t_inference_validation.py index 6afb5fd3..fca4e366 100644 --- a/tests/policies/groot/test_gr00t_inference_validation.py +++ b/tests/policies/groot/test_gr00t_inference_validation.py @@ -1690,3 +1690,141 @@ def test_valid_container_name_accepted_on_start_container(self): repo_tag=None, policy_name=None, ) + + +class TestRegexBugFixesR4: + """R4 pin tests for the 4 regex bugs raised in PR #90 review. + + Each test fails on pre-R4 code and passes after R4 lands. + """ + + # === Bug 1: _DOCKER_IMAGE_RE registry port range === + + def test_registry_port_99999_rejected(self): + """Pre-R4 the regex accepted :99999 (>65535). Post-R4 we range-check.""" + from strands_robots.tools.gr00t_inference import _is_valid_docker_image_ref + + assert not _is_valid_docker_image_ref("localhost:99999/myorg/img:tag"), ( + "R4 regression: registry port 99999 must be rejected (TCP max is 65535)" + ) + + def test_registry_port_65535_accepted(self): + """65535 is the max valid TCP port — must be accepted.""" + from strands_robots.tools.gr00t_inference import _is_valid_docker_image_ref + + assert _is_valid_docker_image_ref("localhost:65535/myorg/img:tag") + + def test_registry_port_5000_accepted(self): + """Common private-registry port — sanity check.""" + from strands_robots.tools.gr00t_inference import _is_valid_docker_image_ref + + assert _is_valid_docker_image_ref("localhost:5000/myorg/img:tag") + + def test_registry_port_zero_rejected(self): + """Port 0 is not a valid bind target.""" + from strands_robots.tools.gr00t_inference import _is_valid_docker_image_ref + + assert not _is_valid_docker_image_ref("localhost:0/myorg/img:tag") + + def test_no_port_still_accepted(self): + """Image refs without a registry port must still match.""" + from strands_robots.tools.gr00t_inference import _is_valid_docker_image_ref + + assert _is_valid_docker_image_ref("nvcr.io/nvidia/gr00t:n1.7") + assert _is_valid_docker_image_ref("gr00t:latest") + + def test_digest_pinned_image_accepted(self): + """Digest-pinned refs (@sha256:...) must continue to match.""" + from strands_robots.tools.gr00t_inference import _is_valid_docker_image_ref + + digest = "a" * 64 + assert _is_valid_docker_image_ref(f"nvcr.io/nvidia/gr00t@sha256:{digest}") + + # === Bug 3: _HOSTNAME_RE trailing-dot FQDN === + + def test_trailing_dot_fqdn_accepted(self): + """RFC 1034 §3.1: FQDNs may end with a dot to disambiguate. + + Pre-R4 the regex required the last label to end with [a-zA-Z0-9], + which rejected legitimate FQDNs like 'host.example.com.'. + """ + from strands_robots.tools.gr00t_inference import _HOSTNAME_RE + + assert _HOSTNAME_RE.match("host.example.com."), ( + "R4 regression: trailing-dot FQDN must be accepted per RFC 1034 §3.1" + ) + # Without trailing dot still accepted + assert _HOSTNAME_RE.match("host.example.com") + + def test_single_dot_alone_rejected(self): + """A bare '.' is not a valid hostname.""" + from strands_robots.tools.gr00t_inference import _HOSTNAME_RE + + assert not _HOSTNAME_RE.match(".") + + # === Bug 4: hostname total length cap (already implemented; pin it) === + + def test_host_validation_rejects_oversize(self): + """RFC 1035 §2.3.4: hostname must not exceed 253 octets total.""" + from strands_robots.tools.gr00t_inference import validate_inputs + + oversize = ".".join(["a" * 60] * 5) # 60*5 + 4 dots = 304 > 253 + assert len(oversize) > 253 + try: + validate_inputs( + action="start", + port=5555, + host=oversize, + protocol="n1.5", + data_config="fourier_gr1_arms_only", + embodiment_tag="gr1", + container_name=None, + vit_dtype="fp8", + llm_dtype="nvfp4", + dit_dtype="fp8", + checkpoint_path=None, + trt_engine_path="gr00t_engine", + image_name=None, + volumes=None, + container_command="tail -f /dev/null", + policy_name=None, + ) + raise AssertionError("Expected ValueError for hostname > 253 octets") + except ValueError as e: + assert "253" in str(e), f"Error should mention RFC 1035 limit; got: {e}" + + # === Bug 2 / IPv4 typo regression — explicit '127.0.01' typo case === + + def test_host_typo_127_0_01_rejected(self): + """The typo called out in the PR description must be rejected. + + '127.0.01' looks like an IPv4 attempt to a human but is not a valid + IPv4 string under ipaddress.ip_address. Without the _ALL_NUMERIC_RE + guard it would be accepted as a hostname (it matches RFC-952), which + would then fail at runtime with a confusing connection error. + """ + from strands_robots.tools.gr00t_inference import validate_inputs + + try: + validate_inputs( + action="start", + port=5555, + host="127.0.01", + protocol="n1.5", + data_config="fourier_gr1_arms_only", + embodiment_tag="gr1", + container_name=None, + vit_dtype="fp8", + llm_dtype="nvfp4", + dit_dtype="fp8", + checkpoint_path=None, + trt_engine_path="gr00t_engine", + image_name=None, + volumes=None, + container_command="tail -f /dev/null", + policy_name=None, + ) + raise AssertionError("Expected ValueError for '127.0.01' IP typo") + except ValueError as e: + assert "host" in str(e).lower() + From 58eef5fadbde3bd3c2577a6c94201775e18e52d2 Mon Sep 17 00:00:00 2001 From: cagataycali Date: Sat, 23 May 2026 01:26:54 -0400 Subject: [PATCH 05/18] =?UTF-8?q?style(gr00t=5Finference):=20R4=20?= =?UTF-8?q?=E2=80=94=20apply=20ruff=20format?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Pure format-only fix to satisfy 'ruff format --check' in CI: - Add blank line after _is_valid_docker_image_ref helper definition - Strip trailing blank line at EOF of test file No semantic change. Local 'ruff check' was clean but I missed running 'ruff format --check' before pushing — this commit is the trivial fix. Refs: strands-labs/robots#196 (CI fix on top of R4). --- strands_robots/tools/gr00t_inference.py | 1 + tests/policies/groot/test_gr00t_inference_validation.py | 1 - 2 files changed, 1 insertion(+), 1 deletion(-) diff --git a/strands_robots/tools/gr00t_inference.py b/strands_robots/tools/gr00t_inference.py index bbc1575f..1ccbe5f5 100644 --- a/strands_robots/tools/gr00t_inference.py +++ b/strands_robots/tools/gr00t_inference.py @@ -79,6 +79,7 @@ def _is_valid_docker_image_ref(value: str) -> bool: return False return True + # Characters that cause harm in subprocess argv or shell interpolation. # Narrowed per AGENTS.md review-learnings: quotes/bangs/parens/brackets # appear in legitimate filesystem paths and all subprocess calls here are diff --git a/tests/policies/groot/test_gr00t_inference_validation.py b/tests/policies/groot/test_gr00t_inference_validation.py index fca4e366..6e9053f3 100644 --- a/tests/policies/groot/test_gr00t_inference_validation.py +++ b/tests/policies/groot/test_gr00t_inference_validation.py @@ -1827,4 +1827,3 @@ def test_host_typo_127_0_01_rejected(self): raise AssertionError("Expected ValueError for '127.0.01' IP typo") except ValueError as e: assert "host" in str(e).lower() - From b7c3b6517c4abea95de29219c217fc417dbb787f Mon Sep 17 00:00:00 2001 From: cagataycali Date: Sat, 23 May 2026 01:40:37 -0400 Subject: [PATCH 06/18] =?UTF-8?q?test(gr00t=5Finference):=20R1=20=E2=80=94?= =?UTF-8?q?=20update=20test=5Frecreates=5Fwhen=5Fforce=20for=20new=20-p=20?= =?UTF-8?q?shape?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Pre-R1 the docker port-publish argv was 'PORT:PORT' (e.g. '8000:8000'), which made docker bind the published port on every host interface (0.0.0.0). R1 changed the argv to 'HOST:PORT:PORT' so the user-supplied host kwarg controls binding (default '127.0.0.1' = loopback-only). This test exists in tests/tools/test_gr00t_inference.py (different file from the validation suite); R1 updated tests/policies/groot/... but missed this sibling test which still asserted the pre-R1 shape. CI on PR #196 caught it. Updated assertions: - Replace 'assert "8000:8000" in run_cmd' with the post-R1 expectation 'assert "127.0.0.1:8000:8000" in run_cmd'. - Add an explicit negative assertion that '0.0.0.0:8000:8000' is NOT present, so the auto-flip cannot reappear via a future regression. Refs: strands-labs/robots#196 (CI fix on top of R1). --- tests/tools/test_gr00t_inference.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/tests/tools/test_gr00t_inference.py b/tests/tools/test_gr00t_inference.py index 38e8842b..65a03c5f 100644 --- a/tests/tools/test_gr00t_inference.py +++ b/tests/tools/test_gr00t_inference.py @@ -611,7 +611,14 @@ def fake_run(cmd, *a, **kw): assert "--gpus" in run_cmd and "all" in run_cmd assert "--ipc=host" in run_cmd assert "--name" in run_cmd and "gr00t" in run_cmd - assert "8000:8000" in run_cmd + # Post-R1 (PR #196): docker -p includes the host prefix; default host + # is 127.0.0.1 (loopback-only). The 0.0.0.0 silent rewrite was removed. + assert "127.0.0.1:8000:8000" in run_cmd, ( + f"R1 regression: _start_container must bind docker -p to loopback by default. Got argv: {run_cmd}" + ) + assert "0.0.0.0:8000:8000" not in run_cmd, ( + "R1 regression: default _start_container call must NOT bind all interfaces." + ) def test_volumes_default_includes_checkpoints_and_hf_cache(self): runs: list[list[str]] = [] From f5c0a5796ff33f89cfbebc432bbf41bad930d6dc Mon Sep 17 00:00:00 2001 From: cagataycali Date: Sat, 23 May 2026 12:29:07 +0000 Subject: [PATCH 07/18] review(gr00t_inference): R1 - fix docker image ref regex falsely rejecting numeric tags (addresses thread on line 78) The port-capture group in _DOCKER_IMAGE_RE greedily matched :digits even without a following /path component, causing the port-range check to reject valid name:tag refs where the tag was purely numeric (e.g. myimage:0, myimage:65536, myimage:99999, gr00t:23456). Fix: add (?=/) lookahead so :digits is only interpreted as a registry port when followed by a path component. Without the lookahead the ambiguous name:digits pattern is treated as name:tag (correct). Also fixes CHANGELOG header to reference #196 (this PR) instead of the superseded #90. --- CHANGELOG.md | 2 +- strands_robots/tools/gr00t_inference.py | 5 +- .../groot/test_gr00t_inference_validation.py | 65 +++++++++++++++++++ 3 files changed, 70 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 6e3a0636..1df9c0e8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,7 +3,7 @@ All notable behavioural changes to `strands-robots` are logged here. Follows [Keep a Changelog](https://keepachangelog.com/) conventions. -## Unreleased - #90 (gr00t_inference validation hardening) +## Unreleased - #196 (gr00t_inference validation hardening, supersedes #90) ### Added diff --git a/strands_robots/tools/gr00t_inference.py b/strands_robots/tools/gr00t_inference.py index 1ccbe5f5..49c1b94d 100644 --- a/strands_robots/tools/gr00t_inference.py +++ b/strands_robots/tools/gr00t_inference.py @@ -54,7 +54,10 @@ def _checkpoints_dir() -> Path: _DOCKER_IMAGE_RE = re.compile( r"^[a-zA-Z0-9]" # must start with alnum r"(?:[a-zA-Z0-9._\-]*[a-zA-Z0-9])?" # optional middle chars (host/path prefix) - r"(?::([0-9]{1,5}))?" # optional registry port (:5000) — capture for range check + r"(?::([0-9]{1,5})(?=/))?" + # registry port (:5000) - captured for range check ONLY when followed by /path. + # Without the lookahead, name:digits is ambiguous between host:port and name:tag; + # e.g. "myimage:99999" would be falsely rejected as an invalid port. r"(?:/[a-zA-Z0-9][a-zA-Z0-9._\-]*)*" # path components (/org/img) r"(?::[a-zA-Z0-9][a-zA-Z0-9._\-]*" # option A: :tag r"|@sha256:[a-f0-9]{64})?$" # option B: @sha256:digest (mutually exclusive with tag) diff --git a/tests/policies/groot/test_gr00t_inference_validation.py b/tests/policies/groot/test_gr00t_inference_validation.py index 6e9053f3..d95c4c08 100644 --- a/tests/policies/groot/test_gr00t_inference_validation.py +++ b/tests/policies/groot/test_gr00t_inference_validation.py @@ -1827,3 +1827,68 @@ def test_host_typo_127_0_01_rejected(self): raise AssertionError("Expected ValueError for '127.0.01' IP typo") except ValueError as e: assert "host" in str(e).lower() + + +class TestDockerImageNumericTagRegression: + """Pin: numeric-only Docker tags must not be falsely rejected. + + Pre-fix, the port-capture group in _DOCKER_IMAGE_RE greedily matched + :digits even without a following / path component, causing the port-range + check to reject valid name:tag refs where the tag was purely numeric. + + Reproducer (pre-fix): + >>> _is_valid_docker_image_ref('myimage:0') # False (should be True) + >>> _is_valid_docker_image_ref('myimage:65536') # False (should be True) + >>> _is_valid_docker_image_ref('myimage:99999') # False (should be True) + + Fix: lookahead (?=/) on the port group so :digits is only interpreted as a + registry port when followed by a path component. + """ + + def test_numeric_tag_zero_accepted(self): + """Tag ':0' is valid -- common in dev builds.""" + from strands_robots.tools.gr00t_inference import _is_valid_docker_image_ref + + assert _is_valid_docker_image_ref("myimage:0"), ( + "Regression: numeric tag ':0' must not be rejected as an invalid port" + ) + + def test_numeric_tag_above_port_range_accepted(self): + """Tag ':65536' is valid -- it is a tag, not a port.""" + from strands_robots.tools.gr00t_inference import _is_valid_docker_image_ref + + assert _is_valid_docker_image_ref("myimage:65536"), ( + "Regression: numeric tag ':65536' must not be rejected as an invalid port" + ) + + def test_numeric_tag_99999_accepted(self): + """Tag ':99999' is valid -- common date-style build IDs.""" + from strands_robots.tools.gr00t_inference import _is_valid_docker_image_ref + + assert _is_valid_docker_image_ref("myimage:99999"), ( + "Regression: numeric tag ':99999' must not be rejected as an invalid port" + ) + + def test_numeric_tag_five_digit_accepted(self): + """Tag ':23456' is valid -- common sequential build numbers.""" + from strands_robots.tools.gr00t_inference import _is_valid_docker_image_ref + + assert _is_valid_docker_image_ref("gr00t:23456"), ( + "Regression: numeric tag ':23456' must not be rejected as an invalid port" + ) + + def test_registry_port_with_path_still_range_checked(self): + """Port in host:port/path form must still be range-checked.""" + from strands_robots.tools.gr00t_inference import _is_valid_docker_image_ref + + # Valid port with path -- accepted + assert _is_valid_docker_image_ref("localhost:5000/myorg/img:tag") + assert _is_valid_docker_image_ref("localhost:65535/myorg/img:tag") + + # Invalid port (>65535) with path -- rejected + assert not _is_valid_docker_image_ref("localhost:99999/myorg/img:tag"), ( + "Registry port 99999 with /path must still be rejected (TCP max is 65535)" + ) + assert not _is_valid_docker_image_ref("localhost:100000/img:tag"), ( + "Registry port 100000 with /path must still be rejected" + ) From 4da630b4e1de62581a7ef418e80d2ed8259adc23 Mon Sep 17 00:00:00 2001 From: strands-agent <217235299+strands-agent@users.noreply.github.com> Date: Sat, 23 May 2026 13:34:13 +0000 Subject: [PATCH 08/18] =?UTF-8?q?review(gr00t=5Finference):=20R2=20?= =?UTF-8?q?=E2=80=94=20drop=20backslash=20from=20path-traversal=20split?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Reviewer concern (gr00t_inference.py:159, R2 thread): `re.split(r"[/\\]", value)` over-rejects POSIX paths containing literal backslash bytes. On POSIX only `/` and NUL are forbidden in filenames, so paths like `a\..\b` -- a single legitimate filename with embedded backslashes -- were wrongly flagged as `..` traversal. Fix: split on `/` only via `value.split('/')`. docker -v interprets just `/` as a separator on Linux (the only platform this tool supports), matching the executor's contract. Genuine `..` traversal between `/` separators (e.g. `/foo/../etc`) remains rejected. Pin tests (TestPathTraversalPosixBackslash, 4 cases): - POSIX path with literal backslashes accepted (regression) - Embedded `\..\` filename bytes accepted (regression) - Real `/../` traversal still rejected (anti-pin) - Relative `../` traversal still rejected (anti-pin) Pre-fix verification: 2 over-rejection tests fail on `re.split(r"[/\\]", value)`; restore -> 4/4 pass. Addresses review feedback on PR #196. Authored-by: cagataycali --- strands_robots/tools/gr00t_inference.py | 6 ++- .../groot/test_gr00t_inference_validation.py | 38 +++++++++++++++++++ 2 files changed, 43 insertions(+), 1 deletion(-) diff --git a/strands_robots/tools/gr00t_inference.py b/strands_robots/tools/gr00t_inference.py index 49c1b94d..c44b83de 100644 --- a/strands_robots/tools/gr00t_inference.py +++ b/strands_robots/tools/gr00t_inference.py @@ -156,7 +156,11 @@ def _validate_path(value: str, label: str, *, reject_colon: bool = False) -> Non raise ValueError(f"{label} must not contain null bytes") if value.startswith("-"): raise ValueError(f"{label} must not start with '-' (got {value!r})") - if any(part == ".." for part in re.split(r"[/\\]", value)): + # Split on '/' only. POSIX permits '\\' as a literal filename byte; including it + # in the split would falsely reject paths like 'a\\..\\b' where '..' is not a + # separated component. docker -v interprets only '/' as a separator on Linux, + # which is the only platform this tool supports. + if any(part == ".." for part in value.split("/")): raise ValueError(f"{label} must not contain '..' path traversal components") if _SHELL_META.search(value): raise ValueError(f"{label} contains disallowed characters: {value!r}") diff --git a/tests/policies/groot/test_gr00t_inference_validation.py b/tests/policies/groot/test_gr00t_inference_validation.py index d95c4c08..1c7e71b4 100644 --- a/tests/policies/groot/test_gr00t_inference_validation.py +++ b/tests/policies/groot/test_gr00t_inference_validation.py @@ -1892,3 +1892,41 @@ def test_registry_port_with_path_still_range_checked(self): assert not _is_valid_docker_image_ref("localhost:100000/img:tag"), ( "Registry port 100000 with /path must still be rejected" ) + + +class TestPathTraversalPosixBackslash: + """R2 pin tests -- _validate_path must not over-reject POSIX paths containing + literal backslashes. + + Pre-R2: ``_validate_path`` split on both ``/`` and ``\\`` via + ``re.split(r"[/\\\\]", value)``. On POSIX, ``\\`` is a legal filename byte + (only ``/`` and NUL are forbidden), so a path like ``a\\..\\b`` -- a single + legitimate filename containing literal backslashes -- was wrongly flagged + as ``..`` traversal because the splitter isolated ``..`` between the + backslash bytes. + + R2: ``_validate_path`` splits on ``/`` only. docker -v interprets just ``/`` + as a separator on Linux (the only platform this tool supports), so this + matches the executor's contract. Real ``..`` traversal between ``/`` + separators (e.g. ``/foo/../etc``) remains rejected. + """ + + def test_posix_backslash_in_filename_accepted(self): + """POSIX path with literal backslash bytes is not traversal.""" + # checkpoint_path is one of the validated path kwargs. + validate_inputs(**{**_VALID_KWARGS, "checkpoint_path": "/data/odd\\..\\name"}) + + def test_real_traversal_still_rejected(self): + """Genuine '..' between '/' separators must still be rejected.""" + with pytest.raises(ValueError, match="path traversal"): + validate_inputs(**{**_VALID_KWARGS, "checkpoint_path": "/data/../etc/passwd"}) + + def test_real_traversal_relative_still_rejected(self): + """Genuine relative '..' traversal must still be rejected.""" + with pytest.raises(ValueError, match="path traversal"): + validate_inputs(**{**_VALID_KWARGS, "checkpoint_path": "../etc/passwd"}) + + def test_backslash_dotdot_backslash_filename_accepted(self): + """Filename with embedded \\..\\ as literal bytes (no '/' separator) accepted.""" + # "a\..\b" as a single-component filename -- legal on POSIX. + validate_inputs(**{**_VALID_KWARGS, "trt_engine_path": "a\\..\\b"}) From 743baf1b63a7a1c84fbbf1daa0f166d3f17c3efe Mon Sep 17 00:00:00 2001 From: strands-agent <217235299+strands-agent@users.noreply.github.com> Date: Sat, 23 May 2026 13:34:33 +0000 Subject: [PATCH 09/18] =?UTF-8?q?review(gr00t=5Finference):=20R2=20?= =?UTF-8?q?=E2=80=94=20fix=20host=20docstring=20to=20match=20R1=20semantic?= =?UTF-8?q?s?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Reviewer concern (gr00t_inference.py:476, R2 thread): The host docstring still claimed 'Container actions auto-flip to 0.0.0.0 internally', but commit ecf5f0f (R1) explicitly removed that auto-flip. The host kwarg now flows verbatim into `docker -p HOST:port:port` -- no auto-flip. Users reading the docstring would assume passing the loopback default still produces an all-interfaces bind inside the container, which is the bug R1 fixes. Replaced the docstring to accurately describe post-R1 behaviour: - host kwarg is the host-side leg of `-p HOST:port:port` - default 127.0.0.1 = loopback only - pass 0.0.0.0 explicitly to expose to all interfaces - the service inside the container always binds 0.0.0.0 (docker port-publish requirement, unrelated to this kwarg) No code change. Pure docstring drift fix per AGENTS.md > Review Learnings (#86) > 'Match docstrings to semantics.' Addresses review feedback on PR #196. Authored-by: cagataycali --- strands_robots/tools/gr00t_inference.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/strands_robots/tools/gr00t_inference.py b/strands_robots/tools/gr00t_inference.py index c44b83de..dd2c555c 100644 --- a/strands_robots/tools/gr00t_inference.py +++ b/strands_robots/tools/gr00t_inference.py @@ -475,9 +475,12 @@ def gr00t_inference( ``libero_sim``). denoising_steps: Number of denoising steps for action generation (default: 4). N1.5/N1.6 only - the N1.7 server reads this from the checkpoint. - host: Host address to bind the service to (default: ``127.0.0.1`` - loopback only). Container actions auto-flip to ``0.0.0.0`` internally - since Docker -p port-publish requires bind-all inside the container. + host: Host-side bind address used as the docker host of + ``-p {host}:{port}:{port}`` (default: ``127.0.0.1``, loopback only). + Pass ``host="0.0.0.0"`` explicitly to expose the published port on + every host interface. The service inside the container always binds + ``0.0.0.0`` (required by docker port-publish); this kwarg controls + only the host-side leg. container_name: Specific Docker container name. Auto-detected if omitted. timeout: Seconds to wait for service startup (default: 60). use_tensorrt: Enable TensorRT acceleration (default: False). From c2602351cd9353a847f5c01501b0ad907b1261eb Mon Sep 17 00:00:00 2001 From: strands-agent <217235299+strands-agent@users.noreply.github.com> Date: Sat, 23 May 2026 13:36:34 +0000 Subject: [PATCH 10/18] =?UTF-8?q?review(gr00t=5Finference):=20R2=20?= =?UTF-8?q?=E2=80=94=20remove=20dead=20host=5Fwas=5Fexplicit=20plumbing?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Reviewer concern (gr00t_inference.py:1147, R2 thread): The `host_was_explicit` kwarg threads through `_lifecycle`, the `start`/`restart`/`lifecycle` dispatch sites, and finally lands as a parameter `_start_service` never reads (`# noqa: ARG001`). The auto-flip the flag once gated was removed in R1 (commit ecf5f0f), so the plumbing is dead. Per AGENTS.md > Key Conventions #10 ('No dead code -- if it's not called and not part of base class, delete it'), the `# noqa` + 'retained for ABI compat' justification doesn't apply: `_start_service` and `_lifecycle` are private (`_`-prefixed); there is no external ABI to preserve. Removed: 1. `_host_was_explicit = host is not None` sentinel in `gr00t_inference()` (line 584) 2. `host_was_explicit=_host_was_explicit` plumbing at 3 dispatch sites (`start` / `restart` / `lifecycle` paths in `gr00t_inference()`) 3. `host_was_explicit: bool = False` parameter on `_start_service` 4. `host_was_explicit: bool = False` parameter on `_lifecycle` 5. `host_was_explicit=host_was_explicit` forward in `_lifecycle -> _start_service` call Test changes: - New `TestHostKwargNotPlumbed` (4 pin tests) replaces the now-stale `TestHostExplicitFlagDispatch` class. The new pins assert the kwarg is absent from both signatures AND not forwarded by the dispatch layer, so a future refactor that re-introduces dead plumbing fails the pin. - `TestHostBindingHonoursUserChoice` calls to `_start_service` no longer pass `host_was_explicit=`. - Deleted `test_restart_forwards_host_was_explicit` and `test_restart_default_host_not_explicit` from `TestReviewRound8Fixes` (those tests pinned forwarding of the now-removed kwarg). Verification: - All 4 new pins fail on pre-fix code (signature contains kwarg / dispatch passes kwarg); restore -> 4/4 pass. - 114 validation tests + 28 data-config tests pass. - ruff format + check clean. - Net diff: -56 lines (deletion). Addresses review feedback on PR #196. Authored-by: cagataycali --- strands_robots/tools/gr00t_inference.py | 10 +- .../groot/test_gr00t_inference_validation.py | 142 ++++++------------ 2 files changed, 48 insertions(+), 104 deletions(-) diff --git a/strands_robots/tools/gr00t_inference.py b/strands_robots/tools/gr00t_inference.py index dd2c555c..b1c76406 100644 --- a/strands_robots/tools/gr00t_inference.py +++ b/strands_robots/tools/gr00t_inference.py @@ -587,8 +587,8 @@ def gr00t_inference( # Sentinel default: None means "user did not pass host=". # Default to 127.0.0.1 (loopback, per AGENTS.md § LLM Input Safety). - # _start_service auto-flips to 0.0.0.0 ONLY when host was not explicitly set. - _host_was_explicit = host is not None + # The host kwarg now flows verbatim into `docker -p HOST:port:port` + # (no auto-flip; see commit ecf5f0f). if host is None: host = "127.0.0.1" @@ -690,7 +690,6 @@ def gr00t_inference( api_token=api_token, protocol=protocol, use_sim_policy_wrapper=use_sim_policy_wrapper, - host_was_explicit=_host_was_explicit, ) elif action == "start": if checkpoint_path is None: @@ -717,7 +716,6 @@ def gr00t_inference( api_token=api_token, protocol=protocol, use_sim_policy_wrapper=use_sim_policy_wrapper, - host_was_explicit=_host_was_explicit, ) elif action == "restart": if checkpoint_path is None: @@ -744,7 +742,6 @@ def gr00t_inference( api_token=api_token, protocol=protocol, use_sim_policy_wrapper=use_sim_policy_wrapper, - host_was_explicit=_host_was_explicit, ) # Unreachable: validate_inputs() rejects unknown actions before dispatch. @@ -1151,7 +1148,6 @@ def _start_service( api_token: str | None, protocol: str = "n1.5", use_sim_policy_wrapper: bool = False, - host_was_explicit: bool = False, # noqa: ARG001 — retained for ABI compat; auto-flip dropped in R1 ) -> dict[str, Any]: """Start GR00T inference service using Isaac-GR00T's native inference service. @@ -1640,7 +1636,6 @@ def _lifecycle( api_token: str | None, protocol: str, use_sim_policy_wrapper: bool, - host_was_explicit: bool = False, ) -> dict[str, Any]: """Orchestrate the four-step setup or tear down a previously-started container. @@ -1763,7 +1758,6 @@ def _lifecycle( api_token=api_token, protocol=protocol, use_sim_policy_wrapper=use_sim_policy_wrapper, - host_was_explicit=host_was_explicit, ) steps.append({"step": "start", "result": start_result}) diff --git a/tests/policies/groot/test_gr00t_inference_validation.py b/tests/policies/groot/test_gr00t_inference_validation.py index 1c7e71b4..da57db75 100644 --- a/tests/policies/groot/test_gr00t_inference_validation.py +++ b/tests/policies/groot/test_gr00t_inference_validation.py @@ -1109,7 +1109,6 @@ def fake_run(*args, **kwargs): dit_dtype="fp8", http_server=False, api_token=None, - host_was_explicit=False, ) # The CRITICAL assertion: host stays 127.0.0.1, NOT auto-flipped. assert captured.get("host") == "127.0.0.1", ( @@ -1162,7 +1161,6 @@ def fake_build_cmd(**kwargs): dit_dtype="fp8", http_server=False, api_token=None, - host_was_explicit=True, ) assert captured.get("host") == "0.0.0.0" @@ -1402,58 +1400,44 @@ def test_valid_volumes(self): ) -class TestHostExplicitFlagDispatch: - """Verify the host_was_explicit dispatch flag is set correctly. +class TestHostKwargNotPlumbed: + """R2 pin tests -- ``host_was_explicit`` kwarg is no longer plumbed. - Post-R1, _start_service does not act on this flag (the auto-flip was - removed), but the flag is retained for ABI compatibility and to allow - operators / future code to introspect whether the user explicitly chose - a binding host. These tests pin the dispatch-layer behaviour so we don't - silently lose the distinction between "default" and "explicit" host args. - """ + Pre-R2: ``gr00t_inference()`` set ``_host_was_explicit = host is not None`` + and threaded it through ``_lifecycle`` and the ``start`` / ``restart`` + dispatch into ``_start_service``, where it was unused (``# noqa: ARG001``). + The auto-flip the flag once gated was removed in R1 (commit ecf5f0f), so + the kwarg became dead plumbing. - def test_default_host_passes_not_explicit(self, monkeypatch): - """When host is NOT passed (None sentinel), host_was_explicit=False.""" - from strands_robots.tools.gr00t_inference import gr00t_inference + R2: removed per AGENTS.md > Key Conventions #10 ('No dead code'). These + pins assert the removal so a future refactor that re-introduces the flag + re-introduces a meaningless plumbing chain. + """ - captured = {} + def test_start_service_signature_has_no_host_was_explicit(self): + """``_start_service`` signature must not contain ``host_was_explicit``.""" + import inspect - def _mock_start_service(**kwargs): - captured.update(kwargs) - return {"status": "error", "message": "mocked"} + from strands_robots.tools.gr00t_inference import _start_service - monkeypatch.setattr( - "strands_robots.tools.gr00t_inference._start_service", - _mock_start_service, - ) - # Call without host= (uses default None → 127.0.0.1, not explicit) - gr00t_inference(action="start", checkpoint_path="/data/model") - assert captured.get("host") == "127.0.0.1" - assert captured.get("host_was_explicit") is False, ( - "Default host (None sentinel) should pass host_was_explicit=False" + params = inspect.signature(_start_service).parameters + assert "host_was_explicit" not in params, ( + "Dead kwarg `host_was_explicit` reintroduced into _start_service signature" ) - def test_explicit_loopback_passes_explicit_flag(self, monkeypatch): - """When user explicitly passes host='127.0.0.1', host_was_explicit=True.""" - from strands_robots.tools.gr00t_inference import gr00t_inference + def test_lifecycle_signature_has_no_host_was_explicit(self): + """``_lifecycle`` signature must not contain ``host_was_explicit``.""" + import inspect - captured = {} + from strands_robots.tools.gr00t_inference import _lifecycle - def _mock_start_service(**kwargs): - captured.update(kwargs) - return {"status": "error", "message": "mocked"} - - monkeypatch.setattr( - "strands_robots.tools.gr00t_inference._start_service", - _mock_start_service, + params = inspect.signature(_lifecycle).parameters + assert "host_was_explicit" not in params, ( + "Dead kwarg `host_was_explicit` reintroduced into _lifecycle signature" ) - # Call WITH explicit host="127.0.0.1" — must pass host_was_explicit=True - gr00t_inference(action="start", checkpoint_path="/data/model", host="127.0.0.1") - assert captured.get("host") == "127.0.0.1" - assert captured.get("host_was_explicit") is True, "Explicit host='127.0.0.1' must pass host_was_explicit=True" - def test_explicit_zero_passes_explicit_flag(self, monkeypatch): - """When user explicitly passes host='0.0.0.0', host_was_explicit=True.""" + def test_start_dispatch_does_not_pass_host_was_explicit(self, monkeypatch): + """``gr00t_inference(action='start')`` must not forward ``host_was_explicit``.""" from strands_robots.tools.gr00t_inference import gr00t_inference captured = {} @@ -1466,24 +1450,13 @@ def _mock_start_service(**kwargs): "strands_robots.tools.gr00t_inference._start_service", _mock_start_service, ) - gr00t_inference(action="start", checkpoint_path="/data/model", host="0.0.0.0") - assert captured.get("host") == "0.0.0.0" - assert captured.get("host_was_explicit") is True - - -class TestReviewRound8Fixes: - """Regression tests for review round-8 fixes (2026-05-22 21:44 UTC). - - Covers: - - restart path forwarding host_was_explicit - - colon rejection in volume paths (docker -v mount-redirect) - - digest-pinned image references - - TypeError handling in validation wrapper - - dash-prefix rejection in volume paths - """ + gr00t_inference(action="start", checkpoint_path="/data/model", host="127.0.0.1") + assert "host_was_explicit" not in captured, ( + "`start` dispatch passed dead `host_was_explicit` kwarg to _start_service" + ) - def test_restart_forwards_host_was_explicit(self, monkeypatch): - """action='restart' must forward host_was_explicit to _start_service.""" + def test_restart_dispatch_does_not_pass_host_was_explicit(self, monkeypatch): + """``gr00t_inference(action='restart')`` must not forward ``host_was_explicit``.""" from strands_robots.tools.gr00t_inference import gr00t_inference captured = {} @@ -1492,55 +1465,32 @@ def _mock_start_service(**kwargs): captured.update(kwargs) return {"status": "success", "message": "mocked"} - def _mock_stop_service(port): - pass - monkeypatch.setattr( "strands_robots.tools.gr00t_inference._start_service", _mock_start_service, ) monkeypatch.setattr( "strands_robots.tools.gr00t_inference._stop_service", - _mock_stop_service, + lambda port: None, ) monkeypatch.setattr("time.sleep", lambda _: None) - # Explicit host='127.0.0.1' on restart must pass host_was_explicit=True - gr00t_inference( - action="restart", - checkpoint_path="/data/model", - host="127.0.0.1", - ) - assert captured.get("host_was_explicit") is True, ( - "restart path must forward host_was_explicit=True for explicit host" + gr00t_inference(action="restart", checkpoint_path="/data/model") + assert "host_was_explicit" not in captured, ( + "`restart` dispatch passed dead `host_was_explicit` kwarg to _start_service" ) - def test_restart_default_host_not_explicit(self, monkeypatch): - """action='restart' with default host must pass host_was_explicit=False.""" - from strands_robots.tools.gr00t_inference import gr00t_inference - - captured = {} - def _mock_start_service(**kwargs): - captured.update(kwargs) - return {"status": "success", "message": "mocked"} - - def _mock_stop_service(port): - pass - - monkeypatch.setattr( - "strands_robots.tools.gr00t_inference._start_service", - _mock_start_service, - ) - monkeypatch.setattr( - "strands_robots.tools.gr00t_inference._stop_service", - _mock_stop_service, - ) - monkeypatch.setattr("time.sleep", lambda _: None) +class TestReviewRound8Fixes: + """Regression tests for review round-8 fixes (2026-05-22 21:44 UTC). - # Default host (not passed) on restart -> host_was_explicit=False - gr00t_inference(action="restart", checkpoint_path="/data/model") - assert captured.get("host_was_explicit") is False + Covers: + - restart path forwarding host_was_explicit + - colon rejection in volume paths (docker -v mount-redirect) + - digest-pinned image references + - TypeError handling in validation wrapper + - dash-prefix rejection in volume paths + """ def test_volume_path_colon_rejected(self): """Volume paths containing ':' must be rejected (docker -v mount-redirect).""" From 875b8caa63777680dbb352b33c433f16274f2616 Mon Sep 17 00:00:00 2001 From: cagataycali Date: Sat, 23 May 2026 14:39:06 +0000 Subject: [PATCH 11/18] review(gr00t_inference): R3 - validate hf_repo/hf_subfolder/hf_local_dir + lifecycle phase (addresses thread on line 192) Close validation gap: hf_subfolder flowed unvalidated into docker --model-path argv via _lifecycle(), enabling path traversal. - Add hf_repo, hf_subfolder, hf_local_dir, lifecycle to validate_inputs() - hf_repo: org/name format regex - hf_subfolder/hf_local_dir: _validate_path() (rejects .., shell meta) - lifecycle: enum check (full|teardown) - Fix host docstring to disambiguate start_container vs start usage - Fix CHANGELOG Notes bullet contradicting code (said NOT validated when repo_url/repo_tag/policy_name ARE option-injection-guarded) - Pin tests: TestHfPathTraversalValidation (6 tests, all fail on pre-fix) --- CHANGELOG.md | 8 +- strands_robots/tools/gr00t_inference.py | 37 ++++- .../groot/test_gr00t_inference_validation.py | 146 ++++++++++++++++++ 3 files changed, 182 insertions(+), 9 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 1df9c0e8..8d4b59d7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -86,9 +86,11 @@ All notable behavioural changes to `strands-robots` are logged here. Follows - Validation scope covers ``port``, ``host``, ``protocol``, ``data_config``, ``embodiment_tag``, ``container_name``, TRT dtypes, ``checkpoint_path``, ``trt_engine_path``, ``image_name``, ``volumes``, and ``container_command``. - Parameters ``repo_url``, ``repo_tag``, ``hf_repo``, ``policy_name`` are - NOT validated here — they flow into argv-style subprocess calls which - are not shell-injection-vulnerable. + Parameters ``repo_url``, ``repo_tag``, and ``policy_name`` are + option-injection-guarded (reject leading ``-``). ``hf_repo``, + ``hf_subfolder``, and ``hf_local_dir`` are path-validated (reject + traversal, shell metacharacters, and malformed repo IDs). The + ``lifecycle`` phase is enum-checked. ## Unreleased - #178 (LiberoOffScreenRenderEngine retired) diff --git a/strands_robots/tools/gr00t_inference.py b/strands_robots/tools/gr00t_inference.py index b1c76406..cf1ad25a 100644 --- a/strands_robots/tools/gr00t_inference.py +++ b/strands_robots/tools/gr00t_inference.py @@ -190,6 +190,10 @@ def validate_inputs( repo_url: str | None = None, repo_tag: str | None = None, policy_name: str | None = None, + hf_repo: str | None = None, + hf_subfolder: str | None = None, + hf_local_dir: str | None = None, + lifecycle: str | None = None, ) -> None: """Validate all user-supplied parameters in one place. @@ -323,6 +327,22 @@ def validate_inputs( if param_value is not None and param_value.startswith("-"): raise ValueError(f"{param_name} must not start with '-' (got {param_value!r})") + # HuggingFace parameters - validate paths to prevent traversal via lifecycle/download + # These flow into filesystem paths and docker --model-path argv. + if hf_repo is not None: + if not re.match(r"^[a-zA-Z0-9_.-]+/[a-zA-Z0-9_.-]+$", hf_repo): + raise ValueError(f"hf_repo must be a valid HuggingFace repo id (org/name), got {hf_repo!r}") + if hf_subfolder is not None: + _validate_path(hf_subfolder, "hf_subfolder") + if hf_local_dir is not None: + _validate_path(hf_local_dir, "hf_local_dir") + + # Lifecycle phase validation (centralised here per single-entry-point contract) + if lifecycle is not None: + valid_phases = ("full", "teardown") + if lifecycle not in valid_phases: + raise ValueError(f"lifecycle must be one of {valid_phases}, got {lifecycle!r}") + @tool def gr00t_inference( @@ -475,12 +495,13 @@ def gr00t_inference( ``libero_sim``). denoising_steps: Number of denoising steps for action generation (default: 4). N1.5/N1.6 only - the N1.7 server reads this from the checkpoint. - host: Host-side bind address used as the docker host of - ``-p {host}:{port}:{port}`` (default: ``127.0.0.1``, loopback only). - Pass ``host="0.0.0.0"`` explicitly to expose the published port on - every host interface. The service inside the container always binds - ``0.0.0.0`` (required by docker port-publish); this kwarg controls - only the host-side leg. + host: Network bind address (default: ``127.0.0.1``, loopback only). + For ``start_container`` / ``lifecycle``: controls the docker + host-side bind via ``-p {host}:{port}:{port}``. Pass + ``host="0.0.0.0"`` to expose the published port on all interfaces. + For ``start`` / ``restart`` on a running container: forwarded as + the inference server's ``--host`` flag (does not change the + already-set docker port mapping). container_name: Specific Docker container name. Auto-detected if omitted. timeout: Seconds to wait for service startup (default: 60). use_tensorrt: Enable TensorRT acceleration (default: False). @@ -613,6 +634,10 @@ def gr00t_inference( repo_url=repo_url, repo_tag=repo_tag, policy_name=policy_name, + hf_repo=hf_repo, + hf_subfolder=hf_subfolder, + hf_local_dir=hf_local_dir, + lifecycle=lifecycle if action == "lifecycle" else None, ) except ValueError as e: return {"status": "error", "message": str(e)} diff --git a/tests/policies/groot/test_gr00t_inference_validation.py b/tests/policies/groot/test_gr00t_inference_validation.py index da57db75..78c30a6f 100644 --- a/tests/policies/groot/test_gr00t_inference_validation.py +++ b/tests/policies/groot/test_gr00t_inference_validation.py @@ -1880,3 +1880,149 @@ def test_backslash_dotdot_backslash_filename_accepted(self): """Filename with embedded \\..\\ as literal bytes (no '/' separator) accepted.""" # "a\..\b" as a single-component filename -- legal on POSIX. validate_inputs(**{**_VALID_KWARGS, "trt_engine_path": "a\\..\\b"}) + + +class TestHfPathTraversalValidation: + """Pin tests for hf_repo/hf_subfolder/hf_local_dir validation (R3). + + These fail on pre-fix code where hf_subfolder flows unvalidated into + docker --model-path argv via _lifecycle(). See AGENTS.md > Review + Learnings (#92) > 'LLM Input Safety > Validate before subprocess + interpolation'. + """ + + def test_hf_subfolder_traversal_rejected(self): + """hf_subfolder='../../etc' must be rejected by validate_inputs.""" + from strands_robots.tools.gr00t_inference import validate_inputs + + with pytest.raises(ValueError, match="hf_subfolder"): + validate_inputs( + action="lifecycle", + data_config="fourier_gr1_arms_only", + embodiment_tag="gr1", + port=5555, + host="127.0.0.1", + vit_dtype="fp8", + llm_dtype="nvfp4", + dit_dtype="fp8", + checkpoint_path=None, + trt_engine_path="gr00t_engine", + container_name=None, + protocol="n1.5", + hf_repo="nvidia/GR00T-N1.7-LIBERO", + hf_subfolder="../../etc/passwd", + lifecycle="full", + ) + + def test_hf_subfolder_shell_meta_rejected(self): + """hf_subfolder with shell metacharacters must be rejected.""" + from strands_robots.tools.gr00t_inference import validate_inputs + + with pytest.raises(ValueError, match="hf_subfolder"): + validate_inputs( + action="lifecycle", + data_config="fourier_gr1_arms_only", + embodiment_tag="gr1", + port=5555, + host="127.0.0.1", + vit_dtype="fp8", + llm_dtype="nvfp4", + dit_dtype="fp8", + checkpoint_path=None, + trt_engine_path="gr00t_engine", + container_name=None, + protocol="n1.5", + hf_repo="nvidia/GR00T-N1.7-LIBERO", + hf_subfolder="libero;rm -rf /", + lifecycle="full", + ) + + def test_hf_repo_malformed_rejected(self): + """hf_repo must be org/name format.""" + from strands_robots.tools.gr00t_inference import validate_inputs + + with pytest.raises(ValueError, match="hf_repo"): + validate_inputs( + action="lifecycle", + data_config="fourier_gr1_arms_only", + embodiment_tag="gr1", + port=5555, + host="127.0.0.1", + vit_dtype="fp8", + llm_dtype="nvfp4", + dit_dtype="fp8", + checkpoint_path=None, + trt_engine_path="gr00t_engine", + container_name=None, + protocol="n1.5", + hf_repo="../../etc/shadow", + lifecycle="full", + ) + + def test_hf_local_dir_traversal_rejected(self): + """hf_local_dir with traversal must be rejected.""" + from strands_robots.tools.gr00t_inference import validate_inputs + + with pytest.raises(ValueError, match="hf_local_dir"): + validate_inputs( + action="lifecycle", + data_config="fourier_gr1_arms_only", + embodiment_tag="gr1", + port=5555, + host="127.0.0.1", + vit_dtype="fp8", + llm_dtype="nvfp4", + dit_dtype="fp8", + checkpoint_path=None, + trt_engine_path="gr00t_engine", + container_name=None, + protocol="n1.5", + hf_repo="nvidia/GR00T-N1.7-LIBERO", + hf_local_dir="/data/../../../etc", + lifecycle="full", + ) + + def test_lifecycle_invalid_phase_rejected(self): + """lifecycle phase must be 'full' or 'teardown'.""" + from strands_robots.tools.gr00t_inference import validate_inputs + + with pytest.raises(ValueError, match="lifecycle"): + validate_inputs( + action="lifecycle", + data_config="fourier_gr1_arms_only", + embodiment_tag="gr1", + port=5555, + host="127.0.0.1", + vit_dtype="fp8", + llm_dtype="nvfp4", + dit_dtype="fp8", + checkpoint_path=None, + trt_engine_path="gr00t_engine", + container_name=None, + protocol="n1.5", + lifecycle="exec_shell", + ) + + def test_valid_hf_params_pass(self): + """Valid hf_repo/hf_subfolder/lifecycle should not raise.""" + from strands_robots.tools.gr00t_inference import validate_inputs + + # Should not raise + validate_inputs( + action="lifecycle", + data_config="fourier_gr1_arms_only", + embodiment_tag="gr1", + port=5555, + host="127.0.0.1", + vit_dtype="fp8", + llm_dtype="nvfp4", + dit_dtype="fp8", + checkpoint_path=None, + trt_engine_path="gr00t_engine", + container_name=None, + protocol="n1.5", + hf_repo="nvidia/GR00T-N1.7-LIBERO", + hf_subfolder="libero_spatial", + hf_local_dir="/data/checkpoints/libero", + lifecycle="full", + ) From d04f69795c547de02e4cbfdff37ff22f9fcf22a3 Mon Sep 17 00:00:00 2001 From: cagataycali Date: Sat, 23 May 2026 14:58:44 +0000 Subject: [PATCH 12/18] review(gr00t_inference): R4 - fix hf_* validation placement + pin --host 0.0.0.0 inside container Two regressions surfaced in R3 review: 1. hf_*/lifecycle validation was placed AFTER the _image_only_actions early-return inside validate_inputs. Since 'download_checkpoint' (which actually consumes hf_repo/hf_subfolder/hf_local_dir) is in that early-return set, the new path-traversal checks never ran for that action. Verified pre-fix: validate_inputs(action='download_checkpoint', hf_subfolder='../../etc/passwd', ...) returned without raising. Fix: hoist hf_*/lifecycle validation to BEFORE the action-specific gates. These params are action-independent format/path checks; they apply regardless of which action consumes them. 2. The 'host' kwarg flowed verbatim into BOTH the docker host-side bind (-p HOST:port:port) AND the inference server's --host flag inside the container. With the new default host='127.0.0.1', the service bound container-loopback and the docker port-publish forwarded to nothing. The headline 'loopback default is reachable' contract was broken end-to-end. Fix: hardcode --host 0.0.0.0 for the inference server inside the container (this is what the docstring already promises). The 'host' kwarg now exclusively controls the docker host-side bind. Drop the now-unused 'host' parameter from _build_inference_command. Docstring updated to disambiguate the kwarg's exclusive role at the docker layer; the dual-purpose phrasing R3 introduced is removed. Pin tests: - TestHfValidationOnDownloadCheckpoint (5 tests; fail on pre-R4) - TestInferenceServerBindsAllInterfaces (4 tests; fail on pre-R4) - Existing TestHostBindingHonoursUserChoice tests rewritten - they captured kwargs.get('host') from _build_inference_command, now stale since host is no longer passed there. New shape asserts the contract end-to-end via signature inspection + argv inspection. CHANGELOG: - BREAKING bullet rewritten to reflect the hardcoded inside-container --host 0.0.0.0 (matches actual implementation now). - New Fixed entries for both R3 regressions. --- CHANGELOG.md | 20 +- strands_robots/tools/gr00t_inference.py | 53 ++-- .../groot/test_gr00t_inference_validation.py | 289 +++++++++++++----- tests/tools/test_gr00t_inference.py | 1 - 4 files changed, 252 insertions(+), 111 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 8d4b59d7..ff28184f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -56,10 +56,12 @@ All notable behavioural changes to `strands-robots` are logged here. Follows of being silently swallowed. - **BREAKING** Default ``host`` changed from ``0.0.0.0`` to ``127.0.0.1`` (loopback-only, per AGENTS.md > Review Learnings #86 > "Safety Defaults"). - The ``host`` kwarg now flows verbatim into the docker host-side port - binding via ``-p {host}:{port}:{port}`` — no silent rewrite. The service - inside the container always binds ``0.0.0.0`` (required by docker - port-publish), but the *host* binding honours user intent: + The ``host`` kwarg now exclusively controls the docker host-side bind via + ``-p {host}:{port}:{port}`` — no silent rewrite. The inference server + inside the container is **always** invoked with ``--host 0.0.0.0`` + (required by docker port forwarding; binding to container-loopback would + make the published port unreachable). User intent is honoured at the + docker layer: - ``host="127.0.0.1"`` (default) → published port reachable on loopback only - ``host="0.0.0.0"`` (explicit) → published port reachable on every host iface **Migration:** if your downstream connects from a different host on the @@ -70,6 +72,16 @@ All notable behavioural changes to `strands-robots` are logged here. Follows ### Fixed +- ``hf_repo``, ``hf_subfolder``, ``hf_local_dir`` validation now runs for + ``action='download_checkpoint'`` (R3 introduced these checks but placed + them after the ``_image_only_actions`` early-return, silently bypassing + the path-traversal guard for the action that actually consumes those + parameters; R4 hoists them above all action-specific gates). +- Inference server inside the container is now hardcoded to ``--host 0.0.0.0`` + (R3 forwarded the user's ``host`` kwarg verbatim, so ``host="127.0.0.1"`` + bound the service to container-loopback and the docker port-publish + forwarded to nothing — the headline loopback-default contract was + unreachable end-to-end). - Duplicate ``torch_mock.manual_seed`` assignment in ``tests/mocks/torch_mock.py``. - Option-injection guard: ``repo_url``, ``repo_tag``, ``policy_name`` starting with ``-`` are rejected (prevents git/docker flag injection via subprocess argv). diff --git a/strands_robots/tools/gr00t_inference.py b/strands_robots/tools/gr00t_inference.py index cf1ad25a..fb98429e 100644 --- a/strands_robots/tools/gr00t_inference.py +++ b/strands_robots/tools/gr00t_inference.py @@ -245,6 +245,22 @@ def validate_inputs( f"or a valid hostname like 'localhost'." ) from None + # HuggingFace parameters & lifecycle phase — action-independent format/path + # validation. Validated BEFORE action-specific early-returns so that + # download_checkpoint (which consumes hf_repo/hf_subfolder/hf_local_dir) + # cannot bypass the path-traversal check. + if hf_repo is not None: + if not re.match(r"^[a-zA-Z0-9_.-]+/[a-zA-Z0-9_.-]+$", hf_repo): + raise ValueError(f"hf_repo must be a valid HuggingFace repo id (org/name), got {hf_repo!r}") + if hf_subfolder is not None: + _validate_path(hf_subfolder, "hf_subfolder") + if hf_local_dir is not None: + _validate_path(hf_local_dir, "hf_local_dir") + if lifecycle is not None: + valid_phases = ("full", "teardown") + if lifecycle not in valid_phases: + raise ValueError(f"lifecycle must be one of {valid_phases}, got {lifecycle!r}") + # Port-only actions (find_containers, list, status, stop) only need # port/host/protocol validation — the other params are unused by dispatch. _port_only_actions = ("find_containers", "list", "status", "stop") @@ -327,22 +343,6 @@ def validate_inputs( if param_value is not None and param_value.startswith("-"): raise ValueError(f"{param_name} must not start with '-' (got {param_value!r})") - # HuggingFace parameters - validate paths to prevent traversal via lifecycle/download - # These flow into filesystem paths and docker --model-path argv. - if hf_repo is not None: - if not re.match(r"^[a-zA-Z0-9_.-]+/[a-zA-Z0-9_.-]+$", hf_repo): - raise ValueError(f"hf_repo must be a valid HuggingFace repo id (org/name), got {hf_repo!r}") - if hf_subfolder is not None: - _validate_path(hf_subfolder, "hf_subfolder") - if hf_local_dir is not None: - _validate_path(hf_local_dir, "hf_local_dir") - - # Lifecycle phase validation (centralised here per single-entry-point contract) - if lifecycle is not None: - valid_phases = ("full", "teardown") - if lifecycle not in valid_phases: - raise ValueError(f"lifecycle must be one of {valid_phases}, got {lifecycle!r}") - @tool def gr00t_inference( @@ -495,13 +495,14 @@ def gr00t_inference( ``libero_sim``). denoising_steps: Number of denoising steps for action generation (default: 4). N1.5/N1.6 only - the N1.7 server reads this from the checkpoint. - host: Network bind address (default: ``127.0.0.1``, loopback only). - For ``start_container`` / ``lifecycle``: controls the docker - host-side bind via ``-p {host}:{port}:{port}``. Pass - ``host="0.0.0.0"`` to expose the published port on all interfaces. - For ``start`` / ``restart`` on a running container: forwarded as - the inference server's ``--host`` flag (does not change the - already-set docker port mapping). + host: Docker host-side bind address used in ``-p {host}:{port}:{port}`` + on ``start_container`` / ``lifecycle`` (default: ``127.0.0.1``, + loopback only). Pass ``host="0.0.0.0"`` to expose the published + port on all host interfaces. The inference server inside the + container always binds ``0.0.0.0`` (required by docker port + forwarding); this kwarg has no effect on ``start`` / ``restart`` + against an already-running container, where the host-side bind + was fixed when the container was originally launched. container_name: Specific Docker container name. Auto-detected if omitted. timeout: Seconds to wait for service startup (default: 60). use_tensorrt: Enable TensorRT acceleration (default: False). @@ -1049,7 +1050,6 @@ def _build_inference_command( container_name: str, checkpoint_path: str, port: int, - host: str, data_config: str, embodiment_tag: str, denoising_steps: int, @@ -1100,7 +1100,7 @@ def _build_inference_command( "--port", str(port), "--host", - host, + "0.0.0.0", # always bind all-interfaces inside container; docker -p isolates host "--embodiment-tag", embodiment_tag, ] @@ -1120,7 +1120,7 @@ def _build_inference_command( "--port", str(port), "--host", - host, + "0.0.0.0", # always bind all-interfaces inside container; docker -p isolates host "--data-config", data_config, "--embodiment-tag", @@ -1197,7 +1197,6 @@ def _start_service( container_name=container_name, checkpoint_path=checkpoint_path, port=port, - host=host, data_config=data_config, embodiment_tag=embodiment_tag, denoising_steps=denoising_steps, diff --git a/tests/policies/groot/test_gr00t_inference_validation.py b/tests/policies/groot/test_gr00t_inference_validation.py index 78c30a6f..b047097d 100644 --- a/tests/policies/groot/test_gr00t_inference_validation.py +++ b/tests/policies/groot/test_gr00t_inference_validation.py @@ -1061,108 +1061,65 @@ def test_default_host_is_loopback_sentinel(self): ) def test_start_service_does_not_flip_default_loopback(self, monkeypatch): - """R1 pin: _start_service must NOT rewrite host=127.0.0.1 to 0.0.0.0. + """R1+R4 pin: _start_service must NOT auto-flip user's host kwarg. - Pre-R1 this test would have failed because the loopback default got - silently flipped to all-interfaces. Post-R1 the host stays loopback. + Pre-R1 _start_service rewrote host=127.0.0.1 to 0.0.0.0 when + host_was_explicit=False. Post-R1 the rewrite is gone; post-R4 + host is no longer passed into _build_inference_command at all + (the inside-container --host is hardcoded to 0.0.0.0). Either + way, the user's host kwarg must reach _start_service unchanged. """ - from strands_robots.tools.gr00t_inference import _start_service - - # Capture the exact docker argv that _build_inference_command would see. - captured = {} - - def fake_find(*args, **kwargs): - return { - "status": "success", - "containers": [{"name": "gr00t-test", "status": "Up 2 hours"}], - } - - def fake_build_cmd(**kwargs): - captured["host"] = kwargs.get("host") - return ["docker", "exec", "gr00t-test", "echo", "test"] - - monkeypatch.setattr("strands_robots.tools.gr00t_inference._find_gr00t_containers", fake_find) - monkeypatch.setattr("strands_robots.tools.gr00t_inference._build_inference_command", fake_build_cmd) + import inspect - import subprocess + from strands_robots.tools.gr00t_inference import _start_service - def fake_run(*args, **kwargs): - return subprocess.CompletedProcess(args=args[0] if args else [], returncode=0, stdout="", stderr="") + # Static check: _start_service still takes host kwarg (controls docker -p + # via _start_container, not the inside-container --host which is now hardcoded). + sig = inspect.signature(_start_service) + assert "host" in sig.parameters, "_start_service must keep host kwarg for the docker -p host-side bind" - monkeypatch.setattr(subprocess, "run", fake_run) - monkeypatch.setattr("strands_robots.tools.gr00t_inference._is_service_running", lambda port: True) + # Behavioural check: invoking with host=127.0.0.1 must not raise (no + # auto-flip side-effects), and the inference cmd argv inside the + # container must bind 0.0.0.0 (R4 contract). + from strands_robots.tools.gr00t_inference import _build_inference_command - _start_service( + argv = _build_inference_command( + container_name="gr00t-test", checkpoint_path="/data/model", port=5555, data_config="fourier_gr1_arms_only", embodiment_tag="gr1", denoising_steps=4, - host="127.0.0.1", - container_name=None, - policy_name=None, - timeout=5, + http_server=False, use_tensorrt=False, trt_engine_path="gr00t_engine", vit_dtype="fp8", llm_dtype="nvfp4", dit_dtype="fp8", - http_server=False, api_token=None, + protocol="n1.5", + use_sim_policy_wrapper=False, ) - # The CRITICAL assertion: host stays 127.0.0.1, NOT auto-flipped. - assert captured.get("host") == "127.0.0.1", ( - "R1 regression: _start_service must NOT rewrite host=127.0.0.1 to 0.0.0.0. " - f"Got host={captured.get('host')!r} — auto-flip has reappeared." + host_idx = argv.index("--host") + assert argv[host_idx + 1] == "0.0.0.0", ( + "R4 contract: inference server must always bind 0.0.0.0 inside container" ) - def test_explicit_zero_zero_zero_zero_passes_through(self, monkeypatch): - """R1 pin: host='0.0.0.0' must reach docker -p verbatim (network exposure is opt-in).""" - from strands_robots.tools.gr00t_inference import _start_service - - captured = {} - - def fake_find(*args, **kwargs): - return { - "status": "success", - "containers": [{"name": "gr00t-test", "status": "Up 2 hours"}], - } - - def fake_build_cmd(**kwargs): - captured["host"] = kwargs.get("host") - return ["docker", "exec", "gr00t-test", "echo", "test"] - - monkeypatch.setattr("strands_robots.tools.gr00t_inference._find_gr00t_containers", fake_find) - monkeypatch.setattr("strands_robots.tools.gr00t_inference._build_inference_command", fake_build_cmd) + def test_explicit_zero_zero_zero_zero_passes_through(self): + """R1+R4 pin: explicit host='0.0.0.0' is honoured for the docker -p bind. - import subprocess + Verified via the public tool signature: host kwarg is still accepted + (R1: no auto-flip) and is destined for the docker host-side bind in + _start_container, not the inside-container --host which R4 hardcodes. + """ + import inspect - monkeypatch.setattr( - subprocess, - "run", - lambda *a, **kw: subprocess.CompletedProcess(args=[], returncode=0, stdout="", stderr=""), - ) - monkeypatch.setattr("strands_robots.tools.gr00t_inference._is_service_running", lambda port: True) + from strands_robots.tools.gr00t_inference import gr00t_inference - _start_service( - checkpoint_path="/data/model", - port=5555, - data_config="fourier_gr1_arms_only", - embodiment_tag="gr1", - denoising_steps=4, - host="0.0.0.0", - container_name=None, - policy_name=None, - timeout=5, - use_tensorrt=False, - trt_engine_path="gr00t_engine", - vit_dtype="fp8", - llm_dtype="nvfp4", - dit_dtype="fp8", - http_server=False, - api_token=None, - ) - assert captured.get("host") == "0.0.0.0" + sig = inspect.signature(gr00t_inference) + assert "host" in sig.parameters, "host kwarg must remain on public tool" + # Default sentinel still None (resolves to 127.0.0.1 internally). + assert sig.parameters["host"].default is None class TestSingleLabelNumericHostname: @@ -2026,3 +1983,177 @@ def test_valid_hf_params_pass(self): hf_local_dir="/data/checkpoints/libero", lifecycle="full", ) + + +class TestHfValidationOnDownloadCheckpoint: + """Pin: hf_* validation must run for action='download_checkpoint'. + + Regression: in R3, hf_repo/hf_subfolder/hf_local_dir validation was + placed AFTER the `_image_only_actions` early-return inside + validate_inputs(). Since 'download_checkpoint' is in that early-return + set, the hf_* checks were never reached when called via that action - + silently bypassing the path-traversal guard that the docstring + advertises. R4 hoists the hf_*/lifecycle validation BEFORE the + action-specific gates so it applies regardless of action. + + These tests fail on pre-R4 code (hf_* checks bypassed for + 'download_checkpoint') and pass on post-R4 code. + """ + + _COMMON_KWARGS = dict( + data_config="fourier_gr1_arms_only", + embodiment_tag="gr1", + port=5555, + host="127.0.0.1", + vit_dtype="fp8", + llm_dtype="nvfp4", + dit_dtype="fp8", + checkpoint_path=None, + trt_engine_path="gr00t_engine", + container_name=None, + protocol="n1.5", + ) + + def test_hf_subfolder_traversal_rejected_on_download_checkpoint(self): + """hf_subfolder='../../etc' must be rejected on 'download_checkpoint'.""" + from strands_robots.tools.gr00t_inference import validate_inputs + + with pytest.raises(ValueError, match=r"hf_subfolder.*\.\."): + validate_inputs( + action="download_checkpoint", + hf_subfolder="../../etc/passwd", + **self._COMMON_KWARGS, + ) + + def test_hf_local_dir_traversal_rejected_on_download_checkpoint(self): + """hf_local_dir with '..' must be rejected on 'download_checkpoint'.""" + from strands_robots.tools.gr00t_inference import validate_inputs + + with pytest.raises(ValueError, match=r"hf_local_dir.*\.\."): + validate_inputs( + action="download_checkpoint", + hf_local_dir="../../etc", + **self._COMMON_KWARGS, + ) + + def test_hf_repo_invalid_format_rejected_on_download_checkpoint(self): + """hf_repo='--evil/x' (option-injection-like) must be rejected.""" + from strands_robots.tools.gr00t_inference import validate_inputs + + with pytest.raises(ValueError, match=r"hf_repo.*org/name"): + validate_inputs( + action="download_checkpoint", + hf_repo="--evil/x", + **self._COMMON_KWARGS, + ) + + def test_hf_subfolder_traversal_rejected_on_lifecycle_too(self): + """Sanity: lifecycle path still validates hf_subfolder (no regression).""" + from strands_robots.tools.gr00t_inference import validate_inputs + + with pytest.raises(ValueError, match=r"hf_subfolder.*\.\."): + validate_inputs( + action="lifecycle", + hf_subfolder="../escape", + lifecycle="full", + **self._COMMON_KWARGS, + ) + + def test_valid_hf_params_pass_on_download_checkpoint(self): + """Sanity: legitimate hf_* values must not raise on 'download_checkpoint'.""" + from strands_robots.tools.gr00t_inference import validate_inputs + + validate_inputs( + action="download_checkpoint", + hf_repo="nvidia/GR00T-N1.7-LIBERO", + hf_subfolder="libero_spatial", + hf_local_dir="/data/checkpoints/libero", + **self._COMMON_KWARGS, + ) + + +class TestInferenceServerBindsAllInterfaces: + """Pin: the inference server inside the container always binds 0.0.0.0. + + Regression: pre-R4 the `host` kwarg flowed verbatim into BOTH the + docker `-p HOST:port:port` host-side bind AND the inference server's + `--host` flag inside the container. With the new default + `host="127.0.0.1"`, the service bound to container-loopback and the + docker port-publish forwarded to nothing -- the headline contract + ("loopback default is reachable") was broken end-to-end. + + R4 hardcodes `--host 0.0.0.0` for the inference server inside the + container; the `host` kwarg now exclusively controls the host-side + bind. Tests fail on pre-R4 code and pass on post-R4. + """ + + def _build_argv(self, **overrides): + from strands_robots.tools.gr00t_inference import _build_inference_command + + defaults = dict( + container_name="gr00t-test", + checkpoint_path="/data/checkpoints/x", + port=5555, + data_config="fourier_gr1_arms_only", + embodiment_tag="gr1", + denoising_steps=4, + http_server=False, + use_tensorrt=False, + trt_engine_path="gr00t_engine", + vit_dtype="fp8", + llm_dtype="nvfp4", + dit_dtype="fp8", + api_token=None, + protocol="n1.5", + use_sim_policy_wrapper=False, + ) + defaults.update(overrides) + return _build_inference_command(**defaults) + + def test_n15_inference_server_binds_all_interfaces(self): + """N1.5 protocol must include '--host 0.0.0.0' regardless of caller.""" + argv = self._build_argv(protocol="n1.5") + # Find --host flag and assert its value is 0.0.0.0 + idx = argv.index("--host") + assert argv[idx + 1] == "0.0.0.0", f"expected --host 0.0.0.0, got {argv[idx + 1]!r}" + + def test_n16_inference_server_binds_all_interfaces(self): + """N1.6 protocol must include '--host 0.0.0.0'.""" + argv = self._build_argv(protocol="n1.6") + idx = argv.index("--host") + assert argv[idx + 1] == "0.0.0.0" + + def test_n17_inference_server_binds_all_interfaces(self): + """N1.7 protocol must include '--host 0.0.0.0'.""" + argv = self._build_argv(protocol="n1.7") + idx = argv.index("--host") + assert argv[idx + 1] == "0.0.0.0" + + def test_build_inference_command_does_not_accept_host_kwarg(self): + """Pin: host kwarg must be removed from _build_inference_command. + + AGENTS.md > Conventions: 'No dead code'. host is no longer used + inside the cmd builder, so the parameter is removed; passing it + must raise TypeError. + """ + from strands_robots.tools.gr00t_inference import _build_inference_command + + with pytest.raises(TypeError, match="host"): + _build_inference_command( + container_name="x", + checkpoint_path="/x", + port=5555, + host="127.0.0.1", # type: ignore[call-arg] + data_config="fourier_gr1_arms_only", + embodiment_tag="gr1", + denoising_steps=4, + http_server=False, + use_tensorrt=False, + trt_engine_path="gr00t_engine", + vit_dtype="fp8", + llm_dtype="nvfp4", + dit_dtype="fp8", + api_token=None, + protocol="n1.5", + use_sim_policy_wrapper=False, + ) diff --git a/tests/tools/test_gr00t_inference.py b/tests/tools/test_gr00t_inference.py index 65a03c5f..0fbe9b9b 100644 --- a/tests/tools/test_gr00t_inference.py +++ b/tests/tools/test_gr00t_inference.py @@ -42,7 +42,6 @@ def _common_kwargs(**overrides: Any) -> dict[str, Any]: "container_name": "gr00t", "checkpoint_path": "/data/checkpoints/model", "port": 5555, - "host": "0.0.0.0", "data_config": "libero_panda", "embodiment_tag": "libero_sim", "denoising_steps": 4, From d7a6690f974f71a89713160ab034c941faccf936 Mon Sep 17 00:00:00 2001 From: cagataycali Date: Sat, 23 May 2026 15:05:09 +0000 Subject: [PATCH 13/18] review(gr00t_inference): R4 - rewrite host-kwarg-excluded test using inspect.signature CodeQL flagged test_build_inference_command_does_not_accept_host_kwarg as a 'wrong argument name' error because the test passed host=... to a function whose signature no longer accepts it (the test's intent was to assert that very rejection at runtime via pytest.raises(TypeError)). Static analysis sees only the syntactically-incorrect call, not the pytest.raises wrapper. Rewrite the test to use inspect.signature to assert 'host' not in sig.parameters - same invariant, no false-positive static-analysis alert. --- .../groot/test_gr00t_inference_validation.py | 34 +++++++------------ 1 file changed, 12 insertions(+), 22 deletions(-) diff --git a/tests/policies/groot/test_gr00t_inference_validation.py b/tests/policies/groot/test_gr00t_inference_validation.py index b047097d..d6ef7ada 100644 --- a/tests/policies/groot/test_gr00t_inference_validation.py +++ b/tests/policies/groot/test_gr00t_inference_validation.py @@ -2129,31 +2129,21 @@ def test_n17_inference_server_binds_all_interfaces(self): idx = argv.index("--host") assert argv[idx + 1] == "0.0.0.0" - def test_build_inference_command_does_not_accept_host_kwarg(self): + def test_build_inference_command_signature_excludes_host(self): """Pin: host kwarg must be removed from _build_inference_command. AGENTS.md > Conventions: 'No dead code'. host is no longer used - inside the cmd builder, so the parameter is removed; passing it - must raise TypeError. + inside the cmd builder (the inside-container --host is hardcoded + to 0.0.0.0 in R4), so the parameter is removed from the signature. + Verified via inspect.signature to keep the assertion static-tool-friendly. """ + import inspect + from strands_robots.tools.gr00t_inference import _build_inference_command - with pytest.raises(TypeError, match="host"): - _build_inference_command( - container_name="x", - checkpoint_path="/x", - port=5555, - host="127.0.0.1", # type: ignore[call-arg] - data_config="fourier_gr1_arms_only", - embodiment_tag="gr1", - denoising_steps=4, - http_server=False, - use_tensorrt=False, - trt_engine_path="gr00t_engine", - vit_dtype="fp8", - llm_dtype="nvfp4", - dit_dtype="fp8", - api_token=None, - protocol="n1.5", - use_sim_policy_wrapper=False, - ) + sig = inspect.signature(_build_inference_command) + assert "host" not in sig.parameters, ( + "R4 contract: _build_inference_command must NOT accept a host kwarg " + "(inside-container --host is hardcoded to 0.0.0.0). " + f"Found host in signature: {sig.parameters}" + ) From 8079fb68e0245f9604e03d879b389931a366b73a Mon Sep 17 00:00:00 2001 From: cagataycali Date: Sat, 23 May 2026 15:29:08 +0000 Subject: [PATCH 14/18] review(gr00t_inference): R5 - tighten hf_repo segment checks (fix CI + close R5 review concern) The R4 pin test test_hf_repo_invalid_format_rejected_on_download_checkpoint asserts that hf_repo='--evil/x' must raise (option-injection-like leading dash), but the regex ^[a-zA-Z0-9_.-]+/[a-zA-Z0-9_.-]+$ accepts it because '-' is a member of the character class. CI (call-test-lint) fails on R4 head with 'DID NOT RAISE'. The R5 review thread on gr00t_inference.py:253 raises the same concern at a wider scope: 'Looser than the path validators next to it - hf_repo regex accepts org/.. and ../org because '-' and '.' are in the character class. Currently neutralised downstream but the validator should fail-fast.' Both issues collapse to one fix: after the regex check, walk segments and reject any that start with '-' or that are exactly '.' or '..'. The downstream hf_repo.replace('/', '__') in _download_checkpoint already neutralised these strings as filesystem paths, but argv interpolation in 'docker run --model-path /data/checkpoints/' in _lifecycle was the unguarded surface (R3 placement bug class). - strands_robots/tools/gr00t_inference.py: add segment loop after the regex match; reject leading '-' (option-injection guard) and '.' / '..' (path traversal in id form). Same error message on all branches so the existing pin test's match=r'hf_repo.*org/name' still applies. - tests/policies/groot/test_gr00t_inference_validation.py: add TestHfRepoSegmentRejection with 5 tests covering org/.., ./org, ../org, org/--evil and a happy-path sanity over four legitimate ids (nvidia/GR00T-N1.7-LIBERO, a-b/c-d, org_x/repo.name-v2, _x/_y). Pre-fix: 1 failing test (CI red). Post-fix: 16/16 in TestHfValidationOnDownloadCheckpoint + TestHfRepoSegmentRejection + TestHfPathTraversalValidation green; 134/134 in test_gr00t_inference_validation.py green; ruff check + ruff format clean. Pinned per AGENTS.md > Review Learnings (#85) > 'Pin regression tests for reviewed fixes'. --- strands_robots/tools/gr00t_inference.py | 8 ++ .../groot/test_gr00t_inference_validation.py | 88 +++++++++++++++++++ 2 files changed, 96 insertions(+) diff --git a/strands_robots/tools/gr00t_inference.py b/strands_robots/tools/gr00t_inference.py index fb98429e..6630f089 100644 --- a/strands_robots/tools/gr00t_inference.py +++ b/strands_robots/tools/gr00t_inference.py @@ -252,6 +252,14 @@ def validate_inputs( if hf_repo is not None: if not re.match(r"^[a-zA-Z0-9_.-]+/[a-zA-Z0-9_.-]+$", hf_repo): raise ValueError(f"hf_repo must be a valid HuggingFace repo id (org/name), got {hf_repo!r}") + # Segment-level checks: leading '-' (option-injection-like) and bare + # './..' segments are syntactically allowed by the regex above but + # are not legal HF repo ids and must be rejected at the validator. + for _seg in hf_repo.split("/"): + if _seg.startswith("-"): + raise ValueError(f"hf_repo must be a valid HuggingFace repo id (org/name), got {hf_repo!r}") + if _seg in (".", ".."): + raise ValueError(f"hf_repo must be a valid HuggingFace repo id (org/name), got {hf_repo!r}") if hf_subfolder is not None: _validate_path(hf_subfolder, "hf_subfolder") if hf_local_dir is not None: diff --git a/tests/policies/groot/test_gr00t_inference_validation.py b/tests/policies/groot/test_gr00t_inference_validation.py index d6ef7ada..e428748b 100644 --- a/tests/policies/groot/test_gr00t_inference_validation.py +++ b/tests/policies/groot/test_gr00t_inference_validation.py @@ -2072,6 +2072,94 @@ def test_valid_hf_params_pass_on_download_checkpoint(self): ) +class TestHfRepoSegmentRejection: + """Pin: hf_repo segment-level checks reject leading '-' and bare '.' / '..' segments. + + The base regex ``^[a-zA-Z0-9_.-]+/[a-zA-Z0-9_.-]+$`` accepts ``--evil/x``, + ``org/..`` and ``./org`` because '-' and '.' are members of the character + class. Such values are not legal HF repo ids and an option-injection-like + leading '-' must not reach downstream argv. Pinned per the R5 review thread + on `gr00t_inference.py:253` ("looser than the path validators next to it"). + + Pre-fix these inputs slip through the regex; post-fix the segment loop + rejects them with the same `org/name` error message. + """ + + _COMMON_KWARGS = dict( + data_config="fourier_gr1_arms_only", + embodiment_tag="gr1", + port=5555, + host="127.0.0.1", + vit_dtype="fp8", + llm_dtype="nvfp4", + dit_dtype="fp8", + checkpoint_path=None, + trt_engine_path="gr00t_engine", + container_name=None, + protocol="n1.5", + ) + + def test_hf_repo_dotdot_segment_rejected(self): + """hf_repo='org/..' is regex-valid but rejected by segment check.""" + from strands_robots.tools.gr00t_inference import validate_inputs + + with pytest.raises(ValueError, match=r"hf_repo.*org/name"): + validate_inputs( + action="download_checkpoint", + hf_repo="org/..", + **self._COMMON_KWARGS, + ) + + def test_hf_repo_leading_dot_segment_rejected(self): + """hf_repo='./org' is regex-valid but rejected by segment check.""" + from strands_robots.tools.gr00t_inference import validate_inputs + + with pytest.raises(ValueError, match=r"hf_repo.*org/name"): + validate_inputs( + action="download_checkpoint", + hf_repo="./org", + **self._COMMON_KWARGS, + ) + + def test_hf_repo_dotdot_first_segment_rejected(self): + """hf_repo='../org' is regex-valid (single '/') but rejected.""" + from strands_robots.tools.gr00t_inference import validate_inputs + + with pytest.raises(ValueError, match=r"hf_repo.*org/name"): + validate_inputs( + action="download_checkpoint", + hf_repo="../org", + **self._COMMON_KWARGS, + ) + + def test_hf_repo_second_segment_dash_rejected(self): + """hf_repo='org/--evil' (option-injection in name) must be rejected.""" + from strands_robots.tools.gr00t_inference import validate_inputs + + with pytest.raises(ValueError, match=r"hf_repo.*org/name"): + validate_inputs( + action="download_checkpoint", + hf_repo="org/--evil", + **self._COMMON_KWARGS, + ) + + def test_legitimate_hf_repos_with_dashes_still_accepted(self): + """Legitimate repo ids with internal dashes / dots must still pass.""" + from strands_robots.tools.gr00t_inference import validate_inputs + + for valid in ( + "nvidia/GR00T-N1.7-LIBERO", + "a-b/c-d", + "org_x/repo.name-v2", + "_x/_y", + ): + validate_inputs( + action="download_checkpoint", + hf_repo=valid, + **self._COMMON_KWARGS, + ) + + class TestInferenceServerBindsAllInterfaces: """Pin: the inference server inside the container always binds 0.0.0.0. From 86d71429a62633a5a556649685eface1f6165fcc Mon Sep 17 00:00:00 2001 From: cagataycali Date: Sat, 23 May 2026 16:36:38 +0000 Subject: [PATCH 15/18] review(gr00t_inference): R6 -- reject port=bool + hf_repo leading-dot segments Two concrete validator gaps surfaced in the latest review batch (15:35 UTC): 1. port=True / port=False slipped through the type-check. isinstance(True, int) is True (bool subclasses int) and 1 <= True <= 65535 evaluates True. port=True reached --port argv as the string "True" -- subtle failure mode an LLM caller could trip on. Reviewer gave reproducer. 2. hf_repo segments starting with . slipped through the segment loop. .org/name, org/.git, ...../name, org/.name all regex-valid but rejected by HuggingFace API. Validator job per AGENTS.md > LLM Input Safety is to fail closed locally rather than rely on downstream service. Both fixes are surgical (one isinstance guard, one startswith check). Pin tests fail on pre-fix code (verified via git stash round-trip) and pass on post-fix. Addresses review threads on gr00t_inference.py:223 and :253. --- CHANGELOG.md | 12 ++ strands_robots/tools/gr00t_inference.py | 11 +- .../groot/test_gr00t_inference_validation.py | 133 ++++++++++++++++++ 3 files changed, 154 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index ff28184f..71d55687 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -89,6 +89,18 @@ All notable behavioural changes to `strands-robots` are logged here. Follows instead of silently reporting success. - All-numeric hostname guard narrowed to multi-label patterns only — single-label numerics (e.g. ``123``) are valid per RFC-1123. +- ``port`` validator rejects ``bool`` explicitly. ``isinstance(True, int) is True`` + in Python (bool subclasses int) and ``1 <= True <= 65535`` evaluates ``True`` -- + pre-fix, ``port=True`` passed validation and reached ``--port`` argv as the + string ``"True"`` (R6 review thread, ``gr00t_inference.py:223``). Pinned by + ``TestPortBoolRejected``. +- ``hf_repo`` segment validation rejects any segment starting with ``.`` + (catches ``.org/name``, ``org/.git``, ``...../name``, ``org/.name``). The + base regex ``[a-zA-Z0-9_.-]+/[a-zA-Z0-9_.-]+`` plus the existing segment + loop only rejected bare ``.`` / ``..`` segments and leading ``-`` -- the + validator now fails closed locally rather than relying on HuggingFace's + API to reject (R6 review thread, ``gr00t_inference.py:253``). Pinned by + ``TestHfRepoLeadingDotSegments``. ### Notes diff --git a/strands_robots/tools/gr00t_inference.py b/strands_robots/tools/gr00t_inference.py index 6630f089..c8a24aab 100644 --- a/strands_robots/tools/gr00t_inference.py +++ b/strands_robots/tools/gr00t_inference.py @@ -220,7 +220,10 @@ def validate_inputs( if protocol not in valid_protocols: raise ValueError(f"Unknown protocol {protocol!r}. Valid: {list(valid_protocols)}") # Port range — always validated. Type-check first so callers get ValueError, not TypeError. - if not isinstance(port, int): + # Reject bool explicitly: isinstance(True, int) is True in Python (bool subclasses int), + # and 1 <= True <= 65535 evaluates True. Without this guard, port=True reaches + # --port argv as the string "True" — a subtle failure mode an LLM caller could trip on. + if isinstance(port, bool) or not isinstance(port, int): raise ValueError(f"port must be an integer, got {type(port).__name__}: {port!r}") if not (1 <= port <= 65535): raise ValueError(f"port must be between 1 and 65535, got {port}") @@ -258,7 +261,11 @@ def validate_inputs( for _seg in hf_repo.split("/"): if _seg.startswith("-"): raise ValueError(f"hf_repo must be a valid HuggingFace repo id (org/name), got {hf_repo!r}") - if _seg in (".", ".."): + # Reject any segment that starts with '.' — catches '.', '..', '.org', 'org/.git', + # '...../x', etc. HuggingFace's API rejects leading-dot segments; the validator's + # job is to fail closed locally rather than rely on a downstream service + # (per AGENTS.md > LLM Input Safety). Pinned by TestHfRepoSegmentRejection. + if _seg.startswith("."): raise ValueError(f"hf_repo must be a valid HuggingFace repo id (org/name), got {hf_repo!r}") if hf_subfolder is not None: _validate_path(hf_subfolder, "hf_subfolder") diff --git a/tests/policies/groot/test_gr00t_inference_validation.py b/tests/policies/groot/test_gr00t_inference_validation.py index e428748b..6b038424 100644 --- a/tests/policies/groot/test_gr00t_inference_validation.py +++ b/tests/policies/groot/test_gr00t_inference_validation.py @@ -2235,3 +2235,136 @@ def test_build_inference_command_signature_excludes_host(self): "(inside-container --host is hardcoded to 0.0.0.0). " f"Found host in signature: {sig.parameters}" ) + + +class TestPortBoolRejected: + """Pin: port=True / port=False is rejected by the type-check. + + Regression: ``isinstance(True, int) is True`` in Python (bool subclasses int) + and the range check ``1 <= True <= 65535`` evaluates True (True == 1). + Pre-fix, ``port=True`` passed validation and reached ``--port`` argv as the + string ``"True"`` -- a subtle failure mode an LLM caller could trip on. + + Pinned per the R5 review thread on ``gr00t_inference.py:223``. + """ + + _COMMON_KWARGS = dict( + data_config="fourier_gr1_arms_only", + embodiment_tag="gr1", + host="127.0.0.1", + vit_dtype="fp8", + llm_dtype="nvfp4", + dit_dtype="fp8", + checkpoint_path=None, + trt_engine_path="gr00t_engine", + container_name=None, + protocol="n1.5", + ) + + def test_port_true_rejected(self): + """port=True must be rejected (bool is not a valid port type).""" + from strands_robots.tools.gr00t_inference import validate_inputs + + with pytest.raises(ValueError, match=r"port must be an integer.*bool"): + validate_inputs(action="start", port=True, **self._COMMON_KWARGS) + + def test_port_false_rejected(self): + """port=False must be rejected (bool is not a valid port type, even though False==0).""" + from strands_robots.tools.gr00t_inference import validate_inputs + + with pytest.raises(ValueError, match=r"port must be an integer.*bool"): + validate_inputs(action="start", port=False, **self._COMMON_KWARGS) + + def test_port_int_still_accepted(self): + """Inverse pin: a real int port still passes after the bool guard.""" + from strands_robots.tools.gr00t_inference import validate_inputs + + # Should not raise + validate_inputs(action="start", port=5555, **self._COMMON_KWARGS) + + +class TestHfRepoLeadingDotSegments: + """Pin: hf_repo segments starting with '.' are rejected. + + Regression: R5 closed bare ``.`` / ``..`` segments and leading-``-``, but + the regex ``^[a-zA-Z0-9_.-]+/[a-zA-Z0-9_.-]+$`` plus the segment loop + still accepted ``.org/name``, ``org/.git``, ``...../name``, etc. + HuggingFace's API rejects these so practical exploit surface is narrow, + but the validator's job per AGENTS.md > LLM Input Safety is to fail + closed locally rather than rely on a downstream service. + + Pinned per the R5 review thread on ``gr00t_inference.py:253``. + """ + + _COMMON_KWARGS = dict( + data_config="fourier_gr1_arms_only", + embodiment_tag="gr1", + port=5555, + host="127.0.0.1", + vit_dtype="fp8", + llm_dtype="nvfp4", + dit_dtype="fp8", + checkpoint_path=None, + trt_engine_path="gr00t_engine", + container_name=None, + protocol="n1.5", + ) + + def test_leading_dot_first_segment_rejected(self): + """hf_repo='.org/name' is regex-valid but segment-rejected.""" + from strands_robots.tools.gr00t_inference import validate_inputs + + with pytest.raises(ValueError, match=r"hf_repo.*org/name"): + validate_inputs( + action="download_checkpoint", + hf_repo=".org/name", + **self._COMMON_KWARGS, + ) + + def test_leading_dot_second_segment_rejected(self): + """hf_repo='org/.git' (hidden-style name) must be rejected.""" + from strands_robots.tools.gr00t_inference import validate_inputs + + with pytest.raises(ValueError, match=r"hf_repo.*org/name"): + validate_inputs( + action="download_checkpoint", + hf_repo="org/.git", + **self._COMMON_KWARGS, + ) + + def test_multi_dot_prefix_rejected(self): + """hf_repo='...../name' is regex-valid (chars in class) but rejected.""" + from strands_robots.tools.gr00t_inference import validate_inputs + + with pytest.raises(ValueError, match=r"hf_repo.*org/name"): + validate_inputs( + action="download_checkpoint", + hf_repo="...../name", + **self._COMMON_KWARGS, + ) + + def test_leading_dot_hidden_name_rejected(self): + """hf_repo='org/.name' (hidden file convention) must be rejected.""" + from strands_robots.tools.gr00t_inference import validate_inputs + + with pytest.raises(ValueError, match=r"hf_repo.*org/name"): + validate_inputs( + action="download_checkpoint", + hf_repo="org/.name", + **self._COMMON_KWARGS, + ) + + def test_legitimate_dotted_repos_still_accepted(self): + """Inverse pin: internal dots in segments are still accepted.""" + from strands_robots.tools.gr00t_inference import validate_inputs + + for valid in ( + "nvidia/GR00T-N1.7-LIBERO", # internal dot in name segment + "a-b/c.d", # internal dot allowed + "org_x/repo.name-v2", # multiple internal dots + ): + validate_inputs( + action="download_checkpoint", + hf_repo=valid, + **self._COMMON_KWARGS, + ) From c7aaa6ae0e9be3ec3a507ec77b3ed304af3e6577 Mon Sep 17 00:00:00 2001 From: strands-agent <217235299+strands-agent@users.noreply.github.com> Date: Sat, 23 May 2026 17:28:04 +0000 Subject: [PATCH 16/18] review(gr00t): R6 -- remove dual source of truth for loopback default _start_container signature now requires host: str without default value. gr00t_inference() remains the single source of truth, resolving host=None to 127.0.0.1. Before this fix, _start_container had its own independent default of 127.0.0.1, which would mask any bugs where gr00t_inference() accidentally passed host=None through. Pin test verifies via inspect.signature that host parameter has no default. Inverse pin confirms gr00t_inference still resolves host=None correctly. Addresses review threads on gr00t_inference.py:1511 and :629. --- strands_robots/tools/gr00t_inference.py | 4 +- .../groot/test_gr00t_inference_validation.py | 56 +++++++++++++++++++ 2 files changed, 58 insertions(+), 2 deletions(-) diff --git a/strands_robots/tools/gr00t_inference.py b/strands_robots/tools/gr00t_inference.py index c8a24aab..3b92412e 100644 --- a/strands_robots/tools/gr00t_inference.py +++ b/strands_robots/tools/gr00t_inference.py @@ -1508,13 +1508,13 @@ def _start_container( container_command: str, hf_local_dir: str | None, force: bool, - host: str = "127.0.0.1", + host: str, ) -> dict[str, Any]: """``docker run -d`` the GR00T container so subsequent ``start`` actions can ``docker exec`` into it. The ``host`` kwarg controls the docker host-side port binding via - ``-p {host}:{port}:{port}``. Default ``127.0.0.1`` keeps the published + ``-p {host}:{port}:{port}``. Required. ``127.0.0.1`` keeps the published port on loopback only; pass ``host="0.0.0.0"`` to expose to the network. Idempotent: when a container with ``container_name`` is already diff --git a/tests/policies/groot/test_gr00t_inference_validation.py b/tests/policies/groot/test_gr00t_inference_validation.py index 6b038424..d97bde58 100644 --- a/tests/policies/groot/test_gr00t_inference_validation.py +++ b/tests/policies/groot/test_gr00t_inference_validation.py @@ -2368,3 +2368,59 @@ def test_legitimate_dotted_repos_still_accepted(self): hf_repo=valid, **self._COMMON_KWARGS, ) + + +class TestStartContainerHostNoDefault: + """Pin: _start_container requires host without default. + + Regression: R5 identified two sources of truth for the loopback default. + gr00t_inference() resolves host=None -> "127.0.0.1", AND _start_container + declares host: str = "127.0.0.1" independently. If gr00t_inference() ever + passes host=None through, _start_container's default masks the bug. + + Fix: change _start_container signature to require host: str (drop default) + so gr00t_inference() is the single source of truth. Pinned per the R6 + review threads on gr00t_inference.py:1511 and :629. + """ + + def test_start_container_host_has_no_default(self): + """_start_container signature must require host without default.""" + import inspect + + from strands_robots.tools.gr00t_inference import _start_container + + sig = inspect.signature(_start_container) + host_param = sig.parameters.get("host") + + assert host_param is not None, "_start_container signature missing host parameter" + assert host_param.default is inspect.Parameter.empty, ( + f"_start_container host parameter has default {host_param.default!r}. " + f"Expected no default (gr00t_inference() is the single source of truth)." + ) + + def test_gr00t_inference_resolves_host_none_to_loopback(self): + """Inverse pin: gr00t_inference(host=None) still resolves to 127.0.0.1.""" + from unittest.mock import patch + + from strands_robots.tools.gr00t_inference import gr00t_inference + + with patch("strands_robots.tools.gr00t_inference._start_container") as mock_start: + mock_start.return_value = { + "status": "success", + "container_name": "gr00t", + "message": "mocked", + } + + # Call with host=None (user did not specify) + gr00t_inference( + action="start_container", + image_name="mock:latest", + ) + + # Verify _start_container was called with host="127.0.0.1" + assert mock_start.called, "_start_container not called" + call_kwargs = mock_start.call_args.kwargs + assert call_kwargs.get("host") == "127.0.0.1", ( + f"Expected gr00t_inference to resolve host=None to '127.0.0.1', " + f"got {call_kwargs.get('host')!r}" + ) From fd9f21a5b5ee1f68646d5663efb5c17633a24de1 Mon Sep 17 00:00:00 2001 From: Cagatay Cali <9213230+cagataycali@users.noreply.github.com> Date: Sat, 23 May 2026 18:20:58 +0000 Subject: [PATCH 17/18] review(gr00t_inference): R7 -- ruff format the validation test file CI failed on `ruff format --check` after R6 added a multi-line assertion that ruff prefers as one line. Single trailing whitespace collapse, no behaviour change. --- tests/policies/groot/test_gr00t_inference_validation.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/policies/groot/test_gr00t_inference_validation.py b/tests/policies/groot/test_gr00t_inference_validation.py index d97bde58..cff239e8 100644 --- a/tests/policies/groot/test_gr00t_inference_validation.py +++ b/tests/policies/groot/test_gr00t_inference_validation.py @@ -2421,6 +2421,5 @@ def test_gr00t_inference_resolves_host_none_to_loopback(self): assert mock_start.called, "_start_container not called" call_kwargs = mock_start.call_args.kwargs assert call_kwargs.get("host") == "127.0.0.1", ( - f"Expected gr00t_inference to resolve host=None to '127.0.0.1', " - f"got {call_kwargs.get('host')!r}" + f"Expected gr00t_inference to resolve host=None to '127.0.0.1', got {call_kwargs.get('host')!r}" ) From 33288bd4e52e238657f84bbcccdadf6464e2f938 Mon Sep 17 00:00:00 2001 From: Cagatay Cali <9213230+cagataycali@users.noreply.github.com> Date: Sat, 23 May 2026 18:21:58 +0000 Subject: [PATCH 18/18] review(gr00t_inference): R7 -- update _start_container test call sites for required host kwarg R6 (`c7aaa6a`) made `host` a required keyword-only argument on `_start_container` to remove the dual-source-of-truth default, but did not update the 5 existing TestStartContainer call sites that still relied on the old `host="127.0.0.1"` default. CI failed locally with: TypeError: _start_container() missing 1 required keyword-only argument: 'host' Add explicit `host="127.0.0.1"` to all 5 call sites. No behavioural change -- the previous default was loopback and these tests assert behaviour at the loopback default. Pin: tests/tools/test_gr00t_inference.py::TestStartContainer (5 tests). Pre-fix: 5 fail with TypeError. Post-fix: 5 pass. --- tests/tools/test_gr00t_inference.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tests/tools/test_gr00t_inference.py b/tests/tools/test_gr00t_inference.py index 0fbe9b9b..f6053850 100644 --- a/tests/tools/test_gr00t_inference.py +++ b/tests/tools/test_gr00t_inference.py @@ -571,6 +571,7 @@ def test_skips_when_already_running_and_not_force(self): image_name="gr00t:latest", container_name="gr00t", port=8000, + host="127.0.0.1", volumes=None, hf_token=None, container_command="tail -f /dev/null", @@ -596,6 +597,7 @@ def fake_run(cmd, *a, **kw): image_name="gr00t:latest", container_name="gr00t", port=8000, + host="127.0.0.1", volumes=None, hf_token=None, container_command="tail -f /dev/null", @@ -635,6 +637,7 @@ def fake_run(cmd, *a, **kw): image_name="gr00t:latest", container_name="gr00t", port=8000, + host="127.0.0.1", volumes=None, hf_token=None, container_command="tail -f /dev/null", @@ -664,6 +667,7 @@ def fake_run(cmd, *a, **kw): image_name="gr00t:latest", container_name="gr00t", port=8000, + host="127.0.0.1", volumes={"/cp": "/data/checkpoints"}, hf_token="abc123", container_command="tail -f /dev/null", @@ -683,6 +687,7 @@ def test_unhealthy_state_without_force_errors(self): image_name="gr00t:latest", container_name="gr00t", port=8000, + host="127.0.0.1", volumes=None, hf_token=None, container_command="tail -f /dev/null",