Skip to content

Commit 7244a4b

Browse files
authored
[CI] Add LoRA inference tests (#546)
1 parent 7e5ebb4 commit 7244a4b

44 files changed

Lines changed: 533 additions & 515 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

.buildkite/pipeline.yml

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,21 @@ steps:
6565
- TEST_TYPE=ssim
6666
agents:
6767
queue: "default"
68+
- path:
69+
- "fastvideo/v1/tests/lora/**"
70+
- "fastvideo/v1/models/loader/**"
71+
- "fastvideo/v1/tests/transformers/**"
72+
- "fastvideo/v1/pipelines/**"
73+
- "fastvideo/v1/layers/lora/**"
74+
- "pyproject.toml"
75+
- "docker/Dockerfile.python3.12"
76+
config:
77+
command: "timeout 15m .buildkite/scripts/pr_test.sh"
78+
label: "LoRA Inference Tests"
79+
env:
80+
- TEST_TYPE=inference_lora
81+
agents:
82+
queue: "default"
6883
- path:
6984
- "fastvideo/v1/**"
7085
- "pyproject.toml"

.buildkite/scripts/pr_test.sh

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,10 @@ case "$TEST_TYPE" in
9797
log "Running precision VSA tests..."
9898
MODAL_COMMAND="$MODAL_ENV python3 -m modal run $MODAL_TEST_FILE::run_precision_tests_VSA"
9999
;;
100+
"inference_lora")
101+
log "Running LoRA tests..."
102+
MODAL_COMMAND="$MODAL_ENV python3 -m modal run $MODAL_TEST_FILE::run_inference_lora_tests"
103+
;;
100104
*)
101105
log "Error: Unknown test type: $TEST_TYPE"
102106
exit 1

.github/workflows/matchers/mypy.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,4 +13,4 @@
1313
]
1414
}
1515
]
16-
}
16+
}

.github/workflows/pr-test.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -372,4 +372,4 @@ jobs:
372372
JOB_IDS: '["encoder-test", "vae-test", "transformer-test", "ssim-test-py3.10", "ssim-test-py3.11", "ssim-test-py3.12", "training-test", "training-test-VSA", "inference-test-STA", "precision-test-STA", "precision-test-VSA"]'
373373
RUNPOD_API_KEY: ${{ secrets.RUNPOD_API_KEY }}
374374
GITHUB_RUN_ID: ${{ github.run_id }}
375-
run: python .github/scripts/runpod_cleanup.py
375+
run: python .github/scripts/runpod_cleanup.py

.pre-commit-config.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ repos:
6060
rev: v1.15.0
6161
hooks:
6262
- id: mypy
63-
args: [--python-version, '3.10', --follow-imports, "skip", ]
63+
args: [--python-version, '3.10', --follow-imports, "skip" ]
6464
additional_dependencies: [types-cachetools, types-setuptools, types-PyYAML, types-requests]
6565
- repo: local
6666
hooks:
@@ -69,7 +69,7 @@ repos:
6969
entry: bash
7070
args:
7171
- -c
72-
- 'git ls-files | grep -v "^fastvideo/v1/tests/ssim/" | grep " " && echo "Filenames should not contain spaces!" && exit 1 || exit 0'
72+
- 'git ls-files | grep -v "^fastvideo/v1/tests/ssim/" | grep -v "^fastvideo/v1/tests/inference/lora/L40S_reference_videos/" | grep " " && echo "Filenames should not contain spaces!" && exit 1 || exit 0'
7373
language: system
7474
always_run: true
7575
pass_filenames: false

