Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
[submodule "lightx2v_ros/src/simulator/simulator/libero_node/LIBERO"]
path = lightx2v_ros/src/simulator/simulator/libero_node/LIBERO
url = https://github.com/Lifelong-Robot-Learning/LIBERO.git
[submodule "lightx2v_ros/src/simulator/simulator/robotwin_node/RoboTwin"]
path = lightx2v_ros/src/simulator/simulator/robotwin_node/RoboTwin
url = https://github.com/robotwin-Platform/robotwin.git
branch = main
3 changes: 3 additions & 0 deletions configs/fastwam/libero_i2va.json
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@
"action_dim_hidden": 1024,
"action_dim": 7,
"robot_state_dim": 8,
"policy_profile": "libero",
"normalize_mode": "min-max",
"binarize_gripper": true,
"gripper_postprocess": true,
"default_prompt": "A video recorded from a robot's point of view executing the following instruction: {task_prompt}"
}
18 changes: 18 additions & 0 deletions configs/fastwam/robotwin_i2va.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
{
"adapter_model_path": "/data/nvme7/yongyang/fastwam_release/robotwin_uncond_3cam_384.pt",
"dataset_stats_path": "/data/nvme7/yongyang/fastwam_release/robotwin_uncond_3cam_384_dataset_stats.json",
"camera_size": 384,
"action_chunk_size": 32,
"actions_per_plan": 8,
"num_steps_wait": 0,
"action_infer_steps": 20,
"action_sample_shift": 5.0,
"action_dim_hidden": 1024,
"action_dim": 14,
"robot_state_dim": 14,
"policy_profile": "robotwin",
"normalize_mode": "z-score",
"binarize_gripper": false,
"gripper_postprocess": false,
"default_prompt": "A video recorded from a robot's point of view executing the following instruction: {task_prompt}"
}
123 changes: 96 additions & 27 deletions lightx2v/models/runners/wan/fastwam_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,18 +25,47 @@ def resize_rgb(image, width, height):
return np.asarray(pil.resize((width, height), resample=Image.BILINEAR), dtype=np.uint8)


class MinMaxNormalizer:
def _canonical_norm_mode(mode):
compact = str(mode).strip().lower().replace("-", "").replace("_", "").replace("/", "")
if compact == "minmax":
return "min/max"
if compact == "zscore":
return "z-score"
if compact == "q01q99":
return "q01/q99"
raise ValueError(f"unsupported normalize_mode: {mode!r}")


class LinearNormalizer:
"""Affine normalizer matching FastWAM's SingleFieldLinearNormalizer (global stats).

Modes used by the released checkpoints:
- "min/max" (LIBERO): maps [global_min, global_max] -> [-1, 1]
- "q01/q99": maps [global_q01, global_q99] -> [-1, 1]
- "z-score" (RoboTwin): (x - global_mean) / global_std
"""

std_reg = 1e-8
range_tol = 1e-4

def __init__(self, stats):
min_v = torch.as_tensor(stats["global_min"], dtype=torch.float32)
max_v = torch.as_tensor(stats["global_max"], dtype=torch.float32)
input_range = max_v - min_v
ignore = input_range < self.range_tol
input_range[ignore] = 2.0
self.scale = 2.0 / input_range
self.offset = -1.0 - self.scale * min_v
self.offset[ignore] = -min_v[ignore]
output_min = -1.0
output_max = 1.0

def __init__(self, stats, mode="min/max"):
mode = _canonical_norm_mode(mode)
g = {key[len("global_") :]: torch.as_tensor(value, dtype=torch.float32) for key, value in stats.items() if key.startswith("global_")}
if mode == "z-score":
mean, std = g["mean"], g["std"]
self.scale = 1.0 / (std + self.std_reg)
self.offset = -mean / (std + self.std_reg)
else:
low, high = (g["min"].clone(), g["max"].clone()) if mode == "min/max" else (g["q01"].clone(), g["q99"].clone())
input_range = high - low
ignore = input_range < self.range_tol
input_range[ignore] = self.output_max - self.output_min
self.scale = (self.output_max - self.output_min) / input_range
self.offset = self.output_min - self.scale * low
self.offset[ignore] = (self.output_max + self.output_min) / 2 - low[ignore]
self.mode = mode

