Skip to content

Commit b114620

Browse files
sayakpaulstevhliudg845
authored
Add examples on how to profile a pipeline (#13356)
* add a profiling worflow. * fix * fix * more clarification * add points. * up * cache hooks * improve readme. * propagate deletion. * up * up * wan fixes. * more * up * add more traces. * up * better title * cuda graphs. * up * Apply suggestions from code review Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * add torch.compile link. * approach -> How the tooling works * table * unavoidable gaps. * make important * note on regional compilation * Apply suggestions from code review Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> * make regional compilation note clearer. * Apply suggestions from code review Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> * clarify scheduler related changes. * Apply suggestions from code review Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> * Update examples/profiling/README.md Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> * up * formatting * benchmarking runtime * up * up * up * up * Update examples/profiling/README.md Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> --------- Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>
1 parent 447e571 commit b114620

File tree

8 files changed

+841
-14
lines changed

8 files changed

+841
-14
lines changed

examples/profiling/README.md

Lines changed: 342 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 196 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,196 @@
1+
"""
2+
Profile diffusers pipelines with torch.profiler.
3+
4+
Usage:
5+
python profiling/profiling_pipelines.py --pipeline flux --mode eager
6+
python profiling/profiling_pipelines.py --pipeline flux --mode compile
7+
python profiling/profiling_pipelines.py --pipeline flux --mode both
8+
python profiling/profiling_pipelines.py --pipeline all --mode eager
9+
python profiling/profiling_pipelines.py --pipeline wan --mode eager --full_decode
10+
python profiling/profiling_pipelines.py --pipeline flux --mode compile --num_steps 4
11+
12+
Benchmarking (wall-clock time, no profiler overhead):
13+
python profiling/profiling_pipelines.py --pipeline flux --mode compile --benchmark
14+
python profiling/profiling_pipelines.py --pipeline flux --mode both --benchmark --num_runs 10 --num_warmups 3
15+
"""
16+
17+
import argparse
18+
import copy
19+
import logging
20+
21+
import torch
22+
from profiling_utils import PipelineProfiler, PipelineProfilingConfig
23+
24+
25+
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s: %(message)s")
26+
logger = logging.getLogger(__name__)
27+
28+
PROMPT = "A cat holding a sign that says hello world"
29+
30+
31+
def build_registry():
32+
"""Build the pipeline config registry. Imports are deferred to avoid loading all pipelines upfront."""
33+
from diffusers import Flux2KleinPipeline, FluxPipeline, LTX2Pipeline, QwenImagePipeline, WanPipeline
34+
35+
return {
36+
"flux": PipelineProfilingConfig(
37+
name="flux",
38+
pipeline_cls=FluxPipeline,
39+
pipeline_init_kwargs={
40+
"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-dev",
41+
"torch_dtype": torch.bfloat16,
42+
},
43+
pipeline_call_kwargs={
44+
"prompt": PROMPT,
45+
"height": 1024,
46+
"width": 1024,
47+
"num_inference_steps": 4,
48+
"guidance_scale": 3.5,
49+
"output_type": "latent",
50+
},
51+
),
52+
"flux2": PipelineProfilingConfig(
53+
name="flux2",
54+
pipeline_cls=Flux2KleinPipeline,
55+
pipeline_init_kwargs={
56+
"pretrained_model_name_or_path": "black-forest-labs/FLUX.2-klein-base-9B",
57+
"torch_dtype": torch.bfloat16,
58+
},
59+
pipeline_call_kwargs={
60+
"prompt": PROMPT,
61+
"height": 1024,
62+
"width": 1024,
63+
"num_inference_steps": 4,
64+
"guidance_scale": 3.5,
65+
"output_type": "latent",
66+
},
67+
),
68+
"wan": PipelineProfilingConfig(
69+
name="wan",
70+
pipeline_cls=WanPipeline,
71+
pipeline_init_kwargs={
72+
"pretrained_model_name_or_path": "Wan-AI/Wan2.1-T2V-14B-Diffusers",
73+
"torch_dtype": torch.bfloat16,
74+
},
75+
pipeline_call_kwargs={
76+
"prompt": PROMPT,
77+
"negative_prompt": "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards",
78+
"height": 480,
79+
"width": 832,
80+
"num_frames": 81,
81+
"num_inference_steps": 4,
82+
"output_type": "latent",
83+
},
84+
),
85+
"ltx2": PipelineProfilingConfig(
86+
name="ltx2",
87+
pipeline_cls=LTX2Pipeline,
88+
pipeline_init_kwargs={
89+
"pretrained_model_name_or_path": "Lightricks/LTX-2",
90+
"torch_dtype": torch.bfloat16,
91+
},
92+
pipeline_call_kwargs={
93+
"prompt": PROMPT,
94+
"negative_prompt": "worst quality, inconsistent motion, blurry, jittery, distorted",
95+
"height": 512,
96+
"width": 768,
97+
"num_frames": 121,
98+
"num_inference_steps": 4,
99+
"guidance_scale": 4.0,
100+
"output_type": "latent",
101+
},
102+
),
103+
"qwenimage": PipelineProfilingConfig(
104+
name="qwenimage",
105+
pipeline_cls=QwenImagePipeline,
106+
pipeline_init_kwargs={
107+
"pretrained_model_name_or_path": "Qwen/Qwen-Image",
108+
"torch_dtype": torch.bfloat16,
109+
},
110+
pipeline_call_kwargs={
111+
"prompt": PROMPT,
112+
"negative_prompt": " ",
113+
"height": 1024,
114+
"width": 1024,
115+
"num_inference_steps": 4,
116+
"true_cfg_scale": 4.0,
117+
"output_type": "latent",
118+
},
119+
),
120+
}
121+
122+
123+
def main():
124+
parser = argparse.ArgumentParser(description="Profile diffusers pipelines with torch.profiler")
125+
parser.add_argument(
126+
"--pipeline",
127+
choices=["flux", "flux2", "wan", "ltx2", "qwenimage", "all"],
128+
required=True,
129+
help="Which pipeline to profile",
130+
)
131+
parser.add_argument(
132+
"--mode",
133+
choices=["eager", "compile", "both"],
134+
default="eager",
135+
help="Run in eager mode, compile mode, or both",
136+
)
137+
parser.add_argument("--output_dir", default="profiling_results", help="Directory for trace output")
138+
parser.add_argument("--num_steps", type=int, default=None, help="Override num_inference_steps")
139+
parser.add_argument("--full_decode", action="store_true", help="Profile including VAE decode (output_type='pil')")
140+
parser.add_argument(
141+
"--compile_mode",
142+
default="default",
143+
choices=["default", "reduce-overhead", "max-autotune"],
144+
help="torch.compile mode",
145+
)
146+
parser.add_argument("--compile_fullgraph", action="store_true", help="Use fullgraph=True for torch.compile")
147+
parser.add_argument(
148+
"--compile_regional",
149+
action="store_true",
150+
help="Use compile_repeated_blocks() instead of full model compile",
151+
)
152+
parser.add_argument(
153+
"--benchmark",
154+
action="store_true",
155+
help="Benchmark wall-clock time instead of profiling. Uses CUDA events, no profiler overhead.",
156+
)
157+
parser.add_argument("--num_runs", type=int, default=5, help="Number of timed runs for benchmarking")
158+
parser.add_argument("--num_warmups", type=int, default=2, help="Number of warmup runs for benchmarking")
159+
args = parser.parse_args()
160+
161+
registry = build_registry()
162+
163+
pipeline_names = list(registry.keys()) if args.pipeline == "all" else [args.pipeline]
164+
modes = ["eager", "compile"] if args.mode == "both" else [args.mode]
165+
166+
for pipeline_name in pipeline_names:
167+
for mode in modes:
168+
config = copy.deepcopy(registry[pipeline_name])
169+
170+
# Apply overrides
171+
if args.num_steps is not None:
172+
config.pipeline_call_kwargs["num_inference_steps"] = args.num_steps
173+
if args.full_decode:
174+
config.pipeline_call_kwargs["output_type"] = "pil"
175+
if mode == "compile":
176+
config.compile_kwargs = {
177+
"fullgraph": args.compile_fullgraph,
178+
"mode": args.compile_mode,
179+
}
180+
config.compile_regional = args.compile_regional
181+
182+
profiler = PipelineProfiler(config, args.output_dir)
183+
try:
184+
if args.benchmark:
185+
logger.info(f"Benchmarking {pipeline_name} in {mode} mode...")
186+
profiler.benchmark(num_runs=args.num_runs, num_warmups=args.num_warmups)
187+
else:
188+
logger.info(f"Profiling {pipeline_name} in {mode} mode...")
189+
trace_file = profiler.run()
190+
logger.info(f"Done: {trace_file}")
191+
except Exception as e:
192+
logger.error(f"Failed to {'benchmark' if args.benchmark else 'profile'} {pipeline_name} ({mode}): {e}")
193+
194+
195+
if __name__ == "__main__":
196+
main()

0 commit comments

Comments
 (0)