11# SPDX-License-Identifier: Apache-2.0
22import json
3- import importlib
4- import sys
5- import types
63from pathlib import Path
74
85import pytest
96
107pytest .importorskip ("torchvision" )
118
129
13- def _bootstrap_fastvideo_namespace () -> None :
14- """Avoid importing fastvideo/__init__.py during local registry tests."""
15- if "fastvideo" in sys .modules :
16- return
17-
18- repo_root = Path (__file__ ).resolve ().parents [2 ]
19- package_dir = repo_root / "fastvideo"
20- fastvideo_pkg = types .ModuleType ("fastvideo" )
21- fastvideo_pkg .__path__ = [str (package_dir )] # type: ignore[attr-defined]
22- fastvideo_pkg .__file__ = str (package_dir / "__init__.py" )
23- sys .modules ["fastvideo" ] = fastvideo_pkg
24-
25-
26- def _get_registry_test_symbols () -> tuple [type , type , type , object , object ]:
27- _bootstrap_fastvideo_namespace ()
28-
29- pipeline_module = importlib .import_module ("fastvideo.configs.pipelines.ltx2" )
30- sample_module = importlib .import_module ("fastvideo.configs.sample.ltx2" )
31- registry_module = importlib .import_module ("fastvideo.registry" )
32-
33- return (
34- pipeline_module .LTX2T2VConfig ,
35- sample_module .LTX2BaseSamplingParam ,
36- sample_module .LTX2DistilledSamplingParam ,
37- registry_module .get_pipeline_config_cls_from_name ,
38- registry_module .get_sampling_param_cls_for_name ,
39- )
10+ from fastvideo .api .sampling_param import SamplingParam
11+ from fastvideo .registry import (
12+ get_pipeline_config_cls_from_name ,
13+ get_sampling_param_cls_for_name ,
14+ )
15+ from fastvideo .configs .pipelines .ltx2 import LTX2T2VConfig
4016
4117
4218@pytest .mark .parametrize (
@@ -47,12 +23,23 @@ def _get_registry_test_symbols() -> tuple[type, type, type, object, object]:
4723 ("FastVideo/LTX2-Distilled-Diffusers" , "distilled" ),
4824 ],
4925)
50- def test_ltx2_sampling_registry_exact_ids (model_id : str ,
51- expected_variant : str ) -> None :
52- _ , base_cls , distilled_cls , _ , get_sampling_param_cls_for_name = _get_registry_test_symbols ()
53- expected_cls = base_cls if expected_variant == "base" else distilled_cls
54- resolved_cls = get_sampling_param_cls_for_name (model_id )
55- assert resolved_cls is expected_cls
26+ def test_ltx2_sampling_registry_exact_ids (
27+ model_id : str ,
28+ expected_variant : str ,
29+ ) -> None :
30+ # All sampling_param_cls are None after profile migration.
31+ assert get_sampling_param_cls_for_name (model_id ) is None
32+
33+ # Profile-based path should return correct defaults.
34+ sp = SamplingParam .from_pretrained (model_id )
35+ if expected_variant == "base" :
36+ assert sp .num_inference_steps == 40
37+ assert sp .height == 512
38+ assert sp .width == 768
39+ else :
40+ assert sp .num_inference_steps == 8
41+ assert sp .height == 1024
42+ assert sp .width == 1536
5643
5744
5845@pytest .mark .parametrize (
@@ -64,45 +51,48 @@ def test_ltx2_sampling_registry_exact_ids(model_id: str,
6451 ],
6552)
6653def test_ltx2_pipeline_registry_exact_ids (model_id : str ) -> None :
67- pipeline_cls , _ , _ , get_pipeline_config_cls_from_name , _ = _get_registry_test_symbols ()
6854 resolved_cls = get_pipeline_config_cls_from_name (model_id )
69- assert resolved_cls is pipeline_cls
55+ assert resolved_cls is LTX2T2VConfig
7056
7157
72- def _write_minimal_diffusers_repo (model_dir : Path , class_name : str ) -> None :
58+ def _write_minimal_diffusers_repo (
59+ model_dir : Path , class_name : str
60+ ) -> None :
7361 model_dir .mkdir (parents = True , exist_ok = True )
7462 (model_dir / "transformer" ).mkdir (exist_ok = True )
7563 (model_dir / "vae" ).mkdir (exist_ok = True )
76- with (model_dir / "model_index.json" ).open ("w" , encoding = "utf-8" ) as f :
64+ with (model_dir / "model_index.json" ).open (
65+ "w" , encoding = "utf-8"
66+ ) as f :
7767 json .dump (
7868 {
7969 "_class_name" : class_name ,
8070 "_diffusers_version" : "0.33.0.dev0" ,
81- "transformer" : ["diffusers" , "LTX2Transformer3DModel" ],
71+ "transformer" : [
72+ "diffusers" ,
73+ "LTX2Transformer3DModel" ,
74+ ],
8275 "vae" : ["diffusers" , "CausalVideoAutoencoder" ],
8376 },
8477 f ,
8578 )
8679
8780
8881def test_ltx2_ambiguous_local_path_has_no_sampling_fallback (
89- tmp_path : Path ) -> None :
90- _ , _ , _ , _ , get_sampling_param_cls_for_name = _get_registry_test_symbols ()
91- # Simulate a user-local converted LTX2 path that might be either base or
92- # distilled. Registry must not assume a variant for local converted paths.
82+ tmp_path : Path ,
83+ ) -> None :
9384 model_dir = tmp_path / "converted" / "ltx2_diffusers"
9485 _write_minimal_diffusers_repo (model_dir , "LTX2Pipeline" )
95-
9686 resolved_cls = get_sampling_param_cls_for_name (str (model_dir ))
9787 assert resolved_cls is None
9888
9989
100- def test_ltx2_ambiguous_local_path_has_no_pipeline_mapping (
101- tmp_path : Path ) -> None :
102- _ , _ , _ , get_pipeline_config_cls_from_name , _ = _get_registry_test_symbols ()
90+ def test_ltx2_ambiguous_local_path_resolves_via_detector (
91+ tmp_path : Path ,
92+ ) -> None :
93+ # Both base and distilled share LTX2T2VConfig, so the
94+ # "ltx2" detector correctly resolves ambiguous local paths.
10395 model_dir = tmp_path / "converted" / "ltx2_diffusers"
10496 _write_minimal_diffusers_repo (model_dir , "LTX2Pipeline" )
105-
106- with pytest .raises (ValueError ,
107- match = "No match found for pipeline .*check the pipeline name or path" ):
108- get_pipeline_config_cls_from_name (str (model_dir ))
97+ resolved_cls = get_pipeline_config_cls_from_name (str (model_dir ))
98+ assert resolved_cls is LTX2T2VConfig
0 commit comments