From 20705c60c721156f005ddefb38a6cec61724edd3 Mon Sep 17 00:00:00 2001 From: Yang Date: Mon, 6 Apr 2026 12:42:34 +0800 Subject: [PATCH 01/25] Add the runner --- lightx2v/infer.py | 2 + .../runners/wan/wan_matrix_game3_runner.py | 932 ++++++++++++++++++ lightx2v/pipeline.py | 8 +- 3 files changed, 941 insertions(+), 1 deletion(-) create mode 100644 lightx2v/models/runners/wan/wan_matrix_game3_runner.py diff --git a/lightx2v/infer.py b/lightx2v/infer.py index 7a295c706..015c3a354 100755 --- a/lightx2v/infer.py +++ b/lightx2v/infer.py @@ -19,6 +19,7 @@ from lightx2v.models.runners.wan.wan_audio_runner import Wan22AudioRunner, WanAudioRunner # noqa: F401 from lightx2v.models.runners.wan.wan_distill_runner import WanDistillRunner # noqa: F401 from lightx2v.models.runners.wan.wan_matrix_game2_runner import WanSFMtxg2Runner # noqa: F401 +from lightx2v.models.runners.wan.wan_matrix_game3_runner import WanMatrixGame3Runner # noqa: F401 from lightx2v.models.runners.wan.wan_runner import Wan22MoeRunner, WanRunner # noqa: F401 from lightx2v.models.runners.wan.wan_sf_runner import WanSFRunner # noqa: F401 from lightx2v.models.runners.wan.wan_vace_runner import Wan22MoeVaceRunner, WanVaceRunner # noqa: F401 @@ -60,6 +61,7 @@ def main(): "wan2.2_moe", "lingbot_world", "wan2.2", + "wan2.2_matrix_game3", "wan2.2_moe_audio", "wan2.2_audio", "wan2.2_moe_distill", diff --git a/lightx2v/models/runners/wan/wan_matrix_game3_runner.py b/lightx2v/models/runners/wan/wan_matrix_game3_runner.py new file mode 100644 index 000000000..094d868ac --- /dev/null +++ b/lightx2v/models/runners/wan/wan_matrix_game3_runner.py @@ -0,0 +1,932 @@ +import importlib.util +import json +import sys +import types +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Optional + +import numpy as np +import torch +import torch.distributed as dist +import torchvision.transforms.functional as TF +from einops import rearrange +from PIL import Image +from loguru import logger + +from lightx2v.models.runners.wan.wan_runner import Wan22DenseRunner, build_wan_model_with_lora +from lightx2v.server.metrics import monitor_cli +from lightx2v.utils.envs import GET_DTYPE, torch_device_module +from lightx2v.utils.profiler import GET_RECORDER_MODE, ProfilingContext4DebugL1, ProfilingContext4DebugL2 +from lightx2v.utils.registry_factory import RUNNER_REGISTER +from lightx2v.utils.utils import best_output_size +from lightx2v_platform.base.global_var import AI_DEVICE + + +DEFAULT_MATRIX_GAME3_OFFICIAL_ROOT = Path("/home/michael/Project/LightX2V/Matrix-Game-3/Matrix-Game-3") +DEFAULT_MATRIX_GAME3_BASE_CONFIG = Path("/home/michael/Project/LightX2V/Matrix-Game-3.0/base_model/config.json") +DEFAULT_MATRIX_GAME3_DISTILLED_CONFIG = Path("/home/michael/Project/LightX2V/Matrix-Game-3.0/base_distilled_model/config.json") +_MATRIX_GAME3_OFFICIAL_PACKAGE = "_lightx2v_matrix_game3_official" + + +@dataclass +class MatrixGame3SegmentState: + segment_idx: int + first_clip: bool + current_start_frame_idx: int + current_end_frame_idx: int + frame_count: int + fixed_latent_frames: int + latent_shape: list[int] + decode_trim_frames: int + append_latent_start: int + keyboard_cond: torch.Tensor + mouse_cond: torch.Tensor + vae_encoder_out: torch.Tensor + dit_cond_dict: dict[str, Any] + + +def _load_module_from_path(module_name: str, file_path: Path): + if module_name in sys.modules: + return sys.modules[module_name] + spec = importlib.util.spec_from_file_location(module_name, file_path) + if spec is None or spec.loader is None: + raise ImportError(f"failed to load module {module_name} from {file_path}") + module = importlib.util.module_from_spec(spec) + sys.modules[module_name] = module + spec.loader.exec_module(module) + return module + + +def _ensure_namespace_package(package_name: str, package_path: Path): + if package_name in sys.modules: + return sys.modules[package_name] + module = types.ModuleType(package_name) + module.__path__ = [str(package_path)] + sys.modules[package_name] = module + return module + + +@RUNNER_REGISTER("wan2.2_matrix_game3") +class WanMatrixGame3Runner(Wan22DenseRunner): + """Runner-only Matrix-Game-3 adapter on top of the existing Wan2.2 lifecycle. + + Official provenance: + - CLI / mode semantics: Matrix-Game-3/generate.py + - Segment lengths / condition assembly: pipeline/inference_pipeline.py + - Interactive action refresh: pipeline/inference_interactive_pipeline.py + - Keyboard / mouse dimensions: utils/conditions.py + - Pose / plucker helpers: utils/cam_utils.py and utils/utils.py + - Structural config truth: Matrix-Game-3.0/*/config.json + """ + + def __init__(self, config): + with config.temporarily_unlocked(): + original_model_cls = str(config.get("model_cls", "wan2.2_matrix_game3")) + config["runner_model_cls"] = original_model_cls + config["model_cls"] = "wan2.2" + config["mode"] = "matrix_game3" + config["use_image_encoder"] = False + config["use_base_model"] = bool(config.get("use_base_model", False)) + if "sub_model_folder" not in config: + config["sub_model_folder"] = "base_model" if config["use_base_model"] else "base_distilled_model" + config["num_channels_latents"] = int(config.get("num_channels_latents", 48)) + config["vae_stride"] = tuple(config.get("vae_stride", (4, 16, 16))) + config["patch_size"] = tuple(config.get("patch_size", (1, 2, 2))) + super().__init__(config) + + self.matrix_game3_model_cls = original_model_cls + self.first_clip_frame = 57 + self.clip_frame = 56 + self.incremental_segment_frames = 40 + self.past_frame = 16 + self.conditioning_latent_frames = 4 + self.mouse_dim_in = 2 + self.keyboard_dim_in = 6 + self._segment_states: dict[int, MatrixGame3SegmentState] = {} + self._official_modules: Optional[dict[str, Any]] = None + self._mg3_lat_h: Optional[int] = None + self._mg3_lat_w: Optional[int] = None + self._mg3_target_h: Optional[int] = None + self._mg3_target_w: Optional[int] = None + self._mg3_base_intrinsics: Optional[torch.Tensor] = None + self._mg3_intrinsics_all: Optional[torch.Tensor] = None + self._mg3_keyboard_all: Optional[torch.Tensor] = None + self._mg3_mouse_all: Optional[torch.Tensor] = None + self._mg3_extrinsics_all: Optional[torch.Tensor] = None + self._mg3_num_iterations: int = 1 + self._mg3_expected_total_frames: int = self.first_clip_frame + self._mg3_interactive = bool(self.config.get("interactive", False)) + self._mg3_last_pose = np.zeros(5, dtype=np.float32) + self._mg3_current_segment_state: Optional[MatrixGame3SegmentState] = None + self._mg3_current_segment_full_latents: Optional[torch.Tensor] = None + self._mg3_generated_latent_history: list[torch.Tensor] = [] + self._mg3_tail_latents: Optional[torch.Tensor] = None + self._mg3_noise_generator: Optional[torch.Generator] = None + self._load_matrix_game3_model_config() + + def set_inputs(self, inputs): + super().set_inputs(inputs) + if "action_path" in self.input_info.__dataclass_fields__: + self.input_info.action_path = inputs.get("action_path", inputs.get("pose", "")) + if "pose" in self.input_info.__dataclass_fields__: + self.input_info.pose = inputs.get("pose", inputs.get("action_path", "")) + + def load_transformer(self): + from lightx2v.models.networks.wan.matrix_game3_model import WanMtxg3Model + + model_kwargs = { + "model_path": self.config["model_path"], + "config": self.config, + "device": self.init_device, + } + lora_configs = self.config.get("lora_configs") + if not lora_configs: + return WanMtxg3Model(**model_kwargs) + return build_wan_model_with_lora(WanMtxg3Model, self.config, model_kwargs, lora_configs, model_type="wan2.2") + + def _load_matrix_game3_model_config(self): + config_path = Path(self.config["model_path"]) / self.config["sub_model_folder"] / "config.json" + if not config_path.exists(): + config_path = DEFAULT_MATRIX_GAME3_BASE_CONFIG if self.config["use_base_model"] else DEFAULT_MATRIX_GAME3_DISTILLED_CONFIG + if not config_path.exists(): + logger.warning("matrix-game-3 config.json not found at {}", config_path) + return + + with config_path.open("r") as f: + model_config = json.load(f) + + with self.config.temporarily_unlocked(): + self.config.update(model_config) + self.config["num_channels_latents"] = int(model_config.get("in_dim", self.config.get("num_channels_latents", 48))) + self.config["vae_stride"] = tuple(self.config.get("vae_stride", (4, 16, 16))) + self.config["patch_size"] = tuple(model_config.get("patch_size", self.config.get("patch_size", (1, 2, 2)))) + + action_config = self.config.get("action_config", {}) + self.keyboard_dim_in = int(action_config.get("keyboard_dim_in", 6)) + self.mouse_dim_in = int(action_config.get("mouse_dim_in", 2)) + + def _get_official_modules(self) -> dict[str, Any]: + if self._official_modules is not None: + return self._official_modules + + official_root = Path(self.config.get("matrix_game3_official_root", DEFAULT_MATRIX_GAME3_OFFICIAL_ROOT)) + if not official_root.exists(): + raise FileNotFoundError(f"Matrix-Game-3 official root not found: {official_root}") + + _ensure_namespace_package(_MATRIX_GAME3_OFFICIAL_PACKAGE, official_root) + utils_pkg = f"{_MATRIX_GAME3_OFFICIAL_PACKAGE}.utils" + _ensure_namespace_package(utils_pkg, official_root / "utils") + + modules = { + "conditions": _load_module_from_path(f"{utils_pkg}.conditions", official_root / "utils" / "conditions.py"), + "cam_utils": _load_module_from_path(f"{utils_pkg}.cam_utils", official_root / "utils" / "cam_utils.py"), + "transform": _load_module_from_path(f"{utils_pkg}.transform", official_root / "utils" / "transform.py"), + "utils": _load_module_from_path(f"{utils_pkg}.utils", official_root / "utils" / "utils.py"), + } + self._official_modules = modules + return modules + + def _get_expected_total_frames(self, raw_total_frames: Optional[int] = None) -> tuple[int, int]: + num_iterations = self.config.get("num_iterations", None) + if num_iterations is not None: + num_iterations = max(int(num_iterations), 1) + return num_iterations, self.first_clip_frame + (num_iterations - 1) * self.incremental_segment_frames + + if raw_total_frames is None: + return 1, self.first_clip_frame + + if raw_total_frames <= self.first_clip_frame: + return 1, self.first_clip_frame + + additional_frames = raw_total_frames - self.first_clip_frame + num_iterations = 1 + max(additional_frames // self.incremental_segment_frames, 0) + expected_total_frames = self.first_clip_frame + (num_iterations - 1) * self.incremental_segment_frames + if additional_frames % self.incremental_segment_frames != 0: + logger.warning( + "[matrix-game-3] raw control sequence has {} frames; truncating tail to {} frames so it matches 57 + 40*k.", + raw_total_frames, + expected_total_frames, + ) + return num_iterations, expected_total_frames + + def _segment_latent_shape(self, lat_h: int, lat_w: int, frame_count: int) -> list[int]: + return [ + self.config.get("num_channels_latents", 48), + (frame_count - 1) // self.config["vae_stride"][0] + 1, + lat_h, + lat_w, + ] + + @ProfilingContext4DebugL1( + "Run VAE Encoder", + recorder_mode=GET_RECORDER_MODE(), + metrics_func=monitor_cli.lightx2v_run_vae_encoder_image_duration, + metrics_labels=["WanMatrixGame3Runner"], + ) + def run_vae_encoder(self, img): + max_area = self.config.target_height * self.config.target_width + ih, iw = img.height, img.width + dh = self.config.patch_size[1] * self.config.vae_stride[1] + dw = self.config.patch_size[2] * self.config.vae_stride[2] + ow, oh = best_output_size(iw, ih, dw, dh, max_area) + + scale = max(ow / iw, oh / ih) + img = img.resize((round(iw * scale), round(ih * scale)), Image.LANCZOS) + x1 = (img.width - ow) // 2 + y1 = (img.height - oh) // 2 + img = img.crop((x1, y1, x1 + ow, y1 + oh)) + + image_tensor = TF.to_tensor(img).sub_(0.5).div_(0.5).to(AI_DEVICE).unsqueeze(1) + first_frame_latent = self.get_vae_encoder_output(image_tensor) + lat_h = oh // self.config["vae_stride"][1] + lat_w = ow // self.config["vae_stride"][2] + latent_shape = self._segment_latent_shape(lat_h, lat_w, self.first_clip_frame) + vae_encoder_out = torch.zeros(latent_shape, device=first_frame_latent.device, dtype=first_frame_latent.dtype) + vae_encoder_out[:, : first_frame_latent.shape[1]] = first_frame_latent + return vae_encoder_out, latent_shape + + @ProfilingContext4DebugL2("Run Encoders") + def _run_input_encoder_local_i2v(self): + _, img_ori = self.read_image_input(self.input_info.image_path) + vae_encoder_out, latent_shape = self.run_vae_encoder(img_ori) + self.input_info.latent_shape = latent_shape + text_encoder_output = self.run_text_encoder(self.input_info) + self._prepare_matrix_game3_session(img_ori, latent_shape, vae_encoder_out) + torch_device_module.empty_cache() + return self.get_encoder_output_i2v(None, vae_encoder_out, text_encoder_output) + + def get_encoder_output_i2v(self, clip_encoder_out, vae_encoder_out, text_encoder_output, img=None): + image_encoder_output = { + "clip_encoder_out": clip_encoder_out, + "vae_encoder_out": vae_encoder_out, + "dit_cond_dict": {}, + } + return { + "text_encoder_output": text_encoder_output, + "image_encoder_output": image_encoder_output, + } + + def _prepare_matrix_game3_session(self, pil_image: Image.Image, latent_shape: list[int], vae_encoder_out: torch.Tensor): + # Official source: + # - Non-interactive path mirrors pipeline/inference_pipeline.py + # - Interactive segment refreshing mirrors pipeline/inference_interactive_pipeline.py + # - Camera/action fallback semantics follow the user's requested runner contract + self._get_official_modules() + self._segment_states.clear() + self._mg3_generated_latent_history = [] + self._mg3_tail_latents = None + self._mg3_current_segment_state = None + self._mg3_current_segment_full_latents = None + self._mg3_interactive = bool(self.config.get("interactive", False)) + self._mg3_last_pose = np.zeros(5, dtype=np.float32) + self._mg3_lat_h = int(latent_shape[-2]) + self._mg3_lat_w = int(latent_shape[-1]) + self._mg3_target_h = self._mg3_lat_h * self.config["vae_stride"][1] + self._mg3_target_w = self._mg3_lat_w * self.config["vae_stride"][2] + self._mg3_base_intrinsics = self._default_intrinsics().to(dtype=torch.float32) + + if self._mg3_interactive: + num_iterations = self.config.get("num_iterations", 1) + self._mg3_num_iterations = max(int(num_iterations), 1) + self._mg3_expected_total_frames = self.first_clip_frame + (self._mg3_num_iterations - 1) * self.incremental_segment_frames + self._mg3_keyboard_all = None + self._mg3_mouse_all = None + self._mg3_extrinsics_all = None + self._mg3_intrinsics_all = None + return + + action_path = self.input_info.action_path or self.input_info.pose or "" + raw_controls = self._load_control_payload(action_path) + raw_total_frames = self._infer_raw_total_frames(raw_controls) + self._mg3_num_iterations, self._mg3_expected_total_frames = self._get_expected_total_frames(raw_total_frames) + self._mg3_keyboard_all, self._mg3_mouse_all, self._mg3_extrinsics_all, self._mg3_intrinsics_all = self._build_noninteractive_controls(raw_controls) + + def _infer_raw_total_frames(self, payload: dict[str, Any]) -> Optional[int]: + lengths = [] + for value in payload.values(): + if value is None: + continue + if isinstance(value, np.ndarray): + if value.ndim >= 1: + lengths.append(int(value.shape[0])) + elif isinstance(value, torch.Tensor): + if value.ndim >= 1: + lengths.append(int(value.shape[0])) + elif isinstance(value, list): + lengths.append(len(value)) + return max(lengths) if lengths else None + + def _load_control_payload(self, action_path: str) -> dict[str, Any]: + if not action_path: + logger.warning("[matrix-game-3] action_path missing, fallback to zero keyboard/mouse and identity poses.") + return {} + + path = Path(action_path) + if not path.exists(): + logger.warning("[matrix-game-3] action_path not found: {}. Fallback to zero keyboard/mouse and identity poses.", action_path) + return {} + + if path.is_dir(): + return self._load_control_payload_from_dir(path) + return self._load_control_payload_from_file(path) + + def _load_control_payload_from_dir(self, path: Path) -> dict[str, Any]: + payload: dict[str, Any] = {} + name_groups = { + "keyboard_cond": ["keyboard_cond.npy", "keyboard_condition.npy", "keyboard_cond.pt", "keyboard_condition.pt", "keyboard_cond.json", "keyboard_condition.json"], + "mouse_cond": ["mouse_cond.npy", "mouse_condition.npy", "mouse_cond.pt", "mouse_condition.pt", "mouse_cond.json", "mouse_condition.json"], + "poses": ["poses.npy", "pose.npy", "poses.pt", "pose.pt", "poses.json", "pose.json", "c2ws.npy", "c2w.npy"], + "intrinsics": ["intrinsics.npy", "intrinsics.pt", "intrinsics.json", "Ks.npy", "K.npy"], + } + for key, names in name_groups.items(): + for file_name in names: + candidate = path / file_name + if not candidate.exists(): + continue + payload[key] = self._load_control_payload_from_file(candidate).get(key) + break + return payload + + def _load_control_payload_from_file(self, path: Path) -> dict[str, Any]: + suffix = path.suffix.lower() + stem = path.stem.lower() + if suffix == ".npz": + data = dict(np.load(path, allow_pickle=True)) + return self._normalize_payload_keys(data) + if suffix == ".json": + with path.open("r") as f: + data = json.load(f) + return self._normalize_payload_keys(data) + if suffix == ".npy": + array = np.load(path, allow_pickle=True) + elif suffix in {".pt", ".pth"}: + array = torch.load(path, map_location="cpu") + if isinstance(array, dict): + return self._normalize_payload_keys(array) + else: + raise ValueError(f"unsupported action_path format: {path}") + + if "keyboard" in stem: + return {"keyboard_cond": array} + if "mouse" in stem: + return {"mouse_cond": array} + if "intrinsic" in stem or stem in {"k", "ks"}: + return {"intrinsics": array} + if "pose" in stem or "c2w" in stem: + return {"poses": array} + raise ValueError(f"unsupported action_path file name: {path}") + + def _normalize_payload_keys(self, data: dict[str, Any]) -> dict[str, Any]: + payload: dict[str, Any] = {} + key_aliases = { + "keyboard_cond": {"keyboard_cond", "keyboard_condition"}, + "mouse_cond": {"mouse_cond", "mouse_condition"}, + "poses": {"poses", "pose", "c2ws", "c2w", "extrinsics"}, + "intrinsics": {"intrinsics", "k", "ks"}, + } + for target_key, aliases in key_aliases.items(): + for key, value in data.items(): + if str(key).lower() in aliases: + payload[target_key] = value + break + return payload + + def _default_intrinsics(self) -> torch.Tensor: + modules = self._get_official_modules() + assert self._mg3_target_h is not None and self._mg3_target_w is not None + return modules["cam_utils"].get_intrinsics(self._mg3_target_h, self._mg3_target_w) + + def _to_tensor(self, value: Any, dtype=torch.float32) -> Optional[torch.Tensor]: + if value is None: + return None + if isinstance(value, torch.Tensor): + return value.detach().cpu().to(dtype=dtype) + if isinstance(value, np.ndarray): + return torch.from_numpy(value).to(dtype=dtype) + if isinstance(value, list): + return torch.tensor(value, dtype=dtype) + return torch.tensor(value, dtype=dtype) + + def _resize_time_axis(self, tensor: torch.Tensor, total_frames: int) -> torch.Tensor: + if tensor.shape[0] == total_frames: + return tensor + if tensor.shape[0] == 1: + return tensor.repeat(total_frames, *([1] * (tensor.ndim - 1))) + if tensor.shape[0] < total_frames: + pad = tensor[-1:].repeat(total_frames - tensor.shape[0], *([1] * (tensor.ndim - 1))) + logger.warning( + "[matrix-game-3] control length {} shorter than expected {}, padding with the last value.", + tensor.shape[0], + total_frames, + ) + return torch.cat([tensor, pad], dim=0) + logger.warning( + "[matrix-game-3] control length {} longer than expected {}, truncating the tail.", + tensor.shape[0], + total_frames, + ) + return tensor[:total_frames] + + def _normalize_keyboard_cond(self, value: Any, total_frames: int) -> torch.Tensor: + if value is None: + return torch.zeros((1, total_frames, self.keyboard_dim_in), dtype=torch.float32) + tensor = self._to_tensor(value) + if tensor.ndim == 1: + tensor = tensor.unsqueeze(0) + if tensor.ndim == 3 and tensor.shape[0] == 1: + tensor = tensor.squeeze(0) + if tensor.ndim != 2 or tensor.shape[-1] != self.keyboard_dim_in: + raise ValueError(f"keyboard_cond shape mismatch, expected [T,{self.keyboard_dim_in}], got {tuple(tensor.shape)}") + tensor = self._resize_time_axis(tensor, total_frames) + return tensor.unsqueeze(0) + + def _normalize_mouse_cond(self, value: Any, total_frames: int) -> torch.Tensor: + if value is None: + return torch.zeros((1, total_frames, self.mouse_dim_in), dtype=torch.float32) + tensor = self._to_tensor(value) + if tensor.ndim == 1: + tensor = tensor.unsqueeze(0) + if tensor.ndim == 3 and tensor.shape[0] == 1: + tensor = tensor.squeeze(0) + if tensor.ndim != 2 or tensor.shape[-1] != self.mouse_dim_in: + raise ValueError(f"mouse_cond shape mismatch, expected [T,{self.mouse_dim_in}], got {tuple(tensor.shape)}") + tensor = self._resize_time_axis(tensor, total_frames) + return tensor.unsqueeze(0) + + def _normalize_intrinsics(self, value: Any, total_frames: int) -> Optional[torch.Tensor]: + if value is None: + return None + tensor = self._to_tensor(value) + if tensor.ndim == 1: + if tensor.shape[0] == 4: + tensor = tensor.unsqueeze(0) + elif tensor.shape[0] == 9: + tensor = tensor.view(3, 3).unsqueeze(0) + if tensor.ndim == 3 and tensor.shape[-2:] == (3, 3): + tensor = torch.stack([tensor[..., 0, 0], tensor[..., 1, 1], tensor[..., 0, 2], tensor[..., 1, 2]], dim=-1) + if tensor.ndim != 2 or tensor.shape[-1] != 4: + raise ValueError(f"intrinsics shape mismatch, expected [T,4] or [T,3,3], got {tuple(tensor.shape)}") + return self._resize_time_axis(tensor, total_frames) + + def _normalize_poses(self, value: Any, total_frames: int) -> Optional[torch.Tensor]: + if value is None: + return None + tensor = self._to_tensor(value) + if tensor.ndim == 2 and tensor.shape[-1] == 5: + modules = self._get_official_modules() + rotations = np.concatenate([np.zeros((tensor.shape[0], 1), dtype=np.float32), tensor[:, 3:5].numpy()], axis=1).tolist() + positions = tensor[:, :3].numpy().tolist() + tensor = modules["cam_utils"].get_extrinsics(rotations, positions).to(dtype=torch.float32) + if tensor.ndim == 3 and tensor.shape[-2:] == (4, 4): + tensor = self._resize_time_axis(tensor, total_frames) + return tensor + raise ValueError(f"poses shape mismatch, expected [T,4,4] or [T,5], got {tuple(tensor.shape)}") + + def _build_noninteractive_controls(self, payload: dict[str, Any]) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + # Official source: + # - utils/conditions.py defines keyboard_dim_in=6 and mouse_dim_in=2 semantics + # - utils/utils.py computes poses from actions when explicit poses are absent + total_frames = self._mg3_expected_total_frames + keyboard_cond = self._normalize_keyboard_cond(payload.get("keyboard_cond"), total_frames) + mouse_cond = self._normalize_mouse_cond(payload.get("mouse_cond"), total_frames) + intrinsics_all = self._normalize_intrinsics(payload.get("intrinsics"), total_frames) + + poses = self._normalize_poses(payload.get("poses"), total_frames) + if poses is None: + modules = self._get_official_modules() + if not payload: + identity_pose = torch.eye(4, dtype=torch.float32).unsqueeze(0).repeat(total_frames, 1, 1) + poses = identity_pose + else: + first_pose = np.zeros(5, dtype=np.float32) + all_poses = modules["utils"].compute_all_poses_from_actions( + keyboard_cond.squeeze(0).cpu(), + mouse_cond.squeeze(0).cpu(), + first_pose=first_pose, + ) + positions = all_poses[:, :3].tolist() + rotations = np.concatenate([np.zeros((all_poses.shape[0], 1), dtype=np.float32), all_poses[:, 3:5]], axis=1).tolist() + poses = modules["cam_utils"].get_extrinsics(rotations, positions).to(dtype=torch.float32) + return keyboard_cond, mouse_cond, poses, intrinsics_all + + def get_video_segment_num(self): + self.video_segment_num = self._mg3_num_iterations + + def init_run(self): + self.gen_video_final = None + self.get_video_segment_num() + self._mg3_noise_generator = torch.Generator(device=AI_DEVICE).manual_seed(self.input_info.seed) + self._mg3_generated_latent_history = [] + self._mg3_tail_latents = None + self._mg3_current_segment_full_latents = None + self._mg3_current_segment_state = None + + if self.config.get("lazy_load", False) or self.config.get("unload_modules", False): + self.model = self.load_transformer() + self.model.set_scheduler(self.scheduler) + + self.model.scheduler.prepare(seed=self.input_info.seed, latent_shape=self.input_info.latent_shape, image_encoder_output=self.inputs["image_encoder_output"]) + self._apply_segment_scheduler_state(self._build_or_get_segment_state(0)) + self.inputs["image_encoder_output"]["vae_encoder_out"] = None + + def _append_interactive_segment_controls(self, segment_idx: int): + modules = self._get_official_modules() + first_clip = segment_idx == 0 + action_frames = self.first_clip_frame if first_clip else self.incremental_segment_frames + + if not dist.is_initialized() or dist.get_rank() == 0: + actions = self._prompt_current_action() + keyboard_curr = actions["keyboard"].repeat(action_frames, 1) + mouse_curr = actions["mouse"].repeat(action_frames, 1) + if first_clip: + first_pose = np.zeros(5, dtype=np.float32) + else: + first_pose = self._mg3_last_pose + all_poses, last_pose = modules["utils"].compute_all_poses_from_actions( + keyboard_curr.cpu(), + mouse_curr.cpu(), + first_pose=first_pose, + return_last_pose=True, + ) + positions = all_poses[:, :3].tolist() + rotations = np.concatenate([np.zeros((all_poses.shape[0], 1), dtype=np.float32), all_poses[:, 3:5]], axis=1).tolist() + extrinsics_curr = modules["cam_utils"].get_extrinsics(rotations, positions).to(dtype=torch.float32) + payload = [ + keyboard_curr.numpy(), + mouse_curr.numpy(), + extrinsics_curr.numpy(), + last_pose.astype(np.float32), + ] + else: + payload = [None, None, None, None] + + if dist.is_initialized(): + dist.broadcast_object_list(payload, src=0) + + keyboard_curr = torch.from_numpy(payload[0]).to(dtype=torch.float32).unsqueeze(0) + mouse_curr = torch.from_numpy(payload[1]).to(dtype=torch.float32).unsqueeze(0) + extrinsics_curr = torch.from_numpy(payload[2]).to(dtype=torch.float32) + self._mg3_last_pose = np.array(payload[3], dtype=np.float32) + + if self._mg3_keyboard_all is None: + self._mg3_keyboard_all = keyboard_curr + self._mg3_mouse_all = mouse_curr + self._mg3_extrinsics_all = extrinsics_curr + else: + self._mg3_keyboard_all = torch.cat([self._mg3_keyboard_all, keyboard_curr], dim=1) + self._mg3_mouse_all = torch.cat([self._mg3_mouse_all, mouse_curr], dim=1) + self._mg3_extrinsics_all = torch.cat([self._mg3_extrinsics_all, extrinsics_curr], dim=0) + + def _prompt_current_action(self) -> dict[str, torch.Tensor]: + cam_value = 0.1 + print() + print("-" * 30) + print("PRESS [I, K, J, L, U] FOR CAMERA TRANSFORM") + print("(I: up, K: down, J: left, L: right, U: no move)") + print("PRESS [W, S, A, D, Q] FOR MOVEMENT") + print("(W: forward, S: back, A: left, D: right, Q: no move)") + print("-" * 30) + + camera_value_map = { + "i": [cam_value, 0.0], + "k": [-cam_value, 0.0], + "j": [0.0, -cam_value], + "l": [0.0, cam_value], + "u": [0.0, 0.0], + } + keyboard_idx = { + "w": [1, 0, 0, 0, 0, 0], + "s": [0, 1, 0, 0, 0, 0], + "a": [0, 0, 1, 0, 0, 0], + "d": [0, 0, 0, 1, 0, 0], + "q": [0, 0, 0, 0, 0, 0], + } + while True: + idx_mouse = input("Please input the mouse action (e.g. `U`):\n").strip().lower() + idx_keyboard = input("Please input the keyboard action (e.g. `W`):\n").strip().lower() + if idx_mouse in camera_value_map and idx_keyboard in keyboard_idx: + return { + "mouse": torch.tensor(camera_value_map[idx_mouse], dtype=torch.float32), + "keyboard": torch.tensor(keyboard_idx[idx_keyboard], dtype=torch.float32), + } + + def _interpolate_intrinsics(self, intrinsics_seq: Optional[torch.Tensor], src_indices: np.ndarray, tgt_indices: np.ndarray) -> torch.Tensor: + assert self._mg3_base_intrinsics is not None + if intrinsics_seq is None: + return self._mg3_base_intrinsics.to(dtype=torch.float32).repeat(len(tgt_indices), 1) + + intrinsics_seq = intrinsics_seq.to(dtype=torch.float32) + if intrinsics_seq.shape[0] == 1: + return intrinsics_seq.repeat(len(tgt_indices), 1) + + src_indices = np.asarray(src_indices, dtype=np.float32) + tgt_indices = np.asarray(tgt_indices, dtype=np.float32) + src_indices = np.clip(np.round(src_indices).astype(np.int64), 0, intrinsics_seq.shape[0] - 1) + src_intrinsics = intrinsics_seq[src_indices] + out = [] + for column_idx in range(src_intrinsics.shape[-1]): + column = np.interp(tgt_indices, src_indices.astype(np.float32), src_intrinsics[:, column_idx].cpu().numpy()) + out.append(torch.from_numpy(column).to(dtype=torch.float32)) + return torch.stack(out, dim=-1) + + def _build_plucker_from_c2ws( + self, + c2ws_seq: torch.Tensor, + src_indices: np.ndarray, + tgt_indices: np.ndarray, + framewise: bool, + intrinsics_seq: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + # Official source: + # - utils/cam_utils.py: interpolate poses, compute relative poses, build plucker rays + # - utils/utils.py: build_plucker_from_c2ws reshaping convention + modules = self._get_official_modules() + assert self._mg3_target_h is not None and self._mg3_target_w is not None + assert self._mg3_lat_h is not None and self._mg3_lat_w is not None + c2ws_np = c2ws_seq.cpu().numpy() + c2ws_infer = modules["cam_utils"]._interpolate_camera_poses_handedness( + src_indices=src_indices, + src_rot_mat=c2ws_np[:, :3, :3], + src_trans_vec=c2ws_np[:, :3, 3], + tgt_indices=tgt_indices, + ).to(device=c2ws_seq.device) + c2ws_infer = modules["cam_utils"].compute_relative_poses(c2ws_infer, framewise=framewise) + Ks = self._interpolate_intrinsics(intrinsics_seq, src_indices, tgt_indices).to(device=c2ws_infer.device, dtype=c2ws_infer.dtype) + plucker = modules["cam_utils"].get_plucker_embeddings(c2ws_infer, Ks, self._mg3_target_h, self._mg3_target_w) + c1 = self._mg3_target_h // self._mg3_lat_h + c2 = self._mg3_target_w // self._mg3_lat_w + plucker = rearrange( + plucker, + "f (h c1) (w c2) c -> (f h w) (c c1 c2)", + c1=c1, + c2=c2, + ) + plucker = plucker[None, ...] + return rearrange( + plucker, + "b (f h w) c -> b c f h w", + f=len(tgt_indices), + h=self._mg3_lat_h, + w=self._mg3_lat_w, + ) + + def _build_plucker_from_pose(self, c2ws_pose: torch.Tensor, intrinsics_seq: Optional[torch.Tensor] = None) -> torch.Tensor: + modules = self._get_official_modules() + assert self._mg3_target_h is not None and self._mg3_target_w is not None + assert self._mg3_lat_h is not None and self._mg3_lat_w is not None + if intrinsics_seq is None: + Ks = self._mg3_base_intrinsics.to(device=c2ws_pose.device, dtype=c2ws_pose.dtype).repeat(c2ws_pose.shape[0], 1) + else: + Ks = intrinsics_seq.to(device=c2ws_pose.device, dtype=c2ws_pose.dtype) + plucker = modules["cam_utils"].get_plucker_embeddings(c2ws_pose, Ks, self._mg3_target_h, self._mg3_target_w) + c1 = self._mg3_target_h // self._mg3_lat_h + c2 = self._mg3_target_w // self._mg3_lat_w + plucker = rearrange( + plucker, + "f (h c1) (w c2) c -> (f h w) (c c1 c2)", + c1=c1, + c2=c2, + ) + plucker = plucker[None, ...] + return rearrange( + plucker, + "b (f h w) c -> b c f h w", + f=c2ws_pose.shape[0], + h=self._mg3_lat_h, + w=self._mg3_lat_w, + ) + + def _build_memory_metadata(self, segment_idx: int, current_start_frame_idx: int, current_end_frame_idx: int) -> dict[str, Any]: + # Official source: pipeline/inference_pipeline.py and utils/cam_utils.py. + # Current downstream model code only requires c2ws_plucker_emb / keyboard_cond / mouse_cond, + # but we still stage the memory-facing metadata here so the runner owns segment bookkeeping. + if segment_idx == 0 or not self._mg3_generated_latent_history: + return { + "x_memory": None, + "timestep_memory": None, + "keyboard_cond_memory": None, + "mouse_cond_memory": None, + "memory_latent_idx": None, + "plucker_emb_with_memory": None, + } + + modules = self._get_official_modules() + assert self._mg3_extrinsics_all is not None + assert self._mg3_base_intrinsics is not None + + def align_frame_to_block(frame_idx: int) -> int: + return (frame_idx - 1) // 4 * 4 + 1 if frame_idx > 0 else 1 + + def get_latent_idx(frame_idx: int) -> int: + return (frame_idx - 1) // 4 + 1 if frame_idx > 0 else 0 + + selected_index_base = [current_end_frame_idx - offset for offset in range(1, 34, 8)] + selected_index = modules["cam_utils"].select_memory_idx_fov( + self._mg3_extrinsics_all, + current_start_frame_idx, + selected_index_base, + use_gpu=torch.cuda.is_available(), + ) + if selected_index: + selected_index[-1] = 4 + + memory_pluckers = [] + latent_idx = [] + for mem_idx, reference_idx in zip(selected_index, selected_index_base): + latent_idx.append(get_latent_idx(mem_idx)) + mem_idx_aligned = align_frame_to_block(mem_idx) + mem_block = self._mg3_extrinsics_all[mem_idx_aligned : mem_idx_aligned + 4] + mem_src = np.linspace(mem_idx_aligned, mem_idx_aligned + 3, mem_block.shape[0]) + mem_tgt = np.array([mem_idx_aligned + 3], dtype=np.float32) + mem_pose = modules["cam_utils"]._interpolate_camera_poses_handedness( + src_indices=mem_src, + src_rot_mat=mem_block[:, :3, :3].cpu().numpy(), + src_trans_vec=mem_block[:, :3, 3].cpu().numpy(), + tgt_indices=mem_tgt, + ) + reference_pose = self._mg3_extrinsics_all[reference_idx : reference_idx + 1] + rel_pair = torch.cat([reference_pose, mem_pose], dim=0) + rel_pose = modules["cam_utils"].compute_relative_poses(rel_pair, framewise=False)[1:2] + memory_pluckers.append(self._build_plucker_from_pose(rel_pose.to(device=AI_DEVICE))) + + current_plucker = self._build_or_get_segment_camera_only(segment_idx) + plucker_with_memory = torch.cat(memory_pluckers + [current_plucker], dim=2) if memory_pluckers else current_plucker + src = torch.cat(self._mg3_generated_latent_history, dim=1) + valid_latent_idx = [idx for idx in latent_idx if 0 <= idx < src.shape[1]] + if valid_latent_idx != latent_idx: + logger.warning( + "[matrix-game-3] memory latent index truncated from {} to {} because generated latent history is shorter.", + latent_idx, + valid_latent_idx, + ) + x_memory = src[:, valid_latent_idx].unsqueeze(0).to(device=AI_DEVICE, dtype=GET_DTYPE()) if valid_latent_idx else None + if x_memory is None: + timestep_memory = None + keyboard_cond_memory = None + mouse_cond_memory = None + else: + timestep_memory = x_memory.new_zeros((1, x_memory.shape[2] * x_memory.shape[3] * x_memory.shape[4] // 4)) + keyboard_cond_memory = -torch.ones((1, len(valid_latent_idx), self.keyboard_dim_in), device=x_memory.device, dtype=x_memory.dtype) + mouse_cond_memory = torch.ones((1, len(valid_latent_idx), self.mouse_dim_in), device=x_memory.device, dtype=x_memory.dtype) + + return { + "x_memory": x_memory, + "timestep_memory": timestep_memory, + "keyboard_cond_memory": keyboard_cond_memory, + "mouse_cond_memory": mouse_cond_memory, + "memory_latent_idx": valid_latent_idx, + "plucker_emb_with_memory": plucker_with_memory, + } + + def _build_or_get_segment_camera_only(self, segment_idx: int) -> torch.Tensor: + state = self._segment_states.get(segment_idx) + if state is not None and "c2ws_plucker_emb" in state.dit_cond_dict: + return state.dit_cond_dict["c2ws_plucker_emb"] + state = self._build_or_get_segment_state(segment_idx) + return state.dit_cond_dict["c2ws_plucker_emb"] + + def _build_or_get_segment_state(self, segment_idx: int) -> MatrixGame3SegmentState: + if segment_idx in self._segment_states: + return self._segment_states[segment_idx] + + if self._mg3_interactive and (self._mg3_keyboard_all is None or self._mg3_keyboard_all.shape[1] < self.first_clip_frame + segment_idx * self.incremental_segment_frames): + self._append_interactive_segment_controls(segment_idx) + + assert self._mg3_keyboard_all is not None + assert self._mg3_mouse_all is not None + assert self._mg3_extrinsics_all is not None + first_clip = segment_idx == 0 + + def get_latent_idx(frame_idx: int) -> int: + return (frame_idx - 1) // 4 + 1 if frame_idx > 0 else 0 + + current_end_frame_idx = self.first_clip_frame if first_clip else self.first_clip_frame + segment_idx * self.incremental_segment_frames + current_start_frame_idx = 0 if first_clip else current_end_frame_idx - self.clip_frame + frame_count = self.first_clip_frame if first_clip else self.clip_frame + latent_start_idx = get_latent_idx(current_start_frame_idx) + latent_end_idx = get_latent_idx(current_end_frame_idx) + fixed_latent_frames = 1 if first_clip else self.conditioning_latent_frames + decode_trim_frames = 0 if first_clip else 1 + self.config["vae_stride"][0] * (fixed_latent_frames - 1) + append_latent_start = 0 if first_clip else fixed_latent_frames + + c2ws_chunk = self._mg3_extrinsics_all[current_start_frame_idx:current_end_frame_idx].to(device=AI_DEVICE) + src_indices = np.linspace(current_start_frame_idx, current_end_frame_idx - 1, frame_count) + + intrinsics_chunk = None + if self._mg3_intrinsics_all is not None: + intrinsics_chunk = self._mg3_intrinsics_all[current_start_frame_idx:current_end_frame_idx] + + latent_shape = self._segment_latent_shape(self._mg3_lat_h, self._mg3_lat_w, frame_count) + tgt_indices = np.linspace(0 if first_clip else current_start_frame_idx + 3, current_end_frame_idx - 1, latent_shape[1]) + + camera_only = self._build_plucker_from_c2ws( + c2ws_chunk, + src_indices=src_indices, + tgt_indices=tgt_indices, + framewise=True, + intrinsics_seq=intrinsics_chunk, + ).to(device=AI_DEVICE, dtype=GET_DTYPE()) + + keyboard_cond = self._mg3_keyboard_all[:, current_start_frame_idx:current_end_frame_idx].to(device=AI_DEVICE, dtype=GET_DTYPE()) + mouse_cond = self._mg3_mouse_all[:, current_start_frame_idx:current_end_frame_idx].to(device=AI_DEVICE, dtype=GET_DTYPE()) + + vae_encoder_out = torch.zeros(latent_shape, device=AI_DEVICE, dtype=GET_DTYPE()) + if first_clip: + vae_encoder_out[:, :1] = self.inputs["image_encoder_output"]["vae_encoder_out"][:, :1] + else: + if self._mg3_tail_latents is None: + raise RuntimeError("matrix-game-3 segment requested without previous tail latents") + vae_encoder_out[:, : self.conditioning_latent_frames] = self._mg3_tail_latents.to(device=AI_DEVICE, dtype=GET_DTYPE()) + + # Fields below intentionally stay in the standard LightX2V image_encoder_output["dit_cond_dict"] + # container so downstream model / infer / weights code can consume them without a new top-level protocol. + dit_cond_dict: dict[str, Any] = { + "keyboard_cond": keyboard_cond, + "mouse_cond": mouse_cond, + "c2ws_plucker_emb": camera_only, + "predict_latent_idx": (latent_start_idx, latent_end_idx), + "segment_frame_range": (current_start_frame_idx, current_end_frame_idx), + "segment_idx": segment_idx, + "first_clip": first_clip, + } + dit_cond_dict.update(self._build_memory_metadata(segment_idx, current_start_frame_idx, current_end_frame_idx)) + + state = MatrixGame3SegmentState( + segment_idx=segment_idx, + first_clip=first_clip, + current_start_frame_idx=current_start_frame_idx, + current_end_frame_idx=current_end_frame_idx, + frame_count=frame_count, + fixed_latent_frames=fixed_latent_frames, + latent_shape=latent_shape, + decode_trim_frames=decode_trim_frames, + append_latent_start=append_latent_start, + keyboard_cond=keyboard_cond, + mouse_cond=mouse_cond, + vae_encoder_out=vae_encoder_out, + dit_cond_dict=dit_cond_dict, + ) + self._segment_states[segment_idx] = state + return state + + def _apply_segment_scheduler_state(self, segment_state: MatrixGame3SegmentState): + scheduler = self.model.scheduler + latents = torch.randn( + tuple(segment_state.latent_shape), + device=AI_DEVICE, + dtype=torch.float32, + generator=self._mg3_noise_generator, + ) + scheduler.vae_encoder_out = segment_state.vae_encoder_out.to(device=AI_DEVICE, dtype=torch.float32) + scheduler.mask = torch.ones_like(latents) + scheduler.mask[:, : segment_state.fixed_latent_frames] = 0 + scheduler.latents = (1.0 - scheduler.mask) * scheduler.vae_encoder_out + scheduler.mask * latents + + @ProfilingContext4DebugL1( + "Init run segment", + recorder_mode=GET_RECORDER_MODE(), + metrics_func=monitor_cli.lightx2v_run_init_run_segment_duration, + metrics_labels=["WanMatrixGame3Runner"], + ) + def init_run_segment(self, segment_idx): + # Official source: pipeline/inference_pipeline.py and inference_interactive_pipeline.py + # refresh per-segment action / camera / latent-conditioning state here so the outer lifecycle + # remains the standard LightX2V segment loop. + self.segment_idx = segment_idx + segment_state = self._build_or_get_segment_state(segment_idx) + self._mg3_current_segment_state = segment_state + self.input_info.latent_shape = segment_state.latent_shape + self.inputs["image_encoder_output"]["dit_cond_dict"] = segment_state.dit_cond_dict + self.inputs["image_encoder_output"]["vae_encoder_out"] = segment_state.vae_encoder_out + if segment_idx > 0: + self.model.scheduler.reset(self.input_info.seed, segment_state.latent_shape) + self._apply_segment_scheduler_state(segment_state) + + def run_segment(self, segment_idx=0): + latents = super().run_segment(segment_idx) + self._mg3_current_segment_full_latents = latents.detach().clone() + return latents + + def end_run_segment(self, segment_idx=None): + if self._mg3_current_segment_state is None or self._mg3_current_segment_full_latents is None: + raise RuntimeError("matrix-game-3 end_run_segment called before the current segment state was prepared") + + full_latents = self._mg3_current_segment_full_latents + # full_latents follows Wan2.2 runner convention: [C, T, H, W]. + self._mg3_tail_latents = full_latents[:, -self.conditioning_latent_frames :].detach().clone() + new_latents = full_latents[:, self._mg3_current_segment_state.append_latent_start :].detach().clone() + self._mg3_generated_latent_history.append(new_latents) + + segment_video = self.gen_video + if self._mg3_current_segment_state.decode_trim_frames > 0: + segment_video = segment_video[:, :, self._mg3_current_segment_state.decode_trim_frames :] + self.gen_video = segment_video + self.gen_video_final = segment_video if self.gen_video_final is None else torch.cat([self.gen_video_final, segment_video], dim=2) + self._mg3_current_segment_state = None + self._mg3_current_segment_full_latents = None + + def process_images_after_vae_decoder(self): + if self.gen_video_final is None: + self.gen_video_final = self.gen_video + return super().process_images_after_vae_decoder() diff --git a/lightx2v/pipeline.py b/lightx2v/pipeline.py index f4c569ea8..7b4cd7f46 100755 --- a/lightx2v/pipeline.py +++ b/lightx2v/pipeline.py @@ -21,6 +21,7 @@ from lightx2v.models.runners.wan.wan_audio_runner import Wan22AudioRunner, WanAudioRunner # noqa: F401 from lightx2v.models.runners.wan.wan_distill_runner import WanDistillRunner # noqa: F401 from lightx2v.models.runners.wan.wan_matrix_game2_runner import WanSFMtxg2Runner # noqa: F401 +from lightx2v.models.runners.wan.wan_matrix_game3_runner import WanMatrixGame3Runner # noqa: F401 from lightx2v.models.runners.wan.wan_runner import Wan22MoeRunner, WanRunner # noqa: F401 from lightx2v.models.runners.wan.wan_sf_runner import WanSFRunner # noqa: F401 from lightx2v.models.runners.wan.wan_vace_runner import WanVaceRunner # noqa: F401 @@ -99,9 +100,11 @@ def __init__( self.vae_stride = (4, 8, 8) if self.model_cls.startswith("wan2.2"): self.use_image_encoder = False - elif self.model_cls in ["wan2.2"]: + elif self.model_cls in ["wan2.2", "wan2.2_matrix_game3"]: self.vae_stride = (4, 16, 16) self.num_channels_latents = 48 + if self.model_cls == "wan2.2_matrix_game3": + self.use_image_encoder = False elif self.model_cls in ["hunyuan_video_1.5", "hunyuan_video_1.5_distill"]: self.vae_stride = (4, 16, 16) self.num_channels_latents = 32 @@ -320,6 +323,7 @@ def enable_offload( "seko_talk", "wan2.2_moe", "wan2.2", + "wan2.2_matrix_game3", "wan2.2_moe_audio", "wan2.2_audio", "wan2.2_moe_distill", @@ -409,6 +413,7 @@ def generate( negative_prompt="", save_result_path="lightx2v_gen_result.png", image_path=None, + action_path=None, video_path=None, # For SR task (video super-resolution) image_strength=None, last_frame_path=None, @@ -425,6 +430,7 @@ def generate( # image_strength can be a scalar (float/int) or a list matching the number of images self.seed = seed self.image_path = image_path + self.action_path = action_path self.video_path = video_path # For SR task self.sr_ratio = sr_ratio self.last_frame_path = last_frame_path From ea6cae9b5477abb71f6aefd9c456873cb78aa5ca Mon Sep 17 00:00:00 2001 From: Yang Date: Mon, 6 Apr 2026 14:17:41 +0800 Subject: [PATCH 02/25] Complete the residual work --- configs/matrix_game3/matrix_game3_base.json | 42 ++ .../matrix_game3/matrix_game3_distilled.json | 42 ++ .../wan/infer/matrix_game3/pre_infer.py | 170 +++++++ .../infer/matrix_game3/transformer_infer.py | 443 ++++++++++++++++++ .../models/networks/wan/matrix_game3_model.py | 84 ++++ .../wan/weights/matrix_game3/pre_weights.py | 77 +++ .../matrix_game3/transformer_weights.py | 342 ++++++++++++++ .../runners/wan/wan_matrix_game3_runner.py | 293 ++++++++++-- scripts/matrix_game3/run_matrix_game3_base.sh | 19 + .../run_matrix_game3_distilled.sh | 19 + 10 files changed, 1504 insertions(+), 27 deletions(-) create mode 100644 configs/matrix_game3/matrix_game3_base.json create mode 100644 configs/matrix_game3/matrix_game3_distilled.json create mode 100644 lightx2v/models/networks/wan/infer/matrix_game3/pre_infer.py create mode 100644 lightx2v/models/networks/wan/infer/matrix_game3/transformer_infer.py create mode 100644 lightx2v/models/networks/wan/matrix_game3_model.py create mode 100644 lightx2v/models/networks/wan/weights/matrix_game3/pre_weights.py create mode 100644 lightx2v/models/networks/wan/weights/matrix_game3/transformer_weights.py create mode 100755 scripts/matrix_game3/run_matrix_game3_base.sh create mode 100755 scripts/matrix_game3/run_matrix_game3_distilled.sh diff --git a/configs/matrix_game3/matrix_game3_base.json b/configs/matrix_game3/matrix_game3_base.json new file mode 100644 index 000000000..6c6d70b25 --- /dev/null +++ b/configs/matrix_game3/matrix_game3_base.json @@ -0,0 +1,42 @@ +{ + "model_cls": "wan2.2_matrix_game3", + "task": "i2v", + "model_path": "", + "sub_model_folder": "base_model", + "use_base_model": true, + + "target_video_length": 57, + "target_height": 704, + "target_width": 1280, + "vae_stride": [4, 16, 16], + "patch_size": [1, 2, 2], + + "num_channels_latents": 48, + "num_inference_steps": 50, + "sample_shift": 5.0, + "sample_guide_scale": 5.0, + "enable_cfg": true, + + "first_clip_frame": 57, + "clip_frame": 56, + "incremental_segment_frames": 40, + "past_frame": 16, + "conditioning_latent_frames": 4, + "num_iterations": 1, + "interactive": false, + "streaming": false, + + "mouse_dim_in": 2, + "keyboard_dim_in": 6, + + "self_attn_1_type": "flash_attn2", + "cross_attn_1_type": "flash_attn2", + "cross_attn_2_type": "flash_attn2", + + "cpu_offload": false, + "cpu_offload_activations": false, + "offload_granularity": "block", + "seq_parallel": false, + "parallel": {}, + "dit_quantized": false +} diff --git a/configs/matrix_game3/matrix_game3_distilled.json b/configs/matrix_game3/matrix_game3_distilled.json new file mode 100644 index 000000000..011fcc0a2 --- /dev/null +++ b/configs/matrix_game3/matrix_game3_distilled.json @@ -0,0 +1,42 @@ +{ + "model_cls": "wan2.2_matrix_game3", + "task": "i2v", + "model_path": "", + "sub_model_folder": "base_distilled_model", + "use_base_model": false, + + "target_video_length": 57, + "target_height": 704, + "target_width": 1280, + "vae_stride": [4, 16, 16], + "patch_size": [1, 2, 2], + + "num_channels_latents": 48, + "num_inference_steps": 3, + "sample_shift": 5.0, + "sample_guide_scale": 1.0, + "enable_cfg": false, + + "first_clip_frame": 57, + "clip_frame": 56, + "incremental_segment_frames": 40, + "past_frame": 16, + "conditioning_latent_frames": 4, + "num_iterations": 1, + "interactive": false, + "streaming": false, + + "mouse_dim_in": 2, + "keyboard_dim_in": 6, + + "self_attn_1_type": "flash_attn2", + "cross_attn_1_type": "flash_attn2", + "cross_attn_2_type": "flash_attn2", + + "cpu_offload": false, + "cpu_offload_activations": false, + "offload_granularity": "block", + "seq_parallel": false, + "parallel": {}, + "dit_quantized": false +} diff --git a/lightx2v/models/networks/wan/infer/matrix_game3/pre_infer.py b/lightx2v/models/networks/wan/infer/matrix_game3/pre_infer.py new file mode 100644 index 000000000..1243b6fa9 --- /dev/null +++ b/lightx2v/models/networks/wan/infer/matrix_game3/pre_infer.py @@ -0,0 +1,170 @@ +import torch +import torch.nn.functional as F + +from lightx2v.models.networks.wan.infer.module_io import GridOutput +from lightx2v.models.networks.wan.infer.pre_infer import WanPreInfer +from lightx2v.models.networks.wan.infer.utils import sinusoidal_embedding_1d +from lightx2v.utils.envs import * +from lightx2v_platform.base.global_var import AI_DEVICE + + +class WanMtxg3PreInferOutput: + """Container for MG3 pre-inference outputs passed to the transformer.""" + + __slots__ = [ + "x", "embed", "embed0", "grid_sizes", "cos_sin", "context", + "plucker_emb", "mouse_cond", "keyboard_cond", + "mouse_cond_memory", "keyboard_cond_memory", + "memory_length", "memory_latent_idx", "predict_latent_idx", + ] + + def __init__(self, **kwargs): + for k, v in kwargs.items(): + setattr(self, k, v) + + +class WanMtxg3PreInfer(WanPreInfer): + """Pre-inference for Matrix-Game-3.0. + + Builds: + - Patch embeddings + plucker camera embeddings + - Text embeddings (no CLIP image encoder — MG3 uses direct text conditioning) + - Time embeddings + - Passes through conditioning signals (keyboard, mouse, plucker, memory) + """ + + def __init__(self, config): + super().__init__(config) + self.use_memory = True + self.sigma_theta = config.get("sigma_theta", 0.0) + + # Build RoPE frequencies with optional sigma_theta head-specific theta + d = config["dim"] // config["num_heads"] + num_heads = config["num_heads"] + if self.sigma_theta > 0: + self.freqs = self._build_sigma_theta_freqs(d, num_heads, self.sigma_theta) + else: + self.freqs = torch.cat( + [ + self.rope_params(2048, d - 4 * (d // 6)), + self.rope_params(2048, 2 * (d // 6)), + self.rope_params(2048, 2 * (d // 6)), + ], + dim=1, + ).to(torch.device(AI_DEVICE)) + + def _build_sigma_theta_freqs(self, d, num_heads, sigma_theta): + """Build head-specific RoPE with sigma_theta perturbation as in official MG3.""" + c = d // 2 + c_t = c - 2 * (c // 3) + c_h = c // 3 + c_w = c // 3 + max_seq_len = 2048 + + rope_epsilon = torch.linspace(-1, 1, num_heads, dtype=torch.float64) + theta_base = 10000.0 + theta_hat = theta_base * (1 + sigma_theta * rope_epsilon) + + def build_freqs(seq_len, c_part): + exp = torch.arange(c_part, dtype=torch.float64) / c_part + omega = 1.0 / torch.pow(theta_hat.unsqueeze(1), exp.unsqueeze(0)) + pos = torch.arange(seq_len, dtype=torch.float64) + angles = pos.view(1, -1, 1) * omega.unsqueeze(1) + return torch.polar(torch.ones_like(angles), angles) + + freqs_t = build_freqs(max_seq_len, c_t) + freqs_h = build_freqs(max_seq_len, c_h) + freqs_w = build_freqs(max_seq_len, c_w) + return torch.cat([freqs_t, freqs_h, freqs_w], dim=2).to(torch.device(AI_DEVICE)) + + def set_scheduler(self, scheduler): + self.scheduler = scheduler + + @torch.no_grad() + def infer(self, weights, inputs, kv_start=0, kv_end=0): + """Build pre-inference outputs for the MG3.0 transformer.""" + x = self.scheduler.latents + t = self.scheduler.timestep_input + + # Text context (MG3 uses text conditioning only, no CLIP image encoder) + if self.scheduler.infer_condition: + context = inputs["text_encoder_output"]["context"] + else: + context = inputs["text_encoder_output"]["context_null"] + + # Patch embedding + x = weights.patch_embedding.apply(x.unsqueeze(0)) + grid_sizes_t, grid_sizes_h, grid_sizes_w = x.shape[2:] + x = x.flatten(2).transpose(1, 2).contiguous() + + # Time embedding + embed = sinusoidal_embedding_1d(self.freq_dim, t.flatten()) + if self.sensitive_layer_dtype != self.infer_dtype: + embed = weights.time_embedding_0.apply(embed.to(self.sensitive_layer_dtype)) + else: + embed = weights.time_embedding_0.apply(embed) + embed = torch.nn.functional.silu(embed) + embed = weights.time_embedding_2.apply(embed) + embed0 = torch.nn.functional.silu(embed) + embed0 = weights.time_projection_1.apply(embed0).unflatten(1, (6, self.dim)) + + # Text embedding + if self.sensitive_layer_dtype != self.infer_dtype: + out = weights.text_embedding_0.apply(context.squeeze(0).to(self.sensitive_layer_dtype)) + else: + out = weights.text_embedding_0.apply(context.squeeze(0)) + out = torch.nn.functional.gelu(out, approximate="tanh") + context = weights.text_embedding_2.apply(out) + + # Grid sizes and RoPE + grid_sizes = GridOutput( + tensor=torch.tensor( + [[grid_sizes_t, grid_sizes_h, grid_sizes_w]], + dtype=torch.int32, + device=x.device, + ), + tuple=(grid_sizes_t, grid_sizes_h, grid_sizes_w), + ) + + if self.cos_sin is None or self.grid_sizes != grid_sizes.tuple: + freqs = self.freqs.clone() + self.grid_sizes = grid_sizes.tuple + self.cos_sin = self.prepare_cos_sin(grid_sizes.tuple, freqs) + + # Extract conditioning signals from the runner's inputs + mg3_cond = inputs.get("mg3_conditions", {}) + plucker_emb = mg3_cond.get("plucker_emb", None) + mouse_cond = mg3_cond.get("mouse_cond", None) + keyboard_cond = mg3_cond.get("keyboard_cond", None) + mouse_cond_memory = mg3_cond.get("mouse_cond_memory", None) + keyboard_cond_memory = mg3_cond.get("keyboard_cond_memory", None) + memory_length = mg3_cond.get("memory_length", 0) + memory_latent_idx = mg3_cond.get("memory_latent_idx", None) + predict_latent_idx = mg3_cond.get("predict_latent_idx", None) + + # Process plucker embedding through the global camera layers + if plucker_emb is not None: + plucker_emb = weights.patch_embedding_wancamctrl.apply(plucker_emb.squeeze(0)) + plucker_hidden = weights.c2ws_hidden_states_layer2.apply( + torch.nn.functional.silu( + weights.c2ws_hidden_states_layer1.apply(plucker_emb) + ) + ) + plucker_emb = plucker_emb + plucker_hidden + + return WanMtxg3PreInferOutput( + embed=embed, + grid_sizes=grid_sizes, + x=x.squeeze(0), + embed0=embed0.squeeze(0), + context=context, + cos_sin=self.cos_sin, + plucker_emb=plucker_emb, + mouse_cond=mouse_cond, + keyboard_cond=keyboard_cond, + mouse_cond_memory=mouse_cond_memory, + keyboard_cond_memory=keyboard_cond_memory, + memory_length=memory_length, + memory_latent_idx=memory_latent_idx, + predict_latent_idx=predict_latent_idx, + ) diff --git a/lightx2v/models/networks/wan/infer/matrix_game3/transformer_infer.py b/lightx2v/models/networks/wan/infer/matrix_game3/transformer_infer.py new file mode 100644 index 000000000..a38a2632d --- /dev/null +++ b/lightx2v/models/networks/wan/infer/matrix_game3/transformer_infer.py @@ -0,0 +1,443 @@ +"""Transformer inference for Matrix-Game-3.0. + +Implements the MG3.0 WanAttentionBlock forward pass in LightX2V's +decomposed weight/infer architecture. The block execution order is: + + self_attn → cam_injection → cross_attn → action_model → ffn + +This closely follows the official MG3 `WanAttentionBlock.forward()`. +""" + +import math + +import torch +from einops import rearrange + +try: + import flash_attn_interface + + FLASH_ATTN_3_AVAILABLE = True +except ImportError: + try: + from flash_attn import flash_attn_func + + FLASH_ATTN_3_AVAILABLE = False + except ImportError: + FLASH_ATTN_3_AVAILABLE = False + +from lightx2v.models.networks.wan.infer.transformer_infer import WanTransformerInfer +from lightx2v.utils.envs import * +from lightx2v.utils.registry_factory import * +from lightx2v_platform.base.global_var import AI_DEVICE + +torch_device_module = getattr(torch, AI_DEVICE) + + +def rope_apply_with_indices(x, grid_sizes, freqs, indices): + """Apply RoPE using explicit frame indices (for memory-aware attention). + + Rather than assuming sequential frame positions 0..F-1, this uses the + provided ``indices`` list so that memory frames and prediction frames + can have non-contiguous positional encodings. + + Args: + x: Shape [B, S, num_heads, head_dim] + grid_sizes: Shape [B, 3] (F, H, W) + freqs: Pre-computed RoPE frequencies + indices: List of frame indices to use for positional encoding + """ + n = x.shape[2] # num_heads + f, h, w = grid_sizes[0].tolist() + + if freqs.dim() == 2: + # Standard freqs: [max_seq, head_dim//2] + c = freqs.shape[1] + c_t = c - 2 * (c // 3) + c_h = c // 3 + c_w = c // 3 + freq_parts = freqs.split([c_t, c_h, c_w], dim=1) + + freq_t = freq_parts[0][indices] + cos_sin = torch.cat( + [ + freq_t.view(f, 1, 1, -1).expand(f, h, w, -1), + freq_parts[1][:h].view(1, h, 1, -1).expand(f, h, w, -1), + freq_parts[2][:w].view(1, 1, w, -1).expand(f, h, w, -1), + ], + dim=-1, + ).reshape(f * h * w, 1, -1) + elif freqs.dim() == 3: + # Head-specific freqs: [num_heads, max_seq, head_dim_per_head//2] + c = freqs.shape[2] + c_t = c - 2 * (c // 3) + c_h = c // 3 + c_w = c // 3 + freq_parts = freqs.split([c_t, c_h, c_w], dim=2) + + freq_t = freq_parts[0][:, indices, :] # [n, f, c_t] + cos_sin = torch.cat( + [ + freq_t.permute(1, 0, 2).unsqueeze(2).unsqueeze(3).expand(-1, -1, h, w, -1), + freq_parts[1][:, :h, :].permute(1, 0, 2).unsqueeze(0).unsqueeze(3).expand(f, -1, -1, w, -1), + freq_parts[2][:, :w, :].permute(1, 0, 2).unsqueeze(0).unsqueeze(2).expand(f, -1, h, -1, -1), + ], + dim=-1, + ).reshape(f * h * w, n, -1) + else: + raise ValueError(f"Unexpected freqs shape: {freqs.shape}") + + cos_sin = cos_sin.to(x.device) + # Apply RoPE + x_complex = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2)) + out = torch.view_as_real(x_complex * cos_sin).flatten(-2) + return out.type_as(x) + + +class WanMtxg3TransformerInfer(WanTransformerInfer): + """Transformer inference backend for Matrix-Game-3.0. + + Extends the base ``WanTransformerInfer`` to handle: + - Memory-aware self-attention with indexed RoPE + - Per-block camera plucker injection (scale/shift) + - ActionModule forward pass for keyboard/mouse conditioning + """ + + def __init__(self, config): + super().__init__(config) + self.action_config = config.get("action_config", {}) + self.action_blocks = set(self.action_config.get("blocks", [])) + + @torch.no_grad() + def infer(self, weights, pre_infer_out): + self.cos_sin = pre_infer_out.cos_sin + self.reset_infer_states() + x = self.infer_main_blocks(weights.blocks, pre_infer_out) + return self.infer_non_blocks(weights, x, pre_infer_out.embed) + + def infer_main_blocks(self, blocks, pre_infer_out): + x = pre_infer_out.x + for block_idx in range(len(blocks)): + self.block_idx = block_idx + x = self.infer_block(blocks[block_idx], x, pre_infer_out) + return x + + def infer_block(self, block, x, pre_infer_out): + """Execute one MG3.0 transformer block. + + Phase order: + 0: self_attn + 1: cam_injection + 2: cross_attn + 3: action_model (only on action blocks) + 4 (or 3): ffn + """ + has_action = self.block_idx in self.action_blocks + + # --- Modulation (6-way) --- + shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = self.pre_process( + block.compute_phases[0].modulation, + pre_infer_out.embed0, + ) + + # --- Phase 0: Self-Attention (with memory-aware RoPE) --- + y_out = self._infer_self_attn_mg3( + block.compute_phases[0], + x, + shift_msa, + scale_msa, + pre_infer_out, + ) + + # Gate and residual + if self.sensitive_layer_dtype != self.infer_dtype: + x = x.to(self.sensitive_layer_dtype) + y_out.to(self.sensitive_layer_dtype) * gate_msa.squeeze() + else: + x = x + y_out * gate_msa.squeeze() + + # --- Phase 1: Camera Plucker Injection --- + if pre_infer_out.plucker_emb is not None: + x = self._infer_cam_injection(block.compute_phases[1], x, pre_infer_out.plucker_emb) + + # --- Phase 2: Cross-Attention --- + cross_phase = block.compute_phases[2] + norm3_out = cross_phase.norm3.apply(x) + n, d = self.num_heads, self.head_dim + q = cross_phase.cross_attn_norm_q.apply(cross_phase.cross_attn_q.apply(norm3_out)).view(-1, n, d) + k = cross_phase.cross_attn_norm_k.apply(cross_phase.cross_attn_k.apply(pre_infer_out.context)).view(-1, n, d) + v = cross_phase.cross_attn_v.apply(pre_infer_out.context).view(-1, n, d) + + if self.cross_attn_cu_seqlens_q is None: + self.cross_attn_cu_seqlens_q = torch.tensor([0, q.shape[0]], dtype=torch.int32).to(q.device) + if self.cross_attn_cu_seqlens_kv is None: + self.cross_attn_cu_seqlens_kv = torch.tensor([0, k.shape[0]], dtype=torch.int32).to(k.device) + + attn_out = cross_phase.cross_attn_1.apply( + q=q, k=k, v=v, + cu_seqlens_q=self.cross_attn_cu_seqlens_q, + cu_seqlens_kv=self.cross_attn_cu_seqlens_kv, + max_seqlen_q=q.size(0), + max_seqlen_kv=k.size(0), + ) + attn_out = cross_phase.cross_attn_o.apply(attn_out) + x = x + attn_out + + # --- Phase 3: ActionModule (optional) --- + if has_action: + action_phase_idx = 3 + action_phase = block.compute_phases[action_phase_idx] + x = self._infer_action_module( + action_phase, x, pre_infer_out + ) + + # --- Phase 4 (or 3): FFN --- + ffn_phase_idx = 4 if has_action else 3 + ffn_phase = block.compute_phases[ffn_phase_idx] + norm2_out = ffn_phase.norm2.apply(x) + if self.sensitive_layer_dtype != self.infer_dtype: + norm2_out = norm2_out.to(self.sensitive_layer_dtype) + norm2_out = norm2_out * (1 + c_scale_msa.squeeze()) + c_shift_msa.squeeze() + if self.sensitive_layer_dtype != self.infer_dtype: + norm2_out = norm2_out.to(self.infer_dtype) + + y = ffn_phase.ffn_0.apply(norm2_out) + y = torch.nn.functional.gelu(y, approximate="tanh") + y = ffn_phase.ffn_2.apply(y) + + # FFN gate + residual + if self.sensitive_layer_dtype != self.infer_dtype: + x = x.to(self.sensitive_layer_dtype) + y.to(self.sensitive_layer_dtype) * c_gate_msa.squeeze() + else: + x = x + y * c_gate_msa.squeeze() + + return x + + def _infer_self_attn_mg3(self, phase, x, shift_msa, scale_msa, pre_infer_out): + """Self-attention with memory-aware indexed RoPE.""" + cos_sin = self.cos_sin + + norm1_out = phase.norm1.apply(x) + if self.sensitive_layer_dtype != self.infer_dtype: + norm1_out = norm1_out.to(self.sensitive_layer_dtype) + norm1_out = norm1_out * (1 + scale_msa.squeeze()) + shift_msa.squeeze() + if self.sensitive_layer_dtype != self.infer_dtype: + norm1_out = norm1_out.to(self.infer_dtype) + + s, n, d = *norm1_out.shape[:1], self.num_heads, self.head_dim + q = phase.self_attn_norm_q.apply(phase.self_attn_q.apply(norm1_out)).view(s, n, d) + k = phase.self_attn_norm_k.apply(phase.self_attn_k.apply(norm1_out)).view(s, n, d) + v = phase.self_attn_v.apply(norm1_out).view(s, n, d) + + # Memory-aware RoPE + memory_length = getattr(pre_infer_out, "memory_length", 0) + memory_latent_idx = getattr(pre_infer_out, "memory_latent_idx", None) + predict_latent_idx = getattr(pre_infer_out, "predict_latent_idx", None) + grid_sizes = pre_infer_out.grid_sizes + + if memory_length > 0: + hw = grid_sizes.tuple[1] * grid_sizes.tuple[2] + # Split into memory and prediction parts + q_memory = q[:memory_length * hw].unsqueeze(0) + k_memory = k[:memory_length * hw].unsqueeze(0) + q_pred = q[memory_length * hw:].unsqueeze(0) + k_pred = k[memory_length * hw:].unsqueeze(0) + + # Build grid_sizes tensors + f_total = grid_sizes.tuple[0] + h, w = grid_sizes.tuple[1], grid_sizes.tuple[2] + grid_sizes_mem = torch.tensor([[memory_length, h, w]], dtype=torch.long, device=q.device) + grid_sizes_pred = torch.tensor([[f_total - memory_length, h, w]], dtype=torch.long, device=q.device) + + # RoPE with explicit indices + mem_indices = memory_latent_idx if memory_latent_idx is not None else list(range(memory_length)) + q_memory = rope_apply_with_indices(q_memory, grid_sizes_mem, self.freqs, mem_indices) + k_memory = rope_apply_with_indices(k_memory, grid_sizes_mem, self.freqs, mem_indices) + + if predict_latent_idx is not None: + if isinstance(predict_latent_idx, tuple) and len(predict_latent_idx) == 2: + pred_indices = list(range(predict_latent_idx[0], predict_latent_idx[1])) + else: + pred_indices = predict_latent_idx + else: + pred_indices = list(range(grid_sizes_pred[0, 0].item())) + + q_pred = rope_apply_with_indices(q_pred, grid_sizes_pred, self.freqs, pred_indices) + k_pred = rope_apply_with_indices(k_pred, grid_sizes_pred, self.freqs, pred_indices) + + q = torch.cat([q_memory.squeeze(0), q_pred.squeeze(0)], dim=0) + k = torch.cat([k_memory.squeeze(0), k_pred.squeeze(0)], dim=0) + else: + # No memory — standard RoPE or indexed RoPE + if predict_latent_idx is not None: + q_unsq = q.unsqueeze(0) + k_unsq = k.unsqueeze(0) + grid_sizes_t = torch.tensor( + [[grid_sizes.tuple[0], grid_sizes.tuple[1], grid_sizes.tuple[2]]], + dtype=torch.long, device=q.device, + ) + if isinstance(predict_latent_idx, tuple) and len(predict_latent_idx) == 2: + pred_indices = list(range(predict_latent_idx[0], predict_latent_idx[1])) + else: + pred_indices = predict_latent_idx + q = rope_apply_with_indices(q_unsq, grid_sizes_t, self.freqs, pred_indices).squeeze(0) + k = rope_apply_with_indices(k_unsq, grid_sizes_t, self.freqs, pred_indices).squeeze(0) + else: + q, k = self.apply_rope_func(q, k, cos_sin) + + img_qkv_len = q.shape[0] + if self.self_attn_cu_seqlens_qkv is None: + self.self_attn_cu_seqlens_qkv = torch.tensor([0, img_qkv_len], dtype=torch.int32).to(q.device) + + attn_out = phase.self_attn_1.apply( + q=q, k=k, v=v, + cu_seqlens_q=self.self_attn_cu_seqlens_qkv, + cu_seqlens_kv=self.self_attn_cu_seqlens_qkv, + max_seqlen_q=img_qkv_len, + max_seqlen_kv=img_qkv_len, + ) + + y = phase.self_attn_o.apply(attn_out) + return y + + def _infer_cam_injection(self, cam_phase, x, plucker_emb): + """Apply per-block camera plucker injection via scale/shift modulation. + + From official MG3: + c2ws_hidden = cam_injector_layer2(silu(cam_injector_layer1(plucker_emb))) + c2ws_hidden = c2ws_hidden + plucker_emb + cam_scale = cam_scale_layer(c2ws_hidden) + cam_shift = cam_shift_layer(c2ws_hidden) + x = (1 + cam_scale) * x + cam_shift + """ + hidden = cam_phase.cam_injector_layer1.apply(plucker_emb) + hidden = torch.nn.functional.silu(hidden) + hidden = cam_phase.cam_injector_layer2.apply(hidden) + hidden = hidden + plucker_emb + + cam_scale = cam_phase.cam_scale_layer.apply(hidden) + cam_shift = cam_phase.cam_shift_layer.apply(hidden) + x = (1.0 + cam_scale) * x + cam_shift + return x + + def _infer_action_module(self, phase, x, pre_infer_out): + """ActionModule forward: keyboard + mouse conditioning via cross-attention. + + This implements the official MG3 ActionModule logic in the LightX2V + weight/infer separation style. The module: + 1. Processes mouse condition through mouse_mlp + 2. Applies temporal self-attention with QKV (t_qkv) + 3. Projects back via proj_mouse + 4. Processes keyboard condition through keyboard_embed + 5. Applies keyboard cross-attention + 6. Projects back via proj_keyboard + """ + grid_sizes = pre_infer_out.grid_sizes + f, h, w = grid_sizes.tuple + S = h * w + + mouse_cond = pre_infer_out.mouse_cond + keyboard_cond = pre_infer_out.keyboard_cond + + x_in = x.unsqueeze(0) # [1, FHW, C] + + # --- Mouse conditioning --- + if mouse_cond is not None: + hidden_states = rearrange(x_in, "B (T S) C -> (B S) T C", T=f, S=S) + + # Mouse MLP + mouse_input = torch.cat([hidden_states, mouse_cond.expand(S, -1, -1) if mouse_cond.shape[0] == 1 else mouse_cond], dim=-1) + mouse_out = phase.mouse_mlp_0.apply(mouse_input.reshape(-1, mouse_input.shape[-1])) + mouse_out = torch.nn.functional.gelu(mouse_out, approximate="tanh") + mouse_out = phase.mouse_mlp_2.apply(mouse_out) + mouse_out = phase.mouse_mlp_3.apply(mouse_out) + mouse_out = mouse_out.reshape(S, f, -1) + + # Mouse temporal self-attention with QKV + mouse_qkv = phase.t_qkv.apply(mouse_out.reshape(-1, mouse_out.shape[-1])) + mouse_qkv = mouse_qkv.reshape(S, f, 3, self.num_heads, self.head_dim) + q_m, k_m, v_m = mouse_qkv.permute(2, 0, 1, 3, 4).unbind(0) + + # QK norm (RMSNorm) + q_m = phase.img_attn_q_norm.apply(q_m.reshape(-1, self.head_dim)).reshape(S, f, self.num_heads, self.head_dim) + k_m = phase.img_attn_k_norm.apply(k_m.reshape(-1, self.head_dim)).reshape(S, f, self.num_heads, self.head_dim) + + # Flash attention + if FLASH_ATTN_3_AVAILABLE: + mouse_attn = flash_attn_interface.flash_attn_func(q_m, k_m, v_m) + else: + mouse_attn = flash_attn_func(q_m, k_m, v_m) + + mouse_attn = rearrange(mouse_attn, "(B S) T h d -> B (T S) (h d)", B=1, S=S) + mouse_proj = phase.proj_mouse.apply(mouse_attn.squeeze(0)).unsqueeze(0) + x_in = x_in + mouse_proj + + # --- Keyboard conditioning --- + if keyboard_cond is not None: + # Keyboard embed + kb_emb = phase.keyboard_embed_0.apply(keyboard_cond.reshape(-1, keyboard_cond.shape[-1])) + kb_emb = torch.nn.functional.silu(kb_emb) + kb_emb = phase.keyboard_embed_2.apply(kb_emb) + kb_emb = kb_emb.reshape(keyboard_cond.shape[0], keyboard_cond.shape[1], -1) + + # Keyboard cross-attention: query from hidden states, key/value from keyboard + mouse_q = phase.mouse_attn_q.apply(x_in.squeeze(0)).unsqueeze(0) + keyboard_kv = phase.keyboard_attn_kv.apply(kb_emb.reshape(-1, kb_emb.shape[-1])) + keyboard_kv = keyboard_kv.reshape(1, -1, keyboard_kv.shape[-1]) + + HD = mouse_q.shape[-1] + D = HD // self.num_heads + q_k = mouse_q.view(1, -1, self.num_heads, D) + kv_split = keyboard_kv.view(1, -1, 2, self.num_heads, D) + k_k, v_k = kv_split.permute(2, 0, 1, 3, 4).unbind(0) + + # QK norm + q_k_flat = q_k.reshape(-1, D) + k_k_flat = k_k.reshape(-1, D) + q_k = phase.key_attn_q_norm.apply(q_k_flat).reshape(1, -1, self.num_heads, D) + k_k = phase.key_attn_k_norm.apply(k_k_flat).reshape(1, -1, self.num_heads, D) + + # Flash attention + if FLASH_ATTN_3_AVAILABLE: + kb_attn = flash_attn_interface.flash_attn_func(q_k, k_k, v_k) + else: + kb_attn = flash_attn_func(q_k, k_k, v_k) + + kb_attn = rearrange(kb_attn, "B L H D -> B L (H D)") + kb_proj = phase.proj_keyboard.apply(kb_attn.squeeze(0)).unsqueeze(0) + x_in = x_in + kb_proj + + return x_in.squeeze(0) + + @property + def freqs(self): + """Access the pre-infer's freqs for RoPE with indices.""" + return self._freqs + + @freqs.setter + def freqs(self, value): + self._freqs = value + + def infer_non_blocks(self, weights, x, e): + """Head processing — same as base but handles per-token time embeddings.""" + if e.dim() == 2: + modulation = weights.head_modulation.tensor + e_parts = (modulation + e.unsqueeze(1)).chunk(2, dim=1) + elif e.dim() == 3: + modulation = weights.head_modulation.tensor.unsqueeze(2) + e_parts = (modulation + e.unsqueeze(1)).chunk(2, dim=1) + e_parts = [ei.squeeze(1) for ei in e_parts] + else: + modulation = weights.head_modulation.tensor + e_parts = (modulation + e.unsqueeze(1)).chunk(2, dim=1) + + x = weights.norm.apply(x) + if self.sensitive_layer_dtype != self.infer_dtype: + x = x.to(self.sensitive_layer_dtype) + x = x * (1 + e_parts[1].squeeze()) + e_parts[0].squeeze() + if self.sensitive_layer_dtype != self.infer_dtype: + x = x.to(self.infer_dtype) + x = weights.head.apply(x) + return x + + def set_freqs(self, freqs): + """Set RoPE frequencies from pre_infer.""" + self._freqs = freqs diff --git a/lightx2v/models/networks/wan/matrix_game3_model.py b/lightx2v/models/networks/wan/matrix_game3_model.py new file mode 100644 index 000000000..17555c945 --- /dev/null +++ b/lightx2v/models/networks/wan/matrix_game3_model.py @@ -0,0 +1,84 @@ +import json +import os + +import torch +from safetensors import safe_open + +from lightx2v.models.networks.wan.infer.matrix_game3.pre_infer import WanMtxg3PreInfer +from lightx2v.models.networks.wan.infer.matrix_game3.transformer_infer import WanMtxg3TransformerInfer +from lightx2v.models.networks.wan.infer.post_infer import WanPostInfer +from lightx2v.models.networks.wan.model import WanModel +from lightx2v.models.networks.wan.weights.matrix_game3.pre_weights import WanMtxg3PreWeights +from lightx2v.models.networks.wan.weights.matrix_game3.transformer_weights import WanMtxg3TransformerWeights +from lightx2v.utils.envs import * +from lightx2v.utils.utils import * + + +class WanMtxg3Model(WanModel): + """Network model for Matrix-Game-3.0. + + Extends the base Wan2.2 DiT backbone with: + - Per-block ActionModule weights for keyboard/mouse conditioning + - Camera plucker ray injection layers (cam_injector, cam_scale, cam_shift) + - Memory-aware self-attention with indexed RoPE + - Global plucker embedding (patch_embedding_wancamctrl, c2ws_hidden_states_layer1/2) + + The model loads diffusers-format safetensors from the MG3.0 checkpoint + directory (base_model/ or base_distilled_model/). + """ + + pre_weight_class = WanMtxg3PreWeights + transformer_weight_class = WanMtxg3TransformerWeights + + def __init__(self, model_path, config, device, model_type="wan2.2", lora_path=None, lora_strength=1.0): + super().__init__(model_path, config, device, model_type, lora_path, lora_strength) + + def _init_infer_class(self): + # Merge the official MG3 model config so that all dimension / action fields + # are available for weight and infer construction. + sub_model_folder = self.config.get("sub_model_folder", "base_distilled_model") + config_path = os.path.join(self.config["model_path"], sub_model_folder, "config.json") + if os.path.exists(config_path): + with open(config_path) as f: + model_config = json.load(f) + for k in model_config.keys(): + self.config[k] = model_config[k] + + self.pre_infer_class = WanMtxg3PreInfer + self.post_infer_class = WanPostInfer + self.transformer_infer_class = WanMtxg3TransformerInfer + + def _load_ckpt(self, unified_dtype, sensitive_layer): + """Load MG3.0 safetensors checkpoint. + + The MG3.0 checkpoint uses diffusers format with keys like + ``model.blocks.0.self_attn.q.weight`` (prefixed with ``model.``). + We strip the ``model.`` prefix so the keys match our weight module names. + """ + sub_model_folder = self.config.get("sub_model_folder", "base_distilled_model") + model_dir = os.path.join(self.config["model_path"], sub_model_folder) + + # Find safetensor files + safetensor_files = [f for f in os.listdir(model_dir) if f.endswith(".safetensors")] + if not safetensor_files: + raise FileNotFoundError( + f"No safetensors files found in {model_dir}. " + "Please download the Matrix-Game-3.0 model weights." + ) + + weight_dict = {} + for sf_file in sorted(safetensor_files): + file_path = os.path.join(model_dir, sf_file) + with safe_open(file_path, framework="pt", device=str(self.device)) as f: + for key in f.keys(): + tensor = f.get_tensor(key) + # Strip the common diffusers prefix if present + name = key + if name.startswith("model."): + name = name[len("model."):] + # Cast to appropriate dtype + if unified_dtype or all(s not in name for s in sensitive_layer): + weight_dict[name] = tensor.to(GET_DTYPE()) + else: + weight_dict[name] = tensor.to(GET_SENSITIVE_DTYPE()) + return weight_dict diff --git a/lightx2v/models/networks/wan/weights/matrix_game3/pre_weights.py b/lightx2v/models/networks/wan/weights/matrix_game3/pre_weights.py new file mode 100644 index 000000000..f187e9e8b --- /dev/null +++ b/lightx2v/models/networks/wan/weights/matrix_game3/pre_weights.py @@ -0,0 +1,77 @@ +from lightx2v.common.modules.weight_module import WeightModule +from lightx2v.utils.registry_factory import ( + CONV3D_WEIGHT_REGISTER, + LN_WEIGHT_REGISTER, + MM_WEIGHT_REGISTER, +) + + +class WanMtxg3PreWeights(WeightModule): + """Pre-processing weights for Matrix-Game-3.0. + + Handles: + - patch_embedding (Conv3D) + - text_embedding (2-layer MLP) + - time_embedding + time_projection + - patch_embedding_wancamctrl (plucker ray Linear) + - c2ws_hidden_states_layer1/2 (camera injection MLP) + """ + + def __init__(self, config): + super().__init__() + self.in_dim = config["in_dim"] + self.dim = config["dim"] + self.patch_size = tuple(config.get("patch_size", (1, 2, 2))) + self.config = config + + # Patch embedding + self.add_module( + "patch_embedding", + CONV3D_WEIGHT_REGISTER["Default"]( + "patch_embedding.weight", "patch_embedding.bias", stride=self.patch_size + ), + ) + + # Text embedding (2-layer MLP with GELU) + self.add_module( + "text_embedding_0", + MM_WEIGHT_REGISTER["Default"]("text_embedding.0.weight", "text_embedding.0.bias"), + ) + self.add_module( + "text_embedding_2", + MM_WEIGHT_REGISTER["Default"]("text_embedding.2.weight", "text_embedding.2.bias"), + ) + + # Time embedding + self.add_module( + "time_embedding_0", + MM_WEIGHT_REGISTER["Default"]("time_embedding.0.weight", "time_embedding.0.bias"), + ) + self.add_module( + "time_embedding_2", + MM_WEIGHT_REGISTER["Default"]("time_embedding.2.weight", "time_embedding.2.bias"), + ) + self.add_module( + "time_projection_1", + MM_WEIGHT_REGISTER["Default"]("time_projection.1.weight", "time_projection.1.bias"), + ) + + # Camera plucker embedding (global, before blocks) + self.add_module( + "patch_embedding_wancamctrl", + MM_WEIGHT_REGISTER["Default"]( + "patch_embedding_wancamctrl.weight", "patch_embedding_wancamctrl.bias" + ), + ) + self.add_module( + "c2ws_hidden_states_layer1", + MM_WEIGHT_REGISTER["Default"]( + "c2ws_hidden_states_layer1.weight", "c2ws_hidden_states_layer1.bias" + ), + ) + self.add_module( + "c2ws_hidden_states_layer2", + MM_WEIGHT_REGISTER["Default"]( + "c2ws_hidden_states_layer2.weight", "c2ws_hidden_states_layer2.bias" + ), + ) diff --git a/lightx2v/models/networks/wan/weights/matrix_game3/transformer_weights.py b/lightx2v/models/networks/wan/weights/matrix_game3/transformer_weights.py new file mode 100644 index 000000000..15e64db34 --- /dev/null +++ b/lightx2v/models/networks/wan/weights/matrix_game3/transformer_weights.py @@ -0,0 +1,342 @@ +from lightx2v.common.modules.weight_module import WeightModule, WeightModuleList +from lightx2v.models.networks.wan.weights.transformer_weights import ( + WanCrossAttention, + WanFFN, + WanSelfAttention, +) +from lightx2v.utils.registry_factory import ( + ATTN_WEIGHT_REGISTER, + LN_WEIGHT_REGISTER, + MM_WEIGHT_REGISTER, + RMS_WEIGHT_REGISTER, + TENSOR_REGISTER, +) + + +class WanMtxg3TransformerWeights(WeightModule): + """Transformer weights for Matrix-Game-3.0. + + Each block has: + - SelfAttention (reused from base) + - CrossAttention (with norm3 / cross_attn_norm support) + - Camera injection layers (cam_injector_layer1/2, cam_scale_layer, cam_shift_layer) + - ActionModule (keyboard_embed, mouse_mlp, mouse/keyboard cross-attn, only on specified blocks) + - FFN (reused from base) + """ + + def __init__(self, config): + super().__init__() + self.blocks_num = config["num_layers"] + self.task = config["task"] + self.config = config + self.mm_type = config.get("dit_quant_scheme", "Default") + if self.mm_type != "Default": + assert config.get("dit_quantized") is True + + action_config = config.get("action_config", {}) + action_blocks = action_config.get("blocks", []) + + block_list = [] + for i in range(self.blocks_num): + has_action = i in action_blocks + block_list.append( + WanMtxg3TransformerBlock( + i, self.task, self.mm_type, self.config, has_action=has_action + ) + ) + self.blocks = WeightModuleList(block_list) + self.add_module("blocks", self.blocks) + + # Non-block weights (head) + self.register_parameter("norm", LN_WEIGHT_REGISTER["torch"]()) + self.add_module( + "head", MM_WEIGHT_REGISTER["Default"]("head.head.weight", "head.head.bias") + ) + self.register_parameter( + "head_modulation", TENSOR_REGISTER["Default"]("head.modulation") + ) + + def non_block_weights_to_cuda(self): + self.norm.to_cuda() + self.head.to_cuda() + self.head_modulation.to_cuda() + + def non_block_weights_to_cpu(self): + self.norm.to_cpu() + self.head.to_cpu() + self.head_modulation.to_cpu() + + +class WanMtxg3TransformerBlock(WeightModule): + """Single transformer block for MG3.0. + + Phases: + 0: SelfAttention + 1: CamInjection (per-block camera plucker scale/shift) + 2: CrossAttention + 3: ActionModule (only on action blocks; None placeholder otherwise) + 4: FFN + """ + + def __init__(self, block_index, task, mm_type, config, has_action=False, block_prefix="blocks"): + super().__init__() + self.block_index = block_index + self.mm_type = mm_type + self.task = task + self.config = config + self.has_action = has_action + + phases = [ + WanSelfAttention(block_index, block_prefix, task, mm_type, config), + WanMtxg3CamInjection(block_index, block_prefix, mm_type, config), + WanMtxg3CrossAttention(block_index, block_prefix, task, mm_type, config), + ] + if has_action: + phases.append(WanMtxg3ActionModule(block_index, block_prefix, task, mm_type, config)) + phases.append(WanFFN(block_index, block_prefix, task, mm_type, config)) + + self.compute_phases = WeightModuleList(phases) + self.add_module("compute_phases", self.compute_phases) + + +class WanMtxg3CamInjection(WeightModule): + """Per-block camera plucker injection weights. + + From the official MG3 WanAttentionBlock: + cam_injector_layer1, cam_injector_layer2: SiLU + residual MLP + cam_scale_layer, cam_shift_layer: scale/shift modulation + """ + + def __init__(self, block_index, block_prefix, mm_type, config): + super().__init__() + self.block_index = block_index + + self.add_module( + "cam_injector_layer1", + MM_WEIGHT_REGISTER[mm_type]( + f"{block_prefix}.{block_index}.cam_injector_layer1.weight", + f"{block_prefix}.{block_index}.cam_injector_layer1.bias", + ), + ) + self.add_module( + "cam_injector_layer2", + MM_WEIGHT_REGISTER[mm_type]( + f"{block_prefix}.{block_index}.cam_injector_layer2.weight", + f"{block_prefix}.{block_index}.cam_injector_layer2.bias", + ), + ) + self.add_module( + "cam_scale_layer", + MM_WEIGHT_REGISTER[mm_type]( + f"{block_prefix}.{block_index}.cam_scale_layer.weight", + f"{block_prefix}.{block_index}.cam_scale_layer.bias", + ), + ) + self.add_module( + "cam_shift_layer", + MM_WEIGHT_REGISTER[mm_type]( + f"{block_prefix}.{block_index}.cam_shift_layer.weight", + f"{block_prefix}.{block_index}.cam_shift_layer.bias", + ), + ) + + +class WanMtxg3CrossAttention(WeightModule): + """Cross-attention weights for MG3.0 blocks. + + Same as base WanCrossAttention but without the image encoder K/V/norm_k_img + branches (MG3.0 does not use a separate image encoder cross-attention). + Also includes norm3 with elementwise_affine=True (cross_attn_norm=True in MG3.0). + """ + + def __init__(self, block_index, block_prefix, task, mm_type, config): + super().__init__() + self.block_index = block_index + self.mm_type = mm_type + self.config = config + + if self.config.get("sf_config", False): + self.attn_rms_norm_type = self.config.get("rms_norm_type", "self_forcing") + else: + self.attn_rms_norm_type = self.config.get("rms_norm_type", "sgl-kernel") + + # norm3 with elementwise_affine=True for cross_attn_norm + self.add_module( + "norm3", + LN_WEIGHT_REGISTER["torch"]( + f"{block_prefix}.{block_index}.norm3.weight", + f"{block_prefix}.{block_index}.norm3.bias", + ), + ) + self.add_module( + "cross_attn_q", + MM_WEIGHT_REGISTER[mm_type]( + f"{block_prefix}.{block_index}.cross_attn.q.weight", + f"{block_prefix}.{block_index}.cross_attn.q.bias", + ), + ) + self.add_module( + "cross_attn_k", + MM_WEIGHT_REGISTER[mm_type]( + f"{block_prefix}.{block_index}.cross_attn.k.weight", + f"{block_prefix}.{block_index}.cross_attn.k.bias", + ), + ) + self.add_module( + "cross_attn_v", + MM_WEIGHT_REGISTER[mm_type]( + f"{block_prefix}.{block_index}.cross_attn.v.weight", + f"{block_prefix}.{block_index}.cross_attn.v.bias", + ), + ) + self.add_module( + "cross_attn_o", + MM_WEIGHT_REGISTER[mm_type]( + f"{block_prefix}.{block_index}.cross_attn.o.weight", + f"{block_prefix}.{block_index}.cross_attn.o.bias", + ), + ) + self.add_module( + "cross_attn_norm_q", + RMS_WEIGHT_REGISTER[self.attn_rms_norm_type]( + f"{block_prefix}.{block_index}.cross_attn.norm_q.weight", + ), + ) + self.add_module( + "cross_attn_norm_k", + RMS_WEIGHT_REGISTER[self.attn_rms_norm_type]( + f"{block_prefix}.{block_index}.cross_attn.norm_k.weight", + ), + ) + self.add_module("cross_attn_1", ATTN_WEIGHT_REGISTER[self.config["cross_attn_1_type"]]()) + + +class WanMtxg3ActionModule(WeightModule): + """ActionModule weights for MG3.0 blocks. + + From the official MG3 ActionModule: + - keyboard_embed (2-layer MLP with SiLU) + - mouse_mlp (Linear -> GELU -> Linear -> LayerNorm) + - t_qkv (combined QKV projection for mouse) + - proj_mouse, proj_keyboard (output projections) + - mouse_attn_q, keyboard_attn_kv (keyboard cross-attn projections) + - img_attn_q_norm, img_attn_k_norm, key_attn_q_norm, key_attn_k_norm (RMSNorm) + """ + + def __init__(self, block_index, block_prefix, task, mm_type, config): + super().__init__() + self.block_index = block_index + self.mm_type = mm_type + self.config = config + + attn_rms_norm_type = config.get("rms_norm_type", "sgl-kernel") + + # Keyboard embed (2-layer MLP) + self.add_module( + "keyboard_embed_0", + MM_WEIGHT_REGISTER[mm_type]( + f"{block_prefix}.{block_index}.action_model.keyboard_embed.0.weight", + f"{block_prefix}.{block_index}.action_model.keyboard_embed.0.bias", + ), + ) + self.add_module( + "keyboard_embed_2", + MM_WEIGHT_REGISTER[mm_type]( + f"{block_prefix}.{block_index}.action_model.keyboard_embed.2.weight", + f"{block_prefix}.{block_index}.action_model.keyboard_embed.2.bias", + ), + ) + + # Mouse MLP + self.add_module( + "mouse_mlp_0", + MM_WEIGHT_REGISTER[mm_type]( + f"{block_prefix}.{block_index}.action_model.mouse_mlp.0.weight", + f"{block_prefix}.{block_index}.action_model.mouse_mlp.0.bias", + ), + ) + self.add_module( + "mouse_mlp_2", + MM_WEIGHT_REGISTER[mm_type]( + f"{block_prefix}.{block_index}.action_model.mouse_mlp.2.weight", + f"{block_prefix}.{block_index}.action_model.mouse_mlp.2.bias", + ), + ) + self.add_module( + "mouse_mlp_3", + LN_WEIGHT_REGISTER["torch"]( + f"{block_prefix}.{block_index}.action_model.mouse_mlp.3.weight", + f"{block_prefix}.{block_index}.action_model.mouse_mlp.3.bias", + eps=1e-5, + ), + ) + + # Mouse QKV and projection + self.add_module( + "t_qkv", + MM_WEIGHT_REGISTER[mm_type]( + f"{block_prefix}.{block_index}.action_model.t_qkv.weight", + bias_name=None, + ), + ) + self.add_module( + "proj_mouse", + MM_WEIGHT_REGISTER[mm_type]( + f"{block_prefix}.{block_index}.action_model.proj_mouse.weight", + bias_name=None, + ), + ) + + # Mouse attention RMS norms + self.add_module( + "img_attn_q_norm", + RMS_WEIGHT_REGISTER[attn_rms_norm_type]( + f"{block_prefix}.{block_index}.action_model.img_attn_q_norm.weight", + ), + ) + self.add_module( + "img_attn_k_norm", + RMS_WEIGHT_REGISTER[attn_rms_norm_type]( + f"{block_prefix}.{block_index}.action_model.img_attn_k_norm.weight", + ), + ) + + # Keyboard cross-attn + self.add_module( + "mouse_attn_q", + MM_WEIGHT_REGISTER[mm_type]( + f"{block_prefix}.{block_index}.action_model.mouse_attn_q.weight", + bias_name=None, + ), + ) + self.add_module( + "keyboard_attn_kv", + MM_WEIGHT_REGISTER[mm_type]( + f"{block_prefix}.{block_index}.action_model.keyboard_attn_kv.weight", + bias_name=None, + ), + ) + self.add_module( + "proj_keyboard", + MM_WEIGHT_REGISTER[mm_type]( + f"{block_prefix}.{block_index}.action_model.proj_keyboard.weight", + bias_name=None, + ), + ) + + # Keyboard attention RMS norms + self.add_module( + "key_attn_q_norm", + RMS_WEIGHT_REGISTER[attn_rms_norm_type]( + f"{block_prefix}.{block_index}.action_model.key_attn_q_norm.weight", + ), + ) + self.add_module( + "key_attn_k_norm", + RMS_WEIGHT_REGISTER[attn_rms_norm_type]( + f"{block_prefix}.{block_index}.action_model.key_attn_k_norm.weight", + ), + ) + + # Flash attention module for action cross-attn + self.add_module("action_attn", ATTN_WEIGHT_REGISTER[self.config.get("cross_attn_1_type", "flash_attn2")]()) diff --git a/lightx2v/models/runners/wan/wan_matrix_game3_runner.py b/lightx2v/models/runners/wan/wan_matrix_game3_runner.py index 094d868ac..84c7ed9a4 100644 --- a/lightx2v/models/runners/wan/wan_matrix_game3_runner.py +++ b/lightx2v/models/runners/wan/wan_matrix_game3_runner.py @@ -23,14 +23,27 @@ from lightx2v_platform.base.global_var import AI_DEVICE -DEFAULT_MATRIX_GAME3_OFFICIAL_ROOT = Path("/home/michael/Project/LightX2V/Matrix-Game-3/Matrix-Game-3") -DEFAULT_MATRIX_GAME3_BASE_CONFIG = Path("/home/michael/Project/LightX2V/Matrix-Game-3.0/base_model/config.json") -DEFAULT_MATRIX_GAME3_DISTILLED_CONFIG = Path("/home/michael/Project/LightX2V/Matrix-Game-3.0/base_distilled_model/config.json") +_PROJECT_ROOT = Path(__file__).resolve().parents[4] +_MATRIX_GAME3_OFFICIAL_ROOT_RELATIVE_CANDIDATES = ( + Path("Matrix-Game-3") / "Matrix-Game-3", + Path("Matrix-Game-3"), +) +_MATRIX_GAME3_CONFIG_ROOT_RELATIVE = Path("Matrix-Game-3.0") _MATRIX_GAME3_OFFICIAL_PACKAGE = "_lightx2v_matrix_game3_official" @dataclass class MatrixGame3SegmentState: + """Precomputed inputs and bookkeeping for one Matrix-Game-3 segment. + + The runner generates video in overlapping chunks. For each chunk we cache: + - the absolute frame window covered by this segment; + - the latent tensor shape the scheduler should sample; + - how many latent frames are fixed by conditioning instead of sampled; + - the condition tensors that will be forwarded through `dit_cond_dict`; + - how many decoded RGB frames should be trimmed before concatenation. + """ + segment_idx: int first_clip: bool current_start_frame_idx: int @@ -47,6 +60,7 @@ class MatrixGame3SegmentState: def _load_module_from_path(module_name: str, file_path: Path): + """Import an official Matrix-Game-3 helper module by filesystem path once.""" if module_name in sys.modules: return sys.modules[module_name] spec = importlib.util.spec_from_file_location(module_name, file_path) @@ -59,6 +73,7 @@ def _load_module_from_path(module_name: str, file_path: Path): def _ensure_namespace_package(package_name: str, package_path: Path): + """Register a synthetic namespace package so relative imports inside official code work.""" if package_name in sys.modules: return sys.modules[package_name] module = types.ModuleType(package_name) @@ -67,6 +82,18 @@ def _ensure_namespace_package(package_name: str, package_path: Path): return module +def _expand_path_candidates(path_value: Any) -> list[Path]: + """Resolve a user-provided path against cwd and the project root when needed.""" + raw_path = Path(str(path_value)).expanduser() + if raw_path.is_absolute(): + return [raw_path] + candidates = [Path.cwd() / raw_path] + project_relative = _PROJECT_ROOT / raw_path + if project_relative != candidates[0]: + candidates.append(project_relative) + return candidates + + @RUNNER_REGISTER("wan2.2_matrix_game3") class WanMatrixGame3Runner(Wan22DenseRunner): """Runner-only Matrix-Game-3 adapter on top of the existing Wan2.2 lifecycle. @@ -78,10 +105,18 @@ class WanMatrixGame3Runner(Wan22DenseRunner): - Keyboard / mouse dimensions: utils/conditions.py - Pose / plucker helpers: utils/cam_utils.py and utils/utils.py - Structural config truth: Matrix-Game-3.0/*/config.json + + Execution model: + - Reuse Wan2.2 text encoder / scheduler / VAE lifecycle from `Wan22DenseRunner`. + - Replace the normal i2v input path with a first-frame-only conditioning scheme. + - Convert keyboard, mouse, and camera trajectories into per-segment DiT conditions. + - Roll latent history across overlapping segments, then trim duplicated decoded frames. """ def __init__(self, config): with config.temporarily_unlocked(): + # The public pipeline still instantiates us as "wan2.2_matrix_game3", but + # the shared Wan2.2 runner expects `model_cls == "wan2.2"` for common setup. original_model_cls = str(config.get("model_cls", "wan2.2_matrix_game3")) config["runner_model_cls"] = original_model_cls config["model_cls"] = "wan2.2" @@ -96,13 +131,21 @@ def __init__(self, config): super().__init__(config) self.matrix_game3_model_cls = original_model_cls - self.first_clip_frame = 57 - self.clip_frame = 56 - self.incremental_segment_frames = 40 - self.past_frame = 16 - self.conditioning_latent_frames = 4 - self.mouse_dim_in = 2 - self.keyboard_dim_in = 6 + # Official MG3 timeline convention: + # - first segment predicts 57 frames from the input image; + # - later segments operate on a 56-frame window; + # - every new segment contributes 40 new frames and reuses 16 historical frames. + action_config = self.config.get("action_config", {}) + self.first_clip_frame = int(self.config.get("first_clip_frame", 57)) + self.clip_frame = int(self.config.get("clip_frame", 56)) + self.incremental_segment_frames = int(self.config.get("incremental_segment_frames", 40)) + self.past_frame = int(self.config.get("past_frame", 16)) + self.conditioning_latent_frames = int(self.config.get("conditioning_latent_frames", 4)) + self.mouse_dim_in = int(self.config.get("mouse_dim_in", action_config.get("mouse_dim_in", 2))) + self.keyboard_dim_in = int(self.config.get("keyboard_dim_in", action_config.get("keyboard_dim_in", 6))) + + # Session-scoped caches filled by `_prepare_matrix_game3_session()` and then + # consumed incrementally as each segment is initialized and decoded. self._segment_states: dict[int, MatrixGame3SegmentState] = {} self._official_modules: Optional[dict[str, Any]] = None self._mg3_lat_h: Optional[int] = None @@ -127,6 +170,8 @@ def __init__(self, config): def set_inputs(self, inputs): super().set_inputs(inputs) + # Some callers still use `pose`, others use `action_path`. Mirror both so the + # runner remains compatible with older LightX2V entry points. if "action_path" in self.input_info.__dataclass_fields__: self.input_info.action_path = inputs.get("action_path", inputs.get("pose", "")) if "pose" in self.input_info.__dataclass_fields__: @@ -135,6 +180,8 @@ def set_inputs(self, inputs): def load_transformer(self): from lightx2v.models.networks.wan.matrix_game3_model import WanMtxg3Model + # The backbone is still a Wan2.2 DiT, but Matrix-Game-3 swaps in a dedicated + # network wrapper that understands keyboard / mouse / camera conditions. model_kwargs = { "model_path": self.config["model_path"], "config": self.config, @@ -145,14 +192,87 @@ def load_transformer(self): return WanMtxg3Model(**model_kwargs) return build_wan_model_with_lora(WanMtxg3Model, self.config, model_kwargs, lora_configs, model_type="wan2.2") - def _load_matrix_game3_model_config(self): - config_path = Path(self.config["model_path"]) / self.config["sub_model_folder"] / "config.json" - if not config_path.exists(): - config_path = DEFAULT_MATRIX_GAME3_BASE_CONFIG if self.config["use_base_model"] else DEFAULT_MATRIX_GAME3_DISTILLED_CONFIG - if not config_path.exists(): - logger.warning("matrix-game-3 config.json not found at {}", config_path) - return + def _get_sub_model_folder(self) -> str: + """Resolve which MG3 sub-model folder should be used for config lookup.""" + return str(self.config.get("sub_model_folder", "base_model" if self.config.get("use_base_model", False) else "base_distilled_model")) + + def _resolve_official_root_candidate(self, candidate: Path) -> Optional[Path]: + """Accept either the inner package root or its parent repository directory.""" + direct_root = candidate.expanduser() + if (direct_root / "generate.py").is_file() and (direct_root / "pipeline").is_dir() and (direct_root / "utils").is_dir(): + return direct_root + + nested_root = direct_root / "Matrix-Game-3" + if (nested_root / "generate.py").is_file() and (nested_root / "pipeline").is_dir() and (nested_root / "utils").is_dir(): + return nested_root + return None + + def resolve_official_root(self) -> Path: + """Resolve the official Matrix-Game-3 source root using config-first priority.""" + configured_root = self.config.get("matrix_game3_official_root") + if configured_root: + for candidate in _expand_path_candidates(configured_root): + resolved = self._resolve_official_root_candidate(candidate) + if resolved is not None: + return resolved + raise FileNotFoundError( + "Matrix-Game-3 official source root is missing or invalid for " + f"matrix_game3_official_root={configured_root!r}. " + "The runner needs the official utils/pipeline files to build camera and action conditions. " + "Please set config['matrix_game3_official_root'] to the official source root directory." + ) + + for relative_path in _MATRIX_GAME3_OFFICIAL_ROOT_RELATIVE_CANDIDATES: + resolved = self._resolve_official_root_candidate(_PROJECT_ROOT / relative_path) + if resolved is not None: + return resolved + + raise FileNotFoundError( + "Matrix-Game-3 official source root could not be resolved from the project layout. " + "The runner needs it to import official utils/conditions.py, utils/cam_utils.py, utils/utils.py, " + "and pipeline helpers. Please set config['matrix_game3_official_root'] explicitly." + ) + def resolve_model_config_path(self) -> Path: + """Resolve the MG3 base/distilled config.json with explicit override support.""" + configured_path = self.config.get("matrix_game3_config_path") + if configured_path: + for candidate in _expand_path_candidates(configured_path): + if candidate.is_file(): + return candidate + raise FileNotFoundError( + "Matrix-Game-3 config.json is missing for " + f"matrix_game3_config_path={configured_path!r}. " + "The runner needs this file to align latent channels, patch size, and action_config with the checkpoint. " + "Please set config['matrix_game3_config_path'] to a valid config.json path." + ) + + sub_model_folder = self._get_sub_model_folder() + candidates: list[Path] = [] + model_path = self.config.get("model_path") + if model_path: + for candidate_root in _expand_path_candidates(model_path): + candidate = candidate_root / sub_model_folder / "config.json" + if candidate not in candidates: + candidates.append(candidate) + candidates.append(_PROJECT_ROOT / _MATRIX_GAME3_CONFIG_ROOT_RELATIVE / sub_model_folder / "config.json") + + for candidate in candidates: + if candidate.is_file(): + return candidate + + checked_locations = ", ".join(str(candidate) for candidate in candidates) + raise FileNotFoundError( + "Matrix-Game-3 sub-model config.json could not be resolved. " + f"Checked: {checked_locations}. " + "The runner needs this file to determine the official base/distilled structure. " + "Please set config['matrix_game3_config_path'], or provide a valid config['model_path'] and " + "config['sub_model_folder'] (defaulted from config['use_base_model'])." + ) + + def _load_matrix_game3_model_config(self): + """Merge the official MG3 config so latent/channel sizes match the checkpoint.""" + config_path = self.resolve_model_config_path() with config_path.open("r") as f: model_config = json.load(f) @@ -163,31 +283,61 @@ def _load_matrix_game3_model_config(self): self.config["patch_size"] = tuple(model_config.get("patch_size", self.config.get("patch_size", (1, 2, 2)))) action_config = self.config.get("action_config", {}) - self.keyboard_dim_in = int(action_config.get("keyboard_dim_in", 6)) - self.mouse_dim_in = int(action_config.get("mouse_dim_in", 2)) + self.keyboard_dim_in = int(self.config.get("keyboard_dim_in", action_config.get("keyboard_dim_in", 6))) + self.mouse_dim_in = int(self.config.get("mouse_dim_in", action_config.get("mouse_dim_in", 2))) def _get_official_modules(self) -> dict[str, Any]: + """Lazy-load helper code from the official Matrix-Game-3 repository. + + We intentionally reuse the official camera/action utilities instead of + re-implementing pose math in the LightX2V runner. + """ if self._official_modules is not None: return self._official_modules - official_root = Path(self.config.get("matrix_game3_official_root", DEFAULT_MATRIX_GAME3_OFFICIAL_ROOT)) - if not official_root.exists(): - raise FileNotFoundError(f"Matrix-Game-3 official root not found: {official_root}") + official_root = self.resolve_official_root() + utils_root = official_root / "utils" + if not utils_root.is_dir(): + raise FileNotFoundError( + f"Matrix-Game-3 utils directory is missing under {official_root}. " + "The runner needs the official utils modules to construct action and camera conditions. " + "Please set config['matrix_game3_official_root'] to the official source root directory." + ) + + required_utils = { + "conditions": utils_root / "conditions.py", + "cam_utils": utils_root / "cam_utils.py", + "transform": utils_root / "transform.py", + "utils": utils_root / "utils.py", + } + missing_utils = [str(path) for path in required_utils.values() if not path.is_file()] + if missing_utils: + raise FileNotFoundError( + "Matrix-Game-3 official utility files are incomplete. " + f"Missing: {missing_utils}. " + "The runner needs these files to reuse the official action/camera preprocessing. " + "Please set config['matrix_game3_official_root'] to a complete official source checkout." + ) _ensure_namespace_package(_MATRIX_GAME3_OFFICIAL_PACKAGE, official_root) utils_pkg = f"{_MATRIX_GAME3_OFFICIAL_PACKAGE}.utils" - _ensure_namespace_package(utils_pkg, official_root / "utils") + _ensure_namespace_package(utils_pkg, utils_root) modules = { - "conditions": _load_module_from_path(f"{utils_pkg}.conditions", official_root / "utils" / "conditions.py"), - "cam_utils": _load_module_from_path(f"{utils_pkg}.cam_utils", official_root / "utils" / "cam_utils.py"), - "transform": _load_module_from_path(f"{utils_pkg}.transform", official_root / "utils" / "transform.py"), - "utils": _load_module_from_path(f"{utils_pkg}.utils", official_root / "utils" / "utils.py"), + "conditions": _load_module_from_path(f"{utils_pkg}.conditions", required_utils["conditions"]), + "cam_utils": _load_module_from_path(f"{utils_pkg}.cam_utils", required_utils["cam_utils"]), + "transform": _load_module_from_path(f"{utils_pkg}.transform", required_utils["transform"]), + "utils": _load_module_from_path(f"{utils_pkg}.utils", required_utils["utils"]), } self._official_modules = modules return modules def _get_expected_total_frames(self, raw_total_frames: Optional[int] = None) -> tuple[int, int]: + """Resolve how many segments to run. + + Matrix-Game-3 only supports lengths of `57 + 40 * k`. If a control sequence + does not align exactly, the tail is ignored so the segment schedule stays valid. + """ num_iterations = self.config.get("num_iterations", None) if num_iterations is not None: num_iterations = max(int(num_iterations), 1) @@ -211,6 +361,7 @@ def _get_expected_total_frames(self, raw_total_frames: Optional[int] = None) -> return num_iterations, expected_total_frames def _segment_latent_shape(self, lat_h: int, lat_w: int, frame_count: int) -> list[int]: + """Compute `[C, T, H, W]` latent shape for one segment window.""" return [ self.config.get("num_channels_latents", 48), (frame_count - 1) // self.config["vae_stride"][0] + 1, @@ -225,6 +376,8 @@ def _segment_latent_shape(self, lat_h: int, lat_w: int, frame_count: int) -> lis metrics_labels=["WanMatrixGame3Runner"], ) def run_vae_encoder(self, img): + # Unlike the generic Wan2.2 i2v path, MG3 only encodes the first frame. The + # remaining temporal slots are left zeroed and later mixed with scheduler noise. max_area = self.config.target_height * self.config.target_width ih, iw = img.height, img.width dh = self.config.patch_size[1] * self.config.vae_stride[1] @@ -248,6 +401,10 @@ def run_vae_encoder(self, img): @ProfilingContext4DebugL2("Run Encoders") def _run_input_encoder_local_i2v(self): + # MG3 does not use the CLIP image encoder branch. The conditioning payload is: + # - text encoder output from the normal Wan pipeline; + # - a first-frame VAE latent; + # - segment metadata prepared for later `init_run_segment()` calls. _, img_ori = self.read_image_input(self.input_info.image_path) vae_encoder_out, latent_shape = self.run_vae_encoder(img_ori) self.input_info.latent_shape = latent_shape @@ -257,6 +414,8 @@ def _run_input_encoder_local_i2v(self): return self.get_encoder_output_i2v(None, vae_encoder_out, text_encoder_output) def get_encoder_output_i2v(self, clip_encoder_out, vae_encoder_out, text_encoder_output, img=None): + # Keep the standard LightX2V output contract so downstream scheduler / model + # code can stay unchanged. Segment-specific conditions are injected later. image_encoder_output = { "clip_encoder_out": clip_encoder_out, "vae_encoder_out": vae_encoder_out, @@ -272,6 +431,12 @@ def _prepare_matrix_game3_session(self, pil_image: Image.Image, latent_shape: li # - Non-interactive path mirrors pipeline/inference_pipeline.py # - Interactive segment refreshing mirrors pipeline/inference_interactive_pipeline.py # - Camera/action fallback semantics follow the user's requested runner contract + # + # This method performs all once-per-request setup: + # - resolve spatial sizes used by camera/plucker helpers; + # - reset cached segment state and latent history; + # - pre-load the entire control sequence for offline mode; or + # - defer control acquisition to segment time for interactive mode. self._get_official_modules() self._segment_states.clear() self._mg3_generated_latent_history = [] @@ -303,6 +468,7 @@ def _prepare_matrix_game3_session(self, pil_image: Image.Image, latent_shape: li self._mg3_keyboard_all, self._mg3_mouse_all, self._mg3_extrinsics_all, self._mg3_intrinsics_all = self._build_noninteractive_controls(raw_controls) def _infer_raw_total_frames(self, payload: dict[str, Any]) -> Optional[int]: + """Infer sequence length from whichever control tensor is present.""" lengths = [] for value in payload.values(): if value is None: @@ -318,6 +484,7 @@ def _infer_raw_total_frames(self, payload: dict[str, Any]) -> Optional[int]: return max(lengths) if lengths else None def _load_control_payload(self, action_path: str) -> dict[str, Any]: + """Load keyboard/mouse/pose/intrinsics controls from a file or a directory.""" if not action_path: logger.warning("[matrix-game-3] action_path missing, fallback to zero keyboard/mouse and identity poses.") return {} @@ -332,6 +499,7 @@ def _load_control_payload(self, action_path: str) -> dict[str, Any]: return self._load_control_payload_from_file(path) def _load_control_payload_from_dir(self, path: Path) -> dict[str, Any]: + """Best-effort directory loader that accepts several common file names.""" payload: dict[str, Any] = {} name_groups = { "keyboard_cond": ["keyboard_cond.npy", "keyboard_condition.npy", "keyboard_cond.pt", "keyboard_condition.pt", "keyboard_cond.json", "keyboard_condition.json"], @@ -349,6 +517,7 @@ def _load_control_payload_from_dir(self, path: Path) -> dict[str, Any]: return payload def _load_control_payload_from_file(self, path: Path) -> dict[str, Any]: + """Parse one control file and map it onto the normalized payload schema.""" suffix = path.suffix.lower() stem = path.stem.lower() if suffix == ".npz": @@ -378,6 +547,7 @@ def _load_control_payload_from_file(self, path: Path) -> dict[str, Any]: raise ValueError(f"unsupported action_path file name: {path}") def _normalize_payload_keys(self, data: dict[str, Any]) -> dict[str, Any]: + """Collapse different naming conventions into the runner's canonical keys.""" payload: dict[str, Any] = {} key_aliases = { "keyboard_cond": {"keyboard_cond", "keyboard_condition"}, @@ -393,11 +563,13 @@ def _normalize_payload_keys(self, data: dict[str, Any]) -> dict[str, Any]: return payload def _default_intrinsics(self) -> torch.Tensor: + """Generate the default camera intrinsics for the current output resolution.""" modules = self._get_official_modules() assert self._mg3_target_h is not None and self._mg3_target_w is not None return modules["cam_utils"].get_intrinsics(self._mg3_target_h, self._mg3_target_w) def _to_tensor(self, value: Any, dtype=torch.float32) -> Optional[torch.Tensor]: + """Convert numpy/list/scalar inputs into CPU tensors for normalization.""" if value is None: return None if isinstance(value, torch.Tensor): @@ -409,6 +581,9 @@ def _to_tensor(self, value: Any, dtype=torch.float32) -> Optional[torch.Tensor]: return torch.tensor(value, dtype=dtype) def _resize_time_axis(self, tensor: torch.Tensor, total_frames: int) -> torch.Tensor: + # MG3 expects exact per-frame control lengths. To make the runner tolerant of + # slightly malformed inputs, short sequences are padded by repeating the last + # value and long sequences are truncated. if tensor.shape[0] == total_frames: return tensor if tensor.shape[0] == 1: @@ -429,6 +604,7 @@ def _resize_time_axis(self, tensor: torch.Tensor, total_frames: int) -> torch.Te return tensor[:total_frames] def _normalize_keyboard_cond(self, value: Any, total_frames: int) -> torch.Tensor: + """Normalize keyboard controls into `[1, T, keyboard_dim_in]`.""" if value is None: return torch.zeros((1, total_frames, self.keyboard_dim_in), dtype=torch.float32) tensor = self._to_tensor(value) @@ -442,6 +618,7 @@ def _normalize_keyboard_cond(self, value: Any, total_frames: int) -> torch.Tenso return tensor.unsqueeze(0) def _normalize_mouse_cond(self, value: Any, total_frames: int) -> torch.Tensor: + """Normalize mouse controls into `[1, T, mouse_dim_in]`.""" if value is None: return torch.zeros((1, total_frames, self.mouse_dim_in), dtype=torch.float32) tensor = self._to_tensor(value) @@ -455,6 +632,7 @@ def _normalize_mouse_cond(self, value: Any, total_frames: int) -> torch.Tensor: return tensor.unsqueeze(0) def _normalize_intrinsics(self, value: Any, total_frames: int) -> Optional[torch.Tensor]: + """Accept either flattened `[fx, fy, cx, cy]` or 3x3 intrinsics matrices.""" if value is None: return None tensor = self._to_tensor(value) @@ -470,10 +648,13 @@ def _normalize_intrinsics(self, value: Any, total_frames: int) -> Optional[torch return self._resize_time_axis(tensor, total_frames) def _normalize_poses(self, value: Any, total_frames: int) -> Optional[torch.Tensor]: + """Normalize poses into `[T, 4, 4]` camera-to-world extrinsics.""" if value is None: return None tensor = self._to_tensor(value) if tensor.ndim == 2 and tensor.shape[-1] == 5: + # The official action pipeline also uses a compact 5D pose + # `[x, y, z, pitch, yaw]`. Convert it here to full extrinsics. modules = self._get_official_modules() rotations = np.concatenate([np.zeros((tensor.shape[0], 1), dtype=np.float32), tensor[:, 3:5].numpy()], axis=1).tolist() positions = tensor[:, :3].numpy().tolist() @@ -487,6 +668,9 @@ def _build_noninteractive_controls(self, payload: dict[str, Any]) -> tuple[torch # Official source: # - utils/conditions.py defines keyboard_dim_in=6 and mouse_dim_in=2 semantics # - utils/utils.py computes poses from actions when explicit poses are absent + # + # Offline mode materializes the whole control trajectory up front so later + # segments only need cheap slicing instead of re-reading user inputs. total_frames = self._mg3_expected_total_frames keyboard_cond = self._normalize_keyboard_cond(payload.get("keyboard_cond"), total_frames) mouse_cond = self._normalize_mouse_cond(payload.get("mouse_cond"), total_frames) @@ -496,9 +680,12 @@ def _build_noninteractive_controls(self, payload: dict[str, Any]) -> tuple[torch if poses is None: modules = self._get_official_modules() if not payload: + # No action file at all: keep the camera fixed at identity. identity_pose = torch.eye(4, dtype=torch.float32).unsqueeze(0).repeat(total_frames, 1, 1) poses = identity_pose else: + # Action file exists but explicit poses do not: reconstruct camera motion + # with the official action-to-pose integrator. first_pose = np.zeros(5, dtype=np.float32) all_poses = modules["utils"].compute_all_poses_from_actions( keyboard_cond.squeeze(0).cpu(), @@ -514,6 +701,8 @@ def get_video_segment_num(self): self.video_segment_num = self._mg3_num_iterations def init_run(self): + # This mostly mirrors `DefaultRunner.init_run()`, but we immediately override + # the scheduler state with the first segment's custom latent/mask setup. self.gen_video_final = None self.get_video_segment_num() self._mg3_noise_generator = torch.Generator(device=AI_DEVICE).manual_seed(self.input_info.seed) @@ -531,12 +720,15 @@ def init_run(self): self.inputs["image_encoder_output"]["vae_encoder_out"] = None def _append_interactive_segment_controls(self, segment_idx: int): + """Collect one segment worth of controls from stdin in interactive mode.""" modules = self._get_official_modules() first_clip = segment_idx == 0 action_frames = self.first_clip_frame if first_clip else self.incremental_segment_frames if not dist.is_initialized() or dist.get_rank() == 0: actions = self._prompt_current_action() + # The prompt returns one action token; MG3 applies it uniformly across the + # newly generated frame span for that segment. keyboard_curr = actions["keyboard"].repeat(action_frames, 1) mouse_curr = actions["mouse"].repeat(action_frames, 1) if first_clip: @@ -574,11 +766,13 @@ def _append_interactive_segment_controls(self, segment_idx: int): self._mg3_mouse_all = mouse_curr self._mg3_extrinsics_all = extrinsics_curr else: + # Interactive mode grows the global control timeline as segments progress. self._mg3_keyboard_all = torch.cat([self._mg3_keyboard_all, keyboard_curr], dim=1) self._mg3_mouse_all = torch.cat([self._mg3_mouse_all, mouse_curr], dim=1) self._mg3_extrinsics_all = torch.cat([self._mg3_extrinsics_all, extrinsics_curr], dim=0) def _prompt_current_action(self) -> dict[str, torch.Tensor]: + """Minimal CLI UX for interactive MG3 generation.""" cam_value = 0.1 print() print("-" * 30) @@ -612,6 +806,7 @@ def _prompt_current_action(self) -> dict[str, torch.Tensor]: } def _interpolate_intrinsics(self, intrinsics_seq: Optional[torch.Tensor], src_indices: np.ndarray, tgt_indices: np.ndarray) -> torch.Tensor: + """Interpolate intrinsics onto the latent timeline used by the DiT.""" assert self._mg3_base_intrinsics is not None if intrinsics_seq is None: return self._mg3_base_intrinsics.to(dtype=torch.float32).repeat(len(tgt_indices), 1) @@ -641,6 +836,9 @@ def _build_plucker_from_c2ws( # Official source: # - utils/cam_utils.py: interpolate poses, compute relative poses, build plucker rays # - utils/utils.py: build_plucker_from_c2ws reshaping convention + # + # The model consumes camera control as plucker ray embeddings aligned to latent + # time and latent spatial resolution, not as raw pose matrices. modules = self._get_official_modules() assert self._mg3_target_h is not None and self._mg3_target_w is not None assert self._mg3_lat_h is not None and self._mg3_lat_w is not None @@ -651,6 +849,8 @@ def _build_plucker_from_c2ws( src_trans_vec=c2ws_np[:, :3, 3], tgt_indices=tgt_indices, ).to(device=c2ws_seq.device) + # `framewise=True` means each timestep is represented relative to its own local + # frame history, which matches the official per-segment conditioning path. c2ws_infer = modules["cam_utils"].compute_relative_poses(c2ws_infer, framewise=framewise) Ks = self._interpolate_intrinsics(intrinsics_seq, src_indices, tgt_indices).to(device=c2ws_infer.device, dtype=c2ws_infer.dtype) plucker = modules["cam_utils"].get_plucker_embeddings(c2ws_infer, Ks, self._mg3_target_h, self._mg3_target_w) @@ -672,6 +872,7 @@ def _build_plucker_from_c2ws( ) def _build_plucker_from_pose(self, c2ws_pose: torch.Tensor, intrinsics_seq: Optional[torch.Tensor] = None) -> torch.Tensor: + """Build plucker embeddings when poses are already on the target timeline.""" modules = self._get_official_modules() assert self._mg3_target_h is not None and self._mg3_target_w is not None assert self._mg3_lat_h is not None and self._mg3_lat_w is not None @@ -701,6 +902,10 @@ def _build_memory_metadata(self, segment_idx: int, current_start_frame_idx: int, # Official source: pipeline/inference_pipeline.py and utils/cam_utils.py. # Current downstream model code only requires c2ws_plucker_emb / keyboard_cond / mouse_cond, # but we still stage the memory-facing metadata here so the runner owns segment bookkeeping. + # + # Later segments can attend to a sparse set of previously generated latent + # frames. This method selects those frames, prepares their latent indices, and + # builds the matching plucker embeddings for the memory branch. if segment_idx == 0 or not self._mg3_generated_latent_history: return { "x_memory": None, @@ -729,6 +934,7 @@ def get_latent_idx(frame_idx: int) -> int: use_gpu=torch.cuda.is_available(), ) if selected_index: + # The official code hard-pins the oldest memory anchor to frame 4. selected_index[-1] = 4 memory_pluckers = [] @@ -780,6 +986,7 @@ def get_latent_idx(frame_idx: int) -> int: } def _build_or_get_segment_camera_only(self, segment_idx: int) -> torch.Tensor: + """Access just the camera plucker embedding without rebuilding other state.""" state = self._segment_states.get(segment_idx) if state is not None and "c2ws_plucker_emb" in state.dit_cond_dict: return state.dit_cond_dict["c2ws_plucker_emb"] @@ -787,6 +994,14 @@ def _build_or_get_segment_camera_only(self, segment_idx: int) -> torch.Tensor: return state.dit_cond_dict["c2ws_plucker_emb"] def _build_or_get_segment_state(self, segment_idx: int) -> MatrixGame3SegmentState: + """Materialize one segment's complete conditioning package. + + This is the core of the adapter. It decides: + - which absolute frames this segment covers; + - which latent timesteps are fixed from prior context; + - which camera/action conditions should be sliced for this window; + - which overlap should be trimmed after decoding. + """ if segment_idx in self._segment_states: return self._segment_states[segment_idx] @@ -807,6 +1022,8 @@ def get_latent_idx(frame_idx: int) -> int: latent_start_idx = get_latent_idx(current_start_frame_idx) latent_end_idx = get_latent_idx(current_end_frame_idx) fixed_latent_frames = 1 if first_clip else self.conditioning_latent_frames + # After decoding, the first RGB frames of every later segment correspond to + # history that was already emitted by the previous segment, so they are dropped. decode_trim_frames = 0 if first_clip else 1 + self.config["vae_stride"][0] * (fixed_latent_frames - 1) append_latent_start = 0 if first_clip else fixed_latent_frames @@ -818,6 +1035,9 @@ def get_latent_idx(frame_idx: int) -> int: intrinsics_chunk = self._mg3_intrinsics_all[current_start_frame_idx:current_end_frame_idx] latent_shape = self._segment_latent_shape(self._mg3_lat_h, self._mg3_lat_w, frame_count) + # The latent timeline is coarser than RGB time because Wan2.2 uses a temporal + # VAE stride of 4. Later segments start interpolation at `start + 3` so the + # first 4 latent slots line up with the carried-over conditioning tail. tgt_indices = np.linspace(0 if first_clip else current_start_frame_idx + 3, current_end_frame_idx - 1, latent_shape[1]) camera_only = self._build_plucker_from_c2ws( @@ -833,10 +1053,13 @@ def get_latent_idx(frame_idx: int) -> int: vae_encoder_out = torch.zeros(latent_shape, device=AI_DEVICE, dtype=GET_DTYPE()) if first_clip: + # Segment 0 is anchored by the input image latent in the first temporal slot. vae_encoder_out[:, :1] = self.inputs["image_encoder_output"]["vae_encoder_out"][:, :1] else: if self._mg3_tail_latents is None: raise RuntimeError("matrix-game-3 segment requested without previous tail latents") + # Later segments are conditioned on the last 4 latent frames produced by the + # previous segment, which creates temporal continuity across chunk boundaries. vae_encoder_out[:, : self.conditioning_latent_frames] = self._mg3_tail_latents.to(device=AI_DEVICE, dtype=GET_DTYPE()) # Fields below intentionally stay in the standard LightX2V image_encoder_output["dit_cond_dict"] @@ -871,6 +1094,7 @@ def get_latent_idx(frame_idx: int) -> int: return state def _apply_segment_scheduler_state(self, segment_state: MatrixGame3SegmentState): + """Seed the scheduler latents and mask for the current segment window.""" scheduler = self.model.scheduler latents = torch.randn( tuple(segment_state.latent_shape), @@ -880,6 +1104,8 @@ def _apply_segment_scheduler_state(self, segment_state: MatrixGame3SegmentState) ) scheduler.vae_encoder_out = segment_state.vae_encoder_out.to(device=AI_DEVICE, dtype=torch.float32) scheduler.mask = torch.ones_like(latents) + # Mask value 0 means "keep the provided latent conditioning", while 1 means + # "sample this slot from noise through the diffusion process". scheduler.mask[:, : segment_state.fixed_latent_frames] = 0 scheduler.latents = (1.0 - scheduler.mask) * scheduler.vae_encoder_out + scheduler.mask * latents @@ -893,6 +1119,10 @@ def init_run_segment(self, segment_idx): # Official source: pipeline/inference_pipeline.py and inference_interactive_pipeline.py # refresh per-segment action / camera / latent-conditioning state here so the outer lifecycle # remains the standard LightX2V segment loop. + # + # The base runner calls this before every segment. We use that hook to swap in + # the next segment's control tensors and, for later segments, reset the scheduler + # so it samples against the new latent shape and conditioning mask. self.segment_idx = segment_idx segment_state = self._build_or_get_segment_state(segment_idx) self._mg3_current_segment_state = segment_state @@ -904,22 +1134,29 @@ def init_run_segment(self, segment_idx): self._apply_segment_scheduler_state(segment_state) def run_segment(self, segment_idx=0): + # Save the raw latent output before the VAE decoder trims or converts anything; + # the next segment needs these latents for temporal conditioning and memory. latents = super().run_segment(segment_idx) self._mg3_current_segment_full_latents = latents.detach().clone() return latents def end_run_segment(self, segment_idx=None): + """Carry segment outputs forward and remove overlap from decoded frames.""" if self._mg3_current_segment_state is None or self._mg3_current_segment_full_latents is None: raise RuntimeError("matrix-game-3 end_run_segment called before the current segment state was prepared") full_latents = self._mg3_current_segment_full_latents # full_latents follows Wan2.2 runner convention: [C, T, H, W]. + # Keep only the tail that should condition the next segment. self._mg3_tail_latents = full_latents[:, -self.conditioning_latent_frames :].detach().clone() + # Only append genuinely new latent timesteps to history; the carried-over prefix + # belongs to the previous segment and would otherwise duplicate memory entries. new_latents = full_latents[:, self._mg3_current_segment_state.append_latent_start :].detach().clone() self._mg3_generated_latent_history.append(new_latents) segment_video = self.gen_video if self._mg3_current_segment_state.decode_trim_frames > 0: + # Remove RGB frames that correspond to the reused latent prefix. segment_video = segment_video[:, :, self._mg3_current_segment_state.decode_trim_frames :] self.gen_video = segment_video self.gen_video_final = segment_video if self.gen_video_final is None else torch.cat([self.gen_video_final, segment_video], dim=2) @@ -927,6 +1164,8 @@ def end_run_segment(self, segment_idx=None): self._mg3_current_segment_full_latents = None def process_images_after_vae_decoder(self): + # `DefaultRunner.process_images_after_vae_decoder()` expects `gen_video_final` + # to already contain the full stitched clip. if self.gen_video_final is None: self.gen_video_final = self.gen_video return super().process_images_after_vae_decoder() diff --git a/scripts/matrix_game3/run_matrix_game3_base.sh b/scripts/matrix_game3/run_matrix_game3_base.sh new file mode 100755 index 000000000..e49810fab --- /dev/null +++ b/scripts/matrix_game3/run_matrix_game3_base.sh @@ -0,0 +1,19 @@ +#!/bin/bash +# Run Matrix-Game-3.0 base model inference via LightX2V +# Usage: ./run_matrix_game3_base.sh + +# Set model path (update this to your local Matrix-Game-3.0 model directory) +MODEL_PATH="${MODEL_PATH:-/path/to/Matrix-Game-3.0}" +CONFIG_JSON="configs/matrix_game3/matrix_game3_base.json" +SAVE_PATH="${SAVE_PATH:-save_results/matrix_game3_base}" + +python -m lightx2v.infer \ + --model_cls wan2.2_matrix_game3 \ + --task i2v \ + --model_path "${MODEL_PATH}" \ + --config_json "${CONFIG_JSON}" \ + --prompt "a city street scene with cars and pedestrians" \ + --image_path "${IMAGE_PATH:-Matrix-Game-3/Matrix-Game-3/demo_images/001/image.png}" \ + --action_path "${ACTION_PATH:-}" \ + --save_result_path "${SAVE_PATH}" \ + --seed 42 diff --git a/scripts/matrix_game3/run_matrix_game3_distilled.sh b/scripts/matrix_game3/run_matrix_game3_distilled.sh new file mode 100755 index 000000000..9cb8e078a --- /dev/null +++ b/scripts/matrix_game3/run_matrix_game3_distilled.sh @@ -0,0 +1,19 @@ +#!/bin/bash +# Run Matrix-Game-3.0 distilled model inference via LightX2V +# Usage: ./run_matrix_game3_distilled.sh + +# Set model path (update this to your local Matrix-Game-3.0 model directory) +MODEL_PATH="${MODEL_PATH:-/path/to/Matrix-Game-3.0}" +CONFIG_JSON="configs/matrix_game3/matrix_game3_distilled.json" +SAVE_PATH="${SAVE_PATH:-save_results/matrix_game3_distilled}" + +python -m lightx2v.infer \ + --model_cls wan2.2_matrix_game3 \ + --task i2v \ + --model_path "${MODEL_PATH}" \ + --config_json "${CONFIG_JSON}" \ + --prompt "a city street scene with cars and pedestrians" \ + --image_path "${IMAGE_PATH:-Matrix-Game-3/Matrix-Game-3/demo_images/001/image.png}" \ + --action_path "${ACTION_PATH:-}" \ + --save_result_path "${SAVE_PATH}" \ + --seed 42 From 0e9d244b0370c81fa0b095700ea32bb06de24c16 Mon Sep 17 00:00:00 2001 From: Yang Date: Mon, 6 Apr 2026 14:51:05 +0800 Subject: [PATCH 03/25] Add the fix --- lightx2v/pipeline.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/lightx2v/pipeline.py b/lightx2v/pipeline.py index b6b48096c..75658fa6d 100755 --- a/lightx2v/pipeline.py +++ b/lightx2v/pipeline.py @@ -10,7 +10,10 @@ import torch.distributed as dist from loguru import logger -from lightx2v.models.runners.flux2_klein.flux2_klein_runner import Flux2KleinRunner # noqa: F401 +try: + from lightx2v.models.runners.flux2_klein.flux2_klein_runner import Flux2KleinRunner # noqa: F401 +except (ImportError, ModuleNotFoundError) as e: + logger.warning(f"Flux2KleinRunner not available: {e}") from lightx2v.models.runners.hunyuan_video.hunyuan_video_15_runner import HunyuanVideo15Runner # noqa: F401 from lightx2v.models.runners.longcat_image.longcat_image_runner import LongCatImageRunner # noqa: F401 from lightx2v.models.runners.ltx2.ltx2_runner import LTX2Runner # noqa: F401 From a70d8f9b1970cdfcc6ce010188aa62e987c4e0fd Mon Sep 17 00:00:00 2001 From: Yang Date: Mon, 6 Apr 2026 15:04:10 +0800 Subject: [PATCH 04/25] Modify the runner --- lightx2v/models/runners/wan/wan_matrix_game3_runner.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/lightx2v/models/runners/wan/wan_matrix_game3_runner.py b/lightx2v/models/runners/wan/wan_matrix_game3_runner.py index 84c7ed9a4..b2be0bc38 100644 --- a/lightx2v/models/runners/wan/wan_matrix_game3_runner.py +++ b/lightx2v/models/runners/wan/wan_matrix_game3_runner.py @@ -16,12 +16,14 @@ from lightx2v.models.runners.wan.wan_runner import Wan22DenseRunner, build_wan_model_with_lora from lightx2v.server.metrics import monitor_cli -from lightx2v.utils.envs import GET_DTYPE, torch_device_module +from lightx2v.utils.envs import GET_DTYPE from lightx2v.utils.profiler import GET_RECORDER_MODE, ProfilingContext4DebugL1, ProfilingContext4DebugL2 from lightx2v.utils.registry_factory import RUNNER_REGISTER from lightx2v.utils.utils import best_output_size from lightx2v_platform.base.global_var import AI_DEVICE +torch_device_module = getattr(torch, AI_DEVICE) + _PROJECT_ROOT = Path(__file__).resolve().parents[4] _MATRIX_GAME3_OFFICIAL_ROOT_RELATIVE_CANDIDATES = ( From 6c0c891e74e64a0065624861e076bdd8aceb2fcc Mon Sep 17 00:00:00 2001 From: Yang Date: Mon, 6 Apr 2026 15:07:30 +0800 Subject: [PATCH 05/25] Add the fix --- lightx2v/infer.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/lightx2v/infer.py b/lightx2v/infer.py index f2029a8aa..eeb4fbf33 100755 --- a/lightx2v/infer.py +++ b/lightx2v/infer.py @@ -7,7 +7,10 @@ from lightx2v.common.ops import * from lightx2v.models.runners.bagel.bagel_runner import BagelRunner # noqa: F401 -from lightx2v.models.runners.flux2_klein.flux2_klein_runner import Flux2KleinRunner # noqa: F401 +try: + from lightx2v.models.runners.flux2_klein.flux2_klein_runner import Flux2KleinRunner # noqa: F401 +except (ImportError, ModuleNotFoundError): + pass # Already warned in pipeline.py from lightx2v.models.runners.hunyuan_video.hunyuan_video_15_distill_runner import HunyuanVideo15DistillRunner # noqa: F401 from lightx2v.models.runners.hunyuan_video.hunyuan_video_15_runner import HunyuanVideo15Runner # noqa: F401 from lightx2v.models.runners.longcat_image.longcat_image_runner import LongCatImageRunner # noqa: F401 From 9e20a4e586ec6c64ec9eacef1afbf6c4dc670b05 Mon Sep 17 00:00:00 2001 From: Yang Date: Mon, 6 Apr 2026 15:21:42 +0800 Subject: [PATCH 06/25] modify the json --- configs/matrix_game3/matrix_game3_base.json | 1 - configs/matrix_game3/matrix_game3_distilled.json | 1 - 2 files changed, 2 deletions(-) diff --git a/configs/matrix_game3/matrix_game3_base.json b/configs/matrix_game3/matrix_game3_base.json index 6c6d70b25..ce90d0da5 100644 --- a/configs/matrix_game3/matrix_game3_base.json +++ b/configs/matrix_game3/matrix_game3_base.json @@ -1,7 +1,6 @@ { "model_cls": "wan2.2_matrix_game3", "task": "i2v", - "model_path": "", "sub_model_folder": "base_model", "use_base_model": true, diff --git a/configs/matrix_game3/matrix_game3_distilled.json b/configs/matrix_game3/matrix_game3_distilled.json index 011fcc0a2..e2d179f88 100644 --- a/configs/matrix_game3/matrix_game3_distilled.json +++ b/configs/matrix_game3/matrix_game3_distilled.json @@ -1,7 +1,6 @@ { "model_cls": "wan2.2_matrix_game3", "task": "i2v", - "model_path": "", "sub_model_folder": "base_distilled_model", "use_base_model": false, From 090746ec45af615e00888aa80474856cf8f8b6df Mon Sep 17 00:00:00 2001 From: Yang Date: Mon, 6 Apr 2026 15:34:26 +0800 Subject: [PATCH 07/25] Fix the infer_step --- configs/matrix_game3/matrix_game3_base.json | 2 +- configs/matrix_game3/matrix_game3_distilled.json | 2 +- lightx2v/utils/set_config.py | 5 +++++ 3 files changed, 7 insertions(+), 2 deletions(-) diff --git a/configs/matrix_game3/matrix_game3_base.json b/configs/matrix_game3/matrix_game3_base.json index ce90d0da5..5c8b0cc81 100644 --- a/configs/matrix_game3/matrix_game3_base.json +++ b/configs/matrix_game3/matrix_game3_base.json @@ -11,7 +11,7 @@ "patch_size": [1, 2, 2], "num_channels_latents": 48, - "num_inference_steps": 50, + "infer_steps": 50, "sample_shift": 5.0, "sample_guide_scale": 5.0, "enable_cfg": true, diff --git a/configs/matrix_game3/matrix_game3_distilled.json b/configs/matrix_game3/matrix_game3_distilled.json index e2d179f88..720f488e5 100644 --- a/configs/matrix_game3/matrix_game3_distilled.json +++ b/configs/matrix_game3/matrix_game3_distilled.json @@ -11,7 +11,7 @@ "patch_size": [1, 2, 2], "num_channels_latents": 48, - "num_inference_steps": 3, + "infer_steps": 3, "sample_shift": 5.0, "sample_guide_scale": 1.0, "enable_cfg": false, diff --git a/lightx2v/utils/set_config.py b/lightx2v/utils/set_config.py index 3017c9f67..3a6601078 100755 --- a/lightx2v/utils/set_config.py +++ b/lightx2v/utils/set_config.py @@ -115,6 +115,11 @@ def auto_calc_config(config): model_config = json.load(f) config.update(model_config) + # Some upstream/offical configs use `num_inference_steps`, while the shared + # LightX2V scheduler stack expects `infer_steps`. + if "infer_steps" not in config and "num_inference_steps" in config: + config["infer_steps"] = config["num_inference_steps"] + if config["task"] in ["i2v", "s2v", "rs2v"]: if config["target_video_length"] % config["vae_stride"][0] != 1: logger.warning(f"`num_frames - 1` has to be divisible by {config['vae_stride'][0]}. Rounding to the nearest number.") From 8b4375a60b654dd6544643e3d08453e3641d7adc Mon Sep 17 00:00:00 2001 From: Yang Date: Mon, 6 Apr 2026 15:38:46 +0800 Subject: [PATCH 08/25] Fix the runner --- lightx2v/models/runners/wan/wan_matrix_game3_runner.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/lightx2v/models/runners/wan/wan_matrix_game3_runner.py b/lightx2v/models/runners/wan/wan_matrix_game3_runner.py index b2be0bc38..7c461ab13 100644 --- a/lightx2v/models/runners/wan/wan_matrix_game3_runner.py +++ b/lightx2v/models/runners/wan/wan_matrix_game3_runner.py @@ -130,6 +130,12 @@ def __init__(self, config): config["num_channels_latents"] = int(config.get("num_channels_latents", 48)) config["vae_stride"] = tuple(config.get("vae_stride", (4, 16, 16))) config["patch_size"] = tuple(config.get("patch_size", (1, 2, 2))) + # Load the official MG3 sub-model config before the parent runner + # constructs the scheduler. The shared Wan scheduler expects fields + # like `dim` and `num_heads` to already exist in `self.config`. + self.config = config + self.matrix_game3_model_cls = original_model_cls + self._load_matrix_game3_model_config() super().__init__(config) self.matrix_game3_model_cls = original_model_cls @@ -168,7 +174,6 @@ def __init__(self, config): self._mg3_generated_latent_history: list[torch.Tensor] = [] self._mg3_tail_latents: Optional[torch.Tensor] = None self._mg3_noise_generator: Optional[torch.Generator] = None - self._load_matrix_game3_model_config() def set_inputs(self, inputs): super().set_inputs(inputs) From 92a2f487c90665b3e57721a41216cdccb77a83d3 Mon Sep 17 00:00:00 2001 From: Yang Date: Mon, 6 Apr 2026 16:00:42 +0800 Subject: [PATCH 09/25] add the navigation --- .../runners/wan/wan_matrix_game3_runner.py | 36 ++++++++++++++++++- 1 file changed, 35 insertions(+), 1 deletion(-) diff --git a/lightx2v/models/runners/wan/wan_matrix_game3_runner.py b/lightx2v/models/runners/wan/wan_matrix_game3_runner.py index 7c461ab13..79e3050d8 100644 --- a/lightx2v/models/runners/wan/wan_matrix_game3_runner.py +++ b/lightx2v/models/runners/wan/wan_matrix_game3_runner.py @@ -96,6 +96,11 @@ def _expand_path_candidates(path_value: Any) -> list[Path]: return candidates +def _append_unique_path(paths: list[Path], candidate: Path): + if candidate not in paths: + paths.append(candidate) + + @RUNNER_REGISTER("wan2.2_matrix_game3") class WanMatrixGame3Runner(Wan22DenseRunner): """Runner-only Matrix-Game-3 adapter on top of the existing Wan2.2 lifecycle. @@ -229,8 +234,37 @@ def resolve_official_root(self) -> Path: "Please set config['matrix_game3_official_root'] to the official source root directory." ) + auto_candidates: list[Path] = [] for relative_path in _MATRIX_GAME3_OFFICIAL_ROOT_RELATIVE_CANDIDATES: - resolved = self._resolve_official_root_candidate(_PROJECT_ROOT / relative_path) + _append_unique_path(auto_candidates, _PROJECT_ROOT / relative_path) + _append_unique_path(auto_candidates, _PROJECT_ROOT) + + path_hints = [ + self.config.get("model_path"), + self.config.get("config_json"), + self.input_info.image_path if getattr(self, "input_info", None) is not None else None, + ] + for path_hint in path_hints: + if not path_hint: + continue + for candidate in _expand_path_candidates(path_hint): + looks_like_file = candidate.is_file() or candidate.suffix.lower() in { + ".json", + ".jpg", + ".jpeg", + ".png", + ".webp", + ".bmp", + ".gif", + ".npy", + } + current = candidate.parent if looks_like_file else candidate + for ancestor in (current, *current.parents): + _append_unique_path(auto_candidates, ancestor) + _append_unique_path(auto_candidates, ancestor / "Matrix-Game-3") + + for candidate in auto_candidates: + resolved = self._resolve_official_root_candidate(candidate) if resolved is not None: return resolved From ebebd0ae94e25da95d066c38c876d0742a661292 Mon Sep 17 00:00:00 2001 From: Yang Date: Mon, 6 Apr 2026 16:09:54 +0800 Subject: [PATCH 10/25] Fix the sigma_theta --- .../wan/infer/matrix_game3/pre_infer.py | 12 +++++++---- .../infer/matrix_game3/transformer_infer.py | 21 ++++++++++--------- 2 files changed, 19 insertions(+), 14 deletions(-) diff --git a/lightx2v/models/networks/wan/infer/matrix_game3/pre_infer.py b/lightx2v/models/networks/wan/infer/matrix_game3/pre_infer.py index 1243b6fa9..0283a3e4d 100644 --- a/lightx2v/models/networks/wan/infer/matrix_game3/pre_infer.py +++ b/lightx2v/models/networks/wan/infer/matrix_game3/pre_infer.py @@ -13,6 +13,7 @@ class WanMtxg3PreInferOutput: __slots__ = [ "x", "embed", "embed0", "grid_sizes", "cos_sin", "context", + "freqs", "plucker_emb", "mouse_cond", "keyboard_cond", "mouse_cond_memory", "keyboard_cond_memory", "memory_length", "memory_latent_idx", "predict_latent_idx", @@ -126,10 +127,12 @@ def infer(self, weights, inputs, kv_start=0, kv_end=0): tuple=(grid_sizes_t, grid_sizes_h, grid_sizes_w), ) - if self.cos_sin is None or self.grid_sizes != grid_sizes.tuple: - freqs = self.freqs.clone() - self.grid_sizes = grid_sizes.tuple - self.cos_sin = self.prepare_cos_sin(grid_sizes.tuple, freqs) + # MG3 can use head-specific 3D RoPE frequencies when `sigma_theta > 0`. + # The shared LightX2V `prepare_cos_sin()` only handles the standard 2D + # RoPE table, so MG3 keeps passing raw `freqs` downstream and lets the + # MG3 transformer apply indexed RoPE itself. + self.grid_sizes = grid_sizes.tuple + self.cos_sin = None # Extract conditioning signals from the runner's inputs mg3_cond = inputs.get("mg3_conditions", {}) @@ -159,6 +162,7 @@ def infer(self, weights, inputs, kv_start=0, kv_end=0): embed0=embed0.squeeze(0), context=context, cos_sin=self.cos_sin, + freqs=self.freqs, plucker_emb=plucker_emb, mouse_cond=mouse_cond, keyboard_cond=keyboard_cond, diff --git a/lightx2v/models/networks/wan/infer/matrix_game3/transformer_infer.py b/lightx2v/models/networks/wan/infer/matrix_game3/transformer_infer.py index a38a2632d..3783fefe4 100644 --- a/lightx2v/models/networks/wan/infer/matrix_game3/transformer_infer.py +++ b/lightx2v/models/networks/wan/infer/matrix_game3/transformer_infer.py @@ -110,6 +110,7 @@ def __init__(self, config): @torch.no_grad() def infer(self, weights, pre_infer_out): self.cos_sin = pre_infer_out.cos_sin + self.freqs = pre_infer_out.freqs self.reset_infer_states() x = self.infer_main_blocks(weights.blocks, pre_infer_out) return self.infer_non_blocks(weights, x, pre_infer_out.embed) @@ -266,22 +267,22 @@ def _infer_self_attn_mg3(self, phase, x, shift_msa, scale_msa, pre_infer_out): q = torch.cat([q_memory.squeeze(0), q_pred.squeeze(0)], dim=0) k = torch.cat([k_memory.squeeze(0), k_pred.squeeze(0)], dim=0) else: - # No memory — standard RoPE or indexed RoPE + # No memory — MG3 official behavior still uses indexed RoPE. + q_unsq = q.unsqueeze(0) + k_unsq = k.unsqueeze(0) + grid_sizes_t = torch.tensor( + [[grid_sizes.tuple[0], grid_sizes.tuple[1], grid_sizes.tuple[2]]], + dtype=torch.long, device=q.device, + ) if predict_latent_idx is not None: - q_unsq = q.unsqueeze(0) - k_unsq = k.unsqueeze(0) - grid_sizes_t = torch.tensor( - [[grid_sizes.tuple[0], grid_sizes.tuple[1], grid_sizes.tuple[2]]], - dtype=torch.long, device=q.device, - ) if isinstance(predict_latent_idx, tuple) and len(predict_latent_idx) == 2: pred_indices = list(range(predict_latent_idx[0], predict_latent_idx[1])) else: pred_indices = predict_latent_idx - q = rope_apply_with_indices(q_unsq, grid_sizes_t, self.freqs, pred_indices).squeeze(0) - k = rope_apply_with_indices(k_unsq, grid_sizes_t, self.freqs, pred_indices).squeeze(0) else: - q, k = self.apply_rope_func(q, k, cos_sin) + pred_indices = list(range(grid_sizes.tuple[0])) + q = rope_apply_with_indices(q_unsq, grid_sizes_t, self.freqs, pred_indices).squeeze(0) + k = rope_apply_with_indices(k_unsq, grid_sizes_t, self.freqs, pred_indices).squeeze(0) img_qkv_len = q.shape[0] if self.self_attn_cu_seqlens_qkv is None: From 8812beeae34b80c73c0bf711169fa12145feaa7a Mon Sep 17 00:00:00 2001 From: Yang Date: Mon, 6 Apr 2026 16:17:53 +0800 Subject: [PATCH 11/25] replace the order --- .../networks/wan/infer/matrix_game3/transformer_infer.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/lightx2v/models/networks/wan/infer/matrix_game3/transformer_infer.py b/lightx2v/models/networks/wan/infer/matrix_game3/transformer_infer.py index 3783fefe4..37c63d82a 100644 --- a/lightx2v/models/networks/wan/infer/matrix_game3/transformer_infer.py +++ b/lightx2v/models/networks/wan/infer/matrix_game3/transformer_infer.py @@ -77,9 +77,9 @@ def rope_apply_with_indices(x, grid_sizes, freqs, indices): freq_t = freq_parts[0][:, indices, :] # [n, f, c_t] cos_sin = torch.cat( [ - freq_t.permute(1, 0, 2).unsqueeze(2).unsqueeze(3).expand(-1, -1, h, w, -1), - freq_parts[1][:, :h, :].permute(1, 0, 2).unsqueeze(0).unsqueeze(3).expand(f, -1, -1, w, -1), - freq_parts[2][:, :w, :].permute(1, 0, 2).unsqueeze(0).unsqueeze(2).expand(f, -1, h, -1, -1), + freq_t.permute(1, 0, 2).view(f, 1, 1, n, -1).expand(f, h, w, n, -1), + freq_parts[1][:, :h, :].permute(1, 0, 2).view(1, h, 1, n, -1).expand(f, h, w, n, -1), + freq_parts[2][:, :w, :].permute(1, 0, 2).view(1, 1, w, n, -1).expand(f, h, w, n, -1), ], dim=-1, ).reshape(f * h * w, n, -1) From 1af1d862fd568b671840e8f606422f55736109ba Mon Sep 17 00:00:00 2001 From: Yang Date: Mon, 6 Apr 2026 17:09:40 +0800 Subject: [PATCH 12/25] resolve the vague problem --- .../wan/infer/matrix_game3/post_infer.py | 25 ++ .../wan/infer/matrix_game3/pre_infer.py | 45 ++- .../infer/matrix_game3/transformer_infer.py | 304 +++++++++++++----- .../models/networks/wan/matrix_game3_model.py | 4 +- 4 files changed, 290 insertions(+), 88 deletions(-) create mode 100644 lightx2v/models/networks/wan/infer/matrix_game3/post_infer.py diff --git a/lightx2v/models/networks/wan/infer/matrix_game3/post_infer.py b/lightx2v/models/networks/wan/infer/matrix_game3/post_infer.py new file mode 100644 index 000000000..f643f452b --- /dev/null +++ b/lightx2v/models/networks/wan/infer/matrix_game3/post_infer.py @@ -0,0 +1,25 @@ +import torch + +from lightx2v.models.networks.wan.infer.post_infer import WanPostInfer + + +class WanMtxg3PostInfer(WanPostInfer): + """Matrix-Game-3 post-processing. + + The official MG3 model prepends memory latents before patch embedding, then + drops those memory frames from the final model output. Keep that behavior + local to the MG3 adapter instead of changing the shared Wan post-infer path. + """ + + @torch.no_grad() + def infer(self, x, pre_infer_out): + x = self.unpatchify(x, pre_infer_out.grid_sizes.tuple) + + memory_length = getattr(pre_infer_out, "memory_length", 0) + if memory_length > 0: + x = [u[:, memory_length:] for u in x] + + if self.clean_cuda_cache: + torch.cuda.empty_cache() + + return [u.float() for u in x] diff --git a/lightx2v/models/networks/wan/infer/matrix_game3/pre_infer.py b/lightx2v/models/networks/wan/infer/matrix_game3/pre_infer.py index 0283a3e4d..4b88820dd 100644 --- a/lightx2v/models/networks/wan/infer/matrix_game3/pre_infer.py +++ b/lightx2v/models/networks/wan/infer/matrix_game3/pre_infer.py @@ -93,6 +93,40 @@ def infer(self, weights, inputs, kv_start=0, kv_end=0): else: context = inputs["text_encoder_output"]["context_null"] + # Matrix-Game-3 conditions are staged in the standard LightX2V + # `image_encoder_output["dit_cond_dict"]` container by the runner. + image_encoder_output = inputs.get("image_encoder_output", {}) + dit_cond_dict = image_encoder_output.get("dit_cond_dict") or {} + + if self.scheduler.infer_condition: + plucker_emb = dit_cond_dict.get("plucker_emb_with_memory", dit_cond_dict.get("c2ws_plucker_emb", None)) + mouse_cond = dit_cond_dict.get("mouse_cond", None) + keyboard_cond = dit_cond_dict.get("keyboard_cond", None) + x_memory = dit_cond_dict.get("x_memory", None) + timestep_memory = dit_cond_dict.get("timestep_memory", None) + mouse_cond_memory = dit_cond_dict.get("mouse_cond_memory", None) + keyboard_cond_memory = dit_cond_dict.get("keyboard_cond_memory", None) + memory_latent_idx = dit_cond_dict.get("memory_latent_idx", None) + else: + plucker_emb = dit_cond_dict.get("c2ws_plucker_emb", None) + mouse_source = dit_cond_dict.get("mouse_cond", None) + keyboard_source = dit_cond_dict.get("keyboard_cond", None) + mouse_cond = torch.ones_like(mouse_source) if mouse_source is not None else None + keyboard_cond = -torch.ones_like(keyboard_source) if keyboard_source is not None else None + x_memory = None + timestep_memory = None + mouse_cond_memory = None + keyboard_cond_memory = None + memory_latent_idx = None + predict_latent_idx = dit_cond_dict.get("predict_latent_idx", None) + + memory_length = 0 + if x_memory is not None: + memory_length = int(x_memory.shape[2]) + x = torch.cat([x_memory.squeeze(0).to(device=x.device, dtype=x.dtype), x], dim=1) + if timestep_memory is not None: + t = torch.cat([timestep_memory.squeeze(0).to(device=t.device, dtype=t.dtype), t], dim=0) + # Patch embedding x = weights.patch_embedding.apply(x.unsqueeze(0)) grid_sizes_t, grid_sizes_h, grid_sizes_w = x.shape[2:] @@ -134,17 +168,6 @@ def infer(self, weights, inputs, kv_start=0, kv_end=0): self.grid_sizes = grid_sizes.tuple self.cos_sin = None - # Extract conditioning signals from the runner's inputs - mg3_cond = inputs.get("mg3_conditions", {}) - plucker_emb = mg3_cond.get("plucker_emb", None) - mouse_cond = mg3_cond.get("mouse_cond", None) - keyboard_cond = mg3_cond.get("keyboard_cond", None) - mouse_cond_memory = mg3_cond.get("mouse_cond_memory", None) - keyboard_cond_memory = mg3_cond.get("keyboard_cond_memory", None) - memory_length = mg3_cond.get("memory_length", 0) - memory_latent_idx = mg3_cond.get("memory_latent_idx", None) - predict_latent_idx = mg3_cond.get("predict_latent_idx", None) - # Process plucker embedding through the global camera layers if plucker_emb is not None: plucker_emb = weights.patch_embedding_wancamctrl.apply(plucker_emb.squeeze(0)) diff --git a/lightx2v/models/networks/wan/infer/matrix_game3/transformer_infer.py b/lightx2v/models/networks/wan/infer/matrix_game3/transformer_infer.py index 37c63d82a..b4681975a 100644 --- a/lightx2v/models/networks/wan/infer/matrix_game3/transformer_infer.py +++ b/lightx2v/models/networks/wan/infer/matrix_game3/transformer_infer.py @@ -26,6 +26,7 @@ FLASH_ATTN_3_AVAILABLE = False from lightx2v.models.networks.wan.infer.transformer_infer import WanTransformerInfer +from lightx2v.models.networks.wan.infer.matrix_game2.posemb_layers import apply_rotary_emb, get_nd_rotary_pos_embed from lightx2v.utils.envs import * from lightx2v.utils.registry_factory import * from lightx2v_platform.base.global_var import AI_DEVICE @@ -106,6 +107,62 @@ def __init__(self, config): super().__init__(config) self.action_config = config.get("action_config", {}) self.action_blocks = set(self.action_config.get("blocks", [])) + self.vae_time_compression_ratio = int(self.action_config.get("vae_time_compression_ratio", 4)) + self.windows_size = int(self.action_config.get("windows_size", 3)) + self.action_patch_size = list(self.action_config.get("patch_size", [1, 2, 2])) + self.action_rope_theta = float(self.action_config.get("rope_theta", 256)) + self.enable_mouse = bool(self.action_config.get("enable_mouse", True)) + self.enable_keyboard = bool(self.action_config.get("enable_keyboard", True)) + self.action_heads_num = int(self.action_config.get("heads_num", 16)) + self.mouse_hidden_dim = int(self.action_config.get("mouse_hidden_dim", 1024)) + self.keyboard_hidden_dim = int(self.action_config.get("keyboard_hidden_dim", 1024)) + self.mouse_qk_dim_list = list(self.action_config.get("mouse_qk_dim_list", [8, 28, 28])) + self.rope_dim_list = list(self.action_config.get("rope_dim_list", [8, 28, 28])) + + def _get_action_rotary_pos_embed(self, video_length, head_dim, rope_dim_list=None): + target_ndim = 3 + latents_size = [video_length, self.action_patch_size[1], self.action_patch_size[2]] + + if isinstance(self.action_patch_size, int): + rope_sizes = [s // self.action_patch_size for s in latents_size] + patch_t = self.action_patch_size + else: + rope_sizes = [s // self.action_patch_size[idx] for idx, s in enumerate(latents_size)] + patch_t = self.action_patch_size[0] + + if len(rope_sizes) != target_ndim: + rope_sizes = [1] * (target_ndim - len(rope_sizes)) + rope_sizes + + if rope_dim_list is None: + rope_dim_list = [head_dim // target_ndim for _ in range(target_ndim)] + assert sum(rope_dim_list) == head_dim, "sum(rope_dim_list) should equal the action attention head dim" + + freqs_cos, freqs_sin = get_nd_rotary_pos_embed( + rope_dim_list, + rope_sizes, + theta=self.action_rope_theta, + use_real=True, + theta_rescale_factor=1, + ) + usable = video_length * rope_sizes[1] * rope_sizes[2] // patch_t + return freqs_cos[-usable:], freqs_sin[-usable:] + + def _run_flash_attention(self, q, k, v, causal=False): + if FLASH_ATTN_3_AVAILABLE: + try: + return flash_attn_interface.flash_attn_func(q, k, v, causal=causal) + except TypeError: + return flash_attn_interface.flash_attn_func(q, k, v) + if "flash_attn_func" in globals(): + try: + return flash_attn_func(q, k, v, causal=causal) + except TypeError: + return flash_attn_func(q, k, v) + + q_pt = q.transpose(1, 2) + k_pt = k.transpose(1, 2) + v_pt = v.transpose(1, 2) + return torch.nn.functional.scaled_dot_product_attention(q_pt, k_pt, v_pt, is_causal=causal).transpose(1, 2).contiguous() @torch.no_grad() def infer(self, weights, pre_infer_out): @@ -320,93 +377,190 @@ def _infer_cam_injection(self, cam_phase, x, plucker_emb): return x def _infer_action_module(self, phase, x, pre_infer_out): - """ActionModule forward: keyboard + mouse conditioning via cross-attention. - - This implements the official MG3 ActionModule logic in the LightX2V - weight/infer separation style. The module: - 1. Processes mouse condition through mouse_mlp - 2. Applies temporal self-attention with QKV (t_qkv) - 3. Projects back via proj_mouse - 4. Processes keyboard condition through keyboard_embed - 5. Applies keyboard cross-attention - 6. Projects back via proj_keyboard - """ - grid_sizes = pre_infer_out.grid_sizes - f, h, w = grid_sizes.tuple - S = h * w + """ActionModule forward aligned with the official MG3 implementation.""" + tt, th, tw = pre_infer_out.grid_sizes.tuple + spatial_tokens = th * tw + pad_t = self.vae_time_compression_ratio * self.windows_size mouse_cond = pre_infer_out.mouse_cond keyboard_cond = pre_infer_out.keyboard_cond + mouse_cond_memory = pre_infer_out.mouse_cond_memory + keyboard_cond_memory = pre_infer_out.keyboard_cond_memory - x_in = x.unsqueeze(0) # [1, FHW, C] - - # --- Mouse conditioning --- - if mouse_cond is not None: - hidden_states = rearrange(x_in, "B (T S) C -> (B S) T C", T=f, S=S) - - # Mouse MLP - mouse_input = torch.cat([hidden_states, mouse_cond.expand(S, -1, -1) if mouse_cond.shape[0] == 1 else mouse_cond], dim=-1) - mouse_out = phase.mouse_mlp_0.apply(mouse_input.reshape(-1, mouse_input.shape[-1])) - mouse_out = torch.nn.functional.gelu(mouse_out, approximate="tanh") - mouse_out = phase.mouse_mlp_2.apply(mouse_out) - mouse_out = phase.mouse_mlp_3.apply(mouse_out) - mouse_out = mouse_out.reshape(S, f, -1) + x_in = x.unsqueeze(0) # [1, T*S, C] + hidden_states = x_in + memory_length = 0 - # Mouse temporal self-attention with QKV - mouse_qkv = phase.t_qkv.apply(mouse_out.reshape(-1, mouse_out.shape[-1])) - mouse_qkv = mouse_qkv.reshape(S, f, 3, self.num_heads, self.head_dim) + if self.enable_mouse and mouse_cond is not None: + batch_size, num_frames, mouse_dim = mouse_cond.shape + assert (((num_frames - 1) + self.vae_time_compression_ratio) % self.vae_time_compression_ratio == 0) or ( + num_frames % self.vae_time_compression_ratio == 0 + ) + if ((num_frames - 1) + self.vae_time_compression_ratio) % self.vae_time_compression_ratio == 0: + num_feats = int((num_frames - 1) / self.vae_time_compression_ratio) + 1 + mouse_cond = torch.cat([mouse_cond[:, 0:1, :].repeat(1, pad_t, 1), mouse_cond], dim=1) + else: + num_feats = num_frames // self.vae_time_compression_ratio + mouse_cond = torch.cat( + [mouse_cond[:, 0:1, :].repeat(1, pad_t - self.vae_time_compression_ratio, 1), mouse_cond], + dim=1, + ) + + mouse_groups = [ + mouse_cond[ + :, + self.vae_time_compression_ratio * (i - self.windows_size) + pad_t : i * self.vae_time_compression_ratio + pad_t, + :, + ] + for i in range(num_feats) + ] + mouse_groups = torch.stack(mouse_groups, dim=1) + if mouse_cond_memory is not None: + memory_length = mouse_cond_memory.shape[1] + mouse_memory = mouse_cond_memory.unsqueeze(2).repeat(1, 1, pad_t, 1) + mouse_groups = torch.cat([mouse_memory, mouse_groups], dim=1) + + hidden_states_mouse = rearrange(x_in, "B (T S) C -> (B S) T C", T=tt, S=spatial_tokens) + mouse_groups = mouse_groups.unsqueeze(-1).repeat(1, 1, 1, 1, spatial_tokens) + mouse_groups = rearrange(mouse_groups, "b t window d s -> (b s) t (window d)") + if mouse_groups.shape[1] != tt: + raise ValueError( + f"matrix-game-3 mouse condition window mismatch: expected latent T={tt}, got {mouse_groups.shape[1]}" + ) + + mouse_input = torch.cat([hidden_states_mouse, mouse_groups], dim=-1) + mouse_hidden = phase.mouse_mlp_0.apply(mouse_input.reshape(-1, mouse_input.shape[-1])) + mouse_hidden = torch.nn.functional.gelu(mouse_hidden, approximate="tanh") + mouse_hidden = phase.mouse_mlp_2.apply(mouse_hidden) + mouse_hidden = phase.mouse_mlp_3.apply(mouse_hidden) + mouse_hidden = mouse_hidden.reshape(batch_size * spatial_tokens, tt, -1) + + mouse_head_dim = self.mouse_hidden_dim // self.action_heads_num + mouse_qkv = phase.t_qkv.apply(mouse_hidden.reshape(-1, mouse_hidden.shape[-1])) + mouse_qkv = mouse_qkv.reshape(batch_size * spatial_tokens, tt, 3, self.action_heads_num, mouse_head_dim) q_m, k_m, v_m = mouse_qkv.permute(2, 0, 1, 3, 4).unbind(0) - # QK norm (RMSNorm) - q_m = phase.img_attn_q_norm.apply(q_m.reshape(-1, self.head_dim)).reshape(S, f, self.num_heads, self.head_dim) - k_m = phase.img_attn_k_norm.apply(k_m.reshape(-1, self.head_dim)).reshape(S, f, self.num_heads, self.head_dim) + q_m = phase.img_attn_q_norm.apply(q_m.reshape(-1, mouse_head_dim)).reshape( + batch_size * spatial_tokens, tt, self.action_heads_num, mouse_head_dim + ) + k_m = phase.img_attn_k_norm.apply(k_m.reshape(-1, mouse_head_dim)).reshape( + batch_size * spatial_tokens, tt, self.action_heads_num, mouse_head_dim + ) - # Flash attention - if FLASH_ATTN_3_AVAILABLE: - mouse_attn = flash_attn_interface.flash_attn_func(q_m, k_m, v_m) + if memory_length > 0: + freqs_memory = self._get_action_rotary_pos_embed(memory_length, mouse_head_dim, self.mouse_qk_dim_list) + q_mem, k_mem = apply_rotary_emb(q_m[:, :memory_length], k_m[:, :memory_length], freqs_memory, head_first=False) + q_m[:, :memory_length] = q_mem + k_m[:, :memory_length] = k_mem + + pred_length = tt - memory_length + if pred_length > 0: + freqs_pred = self._get_action_rotary_pos_embed(pred_length, mouse_head_dim, self.mouse_qk_dim_list) + q_pred, k_pred = apply_rotary_emb(q_m[:, memory_length:], k_m[:, memory_length:], freqs_pred, head_first=False) + q_m[:, memory_length:] = q_pred + k_m[:, memory_length:] = k_pred else: - mouse_attn = flash_attn_func(q_m, k_m, v_m) - - mouse_attn = rearrange(mouse_attn, "(B S) T h d -> B (T S) (h d)", B=1, S=S) - mouse_proj = phase.proj_mouse.apply(mouse_attn.squeeze(0)).unsqueeze(0) - x_in = x_in + mouse_proj - - # --- Keyboard conditioning --- - if keyboard_cond is not None: - # Keyboard embed - kb_emb = phase.keyboard_embed_0.apply(keyboard_cond.reshape(-1, keyboard_cond.shape[-1])) - kb_emb = torch.nn.functional.silu(kb_emb) - kb_emb = phase.keyboard_embed_2.apply(kb_emb) - kb_emb = kb_emb.reshape(keyboard_cond.shape[0], keyboard_cond.shape[1], -1) - - # Keyboard cross-attention: query from hidden states, key/value from keyboard - mouse_q = phase.mouse_attn_q.apply(x_in.squeeze(0)).unsqueeze(0) - keyboard_kv = phase.keyboard_attn_kv.apply(kb_emb.reshape(-1, kb_emb.shape[-1])) - keyboard_kv = keyboard_kv.reshape(1, -1, keyboard_kv.shape[-1]) - - HD = mouse_q.shape[-1] - D = HD // self.num_heads - q_k = mouse_q.view(1, -1, self.num_heads, D) - kv_split = keyboard_kv.view(1, -1, 2, self.num_heads, D) - k_k, v_k = kv_split.permute(2, 0, 1, 3, 4).unbind(0) - - # QK norm - q_k_flat = q_k.reshape(-1, D) - k_k_flat = k_k.reshape(-1, D) - q_k = phase.key_attn_q_norm.apply(q_k_flat).reshape(1, -1, self.num_heads, D) - k_k = phase.key_attn_k_norm.apply(k_k_flat).reshape(1, -1, self.num_heads, D) - - # Flash attention - if FLASH_ATTN_3_AVAILABLE: - kb_attn = flash_attn_interface.flash_attn_func(q_k, k_k, v_k) + freqs = self._get_action_rotary_pos_embed(tt, mouse_head_dim, self.mouse_qk_dim_list) + q_m, k_m = apply_rotary_emb(q_m, k_m, freqs, head_first=False) + + mouse_attn = self._run_flash_attention(q_m, k_m, v_m, causal=False) + mouse_attn = rearrange(mouse_attn, "(b s) t h d -> b (t s) (h d)", b=batch_size, s=spatial_tokens) + mouse_proj = phase.proj_mouse.apply(mouse_attn.reshape(-1, mouse_attn.shape[-1])).reshape( + batch_size, tt * spatial_tokens, -1 + ) + hidden_states = x_in + mouse_proj + else: + hidden_states = x_in + + if self.enable_keyboard and keyboard_cond is not None: + batch_size, num_frames, _ = keyboard_cond.shape + assert (((num_frames - 1) + self.vae_time_compression_ratio) % self.vae_time_compression_ratio == 0) or ( + num_frames % self.vae_time_compression_ratio == 0 + ) + if ((num_frames - 1) + self.vae_time_compression_ratio) % self.vae_time_compression_ratio == 0: + num_feats = int((num_frames - 1) / self.vae_time_compression_ratio) + 1 + keyboard_cond = torch.cat([keyboard_cond[:, 0:1, :].repeat(1, pad_t, 1), keyboard_cond], dim=1) else: - kb_attn = flash_attn_func(q_k, k_k, v_k) + num_feats = num_frames // self.vae_time_compression_ratio + keyboard_cond = torch.cat( + [keyboard_cond[:, 0:1, :].repeat(1, pad_t - self.vae_time_compression_ratio, 1), keyboard_cond], + dim=1, + ) + + keyboard_hidden = phase.keyboard_embed_0.apply(keyboard_cond.reshape(-1, keyboard_cond.shape[-1])) + keyboard_hidden = torch.nn.functional.silu(keyboard_hidden) + keyboard_hidden = phase.keyboard_embed_2.apply(keyboard_hidden) + keyboard_hidden = keyboard_hidden.reshape(batch_size, keyboard_cond.shape[1], -1) + + keyboard_groups = [ + keyboard_hidden[ + :, + self.vae_time_compression_ratio * (i - self.windows_size) + pad_t : i * self.vae_time_compression_ratio + pad_t, + :, + ] + for i in range(num_feats) + ] + keyboard_groups = torch.stack(keyboard_groups, dim=1) + if keyboard_cond_memory is not None: + memory_length = keyboard_cond_memory.shape[1] + keyboard_memory = phase.keyboard_embed_0.apply(keyboard_cond_memory.reshape(-1, keyboard_cond_memory.shape[-1])) + keyboard_memory = torch.nn.functional.silu(keyboard_memory) + keyboard_memory = phase.keyboard_embed_2.apply(keyboard_memory) + keyboard_memory = keyboard_memory.reshape(batch_size, memory_length, -1) + keyboard_memory = keyboard_memory.unsqueeze(2).repeat(1, 1, pad_t, 1) + keyboard_groups = torch.cat([keyboard_memory, keyboard_groups], dim=1) + + if keyboard_groups.shape[1] != tt: + raise ValueError( + f"matrix-game-3 keyboard condition window mismatch: expected latent T={tt}, got {keyboard_groups.shape[1]}" + ) + + keyboard_groups = keyboard_groups.reshape(batch_size, keyboard_groups.shape[1], -1) + mouse_q = phase.mouse_attn_q.apply(hidden_states.reshape(-1, hidden_states.shape[-1])).reshape( + batch_size, tt * spatial_tokens, -1 + ) + keyboard_kv = phase.keyboard_attn_kv.apply(keyboard_groups.reshape(-1, keyboard_groups.shape[-1])) + keyboard_kv = keyboard_kv.reshape(batch_size, keyboard_groups.shape[1], -1) + + keyboard_head_dim = self.keyboard_hidden_dim // self.action_heads_num + q_k = mouse_q.view(batch_size, -1, self.action_heads_num, keyboard_head_dim) + kv = keyboard_kv.view(batch_size, -1, 2, self.action_heads_num, keyboard_head_dim) + k_k, v_k = kv.permute(2, 0, 1, 3, 4).unbind(0) - kb_attn = rearrange(kb_attn, "B L H D -> B L (H D)") - kb_proj = phase.proj_keyboard.apply(kb_attn.squeeze(0)).unsqueeze(0) - x_in = x_in + kb_proj + q_k = phase.key_attn_q_norm.apply(q_k.reshape(-1, keyboard_head_dim)).reshape( + batch_size, -1, self.action_heads_num, keyboard_head_dim + ) + k_k = phase.key_attn_k_norm.apply(k_k.reshape(-1, keyboard_head_dim)).reshape( + batch_size, -1, self.action_heads_num, keyboard_head_dim + ) + + q_k = rearrange(q_k, "b (t s) h d -> (b s) t h d", s=spatial_tokens) + if memory_length > 0: + freqs_memory = self._get_action_rotary_pos_embed(memory_length, keyboard_head_dim, self.mouse_qk_dim_list) + q_mem, k_mem = apply_rotary_emb(q_k[:, :memory_length], k_k[:, :memory_length], freqs_memory, head_first=False) + q_k[:, :memory_length] = q_mem + k_k[:, :memory_length] = k_mem + + pred_length = tt - memory_length + if pred_length > 0: + freqs_pred = self._get_action_rotary_pos_embed(pred_length, keyboard_head_dim, self.mouse_qk_dim_list) + q_pred, k_pred = apply_rotary_emb(q_k[:, memory_length:], k_k[:, memory_length:], freqs_pred, head_first=False) + q_k[:, memory_length:] = q_pred + k_k[:, memory_length:] = k_pred + else: + freqs = self._get_action_rotary_pos_embed(tt, keyboard_head_dim, self.rope_dim_list) + q_k, k_k = apply_rotary_emb(q_k, k_k, freqs, head_first=False) + + k_k = k_k.repeat(spatial_tokens, 1, 1, 1) + v_k = v_k.repeat(spatial_tokens, 1, 1, 1) + kb_attn = self._run_flash_attention(q_k, k_k, v_k, causal=False) + kb_attn = rearrange(kb_attn, "(b s) t h d -> b (t s) (h d)", b=batch_size, s=spatial_tokens) + kb_proj = phase.proj_keyboard.apply(kb_attn.reshape(-1, kb_attn.shape[-1])).reshape( + batch_size, tt * spatial_tokens, -1 + ) + hidden_states = hidden_states + kb_proj - return x_in.squeeze(0) + return hidden_states.squeeze(0) @property def freqs(self): diff --git a/lightx2v/models/networks/wan/matrix_game3_model.py b/lightx2v/models/networks/wan/matrix_game3_model.py index 17555c945..c3c754eb6 100644 --- a/lightx2v/models/networks/wan/matrix_game3_model.py +++ b/lightx2v/models/networks/wan/matrix_game3_model.py @@ -5,8 +5,8 @@ from safetensors import safe_open from lightx2v.models.networks.wan.infer.matrix_game3.pre_infer import WanMtxg3PreInfer +from lightx2v.models.networks.wan.infer.matrix_game3.post_infer import WanMtxg3PostInfer from lightx2v.models.networks.wan.infer.matrix_game3.transformer_infer import WanMtxg3TransformerInfer -from lightx2v.models.networks.wan.infer.post_infer import WanPostInfer from lightx2v.models.networks.wan.model import WanModel from lightx2v.models.networks.wan.weights.matrix_game3.pre_weights import WanMtxg3PreWeights from lightx2v.models.networks.wan.weights.matrix_game3.transformer_weights import WanMtxg3TransformerWeights @@ -45,7 +45,7 @@ def _init_infer_class(self): self.config[k] = model_config[k] self.pre_infer_class = WanMtxg3PreInfer - self.post_infer_class = WanPostInfer + self.post_infer_class = WanMtxg3PostInfer self.transformer_infer_class = WanMtxg3TransformerInfer def _load_ckpt(self, unified_dtype, sensitive_layer): From 8ad18e8dc58a9e02b2a141b445ef57cb2c18cebe Mon Sep 17 00:00:00 2001 From: Yang Date: Mon, 6 Apr 2026 18:45:22 +0800 Subject: [PATCH 13/25] Add the move --- .../models/runners/wan/wan_matrix_game3_runner.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/lightx2v/models/runners/wan/wan_matrix_game3_runner.py b/lightx2v/models/runners/wan/wan_matrix_game3_runner.py index 79e3050d8..a35e2fbf6 100644 --- a/lightx2v/models/runners/wan/wan_matrix_game3_runner.py +++ b/lightx2v/models/runners/wan/wan_matrix_game3_runner.py @@ -506,6 +506,18 @@ def _prepare_matrix_game3_session(self, pil_image: Image.Image, latent_shape: li raw_controls = self._load_control_payload(action_path) raw_total_frames = self._infer_raw_total_frames(raw_controls) self._mg3_num_iterations, self._mg3_expected_total_frames = self._get_expected_total_frames(raw_total_frames) + + # Match the official Matrix-Game-3 demo pipeline: when the user does not + # provide an external action file, fall back to the benchmark universal + # action sequence instead of a fully static zero-control clip. + if not raw_controls: + modules = self._get_official_modules() + logger.warning( + "[matrix-game-3] action_path missing or empty; falling back to official Bench_actions_universal({}).", + self._mg3_expected_total_frames, + ) + raw_controls = self._normalize_payload_keys(modules["conditions"].Bench_actions_universal(self._mg3_expected_total_frames)) + self._mg3_keyboard_all, self._mg3_mouse_all, self._mg3_extrinsics_all, self._mg3_intrinsics_all = self._build_noninteractive_controls(raw_controls) def _infer_raw_total_frames(self, payload: dict[str, Any]) -> Optional[int]: From a3b9c29ae6a93a2efe4e0b7c9eecd81a17078bb5 Mon Sep 17 00:00:00 2001 From: Yang Date: Mon, 6 Apr 2026 19:33:18 +0800 Subject: [PATCH 14/25] Fix the biaes --- .../networks/wan/infer/matrix_game3/pre_infer.py | 14 +++++++++++++- .../wan/infer/matrix_game3/transformer_infer.py | 12 +++++++++++- 2 files changed, 24 insertions(+), 2 deletions(-) diff --git a/lightx2v/models/networks/wan/infer/matrix_game3/pre_infer.py b/lightx2v/models/networks/wan/infer/matrix_game3/pre_infer.py index 4b88820dd..877fcd9af 100644 --- a/lightx2v/models/networks/wan/infer/matrix_game3/pre_infer.py +++ b/lightx2v/models/networks/wan/infer/matrix_game3/pre_infer.py @@ -87,6 +87,18 @@ def infer(self, weights, inputs, kv_start=0, kv_end=0): x = self.scheduler.latents t = self.scheduler.timestep_input + # Official MG3 feeds a per-token timestep map where the fixed conditioning + # latent slots are forced to zero. LightX2V's generic Wan scheduler only + # builds that map for plain `wan2.2`, so the MG3 adapter reconstructs it + # here when the scheduler exposes a scalar timestep. + if t.numel() == 1: + mask = getattr(self.scheduler, "mask", None) + if mask is not None: + timestep_scalar = t.reshape(1).to(device=x.device, dtype=x.dtype) + t = (mask[0][:, ::2, ::2].to(device=x.device, dtype=x.dtype) * timestep_scalar).flatten() + else: + t = t.reshape(-1).to(device=x.device, dtype=x.dtype) + # Text context (MG3 uses text conditioning only, no CLIP image encoder) if self.scheduler.infer_condition: context = inputs["text_encoder_output"]["context"] @@ -125,7 +137,7 @@ def infer(self, weights, inputs, kv_start=0, kv_end=0): memory_length = int(x_memory.shape[2]) x = torch.cat([x_memory.squeeze(0).to(device=x.device, dtype=x.dtype), x], dim=1) if timestep_memory is not None: - t = torch.cat([timestep_memory.squeeze(0).to(device=t.device, dtype=t.dtype), t], dim=0) + t = torch.cat([timestep_memory.squeeze(0).to(device=x.device, dtype=x.dtype), t.to(device=x.device, dtype=x.dtype)], dim=0) # Patch embedding x = weights.patch_embedding.apply(x.unsqueeze(0)) diff --git a/lightx2v/models/networks/wan/infer/matrix_game3/transformer_infer.py b/lightx2v/models/networks/wan/infer/matrix_game3/transformer_infer.py index b4681975a..18621cb1c 100644 --- a/lightx2v/models/networks/wan/infer/matrix_game3/transformer_infer.py +++ b/lightx2v/models/networks/wan/infer/matrix_game3/transformer_infer.py @@ -105,6 +105,9 @@ class WanMtxg3TransformerInfer(WanTransformerInfer): def __init__(self, config): super().__init__(config) + # Official Matrix-Game-3 blocks are always instantiated with + # `use_memory=True`, which slightly changes the cross-attention residual path. + self.use_memory = True self.action_config = config.get("action_config", {}) self.action_blocks = set(self.action_config.get("blocks", [])) self.vae_time_compression_ratio = int(self.action_config.get("vae_time_compression_ratio", 4)) @@ -218,7 +221,14 @@ def infer_block(self, block, x, pre_infer_out): # --- Phase 2: Cross-Attention --- cross_phase = block.compute_phases[2] - norm3_out = cross_phase.norm3.apply(x) + # Match the official MG3 block semantics: + # when `use_memory=True`, norm3 is applied in-place on the residual stream + # before cross-attention, so the action/ffn branches see the normalized x. + if pre_infer_out.mouse_cond is not None or self.use_memory: + x = cross_phase.norm3.apply(x) + norm3_out = x + else: + norm3_out = cross_phase.norm3.apply(x) n, d = self.num_heads, self.head_dim q = cross_phase.cross_attn_norm_q.apply(cross_phase.cross_attn_q.apply(norm3_out)).view(-1, n, d) k = cross_phase.cross_attn_norm_k.apply(cross_phase.cross_attn_k.apply(pre_infer_out.context)).view(-1, n, d) From 079d188b5e5bd387c6e7386a04bc557bfb45f04c Mon Sep 17 00:00:00 2001 From: Yang Date: Mon, 6 Apr 2026 20:42:54 +0800 Subject: [PATCH 15/25] fix the plucker --- .../wan/infer/matrix_game3/pre_infer.py | 31 +++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/lightx2v/models/networks/wan/infer/matrix_game3/pre_infer.py b/lightx2v/models/networks/wan/infer/matrix_game3/pre_infer.py index 877fcd9af..3b907030a 100644 --- a/lightx2v/models/networks/wan/infer/matrix_game3/pre_infer.py +++ b/lightx2v/models/networks/wan/infer/matrix_game3/pre_infer.py @@ -1,5 +1,6 @@ import torch import torch.nn.functional as F +from einops import rearrange from lightx2v.models.networks.wan.infer.module_io import GridOutput from lightx2v.models.networks.wan.infer.pre_infer import WanPreInfer @@ -182,6 +183,36 @@ def infer(self, weights, inputs, kv_start=0, kv_end=0): # Process plucker embedding through the global camera layers if plucker_emb is not None: + # Match the official MG3 implementation: plucker embeddings arrive as + # [B, C, F, H, W] (or an equivalent list form), must be patchified into + # [B, L, C'] tokens, and only then can they pass through the global + # camera-control linear projection. + if torch.is_tensor(plucker_emb): + plucker_items = [u.unsqueeze(0) for u in plucker_emb] + else: + plucker_items = [u.unsqueeze(0) if u.dim() == 4 else u for u in plucker_emb] + + patch_t, patch_h, patch_w = self.config.get("patch_size", (1, 2, 2)) + plucker_emb = [ + rearrange( + item, + "1 c (f c1) (h c2) (w c3) -> 1 (f h w) (c c1 c2 c3)", + c1=patch_t, + c2=patch_h, + c3=patch_w, + ) + for item in plucker_items + ] + plucker_emb = torch.cat(plucker_emb, dim=1) + if plucker_emb.size(1) < x.size(1): + plucker_emb = torch.cat( + [ + plucker_emb, + plucker_emb.new_zeros(plucker_emb.size(0), x.size(1) - plucker_emb.size(1), plucker_emb.size(2)), + ], + dim=1, + ) + plucker_emb = weights.patch_embedding_wancamctrl.apply(plucker_emb.squeeze(0)) plucker_hidden = weights.c2ws_hidden_states_layer2.apply( torch.nn.functional.silu( From bf928e82cbe2457ec520d10dfb4bbdd754b0a6ba Mon Sep 17 00:00:00 2001 From: Yang Date: Mon, 6 Apr 2026 21:18:40 +0800 Subject: [PATCH 16/25] Add the negative prompt --- .../runners/wan/wan_matrix_game3_runner.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/lightx2v/models/runners/wan/wan_matrix_game3_runner.py b/lightx2v/models/runners/wan/wan_matrix_game3_runner.py index a35e2fbf6..eede461a1 100644 --- a/lightx2v/models/runners/wan/wan_matrix_game3_runner.py +++ b/lightx2v/models/runners/wan/wan_matrix_game3_runner.py @@ -32,6 +32,13 @@ ) _MATRIX_GAME3_CONFIG_ROOT_RELATIVE = Path("Matrix-Game-3.0") _MATRIX_GAME3_OFFICIAL_PACKAGE = "_lightx2v_matrix_game3_official" +_MATRIX_GAME3_DEFAULT_NEGATIVE_PROMPT = ( + "Vibrant colors, overexposure, static, blurred details, subtitles, style, artwork, " + "painting, still image, overall grayness, worst quality, low quality, JPEG compression " + "residue, ugly, mutilated, extra fingers, poorly drawn hands, poorly drawn faces, " + "deformed, disfigured, malformed limbs, fused fingers, still image, cluttered background, " + "three legs, crowded background, walking backwards" +) @dataclass @@ -189,6 +196,15 @@ def set_inputs(self, inputs): if "pose" in self.input_info.__dataclass_fields__: self.input_info.pose = inputs.get("pose", inputs.get("action_path", "")) + def run_text_encoder(self, input_info): + # Official Matrix-Game-3 base inference uses a non-empty default negative + # prompt for CFG. If the caller leaves `--negative_prompt` empty, reuse the + # official default so the unconditional branch matches the reference path. + if self.config.get("enable_cfg", False) and not getattr(input_info, "negative_prompt", ""): + input_info.negative_prompt = self.config.get("sample_neg_prompt", _MATRIX_GAME3_DEFAULT_NEGATIVE_PROMPT) + logger.info("[matrix-game-3] negative_prompt not provided; falling back to the official sample_neg_prompt for CFG.") + return super().run_text_encoder(input_info) + def load_transformer(self): from lightx2v.models.networks.wan.matrix_game3_model import WanMtxg3Model @@ -322,6 +338,7 @@ def _load_matrix_game3_model_config(self): self.config["num_channels_latents"] = int(model_config.get("in_dim", self.config.get("num_channels_latents", 48))) self.config["vae_stride"] = tuple(self.config.get("vae_stride", (4, 16, 16))) self.config["patch_size"] = tuple(model_config.get("patch_size", self.config.get("patch_size", (1, 2, 2)))) + self.config["sample_neg_prompt"] = self.config.get("sample_neg_prompt", _MATRIX_GAME3_DEFAULT_NEGATIVE_PROMPT) action_config = self.config.get("action_config", {}) self.keyboard_dim_in = int(self.config.get("keyboard_dim_in", action_config.get("keyboard_dim_in", 6))) From 8c08eb94cddacc5761bdca0c32c04d6c2140be00 Mon Sep 17 00:00:00 2001 From: Yang Date: Tue, 7 Apr 2026 10:02:02 +0800 Subject: [PATCH 17/25] Add the official package --- .../runners/wan/wan_matrix_game3_runner.py | 95 ++++++++++++++++++- 1 file changed, 93 insertions(+), 2 deletions(-) diff --git a/lightx2v/models/runners/wan/wan_matrix_game3_runner.py b/lightx2v/models/runners/wan/wan_matrix_game3_runner.py index eede461a1..1df30aa2e 100644 --- a/lightx2v/models/runners/wan/wan_matrix_game3_runner.py +++ b/lightx2v/models/runners/wan/wan_matrix_game3_runner.py @@ -15,6 +15,7 @@ from loguru import logger from lightx2v.models.runners.wan.wan_runner import Wan22DenseRunner, build_wan_model_with_lora +from lightx2v.models.schedulers.scheduler import BaseScheduler from lightx2v.server.metrics import monitor_cli from lightx2v.utils.envs import GET_DTYPE from lightx2v.utils.profiler import GET_RECORDER_MODE, ProfilingContext4DebugL1, ProfilingContext4DebugL2 @@ -108,6 +109,71 @@ def _append_unique_path(paths: list[Path], candidate: Path): paths.append(candidate) +class MatrixGame3OfficialSchedulerAdapter(BaseScheduler): + """Adapt the official MG3 FlowUniPC scheduler to LightX2V's scheduler interface. + + The distilled path is fairly tolerant of LightX2V's generic Wan scheduler, but the + base model is much more sensitive to scheduler semantics under 50-step CFG. This + adapter keeps the rest of the LightX2V lifecycle untouched while delegating the + actual UniPC stepping logic to the official Matrix-Game-3 implementation. + """ + + def __init__(self, config, scheduler_cls): + super().__init__(config) + self.scheduler_cls = scheduler_cls + self.sample_shift = self.config["sample_shift"] + self.sample_guide_scale = self.config["sample_guide_scale"] + self.noise_pred = None + self.mask = None + self.vae_encoder_out = None + self.timestep_input = None + self._solver = None + self._generator = None + + def _reset_solver(self): + self._solver = self.scheduler_cls() + self._solver.set_timesteps(self.infer_steps, device=AI_DEVICE, shift=self.sample_shift) + + def prepare(self, seed, latent_shape, image_encoder_output=None): + self._generator = torch.Generator(device=AI_DEVICE).manual_seed(seed) + self.latents = torch.randn(tuple(latent_shape), dtype=GET_DTYPE(), device=AI_DEVICE, generator=self._generator) + self.vae_encoder_out = image_encoder_output.get("vae_encoder_out") if image_encoder_output is not None else None + if self.vae_encoder_out is not None: + self.vae_encoder_out = self.vae_encoder_out.to(device=AI_DEVICE, dtype=GET_DTYPE()) + self.mask = torch.ones_like(self.latents) + self._reset_solver() + + def reset(self, seed, latent_shape, step_index=None): + self._generator = torch.Generator(device=AI_DEVICE).manual_seed(seed) + self.latents = torch.randn(tuple(latent_shape), dtype=GET_DTYPE(), device=AI_DEVICE, generator=self._generator) + if self.vae_encoder_out is not None: + self.vae_encoder_out = self.vae_encoder_out.to(device=AI_DEVICE, dtype=GET_DTYPE()) + if self.mask is not None: + self.mask = self.mask.to(device=AI_DEVICE, dtype=GET_DTYPE()) + self._reset_solver() + if step_index is not None: + self.step_index = step_index + + def step_pre(self, step_index): + super().step_pre(step_index) + self.timestep_input = torch.stack([self._solver.timesteps[self.step_index].to(device=AI_DEVICE)]) + + def step_post(self): + timestep = self._solver.timesteps[self.step_index].to(device=self.latents.device) + prev_sample = self._solver.step( + self.noise_pred.to(dtype=self.latents.dtype), + timestep, + self.latents, + return_dict=False, + )[0] + if self.mask is not None and self.vae_encoder_out is not None: + prev_sample = (1.0 - self.mask) * self.vae_encoder_out + self.mask * prev_sample + self.latents = prev_sample.to(dtype=GET_DTYPE()) + + def clear(self): + self._solver = None + + @RUNNER_REGISTER("wan2.2_matrix_game3") class WanMatrixGame3Runner(Wan22DenseRunner): """Runner-only Matrix-Game-3 adapter on top of the existing Wan2.2 lifecycle. @@ -220,6 +286,31 @@ def load_transformer(self): return WanMtxg3Model(**model_kwargs) return build_wan_model_with_lora(WanMtxg3Model, self.config, model_kwargs, lora_configs, model_type="wan2.2") + def init_scheduler(self): + # Distilled MG3 is already stable on the shared Wan scheduler path. Base MG3 + # is far more sensitive to the exact UniPC implementation, so route it + # through the official FlowUniPC scheduler instead of the generic adapter. + if self.config.get("use_base_model", False): + try: + official_root = self.resolve_official_root() + wan_root = official_root / "wan" + utils_root = wan_root / "utils" + _ensure_namespace_package(f"{_MATRIX_GAME3_OFFICIAL_PACKAGE}.wan", wan_root) + _ensure_namespace_package(f"{_MATRIX_GAME3_OFFICIAL_PACKAGE}.wan.utils", utils_root) + scheduler_module = _load_module_from_path( + f"{_MATRIX_GAME3_OFFICIAL_PACKAGE}.wan.utils.fm_solvers_unipc", + utils_root / "fm_solvers_unipc.py", + ) + self.scheduler = MatrixGame3OfficialSchedulerAdapter(self.config, scheduler_module.FlowUniPCMultistepScheduler) + logger.info("[matrix-game-3] using official FlowUniPCMultistepScheduler for base-model sampling.") + return + except Exception as exc: + logger.warning( + "[matrix-game-3] failed to initialize official base scheduler ({}); falling back to LightX2V WanScheduler.", + exc, + ) + super().init_scheduler() + def _get_sub_model_folder(self) -> str: """Resolve which MG3 sub-model folder should be used for config lookup.""" return str(self.config.get("sub_model_folder", "base_model" if self.config.get("use_base_model", False) else "base_distilled_model")) @@ -1169,10 +1260,10 @@ def _apply_segment_scheduler_state(self, segment_state: MatrixGame3SegmentState) latents = torch.randn( tuple(segment_state.latent_shape), device=AI_DEVICE, - dtype=torch.float32, + dtype=GET_DTYPE(), generator=self._mg3_noise_generator, ) - scheduler.vae_encoder_out = segment_state.vae_encoder_out.to(device=AI_DEVICE, dtype=torch.float32) + scheduler.vae_encoder_out = segment_state.vae_encoder_out.to(device=AI_DEVICE, dtype=GET_DTYPE()) scheduler.mask = torch.ones_like(latents) # Mask value 0 means "keep the provided latent conditioning", while 1 means # "sample this slot from noise through the diffusion process". From b181ff1beceb8dd3dc38c8cda5587635a9997059 Mon Sep 17 00:00:00 2001 From: Yang Date: Tue, 7 Apr 2026 14:17:28 +0800 Subject: [PATCH 18/25] Fix the Lint Error --- lightx2v/infer.py | 1 + .../wan/infer/matrix_game3/pre_infer.py | 25 ++++--- .../infer/matrix_game3/transformer_infer.py | 71 +++++++------------ .../models/networks/wan/matrix_game3_model.py | 10 +-- .../wan/weights/matrix_game3/pre_weights.py | 17 ++--- .../matrix_game3/transformer_weights.py | 15 +--- .../runners/wan/wan_matrix_game3_runner.py | 18 +++-- 7 files changed, 62 insertions(+), 95 deletions(-) diff --git a/lightx2v/infer.py b/lightx2v/infer.py index eeb4fbf33..9b1877402 100755 --- a/lightx2v/infer.py +++ b/lightx2v/infer.py @@ -7,6 +7,7 @@ from lightx2v.common.ops import * from lightx2v.models.runners.bagel.bagel_runner import BagelRunner # noqa: F401 + try: from lightx2v.models.runners.flux2_klein.flux2_klein_runner import Flux2KleinRunner # noqa: F401 except (ImportError, ModuleNotFoundError): diff --git a/lightx2v/models/networks/wan/infer/matrix_game3/pre_infer.py b/lightx2v/models/networks/wan/infer/matrix_game3/pre_infer.py index 3b907030a..2284588ff 100644 --- a/lightx2v/models/networks/wan/infer/matrix_game3/pre_infer.py +++ b/lightx2v/models/networks/wan/infer/matrix_game3/pre_infer.py @@ -1,5 +1,4 @@ import torch -import torch.nn.functional as F from einops import rearrange from lightx2v.models.networks.wan.infer.module_io import GridOutput @@ -13,11 +12,21 @@ class WanMtxg3PreInferOutput: """Container for MG3 pre-inference outputs passed to the transformer.""" __slots__ = [ - "x", "embed", "embed0", "grid_sizes", "cos_sin", "context", + "x", + "embed", + "embed0", + "grid_sizes", + "cos_sin", + "context", "freqs", - "plucker_emb", "mouse_cond", "keyboard_cond", - "mouse_cond_memory", "keyboard_cond_memory", - "memory_length", "memory_latent_idx", "predict_latent_idx", + "plucker_emb", + "mouse_cond", + "keyboard_cond", + "mouse_cond_memory", + "keyboard_cond_memory", + "memory_length", + "memory_latent_idx", + "predict_latent_idx", ] def __init__(self, **kwargs): @@ -214,11 +223,7 @@ def infer(self, weights, inputs, kv_start=0, kv_end=0): ) plucker_emb = weights.patch_embedding_wancamctrl.apply(plucker_emb.squeeze(0)) - plucker_hidden = weights.c2ws_hidden_states_layer2.apply( - torch.nn.functional.silu( - weights.c2ws_hidden_states_layer1.apply(plucker_emb) - ) - ) + plucker_hidden = weights.c2ws_hidden_states_layer2.apply(torch.nn.functional.silu(weights.c2ws_hidden_states_layer1.apply(plucker_emb))) plucker_emb = plucker_emb + plucker_hidden return WanMtxg3PreInferOutput( diff --git a/lightx2v/models/networks/wan/infer/matrix_game3/transformer_infer.py b/lightx2v/models/networks/wan/infer/matrix_game3/transformer_infer.py index 18621cb1c..198d3324b 100644 --- a/lightx2v/models/networks/wan/infer/matrix_game3/transformer_infer.py +++ b/lightx2v/models/networks/wan/infer/matrix_game3/transformer_infer.py @@ -8,8 +8,6 @@ This closely follows the official MG3 `WanAttentionBlock.forward()`. """ -import math - import torch from einops import rearrange @@ -25,8 +23,8 @@ except ImportError: FLASH_ATTN_3_AVAILABLE = False -from lightx2v.models.networks.wan.infer.transformer_infer import WanTransformerInfer from lightx2v.models.networks.wan.infer.matrix_game2.posemb_layers import apply_rotary_emb, get_nd_rotary_pos_embed +from lightx2v.models.networks.wan.infer.transformer_infer import WanTransformerInfer from lightx2v.utils.envs import * from lightx2v.utils.registry_factory import * from lightx2v_platform.base.global_var import AI_DEVICE @@ -240,7 +238,9 @@ def infer_block(self, block, x, pre_infer_out): self.cross_attn_cu_seqlens_kv = torch.tensor([0, k.shape[0]], dtype=torch.int32).to(k.device) attn_out = cross_phase.cross_attn_1.apply( - q=q, k=k, v=v, + q=q, + k=k, + v=v, cu_seqlens_q=self.cross_attn_cu_seqlens_q, cu_seqlens_kv=self.cross_attn_cu_seqlens_kv, max_seqlen_q=q.size(0), @@ -253,9 +253,7 @@ def infer_block(self, block, x, pre_infer_out): if has_action: action_phase_idx = 3 action_phase = block.compute_phases[action_phase_idx] - x = self._infer_action_module( - action_phase, x, pre_infer_out - ) + x = self._infer_action_module(action_phase, x, pre_infer_out) # --- Phase 4 (or 3): FFN --- ffn_phase_idx = 4 if has_action else 3 @@ -304,10 +302,10 @@ def _infer_self_attn_mg3(self, phase, x, shift_msa, scale_msa, pre_infer_out): if memory_length > 0: hw = grid_sizes.tuple[1] * grid_sizes.tuple[2] # Split into memory and prediction parts - q_memory = q[:memory_length * hw].unsqueeze(0) - k_memory = k[:memory_length * hw].unsqueeze(0) - q_pred = q[memory_length * hw:].unsqueeze(0) - k_pred = k[memory_length * hw:].unsqueeze(0) + q_memory = q[: memory_length * hw].unsqueeze(0) + k_memory = k[: memory_length * hw].unsqueeze(0) + q_pred = q[memory_length * hw :].unsqueeze(0) + k_pred = k[memory_length * hw :].unsqueeze(0) # Build grid_sizes tensors f_total = grid_sizes.tuple[0] @@ -339,7 +337,8 @@ def _infer_self_attn_mg3(self, phase, x, shift_msa, scale_msa, pre_infer_out): k_unsq = k.unsqueeze(0) grid_sizes_t = torch.tensor( [[grid_sizes.tuple[0], grid_sizes.tuple[1], grid_sizes.tuple[2]]], - dtype=torch.long, device=q.device, + dtype=torch.long, + device=q.device, ) if predict_latent_idx is not None: if isinstance(predict_latent_idx, tuple) and len(predict_latent_idx) == 2: @@ -356,7 +355,9 @@ def _infer_self_attn_mg3(self, phase, x, shift_msa, scale_msa, pre_infer_out): self.self_attn_cu_seqlens_qkv = torch.tensor([0, img_qkv_len], dtype=torch.int32).to(q.device) attn_out = phase.self_attn_1.apply( - q=q, k=k, v=v, + q=q, + k=k, + v=v, cu_seqlens_q=self.self_attn_cu_seqlens_qkv, cu_seqlens_kv=self.self_attn_cu_seqlens_qkv, max_seqlen_q=img_qkv_len, @@ -403,9 +404,7 @@ def _infer_action_module(self, phase, x, pre_infer_out): if self.enable_mouse and mouse_cond is not None: batch_size, num_frames, mouse_dim = mouse_cond.shape - assert (((num_frames - 1) + self.vae_time_compression_ratio) % self.vae_time_compression_ratio == 0) or ( - num_frames % self.vae_time_compression_ratio == 0 - ) + assert (((num_frames - 1) + self.vae_time_compression_ratio) % self.vae_time_compression_ratio == 0) or (num_frames % self.vae_time_compression_ratio == 0) if ((num_frames - 1) + self.vae_time_compression_ratio) % self.vae_time_compression_ratio == 0: num_feats = int((num_frames - 1) / self.vae_time_compression_ratio) + 1 mouse_cond = torch.cat([mouse_cond[:, 0:1, :].repeat(1, pad_t, 1), mouse_cond], dim=1) @@ -434,9 +433,7 @@ def _infer_action_module(self, phase, x, pre_infer_out): mouse_groups = mouse_groups.unsqueeze(-1).repeat(1, 1, 1, 1, spatial_tokens) mouse_groups = rearrange(mouse_groups, "b t window d s -> (b s) t (window d)") if mouse_groups.shape[1] != tt: - raise ValueError( - f"matrix-game-3 mouse condition window mismatch: expected latent T={tt}, got {mouse_groups.shape[1]}" - ) + raise ValueError(f"matrix-game-3 mouse condition window mismatch: expected latent T={tt}, got {mouse_groups.shape[1]}") mouse_input = torch.cat([hidden_states_mouse, mouse_groups], dim=-1) mouse_hidden = phase.mouse_mlp_0.apply(mouse_input.reshape(-1, mouse_input.shape[-1])) @@ -450,12 +447,8 @@ def _infer_action_module(self, phase, x, pre_infer_out): mouse_qkv = mouse_qkv.reshape(batch_size * spatial_tokens, tt, 3, self.action_heads_num, mouse_head_dim) q_m, k_m, v_m = mouse_qkv.permute(2, 0, 1, 3, 4).unbind(0) - q_m = phase.img_attn_q_norm.apply(q_m.reshape(-1, mouse_head_dim)).reshape( - batch_size * spatial_tokens, tt, self.action_heads_num, mouse_head_dim - ) - k_m = phase.img_attn_k_norm.apply(k_m.reshape(-1, mouse_head_dim)).reshape( - batch_size * spatial_tokens, tt, self.action_heads_num, mouse_head_dim - ) + q_m = phase.img_attn_q_norm.apply(q_m.reshape(-1, mouse_head_dim)).reshape(batch_size * spatial_tokens, tt, self.action_heads_num, mouse_head_dim) + k_m = phase.img_attn_k_norm.apply(k_m.reshape(-1, mouse_head_dim)).reshape(batch_size * spatial_tokens, tt, self.action_heads_num, mouse_head_dim) if memory_length > 0: freqs_memory = self._get_action_rotary_pos_embed(memory_length, mouse_head_dim, self.mouse_qk_dim_list) @@ -475,18 +468,14 @@ def _infer_action_module(self, phase, x, pre_infer_out): mouse_attn = self._run_flash_attention(q_m, k_m, v_m, causal=False) mouse_attn = rearrange(mouse_attn, "(b s) t h d -> b (t s) (h d)", b=batch_size, s=spatial_tokens) - mouse_proj = phase.proj_mouse.apply(mouse_attn.reshape(-1, mouse_attn.shape[-1])).reshape( - batch_size, tt * spatial_tokens, -1 - ) + mouse_proj = phase.proj_mouse.apply(mouse_attn.reshape(-1, mouse_attn.shape[-1])).reshape(batch_size, tt * spatial_tokens, -1) hidden_states = x_in + mouse_proj else: hidden_states = x_in if self.enable_keyboard and keyboard_cond is not None: batch_size, num_frames, _ = keyboard_cond.shape - assert (((num_frames - 1) + self.vae_time_compression_ratio) % self.vae_time_compression_ratio == 0) or ( - num_frames % self.vae_time_compression_ratio == 0 - ) + assert (((num_frames - 1) + self.vae_time_compression_ratio) % self.vae_time_compression_ratio == 0) or (num_frames % self.vae_time_compression_ratio == 0) if ((num_frames - 1) + self.vae_time_compression_ratio) % self.vae_time_compression_ratio == 0: num_feats = int((num_frames - 1) / self.vae_time_compression_ratio) + 1 keyboard_cond = torch.cat([keyboard_cond[:, 0:1, :].repeat(1, pad_t, 1), keyboard_cond], dim=1) @@ -521,14 +510,10 @@ def _infer_action_module(self, phase, x, pre_infer_out): keyboard_groups = torch.cat([keyboard_memory, keyboard_groups], dim=1) if keyboard_groups.shape[1] != tt: - raise ValueError( - f"matrix-game-3 keyboard condition window mismatch: expected latent T={tt}, got {keyboard_groups.shape[1]}" - ) + raise ValueError(f"matrix-game-3 keyboard condition window mismatch: expected latent T={tt}, got {keyboard_groups.shape[1]}") keyboard_groups = keyboard_groups.reshape(batch_size, keyboard_groups.shape[1], -1) - mouse_q = phase.mouse_attn_q.apply(hidden_states.reshape(-1, hidden_states.shape[-1])).reshape( - batch_size, tt * spatial_tokens, -1 - ) + mouse_q = phase.mouse_attn_q.apply(hidden_states.reshape(-1, hidden_states.shape[-1])).reshape(batch_size, tt * spatial_tokens, -1) keyboard_kv = phase.keyboard_attn_kv.apply(keyboard_groups.reshape(-1, keyboard_groups.shape[-1])) keyboard_kv = keyboard_kv.reshape(batch_size, keyboard_groups.shape[1], -1) @@ -537,12 +522,8 @@ def _infer_action_module(self, phase, x, pre_infer_out): kv = keyboard_kv.view(batch_size, -1, 2, self.action_heads_num, keyboard_head_dim) k_k, v_k = kv.permute(2, 0, 1, 3, 4).unbind(0) - q_k = phase.key_attn_q_norm.apply(q_k.reshape(-1, keyboard_head_dim)).reshape( - batch_size, -1, self.action_heads_num, keyboard_head_dim - ) - k_k = phase.key_attn_k_norm.apply(k_k.reshape(-1, keyboard_head_dim)).reshape( - batch_size, -1, self.action_heads_num, keyboard_head_dim - ) + q_k = phase.key_attn_q_norm.apply(q_k.reshape(-1, keyboard_head_dim)).reshape(batch_size, -1, self.action_heads_num, keyboard_head_dim) + k_k = phase.key_attn_k_norm.apply(k_k.reshape(-1, keyboard_head_dim)).reshape(batch_size, -1, self.action_heads_num, keyboard_head_dim) q_k = rearrange(q_k, "b (t s) h d -> (b s) t h d", s=spatial_tokens) if memory_length > 0: @@ -565,9 +546,7 @@ def _infer_action_module(self, phase, x, pre_infer_out): v_k = v_k.repeat(spatial_tokens, 1, 1, 1) kb_attn = self._run_flash_attention(q_k, k_k, v_k, causal=False) kb_attn = rearrange(kb_attn, "(b s) t h d -> b (t s) (h d)", b=batch_size, s=spatial_tokens) - kb_proj = phase.proj_keyboard.apply(kb_attn.reshape(-1, kb_attn.shape[-1])).reshape( - batch_size, tt * spatial_tokens, -1 - ) + kb_proj = phase.proj_keyboard.apply(kb_attn.reshape(-1, kb_attn.shape[-1])).reshape(batch_size, tt * spatial_tokens, -1) hidden_states = hidden_states + kb_proj return hidden_states.squeeze(0) diff --git a/lightx2v/models/networks/wan/matrix_game3_model.py b/lightx2v/models/networks/wan/matrix_game3_model.py index c3c754eb6..9cbdfd501 100644 --- a/lightx2v/models/networks/wan/matrix_game3_model.py +++ b/lightx2v/models/networks/wan/matrix_game3_model.py @@ -1,11 +1,10 @@ import json import os -import torch from safetensors import safe_open -from lightx2v.models.networks.wan.infer.matrix_game3.pre_infer import WanMtxg3PreInfer from lightx2v.models.networks.wan.infer.matrix_game3.post_infer import WanMtxg3PostInfer +from lightx2v.models.networks.wan.infer.matrix_game3.pre_infer import WanMtxg3PreInfer from lightx2v.models.networks.wan.infer.matrix_game3.transformer_infer import WanMtxg3TransformerInfer from lightx2v.models.networks.wan.model import WanModel from lightx2v.models.networks.wan.weights.matrix_game3.pre_weights import WanMtxg3PreWeights @@ -61,10 +60,7 @@ def _load_ckpt(self, unified_dtype, sensitive_layer): # Find safetensor files safetensor_files = [f for f in os.listdir(model_dir) if f.endswith(".safetensors")] if not safetensor_files: - raise FileNotFoundError( - f"No safetensors files found in {model_dir}. " - "Please download the Matrix-Game-3.0 model weights." - ) + raise FileNotFoundError(f"No safetensors files found in {model_dir}. Please download the Matrix-Game-3.0 model weights.") weight_dict = {} for sf_file in sorted(safetensor_files): @@ -75,7 +71,7 @@ def _load_ckpt(self, unified_dtype, sensitive_layer): # Strip the common diffusers prefix if present name = key if name.startswith("model."): - name = name[len("model."):] + name = name[len("model.") :] # Cast to appropriate dtype if unified_dtype or all(s not in name for s in sensitive_layer): weight_dict[name] = tensor.to(GET_DTYPE()) diff --git a/lightx2v/models/networks/wan/weights/matrix_game3/pre_weights.py b/lightx2v/models/networks/wan/weights/matrix_game3/pre_weights.py index f187e9e8b..514e238d2 100644 --- a/lightx2v/models/networks/wan/weights/matrix_game3/pre_weights.py +++ b/lightx2v/models/networks/wan/weights/matrix_game3/pre_weights.py @@ -1,7 +1,6 @@ from lightx2v.common.modules.weight_module import WeightModule from lightx2v.utils.registry_factory import ( CONV3D_WEIGHT_REGISTER, - LN_WEIGHT_REGISTER, MM_WEIGHT_REGISTER, ) @@ -27,9 +26,7 @@ def __init__(self, config): # Patch embedding self.add_module( "patch_embedding", - CONV3D_WEIGHT_REGISTER["Default"]( - "patch_embedding.weight", "patch_embedding.bias", stride=self.patch_size - ), + CONV3D_WEIGHT_REGISTER["Default"]("patch_embedding.weight", "patch_embedding.bias", stride=self.patch_size), ) # Text embedding (2-layer MLP with GELU) @@ -59,19 +56,13 @@ def __init__(self, config): # Camera plucker embedding (global, before blocks) self.add_module( "patch_embedding_wancamctrl", - MM_WEIGHT_REGISTER["Default"]( - "patch_embedding_wancamctrl.weight", "patch_embedding_wancamctrl.bias" - ), + MM_WEIGHT_REGISTER["Default"]("patch_embedding_wancamctrl.weight", "patch_embedding_wancamctrl.bias"), ) self.add_module( "c2ws_hidden_states_layer1", - MM_WEIGHT_REGISTER["Default"]( - "c2ws_hidden_states_layer1.weight", "c2ws_hidden_states_layer1.bias" - ), + MM_WEIGHT_REGISTER["Default"]("c2ws_hidden_states_layer1.weight", "c2ws_hidden_states_layer1.bias"), ) self.add_module( "c2ws_hidden_states_layer2", - MM_WEIGHT_REGISTER["Default"]( - "c2ws_hidden_states_layer2.weight", "c2ws_hidden_states_layer2.bias" - ), + MM_WEIGHT_REGISTER["Default"]("c2ws_hidden_states_layer2.weight", "c2ws_hidden_states_layer2.bias"), ) diff --git a/lightx2v/models/networks/wan/weights/matrix_game3/transformer_weights.py b/lightx2v/models/networks/wan/weights/matrix_game3/transformer_weights.py index 15e64db34..7e99de57f 100644 --- a/lightx2v/models/networks/wan/weights/matrix_game3/transformer_weights.py +++ b/lightx2v/models/networks/wan/weights/matrix_game3/transformer_weights.py @@ -1,6 +1,5 @@ from lightx2v.common.modules.weight_module import WeightModule, WeightModuleList from lightx2v.models.networks.wan.weights.transformer_weights import ( - WanCrossAttention, WanFFN, WanSelfAttention, ) @@ -39,22 +38,14 @@ def __init__(self, config): block_list = [] for i in range(self.blocks_num): has_action = i in action_blocks - block_list.append( - WanMtxg3TransformerBlock( - i, self.task, self.mm_type, self.config, has_action=has_action - ) - ) + block_list.append(WanMtxg3TransformerBlock(i, self.task, self.mm_type, self.config, has_action=has_action)) self.blocks = WeightModuleList(block_list) self.add_module("blocks", self.blocks) # Non-block weights (head) self.register_parameter("norm", LN_WEIGHT_REGISTER["torch"]()) - self.add_module( - "head", MM_WEIGHT_REGISTER["Default"]("head.head.weight", "head.head.bias") - ) - self.register_parameter( - "head_modulation", TENSOR_REGISTER["Default"]("head.modulation") - ) + self.add_module("head", MM_WEIGHT_REGISTER["Default"]("head.head.weight", "head.head.bias")) + self.register_parameter("head_modulation", TENSOR_REGISTER["Default"]("head.modulation")) def non_block_weights_to_cuda(self): self.norm.to_cuda() diff --git a/lightx2v/models/runners/wan/wan_matrix_game3_runner.py b/lightx2v/models/runners/wan/wan_matrix_game3_runner.py index 1df30aa2e..336b994aa 100644 --- a/lightx2v/models/runners/wan/wan_matrix_game3_runner.py +++ b/lightx2v/models/runners/wan/wan_matrix_game3_runner.py @@ -10,8 +10,8 @@ import torch import torch.distributed as dist import torchvision.transforms.functional as TF -from einops import rearrange from PIL import Image +from einops import rearrange from loguru import logger from lightx2v.models.runners.wan.wan_runner import Wan22DenseRunner, build_wan_model_with_lora @@ -1004,12 +1004,16 @@ def _build_plucker_from_c2ws( assert self._mg3_target_h is not None and self._mg3_target_w is not None assert self._mg3_lat_h is not None and self._mg3_lat_w is not None c2ws_np = c2ws_seq.cpu().numpy() - c2ws_infer = modules["cam_utils"]._interpolate_camera_poses_handedness( - src_indices=src_indices, - src_rot_mat=c2ws_np[:, :3, :3], - src_trans_vec=c2ws_np[:, :3, 3], - tgt_indices=tgt_indices, - ).to(device=c2ws_seq.device) + c2ws_infer = ( + modules["cam_utils"] + ._interpolate_camera_poses_handedness( + src_indices=src_indices, + src_rot_mat=c2ws_np[:, :3, :3], + src_trans_vec=c2ws_np[:, :3, 3], + tgt_indices=tgt_indices, + ) + .to(device=c2ws_seq.device) + ) # `framewise=True` means each timestep is represented relative to its own local # frame history, which matches the official per-segment conditioning path. c2ws_infer = modules["cam_utils"].compute_relative_poses(c2ws_infer, framewise=framewise) From ee387159c48a5e1e5ee5e13ef38471b9427d0d16 Mon Sep 17 00:00:00 2001 From: Yang Date: Wed, 8 Apr 2026 13:32:20 +0800 Subject: [PATCH 19/25] set the accuracy --- .../wan/infer/matrix_game3/pre_infer.py | 6 +- .../infer/matrix_game3/transformer_infer.py | 55 +++++++------------ scripts/matrix_game3/run_matrix_game3_base.sh | 10 ++++ 3 files changed, 35 insertions(+), 36 deletions(-) diff --git a/lightx2v/models/networks/wan/infer/matrix_game3/pre_infer.py b/lightx2v/models/networks/wan/infer/matrix_game3/pre_infer.py index 2284588ff..a86427cd1 100644 --- a/lightx2v/models/networks/wan/infer/matrix_game3/pre_infer.py +++ b/lightx2v/models/networks/wan/infer/matrix_game3/pre_infer.py @@ -161,9 +161,11 @@ def infer(self, weights, inputs, kv_start=0, kv_end=0): else: embed = weights.time_embedding_0.apply(embed) embed = torch.nn.functional.silu(embed) - embed = weights.time_embedding_2.apply(embed) + embed = weights.time_embedding_2.apply(embed).float() + # Official MG3 keeps both the time embedding and its 6-way modulation + # projection in fp32 before each block consumes them. embed0 = torch.nn.functional.silu(embed) - embed0 = weights.time_projection_1.apply(embed0).unflatten(1, (6, self.dim)) + embed0 = weights.time_projection_1.apply(embed0).unflatten(1, (6, self.dim)).float() # Text embedding if self.sensitive_layer_dtype != self.infer_dtype: diff --git a/lightx2v/models/networks/wan/infer/matrix_game3/transformer_infer.py b/lightx2v/models/networks/wan/infer/matrix_game3/transformer_infer.py index 198d3324b..a4686beef 100644 --- a/lightx2v/models/networks/wan/infer/matrix_game3/transformer_infer.py +++ b/lightx2v/models/networks/wan/infer/matrix_game3/transformer_infer.py @@ -208,10 +208,8 @@ def infer_block(self, block, x, pre_infer_out): ) # Gate and residual - if self.sensitive_layer_dtype != self.infer_dtype: - x = x.to(self.sensitive_layer_dtype) + y_out.to(self.sensitive_layer_dtype) * gate_msa.squeeze() - else: - x = x + y_out * gate_msa.squeeze() + x_dtype = x.dtype + x = (x.float() + y_out.float() * gate_msa.squeeze().float()).to(x_dtype) # --- Phase 1: Camera Plucker Injection --- if pre_infer_out.plucker_emb is not None: @@ -258,22 +256,16 @@ def infer_block(self, block, x, pre_infer_out): # --- Phase 4 (or 3): FFN --- ffn_phase_idx = 4 if has_action else 3 ffn_phase = block.compute_phases[ffn_phase_idx] - norm2_out = ffn_phase.norm2.apply(x) - if self.sensitive_layer_dtype != self.infer_dtype: - norm2_out = norm2_out.to(self.sensitive_layer_dtype) - norm2_out = norm2_out * (1 + c_scale_msa.squeeze()) + c_shift_msa.squeeze() - if self.sensitive_layer_dtype != self.infer_dtype: - norm2_out = norm2_out.to(self.infer_dtype) + norm2_out = ffn_phase.norm2.apply(x).float() + norm2_out = norm2_out * (1 + c_scale_msa.squeeze().float()) + c_shift_msa.squeeze().float() + norm2_out = norm2_out.to(x_dtype) y = ffn_phase.ffn_0.apply(norm2_out) y = torch.nn.functional.gelu(y, approximate="tanh") y = ffn_phase.ffn_2.apply(y) # FFN gate + residual - if self.sensitive_layer_dtype != self.infer_dtype: - x = x.to(self.sensitive_layer_dtype) + y.to(self.sensitive_layer_dtype) * c_gate_msa.squeeze() - else: - x = x + y * c_gate_msa.squeeze() + x = (x.float() + y.float() * c_gate_msa.squeeze().float()).to(x_dtype) return x @@ -281,12 +273,11 @@ def _infer_self_attn_mg3(self, phase, x, shift_msa, scale_msa, pre_infer_out): """Self-attention with memory-aware indexed RoPE.""" cos_sin = self.cos_sin - norm1_out = phase.norm1.apply(x) - if self.sensitive_layer_dtype != self.infer_dtype: - norm1_out = norm1_out.to(self.sensitive_layer_dtype) - norm1_out = norm1_out * (1 + scale_msa.squeeze()) + shift_msa.squeeze() - if self.sensitive_layer_dtype != self.infer_dtype: - norm1_out = norm1_out.to(self.infer_dtype) + # Official MG3 performs the norm1 modulation in fp32, then casts back to + # the model dtype right before the QKV projections. + norm1_out = phase.norm1.apply(x).float() + norm1_out = norm1_out * (1 + scale_msa.squeeze().float()) + shift_msa.squeeze().float() + norm1_out = norm1_out.to(x.dtype) s, n, d = *norm1_out.shape[:1], self.num_heads, self.head_dim q = phase.self_attn_norm_q.apply(phase.self_attn_q.apply(norm1_out)).view(s, n, d) @@ -563,23 +554,19 @@ def freqs(self, value): def infer_non_blocks(self, weights, x, e): """Head processing — same as base but handles per-token time embeddings.""" if e.dim() == 2: - modulation = weights.head_modulation.tensor - e_parts = (modulation + e.unsqueeze(1)).chunk(2, dim=1) + modulation = weights.head_modulation.tensor.float() + e_parts = (modulation + e.float().unsqueeze(1)).chunk(2, dim=1) elif e.dim() == 3: - modulation = weights.head_modulation.tensor.unsqueeze(2) - e_parts = (modulation + e.unsqueeze(1)).chunk(2, dim=1) + modulation = weights.head_modulation.tensor.float().unsqueeze(2) + e_parts = (modulation + e.float().unsqueeze(1)).chunk(2, dim=1) e_parts = [ei.squeeze(1) for ei in e_parts] else: - modulation = weights.head_modulation.tensor - e_parts = (modulation + e.unsqueeze(1)).chunk(2, dim=1) - - x = weights.norm.apply(x) - if self.sensitive_layer_dtype != self.infer_dtype: - x = x.to(self.sensitive_layer_dtype) - x = x * (1 + e_parts[1].squeeze()) + e_parts[0].squeeze() - if self.sensitive_layer_dtype != self.infer_dtype: - x = x.to(self.infer_dtype) - x = weights.head.apply(x) + modulation = weights.head_modulation.tensor.float() + e_parts = (modulation + e.float().unsqueeze(1)).chunk(2, dim=1) + + x = weights.norm.apply(x).float() + x = x * (1 + e_parts[1].squeeze().float()) + e_parts[0].squeeze().float() + x = weights.head.apply(x.to(self.infer_dtype)) return x def set_freqs(self, freqs): diff --git a/scripts/matrix_game3/run_matrix_game3_base.sh b/scripts/matrix_game3/run_matrix_game3_base.sh index e49810fab..513ade7e5 100755 --- a/scripts/matrix_game3/run_matrix_game3_base.sh +++ b/scripts/matrix_game3/run_matrix_game3_base.sh @@ -7,6 +7,16 @@ MODEL_PATH="${MODEL_PATH:-/path/to/Matrix-Game-3.0}" CONFIG_JSON="configs/matrix_game3/matrix_game3_base.json" SAVE_PATH="${SAVE_PATH:-save_results/matrix_game3_base}" +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +REPO_ROOT="$(cd "${SCRIPT_DIR}/../.." && pwd)" + +# Reuse the repo's standard runtime defaults. Base MG3 is notably more sensitive +# to precision than the distilled path, so default the sensitive layers to fp32. +export lightx2v_path="${lightx2v_path:-${REPO_ROOT}}" +export model_path="${model_path:-${MODEL_PATH}}" +source "${lightx2v_path}/scripts/base/base.sh" +export SENSITIVE_LAYER_DTYPE="${SENSITIVE_LAYER_DTYPE:-FP32}" + python -m lightx2v.infer \ --model_cls wan2.2_matrix_game3 \ --task i2v \ From aa0f290a2f6827450b137b429785eeaae739876e Mon Sep 17 00:00:00 2001 From: Yang Date: Wed, 8 Apr 2026 16:26:36 +0800 Subject: [PATCH 20/25] Fix the dtype --- lightx2v/models/networks/wan/infer/matrix_game3/pre_infer.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/lightx2v/models/networks/wan/infer/matrix_game3/pre_infer.py b/lightx2v/models/networks/wan/infer/matrix_game3/pre_infer.py index a86427cd1..8703f3a5d 100644 --- a/lightx2v/models/networks/wan/infer/matrix_game3/pre_infer.py +++ b/lightx2v/models/networks/wan/infer/matrix_game3/pre_infer.py @@ -164,7 +164,8 @@ def infer(self, weights, inputs, kv_start=0, kv_end=0): embed = weights.time_embedding_2.apply(embed).float() # Official MG3 keeps both the time embedding and its 6-way modulation # projection in fp32 before each block consumes them. - embed0 = torch.nn.functional.silu(embed) + modulation_dtype = self.sensitive_layer_dtype if self.sensitive_layer_dtype != self.infer_dtype else self.infer_dtype + embed0 = torch.nn.functional.silu(embed).to(modulation_dtype) embed0 = weights.time_projection_1.apply(embed0).unflatten(1, (6, self.dim)).float() # Text embedding From f23e68164e11d6cba6eb5d11220791b870306e57 Mon Sep 17 00:00:00 2001 From: Yang Date: Fri, 10 Apr 2026 08:29:27 +0800 Subject: [PATCH 21/25] remove fp32 --- .../models/networks/wan/infer/matrix_game3/pre_infer.py | 6 +++--- scripts/matrix_game3/run_matrix_game3_base.sh | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/lightx2v/models/networks/wan/infer/matrix_game3/pre_infer.py b/lightx2v/models/networks/wan/infer/matrix_game3/pre_infer.py index 8703f3a5d..0e22a9786 100644 --- a/lightx2v/models/networks/wan/infer/matrix_game3/pre_infer.py +++ b/lightx2v/models/networks/wan/infer/matrix_game3/pre_infer.py @@ -150,7 +150,7 @@ def infer(self, weights, inputs, kv_start=0, kv_end=0): t = torch.cat([timestep_memory.squeeze(0).to(device=x.device, dtype=x.dtype), t.to(device=x.device, dtype=x.dtype)], dim=0) # Patch embedding - x = weights.patch_embedding.apply(x.unsqueeze(0)) + x = weights.patch_embedding.apply(x.unsqueeze(0)).to(self.infer_dtype) grid_sizes_t, grid_sizes_h, grid_sizes_w = x.shape[2:] x = x.flatten(2).transpose(1, 2).contiguous() @@ -174,7 +174,7 @@ def infer(self, weights, inputs, kv_start=0, kv_end=0): else: out = weights.text_embedding_0.apply(context.squeeze(0)) out = torch.nn.functional.gelu(out, approximate="tanh") - context = weights.text_embedding_2.apply(out) + context = weights.text_embedding_2.apply(out).to(self.infer_dtype) # Grid sizes and RoPE grid_sizes = GridOutput( @@ -227,7 +227,7 @@ def infer(self, weights, inputs, kv_start=0, kv_end=0): plucker_emb = weights.patch_embedding_wancamctrl.apply(plucker_emb.squeeze(0)) plucker_hidden = weights.c2ws_hidden_states_layer2.apply(torch.nn.functional.silu(weights.c2ws_hidden_states_layer1.apply(plucker_emb))) - plucker_emb = plucker_emb + plucker_hidden + plucker_emb = (plucker_emb + plucker_hidden).to(self.infer_dtype) return WanMtxg3PreInferOutput( embed=embed, diff --git a/scripts/matrix_game3/run_matrix_game3_base.sh b/scripts/matrix_game3/run_matrix_game3_base.sh index 513ade7e5..b2188a721 100755 --- a/scripts/matrix_game3/run_matrix_game3_base.sh +++ b/scripts/matrix_game3/run_matrix_game3_base.sh @@ -10,12 +10,12 @@ SAVE_PATH="${SAVE_PATH:-save_results/matrix_game3_base}" SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" REPO_ROOT="$(cd "${SCRIPT_DIR}/../.." && pwd)" -# Reuse the repo's standard runtime defaults. Base MG3 is notably more sensitive -# to precision than the distilled path, so default the sensitive layers to fp32. +# Reuse the repo's standard runtime defaults. MG3 base needs localized fp32 math +# around time modulation and residuals, but broad fp32-sensitive-layer overrides +# skew the main bf16 execution path away from the official implementation. export lightx2v_path="${lightx2v_path:-${REPO_ROOT}}" export model_path="${model_path:-${MODEL_PATH}}" source "${lightx2v_path}/scripts/base/base.sh" -export SENSITIVE_LAYER_DTYPE="${SENSITIVE_LAYER_DTYPE:-FP32}" python -m lightx2v.infer \ --model_cls wan2.2_matrix_game3 \ From 438449e42f0f43ec5d91e6dc9e60d1f174b08ac1 Mon Sep 17 00:00:00 2001 From: Yang Date: Fri, 10 Apr 2026 10:01:07 +0800 Subject: [PATCH 22/25] integrate the official root to LightX2V --- .../runners/wan/wan_matrix_game3_runner.py | 1020 ++++++++++++++--- 1 file changed, 869 insertions(+), 151 deletions(-) diff --git a/lightx2v/models/runners/wan/wan_matrix_game3_runner.py b/lightx2v/models/runners/wan/wan_matrix_game3_runner.py index 336b994aa..7eb7566a3 100644 --- a/lightx2v/models/runners/wan/wan_matrix_game3_runner.py +++ b/lightx2v/models/runners/wan/wan_matrix_game3_runner.py @@ -1,18 +1,22 @@ -import importlib.util import json -import sys -import types +import math +import random from dataclasses import dataclass from pathlib import Path -from typing import Any, Optional +from typing import Any, List, Optional, Tuple, Union import numpy as np import torch import torch.distributed as dist import torchvision.transforms.functional as TF +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.schedulers.scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput +from diffusers.utils import deprecate from PIL import Image from einops import rearrange from loguru import logger +from scipy.interpolate import interp1d +from scipy.spatial.transform import Rotation, Slerp from lightx2v.models.runners.wan.wan_runner import Wan22DenseRunner, build_wan_model_with_lora from lightx2v.models.schedulers.scheduler import BaseScheduler @@ -27,12 +31,7 @@ _PROJECT_ROOT = Path(__file__).resolve().parents[4] -_MATRIX_GAME3_OFFICIAL_ROOT_RELATIVE_CANDIDATES = ( - Path("Matrix-Game-3") / "Matrix-Game-3", - Path("Matrix-Game-3"), -) _MATRIX_GAME3_CONFIG_ROOT_RELATIVE = Path("Matrix-Game-3.0") -_MATRIX_GAME3_OFFICIAL_PACKAGE = "_lightx2v_matrix_game3_official" _MATRIX_GAME3_DEFAULT_NEGATIVE_PROMPT = ( "Vibrant colors, overexposure, static, blurred details, subtitles, style, artwork, " "painting, still image, overall grayness, worst quality, low quality, JPEG compression " @@ -40,6 +39,11 @@ "deformed, disfigured, malformed limbs, fused fingers, still image, cluttered background, " "three legs, crowded background, walking backwards" ) +_MATRIX_GAME3_WSAD_OFFSET = 12.35 +_MATRIX_GAME3_DIAGONAL_OFFSET = 8.73 +_MATRIX_GAME3_MOUSE_PITCH_SENSITIVITY = 15.0 +_MATRIX_GAME3_MOUSE_YAW_SENSITIVITY = 15.0 +_MATRIX_GAME3_MOUSE_THRESHOLD = 0.02 @dataclass @@ -69,29 +73,6 @@ class MatrixGame3SegmentState: dit_cond_dict: dict[str, Any] -def _load_module_from_path(module_name: str, file_path: Path): - """Import an official Matrix-Game-3 helper module by filesystem path once.""" - if module_name in sys.modules: - return sys.modules[module_name] - spec = importlib.util.spec_from_file_location(module_name, file_path) - if spec is None or spec.loader is None: - raise ImportError(f"failed to load module {module_name} from {file_path}") - module = importlib.util.module_from_spec(spec) - sys.modules[module_name] = module - spec.loader.exec_module(module) - return module - - -def _ensure_namespace_package(package_name: str, package_path: Path): - """Register a synthetic namespace package so relative imports inside official code work.""" - if package_name in sys.modules: - return sys.modules[package_name] - module = types.ModuleType(package_name) - module.__path__ = [str(package_path)] - sys.modules[package_name] = module - return module - - def _expand_path_candidates(path_value: Any) -> list[Path]: """Resolve a user-provided path against cwd and the project root when needed.""" raw_path = Path(str(path_value)).expanduser() @@ -104,9 +85,854 @@ def _expand_path_candidates(path_value: Any) -> list[Path]: return candidates -def _append_unique_path(paths: list[Path], candidate: Path): - if candidate not in paths: - paths.append(candidate) +def _matrix_game3_combine_data(data, num_frames=57, keyboard_dim=4, mouse=True): + assert num_frames % 4 == 1 + keyboard_condition = torch.zeros((num_frames, keyboard_dim)) + if mouse: + mouse_condition = torch.zeros((num_frames, 2)) + + current_frame = 0 + selections = [12] + + while current_frame < num_frames: + rd_frame = selections[random.randint(0, len(selections) - 1)] + rd = random.randint(0, len(data) - 1) + keyboard_sample = data[rd]["keyboard_condition"] + if mouse: + mouse_sample = data[rd]["mouse_condition"] + + if current_frame == 0: + keyboard_condition[:1] = keyboard_sample[:1] + if mouse: + mouse_condition[:1] = mouse_sample[:1] + current_frame = 1 + else: + rd_frame = min(rd_frame, num_frames - current_frame) + repeat_time = rd_frame // 4 + keyboard_condition[current_frame : current_frame + rd_frame] = keyboard_sample.repeat(repeat_time, 1) + if mouse: + mouse_condition[current_frame : current_frame + rd_frame] = mouse_sample.repeat(repeat_time, 1) + current_frame += rd_frame + + if mouse: + return { + "keyboard_condition": keyboard_condition, + "mouse_condition": mouse_condition, + } + return {"keyboard_condition": keyboard_condition} + + +def _matrix_game3_bench_actions_universal(num_frames, num_samples_per_action=4): + actions_single_action = [ + "forward", + "left", + "right", + ] + actions_double_action = [ + "forward_left", + "forward_right", + ] + + actions_single_camera = [ + "camera_l", + "camera_r", + ] + actions_to_test = actions_double_action * 5 + actions_single_camera * 5 + actions_single_action * 5 + for action in (actions_single_action + actions_double_action): + for camera in actions_single_camera: + actions_to_test.append(f"{action}_{camera}") + + base_action = actions_single_action + actions_single_camera + keyboard_idx = { + "forward": 0, + "back": 1, + "left": 2, + "right": 3, + } + cam_value = 0.1 + camera_value_map = { + "camera_up": [cam_value, 0], + "camera_down": [-cam_value, 0], + "camera_l": [0, -cam_value], + "camera_r": [0, cam_value], + "camera_ur": [cam_value, cam_value], + "camera_ul": [cam_value, -cam_value], + "camera_dr": [-cam_value, cam_value], + "camera_dl": [-cam_value, -cam_value], + } + + data = [] + for action_name in actions_to_test: + keyboard_condition = [[0, 0, 0, 0, 0, 0] for _ in range(num_samples_per_action)] + mouse_condition = [[0, 0] for _ in range(num_samples_per_action)] + + for sub_action in base_action: + if sub_action not in action_name: + continue + if sub_action in camera_value_map: + mouse_condition = [camera_value_map[sub_action] for _ in range(num_samples_per_action)] + elif sub_action in keyboard_idx: + col = keyboard_idx[sub_action] + for row in keyboard_condition: + row[col] = 1 + + data.append( + { + "keyboard_condition": torch.tensor(keyboard_condition), + "mouse_condition": torch.tensor(mouse_condition), + } + ) + + return _matrix_game3_combine_data(data, num_frames, keyboard_dim=6, mouse=True) + + +def _matrix_game3_compute_next_pose_from_action(current_pose, keyboard_action, mouse_action): + x, y, z, pitch, yaw = current_pose + w, s, a, d = keyboard_action[:4] + mouse_x, mouse_y = mouse_action[:2] + + delta_pitch = _MATRIX_GAME3_MOUSE_PITCH_SENSITIVITY * mouse_x if abs(mouse_x) >= _MATRIX_GAME3_MOUSE_THRESHOLD else 0.0 + delta_yaw = _MATRIX_GAME3_MOUSE_YAW_SENSITIVITY * mouse_y if abs(mouse_y) >= _MATRIX_GAME3_MOUSE_THRESHOLD else 0.0 + + new_pitch = pitch + delta_pitch + new_yaw = yaw + delta_yaw + + while new_yaw > 180: + new_yaw -= 360 + while new_yaw < -180: + new_yaw += 360 + + local_forward = 0.0 + if w > 0.5 and s < 0.5: + local_forward = _MATRIX_GAME3_WSAD_OFFSET + elif s > 0.5 and w < 0.5: + local_forward = -_MATRIX_GAME3_WSAD_OFFSET + + local_right = 0.0 + if d > 0.5 and a < 0.5: + local_right = _MATRIX_GAME3_WSAD_OFFSET + elif a > 0.5 and d < 0.5: + local_right = -_MATRIX_GAME3_WSAD_OFFSET + + if abs(local_forward) > 0.1 and abs(local_right) > 0.1: + local_forward = np.sign(local_forward) * _MATRIX_GAME3_DIAGONAL_OFFSET + local_right = np.sign(local_right) * _MATRIX_GAME3_DIAGONAL_OFFSET + + avg_yaw = float((yaw + new_yaw) / 2.0) + yaw_rad = float(np.deg2rad(avg_yaw)) + cos_yaw = np.cos(yaw_rad) + sin_yaw = np.sin(yaw_rad) + + delta_x = cos_yaw * local_forward - sin_yaw * local_right + delta_y = sin_yaw * local_forward + cos_yaw * local_right + return np.array([x + delta_x, y + delta_y, z, new_pitch, new_yaw], dtype=np.float32) + + +def _matrix_game3_compute_all_poses_from_actions(keyboard_conditions, mouse_conditions, first_pose=None, return_last_pose=False): + total_frames = len(keyboard_conditions) + all_poses = np.zeros((total_frames, 5), dtype=np.float32) + if first_pose is not None: + all_poses[0] = first_pose + + for idx in range(total_frames - 1): + all_poses[idx + 1] = _matrix_game3_compute_next_pose_from_action( + all_poses[idx], + keyboard_conditions[idx], + mouse_conditions[idx], + ) + + if return_last_pose: + last_pose = _matrix_game3_compute_next_pose_from_action( + all_poses[-1], + keyboard_conditions[-1], + mouse_conditions[-1], + ) + return all_poses, last_pose + return all_poses + + +def _matrix_game3_interpolate_camera_poses(src_indices, src_rot_mat, src_trans_vec, tgt_indices): + interp_func_trans = interp1d( + src_indices, + src_trans_vec, + axis=0, + kind="linear", + bounds_error=False, + fill_value="extrapolate", + ) + interpolated_trans_vec = interp_func_trans(tgt_indices) + + src_quat_vec = Rotation.from_matrix(src_rot_mat) + quats = src_quat_vec.as_quat().copy() + for idx in range(1, len(quats)): + if np.dot(quats[idx], quats[idx - 1]) < 0: + quats[idx] = -quats[idx] + src_quat_vec = Rotation.from_quat(quats) + slerp_func_rot = Slerp(src_indices, src_quat_vec) + interpolated_rot_quat = slerp_func_rot(tgt_indices) + interpolated_rot_mat = interpolated_rot_quat.as_matrix() + + poses = np.zeros((len(tgt_indices), 4, 4), dtype=np.float32) + poses[:, :3, :3] = interpolated_rot_mat + poses[:, :3, 3] = interpolated_trans_vec + poses[:, 3, 3] = 1.0 + return torch.from_numpy(poses).float() + + +def _matrix_game3_se3_inverse(transform): + rotation = transform[:, :3, :3] + translation = transform[:, :3, 3:] + rotation_inv = rotation.transpose(-1, -2) + translation_inv = -torch.bmm(rotation_inv, translation) + inverse = torch.eye(4, device=transform.device, dtype=transform.dtype)[None, :, :].repeat(transform.shape[0], 1, 1) + inverse[:, :3, :3] = rotation_inv + inverse[:, :3, 3:] = translation_inv + return inverse + + +def _matrix_game3_compute_relative_poses(c2ws_mat, framewise=False, normalize_trans=True): + ref_w2cs = _matrix_game3_se3_inverse(c2ws_mat[0:1]) + relative_poses = torch.matmul(ref_w2cs, c2ws_mat) + relative_poses[0] = torch.eye(4, device=c2ws_mat.device, dtype=c2ws_mat.dtype) + if framewise: + relative_poses_framewise = torch.bmm(_matrix_game3_se3_inverse(relative_poses[:-1]), relative_poses[1:]) + relative_poses[1:] = relative_poses_framewise + if normalize_trans: + translations = relative_poses[:, :3, 3] + max_norm = torch.norm(translations, dim=-1).max() + if max_norm > 0: + relative_poses[:, :3, 3] = translations / max_norm + return relative_poses + + +@torch.no_grad() +def _matrix_game3_create_meshgrid(n_frames, height, width, bias=0.5, device="cuda", dtype=torch.float32): + x_range = torch.arange(width, device=device, dtype=dtype) + y_range = torch.arange(height, device=device, dtype=dtype) + grid_y, grid_x = torch.meshgrid(y_range, x_range, indexing="ij") + grid_xy = torch.stack([grid_x, grid_y], dim=-1).view([-1, 2]) + bias + return grid_xy[None, ...].repeat(n_frames, 1, 1) + + +def _matrix_game3_get_plucker_embeddings(c2ws_mat, intrinsics, height, width): + n_frames = c2ws_mat.shape[0] + grid_xy = _matrix_game3_create_meshgrid(n_frames, height, width, device=c2ws_mat.device, dtype=c2ws_mat.dtype) + fx, fy, cx, cy = intrinsics.chunk(4, dim=-1) + i = grid_xy[..., 0] + j = grid_xy[..., 1] + zs = torch.ones_like(i) + xs = (i - cx) / fx * zs + ys = (j - cy) / fy * zs + + directions = torch.stack([xs, ys, zs], dim=-1) + directions = directions / directions.norm(dim=-1, keepdim=True) + rays_d = directions @ c2ws_mat[:, :3, :3].transpose(-1, -2) + rays_o = c2ws_mat[:, :3, 3] + rays_o = rays_o[:, None, :].expand_as(rays_d) + return torch.cat([rays_o, rays_d], dim=-1).view([n_frames, height, width, 6]) + + +def _matrix_game3_select_memory_idx_fov(extrinsics_all, current_start_frame_idx, selected_index_base, return_confidence=False, use_gpu=False): + if not use_gpu: + use_gpu = True + + device = extrinsics_all.device if isinstance(extrinsics_all, torch.Tensor) else torch.device("cuda" if torch.cuda.is_available() else "cpu") + if isinstance(extrinsics_all, np.ndarray): + extrinsics_tensor = torch.from_numpy(extrinsics_all).to(device).float() + else: + extrinsics_tensor = extrinsics_all.to(device).float() + + video_w, video_h = 1280, 720 + fov_rad = np.deg2rad(90) + fx = video_w / (2 * np.tan(fov_rad / 2)) + fy = video_h / (2 * np.tan(fov_rad / 2)) + + if current_start_frame_idx <= 1: + empty_index = [0] * len(selected_index_base) + empty_conf = [0.0] * len(selected_index_base) + return (empty_index, empty_conf) if return_confidence else empty_index + + candidate_indices = torch.arange(1, current_start_frame_idx, device=device) + rotation = extrinsics_tensor[candidate_indices, :3, :3] + translation = extrinsics_tensor[candidate_indices, :3, 3:4] + rotation_inv = rotation.transpose(1, 2) + translation_inv = -torch.bmm(rotation_inv, translation) + + selected_index = [] + selected_confidence = [] + near, far = 0.1, 30.0 + num_side = 10 + z_samples = torch.linspace(near, far, num_side, device=device) + x_samples = torch.linspace(-1, 1, num_side, device=device) + y_samples = torch.linspace(-1, 1, num_side, device=device) + grid_x, grid_y, grid_z = torch.meshgrid(x_samples, y_samples, z_samples, indexing="ij") + points_cam_base = torch.stack( + [ + grid_x.reshape(-1) * grid_z.reshape(-1) * (video_w / (2 * fx)), + grid_y.reshape(-1) * grid_z.reshape(-1) * (video_h / (2 * fy)), + grid_z.reshape(-1), + ], + dim=0, + ) + + for base_idx in selected_index_base: + extrinsics = extrinsics_tensor[base_idx] + points_world = extrinsics[:3, :3] @ points_cam_base + extrinsics[:3, 3:4] + points_world_batched = points_world.unsqueeze(0) + points_in_candidates = torch.bmm(rotation_inv, points_world_batched.expand(len(candidate_indices), -1, -1)) + translation_inv + + x = points_in_candidates[:, 0, :] + y = points_in_candidates[:, 1, :] + z = points_in_candidates[:, 2, :] + u = (x * fx / torch.clamp(z, min=1e-6)) + video_w / 2 + v = (y * fy / torch.clamp(z, min=1e-6)) + video_h / 2 + + in_view = (z > near) & (z < far) & (u >= 0) & (u <= video_w) & (v >= 0) & (v <= video_h) + ratios = in_view.float().mean(dim=1) + best_idx = torch.argmax(ratios) + selected_index.append(candidate_indices[best_idx].item()) + selected_confidence.append(ratios[best_idx].item()) + + return (selected_index, selected_confidence) if return_confidence else selected_index + + +def _matrix_game3_get_extrinsics(video_rotation, video_position): + num_frames = len(video_rotation) + extrinsics_vid = [] + for idx in range(num_frames): + roll, pitch, yaw = video_rotation[idx] + roll, pitch, yaw = np.radians([roll, pitch, yaw]) + + rotation_z = np.array([[np.cos(yaw), -np.sin(yaw), 0], [np.sin(yaw), np.cos(yaw), 0], [0, 0, 1]]) + rotation_y = np.array([[np.cos(pitch), 0, np.sin(pitch)], [0, 1, 0], [-np.sin(pitch), 0, np.cos(pitch)]]) + rotation_x = np.array([[1, 0, 0], [0, np.cos(roll), -np.sin(roll)], [0, np.sin(roll), np.cos(roll)]]) + rotation = rotation_z @ rotation_y @ rotation_x + + extrinsics = np.eye(4, dtype=np.float32) + extrinsics[:3, :3] = rotation + extrinsics[:3, 3] = video_position[idx] + extrinsics_vid.append(extrinsics) + + rotation_init = np.array( + [ + [0, 0, 1], + [1, 0, 0], + [0, -1, 0], + ], + dtype=np.float32, + ) + extrinsics = torch.from_numpy(np.array(extrinsics_vid, dtype=np.float32)) + extrinsics[:, :3, :3] = extrinsics[:, :3, :3] @ rotation_init + extrinsics[:, :3, 3] = extrinsics[:, :3, 3] * 0.01 + return extrinsics + + +def _matrix_game3_get_intrinsics(height, width): + fov_deg = 90 + fov_rad = np.deg2rad(fov_deg) + fx = width / (2 * np.tan(fov_rad / 2)) + fy = height / (2 * np.tan(fov_rad / 2)) + cx = width / 2 + cy = height / 2 + return torch.tensor([fx, fy, cx, cy]) + + +def _matrix_game3_interpolate_camera_poses_handedness(src_indices, src_rot_mat, src_trans_vec, tgt_indices): + dets = np.linalg.det(src_rot_mat) + flip_handedness = dets.size > 0 and np.median(dets) < 0.0 + if flip_handedness: + flip_mat = np.diag([1.0, 1.0, -1.0]).astype(src_rot_mat.dtype) + src_rot_mat = src_rot_mat @ flip_mat + c2ws = _matrix_game3_interpolate_camera_poses( + src_indices=src_indices, + src_rot_mat=src_rot_mat, + src_trans_vec=src_trans_vec, + tgt_indices=tgt_indices, + ) + if flip_handedness: + flip_mat_t = torch.from_numpy(flip_mat).to(c2ws.device, dtype=c2ws.dtype) + c2ws[:, :3, :3] = c2ws[:, :3, :3] @ flip_mat_t + return c2ws + + +class _MatrixGame3ConditionsShim: + Bench_actions_universal = staticmethod(_matrix_game3_bench_actions_universal) + + +class _MatrixGame3UtilsShim: + compute_all_poses_from_actions = staticmethod(_matrix_game3_compute_all_poses_from_actions) + + +class _MatrixGame3CamUtilsShim: + _interpolate_camera_poses_handedness = staticmethod(_matrix_game3_interpolate_camera_poses_handedness) + compute_relative_poses = staticmethod(_matrix_game3_compute_relative_poses) + get_plucker_embeddings = staticmethod(_matrix_game3_get_plucker_embeddings) + select_memory_idx_fov = staticmethod(_matrix_game3_select_memory_idx_fov) + get_extrinsics = staticmethod(_matrix_game3_get_extrinsics) + get_intrinsics = staticmethod(_matrix_game3_get_intrinsics) + + +class FlowUniPCMultistepScheduler(SchedulerMixin, ConfigMixin): + """Inlined Matrix-Game-3 FlowUniPC scheduler from the official implementation.""" + + _compatibles = [scheduler.name for scheduler in KarrasDiffusionSchedulers] + order = 1 + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + solver_order: int = 2, + prediction_type: str = "flow_prediction", + shift: Optional[float] = 1.0, + use_dynamic_shifting: bool = False, + thresholding: bool = False, + dynamic_thresholding_ratio: float = 0.995, + sample_max_value: float = 1.0, + predict_x0: bool = True, + solver_type: str = "bh2", + lower_order_final: bool = True, + disable_corrector: List[int] = [], + solver_p: SchedulerMixin = None, + timestep_spacing: str = "linspace", + steps_offset: int = 0, + final_sigmas_type: Optional[str] = "zero", + ): + if solver_type not in ["bh1", "bh2"]: + if solver_type in ["midpoint", "heun", "logrho"]: + self.register_to_config(solver_type="bh2") + else: + raise NotImplementedError(f"{solver_type} is not implemented for {self.__class__}") + + self.predict_x0 = predict_x0 + self.num_inference_steps = None + alphas = np.linspace(1, 1 / num_train_timesteps, num_train_timesteps)[::-1].copy() + sigmas = 1.0 - alphas + sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32) + + if not use_dynamic_shifting: + sigmas = shift * sigmas / (1 + (shift - 1) * sigmas) + + self.sigmas = sigmas + self.timesteps = sigmas * num_train_timesteps + self.model_outputs = [None] * solver_order + self.timestep_list = [None] * solver_order + self.lower_order_nums = 0 + self.disable_corrector = disable_corrector + self.solver_p = solver_p + self.last_sample = None + self._step_index = None + self._begin_index = None + self.sigmas = self.sigmas.to("cpu") + self.sigma_min = self.sigmas[-1].item() + self.sigma_max = self.sigmas[0].item() + + @property + def step_index(self): + return self._step_index + + @property + def begin_index(self): + return self._begin_index + + def set_begin_index(self, begin_index: int = 0): + self._begin_index = begin_index + + def set_timesteps( + self, + num_inference_steps: Union[int, None] = None, + device: Union[str, torch.device] = None, + sigmas: Optional[List[float]] = None, + mu: Optional[Union[float, None]] = None, + shift: Optional[Union[float, None]] = None, + ): + if self.config.use_dynamic_shifting and mu is None: + raise ValueError("you have to pass a value for `mu` when `use_dynamic_shifting` is set to be `True`") + + if sigmas is None: + sigmas = np.linspace(self.sigma_max, self.sigma_min, num_inference_steps + 1).copy()[:-1] + + if self.config.use_dynamic_shifting: + sigmas = self.time_shift(mu, 1.0, sigmas) + else: + if shift is None: + shift = self.config.shift + sigmas = shift * sigmas / (1 + (shift - 1) * sigmas) + + if self.config.final_sigmas_type == "sigma_min": + sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5 + elif self.config.final_sigmas_type == "zero": + sigma_last = 0 + else: + raise ValueError(f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}") + + timesteps = sigmas * self.config.num_train_timesteps + sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32) + self.sigmas = torch.from_numpy(sigmas) + self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=torch.int64) + self.num_inference_steps = len(timesteps) + self.model_outputs = [None] * self.config.solver_order + self.lower_order_nums = 0 + self.last_sample = None + if self.solver_p: + self.solver_p.set_timesteps(self.num_inference_steps, device=device) + self._step_index = None + self._begin_index = None + self.sigmas = self.sigmas.to("cpu") + + def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor: + dtype = sample.dtype + batch_size, channels, *remaining_dims = sample.shape + if dtype not in (torch.float32, torch.float64): + sample = sample.float() + sample = sample.reshape(batch_size, channels * np.prod(remaining_dims)) + abs_sample = sample.abs() + s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1) + s = torch.clamp(s, min=1, max=self.config.sample_max_value) + s = s.unsqueeze(1) + sample = torch.clamp(sample, -s, s) / s + sample = sample.reshape(batch_size, channels, *remaining_dims) + return sample.to(dtype) + + def _sigma_to_t(self, sigma): + return sigma * self.config.num_train_timesteps + + def _sigma_to_alpha_sigma_t(self, sigma): + return 1 - sigma, sigma + + def time_shift(self, mu: float, sigma: float, t: torch.Tensor): + return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) + + def convert_model_output(self, model_output: torch.Tensor, *args, sample: torch.Tensor = None, **kwargs) -> torch.Tensor: + timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None) + if sample is None: + if len(args) > 1: + sample = args[1] + else: + raise ValueError("missing `sample` as a required keyward argument") + if timestep is not None: + deprecate( + "timesteps", + "1.0.0", + "Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + sigma = self.sigmas[self.step_index] + _, sigma_t = self._sigma_to_alpha_sigma_t(sigma) + + if self.predict_x0: + if self.config.prediction_type == "flow_prediction": + sigma_t = self.sigmas[self.step_index] + x0_pred = sample - sigma_t * model_output + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, `v_prediction` or `flow_prediction` for the UniPCMultistepScheduler." + ) + if self.config.thresholding: + x0_pred = self._threshold_sample(x0_pred) + return x0_pred + + if self.config.prediction_type == "flow_prediction": + sigma_t = self.sigmas[self.step_index] + epsilon = sample - (1 - sigma_t) * model_output + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, `v_prediction` or `flow_prediction` for the UniPCMultistepScheduler." + ) + if self.config.thresholding: + sigma_t = self.sigmas[self.step_index] + x0_pred = sample - sigma_t * model_output + x0_pred = self._threshold_sample(x0_pred) + epsilon = model_output + x0_pred + return epsilon + + def multistep_uni_p_bh_update(self, model_output: torch.Tensor, *args, sample: torch.Tensor = None, order: int = None, **kwargs) -> torch.Tensor: + prev_timestep = args[0] if len(args) > 0 else kwargs.pop("prev_timestep", None) + if sample is None: + if len(args) > 1: + sample = args[1] + else: + raise ValueError("missing `sample` as a required keyward argument") + if order is None: + if len(args) > 2: + order = args[2] + else: + raise ValueError("missing `order` as a required keyward argument") + if prev_timestep is not None: + deprecate( + "prev_timestep", + "1.0.0", + "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + model_output_list = self.model_outputs + s0 = self.timestep_list[-1] + m0 = model_output_list[-1] + x = sample + + if self.solver_p: + x_t = self.solver_p.step(model_output, s0, x).prev_sample + return x_t + + sigma_t, sigma_s0 = self.sigmas[self.step_index + 1], self.sigmas[self.step_index] + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) + alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0) + lambda_t = torch.log(alpha_t) - torch.log(sigma_t) + lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0) + h = lambda_t - lambda_s0 + device = sample.device + + rks = [] + d1s = [] + for idx in range(1, order): + si = self.step_index - idx + mi = model_output_list[-(idx + 1)] + alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(self.sigmas[si]) + lambda_si = torch.log(alpha_si) - torch.log(sigma_si) + rk = (lambda_si - lambda_s0) / h + rks.append(rk) + d1s.append((mi - m0) / rk) + + rks.append(1.0) + rks = torch.tensor(rks, device=device) + matrix_r = [] + vector_b = [] + hh = -h if self.predict_x0 else h + h_phi_1 = torch.expm1(hh) + h_phi_k = h_phi_1 / hh - 1 + factorial_i = 1 + + if self.config.solver_type == "bh1": + b_h = hh + elif self.config.solver_type == "bh2": + b_h = torch.expm1(hh) + else: + raise NotImplementedError() + + for idx in range(1, order + 1): + matrix_r.append(torch.pow(rks, idx - 1)) + vector_b.append(h_phi_k * factorial_i / b_h) + factorial_i *= idx + 1 + h_phi_k = h_phi_k / hh - 1 / factorial_i + + matrix_r = torch.stack(matrix_r) + vector_b = torch.tensor(vector_b, device=device) + + if len(d1s) > 0: + d1s = torch.stack(d1s, dim=1) + if order == 2: + rhos_p = torch.tensor([0.5], dtype=x.dtype, device=device) + else: + rhos_p = torch.linalg.solve(matrix_r[:-1, :-1], vector_b[:-1]).to(device).to(x.dtype) + else: + d1s = None + + if self.predict_x0: + x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0 + pred_res = torch.einsum("k,bkc...->bc...", rhos_p, d1s) if d1s is not None else 0 + x_t = x_t_ - alpha_t * b_h * pred_res + else: + x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0 + pred_res = torch.einsum("k,bkc...->bc...", rhos_p, d1s) if d1s is not None else 0 + x_t = x_t_ - sigma_t * b_h * pred_res + return x_t.to(x.dtype) + + def multistep_uni_c_bh_update( + self, + this_model_output: torch.Tensor, + *args, + last_sample: torch.Tensor = None, + this_sample: torch.Tensor = None, + order: int = None, + **kwargs, + ) -> torch.Tensor: + this_timestep = args[0] if len(args) > 0 else kwargs.pop("this_timestep", None) + if last_sample is None: + if len(args) > 1: + last_sample = args[1] + else: + raise ValueError("missing `last_sample` as a required keyward argument") + if this_sample is None: + if len(args) > 2: + this_sample = args[2] + else: + raise ValueError("missing `this_sample` as a required keyward argument") + if order is None: + if len(args) > 3: + order = args[3] + else: + raise ValueError("missing `order` as a required keyward argument") + if this_timestep is not None: + deprecate( + "this_timestep", + "1.0.0", + "Passing `this_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + model_output_list = self.model_outputs + m0 = model_output_list[-1] + x = last_sample + x_t = this_sample + model_t = this_model_output + + sigma_t, sigma_s0 = self.sigmas[self.step_index], self.sigmas[self.step_index - 1] + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) + alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0) + lambda_t = torch.log(alpha_t) - torch.log(sigma_t) + lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0) + h = lambda_t - lambda_s0 + device = this_sample.device + + rks = [] + d1s = [] + for idx in range(1, order): + si = self.step_index - (idx + 1) + mi = model_output_list[-(idx + 1)] + alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(self.sigmas[si]) + lambda_si = torch.log(alpha_si) - torch.log(sigma_si) + rk = (lambda_si - lambda_s0) / h + rks.append(rk) + d1s.append((mi - m0) / rk) + + rks.append(1.0) + rks = torch.tensor(rks, device=device) + matrix_r = [] + vector_b = [] + hh = -h if self.predict_x0 else h + h_phi_1 = torch.expm1(hh) + h_phi_k = h_phi_1 / hh - 1 + factorial_i = 1 + + if self.config.solver_type == "bh1": + b_h = hh + elif self.config.solver_type == "bh2": + b_h = torch.expm1(hh) + else: + raise NotImplementedError() + + for idx in range(1, order + 1): + matrix_r.append(torch.pow(rks, idx - 1)) + vector_b.append(h_phi_k * factorial_i / b_h) + factorial_i *= idx + 1 + h_phi_k = h_phi_k / hh - 1 / factorial_i + + matrix_r = torch.stack(matrix_r) + vector_b = torch.tensor(vector_b, device=device) + d1s = torch.stack(d1s, dim=1) if len(d1s) > 0 else None + + if order == 1: + rhos_c = torch.tensor([0.5], dtype=x.dtype, device=device) + else: + rhos_c = torch.linalg.solve(matrix_r, vector_b).to(device).to(x.dtype) + + if self.predict_x0: + x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0 + corr_res = torch.einsum("k,bkc...->bc...", rhos_c[:-1], d1s) if d1s is not None else 0 + d1_t = model_t - m0 + x_t = x_t_ - alpha_t * b_h * (corr_res + rhos_c[-1] * d1_t) + else: + x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0 + corr_res = torch.einsum("k,bkc...->bc...", rhos_c[:-1], d1s) if d1s is not None else 0 + d1_t = model_t - m0 + x_t = x_t_ - sigma_t * b_h * (corr_res + rhos_c[-1] * d1_t) + return x_t.to(x.dtype) + + def index_for_timestep(self, timestep, schedule_timesteps=None): + if schedule_timesteps is None: + schedule_timesteps = self.timesteps + indices = (schedule_timesteps == timestep).nonzero() + pos = 1 if len(indices) > 1 else 0 + return indices[pos].item() + + def _init_step_index(self, timestep): + if self.begin_index is None: + if isinstance(timestep, torch.Tensor): + timestep = timestep.to(self.timesteps.device) + self._step_index = self.index_for_timestep(timestep) + else: + self._step_index = self._begin_index + + def step( + self, + model_output: torch.Tensor, + timestep: Union[int, torch.Tensor], + sample: torch.Tensor, + return_dict: bool = True, + generator=None, + ) -> Union[SchedulerOutput, Tuple]: + if self.num_inference_steps is None: + raise ValueError("Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler") + + if self.step_index is None: + self._init_step_index(timestep) + + use_corrector = self.step_index > 0 and self.step_index - 1 not in self.disable_corrector and self.last_sample is not None + model_output_convert = self.convert_model_output(model_output, sample=sample) + if use_corrector: + sample = self.multistep_uni_c_bh_update( + this_model_output=model_output_convert, + last_sample=self.last_sample, + this_sample=sample, + order=self.this_order, + ) + + for idx in range(self.config.solver_order - 1): + self.model_outputs[idx] = self.model_outputs[idx + 1] + self.timestep_list[idx] = self.timestep_list[idx + 1] + + self.model_outputs[-1] = model_output_convert + self.timestep_list[-1] = timestep + + if self.config.lower_order_final: + this_order = min(self.config.solver_order, len(self.timesteps) - self.step_index) + else: + this_order = self.config.solver_order + + self.this_order = min(this_order, self.lower_order_nums + 1) + assert self.this_order > 0 + + self.last_sample = sample + prev_sample = self.multistep_uni_p_bh_update( + model_output=model_output, + sample=sample, + order=self.this_order, + ) + + if self.lower_order_nums < self.config.solver_order: + self.lower_order_nums += 1 + self._step_index += 1 + + if not return_dict: + return (prev_sample,) + return SchedulerOutput(prev_sample=prev_sample) + + def scale_model_input(self, sample: torch.Tensor, *args, **kwargs) -> torch.Tensor: + return sample + + def add_noise(self, original_samples: torch.Tensor, noise: torch.Tensor, timesteps: torch.IntTensor) -> torch.Tensor: + sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) + if original_samples.device.type == "mps" and torch.is_floating_point(timesteps): + schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32) + timesteps = timesteps.to(original_samples.device, dtype=torch.float32) + else: + schedule_timesteps = self.timesteps.to(original_samples.device) + timesteps = timesteps.to(original_samples.device) + + if self.begin_index is None: + step_indices = [self.index_for_timestep(timestep, schedule_timesteps) for timestep in timesteps] + elif self.step_index is not None: + step_indices = [self.step_index] * timesteps.shape[0] + else: + step_indices = [self.begin_index] * timesteps.shape[0] + + sigma = sigmas[step_indices].flatten() + while len(sigma.shape) < len(original_samples.shape): + sigma = sigma.unsqueeze(-1) + + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) + return alpha_t * original_samples + sigma_t * noise + + def __len__(self): + return self.config.num_train_timesteps class MatrixGame3OfficialSchedulerAdapter(BaseScheduler): @@ -289,24 +1115,15 @@ def load_transformer(self): def init_scheduler(self): # Distilled MG3 is already stable on the shared Wan scheduler path. Base MG3 # is far more sensitive to the exact UniPC implementation, so route it - # through the official FlowUniPC scheduler instead of the generic adapter. + # through the inlined official FlowUniPC scheduler instead of the generic adapter. if self.config.get("use_base_model", False): try: - official_root = self.resolve_official_root() - wan_root = official_root / "wan" - utils_root = wan_root / "utils" - _ensure_namespace_package(f"{_MATRIX_GAME3_OFFICIAL_PACKAGE}.wan", wan_root) - _ensure_namespace_package(f"{_MATRIX_GAME3_OFFICIAL_PACKAGE}.wan.utils", utils_root) - scheduler_module = _load_module_from_path( - f"{_MATRIX_GAME3_OFFICIAL_PACKAGE}.wan.utils.fm_solvers_unipc", - utils_root / "fm_solvers_unipc.py", - ) - self.scheduler = MatrixGame3OfficialSchedulerAdapter(self.config, scheduler_module.FlowUniPCMultistepScheduler) - logger.info("[matrix-game-3] using official FlowUniPCMultistepScheduler for base-model sampling.") + self.scheduler = MatrixGame3OfficialSchedulerAdapter(self.config, FlowUniPCMultistepScheduler) + logger.info("[matrix-game-3] using inlined official FlowUniPCMultistepScheduler for base-model sampling.") return except Exception as exc: logger.warning( - "[matrix-game-3] failed to initialize official base scheduler ({}); falling back to LightX2V WanScheduler.", + "[matrix-game-3] failed to initialize inlined official base scheduler ({}); falling back to LightX2V WanScheduler.", exc, ) super().init_scheduler() @@ -315,72 +1132,6 @@ def _get_sub_model_folder(self) -> str: """Resolve which MG3 sub-model folder should be used for config lookup.""" return str(self.config.get("sub_model_folder", "base_model" if self.config.get("use_base_model", False) else "base_distilled_model")) - def _resolve_official_root_candidate(self, candidate: Path) -> Optional[Path]: - """Accept either the inner package root or its parent repository directory.""" - direct_root = candidate.expanduser() - if (direct_root / "generate.py").is_file() and (direct_root / "pipeline").is_dir() and (direct_root / "utils").is_dir(): - return direct_root - - nested_root = direct_root / "Matrix-Game-3" - if (nested_root / "generate.py").is_file() and (nested_root / "pipeline").is_dir() and (nested_root / "utils").is_dir(): - return nested_root - return None - - def resolve_official_root(self) -> Path: - """Resolve the official Matrix-Game-3 source root using config-first priority.""" - configured_root = self.config.get("matrix_game3_official_root") - if configured_root: - for candidate in _expand_path_candidates(configured_root): - resolved = self._resolve_official_root_candidate(candidate) - if resolved is not None: - return resolved - raise FileNotFoundError( - "Matrix-Game-3 official source root is missing or invalid for " - f"matrix_game3_official_root={configured_root!r}. " - "The runner needs the official utils/pipeline files to build camera and action conditions. " - "Please set config['matrix_game3_official_root'] to the official source root directory." - ) - - auto_candidates: list[Path] = [] - for relative_path in _MATRIX_GAME3_OFFICIAL_ROOT_RELATIVE_CANDIDATES: - _append_unique_path(auto_candidates, _PROJECT_ROOT / relative_path) - _append_unique_path(auto_candidates, _PROJECT_ROOT) - - path_hints = [ - self.config.get("model_path"), - self.config.get("config_json"), - self.input_info.image_path if getattr(self, "input_info", None) is not None else None, - ] - for path_hint in path_hints: - if not path_hint: - continue - for candidate in _expand_path_candidates(path_hint): - looks_like_file = candidate.is_file() or candidate.suffix.lower() in { - ".json", - ".jpg", - ".jpeg", - ".png", - ".webp", - ".bmp", - ".gif", - ".npy", - } - current = candidate.parent if looks_like_file else candidate - for ancestor in (current, *current.parents): - _append_unique_path(auto_candidates, ancestor) - _append_unique_path(auto_candidates, ancestor / "Matrix-Game-3") - - for candidate in auto_candidates: - resolved = self._resolve_official_root_candidate(candidate) - if resolved is not None: - return resolved - - raise FileNotFoundError( - "Matrix-Game-3 official source root could not be resolved from the project layout. " - "The runner needs it to import official utils/conditions.py, utils/cam_utils.py, utils/utils.py, " - "and pipeline helpers. Please set config['matrix_game3_official_root'] explicitly." - ) - def resolve_model_config_path(self) -> Path: """Resolve the MG3 base/distilled config.json with explicit override support.""" configured_path = self.config.get("matrix_game3_config_path") @@ -436,47 +1187,14 @@ def _load_matrix_game3_model_config(self): self.mouse_dim_in = int(self.config.get("mouse_dim_in", action_config.get("mouse_dim_in", 2))) def _get_official_modules(self) -> dict[str, Any]: - """Lazy-load helper code from the official Matrix-Game-3 repository. - - We intentionally reuse the official camera/action utilities instead of - re-implementing pose math in the LightX2V runner. - """ + """Expose inlined Matrix-Game-3 helpers through a module-like interface.""" if self._official_modules is not None: return self._official_modules - official_root = self.resolve_official_root() - utils_root = official_root / "utils" - if not utils_root.is_dir(): - raise FileNotFoundError( - f"Matrix-Game-3 utils directory is missing under {official_root}. " - "The runner needs the official utils modules to construct action and camera conditions. " - "Please set config['matrix_game3_official_root'] to the official source root directory." - ) - - required_utils = { - "conditions": utils_root / "conditions.py", - "cam_utils": utils_root / "cam_utils.py", - "transform": utils_root / "transform.py", - "utils": utils_root / "utils.py", - } - missing_utils = [str(path) for path in required_utils.values() if not path.is_file()] - if missing_utils: - raise FileNotFoundError( - "Matrix-Game-3 official utility files are incomplete. " - f"Missing: {missing_utils}. " - "The runner needs these files to reuse the official action/camera preprocessing. " - "Please set config['matrix_game3_official_root'] to a complete official source checkout." - ) - - _ensure_namespace_package(_MATRIX_GAME3_OFFICIAL_PACKAGE, official_root) - utils_pkg = f"{_MATRIX_GAME3_OFFICIAL_PACKAGE}.utils" - _ensure_namespace_package(utils_pkg, utils_root) - modules = { - "conditions": _load_module_from_path(f"{utils_pkg}.conditions", required_utils["conditions"]), - "cam_utils": _load_module_from_path(f"{utils_pkg}.cam_utils", required_utils["cam_utils"]), - "transform": _load_module_from_path(f"{utils_pkg}.transform", required_utils["transform"]), - "utils": _load_module_from_path(f"{utils_pkg}.utils", required_utils["utils"]), + "conditions": _MatrixGame3ConditionsShim, + "cam_utils": _MatrixGame3CamUtilsShim, + "utils": _MatrixGame3UtilsShim, } self._official_modules = modules return modules From e923727ae7219a68b3ef0e7ba59724e745bd0393 Mon Sep 17 00:00:00 2001 From: Yang Date: Fri, 10 Apr 2026 10:23:12 +0800 Subject: [PATCH 23/25] modify the wan2.2 --- configs/matrix_game3/matrix_game3_base.json | 2 + .../matrix_game3/matrix_game3_distilled.json | 2 + .../runners/wan/wan_matrix_game3_runner.py | 6 + lightx2v/models/runners/wan/wan_runner.py | 61 ++++ .../models/video_encoders/hf/wan/vae_2_2.py | 271 ++++++++++++++++-- 5 files changed, 323 insertions(+), 19 deletions(-) diff --git a/configs/matrix_game3/matrix_game3_base.json b/configs/matrix_game3/matrix_game3_base.json index 5c8b0cc81..8df68663d 100644 --- a/configs/matrix_game3/matrix_game3_base.json +++ b/configs/matrix_game3/matrix_game3_base.json @@ -3,6 +3,8 @@ "task": "i2v", "sub_model_folder": "base_model", "use_base_model": true, + "vae_type": "mg_lightvae_v2", + "lightvae_pruning_rate": 0.75, "target_video_length": 57, "target_height": 704, diff --git a/configs/matrix_game3/matrix_game3_distilled.json b/configs/matrix_game3/matrix_game3_distilled.json index 720f488e5..391c02dd9 100644 --- a/configs/matrix_game3/matrix_game3_distilled.json +++ b/configs/matrix_game3/matrix_game3_distilled.json @@ -3,6 +3,8 @@ "task": "i2v", "sub_model_folder": "base_distilled_model", "use_base_model": false, + "vae_type": "mg_lightvae_v2", + "lightvae_pruning_rate": 0.75, "target_video_length": 57, "target_height": 704, diff --git a/lightx2v/models/runners/wan/wan_matrix_game3_runner.py b/lightx2v/models/runners/wan/wan_matrix_game3_runner.py index 7eb7566a3..c8e0257c0 100644 --- a/lightx2v/models/runners/wan/wan_matrix_game3_runner.py +++ b/lightx2v/models/runners/wan/wan_matrix_game3_runner.py @@ -1029,6 +1029,12 @@ def __init__(self, config): config["mode"] = "matrix_game3" config["use_image_encoder"] = False config["use_base_model"] = bool(config.get("use_base_model", False)) + config["vae_type"] = str(config.get("vae_type", "mg_lightvae_v2")) + if "lightvae_pruning_rate" not in config: + if config["vae_type"] == "mg_lightvae": + config["lightvae_pruning_rate"] = 0.5 + elif config["vae_type"] == "mg_lightvae_v2": + config["lightvae_pruning_rate"] = 0.75 if "sub_model_folder" not in config: config["sub_model_folder"] = "base_model" if config["use_base_model"] else "base_distilled_model" config["num_channels_latents"] = int(config.get("num_channels_latents", 48)) diff --git a/lightx2v/models/runners/wan/wan_runner.py b/lightx2v/models/runners/wan/wan_runner.py index 6aaefc087..f10b509b2 100755 --- a/lightx2v/models/runners/wan/wan_runner.py +++ b/lightx2v/models/runners/wan/wan_runner.py @@ -684,6 +684,67 @@ def __init__(self, config): self.vae_name = "Wan2.2_VAE.pth" self.tiny_vae_name = "taew2_2.pth" + def _resolve_wan22_vae_paths(self): + requested_vae_type = str(self.config.get("vae_type", "wan")).lower() + if requested_vae_type not in {"wan", "wan2.2", "mg_lightvae", "mg_lightvae_v2"}: + raise ValueError(f"Unsupported wan2.2 vae_type: {requested_vae_type}") + + if requested_vae_type in {"wan", "wan2.2"}: + decoder_filename = "Wan2.2_VAE.pth" + resolved_vae_type = "wan2.2" + default_pruning_rate = None + elif requested_vae_type == "mg_lightvae": + decoder_filename = "MG-LightVAE.pth" + resolved_vae_type = "mg_lightvae" + default_pruning_rate = 0.5 + else: + decoder_filename = "MG-LightVAE_v2.pth" + resolved_vae_type = "mg_lightvae" + default_pruning_rate = 0.75 + + decoder_path = self.config.get("vae_path") or find_torch_model_path(self.config, filename=decoder_filename) + teacher_encoder_path = ( + self.config.get("lightvae_encoder_vae_pth") + or self.config.get("lightvae_encoder_path") + or find_torch_model_path(self.config, filename="Wan2.2_VAE.pth") + ) + + return { + "vae_path": decoder_path, + "vae_type": resolved_vae_type, + "lightvae_pruning_rate": self.config.get("lightvae_pruning_rate", default_pruning_rate), + "lightvae_encoder_vae_pth": teacher_encoder_path, + } + + def _build_wan22_vae_config(self, vae_offload): + vae_device = torch.device("cpu") if vae_offload else torch.device(AI_DEVICE) + resolved_paths = self._resolve_wan22_vae_paths() + return { + "vae_path": resolved_paths["vae_path"], + "device": vae_device, + "parallel": self.get_vae_parallel(), + "use_tiling": self.config.get("use_tiling_vae", False), + "cpu_offload": vae_offload, + "dtype": GET_DTYPE(), + "load_from_rank0": self.config.get("load_from_rank0", False), + "vae_type": resolved_paths["vae_type"], + "lightvae_pruning_rate": resolved_paths["lightvae_pruning_rate"], + "lightvae_encoder_vae_pth": resolved_paths["lightvae_encoder_vae_pth"], + } + + def load_vae_encoder(self): + if self.config["task"] not in ["i2v", "flf2v", "animate", "vace", "s2v", "rs2v"]: + return None + vae_offload = self.config.get("vae_cpu_offload", self.config.get("cpu_offload")) + return self.vae_cls(**self._build_wan22_vae_config(vae_offload)) + + def load_vae_decoder(self): + vae_offload = self.config.get("vae_cpu_offload", self.config.get("cpu_offload")) + if self.config.get("use_tae", False): + tae_path = find_torch_model_path(self.config, "tae_path", self.tiny_vae_name) + return self.tiny_vae_cls(vae_path=tae_path, device=self.init_device, need_scaled=self.config.get("need_scaled", False)).to(AI_DEVICE) + return self.vae_cls(**self._build_wan22_vae_config(vae_offload)) + @ProfilingContext4DebugL1( "Run VAE Encoder", recorder_mode=GET_RECORDER_MODE(), diff --git a/lightx2v/models/video_encoders/hf/wan/vae_2_2.py b/lightx2v/models/video_encoders/hf/wan/vae_2_2.py index 6f8651574..f062c9884 100755 --- a/lightx2v/models/video_encoders/hf/wan/vae_2_2.py +++ b/lightx2v/models/video_encoders/hf/wan/vae_2_2.py @@ -1,5 +1,6 @@ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. import logging +import os import torch import torch.nn as nn @@ -19,6 +20,140 @@ CACHE_T = 2 +def _extract_checkpoint_state_dict(raw): + state = raw + if isinstance(state, dict) and "state_dict" in state: + state = state["state_dict"] + if isinstance(state, dict) and "gen_model" in state: + state = state["gen_model"] + if isinstance(state, dict) and "generator" in state: + state = state["generator"] + if not isinstance(state, dict): + raise ValueError("Unsupported checkpoint format: expected a dict-like state_dict.") + return state + + +def _map_lightvae_key_to_wanvae(key): + def _map_resnet_tail(tail): + if tail.startswith("norm1."): + return "residual.0." + tail[len("norm1.") :] + if tail.startswith("conv1."): + return "residual.2." + tail[len("conv1.") :] + if tail.startswith("norm2."): + return "residual.3." + tail[len("norm2.") :] + if tail.startswith("conv2."): + return "residual.6." + tail[len("conv2.") :] + if tail.startswith("conv_shortcut."): + return "shortcut." + tail[len("conv_shortcut.") :] + return tail + + if key.startswith("dynamic_feature_projection_heads."): + return None + + if key.startswith("quant_conv."): + return key.replace("quant_conv.", "conv1.", 1) + if key.startswith("post_quant_conv."): + return key.replace("post_quant_conv.", "conv2.", 1) + + if key.startswith("encoder.conv_in."): + return key.replace("encoder.conv_in.", "encoder.conv1.", 1) + if key.startswith("encoder.mid_block.resnets.0."): + tail = key[len("encoder.mid_block.resnets.0.") :] + return "encoder.middle.0." + _map_resnet_tail(tail) + if key.startswith("encoder.mid_block.attentions.0."): + return key.replace("encoder.mid_block.attentions.0.", "encoder.middle.1.", 1) + if key.startswith("encoder.mid_block.resnets.1."): + tail = key[len("encoder.mid_block.resnets.1.") :] + return "encoder.middle.2." + _map_resnet_tail(tail) + if key.startswith("encoder.norm_out."): + return key.replace("encoder.norm_out.", "encoder.head.0.", 1) + if key.startswith("encoder.conv_out."): + return key.replace("encoder.conv_out.", "encoder.head.2.", 1) + + if key.startswith("encoder.down_blocks."): + parts = key.split(".") + if len(parts) >= 6 and parts[3] == "resnets": + tail = ".".join(parts[5:]) + return f"encoder.downsamples.{parts[2]}.downsamples.{parts[4]}." + _map_resnet_tail(tail) + if len(parts) >= 7 and parts[3] == "downsampler" and parts[4] == "resample": + return f"encoder.downsamples.{parts[2]}.downsamples.2.resample.{parts[5]}." + ".".join(parts[6:]) + if len(parts) >= 6 and parts[3] == "downsampler" and parts[4] == "time_conv": + return f"encoder.downsamples.{parts[2]}.downsamples.2.time_conv." + ".".join(parts[5:]) + + if key.startswith("decoder.conv_in."): + return key.replace("decoder.conv_in.", "decoder.conv1.", 1) + if key.startswith("decoder.mid_block.resnets.0."): + tail = key[len("decoder.mid_block.resnets.0.") :] + return "decoder.middle.0." + _map_resnet_tail(tail) + if key.startswith("decoder.mid_block.attentions.0."): + return key.replace("decoder.mid_block.attentions.0.", "decoder.middle.1.", 1) + if key.startswith("decoder.mid_block.resnets.1."): + tail = key[len("decoder.mid_block.resnets.1.") :] + return "decoder.middle.2." + _map_resnet_tail(tail) + if key.startswith("decoder.norm_out."): + return key.replace("decoder.norm_out.", "decoder.head.0.", 1) + if key.startswith("decoder.conv_out."): + return key.replace("decoder.conv_out.", "decoder.head.2.", 1) + + if key.startswith("decoder.up_blocks."): + parts = key.split(".") + if len(parts) >= 6 and parts[3] == "resnets": + tail = ".".join(parts[5:]) + return f"decoder.upsamples.{parts[2]}.upsamples.{parts[4]}." + _map_resnet_tail(tail) + if len(parts) >= 7 and parts[3] == "upsampler" and parts[4] == "resample": + return f"decoder.upsamples.{parts[2]}.upsamples.3.resample.{parts[5]}." + ".".join(parts[6:]) + if len(parts) >= 6 and parts[3] == "upsampler" and parts[4] == "time_conv": + return f"decoder.upsamples.{parts[2]}.upsamples.3.time_conv." + ".".join(parts[5:]) + + return key + + +def _normalize_vae_state_dict(raw_state): + state = _extract_checkpoint_state_dict(raw_state) + norm = {} + for key, value in state.items(): + normalized_key = _map_lightvae_key_to_wanvae(key) + if normalized_key is None: + continue + norm[normalized_key] = value + return norm + + +def infer_lightvae_pruning_rate_from_ckpt(vae_pth, full_decoder_conv1_out=1024): + if vae_pth is None or not os.path.exists(vae_pth): + return None + try: + raw_state = load_weights(vae_pth) + state = _extract_checkpoint_state_dict(raw_state) + except Exception as exc: + logging.warning(f"Failed to load checkpoint for pruning-rate inference: {exc}") + return None + + weight = None + if isinstance(state, dict): + if "decoder.conv_in.weight" in state: + weight = state["decoder.conv_in.weight"] + elif "decoder.conv1.weight" in state: + weight = state["decoder.conv1.weight"] + + if weight is None: + try: + norm_state = _normalize_vae_state_dict(state) + weight = norm_state.get("decoder.conv1.weight", None) + except Exception: + weight = None + + if weight is None or not hasattr(weight, "shape") or len(weight.shape) < 1: + return None + + student_out = int(weight.shape[0]) + if full_decoder_conv1_out <= 0: + return None + pruning_rate = 1.0 - (float(student_out) / float(full_decoder_conv1_out)) + pruning_rate = max(0.0, min(0.99, pruning_rate)) + return round(pruning_rate, 6) + + class CausalConv3d(nn.Conv3d): """ Causal 3d convolution. @@ -722,6 +857,7 @@ def __init__( attn_scales=[], temperal_downsample=[True, True, False], dropout=0.0, + pruning_rate=0.0, ): super().__init__() self.dim = dim @@ -732,6 +868,9 @@ def __init__( self.temperal_downsample = temperal_downsample self.temperal_upsample = temperal_downsample[::-1] + dim = max(1, int(round(dim * (1.0 - pruning_rate)))) + dec_dim = max(1, int(round(dec_dim * (1.0 - pruning_rate)))) + # modules self.encoder = Encoder3d( dim, @@ -849,7 +988,18 @@ def decode_video(self, x, scale=[0, 1]): return y.transpose(1, 2).to(x) -def _video_vae(pretrained_path=None, z_dim=16, dim=160, device="cpu", cpu_offload=False, dtype=torch.float32, load_from_rank0=False, **kwargs): +def _video_vae( + pretrained_path=None, + z_dim=16, + dim=160, + device="cpu", + cpu_offload=False, + dtype=torch.float32, + load_from_rank0=False, + normalize_state_dict=False, + strict=True, + **kwargs, +): # params cfg = dict( dim=dim, @@ -868,11 +1018,16 @@ def _video_vae(pretrained_path=None, z_dim=16, dim=160, device="cpu", cpu_offloa # load checkpoint logging.info(f"loading {pretrained_path}") - weights_dict = load_weights(pretrained_path, cpu_offload=cpu_offload, load_from_rank0=load_from_rank0) - for k in weights_dict.keys(): - if weights_dict[k].dtype != dtype: - weights_dict[k] = weights_dict[k].to(dtype) - model.load_state_dict(weights_dict, assign=True) + raw_state = load_weights(pretrained_path, cpu_offload=cpu_offload, load_from_rank0=load_from_rank0) + weights_dict = _normalize_vae_state_dict(raw_state) if normalize_state_dict else raw_state + for key in list(weights_dict.keys()): + if hasattr(weights_dict[key], "dtype") and weights_dict[key].dtype != dtype: + weights_dict[key] = weights_dict[key].to(dtype) + if strict: + model.load_state_dict(weights_dict, assign=True) + else: + missing, unexpected = model.load_state_dict(weights_dict, strict=False, assign=True) + logging.info(f"VAE checkpoint loaded with strict=False (missing={len(missing)}, unexpected={len(unexpected)})") # Convert Conv3d weights to channels_last_3d for cuDNN optimization if GET_USE_CHANNELS_LAST_3D(): @@ -895,12 +1050,17 @@ def __init__( cpu_offload=False, offload_cache=False, load_from_rank0=False, + vae_type="wan2.2", + lightvae_pruning_rate=None, + lightvae_encoder_vae_pth=None, **kwargs, ): self.dtype = dtype self.device = device self.cpu_offload = cpu_offload self.offload_cache = offload_cache + self.vae_type = vae_type + self.encoder_model = None self.mean = torch.tensor( [ @@ -954,7 +1114,7 @@ def __init__( -0.0667, ], dtype=dtype, - device=AI_DEVICE, + device=device, ) self.std = torch.tensor( [ @@ -1008,22 +1168,89 @@ def __init__( 0.7744, ], dtype=dtype, - device=AI_DEVICE, + device=device, ) self.inv_std = 1.0 / self.std self.scale = [self.mean, self.inv_std] - # init model - self.model = ( - _video_vae( - pretrained_path=vae_path, z_dim=z_dim, dim=c_dim, dim_mult=dim_mult, temperal_downsample=temperal_downsample, cpu_offload=cpu_offload, dtype=dtype, load_from_rank0=load_from_rank0 + if self.vae_type == "wan2.2": + self.model = ( + _video_vae( + pretrained_path=vae_path, + z_dim=z_dim, + dim=c_dim, + dim_mult=dim_mult, + temperal_downsample=temperal_downsample, + cpu_offload=cpu_offload, + dtype=dtype, + load_from_rank0=load_from_rank0, + normalize_state_dict=False, + strict=True, + pruning_rate=0.0, + ) + .eval() + .requires_grad_(False) + .to(device) + .to(dtype) ) - .eval() - .requires_grad_(False) - .to(device) - .to(dtype) - ) + elif self.vae_type == "mg_lightvae": + resolved_pruning_rate = lightvae_pruning_rate + if resolved_pruning_rate is None: + resolved_pruning_rate = infer_lightvae_pruning_rate_from_ckpt(vae_path) + if resolved_pruning_rate is None: + resolved_pruning_rate = 0.75 + logging.warning("Unable to infer LightVAE pruning rate from checkpoint; fallback to 0.75.") + + teacher_vae_path = lightvae_encoder_vae_pth or vae_path + logging.info( + f"Loading mg_lightvae decoder from {vae_path} (pruning_rate={resolved_pruning_rate}), " + f"while keeping teacher encoder from {teacher_vae_path}." + ) + self.encoder_model = ( + _video_vae( + pretrained_path=teacher_vae_path, + z_dim=z_dim, + dim=c_dim, + dim_mult=dim_mult, + temperal_downsample=temperal_downsample, + cpu_offload=cpu_offload, + dtype=dtype, + load_from_rank0=load_from_rank0, + normalize_state_dict=False, + strict=True, + pruning_rate=0.0, + ) + .eval() + .requires_grad_(False) + .to(device) + .to(dtype) + ) + self.model = ( + _video_vae( + pretrained_path=vae_path, + z_dim=z_dim, + dim=c_dim, + dim_mult=dim_mult, + temperal_downsample=temperal_downsample, + cpu_offload=cpu_offload, + dtype=dtype, + load_from_rank0=load_from_rank0, + normalize_state_dict=True, + strict=False, + pruning_rate=resolved_pruning_rate, + ) + .eval() + .requires_grad_(False) + .to(device) + .to(dtype) + ) + else: + raise ValueError(f"Unsupported vae_type: {self.vae_type}") def to_cpu(self): + if self.encoder_model is not None: + self.encoder_model.encoder = self.encoder_model.encoder.to("cpu") + self.encoder_model.decoder = self.encoder_model.decoder.to("cpu") + self.encoder_model = self.encoder_model.to("cpu") self.model.encoder = self.model.encoder.to("cpu") self.model.decoder = self.model.decoder.to("cpu") self.model = self.model.to("cpu") @@ -1032,6 +1259,10 @@ def to_cpu(self): self.scale = [self.mean, self.inv_std] def to_cuda(self): + if self.encoder_model is not None: + self.encoder_model.encoder = self.encoder_model.encoder.to(AI_DEVICE) + self.encoder_model.decoder = self.encoder_model.decoder.to(AI_DEVICE) + self.encoder_model = self.encoder_model.to(AI_DEVICE) self.model.encoder = self.model.encoder.to(AI_DEVICE) self.model.decoder = self.model.decoder.to(AI_DEVICE) self.model = self.model.to(AI_DEVICE) @@ -1042,7 +1273,8 @@ def to_cuda(self): def encode(self, video): if self.cpu_offload: self.to_cuda() - out = self.model.encode(video, self.scale).float().squeeze(0) + encode_model = self.encoder_model if self.vae_type == "mg_lightvae" and self.encoder_model is not None else self.model + out = encode_model.encode(video, self.scale).float().squeeze(0) if self.cpu_offload: self.to_cpu() return out @@ -1057,7 +1289,8 @@ def decode(self, zs): return images def encode_video(self, vid): - return self.model.encode_video(vid) + encode_model = self.encoder_model if self.vae_type == "mg_lightvae" and self.encoder_model is not None else self.model + return encode_model.encode_video(vid) def decode_video(self, vid_enc): return self.model.decode_video(vid_enc) From bb06f9e9afdb9a8aa2b6e3ef85fe7bcf73f4ddec Mon Sep 17 00:00:00 2001 From: Yang Date: Sun, 12 Apr 2026 13:04:03 +0800 Subject: [PATCH 24/25] remove the latent type --- lightx2v/models/runners/wan/wan_matrix_game3_runner.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/lightx2v/models/runners/wan/wan_matrix_game3_runner.py b/lightx2v/models/runners/wan/wan_matrix_game3_runner.py index c8e0257c0..83cbfdf91 100644 --- a/lightx2v/models/runners/wan/wan_matrix_game3_runner.py +++ b/lightx2v/models/runners/wan/wan_matrix_game3_runner.py @@ -987,7 +987,11 @@ def step_pre(self, step_index): def step_post(self): timestep = self._solver.timesteps[self.step_index].to(device=self.latents.device) prev_sample = self._solver.step( - self.noise_pred.to(dtype=self.latents.dtype), + # Keep the model output in its original precision. The official MG3 + # pipeline feeds float32 noise predictions into UniPC even when the + # latent state is bf16, and downcasting here introduces a tiny + # per-step drift that the 50-step base model visibly amplifies. + self.noise_pred, timestep, self.latents, return_dict=False, From 9796026e62c6036b0a4f6240ba2954336eddd5f8 Mon Sep 17 00:00:00 2001 From: Yang Date: Sun, 12 Apr 2026 13:24:50 +0800 Subject: [PATCH 25/25] Alter the DiT --- .../models/networks/wan/matrix_game3_model.py | 141 ++++++++++++++++++ .../runners/wan/wan_matrix_game3_runner.py | 14 +- 2 files changed, 154 insertions(+), 1 deletion(-) diff --git a/lightx2v/models/networks/wan/matrix_game3_model.py b/lightx2v/models/networks/wan/matrix_game3_model.py index 9cbdfd501..6ef67a0ff 100644 --- a/lightx2v/models/networks/wan/matrix_game3_model.py +++ b/lightx2v/models/networks/wan/matrix_game3_model.py @@ -1,6 +1,10 @@ import json import os +import sys +from functools import lru_cache +from pathlib import Path +import torch from safetensors import safe_open from lightx2v.models.networks.wan.infer.matrix_game3.post_infer import WanMtxg3PostInfer @@ -13,6 +17,143 @@ from lightx2v.utils.utils import * +@lru_cache(maxsize=1) +def _import_official_matrix_game3_wan_model(): + """Load the official Matrix-Game-3 WanModel implementation on demand.""" + official_root = Path(__file__).resolve().parents[4] / "Matrix-Game-3" / "Matrix-Game-3" + if not official_root.is_dir(): + raise FileNotFoundError(f"Official Matrix-Game-3 source directory not found: {official_root}") + official_root_str = str(official_root) + if official_root_str not in sys.path: + sys.path.insert(0, official_root_str) + from wan.modules.model import WanModel as OfficialWanModel + + return OfficialWanModel + + +class WanMtxg3OfficialBaseModel: + """Base-model wrapper that delegates denoising to the official MG3 forward. + + The distilled MG3 path is numerically tolerant enough to run through the + custom LightX2V weight/infer stack, but the base checkpoint is much more + sensitive under 50-step CFG. Reusing the official DiT forward here removes + the remaining block/head precision mismatches from the adaptation. + """ + + def __init__(self, model_path, config, device, model_type="wan2.2", lora_path=None, lora_strength=1.0): + del model_type, lora_path, lora_strength + self.model_path = model_path + self.config = config + self.device = device + self.scheduler = None + self.transformer_infer = None + self._official_model = self._load_official_model() + + def _load_official_model(self): + sub_model_folder = self.config.get("sub_model_folder", "base_model") + model_dir = os.path.join(self.config["model_path"], sub_model_folder) + if not os.path.isdir(model_dir): + raise FileNotFoundError(f"Matrix-Game-3 base checkpoint directory not found: {model_dir}") + OfficialWanModel = _import_official_matrix_game3_wan_model() + model = OfficialWanModel.from_pretrained(model_dir, torch_dtype=torch.bfloat16) + model = model.eval().requires_grad_(False) + model.to(device=self.device, dtype=torch.bfloat16) + return model + + def set_scheduler(self, scheduler): + self.scheduler = scheduler + + def _build_official_timestep(self, latents): + t = self.scheduler.timestep_input + if t is None: + raise RuntimeError("Matrix-Game-3 base forward requested before scheduler.timestep_input was prepared") + if t.numel() != 1: + return t.reshape(1, -1).to(device=latents.device, dtype=latents.dtype) + + timestep_scalar = t.reshape(1).to(device=latents.device, dtype=latents.dtype) + timestep = latents.new_full( + (latents.shape[1], latents.shape[2] * latents.shape[3] // 4), + timestep_scalar.squeeze(0), + ) + mask = getattr(self.scheduler, "mask", None) + if mask is not None: + fixed_latent_frames = int((mask[0].flatten(1).sum(dim=1) == 0).sum().item()) + if fixed_latent_frames > 0: + timestep[:fixed_latent_frames].zero_() + return timestep.flatten().unsqueeze(0) + + def _build_forward_kwargs(self, inputs, infer_condition): + if self.scheduler is None: + raise RuntimeError("Matrix-Game-3 base model used before scheduler was attached") + + latents = self.scheduler.latents.unsqueeze(0) + timestep = self._build_official_timestep(self.scheduler.latents) + image_encoder_output = inputs.get("image_encoder_output", {}) + dit_cond_dict = image_encoder_output.get("dit_cond_dict") or {} + + if infer_condition: + context = inputs["text_encoder_output"]["context"] + plucker_emb = dit_cond_dict.get("plucker_emb_with_memory", dit_cond_dict.get("c2ws_plucker_emb")) + mouse_cond = dit_cond_dict.get("mouse_cond") + keyboard_cond = dit_cond_dict.get("keyboard_cond") + x_memory = dit_cond_dict.get("x_memory") + timestep_memory = dit_cond_dict.get("timestep_memory") + mouse_cond_memory = dit_cond_dict.get("mouse_cond_memory") + keyboard_cond_memory = dit_cond_dict.get("keyboard_cond_memory") + memory_latent_idx = dit_cond_dict.get("memory_latent_idx") + else: + context = inputs["text_encoder_output"]["context_null"] + mouse_source = dit_cond_dict.get("mouse_cond") + keyboard_source = dit_cond_dict.get("keyboard_cond") + plucker_emb = dit_cond_dict.get("c2ws_plucker_emb") + mouse_cond = torch.ones_like(mouse_source) if mouse_source is not None else None + keyboard_cond = -torch.ones_like(keyboard_source) if keyboard_source is not None else None + x_memory = None + timestep_memory = None + mouse_cond_memory = None + keyboard_cond_memory = None + memory_latent_idx = None + + total_latent_frames = latents.shape[2] + (int(x_memory.shape[2]) if x_memory is not None else 0) + patch_h, patch_w = tuple(self.config.get("patch_size", (1, 2, 2)))[1:] + seq_len = total_latent_frames * latents.shape[3] * latents.shape[4] // (patch_h * patch_w) + + return { + "x": latents, + "t": timestep, + "context": context, + "seq_len": seq_len, + "mouse_cond": mouse_cond, + "keyboard_cond": keyboard_cond, + "x_memory": x_memory, + "timestep_memory": timestep_memory, + "mouse_cond_memory": mouse_cond_memory, + "keyboard_cond_memory": keyboard_cond_memory, + "plucker_emb": plucker_emb, + "memory_latent_idx": memory_latent_idx, + "predict_latent_idx": dit_cond_dict.get("predict_latent_idx"), + } + + @torch.no_grad() + def _infer_cond_uncond(self, inputs, infer_condition=True): + self.scheduler.infer_condition = infer_condition + noise_pred = self._official_model(**self._build_forward_kwargs(inputs, infer_condition)) + if isinstance(noise_pred, list): + noise_pred = torch.stack(noise_pred) + if noise_pred.dim() == 5 and noise_pred.shape[0] == 1: + noise_pred = noise_pred.squeeze(0) + return noise_pred.float() + + @torch.no_grad() + def infer(self, inputs): + if self.config.get("enable_cfg", False): + noise_pred_cond = self._infer_cond_uncond(inputs, infer_condition=True) + noise_pred_uncond = self._infer_cond_uncond(inputs, infer_condition=False) + self.scheduler.noise_pred = noise_pred_uncond + self.scheduler.sample_guide_scale * (noise_pred_cond - noise_pred_uncond) + else: + self.scheduler.noise_pred = self._infer_cond_uncond(inputs, infer_condition=True) + + class WanMtxg3Model(WanModel): """Network model for Matrix-Game-3.0. diff --git a/lightx2v/models/runners/wan/wan_matrix_game3_runner.py b/lightx2v/models/runners/wan/wan_matrix_game3_runner.py index 83cbfdf91..0fbc0eb20 100644 --- a/lightx2v/models/runners/wan/wan_matrix_game3_runner.py +++ b/lightx2v/models/runners/wan/wan_matrix_game3_runner.py @@ -1108,7 +1108,10 @@ def run_text_encoder(self, input_info): return super().run_text_encoder(input_info) def load_transformer(self): - from lightx2v.models.networks.wan.matrix_game3_model import WanMtxg3Model + from lightx2v.models.networks.wan.matrix_game3_model import ( + WanMtxg3Model, + WanMtxg3OfficialBaseModel, + ) # The backbone is still a Wan2.2 DiT, but Matrix-Game-3 swaps in a dedicated # network wrapper that understands keyboard / mouse / camera conditions. @@ -1119,6 +1122,15 @@ def load_transformer(self): } lora_configs = self.config.get("lora_configs") if not lora_configs: + if self.config.get("use_base_model", False): + try: + logger.info("[matrix-game-3] base-model path will use the official WanModel forward for denoising.") + return WanMtxg3OfficialBaseModel(**model_kwargs) + except Exception as exc: + logger.warning( + "[matrix-game-3] failed to initialize official base-model forward ({}); falling back to the custom LightX2V MG3 model.", + exc, + ) return WanMtxg3Model(**model_kwargs) return build_wan_model_with_lora(WanMtxg3Model, self.config, model_kwargs, lora_configs, model_type="wan2.2")