Skip to content

Complete The LightX2V's Support To Motus with i2v task.#992

Open
zowiezhang wants to merge 3 commits intoModelTC:mainfrom
zowiezhang:feature/add-motus-model
Open

Complete The LightX2V's Support To Motus with i2v task.#992
zowiezhang wants to merge 3 commits intoModelTC:mainfrom
zowiezhang:feature/add-motus-model

Conversation

@zowiezhang
Copy link
Copy Markdown

Add Motus feature to LightX2V with i2v task, where "i" here represents the first frame.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces the Motus model for image-to-video tasks, integrating it into the LightX2V framework with new network components, inference modules, and a dedicated runner. The implementation includes an action expert for robot state conditioning and an understanding expert for VLM-based features. Feedback focuses on improving robustness and performance: specifically, handling cases where register tokens are absent to avoid attribute errors and empty slices, optimizing VAE decoding by processing batches instead of individual samples, and fixing ineffective lru_cache usage within local function scopes. Additionally, removing unused legacy code and optimizing tensor initialization on the target device are recommended to enhance maintainability and efficiency.

scheduler.step_pre(step_index)
video_tokens = model.video_module.prepare_input(scheduler.video_latents.to(model.dtype))
state_tokens = pre_infer_out.state.unsqueeze(1).to(model.dtype)
registers = model.action_expert.registers.expand(state_tokens.shape[0], -1, -1)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The registers attribute can be None if num_registers is set to 0 in the configuration (e.g., during pre-training). Calling .expand() on None will raise an AttributeError. This should be handled safely.

Suggested change
registers = model.action_expert.registers.expand(state_tokens.shape[0], -1, -1)
registers = model.action_expert.registers
if registers is not None:
registers = registers.expand(state_tokens.shape[0], -1, -1)


video_velocity = model.video_module.apply_output_head(video_tokens, video_head_time_emb)
action_pred_full = model.action_expert.decoder(action_tokens, action_head_time_emb)
action_velocity = action_pred_full[:, 1 : -model.action_expert.config.num_registers, :]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

When num_registers is 0, the slice 1 : -model.action_expert.config.num_registers becomes 1 : 0, which returns an empty tensor. This will likely cause a failure in the scheduler's step method. The slicing logic should account for the case where no registers are used.

Suggested change
action_velocity = action_pred_full[:, 1 : -model.action_expert.config.num_registers, :]
num_regs = model.action_expert.config.num_registers
action_velocity = action_pred_full[:, 1 : -num_regs, :] if num_regs > 0 else action_pred_full[:, 1:, :]


def decode_video(self, video_latents: torch.Tensor) -> torch.Tensor:
with torch.no_grad():
return torch.stack([self.vae.decode([video_latents[i]])[0] for i in range(video_latents.shape[0])], dim=0)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Decoding video latents sample-by-sample in a loop is inefficient because the underlying VAE model is designed to process batches. Processing the entire batch at once would significantly improve performance.

Suggested change
return torch.stack([self.vae.decode([video_latents[i]])[0] for i in range(video_latents.shape[0])], dim=0)
with torch.no_grad():
return self.vae.model.decode(video_latents, self.vae.scale).float().clamp(-1, 1)

Comment on lines +1029 to +1041
try:
if not isinstance(zs, list):
raise TypeError("zs should be a list")
with amp.autocast(dtype=self.dtype):
return [
self.model.decode(u.unsqueeze(0),
self.scale).float().clamp_(-1,
1).squeeze(0)
for u in zs
]
except TypeError as e:
logging.info(e)
return None
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The error handling in the decode method catches TypeError and returns None. This is problematic because the caller (e.g., WanVideoModel.decode_video) expects a list and will crash when attempting to index or iterate over None. It is better to let the exception propagate or return an empty list if appropriate.

    def decode(self, zs):
        if not isinstance(zs, list):
            raise TypeError("zs should be a list")
        with amp.autocast(dtype=self.dtype):
            return [
                self.model.decode(u.unsqueeze(0), self.scale).float().clamp(-1, 1).squeeze(0)
                for u in zs
            ]

Comment on lines +31 to +32
@lru_cache(maxsize=256)
def _make_freq_grid(f: int, h: int, w: int):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The _make_freq_grid function is decorated with @lru_cache but is defined inside rope_apply. This means a new function object (and a new empty cache) is created on every call to rope_apply, making the cache ineffective. This function should be moved outside rope_apply to allow the cache to persist across calls.

Comment on lines +58 to +59
@lru_cache(maxsize=256)
def _make_freq_grid(f: int, h: int, w: int):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The _make_freq_grid function is decorated with @lru_cache but is defined inside rope_apply. This renders the cache ineffective as it is recreated on every call. Consider moving this function outside rope_apply or using a persistent cache.

Comment on lines +82 to +109
def rope_apply_original(x, grid_sizes, freqs):
n, c = x.size(2), x.size(3) // 2

# split freqs
freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)

# loop over samples
output = []
for i, (f, h, w) in enumerate(grid_sizes.tolist()):
seq_len = f * h * w

# precompute multipliers
x_i = torch.view_as_complex(x[i, :seq_len].to(torch.float64).reshape(
seq_len, n, -1, 2))
freqs_i = torch.cat([
freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
],
dim=-1).reshape(seq_len, 1, -1)

# apply rotary embedding
x_i = torch.view_as_real(x_i * freqs_i).flatten(2)
x_i = torch.cat([x_i, x[i, seq_len:]])

# append to collection
output.append(x_i)
return torch.stack(output).float()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The rope_apply_original function is defined but never used in the codebase. Removing dead code improves maintainability and reduces technical debt.


if q_lens is None:
q = half(q.flatten(0, 1))
q_lens = torch.tensor([q_len] * batch, dtype=torch.int32).to(device=q.device, non_blocking=True)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Creating a tensor on CPU and moving it to the GPU in a hot loop is inefficient. Using torch.full directly on the target device is faster.

Suggested change
q_lens = torch.tensor([q_len] * batch, dtype=torch.int32).to(device=q.device, non_blocking=True)
q_lens = torch.full((batch,), q_len, dtype=torch.int32, device=q.device)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant