Skip to content

Commit 760fc97

Browse files
ChenhanYuclaude
andcommitted
fix: export review feedback - dtype mismatch, default, validation
- Update torch_dtype in config.json when export(dtype=...) overrides it - Remove hardcoded num_target_layers=36 default (use base_config directly) - Add state dict validation (assert non-empty after extraction) Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: Chenhan Yu <chenhany@nvidia.com>
1 parent 8ab1b25 commit 760fc97

File tree

1 file changed

+5
-1
lines changed

1 file changed

+5
-1
lines changed

modelopt/torch/export/plugins/hf_spec_export.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -340,7 +340,7 @@ def _export_config(self):
340340
"torch_dtype": str(getattr(base_config, "torch_dtype", torch.bfloat16)).replace(
341341
"torch.", ""
342342
),
343-
"num_target_layers": getattr(base_config, "num_hidden_layers", 36),
343+
"num_target_layers": base_config.num_hidden_layers,
344344
}
345345

346346
# Add layer_types if present (Qwen3-style)
@@ -366,12 +366,16 @@ def export(self, export_dir: Path | str, dtype: torch.dtype | None = None):
366366

367367
# Export state dict
368368
drafter_sd = self._extract_state_dict(full_sd)
369+
assert drafter_sd, "No dflash_module weights found in state dict"
369370
if dtype is not None and hf_quant_config is None:
370371
drafter_sd = {k: v.to(dtype) for k, v in drafter_sd.items()}
371372
save_file(drafter_sd, f"{export_dir}/model.safetensors")
372373

373374
# Export config
374375
drafter_config = self._export_config()
376+
# Update torch_dtype to match actual exported weights
377+
if dtype is not None:
378+
drafter_config["torch_dtype"] = str(dtype).replace("torch.", "")
375379
if hf_quant_config is not None:
376380
drafter_config["quantization_config"] = hf_quant_config
377381
with open(f"{export_dir}/config.json", "w") as f:

0 commit comments

Comments
 (0)