Skip to content

Commit 011b294

Browse files
authored
Merge branch 'main' into fix-torchao-groupoffloading
2 parents cb7402e + fbe8a75 commit 011b294

29 files changed

Lines changed: 3477 additions & 61 deletions

examples/profiling/README.md

Lines changed: 346 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 196 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,196 @@
1+
"""
2+
Profile diffusers pipelines with torch.profiler.
3+
4+
Usage:
5+
python profiling/profiling_pipelines.py --pipeline flux --mode eager
6+
python profiling/profiling_pipelines.py --pipeline flux --mode compile
7+
python profiling/profiling_pipelines.py --pipeline flux --mode both
8+
python profiling/profiling_pipelines.py --pipeline all --mode eager
9+
python profiling/profiling_pipelines.py --pipeline wan --mode eager --full_decode
10+
python profiling/profiling_pipelines.py --pipeline flux --mode compile --num_steps 4
11+
12+
Benchmarking (wall-clock time, no profiler overhead):
13+
python profiling/profiling_pipelines.py --pipeline flux --mode compile --benchmark
14+
python profiling/profiling_pipelines.py --pipeline flux --mode both --benchmark --num_runs 10 --num_warmups 3
15+
"""
16+
17+
import argparse
18+
import copy
19+
import logging
20+
21+
import torch
22+
from profiling_utils import PipelineProfiler, PipelineProfilingConfig
23+
24+
25+
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s: %(message)s")
26+
logger = logging.getLogger(__name__)
27+
28+
PROMPT = "A cat holding a sign that says hello world"
29+
30+
31+
def build_registry():
32+
"""Build the pipeline config registry. Imports are deferred to avoid loading all pipelines upfront."""
33+
from diffusers import Flux2KleinPipeline, FluxPipeline, LTX2Pipeline, QwenImagePipeline, WanPipeline
34+
35+
return {
36+
"flux": PipelineProfilingConfig(
37+
name="flux",
38+
pipeline_cls=FluxPipeline,
39+
pipeline_init_kwargs={
40+
"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-dev",
41+
"torch_dtype": torch.bfloat16,
42+
},
43+
pipeline_call_kwargs={
44+
"prompt": PROMPT,
45+
"height": 1024,
46+
"width": 1024,
47+
"num_inference_steps": 4,
48+
"guidance_scale": 3.5,
49+
"output_type": "latent",
50+
},
51+
),
52+
"flux2": PipelineProfilingConfig(
53+
name="flux2",
54+
pipeline_cls=Flux2KleinPipeline,
55+
pipeline_init_kwargs={
56+
"pretrained_model_name_or_path": "black-forest-labs/FLUX.2-klein-base-9B",
57+
"torch_dtype": torch.bfloat16,
58+
},
59+
pipeline_call_kwargs={
60+
"prompt": PROMPT,
61+
"height": 1024,
62+
"width": 1024,
63+
"num_inference_steps": 4,
64+
"guidance_scale": 3.5,
65+
"output_type": "latent",
66+
},
67+
),
68+
"wan": PipelineProfilingConfig(
69+
name="wan",
70+
pipeline_cls=WanPipeline,
71+
pipeline_init_kwargs={
72+
"pretrained_model_name_or_path": "Wan-AI/Wan2.1-T2V-14B-Diffusers",
73+
"torch_dtype": torch.bfloat16,
74+
},
75+
pipeline_call_kwargs={
76+
"prompt": PROMPT,
77+
"negative_prompt": "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards",
78+
"height": 480,
79+
"width": 832,
80+
"num_frames": 81,
81+
"num_inference_steps": 4,
82+
"output_type": "latent",
83+
},
84+
),
85+
"ltx2": PipelineProfilingConfig(
86+
name="ltx2",
87+
pipeline_cls=LTX2Pipeline,
88+
pipeline_init_kwargs={
89+
"pretrained_model_name_or_path": "Lightricks/LTX-2",
90+
"torch_dtype": torch.bfloat16,
91+
},
92+
pipeline_call_kwargs={
93+
"prompt": PROMPT,
94+
"negative_prompt": "worst quality, inconsistent motion, blurry, jittery, distorted",
95+
"height": 512,
96+
"width": 768,
97+
"num_frames": 121,
98+
"num_inference_steps": 4,
99+
"guidance_scale": 4.0,
100+
"output_type": "latent",
101+
},
102+
),
103+
"qwenimage": PipelineProfilingConfig(
104+
name="qwenimage",
105+
pipeline_cls=QwenImagePipeline,
106+
pipeline_init_kwargs={
107+
"pretrained_model_name_or_path": "Qwen/Qwen-Image",
108+
"torch_dtype": torch.bfloat16,
109+
},
110+
pipeline_call_kwargs={
111+
"prompt": PROMPT,
112+
"negative_prompt": " ",
113+
"height": 1024,
114+
"width": 1024,
115+
"num_inference_steps": 4,
116+
"true_cfg_scale": 4.0,
117+
"output_type": "latent",
118+
},
119+
),
120+
}
121+
122+
123+
def main():
124+
parser = argparse.ArgumentParser(description="Profile diffusers pipelines with torch.profiler")
125+
parser.add_argument(
126+
"--pipeline",
127+
choices=["flux", "flux2", "wan", "ltx2", "qwenimage", "all"],
128+
required=True,
129+
help="Which pipeline to profile",
130+
)
131+
parser.add_argument(
132+
"--mode",
133+
choices=["eager", "compile", "both"],
134+
default="eager",
135+
help="Run in eager mode, compile mode, or both",
136+
)
137+
parser.add_argument("--output_dir", default="profiling_results", help="Directory for trace output")
138+
parser.add_argument("--num_steps", type=int, default=None, help="Override num_inference_steps")
139+
parser.add_argument("--full_decode", action="store_true", help="Profile including VAE decode (output_type='pil')")
140+
parser.add_argument(
141+
"--compile_mode",
142+
default="default",
143+
choices=["default", "reduce-overhead", "max-autotune"],
144+
help="torch.compile mode",
145+
)
146+
parser.add_argument("--compile_fullgraph", action="store_true", help="Use fullgraph=True for torch.compile")
147+
parser.add_argument(
148+
"--compile_regional",
149+
action="store_true",
150+
help="Use compile_repeated_blocks() instead of full model compile",
151+
)
152+
parser.add_argument(
153+
"--benchmark",
154+
action="store_true",
155+
help="Benchmark wall-clock time instead of profiling. Uses CUDA events, no profiler overhead.",
156+
)
157+
parser.add_argument("--num_runs", type=int, default=5, help="Number of timed runs for benchmarking")
158+
parser.add_argument("--num_warmups", type=int, default=2, help="Number of warmup runs for benchmarking")
159+
args = parser.parse_args()
160+
161+
registry = build_registry()
162+
163+
pipeline_names = list(registry.keys()) if args.pipeline == "all" else [args.pipeline]
164+
modes = ["eager", "compile"] if args.mode == "both" else [args.mode]
165+
166+
for pipeline_name in pipeline_names:
167+
for mode in modes:
168+
config = copy.deepcopy(registry[pipeline_name])
169+
170+
# Apply overrides
171+
if args.num_steps is not None:
172+
config.pipeline_call_kwargs["num_inference_steps"] = args.num_steps
173+
if args.full_decode:
174+
config.pipeline_call_kwargs["output_type"] = "pil"
175+
if mode == "compile":
176+
config.compile_kwargs = {
177+
"fullgraph": args.compile_fullgraph,
178+
"mode": args.compile_mode,
179+
}
180+
config.compile_regional = args.compile_regional
181+
182+
profiler = PipelineProfiler(config, args.output_dir)
183+
try:
184+
if args.benchmark:
185+
logger.info(f"Benchmarking {pipeline_name} in {mode} mode...")
186+
profiler.benchmark(num_runs=args.num_runs, num_warmups=args.num_warmups)
187+
else:
188+
logger.info(f"Profiling {pipeline_name} in {mode} mode...")
189+
trace_file = profiler.run()
190+
logger.info(f"Done: {trace_file}")
191+
except Exception as e:
192+
logger.error(f"Failed to {'benchmark' if args.benchmark else 'profile'} {pipeline_name} ({mode}): {e}")
193+
194+
195+
if __name__ == "__main__":
196+
main()

0 commit comments

Comments
 (0)