Skip to content

Commit 68e2841

Browse files
Davids048SolitaryThinkerRandNMR73XOR-opjzhang38
committed
[feat] Add LTX2 refine and upsampler support
Apply Dreamverse monorepo changes for stack slice 12/13 from the source branch. Source-Branch: will/dreamverse-monorepo Source-SHA: 03d3e61 Dreamverse-Stack: 12/13 Co-authored-by: SolitaryThinker <wlsaidhi@gmail.com> Co-authored-by: Matthew Noto <99706358+RandNMR73@users.noreply.github.com> Co-authored-by: XOR-op <17672363+XOR-op@users.noreply.github.com> Co-authored-by: Zhang Peiyuan <42993249+jzhang38@users.noreply.github.com>
1 parent a110a24 commit 68e2841

18 files changed

Lines changed: 2778 additions & 207 deletions

fastvideo/models/dits/ltx2.py

Lines changed: 234 additions & 49 deletions
Large diffs are not rendered by default.

fastvideo/models/encoders/gemma.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -361,6 +361,10 @@ def named_parameters(self, prefix: str = "", recurse: bool = True):
361361
continue
362362
yield name, param
363363

364+
def prepare_for_compile(self) -> None:
365+
# Load Gemma outside Dynamo so torch.compile does not trace HF file-system checks.
366+
_ = self.gemma_model
367+
364368
@property
365369
def gemma_model(self) -> Gemma3ForConditionalGeneration:
366370
if self._gemma_model is None:
@@ -517,18 +521,21 @@ def forward(
517521
attention_mask = torch.ones_like(input_ids)
518522

519523
model = self.gemma_model
520-
orig_device = model.device
521-
model.to(device=get_local_torch_device())
522-
# input_ids = input_ids.to(device=model.device)
523-
# attention_mask = attention_mask.to(device=model.device)
524+
target_device = get_local_torch_device()
525+
# Do not invoke model.to() inside the compiled forward path.
526+
# _parse_to returns a non-Tensor torch.device, which Dynamo cannot
527+
# trace under fullgraph=True. The model is already moved to device
528+
# when first loaded (see gemma_model property + prepare_for_compile),
529+
# so this guard is a runtime no-op and Dynamo can DCE it.
530+
if model.device != target_device:
531+
model.to(device=target_device)
524532
outputs = model(
525533
input_ids=input_ids,
526534
attention_mask=attention_mask,
527535
output_hidden_states=True,
528536
return_dict=True,
529537
)
530-
model.to(device=orig_device)
531-
538+
532539
encoded_inputs = self._run_feature_extractor(
533540
outputs.hidden_states,
534541
attention_mask,

fastvideo/models/loader/component_loader.py

Lines changed: 50 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,12 @@ def for_module_type(
9999
# NumberConditioners; not a pure text encoder, so it gets
100100
# its own loader.
101101
"conditioner": (ConditionerLoader, "fastvideo"),
102+
# LTX-2 spatial / temporal upsamplers — share the
103+
# UpsamplerLoader path with the upsampler/upsampler_2 keys
104+
# so the SR pipeline picks up real weights instead of the
105+
# generic config-only loader.
106+
"spatial_upsampler": (UpsamplerLoader, "diffusers"),
107+
"temporal_upsampler": (UpsamplerLoader, "diffusers"),
102108
}
103109

104110
if module_type in module_loaders:
@@ -1058,7 +1064,7 @@ def load(self, model_path: str, fastvideo_args: FastVideoArgs):
10581064

10591065

10601066
class UpsamplerLoader(ComponentLoader):
1061-
"""Loader for upsamplers."""
1067+
"""Loader for upsamplers (incl. LTX-2 spatial/temporal upsamplers)."""
10621068

10631069
def load(self, model_path: str, fastvideo_args: FastVideoArgs):
10641070
"""Load the upsampler based on the model path, and inference args."""
@@ -1068,36 +1074,65 @@ def load(self, model_path: str, fastvideo_args: FastVideoArgs):
10681074
if class_name is None:
10691075
raise ValueError(
10701076
"Model config does not contain a _class_name attribute. "
1071-
"Only diffusers format is supported."
1072-
)
1073-
1074-
try:
1075-
upsampler_cfg = deepcopy(fastvideo_args.pipeline_config.upsampler_config[0])
1076-
upsampler_cfg.update_model_config(config_dict)
1077-
except Exception as e:
1078-
upsampler_cfg = deepcopy(fastvideo_args.pipeline_config.upsampler_config[1])
1079-
upsampler_cfg.update_model_config(config_dict)
1077+
"Only diffusers format is supported.")
1078+
1079+
# The base PipelineConfig declares ``upsampler_config`` as a
1080+
# single ``UpsamplerConfig`` instance, but Hunyuan15 narrows it
1081+
# to a tuple of two configs (one per SR target). We only treat
1082+
# the attribute as a multi-config when it actually is one;
1083+
# otherwise the LTX-2 branch below handles the single-class
1084+
# path that takes the diffusers config dict directly.
1085+
upsampler_config_attr = getattr(fastvideo_args.pipeline_config,
1086+
"upsampler_config", None)
1087+
if isinstance(upsampler_config_attr, list | tuple):
1088+
try:
1089+
upsampler_cfg = deepcopy(upsampler_config_attr[0])
1090+
upsampler_cfg.update_model_config(config_dict)
1091+
except Exception:
1092+
upsampler_cfg = deepcopy(upsampler_config_attr[1])
1093+
upsampler_cfg.update_model_config(config_dict)
1094+
elif class_name == "LTX2LatentUpsampler":
1095+
# LTX-2 pipeline_config does not declare upsampler_config; the
1096+
# `LTX2LatentUpsampler` wrapper takes the raw diffusers config
1097+
# dict directly via LatentUpsamplerConfigurator.
1098+
upsampler_cfg = deepcopy(config_dict)
1099+
else:
1100+
raise AttributeError(
1101+
"pipeline_config.upsampler_config is missing; cannot build "
1102+
f"upsampler config for class {class_name}")
10801103

10811104
model_cls, _ = ModelRegistry.resolve_model_cls(class_name)
10821105
model = model_cls(upsampler_cfg)
10831106

10841107
target_device = get_local_torch_device()
1085-
model = model.to(target_device, dtype=PRECISION_TO_TYPE[fastvideo_args.pipeline_config.upsampler_precision])
1108+
upsampler_precision = getattr(fastvideo_args.pipeline_config,
1109+
"upsampler_precision", "bf16")
1110+
model = model.to(target_device,
1111+
dtype=PRECISION_TO_TYPE[upsampler_precision])
10861112

1087-
# Find all safetensors files
10881113
safetensors_list = glob.glob(
10891114
os.path.join(str(model_path), "*.safetensors"))
10901115
if not safetensors_list:
10911116
raise ValueError(f"No safetensors files found in {model_path}")
1092-
1117+
10931118
if len(safetensors_list) == 1:
10941119
loaded = safetensors_load_file(safetensors_list[0])
10951120
else:
10961121
loaded = {}
10971122
for sf_file in safetensors_list:
10981123
loaded.update(safetensors_load_file(sf_file))
1099-
1100-
model.load_state_dict(loaded, strict=True)
1124+
1125+
# The LTX-2 latent upsampler wrapper exposes the actual conv
1126+
# stack at ``self.model``; checkpoint state_dicts may be saved
1127+
# without the ``model.`` prefix when the inner module was
1128+
# serialised directly. Strip / forward as needed so both layouts
1129+
# load cleanly.
1130+
target_module = getattr(model, "model", model)
1131+
if loaded and all(k.startswith("model.") for k in loaded):
1132+
stripped = {k[len("model."):]: v for k, v in loaded.items()}
1133+
target_module.load_state_dict(stripped, strict=True)
1134+
else:
1135+
target_module.load_state_dict(loaded, strict=True)
11011136

11021137
return model.eval()
11031138

fastvideo/models/loader/fsdp_load.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,37 @@
2828
logger = init_logger(__name__)
2929

3030

31+
def _maybe_convert_model_to_nvfp4(model: nn.Module) -> None:
32+
"""Quantize NVFP4-tagged linear layers in-place after weights are loaded.
33+
34+
Walks the module tree once, looking for layers whose ``quant_method``
35+
is an :class:`NVFP4QuantizeMethod` (attached at construction time by
36+
:meth:`NVFP4Config.get_quant_method`). When at least one such layer
37+
exists, calls :func:`convert_model_to_nvfp4` to register the
38+
``_nvfp4_weight*`` / ``_nvfp4_alpha`` / ``_weight_global_sf`` buffers
39+
on each targeted layer.
40+
41+
The walk returns on the first NVFP4 layer found so non-NVFP4 callers
42+
pay only an ``isinstance`` check per module. flashinfer is imported
43+
lazily inside :func:`convert_model_to_nvfp4` so this helper is a
44+
no-op on hosts without the NVFP4 backend.
45+
"""
46+
# Defer the import: nvfp4_config imports heavy diffusers /
47+
# torch.distributed symbols at module-load time, and unconditional
48+
# import would penalize every loader call regardless of whether
49+
# NVFP4 is wired.
50+
from fastvideo.layers.quantization.nvfp4_config import (
51+
NVFP4QuantizeMethod, convert_model_to_nvfp4,
52+
)
53+
54+
for mod in model.modules():
55+
if isinstance(getattr(mod, "quant_method", None),
56+
NVFP4QuantizeMethod):
57+
logger.info("Converting loaded model weights for NVFP4 linear layers")
58+
convert_model_to_nvfp4(model)
59+
return
60+
61+
3162
# TODO(PY): move this to utils elsewhere
3263
@contextlib.contextmanager
3364
def set_default_dtype(dtype: torch.dtype) -> Generator[None, None, None]:
@@ -158,6 +189,15 @@ def maybe_load_fsdp_model(
158189
if isinstance(p, torch.nn.Parameter):
159190
p.requires_grad = False
160191

192+
# NVFP4 weight prequantization. We detect by the registered
193+
# ``quant_method`` on linear layers rather than by a separate flag —
194+
# construction-time ``NVFP4Config.get_quant_method`` already attached
195+
# ``NVFP4QuantizeMethod`` to every targeted layer, so the loader's
196+
# responsibility is just to materialize the per-layer nvfp4 weight /
197+
# scale buffers from the freshly-loaded bf16 weights. No-op when
198+
# ``flashinfer`` is not installed (lazy import inside the helper).
199+
_maybe_convert_model_to_nvfp4(model)
200+
161201
compile_in_loader = enable_torch_compile and training_mode
162202
if compile_in_loader:
163203
compile_kwargs = torch_compile_kwargs or {}
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
3+
from fastvideo.models.upsamplers.ltx2_upsampler import (
4+
BlurDownsample,
5+
LTX2LatentUpsampler,
6+
LatentUpsampler,
7+
LatentUpsamplerConfigurator,
8+
PixelShuffleND,
9+
ResBlock,
10+
SpatialRationalResampler,
11+
upsample_video,
12+
)
13+
14+
__all__ = [
15+
"BlurDownsample",
16+
"LTX2LatentUpsampler",
17+
"LatentUpsampler",
18+
"LatentUpsamplerConfigurator",
19+
"PixelShuffleND",
20+
"ResBlock",
21+
"SpatialRationalResampler",
22+
"upsample_video",
23+
]

0 commit comments

Comments
 (0)