def forward(self, x):
x = torch.as_tensor(x, dtype=torch.float32)
Expand All @@ -47,6 +76,13 @@ def backward(self, x):
return (x - self.offset) / self.scale


class MinMaxNormalizer(LinearNormalizer):
"""Backwards-compatible alias: LIBERO used a global min/max normalizer."""

def __init__(self, stats):
super().__init__(stats, mode="min/max")


class FastWAMPolicy:
def __init__(
self,
Expand All @@ -65,6 +101,9 @@ def __init__(
binarize_gripper=True,
default_prompt=None,
camera_size=224,
policy_profile="libero",
normalize_mode="min/max",
gripper_postprocess=None,
config=None,
):
self.device = torch.device(device)
Expand Down Expand Up @@ -103,6 +142,10 @@ def __init__(
self.seed = None if seed is None or int(seed) < 0 else int(seed)
self.binarize_gripper = bool(binarize_gripper)
self.default_prompt = str(default_prompt)
self.policy_profile = str(policy_profile).strip().lower()

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

目前代码中没有对 policy_profile 的有效性进行校验。如果用户传入了不支持的 profile(例如拼写错误),代码会默认回退到 "libero" 的处理逻辑,这可能会导致难以排查的维度不匹配或 KeyError 错误。

建议在初始化时对 policy_profile 进行校验,仅允许 "libero""robotwin"

Suggested change
self.policy_profile = str(policy_profile).strip().lower()
self.policy_profile = str(policy_profile).strip().lower()
if self.policy_profile not in {"libero", "robotwin"}:
raise ValueError(f"Unsupported policy_profile: {self.policy_profile}. Expected 'libero' or 'robotwin'.")

self.normalize_mode = _canonical_norm_mode(normalize_mode)
# LIBERO post-processes the single gripper channel; RoboTwin (dual-arm qpos) does not.
self.gripper_postprocess = (self.policy_profile == "libero") if gripper_postprocess is None else bool(gripper_postprocess)
self.config = config
self._check_model_config()
self.pending_actions = deque()
Expand Down Expand Up @@ -135,6 +178,9 @@ def from_config(cls, config):
binarize_gripper=config.get("binarize_gripper", True),
default_prompt=config.get("default_prompt"),
camera_size=config.get("camera_size", 224),
policy_profile=config.get("policy_profile", "libero"),
normalize_mode=config.get("normalize_mode", "min/max"),
gripper_postprocess=config.get("gripper_postprocess"),
config=config,
)

Expand All @@ -147,8 +193,8 @@ def _load_normalizers(self):
with open(self.dataset_stats_path, "r", encoding="utf-8") as f:
stats = json.load(f)
return (
MinMaxNormalizer(stats["state"]["default"]),
MinMaxNormalizer(stats["action"]["default"]),
LinearNormalizer(stats["state"]["default"], self.normalize_mode),
LinearNormalizer(stats["action"]["default"], self.normalize_mode),
)

def _load_text_encoder(self):
Expand Down Expand Up @@ -188,17 +234,17 @@ def _find_model_dir(self, dirname):
def reset(self):
self.pending_actions.clear()

def next_action(self, agentview_rgb, wrist_rgb, state, task_description):
def next_action(self, images, state, task_description):
if not self.pending_actions:
action_chunk = self.predict_action_chunk(agentview_rgb, wrist_rgb, state, task_description)
action_chunk = self.predict_action_chunk(images, state, task_description)
for action in action_chunk[: self.actions_per_plan]:
self.pending_actions.append(np.asarray(action, dtype=np.float32))
if not self.pending_actions:
raise RuntimeError("FastWAM produced an empty action chunk")
return self.pending_actions.popleft()

def predict_action_chunk(self, agentview_rgb, wrist_rgb, state, task_description, seed=None):
image = self.build_image_tensor(agentview_rgb, wrist_rgb)
def predict_action_chunk(self, images, state, task_description, seed=None):
image = self.build_image_tensor(images)
first_frame_latents = self.encode_image_latents(image)
context, context_mask = self.encode_prompt(self.default_prompt.format(task_prompt=task_description))
robot_state = self.state_normalizer.forward(np.asarray(state, dtype=np.float32))
Expand All @@ -216,10 +262,12 @@ def predict_action_chunk(self, agentview_rgb, wrist_rgb, state, task_description
seed=self.seed if seed is None else seed,
)
action = self.action_normalizer.backward(action).numpy()
action[..., -1] = action[..., -1] * 2 - 1
action[..., -1] = action[..., -1] * -1.0
if self.binarize_gripper:
action[..., -1] = np.sign(action[..., -1])
if self.gripper_postprocess:
# LIBERO single-gripper channel: map [0,1] -> [-1,1], flip sign, optional binarize.
action[..., -1] = action[..., -1] * 2 - 1
action[..., -1] = action[..., -1] * -1.0
if self.binarize_gripper:
action[..., -1] = np.sign(action[..., -1])
return action.astype(np.float32)

def _run_action_denoising(self, inputs, action_shape, action_infer_steps, seed):
Expand All @@ -239,13 +287,35 @@ def _run_action_denoising(self, inputs, action_shape, action_infer_steps, seed):
scheduler.clear()
return action

def build_image_tensor(self, agentview_rgb, wrist_rgb):
primary = resize_rgb(agentview_rgb, self.camera_size, self.camera_size)
wrist = resize_rgb(wrist_rgb, self.camera_size, self.camera_size)
rgb = np.concatenate([primary, wrist], axis=1)
def build_image_tensor(self, images):
"""Compose the policy image from a dict of {camera_name: HxWx3 RGB}.

Layouts mirror the released FastWAM checkpoints:
- LIBERO : [agentview | wrist] side-by-side, each camera_size x camera_size.
- RoboTwin : head (320x256) on top, [left | right] (each 160x128) below,
i.e. final [384, 320, 3]; matches deploy_policy.py.
"""
rgb = self._compose_rgb(images)
tensor = torch.from_numpy(rgb).permute(2, 0, 1).to(device=self.device, dtype=GET_DTYPE())
return tensor * (2.0 / 255.0) - 1.0

def _compose_rgb(self, images):
def _get(name):
if name not in images or images[name] is None:
raise KeyError(f"FastWAM profile '{self.policy_profile}' requires camera '{name}'")
return images[name]

if self.policy_profile == "robotwin":
head = resize_rgb(_get("head_camera"), 320, 256)
left = resize_rgb(_get("left_camera"), 160, 128)
right = resize_rgb(_get("right_camera"), 160, 128)
bottom = np.concatenate([left, right], axis=1)
return np.concatenate([head, bottom], axis=0)

primary = resize_rgb(_get("agentview"), self.camera_size, self.camera_size)
wrist = resize_rgb(_get("wrist"), self.camera_size, self.camera_size)
return np.concatenate([primary, wrist], axis=1)

def encode_image_latents(self, image):
image = image.unsqueeze(1)
latents = self.vae.encode(image.unsqueeze(0))
Expand Down Expand Up @@ -352,8 +422,7 @@ def run_pipeline(self, input_info):
agentview, wrist = self._load_image_pair()
state = self._load_state()
actions = self.policy.predict_action_chunk(
agentview_rgb=agentview,
wrist_rgb=wrist,
images={"agentview": agentview, "wrist": wrist},
state=state,
task_description=self.input_info.prompt,
seed=self.input_info.seed,
Expand Down
15 changes: 15 additions & 0 deletions lightx2v_ros/src/common/common/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from .contract import (
CONTRACTS,
LIBERO_CONTRACT,
ROBOTWIN_CONTRACT,
EnvContract,
get_contract,
)

__all__ = [
"CONTRACTS",
"EnvContract",
"LIBERO_CONTRACT",
"ROBOTWIN_CONTRACT",
"get_contract",
]
115 changes: 115 additions & 0 deletions lightx2v_ros/src/common/common/contract.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
"""Single source of truth for the simulator/inference/visualization contract.

This module is intentionally dependency-free (pure Python, no ROS/torch imports)
so it can be imported by every ROS node as well as plain scripts. It defines, per
simulation environment, the ROS topic namespace, the set of cameras, the
action/state dimensions and the inference profile. All three nodes derive their
topic names and tensor dimensions from the same `EnvContract`, which removes the
LIBERO-specific hard-coding that used to live in each node.
"""

from dataclasses import dataclass, field
from typing import Dict, Tuple


@dataclass(frozen=True)
class EnvContract:
# Logical environment id, e.g. "libero" or "robotwin".
name: str
# ROS topic prefix, e.g. "/libero".
namespace: str
# All cameras the simulator publishes (logical names, also used by the viewer).
cameras: Tuple[str, ...]
# Subset of `cameras` fed to the policy, in the exact order the policy expects.
policy_input_cameras: Tuple[str, ...]
# Robot action / proprio-state vector dimensions exchanged over ROS.
action_dim: int
state_dim: int
# Square render/publish size hint (used by simulators that render on demand).
image_size: int
# Inference assembly profile understood by FastWAMPolicy ("libero"|"robotwin").
policy_profile: str
# Action/state normalization mode ("minmax"|"zscore").
normalize_mode: str
# Whether to apply the LIBERO single-gripper sign/binarize post-processing.
gripper_postprocess: bool
# Optional human-readable description.
description: str = field(default="")

# ----- derived topic helpers (kept in one place so nodes never disagree) -----
@property
def action_topic(self) -> str:
return f"{self.namespace}/action"

@property
def state_topic(self) -> str:
return f"{self.namespace}/state"

@property
def success_topic(self) -> str:
return f"{self.namespace}/success"

@property
def observation_ready_topic(self) -> str:
return f"{self.namespace}/observation_ready"

@property
def task_topic(self) -> str:
return f"{self.namespace}/task_description"

@property
def episode_topic(self) -> str:
# Monotonic episode counter; lets the policy reset its per-episode state when
# the simulator loops into a fresh episode.
return f"{self.namespace}/episode"

def camera_topic(self, camera: str) -> str:
if camera not in self.cameras:
raise KeyError(f"camera '{camera}' is not part of contract '{self.name}': {self.cameras}")
return f"{self.namespace}/{camera}/image_raw"

def camera_topics(self) -> Dict[str, str]:
return {camera: self.camera_topic(camera) for camera in self.cameras}


LIBERO_CONTRACT = EnvContract(
name="libero",
namespace="/libero",
cameras=("agentview", "wrist", "frontview", "galleryview"),
policy_input_cameras=("agentview", "wrist"),
action_dim=7,
state_dim=8,
image_size=224,
policy_profile="libero",
normalize_mode="minmax",
gripper_postprocess=True,
description="LIBERO (robosuite/mujoco) single-arm tabletop manipulation.",
)


ROBOTWIN_CONTRACT = EnvContract(
name="robotwin",
namespace="/robotwin",
cameras=("head_camera", "left_camera", "right_camera"),
policy_input_cameras=("head_camera", "left_camera", "right_camera"),
action_dim=14,
state_dim=14,
image_size=384,
policy_profile="robotwin",
normalize_mode="zscore",
gripper_postprocess=False,
description="RoboTwin 2.0 (SAPIEN) dual-arm manipulation.",
)


CONTRACTS: Dict[str, EnvContract] = {
LIBERO_CONTRACT.name: LIBERO_CONTRACT,
ROBOTWIN_CONTRACT.name: ROBOTWIN_CONTRACT,
}


def get_contract(name: str) -> EnvContract:
key = str(name).strip().lower()
if key not in CONTRACTS:
raise KeyError(f"unknown environment '{name}'. Available: {sorted(CONTRACTS)}")
return CONTRACTS[key]
15 changes: 15 additions & 0 deletions lightx2v_ros/src/common/package.xml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
<?xml version="1.0"?>
<?xml-model href="http://download.ros.org/schema/package_format3.xsd" schematypens="http://www.w3.org/2001/XMLSchema"?>
<package format="3">
<name>common</name>
<version>0.0.1</version>
<description>Shared environment contract for LightX2V ROS nodes.</description>
<maintainer email="user@example.com">user</maintainer>
<license>Apache-2.0</license>

<buildtool_depend>ament_python</buildtool_depend>

<export>
<build_type>ament_python</build_type>
</export>
</package>
Empty file.
4 changes: 4 additions & 0 deletions lightx2v_ros/src/common/setup.cfg
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
[develop]
script_dir=$base/lib/common
[install]
install_scripts=$base/lib/common
Loading
Loading