Skip to content
This repository was archived by the owner on May 20, 2026. It is now read-only.

Commit bc1699d

Browse files
committed
add support for fastgen precomupted dataset, update the negative prompt, remove 4 step generation of the student and just call fastgen approach.
Signed-off-by: sajadn <snorouzi@nvidia.com>
1 parent 492c11f commit bc1699d

8 files changed

Lines changed: 240 additions & 63 deletions

File tree

dfm/src/megatron/data/wan/wan_energon_datamodule.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -17,31 +17,35 @@
1717
from dataclasses import dataclass
1818

1919
from megatron.bridge.data.utils import DatasetBuildContext
20-
from torch import int_repr
2120

2221
from dfm.src.megatron.data.common.diffusion_energon_datamodule import DiffusionDataModule, DiffusionDataModuleConfig
22+
from dfm.src.megatron.data.wan.wan_latent_taskencoder import WanLatentTaskEncoder
2323
from dfm.src.megatron.data.wan.wan_taskencoder import WanTaskEncoder
2424

2525

2626
@dataclass(kw_only=True)
2727
class WanDataModuleConfig(DiffusionDataModuleConfig):
28-
path: str
29-
seq_length: int
30-
packing_buffer_size: int
31-
micro_batch_size: int
32-
global_batch_size: int
33-
num_workers: int_repr
34-
dataloader_type: str = "external"
28+
# Only define new fields here; inherited fields come from DiffusionDataModuleConfig
29+
use_fastgen_dataset: bool = False # Flag to determine which task encoder to use
3530

