[feat] Add diffusion forcing / ode init for MatrixGame2.0#1179
[feat] Add diffusion forcing / ode init for MatrixGame2.0#1179H1yori233 wants to merge 59 commits intohao-ai-lab:mainfrom
Conversation
Feat/kaiqin/mg overlay validation
|
/gemini review |
There was a problem hiding this comment.
Code Review
This pull request introduces a new ODE trajectory preprocessing pipeline and a corresponding training pipeline for MatrixGame models. It adds a DiffusionForcingScheduler and integrates it into the training process, along with updates to utility functions to support this new scheduler. The changes also include a new shell script for finetuning, a PyArrow schema for ODE trajectory data, and enhancements to the validation logging to include reference videos. Review comments suggest refactoring duplicated code, optimizing timestep calculation, and correcting misleading comments in the code.
| # Optional PIL Image | ||
| if pil_image is not None: | ||
| record.update({ | ||
| "pil_image_bytes": pil_image.tobytes(), | ||
| "pil_image_shape": list(pil_image.shape), | ||
| "pil_image_dtype": str(pil_image.dtype), | ||
| }) | ||
| else: | ||
| record.update({ | ||
| "pil_image_bytes": b"", | ||
| "pil_image_shape": [], | ||
| "pil_image_dtype": "", | ||
| }) | ||
|
|
||
| # Actions | ||
| if keyboard_cond is not None: | ||
| record.update({ | ||
| "keyboard_cond_bytes": keyboard_cond.tobytes(), | ||
| "keyboard_cond_shape": list(keyboard_cond.shape), | ||
| "keyboard_cond_dtype": str(keyboard_cond.dtype), | ||
| }) | ||
| else: | ||
| record.update({ | ||
| "keyboard_cond_bytes": b"", | ||
| "keyboard_cond_shape": [], | ||
| "keyboard_cond_dtype": "", | ||
| }) | ||
|
|
||
| if mouse_cond is not None: | ||
| record.update({ | ||
| "mouse_cond_bytes": mouse_cond.tobytes(), | ||
| "mouse_cond_shape": list(mouse_cond.shape), | ||
| "mouse_cond_dtype": str(mouse_cond.dtype), | ||
| }) | ||
| else: | ||
| record.update({ | ||
| "mouse_cond_bytes": b"", | ||
| "mouse_cond_shape": [], | ||
| "mouse_cond_dtype": "", | ||
| }) |
There was a problem hiding this comment.
The logic for handling optional numpy arrays (pil_image, keyboard_cond, mouse_cond) is duplicated. This can be refactored into a loop to improve maintainability and reduce boilerplate code.
| # Optional PIL Image | |
| if pil_image is not None: | |
| record.update({ | |
| "pil_image_bytes": pil_image.tobytes(), | |
| "pil_image_shape": list(pil_image.shape), | |
| "pil_image_dtype": str(pil_image.dtype), | |
| }) | |
| else: | |
| record.update({ | |
| "pil_image_bytes": b"", | |
| "pil_image_shape": [], | |
| "pil_image_dtype": "", | |
| }) | |
| # Actions | |
| if keyboard_cond is not None: | |
| record.update({ | |
| "keyboard_cond_bytes": keyboard_cond.tobytes(), | |
| "keyboard_cond_shape": list(keyboard_cond.shape), | |
| "keyboard_cond_dtype": str(keyboard_cond.dtype), | |
| }) | |
| else: | |
| record.update({ | |
| "keyboard_cond_bytes": b"", | |
| "keyboard_cond_shape": [], | |
| "keyboard_cond_dtype": "", | |
| }) | |
| if mouse_cond is not None: | |
| record.update({ | |
| "mouse_cond_bytes": mouse_cond.tobytes(), | |
| "mouse_cond_shape": list(mouse_cond.shape), | |
| "mouse_cond_dtype": str(mouse_cond.dtype), | |
| }) | |
| else: | |
| record.update({ | |
| "mouse_cond_bytes": b"", | |
| "mouse_cond_shape": [], | |
| "mouse_cond_dtype": "", | |
| }) | |
| # Optional PIL Image and Actions | |
| for prefix, array in [("pil_image", pil_image), ("keyboard_cond", keyboard_cond), ("mouse_cond", mouse_cond)]: | |
| if array is not None: | |
| record.update({ | |
| f"{prefix}_bytes": array.tobytes(), | |
| f"{prefix}_shape": list(array.shape), | |
| f"{prefix}_dtype": str(array.dtype), | |
| }) | |
| else: | |
| record.update({ | |
| f"{prefix}_bytes": b"", | |
| f"{prefix}_shape": [], | |
| f"{prefix}_dtype": "", | |
| }) |
| timestep_id = torch.argmin( | ||
| (self.timesteps.unsqueeze(0) - timestep.unsqueeze(1)).abs(), | ||
| dim=1, | ||
| ) |
There was a problem hiding this comment.
Using torch.argmin to find the timestep_id can be inefficient, especially if self.timesteps is large, as it has a time complexity of O(N). Since self.timesteps is a sorted tensor (monotonically decreasing), you can use torch.searchsorted for a more performant O(logN) lookup.
To use torch.searchsorted, you'll need to work with a monotonically increasing tensor. You could either flip self.timesteps before searching or adjust its creation in set_timesteps to be ascending and then flip it for the parts of the code that expect a descending order.
|
|
||
| features["clip_feature"] = clip_features | ||
| """Get VAE features from the first frame of each video""" | ||
| # Get CLIP features from the first frame of each video. |
There was a problem hiding this comment.
This comment is misleading. The code block that follows is responsible for calculating VAE features, not CLIP features. CLIP features are calculated in the block just before this comment.
| # Get CLIP features from the first frame of each video. | |
| # Get VAE features from the first frame of each video. |
There was a problem hiding this comment.
Yes this should be VAE, you can fix it
| self.get_module("vae").to(get_local_torch_device()) | ||
|
|
||
| features = {} | ||
| """Get CLIP features from the first frame of each video.""" |
There was a problem hiding this comment.
This appears to be a docstring placed in the middle of a function body. According to Python style guides (like PEP 257), docstrings should only appear as the first statement in a module, function, class, or method definition. For comments within a function, please use the # syntax.
| """Get CLIP features from the first frame of each video.""" | |
| # Get CLIP features from the first frame of each video. |
|
|
||
| features["clip_feature"] = clip_features | ||
| features["pil_image"] = first_frame | ||
| # Get CLIP features from the first frame of each video. |
There was a problem hiding this comment.
This comment is misleading. The code block that follows is responsible for preparing video_conditions to be encoded by the VAE, not for getting CLIP features. CLIP features are calculated in the block before this.
| # Get CLIP features from the first frame of each video. | |
| # Get VAE features from the first frame of each video. |
| logger.info("relevant_traj_latents: %s", relevant_traj_latents.shape) | ||
|
|
||
| indexes = self._get_timestep( # [B, num_frames] | ||
| 0, len(self.dmd_denoising_steps), B, num_frames, 3, uniform_timestep=False) |
There was a problem hiding this comment.
There was a problem hiding this comment.
How about just use global_step == self.init_steps to check ref video logged or not, instead of self.validation_ref_videos_logged = False at init.
Merge ProtectionsYour pull request matches the following merge protections and will not be merged until they are valid. 🔴 PR merge requirementsWaiting for:
This rule is failing.
|
Buildkite CI tests failedHi @H1yori233, some Buildkite CI tests have failed. Check the build for details: Common causes:
If the failure is unrelated to your changes, leave a comment explaining why. |
There was a problem hiding this comment.
Why is sigma_from_timestep needed? Seems timesteps are all precomputed, and the sigmas are stored in scheduler.sigmas based on self.timesteps.
Reverts the DiffusionForcingScheduler dependency in dfsft.py. The original design using student's scheduler is more generic and works with both flow matching and DDPM models.
Add per-frame gaussian loss weighting matching Causal-Forcing's bsmntw scheme. Mid-noise timesteps get higher weight, extremes (near-clean and pure-noise) get lower weight.
Pre-commit checks failedHi @H1yori233, the pre-commit checks have failed. To fix them locally: # Install pre-commit if you haven't already
uv pip install pre-commit
pre-commit install
# Run all checks and auto-fix what's possible
pre-commit run --all-filesCommon fixes:
After fixing, commit and push the changes. The checks will re-run automatically. For future commits, |
Pre-commit checks failedHi @H1yori233, the pre-commit checks have failed. To fix them locally: # Install pre-commit if you haven't already
uv pip install pre-commit
pre-commit install
# Run all checks and auto-fix what's possible
pre-commit run --all-filesCommon fixes:
After fixing, commit and push the changes. The checks will re-run automatically. For future commits, |
Purpose
Add diffusion forcing and ode init for MatrixGame 2.0
Changes
fastvideo/models/schedulers/scheduling_diffusion_forcing.pyscheduler for diffusion forcing.fastvideo/training/matrixgame_ar_diffusion_pipeline.pydiffusion forcing pipeline.fastvideo/dataset/dataloader/schema.pyand pipelinefastvideo/training/matrixgame_ode_causal_pipeline.pyfor ODE Init.fastvideo/training/training_pipeline.pynow can upload reference video to WanDB for easily comparison.Others
this PR is on previous training pipeline, changes for the new trainer will be delivered in another PR.
Checklist
pre-commit run --all-filesand fixed all issues