@@ -55,6 +55,14 @@ def get_all_config_set() -> Union[BaseConfig, List[BaseConfig]]:
5555
5656
5757def _deepcopy_warp (model ):
58+ """Create a deep copy of the model while preserving specific attributes.
59+
60+ Args:
61+ model (torch.nn.Module): The model to deep copy.
62+
63+ Returns:
64+ torch.nn.Module: A deep copy of the model with preserved attributes.
65+ """
5866 additional_attr_lst = ["_exported" , "dynamic_shapes" ]
5967 original_attr = {key : getattr (model , key , None ) for key in additional_attr_lst }
6068 new_model = deepcopy (model )
@@ -64,7 +72,15 @@ def _deepcopy_warp(model):
6472
6573
6674def _preprocess_model_quant_config (model , quant_config ):
67- """Preprocess model and quant config before quantization."""
75+ """Preprocess model and quant config before quantization.
76+
77+ Args:
78+ model (torch.nn.Module): The model to be quantized.
79+ quant_config (TuningConfig): The quantization configuration to preprocess.
80+
81+ Returns:
82+ Tuple[torch.nn.Module, TuningConfig]: The preprocessed model and quantization configuration.
83+ """
6884 for config in quant_config .config_set :
6985 # handle tokenizer attribute in AutoRoundConfig
7086 if isinstance (config , AutoRoundConfig ):
@@ -88,8 +104,8 @@ def autotune(
88104 """The main entry of auto-tune.
89105
90106 Args:
91- model (torch.nn.Module): _description_
92- tune_config (TuningConfig): _description_
107+ model (torch.nn.Module): The model to be quantized.
108+ tune_config (TuningConfig): The configuration for the auto-tuning process.
93109 eval_fn (Callable): for evaluation of quantized models.
94110 eval_args (tuple, optional): arguments used by eval_fn. Defaults to None.
95111 run_fn (Callable, optional): for calibration to quantize model. Defaults to None.
0 commit comments