@@ -766,6 +766,39 @@ def _merge_ltx2(
766766}
767767
768768
769+ def build_layerwise_quant_metadata (
770+ state_dict : dict [str , torch .Tensor ],
771+ hf_quant_config : dict ,
772+ ) -> str :
773+ """Build per-layer ``_quantization_metadata`` JSON from scale keys in state dict.
774+
775+ Scans the state dict for ``weight_scale`` / ``weight_scale_2`` suffixes and
776+ maps each quantized layer to its quantization format.
777+
778+ Args:
779+ state_dict: The (possibly merged) state dict with quantization scale tensors.
780+ hf_quant_config: The quantization config dict (must contain ``quant_algo``).
781+
782+ Returns:
783+ JSON string with ``format_version`` and ``layers`` mapping.
784+ """
785+ quant_algo = hf_quant_config .get ("quant_algo" , "unknown" ).lower ()
786+ layer_metadata = {}
787+ for k in state_dict :
788+ if k .endswith ((".weight_scale" , ".weight_scale_2" )):
789+ layer_name = k .rsplit ("." , 1 )[0 ]
790+ if layer_name .endswith (".weight" ):
791+ layer_name = layer_name .rsplit ("." , 1 )[0 ]
792+ if layer_name not in layer_metadata :
793+ layer_metadata [layer_name ] = {"format" : quant_algo }
794+ return json .dumps (
795+ {
796+ "format_version" : "1.0" ,
797+ "layers" : layer_metadata ,
798+ }
799+ )
800+
801+
769802def merge_diffusion_checkpoint (
770803 state_dict : dict [str , torch .Tensor ],
771804 merged_base_safetensor_path : str ,
@@ -775,17 +808,17 @@ def merge_diffusion_checkpoint(
775808 """Merge transformer weights with a base checkpoint and build ComfyUI metadata.
776809
777810 Dispatches to the model-specific merge function in ``DIFFUSION_MERGE_FUNCTIONS``
778- and, when ``hf_quant_config`` is provided, embeds ``quantization_config`` and
779- per-layer ``_quantization_metadata`` in the safetensors metadata for ComfyUI.
811+ and, when ``hf_quant_config`` is provided, embeds ``quantization_config`` in the
812+ safetensors metadata for ComfyUI.
780813
781814 Args:
782815 state_dict: The transformer state dict (already on CPU).
783816 merged_base_safetensor_path: Path to the full base model ``.safetensors`` file
784817 containing all components (transformer, VAE, vocoder, etc.),
785818 e.g. ``"path/to/ltx-2-19b-dev.safetensors"``.
786819 model_type: Key into ``DIFFUSION_MERGE_FUNCTIONS`` for the model-specific merge.
787- hf_quant_config: If provided, embed quantization config and per-layer
788- ``_quantization_metadata`` in the returned metadata dict.
820+ hf_quant_config: If provided, embed quantization config in the returned
821+ metadata dict.
789822
790823 Returns:
791824 Tuple of (merged_state_dict, metadata) where *metadata* is the base checkpoint's
@@ -797,25 +830,115 @@ def merge_diffusion_checkpoint(
797830 if hf_quant_config is not None :
798831 metadata ["quantization_config" ] = json .dumps (hf_quant_config )
799832
800- quant_algo = hf_quant_config .get ("quant_algo" , "unknown" ).lower ()
801- layer_metadata = {}
802- for k in merged_state_dict :
803- if k .endswith ((".weight_scale" , ".weight_scale_2" )):
804- layer_name = k .rsplit ("." , 1 )[0 ]
805- if layer_name .endswith (".weight" ):
806- layer_name = layer_name .rsplit ("." , 1 )[0 ]
807- if layer_name not in layer_metadata :
808- layer_metadata [layer_name ] = {"format" : quant_algo }
809- metadata ["_quantization_metadata" ] = json .dumps (
810- {
811- "format_version" : "1.0" ,
812- "layers" : layer_metadata ,
813- }
814- )
815-
816833 return merged_state_dict , metadata
817834
818835
836+ def _find_nvfp4_layers (state_dict : dict [str , torch .Tensor ]) -> set [str ]:
837+ """Find all NVFP4 layer prefixes in a state dict.
838+
839+ A layer is NVFP4 if it has ``weight`` (uint8), ``weight_scale`` (float8_e4m3fn),
840+ and ``weight_scale_2`` (float32) entries.
841+ """
842+ layers : set [str ] = set ()
843+ for key in state_dict :
844+ if key .endswith (".weight_scale_2" ):
845+ layer = key [: - len (".weight_scale_2" )]
846+ w_key = f"{ layer } .weight"
847+ s_key = f"{ layer } .weight_scale"
848+ if s_key not in state_dict or w_key not in state_dict :
849+ continue
850+ if state_dict [w_key ].dtype == torch .uint8 and state_dict [s_key ].dtype == torch .float8_e4m3fn :
851+ layers .add (layer )
852+ return layers
853+
854+
855+ def pad_nvfp4_weights (
856+ state_dict : dict [str , torch .Tensor ],
857+ padding_strategy : str = "row" ,
858+ ) -> dict [str , torch .Tensor ]:
859+ """Pad NVFP4 weight and scale tensors so dimensions are multiples of 16.
860+
861+ Args:
862+ state_dict: The state dict to pad (modified in-place and returned).
863+ padding_strategy: ``"row"`` (default) pads only rows to multiples of 16;
864+ ``"row_col"`` pads both rows and columns.
865+ """
866+ if padding_strategy not in ("row" , "row_col" ):
867+ raise ValueError (f"padding_strategy must be 'row' or 'row_col', got '{ padding_strategy } '" )
868+
869+ def _roundup (a : int , b : int ) -> int :
870+ return ((a + b - 1 ) // b ) * b
871+
872+ nvfp4_layers = _find_nvfp4_layers (state_dict )
873+ padded_count = 0
874+
875+ for layer in sorted (nvfp4_layers ):
876+ w_key = f"{ layer } .weight"
877+ s_key = f"{ layer } .weight_scale"
878+
879+ weight = state_dict [w_key ]
880+ scale = state_dict [s_key ]
881+
882+ rows , cols_w = weight .shape
883+ pad_r = _roundup (rows , 16 ) - rows
884+ pad_c_w = (_roundup (cols_w , 16 ) - cols_w ) if padding_strategy == "row_col" else 0
885+ pad_c_s = (_roundup (scale .shape [1 ], 16 ) - scale .shape [1 ]) if padding_strategy == "row_col" else 0
886+
887+ if pad_r > 0 or pad_c_w > 0 :
888+ state_dict [w_key ] = torch .nn .functional .pad (weight , (0 , pad_c_w , 0 , pad_r ))
889+ state_dict [s_key ] = torch .nn .functional .pad (scale , (0 , pad_c_s , 0 , pad_r ))
890+ padded_count += 1
891+
892+ return state_dict
893+
894+
895+ def swizzle_nvfp4_scales (
896+ state_dict : dict [str , torch .Tensor ],
897+ ) -> dict [str , torch .Tensor ]:
898+ """Swizzle NVFP4 block scales to cuBLAS 2-D tiled layout.
899+
900+ Converts the flat ``weight_scale`` tensors ``[rows, cols // 16]`` produced by
901+ ModelOpt into the cuBLAS 2-D block-scaling-factors layout.
902+
903+ Reference: https://docs.nvidia.com/cuda/cublas/index.html#d-block-scaling-factors-layout
904+
905+ Note: Weights and scales should be padded *before* calling this function
906+ (see :func:`pad_nvfp4_weights`).
907+
908+ Args:
909+ state_dict: The state dict to transform (modified in-place and returned).
910+ """
911+
912+ def _ceil_div (a : int , b : int ) -> int :
913+ return (a + b - 1 ) // b
914+
915+ def _to_blocked (input_matrix : torch .Tensor ) -> torch .Tensor :
916+ """Rearrange scale matrix to cuBLAS 2-D block-scaling-factors layout."""
917+ rows , cols = input_matrix .shape
918+ n_row_blocks = _ceil_div (rows , 128 )
919+ n_col_blocks = _ceil_div (cols , 4 )
920+ padded_rows = n_row_blocks * 128
921+ padded_cols = n_col_blocks * 4
922+ padded = input_matrix
923+ if (rows , cols ) != (padded_rows , padded_cols ):
924+ padded = torch .zeros (
925+ (padded_rows , padded_cols ), device = input_matrix .device , dtype = input_matrix .dtype ,
926+ )
927+ padded [:rows , :cols ] = input_matrix
928+ blocks = padded .view (n_row_blocks , 128 , n_col_blocks , 4 ).permute (0 , 2 , 1 , 3 )
929+ rearranged = blocks .reshape (- 1 , 4 , 32 , 4 ).transpose (1 , 2 ).reshape (- 1 , 32 , 16 )
930+ return rearranged .reshape (padded_rows , padded_cols )
931+
932+ nvfp4_layers = _find_nvfp4_layers (state_dict )
933+
934+ for layer in sorted (nvfp4_layers ):
935+ s_key = f"{ layer } .weight_scale"
936+ state_dict [s_key ] = _to_blocked (state_dict [s_key ].to (torch .float8_e4m3fn ))
937+
938+
939+ return state_dict
940+
941+
819942def get_diffusion_model_type (pipe : Any ) -> str :
820943 """Detect the diffusion model type for merge function dispatch.
821944
0 commit comments