diff --git a/.gitmodules b/.gitmodules
index 4a6dd61c1..cc618fe9b 100644
--- a/.gitmodules
+++ b/.gitmodules
@@ -1,3 +1,7 @@
[submodule "lightx2v_ros/src/simulator/simulator/libero_node/LIBERO"]
path = lightx2v_ros/src/simulator/simulator/libero_node/LIBERO
url = https://github.com/Lifelong-Robot-Learning/LIBERO.git
+[submodule "lightx2v_ros/src/simulator/simulator/robotwin_node/RoboTwin"]
+ path = lightx2v_ros/src/simulator/simulator/robotwin_node/RoboTwin
+ url = https://github.com/robotwin-Platform/robotwin.git
+ branch = main
diff --git a/configs/fastwam/libero_i2va.json b/configs/fastwam/libero_i2va.json
index ac3405110..2c17b506f 100644
--- a/configs/fastwam/libero_i2va.json
+++ b/configs/fastwam/libero_i2va.json
@@ -10,6 +10,9 @@
"action_dim_hidden": 1024,
"action_dim": 7,
"robot_state_dim": 8,
+ "policy_profile": "libero",
+ "normalize_mode": "min-max",
"binarize_gripper": true,
+ "gripper_postprocess": true,
"default_prompt": "A video recorded from a robot's point of view executing the following instruction: {task_prompt}"
}
diff --git a/configs/fastwam/robotwin_i2va.json b/configs/fastwam/robotwin_i2va.json
new file mode 100644
index 000000000..b80c77a5c
--- /dev/null
+++ b/configs/fastwam/robotwin_i2va.json
@@ -0,0 +1,18 @@
+{
+ "adapter_model_path": "/data/nvme7/yongyang/fastwam_release/robotwin_uncond_3cam_384.pt",
+ "dataset_stats_path": "/data/nvme7/yongyang/fastwam_release/robotwin_uncond_3cam_384_dataset_stats.json",
+ "camera_size": 384,
+ "action_chunk_size": 32,
+ "actions_per_plan": 8,
+ "num_steps_wait": 0,
+ "action_infer_steps": 20,
+ "action_sample_shift": 5.0,
+ "action_dim_hidden": 1024,
+ "action_dim": 14,
+ "robot_state_dim": 14,
+ "policy_profile": "robotwin",
+ "normalize_mode": "z-score",
+ "binarize_gripper": false,
+ "gripper_postprocess": false,
+ "default_prompt": "A video recorded from a robot's point of view executing the following instruction: {task_prompt}"
+}
diff --git a/lightx2v/models/runners/wan/fastwam_runner.py b/lightx2v/models/runners/wan/fastwam_runner.py
index 0d89ddcc9..7c1ddbdbf 100644
--- a/lightx2v/models/runners/wan/fastwam_runner.py
+++ b/lightx2v/models/runners/wan/fastwam_runner.py
@@ -25,18 +25,47 @@ def resize_rgb(image, width, height):
return np.asarray(pil.resize((width, height), resample=Image.BILINEAR), dtype=np.uint8)
-class MinMaxNormalizer:
+def _canonical_norm_mode(mode):
+ compact = str(mode).strip().lower().replace("-", "").replace("_", "").replace("/", "")
+ if compact == "minmax":
+ return "min/max"
+ if compact == "zscore":
+ return "z-score"
+ if compact == "q01q99":
+ return "q01/q99"
+ raise ValueError(f"unsupported normalize_mode: {mode!r}")
+
+
+class LinearNormalizer:
+ """Affine normalizer matching FastWAM's SingleFieldLinearNormalizer (global stats).
+
+ Modes used by the released checkpoints:
+ - "min/max" (LIBERO): maps [global_min, global_max] -> [-1, 1]
+ - "q01/q99": maps [global_q01, global_q99] -> [-1, 1]
+ - "z-score" (RoboTwin): (x - global_mean) / global_std
+ """
+
+ std_reg = 1e-8
range_tol = 1e-4
-
- def __init__(self, stats):
- min_v = torch.as_tensor(stats["global_min"], dtype=torch.float32)
- max_v = torch.as_tensor(stats["global_max"], dtype=torch.float32)
- input_range = max_v - min_v
- ignore = input_range < self.range_tol
- input_range[ignore] = 2.0
- self.scale = 2.0 / input_range
- self.offset = -1.0 - self.scale * min_v
- self.offset[ignore] = -min_v[ignore]
+ output_min = -1.0
+ output_max = 1.0
+
+ def __init__(self, stats, mode="min/max"):
+ mode = _canonical_norm_mode(mode)
+ g = {key[len("global_") :]: torch.as_tensor(value, dtype=torch.float32) for key, value in stats.items() if key.startswith("global_")}
+ if mode == "z-score":
+ mean, std = g["mean"], g["std"]
+ self.scale = 1.0 / (std + self.std_reg)
+ self.offset = -mean / (std + self.std_reg)
+ else:
+ low, high = (g["min"].clone(), g["max"].clone()) if mode == "min/max" else (g["q01"].clone(), g["q99"].clone())
+ input_range = high - low
+ ignore = input_range < self.range_tol
+ input_range[ignore] = self.output_max - self.output_min
+ self.scale = (self.output_max - self.output_min) / input_range
+ self.offset = self.output_min - self.scale * low
+ self.offset[ignore] = (self.output_max + self.output_min) / 2 - low[ignore]
+ self.mode = mode
def forward(self, x):
x = torch.as_tensor(x, dtype=torch.float32)
@@ -47,6 +76,13 @@ def backward(self, x):
return (x - self.offset) / self.scale
+class MinMaxNormalizer(LinearNormalizer):
+ """Backwards-compatible alias: LIBERO used a global min/max normalizer."""
+
+ def __init__(self, stats):
+ super().__init__(stats, mode="min/max")
+
+
class FastWAMPolicy:
def __init__(
self,
@@ -65,6 +101,9 @@ def __init__(
binarize_gripper=True,
default_prompt=None,
camera_size=224,
+ policy_profile="libero",
+ normalize_mode="min/max",
+ gripper_postprocess=None,
config=None,
):
self.device = torch.device(device)
@@ -103,6 +142,10 @@ def __init__(
self.seed = None if seed is None or int(seed) < 0 else int(seed)
self.binarize_gripper = bool(binarize_gripper)
self.default_prompt = str(default_prompt)
+ self.policy_profile = str(policy_profile).strip().lower()
+ self.normalize_mode = _canonical_norm_mode(normalize_mode)
+ # LIBERO post-processes the single gripper channel; RoboTwin (dual-arm qpos) does not.
+ self.gripper_postprocess = (self.policy_profile == "libero") if gripper_postprocess is None else bool(gripper_postprocess)
self.config = config
self._check_model_config()
self.pending_actions = deque()
@@ -135,6 +178,9 @@ def from_config(cls, config):
binarize_gripper=config.get("binarize_gripper", True),
default_prompt=config.get("default_prompt"),
camera_size=config.get("camera_size", 224),
+ policy_profile=config.get("policy_profile", "libero"),
+ normalize_mode=config.get("normalize_mode", "min/max"),
+ gripper_postprocess=config.get("gripper_postprocess"),
config=config,
)
@@ -147,8 +193,8 @@ def _load_normalizers(self):
with open(self.dataset_stats_path, "r", encoding="utf-8") as f:
stats = json.load(f)
return (
- MinMaxNormalizer(stats["state"]["default"]),
- MinMaxNormalizer(stats["action"]["default"]),
+ LinearNormalizer(stats["state"]["default"], self.normalize_mode),
+ LinearNormalizer(stats["action"]["default"], self.normalize_mode),
)
def _load_text_encoder(self):
@@ -188,17 +234,17 @@ def _find_model_dir(self, dirname):
def reset(self):
self.pending_actions.clear()
- def next_action(self, agentview_rgb, wrist_rgb, state, task_description):
+ def next_action(self, images, state, task_description):
if not self.pending_actions:
- action_chunk = self.predict_action_chunk(agentview_rgb, wrist_rgb, state, task_description)
+ action_chunk = self.predict_action_chunk(images, state, task_description)
for action in action_chunk[: self.actions_per_plan]:
self.pending_actions.append(np.asarray(action, dtype=np.float32))
if not self.pending_actions:
raise RuntimeError("FastWAM produced an empty action chunk")
return self.pending_actions.popleft()
- def predict_action_chunk(self, agentview_rgb, wrist_rgb, state, task_description, seed=None):
- image = self.build_image_tensor(agentview_rgb, wrist_rgb)
+ def predict_action_chunk(self, images, state, task_description, seed=None):
+ image = self.build_image_tensor(images)
first_frame_latents = self.encode_image_latents(image)
context, context_mask = self.encode_prompt(self.default_prompt.format(task_prompt=task_description))
robot_state = self.state_normalizer.forward(np.asarray(state, dtype=np.float32))
@@ -216,10 +262,12 @@ def predict_action_chunk(self, agentview_rgb, wrist_rgb, state, task_description
seed=self.seed if seed is None else seed,
)
action = self.action_normalizer.backward(action).numpy()
- action[..., -1] = action[..., -1] * 2 - 1
- action[..., -1] = action[..., -1] * -1.0
- if self.binarize_gripper:
- action[..., -1] = np.sign(action[..., -1])
+ if self.gripper_postprocess:
+ # LIBERO single-gripper channel: map [0,1] -> [-1,1], flip sign, optional binarize.
+ action[..., -1] = action[..., -1] * 2 - 1
+ action[..., -1] = action[..., -1] * -1.0
+ if self.binarize_gripper:
+ action[..., -1] = np.sign(action[..., -1])
return action.astype(np.float32)
def _run_action_denoising(self, inputs, action_shape, action_infer_steps, seed):
@@ -239,13 +287,35 @@ def _run_action_denoising(self, inputs, action_shape, action_infer_steps, seed):
scheduler.clear()
return action
- def build_image_tensor(self, agentview_rgb, wrist_rgb):
- primary = resize_rgb(agentview_rgb, self.camera_size, self.camera_size)
- wrist = resize_rgb(wrist_rgb, self.camera_size, self.camera_size)
- rgb = np.concatenate([primary, wrist], axis=1)
+ def build_image_tensor(self, images):
+ """Compose the policy image from a dict of {camera_name: HxWx3 RGB}.
+
+ Layouts mirror the released FastWAM checkpoints:
+ - LIBERO : [agentview | wrist] side-by-side, each camera_size x camera_size.
+ - RoboTwin : head (320x256) on top, [left | right] (each 160x128) below,
+ i.e. final [384, 320, 3]; matches deploy_policy.py.
+ """
+ rgb = self._compose_rgb(images)
tensor = torch.from_numpy(rgb).permute(2, 0, 1).to(device=self.device, dtype=GET_DTYPE())
return tensor * (2.0 / 255.0) - 1.0
+ def _compose_rgb(self, images):
+ def _get(name):
+ if name not in images or images[name] is None:
+ raise KeyError(f"FastWAM profile '{self.policy_profile}' requires camera '{name}'")
+ return images[name]
+
+ if self.policy_profile == "robotwin":
+ head = resize_rgb(_get("head_camera"), 320, 256)
+ left = resize_rgb(_get("left_camera"), 160, 128)
+ right = resize_rgb(_get("right_camera"), 160, 128)
+ bottom = np.concatenate([left, right], axis=1)
+ return np.concatenate([head, bottom], axis=0)
+
+ primary = resize_rgb(_get("agentview"), self.camera_size, self.camera_size)
+ wrist = resize_rgb(_get("wrist"), self.camera_size, self.camera_size)
+ return np.concatenate([primary, wrist], axis=1)
+
def encode_image_latents(self, image):
image = image.unsqueeze(1)
latents = self.vae.encode(image.unsqueeze(0))
@@ -352,8 +422,7 @@ def run_pipeline(self, input_info):
agentview, wrist = self._load_image_pair()
state = self._load_state()
actions = self.policy.predict_action_chunk(
- agentview_rgb=agentview,
- wrist_rgb=wrist,
+ images={"agentview": agentview, "wrist": wrist},
state=state,
task_description=self.input_info.prompt,
seed=self.input_info.seed,
diff --git a/lightx2v_ros/src/common/common/__init__.py b/lightx2v_ros/src/common/common/__init__.py
new file mode 100644
index 000000000..5e568f47b
--- /dev/null
+++ b/lightx2v_ros/src/common/common/__init__.py
@@ -0,0 +1,15 @@
+from .contract import (
+ CONTRACTS,
+ LIBERO_CONTRACT,
+ ROBOTWIN_CONTRACT,
+ EnvContract,
+ get_contract,
+)
+
+__all__ = [
+ "CONTRACTS",
+ "EnvContract",
+ "LIBERO_CONTRACT",
+ "ROBOTWIN_CONTRACT",
+ "get_contract",
+]
diff --git a/lightx2v_ros/src/common/common/contract.py b/lightx2v_ros/src/common/common/contract.py
new file mode 100644
index 000000000..e6719eedd
--- /dev/null
+++ b/lightx2v_ros/src/common/common/contract.py
@@ -0,0 +1,115 @@
+"""Single source of truth for the simulator/inference/visualization contract.
+
+This module is intentionally dependency-free (pure Python, no ROS/torch imports)
+so it can be imported by every ROS node as well as plain scripts. It defines, per
+simulation environment, the ROS topic namespace, the set of cameras, the
+action/state dimensions and the inference profile. All three nodes derive their
+topic names and tensor dimensions from the same `EnvContract`, which removes the
+LIBERO-specific hard-coding that used to live in each node.
+"""
+
+from dataclasses import dataclass, field
+from typing import Dict, Tuple
+
+
+@dataclass(frozen=True)
+class EnvContract:
+ # Logical environment id, e.g. "libero" or "robotwin".
+ name: str
+ # ROS topic prefix, e.g. "/libero".
+ namespace: str
+ # All cameras the simulator publishes (logical names, also used by the viewer).
+ cameras: Tuple[str, ...]
+ # Subset of `cameras` fed to the policy, in the exact order the policy expects.
+ policy_input_cameras: Tuple[str, ...]
+ # Robot action / proprio-state vector dimensions exchanged over ROS.
+ action_dim: int
+ state_dim: int
+ # Square render/publish size hint (used by simulators that render on demand).
+ image_size: int
+ # Inference assembly profile understood by FastWAMPolicy ("libero"|"robotwin").
+ policy_profile: str
+ # Action/state normalization mode ("minmax"|"zscore").
+ normalize_mode: str
+ # Whether to apply the LIBERO single-gripper sign/binarize post-processing.
+ gripper_postprocess: bool
+ # Optional human-readable description.
+ description: str = field(default="")
+
+ # ----- derived topic helpers (kept in one place so nodes never disagree) -----
+ @property
+ def action_topic(self) -> str:
+ return f"{self.namespace}/action"
+
+ @property
+ def state_topic(self) -> str:
+ return f"{self.namespace}/state"
+
+ @property
+ def success_topic(self) -> str:
+ return f"{self.namespace}/success"
+
+ @property
+ def observation_ready_topic(self) -> str:
+ return f"{self.namespace}/observation_ready"
+
+ @property
+ def task_topic(self) -> str:
+ return f"{self.namespace}/task_description"
+
+ @property
+ def episode_topic(self) -> str:
+ # Monotonic episode counter; lets the policy reset its per-episode state when
+ # the simulator loops into a fresh episode.
+ return f"{self.namespace}/episode"
+
+ def camera_topic(self, camera: str) -> str:
+ if camera not in self.cameras:
+ raise KeyError(f"camera '{camera}' is not part of contract '{self.name}': {self.cameras}")
+ return f"{self.namespace}/{camera}/image_raw"
+
+ def camera_topics(self) -> Dict[str, str]:
+ return {camera: self.camera_topic(camera) for camera in self.cameras}
+
+
+LIBERO_CONTRACT = EnvContract(
+ name="libero",
+ namespace="/libero",
+ cameras=("agentview", "wrist", "frontview", "galleryview"),
+ policy_input_cameras=("agentview", "wrist"),
+ action_dim=7,
+ state_dim=8,
+ image_size=224,
+ policy_profile="libero",
+ normalize_mode="minmax",
+ gripper_postprocess=True,
+ description="LIBERO (robosuite/mujoco) single-arm tabletop manipulation.",
+)
+
+
+ROBOTWIN_CONTRACT = EnvContract(
+ name="robotwin",
+ namespace="/robotwin",
+ cameras=("head_camera", "left_camera", "right_camera"),
+ policy_input_cameras=("head_camera", "left_camera", "right_camera"),
+ action_dim=14,
+ state_dim=14,
+ image_size=384,
+ policy_profile="robotwin",
+ normalize_mode="zscore",
+ gripper_postprocess=False,
+ description="RoboTwin 2.0 (SAPIEN) dual-arm manipulation.",
+)
+
+
+CONTRACTS: Dict[str, EnvContract] = {
+ LIBERO_CONTRACT.name: LIBERO_CONTRACT,
+ ROBOTWIN_CONTRACT.name: ROBOTWIN_CONTRACT,
+}
+
+
+def get_contract(name: str) -> EnvContract:
+ key = str(name).strip().lower()
+ if key not in CONTRACTS:
+ raise KeyError(f"unknown environment '{name}'. Available: {sorted(CONTRACTS)}")
+ return CONTRACTS[key]
diff --git a/lightx2v_ros/src/common/package.xml b/lightx2v_ros/src/common/package.xml
new file mode 100644
index 000000000..8dcee47f0
--- /dev/null
+++ b/lightx2v_ros/src/common/package.xml
@@ -0,0 +1,15 @@
+
+
+
+ common
+ 0.0.1
+ Shared environment contract for LightX2V ROS nodes.
+ user
+ Apache-2.0
+
+ ament_python
+
+
+ ament_python
+
+
diff --git a/lightx2v_ros/src/common/resource/common b/lightx2v_ros/src/common/resource/common
new file mode 100644
index 000000000..e69de29bb
diff --git a/lightx2v_ros/src/common/setup.cfg b/lightx2v_ros/src/common/setup.cfg
new file mode 100644
index 000000000..2fbdbffa4
--- /dev/null
+++ b/lightx2v_ros/src/common/setup.cfg
@@ -0,0 +1,4 @@
+[develop]
+script_dir=$base/lib/common
+[install]
+install_scripts=$base/lib/common
diff --git a/lightx2v_ros/src/common/setup.py b/lightx2v_ros/src/common/setup.py
new file mode 100644
index 000000000..ef2ca53f9
--- /dev/null
+++ b/lightx2v_ros/src/common/setup.py
@@ -0,0 +1,22 @@
+from setuptools import find_packages, setup
+
+package_name = "common"
+
+setup(
+ name=package_name,
+ version="0.0.1",
+ packages=find_packages(),
+ data_files=[
+ ("share/ament_index/resource_index/packages", ["resource/" + package_name]),
+ ("share/" + package_name, ["package.xml"]),
+ ],
+ install_requires=["setuptools"],
+ zip_safe=True,
+ maintainer="user",
+ maintainer_email="user@example.com",
+ description="Shared environment contract for LightX2V ROS nodes.",
+ license="Apache-2.0",
+ entry_points={
+ "console_scripts": [],
+ },
+)
diff --git a/lightx2v_ros/src/inference/inference/fastwam_node/main.py b/lightx2v_ros/src/inference/inference/fastwam_node/main.py
index df5dad84f..6676a2db7 100644
--- a/lightx2v_ros/src/inference/inference/fastwam_node/main.py
+++ b/lightx2v_ros/src/inference/inference/fastwam_node/main.py
@@ -1,5 +1,6 @@
import numpy as np
import rclpy
+from common.contract import get_contract
from rclpy.node import Node
from sensor_msgs.msg import Image
from std_msgs.msg import Bool, Float32MultiArray, Int32, String
@@ -7,72 +8,48 @@
from lightx2v.models.runners.wan.fastwam_runner import FastWAMPolicy
from lightx2v.utils.set_config import auto_calc_config, get_default_config
-DEFAULT_DUMMY_ACTION = np.asarray([0, 0, 0, 0, 0, 0, -1], dtype=np.float32)
-ACTION_TOPIC = "/libero/action"
-AGENTVIEW_TOPIC = "/libero/agentview/image_raw"
-WRIST_TOPIC = "/libero/wrist/image_raw"
-STATE_TOPIC = "/libero/state"
-OBSERVATION_READY_TOPIC = "/libero/observation_ready"
-SUCCESS_TOPIC = "/libero/success"
-TASK_TOPIC = "/libero/task_description"
-
class FastWAMNode(Node):
def __init__(self):
super().__init__("fastwam_node")
+ self.declare_parameter("env", "libero")
self.declare_parameter("config_json", "")
self.declare_parameter("model_path", "")
+ self.declare_parameter("num_steps_wait", -1)
+
+ env = str(self.get_parameter("env").value).strip().lower()
+ self.contract = get_contract(env)
- self.get_logger().info("loading FastWAM policy")
+ self.get_logger().info(f"[{self.contract.name}] loading FastWAM policy")
self.policy_config = self.build_policy_config()
self.policy = FastWAMPolicy.from_config(self.policy_config)
- self.get_logger().info("FastWAM policy loaded")
+ self.get_logger().info(f"[{self.contract.name}] FastWAM policy loaded")
- self.agentview = None
- self.wrist = None
+ self.images = {cam: None for cam in self.contract.policy_input_cameras}
self.state = None
self.task_description = None
self.success = False
+ self.episode_index = 0
self.last_processed_observation = -1
- self.num_steps_wait = int(self.policy_config.get("num_steps_wait", 30))
-
- self.action_pub = self.create_publisher(Float32MultiArray, ACTION_TOPIC, 10)
- self.create_subscription(
- Image,
- AGENTVIEW_TOPIC,
- self.on_agentview,
- 10,
- )
- self.create_subscription(
- Image,
- WRIST_TOPIC,
- self.on_wrist,
- 10,
- )
- self.create_subscription(
- Float32MultiArray,
- STATE_TOPIC,
- self.on_state,
- 10,
- )
- self.create_subscription(
- String,
- TASK_TOPIC,
- self.on_task,
- 10,
- )
- self.create_subscription(
- Bool,
- SUCCESS_TOPIC,
- self.on_success,
- 10,
- )
- self.create_subscription(
- Int32,
- OBSERVATION_READY_TOPIC,
- self.on_observation_ready,
- 10,
+
+ ns_wait = int(self.get_parameter("num_steps_wait").value)
+ self.num_steps_wait = ns_wait if ns_wait >= 0 else int(self.policy_config.get("num_steps_wait", 0))
+ self.dummy_action = self._build_dummy_action()
+
+ self.action_pub = self.create_publisher(Float32MultiArray, self.contract.action_topic, 10)
+ self._camera_subs = []
+ for cam in self.contract.policy_input_cameras:
+ self._camera_subs.append(self.create_subscription(Image, self.contract.camera_topic(cam), self._make_image_cb(cam), 10))
+ self.create_subscription(Float32MultiArray, self.contract.state_topic, self.on_state, 10)
+ self.create_subscription(String, self.contract.task_topic, self.on_task, 10)
+ self.create_subscription(Bool, self.contract.success_topic, self.on_success, 10)
+ self.create_subscription(Int32, self.contract.episode_topic, self.on_episode, 10)
+ self.create_subscription(Int32, self.contract.observation_ready_topic, self.on_observation_ready, 10)
+
+ self.get_logger().info(
+ f"[{self.contract.name}] fastwam_node ready: input_cameras={list(self.contract.policy_input_cameras)} "
+ f"action_dim={self.contract.action_dim} state_dim={self.contract.state_dim} num_steps_wait={self.num_steps_wait}"
)
def build_policy_config(self):
@@ -91,18 +68,33 @@ def build_policy_config(self):
"config_json": config_json,
}
)
- return auto_calc_config(config)
+ config = auto_calc_config(config)
+
+ # The config_json is authoritative for policy params; warn loudly on any
+ # mismatch with the environment contract so dimension bugs surface early.
+ for key, expected in (("action_dim", self.contract.action_dim), ("robot_state_dim", self.contract.state_dim)):
+ actual = int(config.get(key, expected))
+ if actual != expected:
+ self.get_logger().warning(f"config `{key}`={actual} disagrees with env '{self.contract.name}' contract ({expected})")
+ return config
- def on_agentview(self, msg):
- self.agentview = image_msg_to_rgb(msg)
+ def _build_dummy_action(self):
+ action = np.zeros(self.contract.action_dim, dtype=np.float32)
+ if self.contract.gripper_postprocess:
+ # LIBERO warmup keeps the gripper open while the scene settles.
+ action[-1] = -1.0
+ return action
- def on_wrist(self, msg):
- self.wrist = image_msg_to_rgb(msg)
+ def _make_image_cb(self, camera):
+ def _cb(msg):
+ self.images[camera] = image_msg_to_rgb(msg)
+
+ return _cb
def on_state(self, msg):
state = np.asarray(msg.data, dtype=np.float32)
- if state.shape != (8,):
- self.get_logger().error(f"expected /libero/state length 8, got {state.size}")
+ if state.size != self.contract.state_dim:
+ self.get_logger().error(f"expected {self.contract.state_topic} length {self.contract.state_dim}, got {state.size}")
return
self.state = state
@@ -112,6 +104,18 @@ def on_task(self, msg):
def on_success(self, msg):
self.success = bool(msg.data)
+ def on_episode(self, msg):
+ episode = int(msg.data)
+ if episode == self.episode_index:
+ return
+ # Simulator looped into a fresh episode: drop the queued action chunk and any
+ # stale success/observation bookkeeping so we start clean on the new rollout.
+ self.episode_index = episode
+ self.success = False
+ self.last_processed_observation = -1
+ self.policy.reset()
+ self.get_logger().info(f"new episode {episode}; policy state reset for fresh rollout")
+
def on_observation_ready(self, msg):
observation_index = int(msg.data)
if observation_index <= self.last_processed_observation:
@@ -127,13 +131,12 @@ def on_observation_ready(self, msg):
return
if observation_index < self.num_steps_wait:
- action = DEFAULT_DUMMY_ACTION.copy()
+ action = self.dummy_action.copy()
self.get_logger().info(f"observation {observation_index}: publishing warmup dummy action")
else:
self.get_logger().info(f"observation {observation_index}: running FastWAM inference/action queue")
action = self.policy.next_action(
- agentview_rgb=self.agentview,
- wrist_rgb=self.wrist,
+ images={cam: self.images[cam] for cam in self.contract.policy_input_cameras},
state=self.state,
task_description=self.task_description,
)
@@ -142,11 +145,7 @@ def on_observation_ready(self, msg):
self.last_processed_observation = observation_index
def missing_inputs(self):
- missing = []
- if self.agentview is None:
- missing.append("agentview")
- if self.wrist is None:
- missing.append("wrist")
+ missing = [cam for cam in self.contract.policy_input_cameras if self.images.get(cam) is None]
if self.state is None:
missing.append("state")
if not self.task_description:
@@ -155,8 +154,8 @@ def missing_inputs(self):
def publish_action(self, action):
action = np.asarray(action, dtype=np.float32).reshape(-1)
- if action.shape != (7,):
- raise ValueError(f"expected action length 7, got {action.size}")
+ if action.size != self.contract.action_dim:
+ raise ValueError(f"expected action length {self.contract.action_dim}, got {action.size}")
msg = Float32MultiArray()
msg.data = action.tolist()
self.action_pub.publish(msg)
diff --git a/lightx2v_ros/src/inference/package.xml b/lightx2v_ros/src/inference/package.xml
index d83e8087e..f2fec9ccc 100644
--- a/lightx2v_ros/src/inference/package.xml
+++ b/lightx2v_ros/src/inference/package.xml
@@ -12,6 +12,7 @@
rclpy
sensor_msgs
std_msgs
+ common
ament_python
diff --git a/lightx2v_ros/src/simulator/package.xml b/lightx2v_ros/src/simulator/package.xml
index 4fbd4c1e9..db47cc1c8 100644
--- a/lightx2v_ros/src/simulator/package.xml
+++ b/lightx2v_ros/src/simulator/package.xml
@@ -12,6 +12,7 @@
rclpy
sensor_msgs
std_msgs
+ common
ament_python
diff --git a/lightx2v_ros/src/simulator/setup.py b/lightx2v_ros/src/simulator/setup.py
index f0609f1c6..52b721552 100644
--- a/lightx2v_ros/src/simulator/setup.py
+++ b/lightx2v_ros/src/simulator/setup.py
@@ -19,6 +19,7 @@
entry_points={
"console_scripts": [
"libero_node = simulator.libero_node.main:main",
+ "robotwin_node = simulator.robotwin_node.main:main",
],
},
)
diff --git a/lightx2v_ros/src/simulator/simulator/libero_node/env.py b/lightx2v_ros/src/simulator/simulator/libero_node/env.py
new file mode 100644
index 000000000..49f7a384c
--- /dev/null
+++ b/lightx2v_ros/src/simulator/simulator/libero_node/env.py
@@ -0,0 +1,95 @@
+"""LIBERO implementation of the generic `BaseSimEnv` contract."""
+
+import math
+
+import numpy as np
+from common.contract import EnvContract
+
+from ..sim.base_env import BaseSimEnv, Observation
+from .observer import LiberoActionObserver, default_libero_root
+
+
+def quat_to_axis_angle(quat):
+ quat = np.asarray(quat, dtype=np.float32).copy()
+ quat[3] = np.clip(quat[3], -1.0, 1.0)
+ den = np.sqrt(1.0 - quat[3] * quat[3])
+ if math.isclose(float(den), 0.0):
+ return np.zeros(3, dtype=np.float32)
+ return ((quat[:3] * 2.0 * math.acos(float(quat[3]))) / den).astype(np.float32)
+
+
+class LiberoEnv(BaseSimEnv):
+ # logical camera name -> LIBERO observation key
+ CAMERA_OBS_KEYS = {
+ "agentview": "agentview_image",
+ "wrist": "robot0_eye_in_hand_image",
+ "frontview": "frontview_image",
+ "galleryview": "galleryview_image",
+ }
+
+ def __init__(
+ self,
+ contract: EnvContract,
+ *,
+ benchmark="libero_spatial",
+ task_id=0,
+ init_state_id=0,
+ image_size=224,
+ seed=0,
+ libero_root=None,
+ ):
+ super().__init__(contract)
+ self.observer = LiberoActionObserver(
+ benchmark_name=benchmark,
+ task_id=int(task_id),
+ init_state_id=int(init_state_id),
+ image_size=int(image_size),
+ seed=int(seed),
+ libero_root=libero_root,
+ )
+
+ @property
+ def task_description(self) -> str:
+ return self.observer.task_description
+
+ def reset(self) -> Observation:
+ return self._observation()
+
+ def step(self, action):
+ _, _, success, _ = self.observer.step(action)
+ return self._observation(), bool(success)
+
+ def _observation(self) -> Observation:
+ obs = self.observer.obs
+ # LIBERO renders upside-down/mirrored relative to the policy expectation.
+ images = {cam: np.ascontiguousarray(obs[key][::-1, ::-1]) for cam, key in self.CAMERA_OBS_KEYS.items() if cam in self.contract.cameras}
+ return Observation(images=images, state=self._state(obs))
+
+ def _state(self, obs) -> np.ndarray:
+ pos = np.asarray(obs["robot0_eef_pos"], dtype=np.float32)
+ axis_angle = quat_to_axis_angle(np.asarray(obs["robot0_eef_quat"], dtype=np.float32))
+ gripper = np.asarray(obs["robot0_gripper_qpos"], dtype=np.float32)
+ return np.concatenate([pos, axis_angle, gripper]).astype(np.float32)
+
+ def close(self) -> None:
+ self.observer.close()
+
+
+def build_libero_env(node) -> LiberoEnv:
+ contract = node.contract
+ node.declare_parameter("libero_root", str(default_libero_root()))
+ node.declare_parameter("benchmark", "libero_spatial")
+ node.declare_parameter("task_id", 0)
+ node.declare_parameter("init_state_id", 0)
+ node.declare_parameter("image_size", contract.image_size)
+ node.declare_parameter("seed", 0)
+
+ return LiberoEnv(
+ contract,
+ benchmark=node.get_parameter("benchmark").value,
+ task_id=int(node.get_parameter("task_id").value),
+ init_state_id=int(node.get_parameter("init_state_id").value),
+ image_size=int(node.get_parameter("image_size").value),
+ seed=int(node.get_parameter("seed").value),
+ libero_root=node.get_parameter("libero_root").value,
+ )
diff --git a/lightx2v_ros/src/simulator/simulator/libero_node/main.py b/lightx2v_ros/src/simulator/simulator/libero_node/main.py
index 175387a51..f265a4171 100644
--- a/lightx2v_ros/src/simulator/simulator/libero_node/main.py
+++ b/lightx2v_ros/src/simulator/simulator/libero_node/main.py
@@ -1,162 +1,12 @@
-import math
+from common.contract import get_contract
-import numpy as np
-import rclpy
-from rclpy.node import Node
-from sensor_msgs.msg import Image
-from std_msgs.msg import Bool, Float32MultiArray, Int32, String
-
-from .observer import LiberoActionObserver, default_libero_root
-
-
-class LiberoNode(Node):
- def __init__(self):
- super().__init__("libero_node")
-
- self.declare_parameter("libero_root", str(default_libero_root()))
- self.declare_parameter("benchmark", "libero_spatial")
- self.declare_parameter("task_id", 0)
- self.declare_parameter("init_state_id", 0)
- self.declare_parameter("image_size", 224)
- self.declare_parameter("seed", 0)
- self.declare_parameter("action_topic", "/libero/action")
- self.declare_parameter("state_topic", "/libero/state")
- self.declare_parameter("agentview_topic", "/libero/agentview/image_raw")
- self.declare_parameter("wrist_topic", "/libero/wrist/image_raw")
- self.declare_parameter("frontview_topic", "/libero/frontview/image_raw")
- self.declare_parameter("galleryview_topic", "/libero/galleryview/image_raw")
- self.declare_parameter("success_topic", "/libero/success")
- self.declare_parameter("observation_ready_topic", "/libero/observation_ready")
- self.declare_parameter("task_topic", "/libero/task_description")
-
- self.state_pub = self.create_publisher(Float32MultiArray, self.get_parameter("state_topic").value, 10)
- self.agentview_pub = self.create_publisher(Image, self.get_parameter("agentview_topic").value, 10)
- self.wrist_pub = self.create_publisher(Image, self.get_parameter("wrist_topic").value, 10)
- self.frontview_pub = self.create_publisher(Image, self.get_parameter("frontview_topic").value, 10)
- self.galleryview_pub = self.create_publisher(Image, self.get_parameter("galleryview_topic").value, 10)
- self.success_pub = self.create_publisher(Bool, self.get_parameter("success_topic").value, 10)
- self.observation_ready_pub = self.create_publisher(Int32, self.get_parameter("observation_ready_topic").value, 10)
- self.task_pub = self.create_publisher(String, self.get_parameter("task_topic").value, 10)
- self.action_sub = self.create_subscription(
- Float32MultiArray,
- self.get_parameter("action_topic").value,
- self.on_action,
- 10,
- )
-
- self.observer = LiberoActionObserver(
- benchmark_name=self.get_parameter("benchmark").value,
- task_id=int(self.get_parameter("task_id").value),
- init_state_id=int(self.get_parameter("init_state_id").value),
- image_size=int(self.get_parameter("image_size").value),
- seed=int(self.get_parameter("seed").value),
- libero_root=self.get_parameter("libero_root").value,
- )
- self.step_index = 0
- self.success = False
- self.observation_timer = self.create_timer(1.0, self.republish_observation)
-
- self.get_logger().info(f"listening for actions on {self.get_parameter('action_topic').value}")
- self.publish_observation()
-
- def republish_observation(self):
- if self.success:
- self.observation_timer.cancel()
- return
- self.publish_observation()
-
- def on_action(self, msg):
- if self.success:
- self.get_logger().warning("episode already succeeded; ignoring action")
- return
-
- action = np.asarray(msg.data, dtype=np.float32)
- if action.shape != (7,):
- self.get_logger().error(f"expected action length 7, got {action.size}")
- return
-
- self.get_logger().info(f"received action: {action.tolist()}")
- _, _, success, _ = self.observer.step(action)
- self.step_index += 1
- self.success = bool(success)
- self.publish_observation()
-
- if self.success:
- self.get_logger().info(f"episode succeeded at step {self.step_index}")
-
- def publish_observation(self):
- obs = self.observer.obs
- stamp = self.get_clock().now().to_msg()
-
- self.state_pub.publish(self.make_state_msg(obs))
- self.agentview_pub.publish(self.make_image_msg(obs["agentview_image"][::-1, ::-1], stamp, "agentview"))
- self.wrist_pub.publish(self.make_image_msg(obs["robot0_eye_in_hand_image"][::-1, ::-1], stamp, "wrist"))
- self.frontview_pub.publish(self.make_image_msg(obs["frontview_image"][::-1, ::-1], stamp, "frontview"))
- self.galleryview_pub.publish(self.make_image_msg(obs["galleryview_image"][::-1, ::-1], stamp, "galleryview"))
-
- task_msg = String()
- task_msg.data = self.observer.task_description
- self.task_pub.publish(task_msg)
-
- success_msg = Bool()
- success_msg.data = self.success
- self.success_pub.publish(success_msg)
-
- observation_ready_msg = Int32()
- observation_ready_msg.data = self.step_index
- self.observation_ready_pub.publish(observation_ready_msg)
-
- def make_state_msg(self, obs):
- pos = np.asarray(obs["robot0_eef_pos"], dtype=np.float32)
- axis_angle = quat_to_axis_angle(np.asarray(obs["robot0_eef_quat"], dtype=np.float32))
- gripper = np.asarray(obs["robot0_gripper_qpos"], dtype=np.float32)
-
- msg = Float32MultiArray()
- msg.data = np.concatenate([pos, axis_angle, gripper]).astype(np.float32).tolist()
- return msg
-
- def make_image_msg(self, image, stamp, frame_id):
- image = np.ascontiguousarray(image)
- msg = Image()
- msg.header.stamp = stamp
- msg.header.frame_id = frame_id
- msg.height = int(image.shape[0])
- msg.width = int(image.shape[1])
- msg.encoding = "rgb8"
- msg.is_bigendian = False
- msg.step = int(image.strides[0])
- msg.data = image.tobytes()
- return msg
-
- def destroy_node(self):
- if hasattr(self, "observer"):
- self.observer.close()
- super().destroy_node()
-
-
-def quat_to_axis_angle(quat):
- quat = np.asarray(quat, dtype=np.float32).copy()
- quat[3] = np.clip(quat[3], -1.0, 1.0)
- den = np.sqrt(1.0 - quat[3] * quat[3])
- if math.isclose(float(den), 0.0):
- return np.zeros(3, dtype=np.float32)
- return ((quat[:3] * 2.0 * math.acos(float(quat[3]))) / den).astype(np.float32)
+from ..sim.node import run_simulator_node
+from .env import build_libero_env
def main(args=None):
- rclpy.init(args=args)
- node = LiberoNode()
- try:
- rclpy.spin(node)
- except KeyboardInterrupt:
- pass
- except Exception:
- if rclpy.ok():
- raise
- finally:
- node.destroy_node()
- if rclpy.ok():
- rclpy.shutdown()
+ contract = get_contract("libero")
+ run_simulator_node(contract, build_libero_env, node_name="libero_node", args=args)
if __name__ == "__main__":
diff --git a/lightx2v_ros/src/simulator/simulator/robotwin_node/RoboTwin b/lightx2v_ros/src/simulator/simulator/robotwin_node/RoboTwin
new file mode 160000
index 000000000..bf44be51c
--- /dev/null
+++ b/lightx2v_ros/src/simulator/simulator/robotwin_node/RoboTwin
@@ -0,0 +1 @@
+Subproject commit bf44be51cf5717a5595ce59447f2cf5263d2aa95
diff --git a/lightx2v_ros/src/simulator/simulator/robotwin_node/__init__.py b/lightx2v_ros/src/simulator/simulator/robotwin_node/__init__.py
new file mode 100644
index 000000000..06858927d
--- /dev/null
+++ b/lightx2v_ros/src/simulator/simulator/robotwin_node/__init__.py
@@ -0,0 +1 @@
+# robotwin_node package
diff --git a/lightx2v_ros/src/simulator/simulator/robotwin_node/env.py b/lightx2v_ros/src/simulator/simulator/robotwin_node/env.py
new file mode 100644
index 000000000..d6f46e9df
--- /dev/null
+++ b/lightx2v_ros/src/simulator/simulator/robotwin_node/env.py
@@ -0,0 +1,247 @@
+"""RoboTwin (SAPIEN, dual-arm) implementation of the generic `BaseSimEnv`.
+
+This adapter wraps a vendored RoboTwin task so the generic `SimulatorNode` can
+drive it exactly like LIBERO. RoboTwin's own evaluation orchestration lives in
+``third_party/RoboTwin/script/eval_policy.py``; that script is RoboTwin-driven
+(it owns the rollout loop and calls a policy plugin). Here we invert the control
+flow for ROS: we build the same task `args`, run ``setup_demo`` once, and then
+expose ``reset`` / ``step`` (``get_obs`` + ``take_action(qpos)``), publishing the
+three RoboTwin cameras and the 14-dim joint-state vector over ROS.
+
+RoboTwin heavy dependencies (sapien, mplib, curobo, ...) and assets are imported
+lazily inside ``reset``/construction, so the ROS package builds and imports even
+on machines where the RoboTwin runtime is not installed yet.
+"""
+
+import importlib
+import os
+import sys
+from pathlib import Path
+
+import numpy as np
+from common.contract import EnvContract
+
+from ..sim.base_env import BaseSimEnv, Observation
+
+
+def default_robotwin_root() -> Path:
+ return Path(__file__).resolve().parent / "RoboTwin"
+
+
+def _add_python_path(path) -> None:
+ path = str(Path(path))
+ if path not in sys.path:
+ sys.path.insert(0, path)
+
+
+class RoboTwinEnv(BaseSimEnv):
+ """Single-episode RoboTwin environment exposed through the BaseSimEnv contract."""
+
+ def __init__(
+ self,
+ contract: EnvContract,
+ *,
+ task_name: str = "click_alarmclock",
+ task_config: str = "demo_clean",
+ embodiment: str = "aloha-agilex",
+ instruction_type: str = "unseen",
+ instruction: str = "",
+ seed: int = 0,
+ robotwin_root=None,
+ ):
+ super().__init__(contract)
+ self.robotwin_root = Path(robotwin_root or default_robotwin_root()).expanduser()
+ self.task_name = str(task_name)
+ self.task_config = str(task_config)
+ self.embodiment = str(embodiment).strip()
+ self.instruction_type = str(instruction_type)
+ self._fixed_instruction = str(instruction).strip()
+ self.seed = int(seed)
+ self._episode_index = 0
+
+ self._task_description = ""
+ self._configs_path = self.robotwin_root / "task_config"
+
+ self._prepare_runtime()
+ self.args = self._build_task_args()
+ self.env = self._instantiate_task()
+ self._setup_episode()
+
+ # ------------------------------------------------------------------ setup
+ def _prepare_runtime(self) -> None:
+ root = self.robotwin_root
+ if not (root / "envs").is_dir():
+ raise FileNotFoundError(f"RoboTwin is not vendored at {root}. See robotwin_node/RoboTwin/README and run the RoboTwin install/asset-download steps.")
+ # RoboTwin source uses root-relative imports such as `from envs import ...`
+ # and `from generate_episode_instructions import *`.
+ _add_python_path(root)
+ _add_python_path(root / "description" / "utils")
+
+ def _require_config(self, *parts) -> Path:
+ path = self._configs_path.joinpath(*parts)
+ if not path.exists():
+ raise FileNotFoundError(f"Missing RoboTwin config: {path}. Populate `task_config/` (and `assets/`) from the official RoboTwin repo (see robotwin_node/RoboTwin/script).")
+ return path
+
+ def _build_task_args(self) -> dict:
+ """Replicates third_party/RoboTwin/script/eval_policy.py:main() arg assembly."""
+ import yaml
+
+ with open(self._require_config(f"{self.task_config}.yml"), "r", encoding="utf-8") as f:
+ args = yaml.load(f.read(), Loader=yaml.FullLoader)
+
+ args["task_name"] = self.task_name
+ args["task_config"] = self.task_config
+
+ # Allow the launch parameter to pin the embodiment (e.g. "aloha-agilex").
+ if self.embodiment:
+ args["embodiment"] = [self.embodiment]
+ embodiment_type = args.get("embodiment")
+ if not isinstance(embodiment_type, list):
+ raise ValueError(f"task_config embodiment must be a list, got {embodiment_type!r}")
+
+ with open(self._require_config("_embodiment_config.yml"), "r", encoding="utf-8") as f:
+ embodiment_types = yaml.load(f.read(), Loader=yaml.FullLoader)
+
+ def embodiment_file(key):
+ robot_file = embodiment_types[key]["file_path"]
+ if robot_file is None:
+ raise ValueError(f"No embodiment file for '{key}'")
+ return os.path.join(str(self.robotwin_root), robot_file) if not os.path.isabs(robot_file) else robot_file
+
+ def embodiment_config(robot_file):
+ with open(os.path.join(robot_file, "config.yml"), "r", encoding="utf-8") as f:
+ return yaml.load(f.read(), Loader=yaml.FullLoader)
+
+ with open(self._require_config("_camera_config.yml"), "r", encoding="utf-8") as f:
+ camera_config = yaml.load(f.read(), Loader=yaml.FullLoader)
+
+ head_camera_type = args["camera"]["head_camera_type"]
+ args["head_camera_h"] = camera_config[head_camera_type]["h"]
+ args["head_camera_w"] = camera_config[head_camera_type]["w"]
+
+ if len(embodiment_type) == 1:
+ args["left_robot_file"] = embodiment_file(embodiment_type[0])
+ args["right_robot_file"] = embodiment_file(embodiment_type[0])
+ args["dual_arm_embodied"] = True
+ elif len(embodiment_type) == 3:
+ args["left_robot_file"] = embodiment_file(embodiment_type[0])
+ args["right_robot_file"] = embodiment_file(embodiment_type[1])
+ args["embodiment_dis"] = embodiment_type[2]
+ args["dual_arm_embodied"] = False
+ else:
+ raise ValueError("embodiment items should be 1 or 3")
+
+ args["left_embodiment_config"] = embodiment_config(args["left_robot_file"])
+ args["right_embodiment_config"] = embodiment_config(args["right_robot_file"])
+ args["eval_mode"] = True
+ # Headless: never spawn the on-screen SAPIEN viewer.
+ args["render_freq"] = 0
+ return args
+
+ def _instantiate_task(self):
+ module = importlib.import_module(f"envs.{self.task_name}")
+ task_cls = getattr(module, self.task_name)
+ return task_cls()
+
+ def _setup_episode(self) -> None:
+ self.env.setup_demo(now_ep_num=self._episode_index, seed=self.seed, is_test=True, **self.args)
+ instruction = self._resolve_instruction()
+ self.env.set_instruction(instruction=instruction)
+ self._task_description = instruction
+
+ def _resolve_instruction(self) -> str:
+ if self._fixed_instruction:
+ return self._fixed_instruction
+ try:
+ from generate_episode_instructions import generate_episode_descriptions
+
+ episode_info_list = [self.env.info["info"]]
+ results = generate_episode_descriptions(self.task_name, episode_info_list, 1)
+ return str(np.random.choice(results[0][self.instruction_type]))
+ except Exception:
+ return f"Complete the {self.task_name.replace('_', ' ')} task."
+
+ # ------------------------------------------------------------- contract API
+ @property
+ def task_description(self) -> str:
+ return self._task_description
+
+ def reset(self) -> Observation:
+ return self._observation()
+
+ @property
+ def max_steps(self):
+ # RoboTwin sets a per-task rollout cap (`step_lim`) during setup_demo.
+ return getattr(self.env, "step_lim", None)
+
+ def new_episode(self, max_setup_retries: int = 25) -> Observation:
+ """Tear down the current episode and set up a fresh one (new layout).
+
+ Advances the seed so each episode gets a different object placement, and
+ retries the next seeds if setup raises (e.g. RoboTwin ``UnStableError``).
+ """
+ last_err = None
+ for _ in range(max(1, max_setup_retries)):
+ self.seed += 1
+ try:
+ try:
+ self.env.close_env(clear_cache=True)
+ except Exception:
+ pass
+ self._episode_index += 1
+ self._setup_episode()
+ return self._observation()
+ except Exception as exc: # e.g. UnStableError on an unlucky seed
+ last_err = exc
+ raise RuntimeError(f"RoboTwin failed to set up a new episode after {max_setup_retries} seeds; last error: {last_err}")
+
+ def step(self, action):
+ action = np.asarray(action, dtype=np.float32).reshape(-1)
+ # RoboTwin policies output absolute joint targets (qpos), matching FastWAM.
+ self.env.take_action(action, action_type="qpos")
+ obs = self._observation()
+ success = bool(getattr(self.env, "eval_success", False)) or bool(self.env.check_success())
+ return obs, success
+
+ def _observation(self) -> Observation:
+ raw = self.env.get_obs()
+ cameras = raw["observation"]
+ images = {}
+ for cam in self.contract.cameras:
+ if cam not in cameras or "rgb" not in cameras[cam]:
+ raise KeyError(f"RoboTwin observation missing camera '{cam}' rgb; got {list(cameras)}")
+ rgb = np.asarray(cameras[cam]["rgb"])[..., :3]
+ images[cam] = np.ascontiguousarray(rgb.astype(np.uint8))
+ state = np.asarray(raw["joint_action"]["vector"], dtype=np.float32).reshape(-1)
+ return Observation(images=images, state=state)
+
+ def close(self) -> None:
+ env = getattr(self, "env", None)
+ if env is not None:
+ try:
+ env.close_env()
+ except Exception:
+ pass
+
+
+def build_robotwin_env(node) -> RoboTwinEnv:
+ contract = node.contract
+ node.declare_parameter("robotwin_root", str(default_robotwin_root()))
+ node.declare_parameter("task_name", "click_alarmclock")
+ node.declare_parameter("task_config", "demo_clean")
+ node.declare_parameter("embodiment", "aloha-agilex")
+ node.declare_parameter("instruction_type", "unseen")
+ node.declare_parameter("instruction", "")
+ node.declare_parameter("seed", 0)
+
+ return RoboTwinEnv(
+ contract,
+ task_name=node.get_parameter("task_name").value,
+ task_config=node.get_parameter("task_config").value,
+ embodiment=node.get_parameter("embodiment").value,
+ instruction_type=node.get_parameter("instruction_type").value,
+ instruction=node.get_parameter("instruction").value,
+ seed=int(node.get_parameter("seed").value),
+ robotwin_root=node.get_parameter("robotwin_root").value,
+ )
diff --git a/lightx2v_ros/src/simulator/simulator/robotwin_node/main.py b/lightx2v_ros/src/simulator/simulator/robotwin_node/main.py
new file mode 100644
index 000000000..4251514da
--- /dev/null
+++ b/lightx2v_ros/src/simulator/simulator/robotwin_node/main.py
@@ -0,0 +1,13 @@
+from common.contract import get_contract
+
+from ..sim.node import run_simulator_node
+from .env import build_robotwin_env
+
+
+def main(args=None):
+ contract = get_contract("robotwin")
+ run_simulator_node(contract, build_robotwin_env, node_name="robotwin_node", args=args)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/lightx2v_ros/src/simulator/simulator/sim/__init__.py b/lightx2v_ros/src/simulator/simulator/sim/__init__.py
new file mode 100644
index 000000000..c7b0fec45
--- /dev/null
+++ b/lightx2v_ros/src/simulator/simulator/sim/__init__.py
@@ -0,0 +1,10 @@
+from .base_env import BaseSimEnv, Observation
+from .node import SimulatorNode, rgb_to_image_msg, run_simulator_node
+
+__all__ = [
+ "BaseSimEnv",
+ "Observation",
+ "SimulatorNode",
+ "rgb_to_image_msg",
+ "run_simulator_node",
+]
diff --git a/lightx2v_ros/src/simulator/simulator/sim/base_env.py b/lightx2v_ros/src/simulator/simulator/sim/base_env.py
new file mode 100644
index 000000000..e30abf760
--- /dev/null
+++ b/lightx2v_ros/src/simulator/simulator/sim/base_env.py
@@ -0,0 +1,70 @@
+"""Environment-agnostic simulator interface.
+
+Every concrete simulator (LIBERO, RoboTwin, ...) is exposed to the ROS layer
+through the same small contract: it produces an `Observation` (a dict of
+camera RGB frames plus a flat proprio-state vector) and consumes an action
+vector. The generic `SimulatorNode` only ever talks to this interface, so
+adding a new environment never requires touching the node/topic plumbing.
+"""
+
+from abc import ABC, abstractmethod
+from dataclasses import dataclass
+from typing import Dict, Tuple
+
+import numpy as np
+from common.contract import EnvContract
+
+
+@dataclass
+class Observation:
+ # logical camera name -> HxWx3 uint8 RGB image, already oriented for publishing
+ images: Dict[str, np.ndarray]
+ # flat proprio-state vector, shape (contract.state_dim,)
+ state: np.ndarray
+
+
+class BaseSimEnv(ABC):
+ def __init__(self, contract: EnvContract):
+ self.contract = contract
+
+ @property
+ @abstractmethod
+ def task_description(self) -> str:
+ """Current natural-language task/instruction."""
+
+ @abstractmethod
+ def reset(self) -> Observation:
+ """Reset the environment and return the first observation."""
+
+ @abstractmethod
+ def step(self, action: np.ndarray) -> Tuple[Observation, bool]:
+ """Apply one action, returning (observation, success)."""
+
+ def new_episode(self) -> Observation:
+ """Start a fresh episode and return its first observation.
+
+ Used by the node's continuous-eval loop. The default simply calls
+ ``reset()``; environments that need to re-randomize a scene (e.g.
+ RoboTwin) override this to tear down and rebuild the episode.
+ """
+ return self.reset()
+
+ @property
+ def max_steps(self):
+ """Optional per-episode step cap hint (None = let the node decide)."""
+ return None
+
+ def close(self) -> None:
+ return None
+
+ def validate(self, obs: Observation) -> None:
+ missing = [cam for cam in self.contract.cameras if cam not in obs.images]
+ if missing:
+ raise ValueError(f"env '{self.contract.name}' observation is missing cameras: {missing}")
+ for cam in self.contract.cameras:
+ image = np.asarray(obs.images[cam])
+ if image.ndim != 3 or image.shape[2] != 3:
+ raise ValueError(f"env '{self.contract.name}' camera '{cam}' must be HxWx3, got {image.shape}")
+ state = np.asarray(obs.state, dtype=np.float32).reshape(-1)
+ if state.size != self.contract.state_dim:
+ raise ValueError(f"env '{self.contract.name}' state dim {state.size} != contract {self.contract.state_dim}")
diff --git a/lightx2v_ros/src/simulator/simulator/sim/node.py b/lightx2v_ros/src/simulator/simulator/sim/node.py
new file mode 100644
index 000000000..844993366
--- /dev/null
+++ b/lightx2v_ros/src/simulator/simulator/sim/node.py
@@ -0,0 +1,218 @@
+"""Generic, contract-driven simulator ROS node.
+
+`SimulatorNode` is environment-agnostic: it derives every topic name and the
+action/state dimensions from the `EnvContract` and drives any `BaseSimEnv`
+implementation. An `env_factory(node) -> BaseSimEnv` callback lets each concrete
+environment declare its own ROS parameters on the node before construction.
+"""
+
+from typing import Callable
+
+import numpy as np
+import rclpy
+from common.contract import EnvContract
+from rclpy.node import Node
+from sensor_msgs.msg import Image
+from std_msgs.msg import Bool, Float32MultiArray, Int32, String
+
+from .base_env import BaseSimEnv
+
+
+def rgb_to_image_msg(image, stamp, frame_id):
+ image = np.ascontiguousarray(image)
+ msg = Image()
+ msg.header.stamp = stamp
+ msg.header.frame_id = frame_id
+ msg.height = int(image.shape[0])
+ msg.width = int(image.shape[1])
+ msg.encoding = "rgb8"
+ msg.is_bigendian = False
+ msg.step = int(image.strides[0])
+ msg.data = image.tobytes()
+ return msg
+
+
+class SimulatorNode(Node):
+ def __init__(
+ self,
+ contract: EnvContract,
+ env_factory: Callable[["SimulatorNode"], BaseSimEnv],
+ *,
+ node_name: str = "simulator_node",
+ ):
+ super().__init__(node_name)
+ self.contract = contract
+
+ self.declare_parameter("republish_period", 1.0)
+ self.declare_parameter("idle_republish_period", 2.0)
+ # Continuous-eval loop: when true, the node automatically starts a fresh
+ # episode after each success (or step cap) instead of stopping.
+ self.declare_parameter("loop", False)
+ # Per-episode step cap; <=0 means "use the env hint (env.max_steps) or run
+ # until success". Failed episodes rely on this cap to eventually loop.
+ self.declare_parameter("max_episode_steps", 0)
+ self.republish_period = float(self.get_parameter("republish_period").value)
+ self.idle_republish_period = float(self.get_parameter("idle_republish_period").value)
+ self.loop = bool(self.get_parameter("loop").value)
+
+ # env_factory may declare/read its own parameters via `self`.
+ self.env = env_factory(self)
+ if self.env.contract is not contract:
+ raise ValueError("env_factory returned an env bound to a different contract")
+
+ param_max_steps = int(self.get_parameter("max_episode_steps").value)
+ env_hint = getattr(self.env, "max_steps", None)
+ if param_max_steps > 0:
+ self.max_episode_steps = param_max_steps
+ elif env_hint:
+ self.max_episode_steps = int(env_hint)
+ else:
+ self.max_episode_steps = 0
+
+ self.state_pub = self.create_publisher(Float32MultiArray, contract.state_topic, 10)
+ self.image_pubs = {cam: self.create_publisher(Image, contract.camera_topic(cam), 10) for cam in contract.cameras}
+ self.success_pub = self.create_publisher(Bool, contract.success_topic, 10)
+ self.observation_ready_pub = self.create_publisher(Int32, contract.observation_ready_topic, 10)
+ self.task_pub = self.create_publisher(String, contract.task_topic, 10)
+ self.episode_pub = self.create_publisher(Int32, contract.episode_topic, 10)
+ self.action_sub = self.create_subscription(Float32MultiArray, contract.action_topic, self.on_action, 10)
+
+ # `step_index` is a monotonic global observation counter (never reset), so the
+ # policy's "process only newer observations" logic keeps working across episodes.
+ self.step_index = 0
+ # `episode_step` counts steps within the current episode (drives the step cap).
+ self.episode_step = 0
+ self.episode_index = 0
+ self.success = False
+ self._slowed = False
+
+ self.obs = self.env.reset()
+ self.env.validate(self.obs)
+
+ self.get_logger().info(
+ f"[{contract.name}] cameras={list(contract.cameras)} "
+ f"action_dim={contract.action_dim} state_dim={contract.state_dim}; "
+ f"loop={self.loop} max_episode_steps={self.max_episode_steps or 'unlimited'}; "
+ f"listening for actions on {contract.action_topic}"
+ )
+ self.timer = self.create_timer(self.republish_period, self.republish)
+ self.publish_observation()
+
+ def republish(self):
+ self.publish_observation()
+
+ def on_action(self, msg):
+ if self.success:
+ # In loop mode `success` is reset to False synchronously in
+ # `_start_next_episode`, so this only drops late actions that raced the
+ # episode boundary; in single-episode mode it stops the rollout.
+ if not self.loop:
+ self.get_logger().warning("episode already succeeded; ignoring action")
+ return
+
+ action = np.asarray(msg.data, dtype=np.float32).reshape(-1)
+ if action.size != self.contract.action_dim:
+ self.get_logger().error(f"expected action length {self.contract.action_dim}, got {action.size}")
+ return
+
+ self.obs, success = self.env.step(action)
+ self.step_index += 1
+ self.episode_step += 1
+ self.success = bool(success)
+
+ capped = self.max_episode_steps > 0 and self.episode_step >= self.max_episode_steps
+ if self.loop and (self.success or capped):
+ outcome = "SUCCESS" if self.success else f"step cap ({self.max_episode_steps})"
+ self.get_logger().info(f"episode {self.episode_index} ended [{outcome}] after {self.episode_step} steps (global step {self.step_index}); starting next episode...")
+ # Emit the final frame (success flag reflects the outcome) before rebuilding.
+ self.publish_observation()
+ self._start_next_episode()
+ return
+
+ self.publish_observation()
+
+ if self.success and not self.loop:
+ self.get_logger().info(f"episode succeeded at step {self.step_index}")
+ self._slow_down_timer()
+
+ def _start_next_episode(self):
+ try:
+ self.obs = self.env.new_episode()
+ self.env.validate(self.obs)
+ except Exception as exc:
+ self.get_logger().error(f"failed to start next episode: {exc}")
+ raise
+ self.episode_index += 1
+ self.episode_step = 0
+ # Keep the global observation counter strictly increasing so the policy always
+ # sees the new episode's first frame as "newer" than anything it processed.
+ self.step_index += 1
+ self.success = False
+ self.get_logger().info(f"episode {self.episode_index} started (global step {self.step_index}): {self.env.task_description!r}")
+ self.publish_observation()
+
+ def _slow_down_timer(self):
+ # After success the env stops stepping. Keep republishing the final frame at a
+ # low rate so the web viewer keeps showing the last image instead of going blank.
+ if self._slowed:
+ return
+ self._slowed = True
+ try:
+ self.timer.cancel()
+ except Exception:
+ pass
+ self.timer = self.create_timer(self.idle_republish_period, self.republish)
+
+ def publish_observation(self):
+ stamp = self.get_clock().now().to_msg()
+ for cam, pub in self.image_pubs.items():
+ pub.publish(rgb_to_image_msg(self.obs.images[cam], stamp, cam))
+
+ state_msg = Float32MultiArray()
+ state_msg.data = np.asarray(self.obs.state, dtype=np.float32).reshape(-1).tolist()
+ self.state_pub.publish(state_msg)
+
+ task_msg = String()
+ task_msg.data = self.env.task_description or ""
+ self.task_pub.publish(task_msg)
+
+ success_msg = Bool()
+ success_msg.data = self.success
+ self.success_pub.publish(success_msg)
+
+ episode_msg = Int32()
+ episode_msg.data = self.episode_index
+ self.episode_pub.publish(episode_msg)
+
+ ready_msg = Int32()
+ ready_msg.data = self.step_index
+ self.observation_ready_pub.publish(ready_msg)
+
+ def destroy_node(self):
+ try:
+ if hasattr(self, "env"):
+ self.env.close()
+ finally:
+ super().destroy_node()
+
+
+def run_simulator_node(
+ contract: EnvContract,
+ env_factory: Callable[["SimulatorNode"], BaseSimEnv],
+ *,
+ node_name: str,
+ args=None,
+):
+ rclpy.init(args=args)
+ node = SimulatorNode(contract, env_factory, node_name=node_name)
+ try:
+ rclpy.spin(node)
+ except KeyboardInterrupt:
+ pass
+ except Exception:
+ if rclpy.ok():
+ raise
+ finally:
+ node.destroy_node()
+ if rclpy.ok():
+ rclpy.shutdown()
diff --git a/lightx2v_ros/src/visualization/package.xml b/lightx2v_ros/src/visualization/package.xml
index 3a8edb9a3..623605f72 100644
--- a/lightx2v_ros/src/visualization/package.xml
+++ b/lightx2v_ros/src/visualization/package.xml
@@ -13,6 +13,7 @@
sensor_msgs
std_msgs
python3-opencv
+ common
ament_python
diff --git a/lightx2v_ros/src/visualization/visualization/image_web_viewer_node/main.py b/lightx2v_ros/src/visualization/visualization/image_web_viewer_node/main.py
index ac5dd34fc..f1ac047b5 100644
--- a/lightx2v_ros/src/visualization/visualization/image_web_viewer_node/main.py
+++ b/lightx2v_ros/src/visualization/visualization/image_web_viewer_node/main.py
@@ -4,18 +4,17 @@
import cv2
import numpy as np
import rclpy
+from common.contract import get_contract
from rclpy.node import Node
from sensor_msgs.msg import Image
from std_msgs.msg import String
-from .page import INDEX_HTML
+from .page import render_index
-AGENTVIEW_TOPIC = "/libero/agentview/image_raw"
-WRIST_TOPIC = "/libero/wrist/image_raw"
-FRONTVIEW_TOPIC = "/libero/frontview/image_raw"
-GALLERYVIEW_TOPIC = "/libero/galleryview/image_raw"
-TASK_TOPIC = "/libero/task_description"
-CAMERAS = ("agentview", "wrist", "frontview", "galleryview")
+# How long a stalled MJPEG stream waits before re-sending the last frame. This,
+# together with the simulator re-publishing the final frame after success, keeps
+# the page from going blank when no new frames arrive.
+STREAM_KEEPALIVE_S = 2.0
class ImageHttpServer(ThreadingHTTPServer):
@@ -23,9 +22,9 @@ class ImageHttpServer(ThreadingHTTPServer):
class FrameStore:
- def __init__(self):
+ def __init__(self, cameras):
self.condition = Condition()
- self.frames = {name: (0, None) for name in CAMERAS}
+ self.frames = {name: (0, None) for name in cameras}
self.task = ""
def update(self, name, jpeg):
@@ -34,9 +33,9 @@ def update(self, name, jpeg):
self.frames[name] = (seq + 1, jpeg)
self.condition.notify_all()
- def wait_next(self, name, last_seq):
+ def wait_next(self, name, last_seq, timeout=STREAM_KEEPALIVE_S):
with self.condition:
- self.condition.wait_for(lambda: self.frames[name][0] != last_seq)
+ self.condition.wait_for(lambda: self.frames[name][0] != last_seq, timeout=timeout)
return self.frames[name]
def update_task(self, task):
@@ -52,44 +51,48 @@ class ImageWebViewerNode(Node):
def __init__(self):
super().__init__("image_web_viewer")
+ self.declare_parameter("env", "libero")
self.declare_parameter("host", "127.0.0.1")
self.declare_parameter("port", 8080)
- self.declare_parameter("agentview_topic", AGENTVIEW_TOPIC)
- self.declare_parameter("wrist_topic", WRIST_TOPIC)
- self.declare_parameter("frontview_topic", FRONTVIEW_TOPIC)
- self.declare_parameter("galleryview_topic", GALLERYVIEW_TOPIC)
- self.declare_parameter("task_topic", TASK_TOPIC)
self.declare_parameter("jpeg_quality", 85)
+ self.declare_parameter("cameras", [])
+ self.declare_parameter("namespace", "")
+ self.declare_parameter("task_topic", "")
+
+ env = str(self.get_parameter("env").value).strip().lower()
+ contract = get_contract(env)
+ self.contract = contract
+
+ cameras_param = list(self.get_parameter("cameras").value or [])
+ self.cameras = cameras_param if cameras_param else list(contract.cameras)
+ namespace = str(self.get_parameter("namespace").value).strip() or contract.namespace
+ task_topic = str(self.get_parameter("task_topic").value).strip() or contract.task_topic
self.jpeg_quality = int(self.get_parameter("jpeg_quality").value)
- self.frame_store = FrameStore()
+ self.frame_store = FrameStore(self.cameras)
self.http_server = None
self.http_thread = None
- for name in CAMERAS:
+ for name in self.cameras:
+ topic = f"{namespace}/{name}/image_raw"
self.create_subscription(
Image,
- self.get_parameter(f"{name}_topic").value,
+ topic,
lambda msg, camera_name=name: self.on_image(camera_name, msg),
10,
)
- self.create_subscription(
- String,
- self.get_parameter("task_topic").value,
- self.on_task,
- 10,
- )
+ self.create_subscription(String, task_topic, self.on_task, 10)
self.start_http_server()
def start_http_server(self):
host = str(self.get_parameter("host").value)
port = int(self.get_parameter("port").value)
- handler = make_handler(self.frame_store)
+ handler = make_handler(self.frame_store, self.cameras, self.contract.name)
self.http_server = ImageHttpServer((host, port), handler)
self.http_thread = Thread(target=self.http_server.serve_forever, daemon=True)
self.http_thread.start()
- self.get_logger().info(f"image web viewer listening on http://{host}:{port}")
+ self.get_logger().info(f"[{self.contract.name}] image web viewer on http://{host}:{port} cameras={self.cameras}")
def on_image(self, name, msg):
try:
@@ -121,7 +124,10 @@ def destroy_node(self):
super().destroy_node()
-def make_handler(frame_store):
+def make_handler(frame_store, cameras, title):
+ camera_set = set(cameras)
+ index_html = render_index(cameras, title=f"LightX2V ROS ยท {title}")
+
class ImageWebViewerHandler(BaseHTTPRequestHandler):
def do_GET(self):
if self.path in {"/", "/index.html"}:
@@ -130,22 +136,15 @@ def do_GET(self):
if self.path == "/task.txt":
self.send_task()
return
- if self.path == "/agentview.mjpg":
- self.send_stream("agentview")
- return
- if self.path == "/wrist.mjpg":
- self.send_stream("wrist")
- return
- if self.path == "/frontview.mjpg":
- self.send_stream("frontview")
- return
- if self.path == "/galleryview.mjpg":
- self.send_stream("galleryview")
- return
+ if self.path.endswith(".mjpg"):
+ name = self.path[1 : -len(".mjpg")]
+ if name in camera_set:
+ self.send_stream(name)
+ return
self.send_error(404)
def send_index(self):
- body = INDEX_HTML.encode("utf-8")
+ body = index_html.encode("utf-8")
self.send_response(200)
self.send_header("Content-Type", "text/html; charset=utf-8")
self.send_header("Content-Length", str(len(body)))
diff --git a/lightx2v_ros/src/visualization/visualization/image_web_viewer_node/page.py b/lightx2v_ros/src/visualization/visualization/image_web_viewer_node/page.py
index f1f297a87..414e87f08 100644
--- a/lightx2v_ros/src/visualization/visualization/image_web_viewer_node/page.py
+++ b/lightx2v_ros/src/visualization/visualization/image_web_viewer_node/page.py
@@ -1,10 +1,12 @@
-INDEX_HTML = """
-
-
-
-
- LightX2V ROS
-
+ .task span { color: var(--muted); }
+"""
+
+_VIEW_TEMPLATE = """
+ {label}
+
+ """
+
+
+def render_index(cameras, title="LightX2V ROS"):
+ columns = max(1, min(len(cameras), 4))
+ style = _STYLE.replace("__COLUMNS__", str(columns))
+ views = "\n".join(_VIEW_TEMPLATE.format(name=html.escape(str(cam)), label=html.escape(str(cam))) for cam in cameras)
+ safe_title = html.escape(str(title))
+ return f"""
+
+
+
+
+ {safe_title}
+
- LightX2V ROS
+ {safe_title}
live streams
-
- agentview
-
-
-
- wrist
-
-
-
- frontview
-
-
-
- galleryview
-
-
+{views}
waiting for task description