Skip to content

Commit dc594e4

Browse files
committed
Integrate torchax custom attention kernel into ulysses
1 parent c98002f commit dc594e4

5 files changed

Lines changed: 1079 additions & 129 deletions

File tree

src/maxdiffusion/generate_wan.py

Lines changed: 40 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -255,6 +255,8 @@ def run(config, pipeline=None, filename_prefix="", commit_hash=None):
255255
f"Num steps: {config.num_inference_steps}, height: {config.height}, width: {config.width}, frames: {config.num_frames}"
256256
)
257257
videos = call_pipeline(config, pipeline, prompt, negative_prompt)
258+
if isinstance(videos, tuple):
259+
videos = videos[0]
258260

259261
max_logging.log("===================== Model details =======================")
260262
max_logging.log(f"model name: {config.model_name}")
@@ -278,7 +280,12 @@ def run(config, pipeline=None, filename_prefix="", commit_hash=None):
278280
upload_video_to_gcs(os.path.join(config.output_dir, config.run_name), video_path)
279281

280282
s0 = time.perf_counter()
281-
videos = call_pipeline(config, pipeline, prompt, negative_prompt)
283+
outputs = call_pipeline(config, pipeline, prompt, negative_prompt)
284+
if isinstance(outputs, tuple):
285+
videos, trace = outputs
286+
else:
287+
videos = outputs
288+
trace = {}
282289
generation_time = time.perf_counter() - s0
283290
max_logging.log(f"generation_time: {generation_time}")
284291
if writer and jax.process_index() == 0:
@@ -291,18 +298,39 @@ def run(config, pipeline=None, filename_prefix="", commit_hash=None):
291298
max_logging.log(f"generation time per video: {generation_time_per_video}")
292299
else:
293300
max_logging.log("Warning: Number of videos is zero, cannot calculate generation_time_per_video.")
294-
max_logging.log(
295-
f"\n{'=' * 50}\n"
296-
f" TIMING SUMMARY\n"
297-
f"{'=' * 50}\n"
298-
f" Load (checkpoint): {load_time:>7.1f}s\n"
299-
f" Compile: {compile_time:>7.1f}s\n"
300-
f" {'─' * 40}\n"
301-
f" Inference: {generation_time:>7.1f}s\n"
302-
f"{'=' * 50}"
303-
)
301+
summary = [
302+
f"\n{'=' * 50}",
303+
" TIMING SUMMARY",
304+
f"{'=' * 50}",
305+
f" Load (checkpoint): {load_time:>7.1f}s",
306+
f" Compile: {compile_time:>7.1f}s",
307+
f" {'─' * 40}",
308+
f" Inference: {generation_time:>7.1f}s",
309+
]
310+
if trace:
311+
summary.extend([
312+
f" Conditioning: {trace.get('conditioning', 0.0):>7.1f}s",
313+
f" Denoise Total: {trace.get('denoise_total', 0.0):>7.1f}s",
314+
f" VAE Decode: {trace.get('vae_decode', 0.0):>7.1f}s",
315+
])
316+
summary.append(f"{'=' * 50}")
317+
max_logging.log("\n".join(summary))
304318

305-
videos = call_pipeline(config, pipeline, prompt, negative_prompt)
319+
s0 = time.perf_counter()
320+
if max_utils.profiler_enabled(config):
321+
# Injecting user requested XLA tracing flags
322+
xla_flags = os.environ.get("XLA_FLAGS", "")
323+
new_flags = "--xla_enable_mxu_trace=true --xla_jf_dump_llo_html=true --xla_tpu_enable_llo_profiling=true"
324+
os.environ["XLA_FLAGS"] = f"{xla_flags} {new_flags}"
325+
max_logging.log(f"Injected XLA_FLAGS for profiling: {new_flags}")
326+
327+
videos = call_pipeline(config, pipeline, prompt, negative_prompt)
328+
if isinstance(videos, tuple):
329+
videos = videos[0]
330+
generation_time_with_profiler = time.perf_counter() - s0
331+
max_logging.log(f"generation_time_with_profiler: {generation_time_with_profiler}")
332+
if writer and jax.process_index() == 0:
333+
writer.add_scalar("inference/generation_time_with_profiler", generation_time_with_profiler, global_step=0)
306334

307335
return saved_video_path
308336

0 commit comments

Comments
 (0)