Skip to content

Commit e30a8f1

Browse files
YASH NankaniYASH Nankani
authored andcommitted
Refactor to extract kwargs in postprocess call
Signed-off-by: YASH Nankani <ynankani@dl325g11-0771.ipp4a1.colossus.nvidia.com>
1 parent ceec092 commit e30a8f1

2 files changed

Lines changed: 68 additions & 71 deletions

File tree

modelopt/torch/export/unified_export_hf.py

Lines changed: 49 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -149,12 +149,9 @@ def _save_component_state_dict_safetensors(
149149

150150
def _postprocess_safetensors(
151151
export_dir: Path,
152-
merged_base_safetensor_path: str | None = None,
153-
model_type: str | None = None,
152+
pipe: Any | None = None,
154153
hf_quant_config: dict | None = None,
155-
enable_layerwise_quant_metadata: bool = True,
156-
padding_strategy: str | None = None,
157-
enable_swizzle_layout: bool = False,
154+
**kwargs,
158155
) -> None:
159156
"""Post-process saved safetensors files for deployment compatibility.
160157
@@ -174,13 +171,38 @@ def _postprocess_safetensors(
174171
175172
Args:
176173
export_dir: Directory containing the saved ``.safetensors`` file(s).
177-
merged_base_safetensor_path: Path to base model safetensors for merge.
178-
model_type: Key into ``DIFFUSION_MERGE_FUNCTIONS`` (e.g., ``"ltx2"``).
174+
pipe: The diffusion pipeline / model. Used to infer the model type
175+
(via :func:`get_diffusion_model_type`) when
176+
``merged_base_safetensor_path`` is set.
179177
hf_quant_config: Quantization config dict to embed in metadata.
180-
enable_layerwise_quant_metadata: Whether to build per-layer metadata.
181-
padding_strategy: ``"row"``, ``"row_col"``, or None.
182-
enable_swizzle_layout: Whether to swizzle block scales.
178+
**kwargs: Runtime-specific keyword arguments:
179+
merged_base_safetensor_path (str, optional): When provided, merges
180+
the exported transformer weights with non-transformer components
181+
(VAE, vocoder, text encoders, etc.) from this base safetensors
182+
file to produce a single-file checkpoint compatible with ComfyUI.
183+
Value should be the path to a full base model ``.safetensors``
184+
file (e.g. ``"path/to/ltx-2-19b-dev.safetensors"``).
185+
enable_layerwise_quant_metadata (bool, optional): When True
186+
(default), includes per-layer ``_quantization_metadata`` in the
187+
checkpoint metadata so that inference runtimes (e.g., ComfyUI)
188+
can identify which layers are quantized and in what format. Set
189+
to False to skip.
190+
enable_swizzle_layout (bool, optional): When True, rearranges NVFP4
191+
block scales from ModelOpt's flat layout to cuBLAS 2-D tiled
192+
layout. Required for runtimes that consume cuBLAS block-scaled
193+
GEMM (e.g., comfy_kitchen). Defaults to False.
194+
padding_strategy (str | None, optional): Padding strategy for NVFP4
195+
weight and scale tensors. ``"row"`` pads rows to multiples of
196+
16 (columns assumed already aligned). ``"row_col"`` pads both
197+
dimensions. ``None`` (default) disables padding. Independent of
198+
``enable_swizzle_layout``.
199+
183200
"""
201+
merged_base_safetensor_path: str | None = kwargs.get("merged_base_safetensor_path")
202+
enable_layerwise_quant_metadata: bool = kwargs.get("enable_layerwise_quant_metadata", True)
203+
enable_swizzle_layout: bool = kwargs.get("enable_swizzle_layout", False)
204+
padding_strategy: str | None = kwargs.get("padding_strategy")
205+
184206
safetensor_files = sorted(export_dir.glob("*.safetensors"))
185207
if not safetensor_files:
186208
return
@@ -193,6 +215,14 @@ def _postprocess_safetensors(
193215
"Export with a larger max_shard_size or disable merge/metadata options."
194216
)
195217

218+
model_type: str | None = None
219+
if merged_base_safetensor_path is not None:
220+
if pipe is None:
221+
raise ValueError(
222+
"`pipe` must be provided when `merged_base_safetensor_path` is set."
223+
)
224+
model_type = get_diffusion_model_type(pipe)
225+
196226
for sf_path in safetensor_files:
197227
with safe_open(str(sf_path), framework="pt") as f:
198228
metadata = dict(f.metadata() or {})
@@ -948,11 +978,8 @@ def _export_diffusers_checkpoint(
948978
dtype: torch.dtype | None,
949979
export_dir: Path,
950980
components: list[str] | None,
951-
merged_base_safetensor_path: str | None = None,
952981
max_shard_size: int | str = "10GB",
953-
enable_layerwise_quant_metadata: bool = True,
954-
enable_swizzle_layout: bool = False,
955-
padding_strategy: str | None = None,
982+
**kwargs,
956983
) -> None:
957984
"""Internal: Export diffusion(-like) model/pipeline checkpoint.
958985
@@ -966,19 +993,11 @@ def _export_diffusers_checkpoint(
966993
export_dir: The directory to save the exported checkpoint.
967994
components: Optional list of component names to export. Only used for pipelines.
968995
If None, all components are exported.
969-
merged_base_safetensor_path: If provided, merge the exported transformer weights
970-
with non-transformer components (VAE, vocoder, text encoders, etc.) from this
971-
base safetensors file and add quantization metadata to produce a single-file
972-
checkpoint compatible with ComfyUI. This should be the path to a full base
973-
model ``.safetensors`` file, e.g. ``"path/to/ltx-2-19b-dev.safetensors"``.
974996
max_shard_size: Maximum size of each shard file. If the model exceeds this size,
975997
it will be sharded into multiple files and a .safetensors.index.json will be
976998
created. Use smaller values like "5GB" or "2GB" to force sharding.
977-
enable_layerwise_quant_metadata: If True (default), include per-layer
978-
``_quantization_metadata`` in the merged checkpoint metadata.
979-
enable_swizzle_layout: If True, swizzle NVFP4 block scales to cuBLAS tiled layout.
980-
padding_strategy: ``"row"``, ``"row_col"``, or None. Pads NVFP4 weight/scale
981-
tensors independently of swizzle.
999+
**kwargs: Runtime-specific post-processing options forwarded to
1000+
:func:`_postprocess_safetensors`. See its docstring for details.
9821001
"""
9831002
export_dir = Path(export_dir)
9841003

@@ -989,9 +1008,6 @@ def _export_diffusers_checkpoint(
9891008
warnings.warn("No exportable components found in the model.")
9901009
return
9911010

992-
# Resolve model type once (only needed when merging with a base checkpoint)
993-
model_type = get_diffusion_model_type(pipe) if merged_base_safetensor_path else None
994-
9951011
# Separate nn.Module components for quantization-aware export
9961012
module_components = {
9971013
name: comp for name, comp in all_components.items() if isinstance(comp, nn.Module)
@@ -1052,12 +1068,9 @@ def _export_diffusers_checkpoint(
10521068
# Step 7: Post-process — merge, metadata, padding, swizzle
10531069
_postprocess_safetensors(
10541070
component_export_dir,
1055-
merged_base_safetensor_path=merged_base_safetensor_path,
1056-
model_type=model_type,
1071+
pipe,
10571072
hf_quant_config=hf_quant_config,
1058-
enable_layerwise_quant_metadata=enable_layerwise_quant_metadata,
1059-
padding_strategy=padding_strategy,
1060-
enable_swizzle_layout=enable_swizzle_layout,
1073+
**kwargs,
10611074
)
10621075

10631076
# Step 8: Update config.json with quantization info
@@ -1229,31 +1242,10 @@ def export_hf_checkpoint(
12291242
to export. If None, all quantized components are exported.
12301243
extra_state_dict: Extra state dictionary to add to the exported model.
12311244
max_shard_size: Maximum size of each safetensors shard file. Defaults to "10GB".
1232-
**kwargs: Internal-only keyword arguments. Supported keys:
1233-
merged_base_safetensor_path (str, optional). When provided, merges the
1234-
exported diffusion transformer weights with non-transformer components
1235-
(VAE, vocoder, text encoders, etc.) from this base safetensors file to
1236-
produce a single-file checkpoint compatible with ComfyUI. Value should be
1237-
the path to a full base model ``.safetensors`` file
1238-
(e.g. ``"path/to/ltx-2-19b-dev.safetensors"``).
1239-
Only used for diffusion model exports.
1240-
enable_layerwise_quant_metadata (bool, optional). When True (default),
1241-
includes per-layer ``_quantization_metadata`` in the checkpoint metadata
1242-
so that inference runtimes (e.g., ComfyUI) can identify which layers are
1243-
quantized and in what format. Set to False to skip.
1244-
enable_swizzle_layout (bool, optional). When True, rearranges NVFP4 block
1245-
scales from ModelOpt's flat layout to cuBLAS 2-D tiled layout. Required
1246-
for runtimes that consume cuBLAS block-scaled GEMM (e.g., comfy_kitchen).
1247-
Defaults to False.
1248-
padding_strategy (str | None, optional). Padding strategy for NVFP4 weight
1249-
and scale tensors. ``"row"`` pads rows to multiples of 16 (columns assumed
1250-
already aligned). ``"row_col"`` pads both dimensions. ``None`` (default)
1251-
disables padding. Independent of ``enable_swizzle_layout``.
1245+
**kwargs: Runtime-specific post-processing options forwarded to
1246+
:func:`_postprocess_safetensors` for diffusion model exports.
1247+
See its docstring for supported keys.
12521248
"""
1253-
merged_base_safetensor_path: str | None = kwargs.get("merged_base_safetensor_path")
1254-
enable_layerwise_quant_metadata: bool = kwargs.get("enable_layerwise_quant_metadata", True)
1255-
enable_swizzle_layout: bool = kwargs.get("enable_swizzle_layout", False)
1256-
padding_strategy: str | None = kwargs.get("padding_strategy")
12571249
export_dir = Path(export_dir)
12581250
export_dir.mkdir(parents=True, exist_ok=True)
12591251

@@ -1266,11 +1258,8 @@ def export_hf_checkpoint(
12661258
dtype,
12671259
export_dir,
12681260
components,
1269-
merged_base_safetensor_path,
12701261
max_shard_size,
1271-
enable_layerwise_quant_metadata,
1272-
enable_swizzle_layout,
1273-
padding_strategy,
1262+
**kwargs,
12741263
)
12751264
return
12761265

tests/unit/torch/export/test_nvfp4_utils.py

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
pad_nvfp4_weights,
2828
swizzle_nvfp4_scales,
2929
)
30+
from modelopt.torch.export.unified_export_hf import _postprocess_safetensors
3031

3132

3233
def _make_nvfp4_state_dict(rows=32, cols=64):
@@ -146,8 +147,6 @@ def test_small_scale_needs_internal_padding(self):
146147

147148
class TestPostprocessSafetensors:
148149
def test_metadata_injection(self, tmp_path):
149-
from modelopt.torch.export.unified_export_hf import _postprocess_safetensors
150-
151150
sd = {"weight": torch.randn(4, 4)}
152151
save_file(sd, str(tmp_path / "model.safetensors"))
153152

@@ -169,8 +168,6 @@ def test_metadata_injection(self, tmp_path):
169168
}
170169

171170
def test_padding_and_swizzle(self, tmp_path):
172-
from modelopt.torch.export.unified_export_hf import _postprocess_safetensors
173-
174171
sd = _make_nvfp4_state_dict(rows=20, cols=64)
175172
save_file(sd, str(tmp_path / "model.safetensors"))
176173

@@ -187,23 +184,18 @@ def test_padding_and_swizzle(self, tmp_path):
187184
assert reloaded["layer0.weight_scale"].shape == (128, 64 // 16)
188185

189186
def test_sharded_guard(self, tmp_path):
190-
from modelopt.torch.export.unified_export_hf import _postprocess_safetensors
191-
192187
save_file({"w": torch.randn(2, 2)}, str(tmp_path / "model.safetensors"))
193188
(tmp_path / "model.safetensors.index.json").write_text("{}")
194189

195190
with pytest.raises(NotImplementedError, match="sharded"):
196191
_postprocess_safetensors(
197192
tmp_path,
198193
merged_base_safetensor_path="/fake/path.safetensors",
199-
model_type="ltx2",
200194
enable_layerwise_quant_metadata=True,
201195
)
202196

203197
def test_preserves_existing_metadata(self, tmp_path):
204198
"""Simulate save_pretrained output: safetensors with pre-existing metadata."""
205-
from modelopt.torch.export.unified_export_hf import _postprocess_safetensors
206-
207199
sd = _make_nvfp4_state_dict(rows=20, cols=64)
208200
preexisting_metadata = {"format": "pt", "_class_name": "MyModel"}
209201
save_file(sd, str(tmp_path / "model.safetensors"), metadata=preexisting_metadata)
@@ -230,6 +222,22 @@ def test_preserves_existing_metadata(self, tmp_path):
230222
assert "layer0" in layer_meta["layers"]
231223

232224
def test_no_safetensor_files(self, tmp_path):
233-
from modelopt.torch.export.unified_export_hf import _postprocess_safetensors
234-
235225
_postprocess_safetensors(tmp_path)
226+
227+
def test_unknown_kwargs_silently_ignored(self, tmp_path):
228+
sd = {"weight": torch.randn(4, 4)}
229+
save_file(sd, str(tmp_path / "model.safetensors"))
230+
231+
_postprocess_safetensors(tmp_path, bad_option=True)
232+
233+
reloaded = load_file(str(tmp_path / "model.safetensors"))
234+
assert torch.allclose(reloaded["weight"], sd["weight"])
235+
236+
def test_merge_requires_pipe(self, tmp_path):
237+
save_file({"w": torch.randn(2, 2)}, str(tmp_path / "model.safetensors"))
238+
239+
with pytest.raises(ValueError, match="`pipe` must be provided"):
240+
_postprocess_safetensors(
241+
tmp_path,
242+
merged_base_safetensor_path="/fake/path.safetensors",
243+
)

0 commit comments

Comments
 (0)