@@ -149,12 +149,9 @@ def _save_component_state_dict_safetensors(
149149
150150def _postprocess_safetensors (
151151 export_dir : Path ,
152- merged_base_safetensor_path : str | None = None ,
153- model_type : str | None = None ,
152+ pipe : Any | None = None ,
154153 hf_quant_config : dict | None = None ,
155- enable_layerwise_quant_metadata : bool = True ,
156- padding_strategy : str | None = None ,
157- enable_swizzle_layout : bool = False ,
154+ ** kwargs ,
158155) -> None :
159156 """Post-process saved safetensors files for deployment compatibility.
160157
@@ -174,13 +171,38 @@ def _postprocess_safetensors(
174171
175172 Args:
176173 export_dir: Directory containing the saved ``.safetensors`` file(s).
177- merged_base_safetensor_path: Path to base model safetensors for merge.
178- model_type: Key into ``DIFFUSION_MERGE_FUNCTIONS`` (e.g., ``"ltx2"``).
174+ pipe: The diffusion pipeline / model. Used to infer the model type
175+ (via :func:`get_diffusion_model_type`) when
176+ ``merged_base_safetensor_path`` is set.
179177 hf_quant_config: Quantization config dict to embed in metadata.
180- enable_layerwise_quant_metadata: Whether to build per-layer metadata.
181- padding_strategy: ``"row"``, ``"row_col"``, or None.
182- enable_swizzle_layout: Whether to swizzle block scales.
178+ **kwargs: Runtime-specific keyword arguments:
179+ merged_base_safetensor_path (str, optional): When provided, merges
180+ the exported transformer weights with non-transformer components
181+ (VAE, vocoder, text encoders, etc.) from this base safetensors
182+ file to produce a single-file checkpoint compatible with ComfyUI.
183+ Value should be the path to a full base model ``.safetensors``
184+ file (e.g. ``"path/to/ltx-2-19b-dev.safetensors"``).
185+ enable_layerwise_quant_metadata (bool, optional): When True
186+ (default), includes per-layer ``_quantization_metadata`` in the
187+ checkpoint metadata so that inference runtimes (e.g., ComfyUI)
188+ can identify which layers are quantized and in what format. Set
189+ to False to skip.
190+ enable_swizzle_layout (bool, optional): When True, rearranges NVFP4
191+ block scales from ModelOpt's flat layout to cuBLAS 2-D tiled
192+ layout. Required for runtimes that consume cuBLAS block-scaled
193+ GEMM (e.g., comfy_kitchen). Defaults to False.
194+ padding_strategy (str | None, optional): Padding strategy for NVFP4
195+ weight and scale tensors. ``"row"`` pads rows to multiples of
196+ 16 (columns assumed already aligned). ``"row_col"`` pads both
197+ dimensions. ``None`` (default) disables padding. Independent of
198+ ``enable_swizzle_layout``.
199+
183200 """
201+ merged_base_safetensor_path : str | None = kwargs .get ("merged_base_safetensor_path" )
202+ enable_layerwise_quant_metadata : bool = kwargs .get ("enable_layerwise_quant_metadata" , True )
203+ enable_swizzle_layout : bool = kwargs .get ("enable_swizzle_layout" , False )
204+ padding_strategy : str | None = kwargs .get ("padding_strategy" )
205+
184206 safetensor_files = sorted (export_dir .glob ("*.safetensors" ))
185207 if not safetensor_files :
186208 return
@@ -193,6 +215,14 @@ def _postprocess_safetensors(
193215 "Export with a larger max_shard_size or disable merge/metadata options."
194216 )
195217
218+ model_type : str | None = None
219+ if merged_base_safetensor_path is not None :
220+ if pipe is None :
221+ raise ValueError (
222+ "`pipe` must be provided when `merged_base_safetensor_path` is set."
223+ )
224+ model_type = get_diffusion_model_type (pipe )
225+
196226 for sf_path in safetensor_files :
197227 with safe_open (str (sf_path ), framework = "pt" ) as f :
198228 metadata = dict (f .metadata () or {})
@@ -948,11 +978,8 @@ def _export_diffusers_checkpoint(
948978 dtype : torch .dtype | None ,
949979 export_dir : Path ,
950980 components : list [str ] | None ,
951- merged_base_safetensor_path : str | None = None ,
952981 max_shard_size : int | str = "10GB" ,
953- enable_layerwise_quant_metadata : bool = True ,
954- enable_swizzle_layout : bool = False ,
955- padding_strategy : str | None = None ,
982+ ** kwargs ,
956983) -> None :
957984 """Internal: Export diffusion(-like) model/pipeline checkpoint.
958985
@@ -966,19 +993,11 @@ def _export_diffusers_checkpoint(
966993 export_dir: The directory to save the exported checkpoint.
967994 components: Optional list of component names to export. Only used for pipelines.
968995 If None, all components are exported.
969- merged_base_safetensor_path: If provided, merge the exported transformer weights
970- with non-transformer components (VAE, vocoder, text encoders, etc.) from this
971- base safetensors file and add quantization metadata to produce a single-file
972- checkpoint compatible with ComfyUI. This should be the path to a full base
973- model ``.safetensors`` file, e.g. ``"path/to/ltx-2-19b-dev.safetensors"``.
974996 max_shard_size: Maximum size of each shard file. If the model exceeds this size,
975997 it will be sharded into multiple files and a .safetensors.index.json will be
976998 created. Use smaller values like "5GB" or "2GB" to force sharding.
977- enable_layerwise_quant_metadata: If True (default), include per-layer
978- ``_quantization_metadata`` in the merged checkpoint metadata.
979- enable_swizzle_layout: If True, swizzle NVFP4 block scales to cuBLAS tiled layout.
980- padding_strategy: ``"row"``, ``"row_col"``, or None. Pads NVFP4 weight/scale
981- tensors independently of swizzle.
999+ **kwargs: Runtime-specific post-processing options forwarded to
1000+ :func:`_postprocess_safetensors`. See its docstring for details.
9821001 """
9831002 export_dir = Path (export_dir )
9841003
@@ -989,9 +1008,6 @@ def _export_diffusers_checkpoint(
9891008 warnings .warn ("No exportable components found in the model." )
9901009 return
9911010
992- # Resolve model type once (only needed when merging with a base checkpoint)
993- model_type = get_diffusion_model_type (pipe ) if merged_base_safetensor_path else None
994-
9951011 # Separate nn.Module components for quantization-aware export
9961012 module_components = {
9971013 name : comp for name , comp in all_components .items () if isinstance (comp , nn .Module )
@@ -1052,12 +1068,9 @@ def _export_diffusers_checkpoint(
10521068 # Step 7: Post-process — merge, metadata, padding, swizzle
10531069 _postprocess_safetensors (
10541070 component_export_dir ,
1055- merged_base_safetensor_path = merged_base_safetensor_path ,
1056- model_type = model_type ,
1071+ pipe ,
10571072 hf_quant_config = hf_quant_config ,
1058- enable_layerwise_quant_metadata = enable_layerwise_quant_metadata ,
1059- padding_strategy = padding_strategy ,
1060- enable_swizzle_layout = enable_swizzle_layout ,
1073+ ** kwargs ,
10611074 )
10621075
10631076 # Step 8: Update config.json with quantization info
@@ -1229,31 +1242,10 @@ def export_hf_checkpoint(
12291242 to export. If None, all quantized components are exported.
12301243 extra_state_dict: Extra state dictionary to add to the exported model.
12311244 max_shard_size: Maximum size of each safetensors shard file. Defaults to "10GB".
1232- **kwargs: Internal-only keyword arguments. Supported keys:
1233- merged_base_safetensor_path (str, optional). When provided, merges the
1234- exported diffusion transformer weights with non-transformer components
1235- (VAE, vocoder, text encoders, etc.) from this base safetensors file to
1236- produce a single-file checkpoint compatible with ComfyUI. Value should be
1237- the path to a full base model ``.safetensors`` file
1238- (e.g. ``"path/to/ltx-2-19b-dev.safetensors"``).
1239- Only used for diffusion model exports.
1240- enable_layerwise_quant_metadata (bool, optional). When True (default),
1241- includes per-layer ``_quantization_metadata`` in the checkpoint metadata
1242- so that inference runtimes (e.g., ComfyUI) can identify which layers are
1243- quantized and in what format. Set to False to skip.
1244- enable_swizzle_layout (bool, optional). When True, rearranges NVFP4 block
1245- scales from ModelOpt's flat layout to cuBLAS 2-D tiled layout. Required
1246- for runtimes that consume cuBLAS block-scaled GEMM (e.g., comfy_kitchen).
1247- Defaults to False.
1248- padding_strategy (str | None, optional). Padding strategy for NVFP4 weight
1249- and scale tensors. ``"row"`` pads rows to multiples of 16 (columns assumed
1250- already aligned). ``"row_col"`` pads both dimensions. ``None`` (default)
1251- disables padding. Independent of ``enable_swizzle_layout``.
1245+ **kwargs: Runtime-specific post-processing options forwarded to
1246+ :func:`_postprocess_safetensors` for diffusion model exports.
1247+ See its docstring for supported keys.
12521248 """
1253- merged_base_safetensor_path : str | None = kwargs .get ("merged_base_safetensor_path" )
1254- enable_layerwise_quant_metadata : bool = kwargs .get ("enable_layerwise_quant_metadata" , True )
1255- enable_swizzle_layout : bool = kwargs .get ("enable_swizzle_layout" , False )
1256- padding_strategy : str | None = kwargs .get ("padding_strategy" )
12571249 export_dir = Path (export_dir )
12581250 export_dir .mkdir (parents = True , exist_ok = True )
12591251
@@ -1266,11 +1258,8 @@ def export_hf_checkpoint(
12661258 dtype ,
12671259 export_dir ,
12681260 components ,
1269- merged_base_safetensor_path ,
12701261 max_shard_size ,
1271- enable_layerwise_quant_metadata ,
1272- enable_swizzle_layout ,
1273- padding_strategy ,
1262+ ** kwargs ,
12741263 )
12751264 return
12761265
0 commit comments