examples/inference/lora/wan_lora_inference.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ def main():
66
# Initialize VideoGenerator with the Wan model
77
generator = VideoGenerator.from_pretrained(
88
"Wan-AI/Wan2.1-T2V-1.3B-Diffusers",
9-
num_gpus=2,
9+
num_gpus=1,
1010
lora_path="benjamin-paine/steamboat-willie-1.3b",
1111
lora_nickname="steamboat"
1212
)
@@ -16,6 +16,7 @@ def main():
1616
"num_frames": 81,
1717
"guidance_scale": 5.0,
1818
"num_inference_steps": 32,
19+
"seed": 42,
1920
}
2021
# Generate video with LoRA style
2122
prompt = "steamboat willie style, golden era animation, close-up of a short fluffy monster kneeling beside a melting red candle. the mood is one of wonder and curiosity, as the monster gazes at the flame with wide eyes and open mouth. Its pose and expression convey a sense of innocence and playfulness, as if it is exploring the world around it for the first time. The use of warm colors and dramatic lighting further enhances the cozy atmosphere of the image."
@@ -29,8 +30,17 @@ def main():
2930
negative_prompt=negative_prompt,
3031
**kwargs
3132
)
32-
33-
generator.set_lora_adapter(lora_nickname="flat_color", lora_path="motimalu/wan-flat-color-1.3b-v2")
33+
del generator
34+
35+
# Until FSDP resharding bug is fixed, multi-lora requires reloading the model
36+
# see https://github.com/pytorch/pytorch/issues/157209
37+
generator = VideoGenerator.from_pretrained(
38+
"Wan-AI/Wan2.1-T2V-1.3B-Diffusers",
39+
num_gpus=1,
40+
lora_path="motimalu/wan-flat-color-1.3b-v2",
41+
lora_nickname="flat_color"
42+
)
43+
# generator.set_lora_adapter(lora_nickname="flat_color", lora_path="motimalu/wan-flat-color-1.3b-v2")
3444
prompt = "flat color, no lineart, blending, negative space, artist:[john kafka|ponsuke kaikai|hara id 21|yoneyama mai|fuzichoco], 1girl, sakura miko, pink hair, cowboy shot, white shirt, floral print, off shoulder, outdoors, cherry blossom, tree shade, wariza, looking up, falling petals, half-closed eyes, white sky, clouds, live2d animation, upper body, high quality cinematic video of a woman sitting under a sakura tree. Dreamy and lonely, the camera close-ups on the face of the woman as she turns towards the viewer. The Camera is steady, This is a cowboy shot. The animation is smooth and fluid."
3545
negative_prompt = "bad quality video,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走"
3646
video = generator.generate_video(

fastvideo/utils/collect_env.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@
6262
DEFAULT_CONDA_PATTERNS = {
6363
"torch",
6464
"numpy",
65+
"mypy"
6566
"cudatoolkit",
6667
"soumith",
6768
"mkl",
@@ -80,7 +81,6 @@
8081
DEFAULT_PIP_PATTERNS = {
8182
"torch",
8283
"numpy",
83-
"mypy",
8484
"flake8",
8585
"triton",
8686
"optree",

fastvideo/v1/configs/fasthunyuan_t2v.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
"use_cpu_offload": false,
55
"disable_autocast": false,
66
"precision": "bf16",
7-
"vae_precision": "fp16",
7+
"vae_precision": "fp32",
88
"vae_tiling": true,
99
"vae_sp": true,
1010
"vae_config": {

fastvideo/v1/configs/models/dits/base.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,9 @@
1111
class DiTArchConfig(ArchConfig):
1212
_fsdp_shard_conditions: list = field(default_factory=list)
1313
_compile_conditions: list = field(default_factory=list)
14-
_param_names_mapping: dict = field(default_factory=dict)
15-
_reverse_param_names_mapping: dict = field(default_factory=dict)
16-
_lora_param_names_mapping: dict = field(default_factory=dict)
14+
param_names_mapping: dict = field(default_factory=dict)
15+
reverse_param_names_mapping: dict = field(default_factory=dict)
16+
lora_param_names_mapping: dict = field(default_factory=dict)
1717
_supported_attention_backends: tuple[AttentionBackendEnum, ...] = (
1818
AttentionBackendEnum.SLIDING_TILE_ATTN, AttentionBackendEnum.SAGE_ATTN,
1919
AttentionBackendEnum.FLASH_ATTN, AttentionBackendEnum.TORCH_SDPA,

fastvideo/v1/configs/models/dits/hunyuanvideo.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ class HunyuanVideoArchConfig(DiTArchConfig):
3131
_compile_conditions: list = field(
3232
default_factory=lambda: [is_double_block, is_single_block, is_txt_in])
3333

34-
_param_names_mapping: dict = field(
34+
param_names_mapping: dict = field(
3535
default_factory=lambda: {
3636
# 1. context_embedder.time_text_embed submodules (specific rules, applied first):
3737
r"^context_embedder\.time_text_embed\.timestep_embedder\.linear_1\.(.*)$":
@@ -146,8 +146,8 @@ class HunyuanVideoArchConfig(DiTArchConfig):
146146
r"final_layer.linear.\1",
147147
})
148148

149-
# Reverse mapping for saving checkpoints: training -> diffusers
150-
_reverse_param_names_mapping: dict = field(default_factory=lambda: {})
149+
# Reverse mapping for saving checkpoints: custom -> hf
150+
reverse_param_names_mapping: dict = field(default_factory=lambda: {})
151151

152152
patch_size: int = 2
153153
patch_size_t: int = 1

0 commit comments

Comments
 (0)