|
| 1 | +""" |
| 2 | +Profile diffusers pipelines with torch.profiler. |
| 3 | +
|
| 4 | +Usage: |
| 5 | + python profiling/profiling_pipelines.py --pipeline flux --mode eager |
| 6 | + python profiling/profiling_pipelines.py --pipeline flux --mode compile |
| 7 | + python profiling/profiling_pipelines.py --pipeline flux --mode both |
| 8 | + python profiling/profiling_pipelines.py --pipeline all --mode eager |
| 9 | + python profiling/profiling_pipelines.py --pipeline wan --mode eager --full_decode |
| 10 | + python profiling/profiling_pipelines.py --pipeline flux --mode compile --num_steps 4 |
| 11 | +
|
| 12 | +Benchmarking (wall-clock time, no profiler overhead): |
| 13 | + python profiling/profiling_pipelines.py --pipeline flux --mode compile --benchmark |
| 14 | + python profiling/profiling_pipelines.py --pipeline flux --mode both --benchmark --num_runs 10 --num_warmups 3 |
| 15 | +""" |
| 16 | + |
| 17 | +import argparse |
| 18 | +import copy |
| 19 | +import logging |
| 20 | + |
| 21 | +import torch |
| 22 | +from profiling_utils import PipelineProfiler, PipelineProfilingConfig |
| 23 | + |
| 24 | + |
| 25 | +logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s: %(message)s") |
| 26 | +logger = logging.getLogger(__name__) |
| 27 | + |
| 28 | +PROMPT = "A cat holding a sign that says hello world" |
| 29 | + |
| 30 | + |
| 31 | +def build_registry(): |
| 32 | + """Build the pipeline config registry. Imports are deferred to avoid loading all pipelines upfront.""" |
| 33 | + from diffusers import Flux2KleinPipeline, FluxPipeline, LTX2Pipeline, QwenImagePipeline, WanPipeline |
| 34 | + |
| 35 | + return { |
| 36 | + "flux": PipelineProfilingConfig( |
| 37 | + name="flux", |
| 38 | + pipeline_cls=FluxPipeline, |
| 39 | + pipeline_init_kwargs={ |
| 40 | + "pretrained_model_name_or_path": "black-forest-labs/FLUX.1-dev", |
| 41 | + "torch_dtype": torch.bfloat16, |
| 42 | + }, |
| 43 | + pipeline_call_kwargs={ |
| 44 | + "prompt": PROMPT, |
| 45 | + "height": 1024, |
| 46 | + "width": 1024, |
| 47 | + "num_inference_steps": 4, |
| 48 | + "guidance_scale": 3.5, |
| 49 | + "output_type": "latent", |
| 50 | + }, |
| 51 | + ), |
| 52 | + "flux2": PipelineProfilingConfig( |
| 53 | + name="flux2", |
| 54 | + pipeline_cls=Flux2KleinPipeline, |
| 55 | + pipeline_init_kwargs={ |
| 56 | + "pretrained_model_name_or_path": "black-forest-labs/FLUX.2-klein-base-9B", |
| 57 | + "torch_dtype": torch.bfloat16, |
| 58 | + }, |
| 59 | + pipeline_call_kwargs={ |
| 60 | + "prompt": PROMPT, |
| 61 | + "height": 1024, |
| 62 | + "width": 1024, |
| 63 | + "num_inference_steps": 4, |
| 64 | + "guidance_scale": 3.5, |
| 65 | + "output_type": "latent", |
| 66 | + }, |
| 67 | + ), |
| 68 | + "wan": PipelineProfilingConfig( |
| 69 | + name="wan", |
| 70 | + pipeline_cls=WanPipeline, |
| 71 | + pipeline_init_kwargs={ |
| 72 | + "pretrained_model_name_or_path": "Wan-AI/Wan2.1-T2V-14B-Diffusers", |
| 73 | + "torch_dtype": torch.bfloat16, |
| 74 | + }, |
| 75 | + pipeline_call_kwargs={ |
| 76 | + "prompt": PROMPT, |
| 77 | + "negative_prompt": "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards", |
| 78 | + "height": 480, |
| 79 | + "width": 832, |
| 80 | + "num_frames": 81, |
| 81 | + "num_inference_steps": 4, |
| 82 | + "output_type": "latent", |
| 83 | + }, |
| 84 | + ), |
| 85 | + "ltx2": PipelineProfilingConfig( |
| 86 | + name="ltx2", |
| 87 | + pipeline_cls=LTX2Pipeline, |
| 88 | + pipeline_init_kwargs={ |
| 89 | + "pretrained_model_name_or_path": "Lightricks/LTX-2", |
| 90 | + "torch_dtype": torch.bfloat16, |
| 91 | + }, |
| 92 | + pipeline_call_kwargs={ |
| 93 | + "prompt": PROMPT, |
| 94 | + "negative_prompt": "worst quality, inconsistent motion, blurry, jittery, distorted", |
| 95 | + "height": 512, |
| 96 | + "width": 768, |
| 97 | + "num_frames": 121, |
| 98 | + "num_inference_steps": 4, |
| 99 | + "guidance_scale": 4.0, |
| 100 | + "output_type": "latent", |
| 101 | + }, |
| 102 | + ), |
| 103 | + "qwenimage": PipelineProfilingConfig( |
| 104 | + name="qwenimage", |
| 105 | + pipeline_cls=QwenImagePipeline, |
| 106 | + pipeline_init_kwargs={ |
| 107 | + "pretrained_model_name_or_path": "Qwen/Qwen-Image", |
| 108 | + "torch_dtype": torch.bfloat16, |
| 109 | + }, |
| 110 | + pipeline_call_kwargs={ |
| 111 | + "prompt": PROMPT, |
| 112 | + "negative_prompt": " ", |
| 113 | + "height": 1024, |
| 114 | + "width": 1024, |
| 115 | + "num_inference_steps": 4, |
| 116 | + "true_cfg_scale": 4.0, |
| 117 | + "output_type": "latent", |
| 118 | + }, |
| 119 | + ), |
| 120 | + } |
| 121 | + |
| 122 | + |
| 123 | +def main(): |
| 124 | + parser = argparse.ArgumentParser(description="Profile diffusers pipelines with torch.profiler") |
| 125 | + parser.add_argument( |
| 126 | + "--pipeline", |
| 127 | + choices=["flux", "flux2", "wan", "ltx2", "qwenimage", "all"], |
| 128 | + required=True, |
| 129 | + help="Which pipeline to profile", |
| 130 | + ) |
| 131 | + parser.add_argument( |
| 132 | + "--mode", |
| 133 | + choices=["eager", "compile", "both"], |
| 134 | + default="eager", |
| 135 | + help="Run in eager mode, compile mode, or both", |
| 136 | + ) |
| 137 | + parser.add_argument("--output_dir", default="profiling_results", help="Directory for trace output") |
| 138 | + parser.add_argument("--num_steps", type=int, default=None, help="Override num_inference_steps") |
| 139 | + parser.add_argument("--full_decode", action="store_true", help="Profile including VAE decode (output_type='pil')") |
| 140 | + parser.add_argument( |
| 141 | + "--compile_mode", |
| 142 | + default="default", |
| 143 | + choices=["default", "reduce-overhead", "max-autotune"], |
| 144 | + help="torch.compile mode", |
| 145 | + ) |
| 146 | + parser.add_argument("--compile_fullgraph", action="store_true", help="Use fullgraph=True for torch.compile") |
| 147 | + parser.add_argument( |
| 148 | + "--compile_regional", |
| 149 | + action="store_true", |
| 150 | + help="Use compile_repeated_blocks() instead of full model compile", |
| 151 | + ) |
| 152 | + parser.add_argument( |
| 153 | + "--benchmark", |
| 154 | + action="store_true", |
| 155 | + help="Benchmark wall-clock time instead of profiling. Uses CUDA events, no profiler overhead.", |
| 156 | + ) |
| 157 | + parser.add_argument("--num_runs", type=int, default=5, help="Number of timed runs for benchmarking") |
| 158 | + parser.add_argument("--num_warmups", type=int, default=2, help="Number of warmup runs for benchmarking") |
| 159 | + args = parser.parse_args() |
| 160 | + |
| 161 | + registry = build_registry() |
| 162 | + |
| 163 | + pipeline_names = list(registry.keys()) if args.pipeline == "all" else [args.pipeline] |
| 164 | + modes = ["eager", "compile"] if args.mode == "both" else [args.mode] |
| 165 | + |
| 166 | + for pipeline_name in pipeline_names: |
| 167 | + for mode in modes: |
| 168 | + config = copy.deepcopy(registry[pipeline_name]) |
| 169 | + |
| 170 | + # Apply overrides |
| 171 | + if args.num_steps is not None: |
| 172 | + config.pipeline_call_kwargs["num_inference_steps"] = args.num_steps |
| 173 | + if args.full_decode: |
| 174 | + config.pipeline_call_kwargs["output_type"] = "pil" |
| 175 | + if mode == "compile": |
| 176 | + config.compile_kwargs = { |
| 177 | + "fullgraph": args.compile_fullgraph, |
| 178 | + "mode": args.compile_mode, |
| 179 | + } |
| 180 | + config.compile_regional = args.compile_regional |
| 181 | + |
| 182 | + profiler = PipelineProfiler(config, args.output_dir) |
| 183 | + try: |
| 184 | + if args.benchmark: |
| 185 | + logger.info(f"Benchmarking {pipeline_name} in {mode} mode...") |
| 186 | + profiler.benchmark(num_runs=args.num_runs, num_warmups=args.num_warmups) |
| 187 | + else: |
| 188 | + logger.info(f"Profiling {pipeline_name} in {mode} mode...") |
| 189 | + trace_file = profiler.run() |
| 190 | + logger.info(f"Done: {trace_file}") |
| 191 | + except Exception as e: |
| 192 | + logger.error(f"Failed to {'benchmark' if args.benchmark else 'profile'} {pipeline_name} ({mode}): {e}") |
| 193 | + |
| 194 | + |
| 195 | +if __name__ == "__main__": |
| 196 | + main() |
0 commit comments