diff --git a/.gitmodules b/.gitmodules index 4a6dd61c1..cc618fe9b 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,3 +1,7 @@ [submodule "lightx2v_ros/src/simulator/simulator/libero_node/LIBERO"] path = lightx2v_ros/src/simulator/simulator/libero_node/LIBERO url = https://github.com/Lifelong-Robot-Learning/LIBERO.git +[submodule "lightx2v_ros/src/simulator/simulator/robotwin_node/RoboTwin"] + path = lightx2v_ros/src/simulator/simulator/robotwin_node/RoboTwin + url = https://github.com/robotwin-Platform/robotwin.git + branch = main diff --git a/configs/fastwam/libero_i2va.json b/configs/fastwam/libero_i2va.json index ac3405110..2c17b506f 100644 --- a/configs/fastwam/libero_i2va.json +++ b/configs/fastwam/libero_i2va.json @@ -10,6 +10,9 @@ "action_dim_hidden": 1024, "action_dim": 7, "robot_state_dim": 8, + "policy_profile": "libero", + "normalize_mode": "min-max", "binarize_gripper": true, + "gripper_postprocess": true, "default_prompt": "A video recorded from a robot's point of view executing the following instruction: {task_prompt}" } diff --git a/configs/fastwam/robotwin_i2va.json b/configs/fastwam/robotwin_i2va.json new file mode 100644 index 000000000..b80c77a5c --- /dev/null +++ b/configs/fastwam/robotwin_i2va.json @@ -0,0 +1,18 @@ +{ + "adapter_model_path": "/data/nvme7/yongyang/fastwam_release/robotwin_uncond_3cam_384.pt", + "dataset_stats_path": "/data/nvme7/yongyang/fastwam_release/robotwin_uncond_3cam_384_dataset_stats.json", + "camera_size": 384, + "action_chunk_size": 32, + "actions_per_plan": 8, + "num_steps_wait": 0, + "action_infer_steps": 20, + "action_sample_shift": 5.0, + "action_dim_hidden": 1024, + "action_dim": 14, + "robot_state_dim": 14, + "policy_profile": "robotwin", + "normalize_mode": "z-score", + "binarize_gripper": false, + "gripper_postprocess": false, + "default_prompt": "A video recorded from a robot's point of view executing the following instruction: {task_prompt}" +} diff --git a/lightx2v/models/runners/wan/fastwam_runner.py b/lightx2v/models/runners/wan/fastwam_runner.py index 0d89ddcc9..7c1ddbdbf 100644 --- a/lightx2v/models/runners/wan/fastwam_runner.py +++ b/lightx2v/models/runners/wan/fastwam_runner.py @@ -25,18 +25,47 @@ def resize_rgb(image, width, height): return np.asarray(pil.resize((width, height), resample=Image.BILINEAR), dtype=np.uint8) -class MinMaxNormalizer: +def _canonical_norm_mode(mode): + compact = str(mode).strip().lower().replace("-", "").replace("_", "").replace("/", "") + if compact == "minmax": + return "min/max" + if compact == "zscore": + return "z-score" + if compact == "q01q99": + return "q01/q99" + raise ValueError(f"unsupported normalize_mode: {mode!r}") + + +class LinearNormalizer: + """Affine normalizer matching FastWAM's SingleFieldLinearNormalizer (global stats). + + Modes used by the released checkpoints: + - "min/max" (LIBERO): maps [global_min, global_max] -> [-1, 1] + - "q01/q99": maps [global_q01, global_q99] -> [-1, 1] + - "z-score" (RoboTwin): (x - global_mean) / global_std + """ + + std_reg = 1e-8 range_tol = 1e-4 - - def __init__(self, stats): - min_v = torch.as_tensor(stats["global_min"], dtype=torch.float32) - max_v = torch.as_tensor(stats["global_max"], dtype=torch.float32) - input_range = max_v - min_v - ignore = input_range < self.range_tol - input_range[ignore] = 2.0 - self.scale = 2.0 / input_range - self.offset = -1.0 - self.scale * min_v - self.offset[ignore] = -min_v[ignore] + output_min = -1.0 + output_max = 1.0 + + def __init__(self, stats, mode="min/max"): + mode = _canonical_norm_mode(mode) + g = {key[len("global_") :]: torch.as_tensor(value, dtype=torch.float32) for key, value in stats.items() if key.startswith("global_")} + if mode == "z-score": + mean, std = g["mean"], g["std"] + self.scale = 1.0 / (std + self.std_reg) + self.offset = -mean / (std + self.std_reg) + else: + low, high = (g["min"].clone(), g["max"].clone()) if mode == "min/max" else (g["q01"].clone(), g["q99"].clone()) + input_range = high - low + ignore = input_range < self.range_tol + input_range[ignore] = self.output_max - self.output_min + self.scale = (self.output_max - self.output_min) / input_range + self.offset = self.output_min - self.scale * low + self.offset[ignore] = (self.output_max + self.output_min) / 2 - low[ignore] + self.mode = mode def forward(self, x): x = torch.as_tensor(x, dtype=torch.float32) @@ -47,6 +76,13 @@ def backward(self, x): return (x - self.offset) / self.scale +class MinMaxNormalizer(LinearNormalizer): + """Backwards-compatible alias: LIBERO used a global min/max normalizer.""" + + def __init__(self, stats): + super().__init__(stats, mode="min/max") + + class FastWAMPolicy: def __init__( self, @@ -65,6 +101,9 @@ def __init__( binarize_gripper=True, default_prompt=None, camera_size=224, + policy_profile="libero", + normalize_mode="min/max", + gripper_postprocess=None, config=None, ): self.device = torch.device(device) @@ -103,6 +142,10 @@ def __init__( self.seed = None if seed is None or int(seed) < 0 else int(seed) self.binarize_gripper = bool(binarize_gripper) self.default_prompt = str(default_prompt) + self.policy_profile = str(policy_profile).strip().lower() + self.normalize_mode = _canonical_norm_mode(normalize_mode) + # LIBERO post-processes the single gripper channel; RoboTwin (dual-arm qpos) does not. + self.gripper_postprocess = (self.policy_profile == "libero") if gripper_postprocess is None else bool(gripper_postprocess) self.config = config self._check_model_config() self.pending_actions = deque() @@ -135,6 +178,9 @@ def from_config(cls, config): binarize_gripper=config.get("binarize_gripper", True), default_prompt=config.get("default_prompt"), camera_size=config.get("camera_size", 224), + policy_profile=config.get("policy_profile", "libero"), + normalize_mode=config.get("normalize_mode", "min/max"), + gripper_postprocess=config.get("gripper_postprocess"), config=config, ) @@ -147,8 +193,8 @@ def _load_normalizers(self): with open(self.dataset_stats_path, "r", encoding="utf-8") as f: stats = json.load(f) return ( - MinMaxNormalizer(stats["state"]["default"]), - MinMaxNormalizer(stats["action"]["default"]), + LinearNormalizer(stats["state"]["default"], self.normalize_mode), + LinearNormalizer(stats["action"]["default"], self.normalize_mode), ) def _load_text_encoder(self): @@ -188,17 +234,17 @@ def _find_model_dir(self, dirname): def reset(self): self.pending_actions.clear() - def next_action(self, agentview_rgb, wrist_rgb, state, task_description): + def next_action(self, images, state, task_description): if not self.pending_actions: - action_chunk = self.predict_action_chunk(agentview_rgb, wrist_rgb, state, task_description) + action_chunk = self.predict_action_chunk(images, state, task_description) for action in action_chunk[: self.actions_per_plan]: self.pending_actions.append(np.asarray(action, dtype=np.float32)) if not self.pending_actions: raise RuntimeError("FastWAM produced an empty action chunk") return self.pending_actions.popleft() - def predict_action_chunk(self, agentview_rgb, wrist_rgb, state, task_description, seed=None): - image = self.build_image_tensor(agentview_rgb, wrist_rgb) + def predict_action_chunk(self, images, state, task_description, seed=None): + image = self.build_image_tensor(images) first_frame_latents = self.encode_image_latents(image) context, context_mask = self.encode_prompt(self.default_prompt.format(task_prompt=task_description)) robot_state = self.state_normalizer.forward(np.asarray(state, dtype=np.float32)) @@ -216,10 +262,12 @@ def predict_action_chunk(self, agentview_rgb, wrist_rgb, state, task_description seed=self.seed if seed is None else seed, ) action = self.action_normalizer.backward(action).numpy() - action[..., -1] = action[..., -1] * 2 - 1 - action[..., -1] = action[..., -1] * -1.0 - if self.binarize_gripper: - action[..., -1] = np.sign(action[..., -1]) + if self.gripper_postprocess: + # LIBERO single-gripper channel: map [0,1] -> [-1,1], flip sign, optional binarize. + action[..., -1] = action[..., -1] * 2 - 1 + action[..., -1] = action[..., -1] * -1.0 + if self.binarize_gripper: + action[..., -1] = np.sign(action[..., -1]) return action.astype(np.float32) def _run_action_denoising(self, inputs, action_shape, action_infer_steps, seed): @@ -239,13 +287,35 @@ def _run_action_denoising(self, inputs, action_shape, action_infer_steps, seed): scheduler.clear() return action - def build_image_tensor(self, agentview_rgb, wrist_rgb): - primary = resize_rgb(agentview_rgb, self.camera_size, self.camera_size) - wrist = resize_rgb(wrist_rgb, self.camera_size, self.camera_size) - rgb = np.concatenate([primary, wrist], axis=1) + def build_image_tensor(self, images): + """Compose the policy image from a dict of {camera_name: HxWx3 RGB}. + + Layouts mirror the released FastWAM checkpoints: + - LIBERO : [agentview | wrist] side-by-side, each camera_size x camera_size. + - RoboTwin : head (320x256) on top, [left | right] (each 160x128) below, + i.e. final [384, 320, 3]; matches deploy_policy.py. + """ + rgb = self._compose_rgb(images) tensor = torch.from_numpy(rgb).permute(2, 0, 1).to(device=self.device, dtype=GET_DTYPE()) return tensor * (2.0 / 255.0) - 1.0 + def _compose_rgb(self, images): + def _get(name): + if name not in images or images[name] is None: + raise KeyError(f"FastWAM profile '{self.policy_profile}' requires camera '{name}'") + return images[name] + + if self.policy_profile == "robotwin": + head = resize_rgb(_get("head_camera"), 320, 256) + left = resize_rgb(_get("left_camera"), 160, 128) + right = resize_rgb(_get("right_camera"), 160, 128) + bottom = np.concatenate([left, right], axis=1) + return np.concatenate([head, bottom], axis=0) + + primary = resize_rgb(_get("agentview"), self.camera_size, self.camera_size) + wrist = resize_rgb(_get("wrist"), self.camera_size, self.camera_size) + return np.concatenate([primary, wrist], axis=1) + def encode_image_latents(self, image): image = image.unsqueeze(1) latents = self.vae.encode(image.unsqueeze(0)) @@ -352,8 +422,7 @@ def run_pipeline(self, input_info): agentview, wrist = self._load_image_pair() state = self._load_state() actions = self.policy.predict_action_chunk( - agentview_rgb=agentview, - wrist_rgb=wrist, + images={"agentview": agentview, "wrist": wrist}, state=state, task_description=self.input_info.prompt, seed=self.input_info.seed, diff --git a/lightx2v_ros/src/common/common/__init__.py b/lightx2v_ros/src/common/common/__init__.py new file mode 100644 index 000000000..5e568f47b --- /dev/null +++ b/lightx2v_ros/src/common/common/__init__.py @@ -0,0 +1,15 @@ +from .contract import ( + CONTRACTS, + LIBERO_CONTRACT, + ROBOTWIN_CONTRACT, + EnvContract, + get_contract, +) + +__all__ = [ + "CONTRACTS", + "EnvContract", + "LIBERO_CONTRACT", + "ROBOTWIN_CONTRACT", + "get_contract", +] diff --git a/lightx2v_ros/src/common/common/contract.py b/lightx2v_ros/src/common/common/contract.py new file mode 100644 index 000000000..e6719eedd --- /dev/null +++ b/lightx2v_ros/src/common/common/contract.py @@ -0,0 +1,115 @@ +"""Single source of truth for the simulator/inference/visualization contract. + +This module is intentionally dependency-free (pure Python, no ROS/torch imports) +so it can be imported by every ROS node as well as plain scripts. It defines, per +simulation environment, the ROS topic namespace, the set of cameras, the +action/state dimensions and the inference profile. All three nodes derive their +topic names and tensor dimensions from the same `EnvContract`, which removes the +LIBERO-specific hard-coding that used to live in each node. +""" + +from dataclasses import dataclass, field +from typing import Dict, Tuple + + +@dataclass(frozen=True) +class EnvContract: + # Logical environment id, e.g. "libero" or "robotwin". + name: str + # ROS topic prefix, e.g. "/libero". + namespace: str + # All cameras the simulator publishes (logical names, also used by the viewer). + cameras: Tuple[str, ...] + # Subset of `cameras` fed to the policy, in the exact order the policy expects. + policy_input_cameras: Tuple[str, ...] + # Robot action / proprio-state vector dimensions exchanged over ROS. + action_dim: int + state_dim: int + # Square render/publish size hint (used by simulators that render on demand). + image_size: int + # Inference assembly profile understood by FastWAMPolicy ("libero"|"robotwin"). + policy_profile: str + # Action/state normalization mode ("minmax"|"zscore"). + normalize_mode: str + # Whether to apply the LIBERO single-gripper sign/binarize post-processing. + gripper_postprocess: bool + # Optional human-readable description. + description: str = field(default="") + + # ----- derived topic helpers (kept in one place so nodes never disagree) ----- + @property + def action_topic(self) -> str: + return f"{self.namespace}/action" + + @property + def state_topic(self) -> str: + return f"{self.namespace}/state" + + @property + def success_topic(self) -> str: + return f"{self.namespace}/success" + + @property + def observation_ready_topic(self) -> str: + return f"{self.namespace}/observation_ready" + + @property + def task_topic(self) -> str: + return f"{self.namespace}/task_description" + + @property + def episode_topic(self) -> str: + # Monotonic episode counter; lets the policy reset its per-episode state when + # the simulator loops into a fresh episode. + return f"{self.namespace}/episode" + + def camera_topic(self, camera: str) -> str: + if camera not in self.cameras: + raise KeyError(f"camera '{camera}' is not part of contract '{self.name}': {self.cameras}") + return f"{self.namespace}/{camera}/image_raw" + + def camera_topics(self) -> Dict[str, str]: + return {camera: self.camera_topic(camera) for camera in self.cameras} + + +LIBERO_CONTRACT = EnvContract( + name="libero", + namespace="/libero", + cameras=("agentview", "wrist", "frontview", "galleryview"), + policy_input_cameras=("agentview", "wrist"), + action_dim=7, + state_dim=8, + image_size=224, + policy_profile="libero", + normalize_mode="minmax", + gripper_postprocess=True, + description="LIBERO (robosuite/mujoco) single-arm tabletop manipulation.", +) + + +ROBOTWIN_CONTRACT = EnvContract( + name="robotwin", + namespace="/robotwin", + cameras=("head_camera", "left_camera", "right_camera"), + policy_input_cameras=("head_camera", "left_camera", "right_camera"), + action_dim=14, + state_dim=14, + image_size=384, + policy_profile="robotwin", + normalize_mode="zscore", + gripper_postprocess=False, + description="RoboTwin 2.0 (SAPIEN) dual-arm manipulation.", +) + + +CONTRACTS: Dict[str, EnvContract] = { + LIBERO_CONTRACT.name: LIBERO_CONTRACT, + ROBOTWIN_CONTRACT.name: ROBOTWIN_CONTRACT, +} + + +def get_contract(name: str) -> EnvContract: + key = str(name).strip().lower() + if key not in CONTRACTS: + raise KeyError(f"unknown environment '{name}'. Available: {sorted(CONTRACTS)}") + return CONTRACTS[key] diff --git a/lightx2v_ros/src/common/package.xml b/lightx2v_ros/src/common/package.xml new file mode 100644 index 000000000..8dcee47f0 --- /dev/null +++ b/lightx2v_ros/src/common/package.xml @@ -0,0 +1,15 @@ + + + + common + 0.0.1 + Shared environment contract for LightX2V ROS nodes. + user + Apache-2.0 + + ament_python + + + ament_python + + diff --git a/lightx2v_ros/src/common/resource/common b/lightx2v_ros/src/common/resource/common new file mode 100644 index 000000000..e69de29bb diff --git a/lightx2v_ros/src/common/setup.cfg b/lightx2v_ros/src/common/setup.cfg new file mode 100644 index 000000000..2fbdbffa4 --- /dev/null +++ b/lightx2v_ros/src/common/setup.cfg @@ -0,0 +1,4 @@ +[develop] +script_dir=$base/lib/common +[install] +install_scripts=$base/lib/common diff --git a/lightx2v_ros/src/common/setup.py b/lightx2v_ros/src/common/setup.py new file mode 100644 index 000000000..ef2ca53f9 --- /dev/null +++ b/lightx2v_ros/src/common/setup.py @@ -0,0 +1,22 @@ +from setuptools import find_packages, setup + +package_name = "common" + +setup( + name=package_name, + version="0.0.1", + packages=find_packages(), + data_files=[ + ("share/ament_index/resource_index/packages", ["resource/" + package_name]), + ("share/" + package_name, ["package.xml"]), + ], + install_requires=["setuptools"], + zip_safe=True, + maintainer="user", + maintainer_email="user@example.com", + description="Shared environment contract for LightX2V ROS nodes.", + license="Apache-2.0", + entry_points={ + "console_scripts": [], + }, +) diff --git a/lightx2v_ros/src/inference/inference/fastwam_node/main.py b/lightx2v_ros/src/inference/inference/fastwam_node/main.py index df5dad84f..6676a2db7 100644 --- a/lightx2v_ros/src/inference/inference/fastwam_node/main.py +++ b/lightx2v_ros/src/inference/inference/fastwam_node/main.py @@ -1,5 +1,6 @@ import numpy as np import rclpy +from common.contract import get_contract from rclpy.node import Node from sensor_msgs.msg import Image from std_msgs.msg import Bool, Float32MultiArray, Int32, String @@ -7,72 +8,48 @@ from lightx2v.models.runners.wan.fastwam_runner import FastWAMPolicy from lightx2v.utils.set_config import auto_calc_config, get_default_config -DEFAULT_DUMMY_ACTION = np.asarray([0, 0, 0, 0, 0, 0, -1], dtype=np.float32) -ACTION_TOPIC = "/libero/action" -AGENTVIEW_TOPIC = "/libero/agentview/image_raw" -WRIST_TOPIC = "/libero/wrist/image_raw" -STATE_TOPIC = "/libero/state" -OBSERVATION_READY_TOPIC = "/libero/observation_ready" -SUCCESS_TOPIC = "/libero/success" -TASK_TOPIC = "/libero/task_description" - class FastWAMNode(Node): def __init__(self): super().__init__("fastwam_node") + self.declare_parameter("env", "libero") self.declare_parameter("config_json", "") self.declare_parameter("model_path", "") + self.declare_parameter("num_steps_wait", -1) + + env = str(self.get_parameter("env").value).strip().lower() + self.contract = get_contract(env) - self.get_logger().info("loading FastWAM policy") + self.get_logger().info(f"[{self.contract.name}] loading FastWAM policy") self.policy_config = self.build_policy_config() self.policy = FastWAMPolicy.from_config(self.policy_config) - self.get_logger().info("FastWAM policy loaded") + self.get_logger().info(f"[{self.contract.name}] FastWAM policy loaded") - self.agentview = None - self.wrist = None + self.images = {cam: None for cam in self.contract.policy_input_cameras} self.state = None self.task_description = None self.success = False + self.episode_index = 0 self.last_processed_observation = -1 - self.num_steps_wait = int(self.policy_config.get("num_steps_wait", 30)) - - self.action_pub = self.create_publisher(Float32MultiArray, ACTION_TOPIC, 10) - self.create_subscription( - Image, - AGENTVIEW_TOPIC, - self.on_agentview, - 10, - ) - self.create_subscription( - Image, - WRIST_TOPIC, - self.on_wrist, - 10, - ) - self.create_subscription( - Float32MultiArray, - STATE_TOPIC, - self.on_state, - 10, - ) - self.create_subscription( - String, - TASK_TOPIC, - self.on_task, - 10, - ) - self.create_subscription( - Bool, - SUCCESS_TOPIC, - self.on_success, - 10, - ) - self.create_subscription( - Int32, - OBSERVATION_READY_TOPIC, - self.on_observation_ready, - 10, + + ns_wait = int(self.get_parameter("num_steps_wait").value) + self.num_steps_wait = ns_wait if ns_wait >= 0 else int(self.policy_config.get("num_steps_wait", 0)) + self.dummy_action = self._build_dummy_action() + + self.action_pub = self.create_publisher(Float32MultiArray, self.contract.action_topic, 10) + self._camera_subs = [] + for cam in self.contract.policy_input_cameras: + self._camera_subs.append(self.create_subscription(Image, self.contract.camera_topic(cam), self._make_image_cb(cam), 10)) + self.create_subscription(Float32MultiArray, self.contract.state_topic, self.on_state, 10) + self.create_subscription(String, self.contract.task_topic, self.on_task, 10) + self.create_subscription(Bool, self.contract.success_topic, self.on_success, 10) + self.create_subscription(Int32, self.contract.episode_topic, self.on_episode, 10) + self.create_subscription(Int32, self.contract.observation_ready_topic, self.on_observation_ready, 10) + + self.get_logger().info( + f"[{self.contract.name}] fastwam_node ready: input_cameras={list(self.contract.policy_input_cameras)} " + f"action_dim={self.contract.action_dim} state_dim={self.contract.state_dim} num_steps_wait={self.num_steps_wait}" ) def build_policy_config(self): @@ -91,18 +68,33 @@ def build_policy_config(self): "config_json": config_json, } ) - return auto_calc_config(config) + config = auto_calc_config(config) + + # The config_json is authoritative for policy params; warn loudly on any + # mismatch with the environment contract so dimension bugs surface early. + for key, expected in (("action_dim", self.contract.action_dim), ("robot_state_dim", self.contract.state_dim)): + actual = int(config.get(key, expected)) + if actual != expected: + self.get_logger().warning(f"config `{key}`={actual} disagrees with env '{self.contract.name}' contract ({expected})") + return config - def on_agentview(self, msg): - self.agentview = image_msg_to_rgb(msg) + def _build_dummy_action(self): + action = np.zeros(self.contract.action_dim, dtype=np.float32) + if self.contract.gripper_postprocess: + # LIBERO warmup keeps the gripper open while the scene settles. + action[-1] = -1.0 + return action - def on_wrist(self, msg): - self.wrist = image_msg_to_rgb(msg) + def _make_image_cb(self, camera): + def _cb(msg): + self.images[camera] = image_msg_to_rgb(msg) + + return _cb def on_state(self, msg): state = np.asarray(msg.data, dtype=np.float32) - if state.shape != (8,): - self.get_logger().error(f"expected /libero/state length 8, got {state.size}") + if state.size != self.contract.state_dim: + self.get_logger().error(f"expected {self.contract.state_topic} length {self.contract.state_dim}, got {state.size}") return self.state = state @@ -112,6 +104,18 @@ def on_task(self, msg): def on_success(self, msg): self.success = bool(msg.data) + def on_episode(self, msg): + episode = int(msg.data) + if episode == self.episode_index: + return + # Simulator looped into a fresh episode: drop the queued action chunk and any + # stale success/observation bookkeeping so we start clean on the new rollout. + self.episode_index = episode + self.success = False + self.last_processed_observation = -1 + self.policy.reset() + self.get_logger().info(f"new episode {episode}; policy state reset for fresh rollout") + def on_observation_ready(self, msg): observation_index = int(msg.data) if observation_index <= self.last_processed_observation: @@ -127,13 +131,12 @@ def on_observation_ready(self, msg): return if observation_index < self.num_steps_wait: - action = DEFAULT_DUMMY_ACTION.copy() + action = self.dummy_action.copy() self.get_logger().info(f"observation {observation_index}: publishing warmup dummy action") else: self.get_logger().info(f"observation {observation_index}: running FastWAM inference/action queue") action = self.policy.next_action( - agentview_rgb=self.agentview, - wrist_rgb=self.wrist, + images={cam: self.images[cam] for cam in self.contract.policy_input_cameras}, state=self.state, task_description=self.task_description, ) @@ -142,11 +145,7 @@ def on_observation_ready(self, msg): self.last_processed_observation = observation_index def missing_inputs(self): - missing = [] - if self.agentview is None: - missing.append("agentview") - if self.wrist is None: - missing.append("wrist") + missing = [cam for cam in self.contract.policy_input_cameras if self.images.get(cam) is None] if self.state is None: missing.append("state") if not self.task_description: @@ -155,8 +154,8 @@ def missing_inputs(self): def publish_action(self, action): action = np.asarray(action, dtype=np.float32).reshape(-1) - if action.shape != (7,): - raise ValueError(f"expected action length 7, got {action.size}") + if action.size != self.contract.action_dim: + raise ValueError(f"expected action length {self.contract.action_dim}, got {action.size}") msg = Float32MultiArray() msg.data = action.tolist() self.action_pub.publish(msg) diff --git a/lightx2v_ros/src/inference/package.xml b/lightx2v_ros/src/inference/package.xml index d83e8087e..f2fec9ccc 100644 --- a/lightx2v_ros/src/inference/package.xml +++ b/lightx2v_ros/src/inference/package.xml @@ -12,6 +12,7 @@ rclpy sensor_msgs std_msgs + common ament_python diff --git a/lightx2v_ros/src/simulator/package.xml b/lightx2v_ros/src/simulator/package.xml index 4fbd4c1e9..db47cc1c8 100644 --- a/lightx2v_ros/src/simulator/package.xml +++ b/lightx2v_ros/src/simulator/package.xml @@ -12,6 +12,7 @@ rclpy sensor_msgs std_msgs + common ament_python diff --git a/lightx2v_ros/src/simulator/setup.py b/lightx2v_ros/src/simulator/setup.py index f0609f1c6..52b721552 100644 --- a/lightx2v_ros/src/simulator/setup.py +++ b/lightx2v_ros/src/simulator/setup.py @@ -19,6 +19,7 @@ entry_points={ "console_scripts": [ "libero_node = simulator.libero_node.main:main", + "robotwin_node = simulator.robotwin_node.main:main", ], }, ) diff --git a/lightx2v_ros/src/simulator/simulator/libero_node/env.py b/lightx2v_ros/src/simulator/simulator/libero_node/env.py new file mode 100644 index 000000000..49f7a384c --- /dev/null +++ b/lightx2v_ros/src/simulator/simulator/libero_node/env.py @@ -0,0 +1,95 @@ +"""LIBERO implementation of the generic `BaseSimEnv` contract.""" + +import math + +import numpy as np +from common.contract import EnvContract + +from ..sim.base_env import BaseSimEnv, Observation +from .observer import LiberoActionObserver, default_libero_root + + +def quat_to_axis_angle(quat): + quat = np.asarray(quat, dtype=np.float32).copy() + quat[3] = np.clip(quat[3], -1.0, 1.0) + den = np.sqrt(1.0 - quat[3] * quat[3]) + if math.isclose(float(den), 0.0): + return np.zeros(3, dtype=np.float32) + return ((quat[:3] * 2.0 * math.acos(float(quat[3]))) / den).astype(np.float32) + + +class LiberoEnv(BaseSimEnv): + # logical camera name -> LIBERO observation key + CAMERA_OBS_KEYS = { + "agentview": "agentview_image", + "wrist": "robot0_eye_in_hand_image", + "frontview": "frontview_image", + "galleryview": "galleryview_image", + } + + def __init__( + self, + contract: EnvContract, + *, + benchmark="libero_spatial", + task_id=0, + init_state_id=0, + image_size=224, + seed=0, + libero_root=None, + ): + super().__init__(contract) + self.observer = LiberoActionObserver( + benchmark_name=benchmark, + task_id=int(task_id), + init_state_id=int(init_state_id), + image_size=int(image_size), + seed=int(seed), + libero_root=libero_root, + ) + + @property + def task_description(self) -> str: + return self.observer.task_description + + def reset(self) -> Observation: + return self._observation() + + def step(self, action): + _, _, success, _ = self.observer.step(action) + return self._observation(), bool(success) + + def _observation(self) -> Observation: + obs = self.observer.obs + # LIBERO renders upside-down/mirrored relative to the policy expectation. + images = {cam: np.ascontiguousarray(obs[key][::-1, ::-1]) for cam, key in self.CAMERA_OBS_KEYS.items() if cam in self.contract.cameras} + return Observation(images=images, state=self._state(obs)) + + def _state(self, obs) -> np.ndarray: + pos = np.asarray(obs["robot0_eef_pos"], dtype=np.float32) + axis_angle = quat_to_axis_angle(np.asarray(obs["robot0_eef_quat"], dtype=np.float32)) + gripper = np.asarray(obs["robot0_gripper_qpos"], dtype=np.float32) + return np.concatenate([pos, axis_angle, gripper]).astype(np.float32) + + def close(self) -> None: + self.observer.close() + + +def build_libero_env(node) -> LiberoEnv: + contract = node.contract + node.declare_parameter("libero_root", str(default_libero_root())) + node.declare_parameter("benchmark", "libero_spatial") + node.declare_parameter("task_id", 0) + node.declare_parameter("init_state_id", 0) + node.declare_parameter("image_size", contract.image_size) + node.declare_parameter("seed", 0) + + return LiberoEnv( + contract, + benchmark=node.get_parameter("benchmark").value, + task_id=int(node.get_parameter("task_id").value), + init_state_id=int(node.get_parameter("init_state_id").value), + image_size=int(node.get_parameter("image_size").value), + seed=int(node.get_parameter("seed").value), + libero_root=node.get_parameter("libero_root").value, + ) diff --git a/lightx2v_ros/src/simulator/simulator/libero_node/main.py b/lightx2v_ros/src/simulator/simulator/libero_node/main.py index 175387a51..f265a4171 100644 --- a/lightx2v_ros/src/simulator/simulator/libero_node/main.py +++ b/lightx2v_ros/src/simulator/simulator/libero_node/main.py @@ -1,162 +1,12 @@ -import math +from common.contract import get_contract -import numpy as np -import rclpy -from rclpy.node import Node -from sensor_msgs.msg import Image -from std_msgs.msg import Bool, Float32MultiArray, Int32, String - -from .observer import LiberoActionObserver, default_libero_root - - -class LiberoNode(Node): - def __init__(self): - super().__init__("libero_node") - - self.declare_parameter("libero_root", str(default_libero_root())) - self.declare_parameter("benchmark", "libero_spatial") - self.declare_parameter("task_id", 0) - self.declare_parameter("init_state_id", 0) - self.declare_parameter("image_size", 224) - self.declare_parameter("seed", 0) - self.declare_parameter("action_topic", "/libero/action") - self.declare_parameter("state_topic", "/libero/state") - self.declare_parameter("agentview_topic", "/libero/agentview/image_raw") - self.declare_parameter("wrist_topic", "/libero/wrist/image_raw") - self.declare_parameter("frontview_topic", "/libero/frontview/image_raw") - self.declare_parameter("galleryview_topic", "/libero/galleryview/image_raw") - self.declare_parameter("success_topic", "/libero/success") - self.declare_parameter("observation_ready_topic", "/libero/observation_ready") - self.declare_parameter("task_topic", "/libero/task_description") - - self.state_pub = self.create_publisher(Float32MultiArray, self.get_parameter("state_topic").value, 10) - self.agentview_pub = self.create_publisher(Image, self.get_parameter("agentview_topic").value, 10) - self.wrist_pub = self.create_publisher(Image, self.get_parameter("wrist_topic").value, 10) - self.frontview_pub = self.create_publisher(Image, self.get_parameter("frontview_topic").value, 10) - self.galleryview_pub = self.create_publisher(Image, self.get_parameter("galleryview_topic").value, 10) - self.success_pub = self.create_publisher(Bool, self.get_parameter("success_topic").value, 10) - self.observation_ready_pub = self.create_publisher(Int32, self.get_parameter("observation_ready_topic").value, 10) - self.task_pub = self.create_publisher(String, self.get_parameter("task_topic").value, 10) - self.action_sub = self.create_subscription( - Float32MultiArray, - self.get_parameter("action_topic").value, - self.on_action, - 10, - ) - - self.observer = LiberoActionObserver( - benchmark_name=self.get_parameter("benchmark").value, - task_id=int(self.get_parameter("task_id").value), - init_state_id=int(self.get_parameter("init_state_id").value), - image_size=int(self.get_parameter("image_size").value), - seed=int(self.get_parameter("seed").value), - libero_root=self.get_parameter("libero_root").value, - ) - self.step_index = 0 - self.success = False - self.observation_timer = self.create_timer(1.0, self.republish_observation) - - self.get_logger().info(f"listening for actions on {self.get_parameter('action_topic').value}") - self.publish_observation() - - def republish_observation(self): - if self.success: - self.observation_timer.cancel() - return - self.publish_observation() - - def on_action(self, msg): - if self.success: - self.get_logger().warning("episode already succeeded; ignoring action") - return - - action = np.asarray(msg.data, dtype=np.float32) - if action.shape != (7,): - self.get_logger().error(f"expected action length 7, got {action.size}") - return - - self.get_logger().info(f"received action: {action.tolist()}") - _, _, success, _ = self.observer.step(action) - self.step_index += 1 - self.success = bool(success) - self.publish_observation() - - if self.success: - self.get_logger().info(f"episode succeeded at step {self.step_index}") - - def publish_observation(self): - obs = self.observer.obs - stamp = self.get_clock().now().to_msg() - - self.state_pub.publish(self.make_state_msg(obs)) - self.agentview_pub.publish(self.make_image_msg(obs["agentview_image"][::-1, ::-1], stamp, "agentview")) - self.wrist_pub.publish(self.make_image_msg(obs["robot0_eye_in_hand_image"][::-1, ::-1], stamp, "wrist")) - self.frontview_pub.publish(self.make_image_msg(obs["frontview_image"][::-1, ::-1], stamp, "frontview")) - self.galleryview_pub.publish(self.make_image_msg(obs["galleryview_image"][::-1, ::-1], stamp, "galleryview")) - - task_msg = String() - task_msg.data = self.observer.task_description - self.task_pub.publish(task_msg) - - success_msg = Bool() - success_msg.data = self.success - self.success_pub.publish(success_msg) - - observation_ready_msg = Int32() - observation_ready_msg.data = self.step_index - self.observation_ready_pub.publish(observation_ready_msg) - - def make_state_msg(self, obs): - pos = np.asarray(obs["robot0_eef_pos"], dtype=np.float32) - axis_angle = quat_to_axis_angle(np.asarray(obs["robot0_eef_quat"], dtype=np.float32)) - gripper = np.asarray(obs["robot0_gripper_qpos"], dtype=np.float32) - - msg = Float32MultiArray() - msg.data = np.concatenate([pos, axis_angle, gripper]).astype(np.float32).tolist() - return msg - - def make_image_msg(self, image, stamp, frame_id): - image = np.ascontiguousarray(image) - msg = Image() - msg.header.stamp = stamp - msg.header.frame_id = frame_id - msg.height = int(image.shape[0]) - msg.width = int(image.shape[1]) - msg.encoding = "rgb8" - msg.is_bigendian = False - msg.step = int(image.strides[0]) - msg.data = image.tobytes() - return msg - - def destroy_node(self): - if hasattr(self, "observer"): - self.observer.close() - super().destroy_node() - - -def quat_to_axis_angle(quat): - quat = np.asarray(quat, dtype=np.float32).copy() - quat[3] = np.clip(quat[3], -1.0, 1.0) - den = np.sqrt(1.0 - quat[3] * quat[3]) - if math.isclose(float(den), 0.0): - return np.zeros(3, dtype=np.float32) - return ((quat[:3] * 2.0 * math.acos(float(quat[3]))) / den).astype(np.float32) +from ..sim.node import run_simulator_node +from .env import build_libero_env def main(args=None): - rclpy.init(args=args) - node = LiberoNode() - try: - rclpy.spin(node) - except KeyboardInterrupt: - pass - except Exception: - if rclpy.ok(): - raise - finally: - node.destroy_node() - if rclpy.ok(): - rclpy.shutdown() + contract = get_contract("libero") + run_simulator_node(contract, build_libero_env, node_name="libero_node", args=args) if __name__ == "__main__": diff --git a/lightx2v_ros/src/simulator/simulator/robotwin_node/RoboTwin b/lightx2v_ros/src/simulator/simulator/robotwin_node/RoboTwin new file mode 160000 index 000000000..bf44be51c --- /dev/null +++ b/lightx2v_ros/src/simulator/simulator/robotwin_node/RoboTwin @@ -0,0 +1 @@ +Subproject commit bf44be51cf5717a5595ce59447f2cf5263d2aa95 diff --git a/lightx2v_ros/src/simulator/simulator/robotwin_node/__init__.py b/lightx2v_ros/src/simulator/simulator/robotwin_node/__init__.py new file mode 100644 index 000000000..06858927d --- /dev/null +++ b/lightx2v_ros/src/simulator/simulator/robotwin_node/__init__.py @@ -0,0 +1 @@ +# robotwin_node package diff --git a/lightx2v_ros/src/simulator/simulator/robotwin_node/env.py b/lightx2v_ros/src/simulator/simulator/robotwin_node/env.py new file mode 100644 index 000000000..d6f46e9df --- /dev/null +++ b/lightx2v_ros/src/simulator/simulator/robotwin_node/env.py @@ -0,0 +1,247 @@ +"""RoboTwin (SAPIEN, dual-arm) implementation of the generic `BaseSimEnv`. + +This adapter wraps a vendored RoboTwin task so the generic `SimulatorNode` can +drive it exactly like LIBERO. RoboTwin's own evaluation orchestration lives in +``third_party/RoboTwin/script/eval_policy.py``; that script is RoboTwin-driven +(it owns the rollout loop and calls a policy plugin). Here we invert the control +flow for ROS: we build the same task `args`, run ``setup_demo`` once, and then +expose ``reset`` / ``step`` (``get_obs`` + ``take_action(qpos)``), publishing the +three RoboTwin cameras and the 14-dim joint-state vector over ROS. + +RoboTwin heavy dependencies (sapien, mplib, curobo, ...) and assets are imported +lazily inside ``reset``/construction, so the ROS package builds and imports even +on machines where the RoboTwin runtime is not installed yet. +""" + +import importlib +import os +import sys +from pathlib import Path + +import numpy as np +from common.contract import EnvContract + +from ..sim.base_env import BaseSimEnv, Observation + + +def default_robotwin_root() -> Path: + return Path(__file__).resolve().parent / "RoboTwin" + + +def _add_python_path(path) -> None: + path = str(Path(path)) + if path not in sys.path: + sys.path.insert(0, path) + + +class RoboTwinEnv(BaseSimEnv): + """Single-episode RoboTwin environment exposed through the BaseSimEnv contract.""" + + def __init__( + self, + contract: EnvContract, + *, + task_name: str = "click_alarmclock", + task_config: str = "demo_clean", + embodiment: str = "aloha-agilex", + instruction_type: str = "unseen", + instruction: str = "", + seed: int = 0, + robotwin_root=None, + ): + super().__init__(contract) + self.robotwin_root = Path(robotwin_root or default_robotwin_root()).expanduser() + self.task_name = str(task_name) + self.task_config = str(task_config) + self.embodiment = str(embodiment).strip() + self.instruction_type = str(instruction_type) + self._fixed_instruction = str(instruction).strip() + self.seed = int(seed) + self._episode_index = 0 + + self._task_description = "" + self._configs_path = self.robotwin_root / "task_config" + + self._prepare_runtime() + self.args = self._build_task_args() + self.env = self._instantiate_task() + self._setup_episode() + + # ------------------------------------------------------------------ setup + def _prepare_runtime(self) -> None: + root = self.robotwin_root + if not (root / "envs").is_dir(): + raise FileNotFoundError(f"RoboTwin is not vendored at {root}. See robotwin_node/RoboTwin/README and run the RoboTwin install/asset-download steps.") + # RoboTwin source uses root-relative imports such as `from envs import ...` + # and `from generate_episode_instructions import *`. + _add_python_path(root) + _add_python_path(root / "description" / "utils") + + def _require_config(self, *parts) -> Path: + path = self._configs_path.joinpath(*parts) + if not path.exists(): + raise FileNotFoundError(f"Missing RoboTwin config: {path}. Populate `task_config/` (and `assets/`) from the official RoboTwin repo (see robotwin_node/RoboTwin/script).") + return path + + def _build_task_args(self) -> dict: + """Replicates third_party/RoboTwin/script/eval_policy.py:main() arg assembly.""" + import yaml + + with open(self._require_config(f"{self.task_config}.yml"), "r", encoding="utf-8") as f: + args = yaml.load(f.read(), Loader=yaml.FullLoader) + + args["task_name"] = self.task_name + args["task_config"] = self.task_config + + # Allow the launch parameter to pin the embodiment (e.g. "aloha-agilex"). + if self.embodiment: + args["embodiment"] = [self.embodiment] + embodiment_type = args.get("embodiment") + if not isinstance(embodiment_type, list): + raise ValueError(f"task_config embodiment must be a list, got {embodiment_type!r}") + + with open(self._require_config("_embodiment_config.yml"), "r", encoding="utf-8") as f: + embodiment_types = yaml.load(f.read(), Loader=yaml.FullLoader) + + def embodiment_file(key): + robot_file = embodiment_types[key]["file_path"] + if robot_file is None: + raise ValueError(f"No embodiment file for '{key}'") + return os.path.join(str(self.robotwin_root), robot_file) if not os.path.isabs(robot_file) else robot_file + + def embodiment_config(robot_file): + with open(os.path.join(robot_file, "config.yml"), "r", encoding="utf-8") as f: + return yaml.load(f.read(), Loader=yaml.FullLoader) + + with open(self._require_config("_camera_config.yml"), "r", encoding="utf-8") as f: + camera_config = yaml.load(f.read(), Loader=yaml.FullLoader) + + head_camera_type = args["camera"]["head_camera_type"] + args["head_camera_h"] = camera_config[head_camera_type]["h"] + args["head_camera_w"] = camera_config[head_camera_type]["w"] + + if len(embodiment_type) == 1: + args["left_robot_file"] = embodiment_file(embodiment_type[0]) + args["right_robot_file"] = embodiment_file(embodiment_type[0]) + args["dual_arm_embodied"] = True + elif len(embodiment_type) == 3: + args["left_robot_file"] = embodiment_file(embodiment_type[0]) + args["right_robot_file"] = embodiment_file(embodiment_type[1]) + args["embodiment_dis"] = embodiment_type[2] + args["dual_arm_embodied"] = False + else: + raise ValueError("embodiment items should be 1 or 3") + + args["left_embodiment_config"] = embodiment_config(args["left_robot_file"]) + args["right_embodiment_config"] = embodiment_config(args["right_robot_file"]) + args["eval_mode"] = True + # Headless: never spawn the on-screen SAPIEN viewer. + args["render_freq"] = 0 + return args + + def _instantiate_task(self): + module = importlib.import_module(f"envs.{self.task_name}") + task_cls = getattr(module, self.task_name) + return task_cls() + + def _setup_episode(self) -> None: + self.env.setup_demo(now_ep_num=self._episode_index, seed=self.seed, is_test=True, **self.args) + instruction = self._resolve_instruction() + self.env.set_instruction(instruction=instruction) + self._task_description = instruction + + def _resolve_instruction(self) -> str: + if self._fixed_instruction: + return self._fixed_instruction + try: + from generate_episode_instructions import generate_episode_descriptions + + episode_info_list = [self.env.info["info"]] + results = generate_episode_descriptions(self.task_name, episode_info_list, 1) + return str(np.random.choice(results[0][self.instruction_type])) + except Exception: + return f"Complete the {self.task_name.replace('_', ' ')} task." + + # ------------------------------------------------------------- contract API + @property + def task_description(self) -> str: + return self._task_description + + def reset(self) -> Observation: + return self._observation() + + @property + def max_steps(self): + # RoboTwin sets a per-task rollout cap (`step_lim`) during setup_demo. + return getattr(self.env, "step_lim", None) + + def new_episode(self, max_setup_retries: int = 25) -> Observation: + """Tear down the current episode and set up a fresh one (new layout). + + Advances the seed so each episode gets a different object placement, and + retries the next seeds if setup raises (e.g. RoboTwin ``UnStableError``). + """ + last_err = None + for _ in range(max(1, max_setup_retries)): + self.seed += 1 + try: + try: + self.env.close_env(clear_cache=True) + except Exception: + pass + self._episode_index += 1 + self._setup_episode() + return self._observation() + except Exception as exc: # e.g. UnStableError on an unlucky seed + last_err = exc + raise RuntimeError(f"RoboTwin failed to set up a new episode after {max_setup_retries} seeds; last error: {last_err}") + + def step(self, action): + action = np.asarray(action, dtype=np.float32).reshape(-1) + # RoboTwin policies output absolute joint targets (qpos), matching FastWAM. + self.env.take_action(action, action_type="qpos") + obs = self._observation() + success = bool(getattr(self.env, "eval_success", False)) or bool(self.env.check_success()) + return obs, success + + def _observation(self) -> Observation: + raw = self.env.get_obs() + cameras = raw["observation"] + images = {} + for cam in self.contract.cameras: + if cam not in cameras or "rgb" not in cameras[cam]: + raise KeyError(f"RoboTwin observation missing camera '{cam}' rgb; got {list(cameras)}") + rgb = np.asarray(cameras[cam]["rgb"])[..., :3] + images[cam] = np.ascontiguousarray(rgb.astype(np.uint8)) + state = np.asarray(raw["joint_action"]["vector"], dtype=np.float32).reshape(-1) + return Observation(images=images, state=state) + + def close(self) -> None: + env = getattr(self, "env", None) + if env is not None: + try: + env.close_env() + except Exception: + pass + + +def build_robotwin_env(node) -> RoboTwinEnv: + contract = node.contract + node.declare_parameter("robotwin_root", str(default_robotwin_root())) + node.declare_parameter("task_name", "click_alarmclock") + node.declare_parameter("task_config", "demo_clean") + node.declare_parameter("embodiment", "aloha-agilex") + node.declare_parameter("instruction_type", "unseen") + node.declare_parameter("instruction", "") + node.declare_parameter("seed", 0) + + return RoboTwinEnv( + contract, + task_name=node.get_parameter("task_name").value, + task_config=node.get_parameter("task_config").value, + embodiment=node.get_parameter("embodiment").value, + instruction_type=node.get_parameter("instruction_type").value, + instruction=node.get_parameter("instruction").value, + seed=int(node.get_parameter("seed").value), + robotwin_root=node.get_parameter("robotwin_root").value, + ) diff --git a/lightx2v_ros/src/simulator/simulator/robotwin_node/main.py b/lightx2v_ros/src/simulator/simulator/robotwin_node/main.py new file mode 100644 index 000000000..4251514da --- /dev/null +++ b/lightx2v_ros/src/simulator/simulator/robotwin_node/main.py @@ -0,0 +1,13 @@ +from common.contract import get_contract + +from ..sim.node import run_simulator_node +from .env import build_robotwin_env + + +def main(args=None): + contract = get_contract("robotwin") + run_simulator_node(contract, build_robotwin_env, node_name="robotwin_node", args=args) + + +if __name__ == "__main__": + main() diff --git a/lightx2v_ros/src/simulator/simulator/sim/__init__.py b/lightx2v_ros/src/simulator/simulator/sim/__init__.py new file mode 100644 index 000000000..c7b0fec45 --- /dev/null +++ b/lightx2v_ros/src/simulator/simulator/sim/__init__.py @@ -0,0 +1,10 @@ +from .base_env import BaseSimEnv, Observation +from .node import SimulatorNode, rgb_to_image_msg, run_simulator_node + +__all__ = [ + "BaseSimEnv", + "Observation", + "SimulatorNode", + "rgb_to_image_msg", + "run_simulator_node", +] diff --git a/lightx2v_ros/src/simulator/simulator/sim/base_env.py b/lightx2v_ros/src/simulator/simulator/sim/base_env.py new file mode 100644 index 000000000..e30abf760 --- /dev/null +++ b/lightx2v_ros/src/simulator/simulator/sim/base_env.py @@ -0,0 +1,70 @@ +"""Environment-agnostic simulator interface. + +Every concrete simulator (LIBERO, RoboTwin, ...) is exposed to the ROS layer +through the same small contract: it produces an `Observation` (a dict of +camera RGB frames plus a flat proprio-state vector) and consumes an action +vector. The generic `SimulatorNode` only ever talks to this interface, so +adding a new environment never requires touching the node/topic plumbing. +""" + +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Dict, Tuple + +import numpy as np +from common.contract import EnvContract + + +@dataclass +class Observation: + # logical camera name -> HxWx3 uint8 RGB image, already oriented for publishing + images: Dict[str, np.ndarray] + # flat proprio-state vector, shape (contract.state_dim,) + state: np.ndarray + + +class BaseSimEnv(ABC): + def __init__(self, contract: EnvContract): + self.contract = contract + + @property + @abstractmethod + def task_description(self) -> str: + """Current natural-language task/instruction.""" + + @abstractmethod + def reset(self) -> Observation: + """Reset the environment and return the first observation.""" + + @abstractmethod + def step(self, action: np.ndarray) -> Tuple[Observation, bool]: + """Apply one action, returning (observation, success).""" + + def new_episode(self) -> Observation: + """Start a fresh episode and return its first observation. + + Used by the node's continuous-eval loop. The default simply calls + ``reset()``; environments that need to re-randomize a scene (e.g. + RoboTwin) override this to tear down and rebuild the episode. + """ + return self.reset() + + @property + def max_steps(self): + """Optional per-episode step cap hint (None = let the node decide).""" + return None + + def close(self) -> None: + return None + + def validate(self, obs: Observation) -> None: + missing = [cam for cam in self.contract.cameras if cam not in obs.images] + if missing: + raise ValueError(f"env '{self.contract.name}' observation is missing cameras: {missing}") + for cam in self.contract.cameras: + image = np.asarray(obs.images[cam]) + if image.ndim != 3 or image.shape[2] != 3: + raise ValueError(f"env '{self.contract.name}' camera '{cam}' must be HxWx3, got {image.shape}") + state = np.asarray(obs.state, dtype=np.float32).reshape(-1) + if state.size != self.contract.state_dim: + raise ValueError(f"env '{self.contract.name}' state dim {state.size} != contract {self.contract.state_dim}") diff --git a/lightx2v_ros/src/simulator/simulator/sim/node.py b/lightx2v_ros/src/simulator/simulator/sim/node.py new file mode 100644 index 000000000..844993366 --- /dev/null +++ b/lightx2v_ros/src/simulator/simulator/sim/node.py @@ -0,0 +1,218 @@ +"""Generic, contract-driven simulator ROS node. + +`SimulatorNode` is environment-agnostic: it derives every topic name and the +action/state dimensions from the `EnvContract` and drives any `BaseSimEnv` +implementation. An `env_factory(node) -> BaseSimEnv` callback lets each concrete +environment declare its own ROS parameters on the node before construction. +""" + +from typing import Callable + +import numpy as np +import rclpy +from common.contract import EnvContract +from rclpy.node import Node +from sensor_msgs.msg import Image +from std_msgs.msg import Bool, Float32MultiArray, Int32, String + +from .base_env import BaseSimEnv + + +def rgb_to_image_msg(image, stamp, frame_id): + image = np.ascontiguousarray(image) + msg = Image() + msg.header.stamp = stamp + msg.header.frame_id = frame_id + msg.height = int(image.shape[0]) + msg.width = int(image.shape[1]) + msg.encoding = "rgb8" + msg.is_bigendian = False + msg.step = int(image.strides[0]) + msg.data = image.tobytes() + return msg + + +class SimulatorNode(Node): + def __init__( + self, + contract: EnvContract, + env_factory: Callable[["SimulatorNode"], BaseSimEnv], + *, + node_name: str = "simulator_node", + ): + super().__init__(node_name) + self.contract = contract + + self.declare_parameter("republish_period", 1.0) + self.declare_parameter("idle_republish_period", 2.0) + # Continuous-eval loop: when true, the node automatically starts a fresh + # episode after each success (or step cap) instead of stopping. + self.declare_parameter("loop", False) + # Per-episode step cap; <=0 means "use the env hint (env.max_steps) or run + # until success". Failed episodes rely on this cap to eventually loop. + self.declare_parameter("max_episode_steps", 0) + self.republish_period = float(self.get_parameter("republish_period").value) + self.idle_republish_period = float(self.get_parameter("idle_republish_period").value) + self.loop = bool(self.get_parameter("loop").value) + + # env_factory may declare/read its own parameters via `self`. + self.env = env_factory(self) + if self.env.contract is not contract: + raise ValueError("env_factory returned an env bound to a different contract") + + param_max_steps = int(self.get_parameter("max_episode_steps").value) + env_hint = getattr(self.env, "max_steps", None) + if param_max_steps > 0: + self.max_episode_steps = param_max_steps + elif env_hint: + self.max_episode_steps = int(env_hint) + else: + self.max_episode_steps = 0 + + self.state_pub = self.create_publisher(Float32MultiArray, contract.state_topic, 10) + self.image_pubs = {cam: self.create_publisher(Image, contract.camera_topic(cam), 10) for cam in contract.cameras} + self.success_pub = self.create_publisher(Bool, contract.success_topic, 10) + self.observation_ready_pub = self.create_publisher(Int32, contract.observation_ready_topic, 10) + self.task_pub = self.create_publisher(String, contract.task_topic, 10) + self.episode_pub = self.create_publisher(Int32, contract.episode_topic, 10) + self.action_sub = self.create_subscription(Float32MultiArray, contract.action_topic, self.on_action, 10) + + # `step_index` is a monotonic global observation counter (never reset), so the + # policy's "process only newer observations" logic keeps working across episodes. + self.step_index = 0 + # `episode_step` counts steps within the current episode (drives the step cap). + self.episode_step = 0 + self.episode_index = 0 + self.success = False + self._slowed = False + + self.obs = self.env.reset() + self.env.validate(self.obs) + + self.get_logger().info( + f"[{contract.name}] cameras={list(contract.cameras)} " + f"action_dim={contract.action_dim} state_dim={contract.state_dim}; " + f"loop={self.loop} max_episode_steps={self.max_episode_steps or 'unlimited'}; " + f"listening for actions on {contract.action_topic}" + ) + self.timer = self.create_timer(self.republish_period, self.republish) + self.publish_observation() + + def republish(self): + self.publish_observation() + + def on_action(self, msg): + if self.success: + # In loop mode `success` is reset to False synchronously in + # `_start_next_episode`, so this only drops late actions that raced the + # episode boundary; in single-episode mode it stops the rollout. + if not self.loop: + self.get_logger().warning("episode already succeeded; ignoring action") + return + + action = np.asarray(msg.data, dtype=np.float32).reshape(-1) + if action.size != self.contract.action_dim: + self.get_logger().error(f"expected action length {self.contract.action_dim}, got {action.size}") + return + + self.obs, success = self.env.step(action) + self.step_index += 1 + self.episode_step += 1 + self.success = bool(success) + + capped = self.max_episode_steps > 0 and self.episode_step >= self.max_episode_steps + if self.loop and (self.success or capped): + outcome = "SUCCESS" if self.success else f"step cap ({self.max_episode_steps})" + self.get_logger().info(f"episode {self.episode_index} ended [{outcome}] after {self.episode_step} steps (global step {self.step_index}); starting next episode...") + # Emit the final frame (success flag reflects the outcome) before rebuilding. + self.publish_observation() + self._start_next_episode() + return + + self.publish_observation() + + if self.success and not self.loop: + self.get_logger().info(f"episode succeeded at step {self.step_index}") + self._slow_down_timer() + + def _start_next_episode(self): + try: + self.obs = self.env.new_episode() + self.env.validate(self.obs) + except Exception as exc: + self.get_logger().error(f"failed to start next episode: {exc}") + raise + self.episode_index += 1 + self.episode_step = 0 + # Keep the global observation counter strictly increasing so the policy always + # sees the new episode's first frame as "newer" than anything it processed. + self.step_index += 1 + self.success = False + self.get_logger().info(f"episode {self.episode_index} started (global step {self.step_index}): {self.env.task_description!r}") + self.publish_observation() + + def _slow_down_timer(self): + # After success the env stops stepping. Keep republishing the final frame at a + # low rate so the web viewer keeps showing the last image instead of going blank. + if self._slowed: + return + self._slowed = True + try: + self.timer.cancel() + except Exception: + pass + self.timer = self.create_timer(self.idle_republish_period, self.republish) + + def publish_observation(self): + stamp = self.get_clock().now().to_msg() + for cam, pub in self.image_pubs.items(): + pub.publish(rgb_to_image_msg(self.obs.images[cam], stamp, cam)) + + state_msg = Float32MultiArray() + state_msg.data = np.asarray(self.obs.state, dtype=np.float32).reshape(-1).tolist() + self.state_pub.publish(state_msg) + + task_msg = String() + task_msg.data = self.env.task_description or "" + self.task_pub.publish(task_msg) + + success_msg = Bool() + success_msg.data = self.success + self.success_pub.publish(success_msg) + + episode_msg = Int32() + episode_msg.data = self.episode_index + self.episode_pub.publish(episode_msg) + + ready_msg = Int32() + ready_msg.data = self.step_index + self.observation_ready_pub.publish(ready_msg) + + def destroy_node(self): + try: + if hasattr(self, "env"): + self.env.close() + finally: + super().destroy_node() + + +def run_simulator_node( + contract: EnvContract, + env_factory: Callable[["SimulatorNode"], BaseSimEnv], + *, + node_name: str, + args=None, +): + rclpy.init(args=args) + node = SimulatorNode(contract, env_factory, node_name=node_name) + try: + rclpy.spin(node) + except KeyboardInterrupt: + pass + except Exception: + if rclpy.ok(): + raise + finally: + node.destroy_node() + if rclpy.ok(): + rclpy.shutdown() diff --git a/lightx2v_ros/src/visualization/package.xml b/lightx2v_ros/src/visualization/package.xml index 3a8edb9a3..623605f72 100644 --- a/lightx2v_ros/src/visualization/package.xml +++ b/lightx2v_ros/src/visualization/package.xml @@ -13,6 +13,7 @@ sensor_msgs std_msgs python3-opencv + common ament_python diff --git a/lightx2v_ros/src/visualization/visualization/image_web_viewer_node/main.py b/lightx2v_ros/src/visualization/visualization/image_web_viewer_node/main.py index ac5dd34fc..f1ac047b5 100644 --- a/lightx2v_ros/src/visualization/visualization/image_web_viewer_node/main.py +++ b/lightx2v_ros/src/visualization/visualization/image_web_viewer_node/main.py @@ -4,18 +4,17 @@ import cv2 import numpy as np import rclpy +from common.contract import get_contract from rclpy.node import Node from sensor_msgs.msg import Image from std_msgs.msg import String -from .page import INDEX_HTML +from .page import render_index -AGENTVIEW_TOPIC = "/libero/agentview/image_raw" -WRIST_TOPIC = "/libero/wrist/image_raw" -FRONTVIEW_TOPIC = "/libero/frontview/image_raw" -GALLERYVIEW_TOPIC = "/libero/galleryview/image_raw" -TASK_TOPIC = "/libero/task_description" -CAMERAS = ("agentview", "wrist", "frontview", "galleryview") +# How long a stalled MJPEG stream waits before re-sending the last frame. This, +# together with the simulator re-publishing the final frame after success, keeps +# the page from going blank when no new frames arrive. +STREAM_KEEPALIVE_S = 2.0 class ImageHttpServer(ThreadingHTTPServer): @@ -23,9 +22,9 @@ class ImageHttpServer(ThreadingHTTPServer): class FrameStore: - def __init__(self): + def __init__(self, cameras): self.condition = Condition() - self.frames = {name: (0, None) for name in CAMERAS} + self.frames = {name: (0, None) for name in cameras} self.task = "" def update(self, name, jpeg): @@ -34,9 +33,9 @@ def update(self, name, jpeg): self.frames[name] = (seq + 1, jpeg) self.condition.notify_all() - def wait_next(self, name, last_seq): + def wait_next(self, name, last_seq, timeout=STREAM_KEEPALIVE_S): with self.condition: - self.condition.wait_for(lambda: self.frames[name][0] != last_seq) + self.condition.wait_for(lambda: self.frames[name][0] != last_seq, timeout=timeout) return self.frames[name] def update_task(self, task): @@ -52,44 +51,48 @@ class ImageWebViewerNode(Node): def __init__(self): super().__init__("image_web_viewer") + self.declare_parameter("env", "libero") self.declare_parameter("host", "127.0.0.1") self.declare_parameter("port", 8080) - self.declare_parameter("agentview_topic", AGENTVIEW_TOPIC) - self.declare_parameter("wrist_topic", WRIST_TOPIC) - self.declare_parameter("frontview_topic", FRONTVIEW_TOPIC) - self.declare_parameter("galleryview_topic", GALLERYVIEW_TOPIC) - self.declare_parameter("task_topic", TASK_TOPIC) self.declare_parameter("jpeg_quality", 85) + self.declare_parameter("cameras", []) + self.declare_parameter("namespace", "") + self.declare_parameter("task_topic", "") + + env = str(self.get_parameter("env").value).strip().lower() + contract = get_contract(env) + self.contract = contract + + cameras_param = list(self.get_parameter("cameras").value or []) + self.cameras = cameras_param if cameras_param else list(contract.cameras) + namespace = str(self.get_parameter("namespace").value).strip() or contract.namespace + task_topic = str(self.get_parameter("task_topic").value).strip() or contract.task_topic self.jpeg_quality = int(self.get_parameter("jpeg_quality").value) - self.frame_store = FrameStore() + self.frame_store = FrameStore(self.cameras) self.http_server = None self.http_thread = None - for name in CAMERAS: + for name in self.cameras: + topic = f"{namespace}/{name}/image_raw" self.create_subscription( Image, - self.get_parameter(f"{name}_topic").value, + topic, lambda msg, camera_name=name: self.on_image(camera_name, msg), 10, ) - self.create_subscription( - String, - self.get_parameter("task_topic").value, - self.on_task, - 10, - ) + self.create_subscription(String, task_topic, self.on_task, 10) self.start_http_server() def start_http_server(self): host = str(self.get_parameter("host").value) port = int(self.get_parameter("port").value) - handler = make_handler(self.frame_store) + handler = make_handler(self.frame_store, self.cameras, self.contract.name) self.http_server = ImageHttpServer((host, port), handler) self.http_thread = Thread(target=self.http_server.serve_forever, daemon=True) self.http_thread.start() - self.get_logger().info(f"image web viewer listening on http://{host}:{port}") + self.get_logger().info(f"[{self.contract.name}] image web viewer on http://{host}:{port} cameras={self.cameras}") def on_image(self, name, msg): try: @@ -121,7 +124,10 @@ def destroy_node(self): super().destroy_node() -def make_handler(frame_store): +def make_handler(frame_store, cameras, title): + camera_set = set(cameras) + index_html = render_index(cameras, title=f"LightX2V ROS ยท {title}") + class ImageWebViewerHandler(BaseHTTPRequestHandler): def do_GET(self): if self.path in {"/", "/index.html"}: @@ -130,22 +136,15 @@ def do_GET(self): if self.path == "/task.txt": self.send_task() return - if self.path == "/agentview.mjpg": - self.send_stream("agentview") - return - if self.path == "/wrist.mjpg": - self.send_stream("wrist") - return - if self.path == "/frontview.mjpg": - self.send_stream("frontview") - return - if self.path == "/galleryview.mjpg": - self.send_stream("galleryview") - return + if self.path.endswith(".mjpg"): + name = self.path[1 : -len(".mjpg")] + if name in camera_set: + self.send_stream(name) + return self.send_error(404) def send_index(self): - body = INDEX_HTML.encode("utf-8") + body = index_html.encode("utf-8") self.send_response(200) self.send_header("Content-Type", "text/html; charset=utf-8") self.send_header("Content-Length", str(len(body))) diff --git a/lightx2v_ros/src/visualization/visualization/image_web_viewer_node/page.py b/lightx2v_ros/src/visualization/visualization/image_web_viewer_node/page.py index f1f297a87..414e87f08 100644 --- a/lightx2v_ros/src/visualization/visualization/image_web_viewer_node/page.py +++ b/lightx2v_ros/src/visualization/visualization/image_web_viewer_node/page.py @@ -1,10 +1,12 @@ -INDEX_HTML = """ - - - - - LightX2V ROS - + .task span { color: var(--muted); } +""" + +_VIEW_TEMPLATE = """
+

{label}

+ {label} +
""" + + +def render_index(cameras, title="LightX2V ROS"): + columns = max(1, min(len(cameras), 4)) + style = _STYLE.replace("__COLUMNS__", str(columns)) + views = "\n".join(_VIEW_TEMPLATE.format(name=html.escape(str(cam)), label=html.escape(str(cam))) for cam in cameras) + safe_title = html.escape(str(title)) + return f""" + + + + + {safe_title} +
-

LightX2V ROS

+

{safe_title}

live streams
-
-

agentview

- agentview -
-
-

wrist

- wrist -
-
-

frontview

- frontview -
-
-

galleryview

- galleryview -
+{views}
waiting for task description