Skip to content

[feat] Add diffusion forcing / ode init for MatrixGame2.0#1179

Open
H1yori233 wants to merge 59 commits intohao-ai-lab:mainfrom
H1yori233:feat/kaiqin/add-mg-df
Open

[feat] Add diffusion forcing / ode init for MatrixGame2.0#1179
H1yori233 wants to merge 59 commits intohao-ai-lab:mainfrom
H1yori233:feat/kaiqin/add-mg-df

Conversation

@H1yori233
Copy link
Copy Markdown
Collaborator

Purpose

Add diffusion forcing and ode init for MatrixGame 2.0

Changes

  • Introduce fastvideo/models/schedulers/scheduling_diffusion_forcing.py scheduler for diffusion forcing.
  • add fastvideo/training/matrixgame_ar_diffusion_pipeline.py diffusion forcing pipeline.
  • add schema fastvideo/dataset/dataloader/schema.py and pipeline fastvideo/training/matrixgame_ode_causal_pipeline.py for ODE Init.
  • fastvideo/training/training_pipeline.py now can upload reference video to WanDB for easily comparison.

Others

this PR is on previous training pipeline, changes for the new trainer will be delivered in another PR.

Checklist

  • I ran pre-commit run --all-files and fixed all issues

@H1yori233
Copy link
Copy Markdown
Collaborator Author

/gemini review

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a new ODE trajectory preprocessing pipeline and a corresponding training pipeline for MatrixGame models. It adds a DiffusionForcingScheduler and integrates it into the training process, along with updates to utility functions to support this new scheduler. The changes also include a new shell script for finetuning, a PyArrow schema for ODE trajectory data, and enhancements to the validation logging to include reference videos. Review comments suggest refactoring duplicated code, optimizing timestep calculation, and correcting misleading comments in the code.

Comment on lines +244 to +283
# Optional PIL Image
if pil_image is not None:
record.update({
"pil_image_bytes": pil_image.tobytes(),
"pil_image_shape": list(pil_image.shape),
"pil_image_dtype": str(pil_image.dtype),
})
else:
record.update({
"pil_image_bytes": b"",
"pil_image_shape": [],
"pil_image_dtype": "",
})

# Actions
if keyboard_cond is not None:
record.update({
"keyboard_cond_bytes": keyboard_cond.tobytes(),
"keyboard_cond_shape": list(keyboard_cond.shape),
"keyboard_cond_dtype": str(keyboard_cond.dtype),
})
else:
record.update({
"keyboard_cond_bytes": b"",
"keyboard_cond_shape": [],
"keyboard_cond_dtype": "",
})

if mouse_cond is not None:
record.update({
"mouse_cond_bytes": mouse_cond.tobytes(),
"mouse_cond_shape": list(mouse_cond.shape),
"mouse_cond_dtype": str(mouse_cond.dtype),
})
else:
record.update({
"mouse_cond_bytes": b"",
"mouse_cond_shape": [],
"mouse_cond_dtype": "",
})
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The logic for handling optional numpy arrays (pil_image, keyboard_cond, mouse_cond) is duplicated. This can be refactored into a loop to improve maintainability and reduce boilerplate code.

Suggested change
# Optional PIL Image
if pil_image is not None:
record.update({
"pil_image_bytes": pil_image.tobytes(),
"pil_image_shape": list(pil_image.shape),
"pil_image_dtype": str(pil_image.dtype),
})
else:
record.update({
"pil_image_bytes": b"",
"pil_image_shape": [],
"pil_image_dtype": "",
})
# Actions
if keyboard_cond is not None:
record.update({
"keyboard_cond_bytes": keyboard_cond.tobytes(),
"keyboard_cond_shape": list(keyboard_cond.shape),
"keyboard_cond_dtype": str(keyboard_cond.dtype),
})
else:
record.update({
"keyboard_cond_bytes": b"",
"keyboard_cond_shape": [],
"keyboard_cond_dtype": "",
})
if mouse_cond is not None:
record.update({
"mouse_cond_bytes": mouse_cond.tobytes(),
"mouse_cond_shape": list(mouse_cond.shape),
"mouse_cond_dtype": str(mouse_cond.dtype),
})
else:
record.update({
"mouse_cond_bytes": b"",
"mouse_cond_shape": [],
"mouse_cond_dtype": "",
})
# Optional PIL Image and Actions
for prefix, array in [("pil_image", pil_image), ("keyboard_cond", keyboard_cond), ("mouse_cond", mouse_cond)]:
if array is not None:
record.update({
f"{prefix}_bytes": array.tobytes(),
f"{prefix}_shape": list(array.shape),
f"{prefix}_dtype": str(array.dtype),
})
else:
record.update({
f"{prefix}_bytes": b"",
f"{prefix}_shape": [],
f"{prefix}_dtype": "",
})

