2727__all__ = ["export_hf_vllm_fq_checkpoint" ]
2828
2929
30- def cleanup_for_torch_save (x : Any ) -> Any :
31- """Drop callables / local closures (e.g. `<locals>.new_forward`) before torch.save.
32-
33- ModelOpt stored state dict may contain local closures like `<locals>.new_forward`
34- which are not picklable. So we need to cleanup the state dict before saving.
35- """
36- if isinstance (x , dict ):
37- return {
38- k : cleanup_for_torch_save (v )
39- for k , v in x .items ()
40- if not callable (v ) and "<locals>" not in str (getattr (v , "__qualname__" , "" ))
41- }
42- if isinstance (x , list ):
43- return [cleanup_for_torch_save (v ) for v in x ]
44- if isinstance (x , tuple ):
45- return tuple (cleanup_for_torch_save (v ) for v in x )
46- return x
47-
48-
4930def export_hf_vllm_fq_checkpoint (
5031 model : nn .Module ,
5132 export_dir : Path | str ,
@@ -68,8 +49,7 @@ def export_hf_vllm_fq_checkpoint(
6849 quantizer_state_dict = get_quantizer_state_dict (model )
6950
7051 modelopt_state = mto .modelopt_state (model )
71- modelopt_state = cleanup_for_torch_save (modelopt_state )
72- modelopt_state ["modelopt_state_weights" ] = cleanup_for_torch_save (quantizer_state_dict )
52+ modelopt_state ["modelopt_state_weights" ] = quantizer_state_dict
7353 torch .save (modelopt_state , export_dir / "vllm_fq_modelopt_state.pth" )
7454 # remove quantizer from model
7555 for _ , module in model .named_modules ():
0 commit comments