-
Notifications
You must be signed in to change notification settings - Fork 156
Expand file tree
/
Copy pathgenerate.py
More file actions
499 lines (392 loc) · 18 KB
/
generate.py
File metadata and controls
499 lines (392 loc) · 18 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Unified Diffusion Generation Script
Single entry point for generating images and videos from all supported diffusion
models (FLUX, Wan 2.1/2.2, HunyuanVideo). Supports single-GPU and distributed
inference with optional checkpoint loading.
Usage:
# Single-GPU
python examples/diffusion/generate/generate.py \
-c examples/diffusion/generate/configs/generate_wan.yaml
# Multi-GPU distributed
torchrun --nproc-per-node=8 \
examples/diffusion/generate/generate.py \
-c examples/diffusion/generate/configs/generate_wan_distributed.yaml
# With checkpoint and custom prompts
python examples/diffusion/generate/generate.py \
-c examples/diffusion/generate/configs/generate_wan.yaml \
--model.checkpoint ./checkpoints/step_1000 \
--inference.prompts '["A dog running on a beach"]'
"""
import inspect
import logging
import os
from pathlib import Path
import torch
import torch.distributed as dist
from nemo_automodel.components.config._arg_parser import parse_args_and_load_config
from nemo_automodel.components.loggers.log_utils import setup_logging
from nemo_automodel.shared.transformers_patches import patch_t5_layer_norm
logger = logging.getLogger(__name__)
# Pipeline class name -> output type mapping
_PIPELINE_OUTPUT_TYPES = {
"FluxPipeline": "image",
"QwenImagePipeline": "image",
"WanPipeline": "video",
"HunyuanVideoPipeline": "video",
"HunyuanVideo15Pipeline": "video",
}
def maybe_init_distributed(cfg):
"""Initialize distributed environment if configured.
Args:
cfg: Config node with optional `distributed` section.
Returns:
DistInfo if distributed is configured, None otherwise.
"""
dist_cfg = getattr(cfg, "distributed", None)
if dist_cfg is None:
return None
from nemo_automodel.components.distributed.init_utils import initialize_distributed
backend = getattr(dist_cfg, "backend", "nccl")
timeout = getattr(dist_cfg, "timeout_minutes", 10)
dist_info = initialize_distributed(backend=backend, timeout_minutes=timeout)
logger.info("Distributed initialized: rank=%d, world_size=%d", dist_info.rank, dist_info.world_size)
return dist_info
def load_pipeline(cfg, dist_info):
"""Load the diffusion pipeline, auto-detecting model type.
Uses NeMoAutoDiffusionPipeline for both single-GPU and distributed
inference. When no distributed config is present, parallelization is
skipped automatically.
Args:
cfg: Config node with `model.pretrained_model_name_or_path`.
dist_info: DistInfo from maybe_init_distributed, or None.
Returns:
A diffusers pipeline instance.
"""
from nemo_automodel._diffusers.auto_diffusion_pipeline import NeMoAutoDiffusionPipeline
model_id = cfg.model.pretrained_model_name_or_path
dtype_str = getattr(cfg.inference, "dtype", "bfloat16")
torch_dtype = _resolve_dtype(dtype_str)
# Apex's FusedRMSNorm doesn't support bf16. Patch T5LayerNorm before loading
# any pipeline that may use a T5 text encoder (FLUX, HunyuanVideo, etc.).
if torch_dtype == torch.bfloat16:
patch_t5_layer_norm()
# Build parallel_scheme from distributed config (None for single-GPU).
parallel_scheme = None
if dist_info is not None and hasattr(cfg.distributed, "parallel_scheme"):
parallel_scheme = _build_parallel_scheme(cfg.distributed.parallel_scheme, dist_info)
# CPU offload requires modules to stay on CPU so enable_model_cpu_offload()
# can install per-module device hooks (called later in apply_optimizations).
vae_cfg = getattr(cfg, "vae", None)
cpu_offload = vae_cfg is not None and getattr(vae_cfg, "enable_cpu_offload", False)
pipe, _ = NeMoAutoDiffusionPipeline.from_pretrained(
model_id,
torch_dtype=torch_dtype,
parallel_scheme=parallel_scheme,
move_to_device=not cpu_offload,
)
_fix_text_encoder_weight_tying(pipe)
logger.info("Loaded pipeline: %s (distributed=%s)", type(pipe).__name__, parallel_scheme is not None)
return pipe
def _fix_text_encoder_weight_tying(pipe):
"""Fix UMT5 text encoder weight tying for transformers>=5.0.0.
The Wan 2.1 checkpoint stores the token embedding as "shared.weight",
which transformers<5 automatically tied to "encoder.embed_tokens.weight".
In v5+, this tying no longer happens during from_pretrained(), leaving
embed_tokens zero-initialized and producing all-zero text embeddings.
"""
text_encoder = getattr(pipe, "text_encoder", None)
if text_encoder is None:
return
if (
hasattr(text_encoder, "shared")
and hasattr(text_encoder, "encoder")
and hasattr(text_encoder.encoder, "embed_tokens")
and text_encoder.encoder.embed_tokens.weight.data_ptr() != text_encoder.shared.weight.data_ptr()
):
text_encoder.encoder.embed_tokens.weight = text_encoder.shared.weight
logger.info("Fixed UMT5 text encoder weight tying (shared.weight -> embed_tokens.weight)")
def _build_parallel_scheme(scheme_cfg, dist_info):
"""Build parallel_scheme dict from config for NeMoAutoDiffusionPipeline.
Args:
scheme_cfg: Config node mapping component names to their parallelism settings.
dist_info: DistInfo with distributed environment details.
Returns:
Dict mapping component names to manager kwargs dicts.
"""
parallel_scheme = {}
for comp_name in dir(scheme_cfg):
if comp_name.startswith("_"):
continue
comp_cfg = getattr(scheme_cfg, comp_name)
if comp_cfg is None:
continue
manager_args = {
"backend": "nccl",
"world_size": dist_info.world_size,
"use_hf_tp_plan": False,
}
# Copy parallelism sizes from config
for key in ("tp_size", "cp_size", "pp_size", "dp_size", "dp_replicate_size"):
val = getattr(comp_cfg, key, None)
if val is not None:
manager_args[key] = val
parallel_scheme[comp_name] = manager_args
return parallel_scheme
def load_checkpoint_into_pipeline(pipe, cfg):
"""Load a training checkpoint into the pipeline's transformer.
Expects a consolidated HF safetensors checkpoint produced by training
with model_save_format: safetensors, save_consolidated: true, and
diffusers_compatible: true. The checkpoint directory should contain
model/consolidated/ with diffusion_pytorch_model.safetensors.index.json
and the corresponding safetensors files.
Uses the standard diffusers from_pretrained() API for loading.
Args:
pipe: The diffusion pipeline with a `.transformer` attribute.
cfg: Config node with `model.checkpoint` path.
"""
checkpoint = getattr(cfg.model, "checkpoint", None)
if not checkpoint:
return
dtype_str = getattr(cfg.inference, "dtype", "bfloat16")
torch_dtype = _resolve_dtype(dtype_str)
checkpoint_dir = Path(checkpoint)
if not checkpoint_dir.exists():
raise FileNotFoundError(f"Checkpoint directory not found: {checkpoint_dir}")
ema_path = checkpoint_dir / "ema_shadow.pt"
consolidated_path = checkpoint_dir / "consolidated_model.bin"
consolidated_st_dir = checkpoint_dir / "model" / "consolidated"
sharded_dir = checkpoint_dir / "model"
if ema_path.exists():
logger.info("Loading EMA checkpoint from %s", ema_path)
ema_state = torch.load(ema_path, map_location="cuda", weights_only=True)
pipe.transformer.load_state_dict(ema_state, strict=True)
logger.info("Loaded EMA checkpoint")
elif consolidated_path.exists():
logger.info("Loading consolidated checkpoint from %s", consolidated_path)
state_dict = torch.load(consolidated_path, map_location="cuda", weights_only=True)
if "model_state_dict" in state_dict:
state_dict = state_dict["model_state_dict"]
pipe.transformer.load_state_dict(state_dict, strict=True)
logger.info("Loaded consolidated checkpoint")
elif consolidated_st_dir.is_dir() and any(
name.endswith(".safetensors") for name in os.listdir(consolidated_st_dir)
):
logger.info("Loading consolidated safetensors checkpoint from %s", consolidated_st_dir)
pipe.transformer = type(pipe.transformer).from_pretrained(str(consolidated_st_dir), torch_dtype=torch_dtype)
pipe.transformer.to("cuda")
logger.info("Loaded consolidated safetensors checkpoint")
elif sharded_dir.is_dir() and any(name.endswith(".distcp") for name in os.listdir(sharded_dir)):
logger.info("Loading sharded FSDP checkpoint from %s", sharded_dir)
pipe.transformer = _load_sharded_fsdp_checkpoint(pipe.transformer, str(sharded_dir), torch_dtype)
pipe.transformer.to("cuda", dtype=torch_dtype)
logger.info("Loaded sharded FSDP checkpoint")
else:
logger.warning("No recognized checkpoint format found in %s, using base model weights", checkpoint_dir)
def load_lora_weights_into_pipeline(pipe, cfg):
"""Load LoRA adapter weights into the pipeline's transformer.
Reads adapter_model.safetensors + adapter_config.json from the directory
specified by model.lora_weights. Does nothing if lora_weights is null/unset.
Args:
pipe: The diffusion pipeline with a `.transformer` attribute.
cfg: Config node with optional `model.lora_weights`, `model.lora_adapter_name`.
"""
lora_weights = getattr(cfg.model, "lora_weights", None)
if not lora_weights:
return
import json
from safetensors.torch import load_file
from nemo_automodel.components._peft.lora import PeftConfig, apply_lora_to_linear_modules
lora_path = Path(lora_weights)
if not lora_path.exists():
raise FileNotFoundError(f"LoRA weights directory not found: {lora_path}")
with open(lora_path / "adapter_config.json") as f:
peft_config = PeftConfig.from_dict(json.load(f))
apply_lora_to_linear_modules(pipe.transformer, peft_config, skip_freeze=True)
state_dict = load_file(lora_path / "adapter_model.safetensors", device="cuda")
pipe.transformer.load_state_dict(state_dict, strict=False)
def _load_sharded_fsdp_checkpoint(transformer, sharded_dir, torch_dtype=torch.bfloat16):
"""Load sharded FSDP/DCP checkpoint into a transformer module.
Creates a temporary gloo process group for single-GPU loading if
torch.distributed is not already initialized.
Args:
transformer: The transformer nn.Module to load weights into.
sharded_dir: Path to the directory containing .distcp shard files.
torch_dtype: The dtype to cast the transformer to before loading.
Returns:
The unwrapped transformer module with loaded checkpoint weights.
"""
from torch.distributed.checkpoint import FileSystemReader
from torch.distributed.checkpoint import load as dist_load
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import StateDictType
from torch.distributed.fsdp.api import ShardedStateDictConfig
init_dist = False
if not dist.is_initialized():
os.environ.setdefault("MASTER_ADDR", "localhost")
os.environ.setdefault("MASTER_PORT", "29500")
dist.init_process_group(backend="gloo", rank=0, world_size=1)
init_dist = True
try:
transformer.to(device="cuda", dtype=torch_dtype)
fsdp_transformer = FSDP(transformer, use_orig_params=True)
FSDP.set_state_dict_type(
fsdp_transformer,
StateDictType.SHARDED_STATE_DICT,
state_dict_config=ShardedStateDictConfig(offload_to_cpu=True),
)
model_state = fsdp_transformer.state_dict()
dist_load(state_dict=model_state, storage_reader=FileSystemReader(sharded_dir))
fsdp_transformer.load_state_dict(model_state)
# Unwrap back to the original module for inference
return fsdp_transformer.module
finally:
if init_dist:
dist.destroy_process_group()
def apply_optimizations(pipe, cfg):
"""Apply VAE and memory optimizations to the pipeline.
Args:
pipe: The diffusion pipeline.
cfg: Config node with optional `vae` section.
"""
vae_cfg = getattr(cfg, "vae", None)
if vae_cfg is None:
return
if hasattr(pipe, "vae"):
if getattr(vae_cfg, "enable_slicing", False):
pipe.vae.enable_slicing()
logger.info("Enabled VAE slicing")
if getattr(vae_cfg, "enable_tiling", False):
pipe.vae.enable_tiling()
logger.info("Enabled VAE tiling")
if getattr(vae_cfg, "enable_cpu_offload", False):
pipe.enable_model_cpu_offload()
logger.info("Enabled model CPU offload")
def detect_output_type(pipe):
"""Detect whether the pipeline produces images or videos.
Uses a class name lookup table, with a fallback that checks if the
pipeline's __call__ method accepts a `num_frames` parameter.
Args:
pipe: The diffusion pipeline instance.
Returns:
"image" or "video"
"""
class_name = type(pipe).__name__
output_type = _PIPELINE_OUTPUT_TYPES.get(class_name)
if output_type is not None:
return output_type
# Fallback: check if pipeline accepts num_frames
try:
sig = inspect.signature(pipe.__call__)
if "num_frames" in sig.parameters:
return "video"
except (ValueError, TypeError):
pass
return "image"
def run_inference(pipe, cfg, is_rank0):
"""Run inference on all configured prompts and save outputs.
Args:
pipe: The diffusion pipeline.
cfg: Config node with `inference` and `output` sections.
is_rank0: Whether this is the main process (for saving outputs).
"""
from diffusers.utils import export_to_video
output_type = detect_output_type(pipe)
prompts = cfg.inference.prompts
max_samples = getattr(cfg.inference, "max_samples", len(prompts))
prompts = prompts[:max_samples]
output_dir = Path(getattr(cfg.output, "output_dir", "./inference_outputs"))
fps = getattr(cfg.output, "fps", 16)
if is_rank0:
output_dir.mkdir(parents=True, exist_ok=True)
# Build common pipeline kwargs
pipe_kwargs = {}
for key in ("num_inference_steps", "guidance_scale", "height", "width"):
val = getattr(cfg.inference, key, None)
if val is not None:
pipe_kwargs[key] = val
# Merge model-specific pipeline_kwargs (convert ConfigNode to plain dict)
extra_kwargs = getattr(cfg.inference, "pipeline_kwargs", None)
if extra_kwargs is not None:
pipe_kwargs.update(extra_kwargs.to_dict())
# LoRA scale: passed as attention_kwargs (newer diffusers) or
# cross_attention_kwargs (older diffusers) so the transformer forward()
# applies the correct contribution weight.
lora_weights = getattr(cfg.model, "lora_weights", None)
if lora_weights:
lora_scale = getattr(cfg.model, "lora_scale", 1.0)
call_sig = inspect.signature(pipe.__call__)
if "attention_kwargs" in call_sig.parameters:
pipe_kwargs["attention_kwargs"] = {"scale": lora_scale}
elif "cross_attention_kwargs" in call_sig.parameters:
pipe_kwargs["cross_attention_kwargs"] = {"scale": lora_scale}
seed = getattr(cfg, "seed", 42)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
logger.info("Generating %d samples (%s mode)", len(prompts), output_type)
logger.info("Pipeline kwargs: %s", pipe_kwargs)
for i, prompt_text in enumerate(prompts):
logger.info("[%d/%d] Prompt: %s", i + 1, len(prompts), prompt_text[:80])
generator = torch.Generator(device="cuda").manual_seed(seed + i)
with torch.no_grad():
output = pipe(prompt=prompt_text, generator=generator, **pipe_kwargs)
if not is_rank0:
continue
# Save output
safe_name = "".join(c if c.isalnum() or c in " _-" else "" for c in prompt_text)[:50].strip().replace(" ", "_")
if output_type == "video":
frames = output.frames[0]
output_path = output_dir / f"sample_{i:03d}_{safe_name}.mp4"
export_to_video(frames, str(output_path), fps=fps)
else:
image = output.images[0]
output_path = output_dir / f"sample_{i:03d}_{safe_name}.png"
image.save(str(output_path))
logger.info("Saved: %s", output_path)
def _resolve_dtype(dtype_str):
"""Convert a dtype string to a torch.dtype."""
dtype_map = {
"float16": torch.float16,
"bfloat16": torch.bfloat16,
"float32": torch.float32,
}
return dtype_map.get(dtype_str, torch.bfloat16)
def main():
cfg = parse_args_and_load_config()
setup_logging()
# 1. Initialize distributed (if configured)
dist_info = maybe_init_distributed(cfg)
is_rank0 = dist_info is None or dist_info.is_main
# 2. Load pipeline
pipe = load_pipeline(cfg, dist_info)
# 3. Load checkpoint (if configured)
load_checkpoint_into_pipeline(pipe, cfg)
# 3b. Load LoRA adapter weights (if configured)
load_lora_weights_into_pipeline(pipe, cfg)
# 4. Apply VAE / memory optimizations
apply_optimizations(pipe, cfg)
# 5. Synchronize before inference
if dist_info is not None:
dist.barrier()
# 6. Run inference
run_inference(pipe, cfg, is_rank0)
# 7. Cleanup
if dist_info is not None:
dist.barrier()
dist.destroy_process_group()
logger.info("Distributed inference complete")
if __name__ == "__main__":
main()