Comment on lines +97 to +100
timestep_id = torch.argmin(
(self.timesteps.unsqueeze(0) - timestep.unsqueeze(1)).abs(),
dim=1,
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using torch.argmin to find the timestep_id can be inefficient, especially if self.timesteps is large, as it has a time complexity of O(N). Since self.timesteps is a sorted tensor (monotonically decreasing), you can use torch.searchsorted for a more performant O(logN) lookup.

To use torch.searchsorted, you'll need to work with a monotonically increasing tensor. You could either flip self.timesteps before searching or adjust its creation in set_timesteps to be ascending and then flip it for the parts of the code that expect a descending order.


features["clip_feature"] = clip_features
"""Get VAE features from the first frame of each video"""
# Get CLIP features from the first frame of each video.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This comment is misleading. The code block that follows is responsible for calculating VAE features, not CLIP features. CLIP features are calculated in the block just before this comment.

Suggested change
# Get CLIP features from the first frame of each video.
# Get VAE features from the first frame of each video.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes this should be VAE, you can fix it

self.get_module("vae").to(get_local_torch_device())

features = {}
"""Get CLIP features from the first frame of each video."""
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This appears to be a docstring placed in the middle of a function body. According to Python style guides (like PEP 257), docstrings should only appear as the first statement in a module, function, class, or method definition. For comments within a function, please use the # syntax.

Suggested change
"""Get CLIP features from the first frame of each video."""
# Get CLIP features from the first frame of each video.


features["clip_feature"] = clip_features
features["pil_image"] = first_frame
# Get CLIP features from the first frame of each video.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This comment is misleading. The code block that follows is responsible for preparing video_conditions to be encoded by the VAE, not for getting CLIP features. CLIP features are calculated in the block before this.

Suggested change
# Get CLIP features from the first frame of each video.
# Get VAE features from the first frame of each video.

logger.info("relevant_traj_latents: %s", relevant_traj_latents.shape)

indexes = self._get_timestep( # [B, num_frames]
0, len(self.dmd_denoising_steps), B, num_frames, 3, uniform_timestep=False)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The num_frame_per_block argument is hardcoded as 3. This should be a configurable parameter, similar to how it's handled in MatrixGameARDiffusionPipeline, to improve flexibility and maintain consistency across pipelines. Consider making it a class attribute initialized from the training arguments.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about just use global_step == self.init_steps to check ref video logged or not, instead of self.validation_ref_videos_logged = False at init.

@mergify mergify Bot added type: feat New feature or capability scope: training Training pipeline, methods, configs scope: data Data preprocessing, datasets scope: model Model architecture (DiTs, encoders, VAEs) labels Mar 30, 2026
@mergify
Copy link
Copy Markdown
Contributor

mergify Bot commented Mar 30, 2026

Merge Protections

Your pull request matches the following merge protections and will not be merged until they are valid.

🔴 PR merge requirements

Waiting for:

  • #approved-reviews-by>=1
  • check-success=full-suite-passed
  • check-success~=pre-commit
This rule is failing.
  • #approved-reviews-by>=1
  • check-success=full-suite-passed
  • check-success~=pre-commit
  • check-success=fastcheck-passed
  • title~=(?i)^\[(feat|feature|bugfix|fix|refactor|perf|ci|doc|docs|misc|chore|kernel|new.?model)\]

@mergify
Copy link
Copy Markdown
Contributor

mergify Bot commented Mar 30, 2026

Buildkite CI tests failed

Hi @H1yori233, some Buildkite CI tests have failed. Check the build for details:
View Buildkite build →

Common causes:

  • Test failures: Check the failing step's output for assertion errors or tracebacks
  • Import errors: Make sure new dependencies are added to pyproject.toml
  • GPU memory: Some tests require specific GPU types (L40S, H100 NVL)
  • Kernel build: If you changed fastvideo-kernel/, the build may have failed

If the failure is unrelated to your changes, leave a comment explaining why.

Comment thread fastvideo/models/utils.py
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is sigma_from_timestep needed? Seems timesteps are all precomputed, and the sigmas are stored in scheduler.sigmas based on self.timesteps.

@alexzms alexzms self-requested a review March 31, 2026 22:45
alexzms and others added 6 commits March 31, 2026 23:04
Reverts the DiffusionForcingScheduler dependency in dfsft.py.
The original design using student's scheduler is more generic
and works with both flow matching and DDPM models.
Add per-frame gaussian loss weighting matching Causal-Forcing's
bsmntw scheme. Mid-noise timesteps get higher weight, extremes
(near-clean and pure-noise) get lower weight.
@mergify
Copy link
Copy Markdown
Contributor

mergify Bot commented Apr 24, 2026

Pre-commit checks failed

Hi @H1yori233, the pre-commit checks have failed. To fix them locally:

# Install pre-commit if you haven't already
uv pip install pre-commit
pre-commit install

# Run all checks and auto-fix what's possible
pre-commit run --all-files

Common fixes:

  • yapf: yapf -i <file> (formatting)
  • ruff: ruff check --fix <file> (linting)
  • codespell: codespell --write-changes <file> (spelling)

After fixing, commit and push the changes. The checks will re-run automatically.

For future commits, pre-commit will run automatically on changed files before each commit.

@mergify
Copy link
Copy Markdown
Contributor

mergify Bot commented Apr 24, 2026

Pre-commit checks failed

Hi @H1yori233, the pre-commit checks have failed. To fix them locally:

# Install pre-commit if you haven't already
uv pip install pre-commit
pre-commit install

# Run all checks and auto-fix what's possible
pre-commit run --all-files

Common fixes:

  • yapf: yapf -i <file> (formatting)
  • ruff: ruff check --fix <file> (linting)
  • codespell: codespell --write-changes <file> (spelling)

After fixing, commit and push the changes. The checks will re-run automatically.

For future commits, pre-commit will run automatically on changed files before each commit.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

scope: data Data preprocessing, datasets scope: model Model architecture (DiTs, encoders, VAEs) scope: training Training pipeline, methods, configs type: feat New feature or capability

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants