Skip to content

Commit 0170e83

Browse files
committed
Add support for postprocess exported model for block scale swizzling and support for different padding strategy
Signed-off-by: ynankani <ynankani@nvidia.com>
1 parent 80d2f02 commit 0170e83

File tree

6 files changed

+326
-97
lines changed

6 files changed

+326
-97
lines changed

examples/diffusers/README.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,13 @@ python quantize.py \
133133
--extra-param merged_base_safetensor_path=./ltx-2-19b-dev-fp8.safetensors
134134
```
135135

136+
To additionally apply NVFP4 scale swizzle and padding , add:
137+
138+
```sh
139+
--extra-param enable_swizzle_layout=true \
140+
--extra-param padding_strategy=row_col
141+
```
142+
136143
#### Important Parameters
137144

138145
- `percentile`: Control quantization scaling factors (amax) collecting range, meaning that we will collect the chosen amax in the range of `(n_steps * percentile)` steps. Recommendation: 1.0

examples/diffusers/quantization/pipeline_manager.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,9 @@ def _create_ltx2_pipeline(self) -> Any:
217217
"fp8transformer", False
218218
)
219219
params.pop("merged_base_safetensor_path", None)
220+
params.pop("enable_swizzle_layout", None)
221+
params.pop("padding_strategy", None)
222+
params.pop("enable_layerwise_quant_metadata", None)
220223

221224
if not checkpoint_path:
222225
raise ValueError("Missing required extra_param: checkpoint_path.")

examples/diffusers/quantization/quantize.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -320,6 +320,14 @@ def export_hf_ckpt(self, pipe: Any, model_config: ModelConfig | None = None) ->
320320
if merged_path:
321321
self.logger.info(f"Merging base safetensors from {merged_path} for LTX2 export")
322322
kwargs["merged_base_safetensor_path"] = merged_path
323+
if model_config:
324+
for key in ("enable_swizzle_layout", "enable_layerwise_quant_metadata"):
325+
val = model_config.extra_params.get(key)
326+
if val is not None:
327+
kwargs[key] = str(val).lower() in ("true", "1", "yes")
328+
padding = model_config.extra_params.get("padding_strategy")
329+
if padding:
330+
kwargs["padding_strategy"] = padding
323331
export_hf_checkpoint(pipe, export_dir=self.config.hf_ckpt_dir, **kwargs)
324332
self.logger.info("HuggingFace checkpoint export completed successfully")
325333

modelopt/torch/export/diffusers_utils.py

Lines changed: 143 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -766,6 +766,39 @@ def _merge_ltx2(
766766
}
767767

768768

769+
def build_layerwise_quant_metadata(
770+
state_dict: dict[str, torch.Tensor],
771+
hf_quant_config: dict,
772+
) -> str:
773+
"""Build per-layer ``_quantization_metadata`` JSON from scale keys in state dict.
774+
775+
Scans the state dict for ``weight_scale`` / ``weight_scale_2`` suffixes and
776+
maps each quantized layer to its quantization format.
777+
778+
Args:
779+
state_dict: The (possibly merged) state dict with quantization scale tensors.
780+
hf_quant_config: The quantization config dict (must contain ``quant_algo``).
781+
782+
Returns:
783+
JSON string with ``format_version`` and ``layers`` mapping.
784+
"""
785+
quant_algo = hf_quant_config.get("quant_algo", "unknown").lower()
786+
layer_metadata = {}
787+
for k in state_dict:
788+
if k.endswith((".weight_scale", ".weight_scale_2")):
789+
layer_name = k.rsplit(".", 1)[0]
790+
if layer_name.endswith(".weight"):
791+
layer_name = layer_name.rsplit(".", 1)[0]
792+
if layer_name not in layer_metadata:
793+
layer_metadata[layer_name] = {"format": quant_algo}
794+
return json.dumps(
795+
{
796+
"format_version": "1.0",
797+
"layers": layer_metadata,
798+
}
799+
)
800+
801+
769802
def merge_diffusion_checkpoint(
770803
state_dict: dict[str, torch.Tensor],
771804
merged_base_safetensor_path: str,
@@ -775,17 +808,17 @@ def merge_diffusion_checkpoint(
775808
"""Merge transformer weights with a base checkpoint and build ComfyUI metadata.
776809
777810
Dispatches to the model-specific merge function in ``DIFFUSION_MERGE_FUNCTIONS``
778-
and, when ``hf_quant_config`` is provided, embeds ``quantization_config`` and
779-
per-layer ``_quantization_metadata`` in the safetensors metadata for ComfyUI.
811+
and, when ``hf_quant_config`` is provided, embeds ``quantization_config`` in the
812+
safetensors metadata for ComfyUI.
780813
781814
Args:
782815
state_dict: The transformer state dict (already on CPU).
783816
merged_base_safetensor_path: Path to the full base model ``.safetensors`` file
784817
containing all components (transformer, VAE, vocoder, etc.),
785818
e.g. ``"path/to/ltx-2-19b-dev.safetensors"``.
786819
model_type: Key into ``DIFFUSION_MERGE_FUNCTIONS`` for the model-specific merge.
787-
hf_quant_config: If provided, embed quantization config and per-layer
788-
``_quantization_metadata`` in the returned metadata dict.
820+
hf_quant_config: If provided, embed quantization config in the returned
821+
metadata dict.
789822
790823
Returns:
791824
Tuple of (merged_state_dict, metadata) where *metadata* is the base checkpoint's
@@ -797,25 +830,115 @@ def merge_diffusion_checkpoint(
797830
if hf_quant_config is not None:
798831
metadata["quantization_config"] = json.dumps(hf_quant_config)
799832

800-
quant_algo = hf_quant_config.get("quant_algo", "unknown").lower()
801-
layer_metadata = {}
802-
for k in merged_state_dict:
803-
if k.endswith((".weight_scale", ".weight_scale_2")):
804-
layer_name = k.rsplit(".", 1)[0]
805-
if layer_name.endswith(".weight"):
806-
layer_name = layer_name.rsplit(".", 1)[0]
807-
if layer_name not in layer_metadata:
808-
layer_metadata[layer_name] = {"format": quant_algo}
809-
metadata["_quantization_metadata"] = json.dumps(
810-
{
811-
"format_version": "1.0",
812-
"layers": layer_metadata,
813-
}
814-
)
815-
816833
return merged_state_dict, metadata
817834

818835

836+
def _find_nvfp4_layers(state_dict: dict[str, torch.Tensor]) -> set[str]:
837+
"""Find all NVFP4 layer prefixes in a state dict.
838+
839+
A layer is NVFP4 if it has ``weight`` (uint8), ``weight_scale`` (float8_e4m3fn),
840+
and ``weight_scale_2`` (float32) entries.
841+
"""
842+
layers: set[str] = set()
843+
for key in state_dict:
844+
if key.endswith(".weight_scale_2"):
845+
layer = key[: -len(".weight_scale_2")]
846+
w_key = f"{layer}.weight"
847+
s_key = f"{layer}.weight_scale"
848+
if s_key not in state_dict or w_key not in state_dict:
849+
continue
850+
if state_dict[w_key].dtype == torch.uint8 and state_dict[s_key].dtype == torch.float8_e4m3fn:
851+
layers.add(layer)
852+
return layers
853+
854+
855+
def pad_nvfp4_weights(
856+
state_dict: dict[str, torch.Tensor],
857+
padding_strategy: str = "row",
858+
) -> dict[str, torch.Tensor]:
859+
"""Pad NVFP4 weight and scale tensors so dimensions are multiples of 16.
860+
861+
Args:
862+
state_dict: The state dict to pad (modified in-place and returned).
863+
padding_strategy: ``"row"`` (default) pads only rows to multiples of 16;
864+
``"row_col"`` pads both rows and columns.
865+
"""
866+
if padding_strategy not in ("row", "row_col"):
867+
raise ValueError(f"padding_strategy must be 'row' or 'row_col', got '{padding_strategy}'")
868+
869+
def _roundup(a: int, b: int) -> int:
870+
return ((a + b - 1) // b) * b
871+
872+
nvfp4_layers = _find_nvfp4_layers(state_dict)
873+
padded_count = 0
874+
875+
for layer in sorted(nvfp4_layers):
876+
w_key = f"{layer}.weight"
877+
s_key = f"{layer}.weight_scale"
878+
879+
weight = state_dict[w_key]
880+
scale = state_dict[s_key]
881+
882+
rows, cols_w = weight.shape
883+
pad_r = _roundup(rows, 16) - rows
884+
pad_c_w = (_roundup(cols_w, 16) - cols_w) if padding_strategy == "row_col" else 0
885+
pad_c_s = (_roundup(scale.shape[1], 16) - scale.shape[1]) if padding_strategy == "row_col" else 0
886+
887+
if pad_r > 0 or pad_c_w > 0:
888+
state_dict[w_key] = torch.nn.functional.pad(weight, (0, pad_c_w, 0, pad_r))
889+
state_dict[s_key] = torch.nn.functional.pad(scale, (0, pad_c_s, 0, pad_r))
890+
padded_count += 1
891+
892+
return state_dict
893+
894+
895+
def swizzle_nvfp4_scales(
896+
state_dict: dict[str, torch.Tensor],
897+
) -> dict[str, torch.Tensor]:
898+
"""Swizzle NVFP4 block scales to cuBLAS 2-D tiled layout.
899+
900+
Converts the flat ``weight_scale`` tensors ``[rows, cols // 16]`` produced by
901+
ModelOpt into the cuBLAS 2-D block-scaling-factors layout.
902+
903+
Reference: https://docs.nvidia.com/cuda/cublas/index.html#d-block-scaling-factors-layout
904+
905+
Note: Weights and scales should be padded *before* calling this function
906+
(see :func:`pad_nvfp4_weights`).
907+
908+
Args:
909+
state_dict: The state dict to transform (modified in-place and returned).
910+
"""
911+
912+
def _ceil_div(a: int, b: int) -> int:
913+
return (a + b - 1) // b
914+
915+
def _to_blocked(input_matrix: torch.Tensor) -> torch.Tensor:
916+
"""Rearrange scale matrix to cuBLAS 2-D block-scaling-factors layout."""
917+
rows, cols = input_matrix.shape
918+
n_row_blocks = _ceil_div(rows, 128)
919+
n_col_blocks = _ceil_div(cols, 4)
920+
padded_rows = n_row_blocks * 128
921+
padded_cols = n_col_blocks * 4
922+
padded = input_matrix
923+
if (rows, cols) != (padded_rows, padded_cols):
924+
padded = torch.zeros(
925+
(padded_rows, padded_cols), device=input_matrix.device, dtype=input_matrix.dtype,
926+
)
927+
padded[:rows, :cols] = input_matrix
928+
blocks = padded.view(n_row_blocks, 128, n_col_blocks, 4).permute(0, 2, 1, 3)
929+
rearranged = blocks.reshape(-1, 4, 32, 4).transpose(1, 2).reshape(-1, 32, 16)
930+
return rearranged.reshape(padded_rows, padded_cols)
931+
932+
nvfp4_layers = _find_nvfp4_layers(state_dict)
933+
934+
for layer in sorted(nvfp4_layers):
935+
s_key = f"{layer}.weight_scale"
936+
state_dict[s_key] = _to_blocked(state_dict[s_key].to(torch.float8_e4m3fn))
937+
938+
939+
return state_dict
940+
941+
819942
def get_diffusion_model_type(pipe: Any) -> str:
820943
"""Detect the diffusion model type for merge function dispatch.
821944

0 commit comments

Comments
 (0)