Skip to content

Commit 02534ac

Browse files
[bugfix] [4/n] Fix profile loading bugs: shared mutable state, error swallowing, version sort
- deepcopy profile defaults to prevent shared mutable list references between SamplingParam instances (e.g. ltx2_stg_blocks_video, sigmas) - narrow bare except Exception to (ValueError, RuntimeError) in _from_profile() so real errors aren't silently swallowed - use numeric-aware version sorting in get_profile() to handle multi-segment versions correctly (e.g. "2" > "10" string sort bug) - remove dead get_sampling_param_cls_for_name fallback from from_pretrained() since all sampling_param_cls are now None - update test_ltx2_registry.py for profile-based behavior
1 parent f478a38 commit 02534ac

3 files changed

Lines changed: 48 additions & 62 deletions

File tree

fastvideo/api/profiles.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ def get_profile(
124124
f"{model_family!r}; "
125125
f"registered: {_format_registered(model_family)}",
126126
)
127-
candidates.sort(key=lambda p: p.version)
127+
candidates.sort(key=lambda p: [int(x) if x.isdigit() else x for x in p.version.split(".")])
128128
return candidates[-1]
129129

130130

fastvideo/api/sampling_param.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
# SPDX-License-Identifier: Apache-2.0
2+
import copy
23
from dataclasses import dataclass, field
34
from typing import Any
45

@@ -130,13 +131,8 @@ def from_pretrained(cls, model_path: str) -> "SamplingParam":
130131
if sampling_param is not None:
131132
return sampling_param
132133

