3232from megatron .core .utils import get_model_config , unwrap_model
3333
3434import wandb
35- from dfm .src .fastgen .fastgen .methods .model import FastGenModel
3635from dfm .src .megatron .model .wan .flow_matching .flow_inference_pipeline import FlowInferencePipeline
3736from dfm .src .megatron .model .wan .inference import SIZE_CONFIGS
3837from dfm .src .megatron .model .wan .wan_step import wan_data_step
@@ -128,11 +127,11 @@ def __init__(
128127 def _get_neg_condition (self , unwrapped_model ):
129128 """
130129 Get the negative condition embedding, computing and caching it on first call.
131- The negative condition is the embedding of an empty string "" .
130+ The negative condition uses the prompt from self.inference_cfg.english_sample_neg_prompt .
132131 """
133132 if self ._neg_condition is None :
134133 logger .info ("Computing and caching negative condition embedding..." )
135- neg_prompt = ["" ]
134+ neg_prompt = [self . inference_cfg . english_sample_neg_prompt ]
136135 neg_condition = unwrapped_model .get_text_encoder ().encode (neg_prompt , precision = torch .bfloat16 )
137136 self ._neg_condition = neg_condition .transpose (0 , 1 ).contiguous ()
138137 logger .info (f"Negative condition cached with shape: { self ._neg_condition .shape } " )
@@ -176,7 +175,7 @@ def on_train_start(self, student, teacher, fake_score, state: GlobalState):
176175
177176 def on_validation_start (self , single_step_outputs , batch , student , teacher , state : GlobalState ):
178177 """
179- Generate validation videos from teacher (50 steps) and student (1 step ).
178+ Generate validation videos from teacher (50 steps) and student (N steps based on config ).
180179 Logs videos to Weights & Biases.
181180 """
182181 if self ._inference_pipeline is None :
@@ -187,8 +186,10 @@ def on_validation_start(self, single_step_outputs, batch, student, teacher, stat
187186 torch .cuda .empty_cache ()
188187
189188 # Create pipeline with teacher model (we'll swap for student later)
190-
191189 gen_latent = single_step_outputs ["gen_rand" ]
190+ if callable (gen_latent ):
191+ logger .info ("gen_rand is callable (multi-step generation), invoking it to get latents..." )
192+ gen_latent = gen_latent ()
192193 with torch .no_grad ():
193194 gen_videos = self ._inference_pipeline ._decode_latents (gen_latent , sample = False )
194195 fps = self .inference_cfg .sample_fps
@@ -205,10 +206,11 @@ def on_validation_start(self, single_step_outputs, batch, student, teacher, stat
205206 prompt = "The video captures a series of images showing a group of children seated in an outdoor setting, possibly at a sports event. The children are dressed in casual attire, with one wearing a red top and another in a white top with a rainbow design. The background is filled with other spectators, some of whom are wearing baseball caps. The lighting suggests it's either late afternoon or early evening, and the atmosphere appears to be casual and relaxed."
206207
207208 print ("prompt" , prompt )
209+ student_steps = student .config .student_sample_steps
208210 self ._log_videos_to_wandb (
209211 videos = gen_videos ,
210- video_name = "student_prediction " ,
211- caption = f"Student (1 step): { prompt } " ,
212+ video_name = f"student_ { student_steps } step_prediction " ,
213+ caption = f"{ prompt } " ,
212214 fps = fps ,
213215 state = state ,
214216 )
@@ -218,50 +220,13 @@ def on_validation_start(self, single_step_outputs, batch, student, teacher, stat
218220 gc .collect ()
219221 torch .cuda .empty_cache ()
220222
221- student_steps = 4
222- input_rand = single_step_outputs .get ("input_rand" , None )
223- logger .info (f"Generating validation video from student with { student_steps } steps using generator_fn..." )
224-
225- # Get condition from batch
226- condition = batch .get ("context_embeddings" , None )
227- # Extract prompt for caption
228-
229- with torch .no_grad ():
230- # Wrap student to adapt interface for FastGenModel.generator_fn
231- wrapped_student = MegatronFastGenInferenceWrapper (student , batch )
232- # Use FastGenModel.generator_fn directly
233- student_4step_latents = FastGenModel .generator_fn (
234- net = wrapped_student ,
235- noise = input_rand , # [B, C, T, H, W] unit Gaussian
236- condition = condition ,
237- student_sample_steps = student_steps ,
238- student_sample_type = "sde" , # stochastic sampling
239- )
240-
241- # Decode latents to video
242- student_4step_videos = self ._inference_pipeline ._decode_latents (student_4step_latents , sample = False )
243- self ._log_videos_to_wandb (
244- videos = student_4step_videos ,
245- video_name = "student_4step_prediction" ,
246- caption = f"Student ({ student_steps } steps): { prompt } " ,
247- fps = fps ,
248- state = state ,
249- )
250-
251- del student_4step_videos , student_4step_latents
252- gc .collect ()
253- torch .cuda .empty_cache ()
254-
255223 # Generation parameters
256224 size_key = "832*480"
257225 size = SIZE_CONFIGS [size_key ]
258226 frame_num = 81
259227 shift = 5.0
260228 guide_scale = 5.0
261-
262229 seed = parallel_state .get_data_parallel_rank ()
263-
264- # Get the same initial noise that was used by the student
265230 # input_rand is the unit Gaussian noise (input_student / max_sigma)
266231 input_rand = single_step_outputs .get ("input_rand" , None )
267232 if input_rand is not None :
0 commit comments