Complete The LightX2V's Support To Matrix-Game-3.#989
Complete The LightX2V's Support To Matrix-Game-3.#989Michael20070814 wants to merge 26 commits intoModelTC:mainfrom
Conversation
There was a problem hiding this comment.
Code Review
This pull request implements support for Matrix-Game-3.0 (MG3), adding new configurations, network models, weight modules, and a dedicated runner for segment-based video generation with camera and action conditioning. The update includes memory-aware self-attention and an action module for keyboard and mouse inputs. Feedback suggests enhancing hardware compatibility by checking against AI_DEVICE instead of torch.cuda.is_available() and refining the token count calculation for memory timesteps to improve robustness.
| def get_latent_idx(frame_idx: int) -> int: | ||
| return (frame_idx - 1) // 4 + 1 if frame_idx > 0 else 0 | ||
|
|
||
| selected_index_base = [current_end_frame_idx - offset for offset in range(1, 34, 8)] | ||
| selected_index = modules["cam_utils"].select_memory_idx_fov( |
There was a problem hiding this comment.
The use of torch.cuda.is_available() is restrictive for a library that supports multiple hardware backends (like NPU or XPU). It is better to check against the configured AI_DEVICE to determine if hardware acceleration should be used for this operation.
| def get_latent_idx(frame_idx: int) -> int: | |
| return (frame_idx - 1) // 4 + 1 if frame_idx > 0 else 0 | |
| selected_index_base = [current_end_frame_idx - offset for offset in range(1, 34, 8)] | |
| selected_index = modules["cam_utils"].select_memory_idx_fov( | |
| selected_index = modules["cam_utils"].select_memory_idx_fov( | |
| self._mg3_extrinsics_all, | |
| current_start_frame_idx, | |
| selected_index_base, | |
| use_gpu=(AI_DEVICE == "cuda"), | |
| ) |
| x_memory = src[:, valid_latent_idx].unsqueeze(0).to(device=AI_DEVICE, dtype=GET_DTYPE()) if valid_latent_idx else None | ||
| if x_memory is None: | ||
| timestep_memory = None | ||
| keyboard_cond_memory = None |
There was a problem hiding this comment.
The calculation of the number of tokens for timestep_memory using // 4 assumes that the latent spatial dimensions (H and W) are always even. While this is typically true for Wan2.2, it is more robust to calculate the token count based on the actual downsampled grid size logic (stride 2) used in the transformer's patch embedding.
| keyboard_cond_memory = None | |
| timestep_memory = x_memory.new_zeros((1, x_memory.shape[2] * (x_memory.shape[3] // 2) * (x_memory.shape[4] // 2))) |
Complete The LightX2V's Support To Matrix-Game-3.