3631
def __post_init__(self):
32+
# Instantiate the appropriate task encoder based on the flag
33+
if self.use_fastgen_dataset:
34+
task_encoder = WanLatentTaskEncoder(
35+
seq_length=self.task_encoder_seq_length,
36+
packing_buffer_size=self.packing_buffer_size,
37+
)
38+
else:
39+
task_encoder = WanTaskEncoder(
40+
seq_length=self.task_encoder_seq_length,
41+
packing_buffer_size=self.packing_buffer_size,
42+
)
43+
3744
self.dataset = DiffusionDataModule(
3845
path=self.path,
3946
seq_length=self.seq_length,
4047
packing_buffer_size=self.packing_buffer_size,
41-
task_encoder=WanTaskEncoder(
42-
seq_length=self.task_encoder_seq_length, # Use task_encoder_seq_length for packing
43-
packing_buffer_size=self.packing_buffer_size,
44-
),
48+
task_encoder=task_encoder,
4549
micro_batch_size=self.micro_batch_size,
4650
global_batch_size=self.global_batch_size,
4751
num_workers=self.num_workers,
Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# pylint: disable=C0115,C0116,C0301
16+
17+
"""
18+
Task encoder for WAN dataset with precomputed latents.
19+
20+
This module provides WanLatentTaskEncoder which handles precomputed VAE-encoded
21+
video latents and text embeddings. It differs from WanTaskEncoder in that it
22+
expects latents that are already VAE-encoded rather than raw video files.
23+
24+
Expected Energon dataset structure per sample:
25+
- latent.pth: RGB video latents (precomputed, already VAE-encoded) [C, T, H, W]
26+
- txt_emb.pth: Text embeddings (already padded to [512, dim])
27+
- depth_latent.pth: Depth latents (optional)
28+
- json: Metadata (resolution, fps, etc.)
29+
30+
The cook function maps these to the format expected by the parent WanTaskEncoder:
31+
- latent.pth -> pth (video latents)
32+
- txt_emb.pth -> pickle (text embeddings)
33+
"""
34+
35+
from megatron.energon.task_encoder.cooking import Cooker, basic_sample_keys
36+
37+
from dfm.src.megatron.data.wan.wan_taskencoder import WanTaskEncoder
38+
39+
40+
def cook_latent(sample: dict) -> dict:
41+
"""
42+
Cook function for precomputed latent samples.
43+
44+
Maps the precomputed latent file naming convention to the format
45+
expected by the parent WanTaskEncoder class.
46+
47+
Args:
48+
sample (dict): Raw sample from Energon dataset containing:
49+
- "latent.pth": RGB video latents (precomputed)
50+
- "txt_emb.pth": Text embeddings
51+
- "depth_latent.pth" (optional): Depth latents
52+
- "json": Metadata
53+
54+
Returns:
55+
dict: Processed sample with keys mapped to parent's expected format:
56+
- "pth": Video latent tensor (from latent.pth)
57+
- "pickle": Text embeddings (from txt_emb.pth)
58+
- "json": Metadata
59+
"""
60+
return dict(
61+
**basic_sample_keys(sample),
62+
pth=sample["latent.pth"], # Map latent.pth -> pth
63+
pickle=sample["txt_emb.pth"], # Map txt_emb.pth -> pickle
64+
json=sample.get("json", {}),
65+
)
66+
67+
68+
class WanLatentTaskEncoder(WanTaskEncoder):
69+
"""
70+
Task encoder for WAN dataset with precomputed latents.
71+
72+
This class inherits from WanTaskEncoder and only overrides the cook function
73+
to handle the different file naming convention used for precomputed latents:
74+
- latent.pth (precomputed VAE-encoded video) instead of raw video in pth
75+
- txt_emb.pth (pre-encoded text embeddings) instead of pickle
76+
77+
All other processing is handled by the parent class:
78+
- Patchifying video latents
79+
- Grid size calculation
80+
- Text embedding padding to 512 tokens
81+
- Context parallelism padding
82+
- Sequence packing
83+
84+
Attributes:
85+
use_depth_latent (bool): Whether to load and use depth latents.
86+
Note: Currently depth latents are loaded but not actively used
87+
in the encoding pipeline. They can be accessed via video_metadata.
88+
89+
Example usage:
90+
task_encoder = WanLatentTaskEncoder(
91+
seq_length=500,
92+
packing_buffer_size=100,
93+
patch_spatial=2,
94+
patch_temporal=1,
95+
use_depth_latent=False, # Set to True if needed
96+
)
97+
"""
98+
99+
cookers = [
100+
Cooker(cook_latent),
101+
]
102+
103+
def __init__(
104+
self,
105+
*args,
106+
use_depth_latent: bool = False,
107+
**kwargs,
108+
):
109+
"""
110+
Initialize the WanLatentTaskEncoder.
111+
112+
Args:
113+
use_depth_latent (bool): Flag to enable depth latent loading.
114+
Defaults to False for memory optimization.
115+
*args: Additional positional arguments passed to parent WanTaskEncoder.
116+
**kwargs: Additional keyword arguments passed to parent WanTaskEncoder.
117+
Common kwargs include:
118+
- seq_length (int): Maximum sequence length
119+
- packing_buffer_size (int): Buffer size for sequence packing
120+
- patch_spatial (int): Spatial patch size (default: 2)
121+
- patch_temporal (int): Temporal patch size (default: 1)
122+
"""
123+
super().__init__(*args, **kwargs)
124+
self.use_depth_latent = use_depth_latent
125+
# All other initialization (patchifying, grid calculation, etc.)
126+
# is handled by the parent WanTaskEncoder class

dfm/src/megatron/model/wan/flow_matching/flow_inference_pipeline.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,8 @@ def __init__(
164164

165165
if dist.is_initialized():
166166
dist.barrier()
167-
self.model.to(self.device)
167+
# Move model to device and convert to correct dtype in one call
168+
self.model.to(device=self.device, dtype=self.param_dtype)
168169

169170
self.sample_neg_prompt = inference_cfg.english_sample_neg_prompt
170171

@@ -317,11 +318,7 @@ def _decode_latents(self, latents, sample=True):
317318
latents.device, latents.dtype
318319
)
319320
latents = latents / latents_std + latents_mean
320-
videos = self.vae.decode(latents)
321-
if sample:
322-
videos = videos.sample()
323-
else:
324-
videos = videos[0].clip_(-1.0, 1.0)
321+
videos = self.vae.decode(latents).sample
325322
return videos
326323

327324
def generate(

dfm/src/megatron/model/wan/wan_model.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -350,3 +350,27 @@ def _set_embedder_weights_replica_id(
350350
replica_id=replica_id,
351351
allow_shape_mismatch=False,
352352
)
353+
354+
def load_state_dict(self, state_dict, strict=True):
355+
"""Load state dict with automatic handling of 'module.' prefix mismatch.
356+
357+
This method handles the case where checkpoints saved with DistributedDataParallel
358+
have a 'module.' prefix that needs to be removed when loading.
359+
360+
Args:
361+
state_dict (dict): The state dictionary to load
362+
strict (bool): Whether to strictly enforce that the keys match
363+
364+
Returns:
365+
NamedTuple: with 'missing_keys' and 'unexpected_keys' fields
366+
"""
367+
# Check if state_dict has 'module.' prefix but model doesn't
368+
has_module_prefix = any(k.startswith("module.") for k in state_dict.keys())
369+
if has_module_prefix:
370+
new_state_dict = {}
371+
for key, value in state_dict.items():
372+
new_key = key.replace("module.", "", 1) if key.startswith("module.") else key
373+
new_state_dict[new_key] = value
374+
state_dict = new_state_dict
375+
376+
return super().load_state_dict(state_dict, strict=strict)

dfm/src/megatron/model/wan_dmd/wan_dmd_step.py

Lines changed: 9 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@
3232
from megatron.core.utils import get_model_config, unwrap_model
3333

3434
import wandb
35-
from dfm.src.fastgen.fastgen.methods.model import FastGenModel
3635
from dfm.src.megatron.model.wan.flow_matching.flow_inference_pipeline import FlowInferencePipeline
3736
from dfm.src.megatron.model.wan.inference import SIZE_CONFIGS
3837
from dfm.src.megatron.model.wan.wan_step import wan_data_step
@@ -128,11 +127,11 @@ def __init__(
128127
def _get_neg_condition(self, unwrapped_model):
129128
"""
130129
Get the negative condition embedding, computing and caching it on first call.
131-
The negative condition is the embedding of an empty string "".
130+
The negative condition uses the prompt from self.inference_cfg.english_sample_neg_prompt.
132131
"""
133132
if self._neg_condition is None:
134133
logger.info("Computing and caching negative condition embedding...")
135-
neg_prompt = [""]
134+
neg_prompt = [self.inference_cfg.english_sample_neg_prompt]
136135
neg_condition = unwrapped_model.get_text_encoder().encode(neg_prompt, precision=torch.bfloat16)
137136
self._neg_condition = neg_condition.transpose(0, 1).contiguous()
138137
logger.info(f"Negative condition cached with shape: {self._neg_condition.shape}")
@@ -176,7 +175,7 @@ def on_train_start(self, student, teacher, fake_score, state: GlobalState):
176175

177176
def on_validation_start(self, single_step_outputs, batch, student, teacher, state: GlobalState):
178177
"""
179-
Generate validation videos from teacher (50 steps) and student (1 step).
178+
Generate validation videos from teacher (50 steps) and student (N steps based on config).
180179
Logs videos to Weights & Biases.
181180
"""
182181
if self._inference_pipeline is None:
@@ -187,8 +186,10 @@ def on_validation_start(self, single_step_outputs, batch, student, teacher, stat
187186
torch.cuda.empty_cache()
188187

189188
# Create pipeline with teacher model (we'll swap for student later)
190-
191189
gen_latent = single_step_outputs["gen_rand"]
190+
if callable(gen_latent):
191+
logger.info("gen_rand is callable (multi-step generation), invoking it to get latents...")
192+
gen_latent = gen_latent()
192193
with torch.no_grad():
193194
gen_videos = self._inference_pipeline._decode_latents(gen_latent, sample=False)
194195
fps = self.inference_cfg.sample_fps
@@ -205,10 +206,11 @@ def on_validation_start(self, single_step_outputs, batch, student, teacher, stat
205206
prompt = "The video captures a series of images showing a group of children seated in an outdoor setting, possibly at a sports event. The children are dressed in casual attire, with one wearing a red top and another in a white top with a rainbow design. The background is filled with other spectators, some of whom are wearing baseball caps. The lighting suggests it's either late afternoon or early evening, and the atmosphere appears to be casual and relaxed."
206207

207208
print("prompt", prompt)
209+
student_steps = student.config.student_sample_steps
208210
self._log_videos_to_wandb(
209211
videos=gen_videos,
210-
video_name="student_prediction",
211-
caption=f"Student (1 step): {prompt}",
212+
video_name=f"student_{student_steps}step_prediction",
213+
caption=f"{prompt}",
212214
fps=fps,
213215
state=state,
214216
)
@@ -218,50 +220,13 @@ def on_validation_start(self, single_step_outputs, batch, student, teacher, stat
218220
gc.collect()
219221
torch.cuda.empty_cache()
220222

221-
student_steps = 4
222-
input_rand = single_step_outputs.get("input_rand", None)
223-
logger.info(f"Generating validation video from student with {student_steps} steps using generator_fn...")
224-
225-
# Get condition from batch
226-
condition = batch.get("context_embeddings", None)
227-
# Extract prompt for caption
228-
229-
with torch.no_grad():
230-
# Wrap student to adapt interface for FastGenModel.generator_fn
231-
wrapped_student = MegatronFastGenInferenceWrapper(student, batch)
232-
# Use FastGenModel.generator_fn directly
233-
student_4step_latents = FastGenModel.generator_fn(
234-
net=wrapped_student,
235-
noise=input_rand, # [B, C, T, H, W] unit Gaussian
236-
condition=condition,
237-
student_sample_steps=student_steps,
238-
student_sample_type="sde", # stochastic sampling
239-
)
240-
241-
# Decode latents to video
242-
student_4step_videos = self._inference_pipeline._decode_latents(student_4step_latents, sample=False)
243-
self._log_videos_to_wandb(
244-
videos=student_4step_videos,
245-
video_name="student_4step_prediction",
246-
caption=f"Student ({student_steps} steps): {prompt}",
247-
fps=fps,
248-
state=state,
249-
)
250-
251-
del student_4step_videos, student_4step_latents
252-
gc.collect()
253-
torch.cuda.empty_cache()
254-
255223
# Generation parameters
256224
size_key = "832*480"
257225
size = SIZE_CONFIGS[size_key]
258226
frame_num = 81
259227
shift = 5.0
260228
guide_scale = 5.0
261-
262229
seed = parallel_state.get_data_parallel_rank()
263-
264-
# Get the same initial noise that was used by the student
265230
# input_rand is the unit Gaussian noise (input_student / max_sigma)
266231
input_rand = single_step_outputs.get("input_rand", None)
267232
if input_rand is not None:

dfm/src/megatron/recipes/wan/wan_dmd.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,7 @@ def wan_dmd_config(
179179
test_data_path: Optional[List[str]] = None,
180180
per_split_data_args_path: Optional[str] = None,
181181
mock: bool = False,
182+
use_fastgen_dataset: bool = False,
182183
# Model configuration
183184
tensor_parallelism: int = 1,
184185
pipeline_parallelism: int = 1,
@@ -212,6 +213,8 @@ def wan_dmd_config(
212213
test_data_path (Optional[List[str]]): List of test data paths.
213214
per_split_data_args_path (Optional[str]): Path to JSON file with per-split data configuration.
214215
mock (bool): Whether to use mock data. If True, ignores data_paths.
216+
use_fastgen_dataset (bool): Whether to use WanLatentTaskEncoder for precomputed latents (True)
217+
or WanTaskEncoder for raw data (False). Defaults to False.
215218
tensor_parallelism (int): Degree of tensor model parallelism.
216219
pipeline_parallelism (int): Degree of pipeline model parallelism.
217220
pipeline_parallelism_dtype (Optional[torch.dtype]): Data type for pipeline parallelism.
@@ -295,6 +298,7 @@ def wan_dmd_config(
295298
num_workers=10,
296299
task_encoder_seq_length=None,
297300
packing_buffer_size=40, # 131,072 = 2^17 tokens, each 5 secs of 832*480 is about 45k tokens
301+
use_fastgen_dataset=use_fastgen_dataset, # Pass flag instead of instance
298302
)
299303

300304
# Config Container

examples/megatron/recipes/wan/wan_dmd.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,11 @@ def parse_cli_args() -> Tuple[argparse.Namespace, list[str]]:
9393
default="finetune",
9494
help="Set training mode, 'pretrain' or 'finetune'.",
9595
)
96+
parser.add_argument(
97+
"--use-fastgen-dataset",
98+
action="store_true",
99+
help="Use WanLatentTaskEncoder for precomputed latents instead of WanTaskEncoder.",
100+
)
96101
parser.add_argument(
97102
"--config-file",
98103
type=str,
@@ -138,7 +143,11 @@ def main() -> None:
138143
logger.info("------------------------------------------------------------------")
139144

140145
# Load base configuration from the recipe as a Python dataclass
141-
cfg: ConfigContainer = wan_dmd_config(mock=args.mock, training_mode=args.training_mode)
146+
cfg: ConfigContainer = wan_dmd_config(
147+
mock=args.mock,
148+
training_mode=args.training_mode,
149+
use_fastgen_dataset=args.use_fastgen_dataset,
150+
)
142151
logger.info("Loaded base configuration")
143152

144153
# Print configuration on rank 0

0 commit comments

Comments
 (0)