Skip to content
This repository was archived by the owner on May 20, 2026. It is now read-only.

Commit a31128f

Browse files
huvunvidiaHuy Vu2
andauthored
Unifying Flow-matching Automodel <-> Mbridge (#84)
* initial commit * runnable codes but not convergence tested * runnable codes but not convergence tested * workable code * workable code, verified training on 30 videos, functional tests passed for wan Automodel and wan Mbridge * fix lint * fix line * fix lint * fix lint * update unit test * fix lint --------- Co-authored-by: Huy Vu2 <huvu@login-eos02.eos.clusters.nvidia.com>
1 parent 9eaace1 commit a31128f

11 files changed

Lines changed: 467 additions & 457 deletions

File tree

dfm/src/automodel/flow_matching/flow_matching_pipeline.py

Lines changed: 28 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ class FlowMatchingPipeline:
104104
)
105105
106106
# Training step
107-
loss, metrics = pipeline.step(model, batch, device, dtype, global_step)
107+
weighted_loss, average_weighted_loss, loss_mask, metrics = pipeline.step(model, batch, device, dtype, global_step)
108108
"""
109109

110110
def __init__(
@@ -262,6 +262,7 @@ def compute_loss(
262262
model_pred: torch.Tensor,
263263
target: torch.Tensor,
264264
sigma: torch.Tensor,
265+
batch: Dict[str, Any],
265266
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
266267
"""
267268
Compute flow matching loss with optional weighting.
@@ -279,6 +280,7 @@ def compute_loss(
279280
loss_weight: Applied weights
280281
"""
281282
loss = nn.functional.mse_loss(model_pred.float(), target.float(), reduction="none")
283+
loss_mask = batch["loss_mask"] if "loss_mask" in batch else None
282284

283285
if self.use_loss_weighting:
284286
loss_weight = 1.0 + self.flow_shift * sigma
@@ -288,17 +290,19 @@ def compute_loss(
288290

289291
loss_weight = loss_weight.to(model_pred.device)
290292

291-
unweighted_loss = loss.mean()
292-
weighted_loss = (loss * loss_weight).mean()
293+
unweighted_loss = loss
294+
weighted_loss = loss * loss_weight
295+
average_unweighted_loss = unweighted_loss.mean()
296+
average_weighted_loss = weighted_loss.mean()
293297

294-
return weighted_loss, unweighted_loss, loss_weight
298+
return weighted_loss, average_weighted_loss, unweighted_loss, average_unweighted_loss, loss_weight, loss_mask
295299

296300
def step(
297301
self,
298302
model: nn.Module,
299303
batch: Dict[str, Any],
300-
device: torch.device,
301-
dtype: torch.dtype,
304+
device: torch.device = torch.device("cuda"),
305+
dtype: torch.dtype = torch.bfloat16,
302306
global_step: int = 0,
303307
) -> Tuple[torch.Tensor, Dict[str, Any]]:
304308
"""
@@ -398,26 +402,35 @@ def step(
398402
# ====================================================================
399403
# Loss Computation
400404
# ====================================================================
401-
weighted_loss, unweighted_loss, loss_weight = self.compute_loss(model_pred, target, sigma)
405+
weighted_loss, average_weighted_loss, unweighted_loss, average_unweighted_loss, loss_weight, loss_mask = (
406+
self.compute_loss(model_pred, target, sigma, batch)
407+
)
402408

403409
# Safety check
404-
if torch.isnan(weighted_loss) or weighted_loss > 100:
405-
logger.error(f"[ERROR] Loss explosion! Loss={weighted_loss.item():.3f}")
406-
raise ValueError(f"Loss exploded: {weighted_loss.item()}")
410+
if torch.isnan(average_weighted_loss) or average_weighted_loss > 100:
411+
logger.error(f"[ERROR] Loss explosion! Loss={average_weighted_loss.item():.3f}")
412+
raise ValueError(f"Loss exploded: {average_weighted_loss.item()}")
407413

408414
# Logging
409415
if detailed_log or debug_mode:
410-
self._log_loss_detailed(global_step, model_pred, target, loss_weight, unweighted_loss, weighted_loss)
416+
self._log_loss_detailed(
417+
global_step,
418+
model_pred,
419+
target,
420+
loss_weight,
421+
average_unweighted_loss,
422+
average_weighted_loss,
423+
)
411424
elif summary_log:
412425
logger.info(
413-
f"[STEP {global_step}] Loss: {weighted_loss.item():.6f} | "
426+
f"[STEP {global_step}] Loss: {average_weighted_loss.item():.6f} | "
414427
f"w=[{loss_weight.min():.2f},{loss_weight.max():.2f}]"
415428
)
416429

417430
# Collect metrics
418431
metrics = {
419-
"loss": weighted_loss.item(),
420-
"unweighted_loss": unweighted_loss.item(),
432+
"loss": average_weighted_loss.item(),
433+
"unweighted_loss": average_unweighted_loss.item(),
421434
"sigma_min": sigma.min().item(),
422435
"sigma_max": sigma.max().item(),
423436
"sigma_mean": sigma.mean().item(),
@@ -432,7 +445,7 @@ def step(
432445
"data_type": data_type,
433446
}
434447

435-
return weighted_loss, metrics
448+
return weighted_loss, average_weighted_loss, loss_mask, metrics
436449

437450
def _log_detailed(
438451
self,

dfm/src/automodel/recipes/train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -382,7 +382,7 @@ def run_train_validation_loop(self):
382382
micro_losses = []
383383
for micro_batch in batch_group:
384384
try:
385-
loss, metrics = self.flow_matching_pipeline.step(
385+
_, loss, _, metrics = self.flow_matching_pipeline.step(
386386
model=self.model,
387387
batch=micro_batch,
388388
device=self.device,

dfm/src/megatron/model/dit/dit_layer_spec.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,7 @@ def forward(
184184
sequence_len_offset=None,
185185
inference_context=None,
186186
rotary_pos_cos_sin=None,
187+
**kwargs,
187188
):
188189
timestep_emb = attention_mask
189190

Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,161 @@
1+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from typing import Any, Dict, Tuple
16+
17+
import torch
18+
import torch.nn as nn
19+
from megatron.core import parallel_state
20+
21+
from dfm.src.automodel.flow_matching.adapters.base import FlowMatchingContext, ModelAdapter
22+
from dfm.src.automodel.flow_matching.flow_matching_pipeline import FlowMatchingPipeline
23+
from dfm.src.megatron.model.wan.utils import thd_split_inputs_cp
24+
25+
26+
class WanAdapter(ModelAdapter):
27+
"""
28+
Model adapter for Wan model (Megatron version).
29+
30+
Handles mapping of standard FlowMatchingContext to Wan specific inputs.
31+
"""
32+
33+
def prepare_inputs(self, context: FlowMatchingContext) -> Dict[str, Any]:
34+
grid_sizes = context.batch["grid_sizes"]
35+
noisy_latents = context.noisy_latents
36+
video_latents = context.video_latents
37+
loss_mask = context.batch["loss_mask"]
38+
context_embeddings = context.batch["context_embeddings"]
39+
timesteps = context.timesteps
40+
packed_seq_params = context.batch["packed_seq_params"]
41+
42+
# tranpose back to have shape "sbhd"
43+
# (before we reshaped to "bshd" to be compatible with flow matching pipeline)
44+
noisy_latents = noisy_latents.transpose(0, 1)
45+
46+
# ========================================================================
47+
# Cast model inputs to bf16
48+
# ========================================================================
49+
50+
noisy_latents = noisy_latents.to(torch.bfloat16)
51+
context_embeddings = context_embeddings.to(torch.bfloat16)
52+
53+
# NOTE: investigate the affect of bf16 timesteps on embedding precision
54+
# CRITICAL: Keep timesteps in fp32 for embedding precision
55+
# timesteps = timesteps.float() # NOT bf16!
56+
timesteps = timesteps.to(torch.bfloat16)
57+
58+
# ========================================================================
59+
# Split accross context parallelism
60+
# ========================================================================
61+
62+
if parallel_state.get_context_parallel_world_size() > 1:
63+
noisy_latents = thd_split_inputs_cp(
64+
noisy_latents,
65+
packed_seq_params["self_attention"].cu_seqlens_q_padded,
66+
parallel_state.get_context_parallel_group(),
67+
)
68+
# TODO (pmannan): Disable CP for CrossAttention as KV context is small.
69+
# We don't need to split context embeddings across context parallelism
70+
# if we disable context parallelism for cross-attention
71+
context_embeddings = thd_split_inputs_cp(
72+
context_embeddings,
73+
packed_seq_params["cross_attention"].cu_seqlens_kv_padded,
74+
parallel_state.get_context_parallel_group(),
75+
)
76+
else:
77+
noisy_latents = noisy_latents
78+
context_embeddings = context_embeddings
79+
80+
return {
81+
"noisy_latents": noisy_latents,
82+
"grid_sizes": grid_sizes,
83+
"timesteps": timesteps,
84+
"context_embeddings": context_embeddings,
85+
"packed_seq_params": packed_seq_params,
86+
}
87+
88+
def forward(self, model: nn.Module, inputs: Dict[str, Any]) -> torch.Tensor:
89+
"""
90+
Execute forward pass for Wan model.
91+
92+
Args:
93+
model: Wan model
94+
inputs: Dictionary from prepare_inputs()
95+
96+
Returns:
97+
Model prediction tensor
98+
"""
99+
100+
model_pred = model(
101+
x=inputs["noisy_latents"],
102+
grid_sizes=inputs["grid_sizes"],
103+
t=inputs["timesteps"],
104+
context=inputs["context_embeddings"],
105+
packed_seq_params=inputs["packed_seq_params"],
106+
)
107+
return self.post_process_prediction(model_pred)
108+
109+
110+
class WanFlowMatchingPipeline(FlowMatchingPipeline):
111+
"""
112+
Wan-specific Flow Matching pipeline handling Context Parallelism and Custom Noise.
113+
114+
This pipeline extends the standard FlowMatchingPipeline to support:
115+
1. Wan-specific noise generation (patching + padding)
116+
2. Context Parallelism (CP) splitting of inputs
117+
3. Masked loss computation
118+
"""
119+
120+
def determine_task_type(self, data_type: str) -> str:
121+
"""Determine task type based on data type and randomization."""
122+
return "t2v"
123+
124+
def compute_loss(
125+
self,
126+
model_pred: torch.Tensor,
127+
target: torch.Tensor,
128+
sigma: torch.Tensor,
129+
batch: Dict[str, Any],
130+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
131+
loss_mask = batch["loss_mask"]
132+
packed_seq_params = batch["packed_seq_params"]
133+
134+
# tranpose back to have shape "sbhd"
135+
# (before we reshaped to "bshd" to be compatible with flow matching pipeline)
136+
target = target.transpose(0, 1)
137+
138+
# ========================================================================
139+
# Split accross context parallelism
140+
# ========================================================================
141+
142+
if parallel_state.get_context_parallel_world_size() > 1:
143+
target = thd_split_inputs_cp(
144+
target,
145+
packed_seq_params["self_attention"].cu_seqlens_q_padded,
146+
parallel_state.get_context_parallel_group(),
147+
)
148+
split_loss_mask = thd_split_inputs_cp(
149+
loss_mask,
150+
packed_seq_params["self_attention"].cu_seqlens_q_padded,
151+
parallel_state.get_context_parallel_group(),
152+
)
153+
else:
154+
target = target
155+
split_loss_mask = loss_mask
156+
157+
batch["loss_mask"] = split_loss_mask
158+
weighted_loss, average_weighted_loss, unweighted_loss, average_unweighted_loss, loss_weight, loss_mask = (
159+
super().compute_loss(model_pred, target, sigma, batch)
160+
)
161+
return weighted_loss, average_weighted_loss, unweighted_loss, average_unweighted_loss, loss_weight, loss_mask

0 commit comments

Comments
 (0)