Skip to content

Commit 7196692

Browse files
committed
removed cleanup_for_torch_save
Signed-off-by: Kinjal Patel <kinjalpravin@nvidia.com>
1 parent fa9b770 commit 7196692

1 file changed

Lines changed: 1 addition & 21 deletions

File tree

modelopt/torch/export/plugins/vllm_fakequant_hf.py

Lines changed: 1 addition & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -27,25 +27,6 @@
2727
__all__ = ["export_hf_vllm_fq_checkpoint"]
2828

2929

30-
def cleanup_for_torch_save(x: Any) -> Any:
31-
"""Drop callables / local closures (e.g. `<locals>.new_forward`) before torch.save.
32-
33-
ModelOpt stored state dict may contain local closures like `<locals>.new_forward`
34-
which are not picklable. So we need to cleanup the state dict before saving.
35-
"""
36-
if isinstance(x, dict):
37-
return {
38-
k: cleanup_for_torch_save(v)
39-
for k, v in x.items()
40-
if not callable(v) and "<locals>" not in str(getattr(v, "__qualname__", ""))
41-
}
42-
if isinstance(x, list):
43-
return [cleanup_for_torch_save(v) for v in x]
44-
if isinstance(x, tuple):
45-
return tuple(cleanup_for_torch_save(v) for v in x)
46-
return x
47-
48-
4930
def export_hf_vllm_fq_checkpoint(
5031
model: nn.Module,
5132
export_dir: Path | str,
@@ -68,8 +49,7 @@ def export_hf_vllm_fq_checkpoint(
6849
quantizer_state_dict = get_quantizer_state_dict(model)
6950

7051
modelopt_state = mto.modelopt_state(model)
71-
modelopt_state = cleanup_for_torch_save(modelopt_state)
72-
modelopt_state["modelopt_state_weights"] = cleanup_for_torch_save(quantizer_state_dict)
52+
modelopt_state["modelopt_state_weights"] = quantizer_state_dict
7353
torch.save(modelopt_state, export_dir / "vllm_fq_modelopt_state.pth")
7454
# remove quantizer from model
7555
for _, module in model.named_modules():

0 commit comments

Comments
 (0)