Skip to content

Commit a2d3d21

Browse files
committed
Fixing lint errors
Signed-off-by: ynankani <ynankani@nvidia.com>
1 parent 3312d91 commit a2d3d21

3 files changed

Lines changed: 22 additions & 8 deletions

File tree

modelopt/torch/export/diffusers_utils.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -847,7 +847,10 @@ def _find_nvfp4_layers(state_dict: dict[str, torch.Tensor]) -> set[str]:
847847
s_key = f"{layer}.weight_scale"
848848
if s_key not in state_dict or w_key not in state_dict:
849849
continue
850-
if state_dict[w_key].dtype == torch.uint8 and state_dict[s_key].dtype == torch.float8_e4m3fn:
850+
if (
851+
state_dict[w_key].dtype == torch.uint8
852+
and state_dict[s_key].dtype == torch.float8_e4m3fn
853+
):
851854
layers.add(layer)
852855
return layers
853856

@@ -882,7 +885,9 @@ def _roundup(a: int, b: int) -> int:
882885
rows, cols_w = weight.shape
883886
pad_r = _roundup(rows, 16) - rows
884887
pad_c_w = (_roundup(cols_w, 16) - cols_w) if padding_strategy == "row_col" else 0
885-
pad_c_s = (_roundup(scale.shape[1], 16) - scale.shape[1]) if padding_strategy == "row_col" else 0
888+
pad_c_s = (
889+
(_roundup(scale.shape[1], 16) - scale.shape[1]) if padding_strategy == "row_col" else 0
890+
)
886891

887892
if pad_r > 0 or pad_c_w > 0:
888893
state_dict[w_key] = torch.nn.functional.pad(weight, (0, pad_c_w, 0, pad_r))
@@ -922,7 +927,9 @@ def _to_blocked(input_matrix: torch.Tensor) -> torch.Tensor:
922927
padded = input_matrix
923928
if (rows, cols) != (padded_rows, padded_cols):
924929
padded = torch.zeros(
925-
(padded_rows, padded_cols), device=input_matrix.device, dtype=input_matrix.dtype,
930+
(padded_rows, padded_cols),
931+
device=input_matrix.device,
932+
dtype=input_matrix.dtype,
926933
)
927934
padded[:rows, :cols] = input_matrix
928935
blocks = padded.view(n_row_blocks, 128, n_col_blocks, 4).permute(0, 2, 1, 3)
@@ -935,7 +942,6 @@ def _to_blocked(input_matrix: torch.Tensor) -> torch.Tensor:
935942
s_key = f"{layer}.weight_scale"
936943
state_dict[s_key] = _to_blocked(state_dict[s_key].to(torch.float8_e4m3fn))
937944

938-
939945
return state_dict
940946

941947

modelopt/torch/export/unified_export_hf.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,13 +34,13 @@
3434
import diffusers
3535

3636
from .diffusers_utils import (
37+
build_layerwise_quant_metadata,
3738
generate_diffusion_dummy_forward_fn,
3839
get_diffusion_components,
3940
get_diffusion_model_type,
4041
get_qkv_group_key,
4142
hide_quantizers_from_state_dict,
4243
infer_dtype_from_model,
43-
build_layerwise_quant_metadata,
4444
is_diffusers_object,
4545
is_qkv_projection,
4646
merge_diffusion_checkpoint,
@@ -196,6 +196,10 @@ def _postprocess_safetensors(
196196
header = json.loads(f.read(header_size))
197197
metadata = header.get("__metadata__", None) or {}
198198

199+
# Clone tensors so the memory-mapped file handle from load_file is
200+
# released before we overwrite the same path (required on Windows).
201+
sd = {k: v.clone() for k, v in sd.items()}
202+
199203
if merged_base_safetensor_path is not None and model_type is not None:
200204
sd, base_metadata = merge_diffusion_checkpoint(
201205
sd, merged_base_safetensor_path, model_type, hf_quant_config
@@ -1237,7 +1241,7 @@ def export_hf_checkpoint(
12371241
merged_base_safetensor_path: str | None = kwargs.get("merged_base_safetensor_path")
12381242
enable_layerwise_quant_metadata: bool = kwargs.get("enable_layerwise_quant_metadata", True)
12391243
enable_swizzle_layout: bool = kwargs.get("enable_swizzle_layout", False)
1240-
padding_strategy: str | None = kwargs.get("padding_strategy", None)
1244+
padding_strategy: str | None = kwargs.get("padding_strategy")
12411245
export_dir = Path(export_dir)
12421246
export_dir.mkdir(parents=True, exist_ok=True)
12431247

modelopt/torch/quantization/conversion.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -503,14 +503,18 @@ def set_quantizer_by_cfg_context(quant_model: nn.Module, quant_cfg: QuantizeQuan
503503
if orig_type is SequentialQuantizer and not isinstance(module, SequentialQuantizer):
504504
saved = original_attributes[name]
505505
parent_name, _, attr_name = name.rpartition(".")
506-
parent_module = quant_model.get_submodule(parent_name) if parent_name else quant_model
506+
parent_module = (
507+
quant_model.get_submodule(parent_name) if parent_name else quant_model
508+
)
507509
module = SequentialQuantizer(*(TensorQuantizer() for _ in saved["sub_states"]))
508510
setattr(parent_module, attr_name, module)
509511
for tq, sub_state in zip(module, saved["sub_states"]):
510512
tq.set_from_modelopt_state(sub_state, properties_only=True)
511513
elif orig_type is TensorQuantizer and not isinstance(module, TensorQuantizer):
512514
parent_name, _, attr_name = name.rpartition(".")
513-
parent_module = quant_model.get_submodule(parent_name) if parent_name else quant_model
515+
parent_module = (
516+
quant_model.get_submodule(parent_name) if parent_name else quant_model
517+
)
514518
module = TensorQuantizer()
515519
setattr(parent_module, attr_name, module)
516520
module.set_from_modelopt_state(original_attributes[name], properties_only=True)

0 commit comments

Comments
 (0)