fix(diffusion): honor local checkpoint dirs and task-derived model id in WAN inference#4408
fix(diffusion): honor local checkpoint dirs and task-derived model id in WAN inference#4408huvunvidia wants to merge 6 commits into
Conversation
… in WAN inference WAN inference failed on air-gapped nodes in two ways: - easydict was imported unconditionally in inference_wan.py but absent from the container; add it to scripts/install_diffusion_deps.sh and document the inference prerequisites in the README. - FlowInferencePipeline ignored the t5_checkpoint_dir / vae_checkpoint_dir params and always resolved the hardcoded 14B hub id, so offline runs hit the HF hub and crashed. Resolve text encoder, tokenizer, VAE, and scheduler from the provided local dirs (falling back to model_id), and derive model_id from --task instead of hardcoding 14B. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> Signed-off-by: Huy Vu <huvu@nvidia.com>
ReviewClean bug-fix PR -- both fixes are correct and the README additions are helpful. Two small items: 1. Docstring inaccuracy (pre-existing, good to fix here) flow_inference_pipeline.py lines 97-100: the docstrings for t5_checkpoint_dir and vae_checkpoint_dir say 'falls back to checkpoint_dir if None' but the actual fallback (lines 125-126) is model_id. Now that this PR wires up these parameters, worth updating the docstrings to say 'falls back to model_id if None'. 2. Test helper missing scheduler_source The _make_pipeline helper in tests/unit_tests/diffusion/model/wan/flow_matching/test_flow_inference_pipeline.py (line 52) sets model_id but not the new scheduler_source attribute. The existing generate tests still pass because FlowMatchEulerDiscreteScheduler.from_pretrained is monkeypatched to ignore its args, but any future test that exercises the scheduler path without that monkeypatch will hit AttributeError. Suggest adding scheduler_source next to the existing model_id default in the helper. Suggested test cases No perf tests impacted. Generated with Claude Code |
|
/ok to test 353214e |
…ion in WAN inference The generate() loop now reads self.scheduler_source; add it to the test helper that bypasses __init__ so the existing TestGenerate cases pass, and add TestComponentDirResolution to verify __init__ resolves the text encoder, tokenizer, VAE, and scheduler from --t5_checkpoint_dir / --vae_checkpoint_dir with fallback to model_id. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> Signed-off-by: Huy Vu <huvu@nvidia.com>
…ion tests Trim the inference README to the easydict prerequisite and the basic example. Remove the added TestComponentDirResolution cases; keep only the scheduler_source default in the test helper so the existing TestGenerate cases pass. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> Signed-off-by: Huy Vu <huvu@nvidia.com>
|
Fixed the Note: the /ok to test 0ce84dc |
Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> Signed-off-by: Huy Vu <huvu@nvidia.com>
|
/ok to test 546f0ea |
|
Confirmed manually: with this PR, WAN /ok to test 546f0ea |
|
Merged latest /ok to test d935ded |
|
Merged latest /ok to test 3469f8f |
|
/ok to test 3469f8f |
What
Fixes two independent failures when running
examples/models/wan/inference_wan.py --task t2v-1.3Bon an air-gapped node (no outbound internet).Bug 1 — Missing
easydictdependencyinference_wan.pyimportseasydictunconditionally at module load, but the package is absent from the container (it was a transitive dep dropped when the CVE-bearing codecsav/imageio/imageio-ffmpegwere removed). The script crashes at startup withModuleNotFoundError: No module named 'easydict'.Fix: add
easydicttoscripts/install_diffusion_deps.sh(alongside the codecs) and document inference prerequisites in the WAN README.Bug 2 — Hardcoded 14B model id / ignored
--t5_checkpoint_dir&--vae_checkpoint_dirFlowInferencePipelineacceptedt5_checkpoint_dir/vae_checkpoint_dirbut never used them — allfrom_pretrainedcalls resolved the hardcodedWan-AI/Wan2.1-T2V-14B-Diffusershub id. Offline runs hit the HF hub and crashed withLocalEntryNotFoundError, even with the correct local files supplied.inference_wan.pyalso hardcoded the 14B id regardless of--task.Fix:
flow_inference_pipeline.py: resolve the text encoder, tokenizer, VAE, and scheduler from the provided local dirs, falling back tomodel_id. (The schedulerfrom_pretrainedingenerate()also usedmodel_idand would have failed offline — it is now covered too.)inference_wan.py: derivemodel_idfrom--taskvia aTASK_TO_MODEL_IDmap instead of hardcoding 14B.--t5_checkpoint_dir/--vae_checkpoint_dirand add an offline-inference example.Notes
easydictis intentionally not added topyproject.tomlas a required dependency, consistent with how the CVE-bearing codecs are handled (example/test-time install only). Happy to make it an optional extra in a separate dependency PR if preferred.Testing
ruff checkandruff format --checkpass on the changed Python files.🤖 Generated with Claude Code