133-
from fastvideo.registry import get_sampling_param_cls_for_name
134-
sampling_cls = get_sampling_param_cls_for_name(model_path)
135-
if sampling_cls is not None:
136-
return sampling_cls()
137-
138134
logger.warning(
139-
"Couldn't find an optimal sampling param for %s."
135+
"Couldn't find a profile for %s."
140136
" Using the default sampling param.",
141137
model_path,
142138
)
@@ -160,7 +156,7 @@ def _from_profile(
160156

161157
try:
162158
profile_name = get_default_profile(model_path)
163-
except Exception:
159+
except (ValueError, RuntimeError):
164160
return None
165161
if profile_name is None:
166162
return None
@@ -175,7 +171,7 @@ def _from_profile(
175171
sp = cls()
176172
for key, value in profile.defaults.items():
177173
if hasattr(sp, key):
178-
setattr(sp, key, value)
174+
setattr(sp, key, copy.deepcopy(value))
179175
sp.__post_init__()
180176
return sp
181177

Lines changed: 43 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -1,42 +1,18 @@
11
# SPDX-License-Identifier: Apache-2.0
22
import json
3-
import importlib
4-
import sys
5-
import types
63
from pathlib import Path
74

85
import pytest
96

107
pytest.importorskip("torchvision")
118

129

13-
def _bootstrap_fastvideo_namespace() -> None:
14-
"""Avoid importing fastvideo/__init__.py during local registry tests."""
15-
if "fastvideo" in sys.modules:
16-
return
17-
18-
repo_root = Path(__file__).resolve().parents[2]
19-
package_dir = repo_root / "fastvideo"
20-
fastvideo_pkg = types.ModuleType("fastvideo")
21-
fastvideo_pkg.__path__ = [str(package_dir)] # type: ignore[attr-defined]
22-
fastvideo_pkg.__file__ = str(package_dir / "__init__.py")
23-
sys.modules["fastvideo"] = fastvideo_pkg
24-
25-
26-
def _get_registry_test_symbols() -> tuple[type, type, type, object, object]:
27-
_bootstrap_fastvideo_namespace()
28-
29-
pipeline_module = importlib.import_module("fastvideo.configs.pipelines.ltx2")
30-
sample_module = importlib.import_module("fastvideo.configs.sample.ltx2")
31-
registry_module = importlib.import_module("fastvideo.registry")
32-
33-
return (
34-
pipeline_module.LTX2T2VConfig,
35-
sample_module.LTX2BaseSamplingParam,
36-
sample_module.LTX2DistilledSamplingParam,
37-
registry_module.get_pipeline_config_cls_from_name,
38-
registry_module.get_sampling_param_cls_for_name,
39-
)
10+
from fastvideo.api.sampling_param import SamplingParam
11+
from fastvideo.registry import (
12+
get_pipeline_config_cls_from_name,
13+
get_sampling_param_cls_for_name,
14+
)
15+
from fastvideo.configs.pipelines.ltx2 import LTX2T2VConfig
4016

4117

4218
@pytest.mark.parametrize(
@@ -47,12 +23,23 @@ def _get_registry_test_symbols() -> tuple[type, type, type, object, object]:
4723
("FastVideo/LTX2-Distilled-Diffusers", "distilled"),
4824
],
4925
)
50-
def test_ltx2_sampling_registry_exact_ids(model_id: str,
51-
expected_variant: str) -> None:
52-
_, base_cls, distilled_cls, _, get_sampling_param_cls_for_name = _get_registry_test_symbols()
53-
expected_cls = base_cls if expected_variant == "base" else distilled_cls
54-
resolved_cls = get_sampling_param_cls_for_name(model_id)
55-
assert resolved_cls is expected_cls
26+
def test_ltx2_sampling_registry_exact_ids(
27+
model_id: str,
28+
expected_variant: str,
29+
) -> None:
30+
# All sampling_param_cls are None after profile migration.
31+
assert get_sampling_param_cls_for_name(model_id) is None
32+
33+
# Profile-based path should return correct defaults.
34+
sp = SamplingParam.from_pretrained(model_id)
35+
if expected_variant == "base":
36+
assert sp.num_inference_steps == 40
37+
assert sp.height == 512
38+
assert sp.width == 768
39+
else:
40+
assert sp.num_inference_steps == 8
41+
assert sp.height == 1024
42+
assert sp.width == 1536
5643

5744

5845
@pytest.mark.parametrize(
@@ -64,45 +51,48 @@ def test_ltx2_sampling_registry_exact_ids(model_id: str,
6451
],
6552
)
6653
def test_ltx2_pipeline_registry_exact_ids(model_id: str) -> None:
67-
pipeline_cls, _, _, get_pipeline_config_cls_from_name, _ = _get_registry_test_symbols()
6854
resolved_cls = get_pipeline_config_cls_from_name(model_id)
69-
assert resolved_cls is pipeline_cls
55+
assert resolved_cls is LTX2T2VConfig
7056

7157

72-
def _write_minimal_diffusers_repo(model_dir: Path, class_name: str) -> None:
58+
def _write_minimal_diffusers_repo(
59+
model_dir: Path, class_name: str
60+
) -> None:
7361
model_dir.mkdir(parents=True, exist_ok=True)
7462
(model_dir / "transformer").mkdir(exist_ok=True)
7563
(model_dir / "vae").mkdir(exist_ok=True)
76-
with (model_dir / "model_index.json").open("w", encoding="utf-8") as f:
64+
with (model_dir / "model_index.json").open(
65+
"w", encoding="utf-8"
66+
) as f:
7767
json.dump(
7868
{
7969
"_class_name": class_name,
8070
"_diffusers_version": "0.33.0.dev0",
81-
"transformer": ["diffusers", "LTX2Transformer3DModel"],
71+
"transformer": [
72+
"diffusers",
73+
"LTX2Transformer3DModel",
74+
],
8275
"vae": ["diffusers", "CausalVideoAutoencoder"],
8376
},
8477
f,
8578
)
8679

8780

8881
def test_ltx2_ambiguous_local_path_has_no_sampling_fallback(
89-
tmp_path: Path) -> None:
90-
_, _, _, _, get_sampling_param_cls_for_name = _get_registry_test_symbols()
91-
# Simulate a user-local converted LTX2 path that might be either base or
92-
# distilled. Registry must not assume a variant for local converted paths.
82+
tmp_path: Path,
83+
) -> None:
9384
model_dir = tmp_path / "converted" / "ltx2_diffusers"
9485
_write_minimal_diffusers_repo(model_dir, "LTX2Pipeline")
95-
9686
resolved_cls = get_sampling_param_cls_for_name(str(model_dir))
9787
assert resolved_cls is None
9888

9989

100-
def test_ltx2_ambiguous_local_path_has_no_pipeline_mapping(
101-
tmp_path: Path) -> None:
102-
_, _, _, get_pipeline_config_cls_from_name, _ = _get_registry_test_symbols()
90+
def test_ltx2_ambiguous_local_path_resolves_via_detector(
91+
tmp_path: Path,
92+
) -> None:
93+
# Both base and distilled share LTX2T2VConfig, so the
94+
# "ltx2" detector correctly resolves ambiguous local paths.
10395
model_dir = tmp_path / "converted" / "ltx2_diffusers"
10496
_write_minimal_diffusers_repo(model_dir, "LTX2Pipeline")
105-
106-
with pytest.raises(ValueError,
107-
match="No match found for pipeline .*check the pipeline name or path"):
108-
get_pipeline_config_cls_from_name(str(model_dir))
97+
resolved_cls = get_pipeline_config_cls_from_name(str(model_dir))
98+
assert resolved_cls is LTX2T2VConfig

0 commit comments

Comments
 (0)