Complete The LightX2V's Support To Motus with i2v task.#992
Complete The LightX2V's Support To Motus with i2v task.#992zowiezhang wants to merge 3 commits intoModelTC:mainfrom
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces the Motus model for image-to-video tasks, integrating it into the LightX2V framework with new network components, inference modules, and a dedicated runner. The implementation includes an action expert for robot state conditioning and an understanding expert for VLM-based features. Feedback focuses on improving robustness and performance: specifically, handling cases where register tokens are absent to avoid attribute errors and empty slices, optimizing VAE decoding by processing batches instead of individual samples, and fixing ineffective lru_cache usage within local function scopes. Additionally, removing unused legacy code and optimizing tensor initialization on the target device are recommended to enhance maintainability and efficiency.
| scheduler.step_pre(step_index) | ||
| video_tokens = model.video_module.prepare_input(scheduler.video_latents.to(model.dtype)) | ||
| state_tokens = pre_infer_out.state.unsqueeze(1).to(model.dtype) | ||
| registers = model.action_expert.registers.expand(state_tokens.shape[0], -1, -1) |
There was a problem hiding this comment.
The registers attribute can be None if num_registers is set to 0 in the configuration (e.g., during pre-training). Calling .expand() on None will raise an AttributeError. This should be handled safely.
| registers = model.action_expert.registers.expand(state_tokens.shape[0], -1, -1) | |
| registers = model.action_expert.registers | |
| if registers is not None: | |
| registers = registers.expand(state_tokens.shape[0], -1, -1) |
|
|
||
| video_velocity = model.video_module.apply_output_head(video_tokens, video_head_time_emb) | ||
| action_pred_full = model.action_expert.decoder(action_tokens, action_head_time_emb) | ||
| action_velocity = action_pred_full[:, 1 : -model.action_expert.config.num_registers, :] |
There was a problem hiding this comment.
When num_registers is 0, the slice 1 : -model.action_expert.config.num_registers becomes 1 : 0, which returns an empty tensor. This will likely cause a failure in the scheduler's step method. The slicing logic should account for the case where no registers are used.
| action_velocity = action_pred_full[:, 1 : -model.action_expert.config.num_registers, :] | |
| num_regs = model.action_expert.config.num_registers | |
| action_velocity = action_pred_full[:, 1 : -num_regs, :] if num_regs > 0 else action_pred_full[:, 1:, :] |
|
|
||
| def decode_video(self, video_latents: torch.Tensor) -> torch.Tensor: | ||
| with torch.no_grad(): | ||
| return torch.stack([self.vae.decode([video_latents[i]])[0] for i in range(video_latents.shape[0])], dim=0) |
There was a problem hiding this comment.
Decoding video latents sample-by-sample in a loop is inefficient because the underlying VAE model is designed to process batches. Processing the entire batch at once would significantly improve performance.
| return torch.stack([self.vae.decode([video_latents[i]])[0] for i in range(video_latents.shape[0])], dim=0) | |
| with torch.no_grad(): | |
| return self.vae.model.decode(video_latents, self.vae.scale).float().clamp(-1, 1) |
| try: | ||
| if not isinstance(zs, list): | ||
| raise TypeError("zs should be a list") | ||
| with amp.autocast(dtype=self.dtype): | ||
| return [ | ||
| self.model.decode(u.unsqueeze(0), | ||
| self.scale).float().clamp_(-1, | ||
| 1).squeeze(0) | ||
| for u in zs | ||
| ] | ||
| except TypeError as e: | ||
| logging.info(e) | ||
| return None |
There was a problem hiding this comment.
The error handling in the decode method catches TypeError and returns None. This is problematic because the caller (e.g., WanVideoModel.decode_video) expects a list and will crash when attempting to index or iterate over None. It is better to let the exception propagate or return an empty list if appropriate.
def decode(self, zs):
if not isinstance(zs, list):
raise TypeError("zs should be a list")
with amp.autocast(dtype=self.dtype):
return [
self.model.decode(u.unsqueeze(0), self.scale).float().clamp(-1, 1).squeeze(0)
for u in zs
]| @lru_cache(maxsize=256) | ||
| def _make_freq_grid(f: int, h: int, w: int): |
There was a problem hiding this comment.
The _make_freq_grid function is decorated with @lru_cache but is defined inside rope_apply. This means a new function object (and a new empty cache) is created on every call to rope_apply, making the cache ineffective. This function should be moved outside rope_apply to allow the cache to persist across calls.
| @lru_cache(maxsize=256) | ||
| def _make_freq_grid(f: int, h: int, w: int): |
| def rope_apply_original(x, grid_sizes, freqs): | ||
| n, c = x.size(2), x.size(3) // 2 | ||
|
|
||
| # split freqs | ||
| freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1) | ||
|
|
||
| # loop over samples | ||
| output = [] | ||
| for i, (f, h, w) in enumerate(grid_sizes.tolist()): | ||
| seq_len = f * h * w | ||
|
|
||
| # precompute multipliers | ||
| x_i = torch.view_as_complex(x[i, :seq_len].to(torch.float64).reshape( | ||
| seq_len, n, -1, 2)) | ||
| freqs_i = torch.cat([ | ||
| freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1), | ||
| freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1), | ||
| freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1) | ||
| ], | ||
| dim=-1).reshape(seq_len, 1, -1) | ||
|
|
||
| # apply rotary embedding | ||
| x_i = torch.view_as_real(x_i * freqs_i).flatten(2) | ||
| x_i = torch.cat([x_i, x[i, seq_len:]]) | ||
|
|
||
| # append to collection | ||
| output.append(x_i) | ||
| return torch.stack(output).float() |
|
|
||
| if q_lens is None: | ||
| q = half(q.flatten(0, 1)) | ||
| q_lens = torch.tensor([q_len] * batch, dtype=torch.int32).to(device=q.device, non_blocking=True) |
There was a problem hiding this comment.
Creating a tensor on CPU and moving it to the GPU in a hot loop is inefficient. Using torch.full directly on the target device is faster.
| q_lens = torch.tensor([q_len] * batch, dtype=torch.int32).to(device=q.device, non_blocking=True) | |
| q_lens = torch.full((batch,), q_len, dtype=torch.int32, device=q.device) |
Add Motus feature to LightX2V with i2v task, where "i" here represents the first frame.