|
14 | 14 | limitations under the License. |
15 | 15 | --> |
16 | 16 |
|
17 | | -[](https://github.com/AI-Hypercomputer/maxdiffusion/actions/workflows/UnitTests.yml) |
| 17 | +[](https://github.com/AI-Hypercomputer/maxdiffusion/actions/workflows/UnitTests.yml) |
18 | 18 |
|
19 | 19 | # What's new? |
20 | | -- **`2026/1/15`**: Wan2.1 and Wan2.2 Img2vid generation is now supported |
| 20 | +- **`2026/01/29`**: Wan LoRA for inference is now supported |
| 21 | +- **`2026/01/15`**: Wan2.1 and Wan2.2 Img2vid generation is now supported |
21 | 22 | - **`2025/11/11`**: Wan2.2 txt2vid generation is now supported |
22 | 23 | - **`2025/10/10`**: Wan2.1 txt2vid training and generation is now supported. |
23 | 24 | - **`2025/10/14`**: NVIDIA DGX Spark Flux support. |
24 | | -- **`2025/8/14`**: LTX-Video img2vid generation is now supported. |
25 | | -- **`2025/7/29`**: LTX-Video text2vid generation is now supported. |
| 25 | +- **`2025/08/14`**: LTX-Video img2vid generation is now supported. |
| 26 | +- **`2025/07/29`**: LTX-Video text2vid generation is now supported. |
26 | 27 | - **`2025/04/17`**: Flux Finetuning. |
27 | 28 | - **`2025/02/12`**: Flux LoRA for inference. |
28 | 29 | - **`2025/02/08`**: Flux schnell & dev inference. |
29 | 30 | - **`2024/12/12`**: Load multiple LoRAs for inference. |
30 | 31 | - **`2024/10/22`**: LoRA support for Hyper SDXL. |
31 | | -- **`2024/8/1`**: Orbax is the new default checkpointer. You can still use `pipeline.save_pretrained` after training to save in diffusers format. |
32 | | -- **`2024/7/20`**: Dreambooth training for Stable Diffusion 1.x,2.x is now supported. |
| 32 | +- **`2024/08/01`**: Orbax is the new default checkpointer. You can still use `pipeline.save_pretrained` after training to save in diffusers format. |
| 33 | +- **`2024/07/20`**: Dreambooth training for Stable Diffusion 1.x,2.x is now supported. |
33 | 34 |
|
34 | 35 | # Overview |
35 | 36 |
|
@@ -68,14 +69,15 @@ MaxDiffusion supports |
68 | 69 | - [SD 1.4](#stable-diffusion-14-training) |
69 | 70 | - [Dreambooth](#dreambooth) |
70 | 71 | - [Inference](#inference) |
71 | | - - [Wan2.1](#wan21) |
72 | | - - [Wan2.2](#wan22) |
| 72 | + - [Wan](#wan-models) |
73 | 73 | - [LTX-Video](#ltx-video) |
74 | 74 | - [Flux](#flux) |
75 | 75 | - [Fused Attention for GPU](#fused-attention-for-gpu) |
76 | 76 | - [SDXL](#stable-diffusion-xl) |
77 | 77 | - [SD 2 base](#stable-diffusion-2-base) |
78 | 78 | - [SD 2.1](#stable-diffusion-21) |
| 79 | + - [Wan LoRA](#wan-lora) |
| 80 | + - [Flux LoRA](#flux-lora) |
79 | 81 | - [Hyper SDXL LoRA](#hyper-sdxl-lora) |
80 | 82 | - [Load Multiple LoRA](#load-multiple-lora) |
81 | 83 | - [SDXL Lightning](#sdxl-lightning) |
@@ -482,41 +484,48 @@ To generate images, run the following command: |
482 | 484 |
|
483 | 485 | Add conditioning image path as conditioning_media_paths in the form of ["IMAGE_PATH"] along with other generation parameters in the ltx_video.yml file. Then follow same instruction as above. |
484 | 486 |
|
485 | | - ## Wan2.1 |
| 487 | + ## Wan Models |
486 | 488 |
|
487 | 489 | Although not required, attaching an external disk is recommended as weights take up a lot of disk space. [Follow these instructions if you would like to attach an external disk](https://cloud.google.com/tpu/docs/attach-durable-block-storage). |
488 | 490 |
|
489 | | - ### Text2Vid |
| 491 | + Supports both Text2Vid and Img2Vid pipelines. |
490 | 492 |
|
491 | | - ```bash |
492 | | - HF_HUB_CACHE=/mnt/disks/external_disk/maxdiffusion_hf_cache/ |
493 | | - LIBTPU_INIT_ARGS="--xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_fuse_all_reduce=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_reduce=true" HF_HUB_ENABLE_HF_TRANSFER=1 python src/maxdiffusion/generate_wan.py src/maxdiffusion/configs/base_wan_14b.yml attention="flash" num_inference_steps=50 num_frames=81 width=1280 height=720 jax_cache_dir=gs://jfacevedo-maxdiffusion/jax_cache/ per_device_batch_size=.125 ici_data_parallelism=2 ici_fsdp_parallelism=2 flow_shift=5.0 enable_profiler=True run_name=wan-inference-testing-720p output_dir=gs:/jfacevedo-maxdiffusion fps=16 flash_min_seq_length=0 flash_block_sizes='{"block_q" : 3024, "block_kv_compute" : 1024, "block_kv" : 2048, "block_q_dkv": 3024, "block_kv_dkv" : 2048, "block_kv_dkv_compute" : 2048, "block_q_dq" : 3024, "block_kv_dq" : 2048 }' seed=118445 |
494 | | - ``` |
495 | | - |
496 | | - ### Img2Vid |
497 | | - |
498 | | - ```bash |
499 | | - HF_HUB_CACHE=/mnt/disks/external_disk/maxdiffusion_hf_cache/ |
500 | | - LIBTPU_INIT_ARGS="--xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_fuse_all_reduce=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_reduce=true" HF_HUB_ENABLE_HF_TRANSFER=1 python src/maxdiffusion/generate_wan.py src/maxdiffusion/configs/base_wan_i2v_14b.yml attention="flash" num_inference_steps=30 num_frames=81 width=832 height=480 jax_cache_dir=gs://jfacevedo-maxdiffusion/jax_cache/ per_device_batch_size=.125 ici_data_parallelism=2 ici_fsdp_parallelism=2 flow_shift=3.0 enable_profiler=True run_name=wan-i2v-inference-testing-480p output_dir=gs:/jfacevedo-maxdiffusion fps=16 flash_min_seq_length=0 flash_block_sizes='{"block_q" : 3024, "block_kv_compute" : 1024, "block_kv" : 2048, "block_q_dkv": 3024, "block_kv_dkv" : 2048, "block_kv_dkv_compute" : 2048, "block_q_dq" : 3024, "block_kv_dq" : 2048 }' seed=118445 |
501 | | - ``` |
502 | | - |
503 | | - ## Wan2.2 |
504 | | - |
505 | | - Although not required, attaching an external disk is recommended as weights take up a lot of disk space. [Follow these instructions if you would like to attach an external disk](https://cloud.google.com/tpu/docs/attach-durable-block-storage). |
506 | | - |
507 | | - ### Text2Vid |
| 493 | + The following command will run Wan2.1 T2V: |
508 | 494 |
|
509 | 495 | ```bash |
510 | | - HF_HUB_CACHE=/mnt/disks/external_disk/maxdiffusion_hf_cache/ |
511 | | - LIBTPU_INIT_ARGS="--xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_fuse_all_reduce=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_reduce=true" HF_HUB_ENABLE_HF_TRANSFER=1 python src/maxdiffusion/generate_wan.py src/maxdiffusion/configs/base_wan_27b.yml attention="flash" num_inference_steps=50 num_frames=81 width=1280 height=720 jax_cache_dir=gs://jfacevedo-maxdiffusion/jax_cache/ per_device_batch_size=.125 ici_data_parallelism=2 ici_fsdp_parallelism=2 flow_shift=5.0 enable_profiler=True run_name=wan-inference-testing-720p output_dir=gs:/jfacevedo-maxdiffusion fps=16 flash_min_seq_length=0 flash_block_sizes='{"block_q" : 3024, "block_kv_compute" : 1024, "block_kv" : 2048, "block_q_dkv": 3024, "block_kv_dkv" : 2048, "block_kv_dkv_compute" : 2048, "block_q_dq" : 3024, "block_kv_dq" : 2048 }' seed=118445 |
| 496 | + HF_HUB_CACHE=/mnt/disks/external_disk/maxdiffusion_hf_cache/ \ |
| 497 | + LIBTPU_INIT_ARGS="--xla_tpu_enable_async_collective_fusion=true \ |
| 498 | + --xla_tpu_enable_async_collective_fusion_fuse_all_reduce=true \ |
| 499 | + --xla_tpu_enable_async_collective_fusion_multiple_steps=true \ |
| 500 | + --xla_tpu_overlap_compute_collective_tc=true \ |
| 501 | + --xla_enable_async_all_reduce=true" \ |
| 502 | + HF_HUB_ENABLE_HF_TRANSFER=1 \ |
| 503 | + python src/maxdiffusion/generate_wan.py \ |
| 504 | + src/maxdiffusion/configs/base_wan_14b.yml \ |
| 505 | + attention="flash" \ |
| 506 | + num_inference_steps=50 \ |
| 507 | + num_frames=81 \ |
| 508 | + width=1280 \ |
| 509 | + height=720 \ |
| 510 | + jax_cache_dir=gs://jfacevedo-maxdiffusion/jax_cache/ \ |
| 511 | + per_device_batch_size=.125 \ |
| 512 | + ici_data_parallelism=2 \ |
| 513 | + ici_context_parallelism=2 \ |
| 514 | + flow_shift=5.0 \ |
| 515 | + enable_profiler=True \ |
| 516 | + run_name=wan-inference-testing-720p \ |
| 517 | + output_dir=gs:/jfacevedo-maxdiffusion \ |
| 518 | + fps=16 \ |
| 519 | + flash_min_seq_length=0 \ |
| 520 | + flash_block_sizes='{"block_q" : 3024, "block_kv_compute" : 1024, "block_kv" : 2048, "block_q_dkv": 3024, "block_kv_dkv" : 2048, "block_kv_dkv_compute" : 2048, "block_q_dq" : 3024, "block_kv_dq" : 2048 }' \ |
| 521 | + seed=118445 |
512 | 522 | ``` |
513 | 523 |
|
514 | | - ### Img2Vid |
| 524 | + To run other Wan model inference pipelines, change the config file in the command above: |
515 | 525 |
|
516 | | - ```bash |
517 | | - HF_HUB_CACHE=/mnt/disks/external_disk/maxdiffusion_hf_cache/ |
518 | | - LIBTPU_INIT_ARGS="--xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_fuse_all_reduce=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_reduce=true" HF_HUB_ENABLE_HF_TRANSFER=1 python src/maxdiffusion/generate_wan.py src/maxdiffusion/configs/base_wan_i2v_27b.yml attention="flash" num_inference_steps=30 num_frames=81 width=832 height=480 jax_cache_dir=gs://jfacevedo-maxdiffusion/jax_cache/ per_device_batch_size=.125 ici_data_parallelism=2 ici_fsdp_parallelism=2 flow_shift=3.0 enable_profiler=True run_name=wan-i2v-inference-testing-480p output_dir=gs:/jfacevedo-maxdiffusion fps=16 flash_min_seq_length=0 flash_block_sizes='{"block_q" : 3024, "block_kv_compute" : 1024, "block_kv" : 2048, "block_q_dkv": 3024, "block_kv_dkv" : 2048, "block_kv_dkv_compute" : 2048, "block_q_dq" : 3024, "block_kv_dq" : 2048 }' seed=118445 |
519 | | - ``` |
| 526 | + * For Wan2.1 I2V, use `base_wan_i2v_14b.yml`. |
| 527 | + * For Wan2.2 T2V, use `base_wan_27b.yml`. |
| 528 | + * For Wan2.2 I2V, use `base_wan_i2v_27b.yml`. |
520 | 529 |
|
521 | 530 | ## Flux |
522 | 531 |
|
@@ -568,6 +577,33 @@ To generate images, run the following command: |
568 | 577 | ```bash |
569 | 578 | NVTE_FUSED_ATTN=1 HF_HUB_ENABLE_HF_TRANSFER=1 python src/maxdiffusion/generate_flux.py src/maxdiffusion/configs/base_flux_dev.yml jax_cache_dir=/tmp/cache_dir run_name=flux_test output_dir=/tmp/ prompt='A cute corgi lives in a house made out of sushi, anime' num_inference_steps=28 split_head_dim=True per_device_batch_size=1 attention="cudnn_flash_te" hardware=gpu |
570 | 579 | ``` |
| 580 | + ## Wan LoRA |
| 581 | +
|
| 582 | + Disclaimer: not all LoRA formats have been tested. Currently supports ComfyUI and AI Toolkit formats. If there is a specific LoRA that doesn't load, please let us know. |
| 583 | +
|
| 584 | + First create a copy of the relevant config file eg: `src/maxdiffusion/configs/base_wan_{*}.yml`. Update the prompt and LoRA details in the config. Make sure to set `enable_lora: True`. Then run the following command: |
| 585 | +
|
| 586 | + ```bash |
| 587 | + HF_HUB_CACHE=/mnt/disks/external_disk/maxdiffusion_hf_cache/ \ |
| 588 | + LIBTPU_INIT_ARGS="--xla_tpu_enable_async_collective_fusion=true \ |
| 589 | + --xla_tpu_enable_async_collective_fusion_fuse_all_reduce=true \ |
| 590 | + --xla_tpu_enable_async_collective_fusion_multiple_steps=true \ |
| 591 | + --xla_tpu_overlap_compute_collective_tc=true \ |
| 592 | + --xla_enable_async_all_reduce=true" \ |
| 593 | + HF_HUB_ENABLE_HF_TRANSFER=1 \ |
| 594 | + python src/maxdiffusion/generate_wan.py \ |
| 595 | + src/maxdiffusion/configs/base_wan_i2v_14b.yml \ # --> Change to your copy |
| 596 | + jax_cache_dir=gs://jfacevedo-maxdiffusion/jax_cache/ \ |
| 597 | + per_device_batch_size=.125 \ |
| 598 | + ici_data_parallelism=2 \ |
| 599 | + ici_context_parallelism=2 \ |
| 600 | + run_name=wan-lora-inference-testing-720p \ |
| 601 | + output_dir=gs:/jfacevedo-maxdiffusion \ |
| 602 | + seed=118445 \ |
| 603 | + enable_lora=True \ |
| 604 | + ``` |
| 605 | +
|
| 606 | + Loading multiple LoRAs is supported as well. |
571 | 607 |
|
572 | 608 | ## Flux LoRA |
573 | 609 |
|
|
0 commit comments