diff --git a/angelslim/compressor/quant/modules/awq/awq.py b/angelslim/compressor/quant/modules/awq/awq.py index f8407536..3ec56f9c 100644 --- a/angelslim/compressor/quant/modules/awq/awq.py +++ b/angelslim/compressor/quant/modules/awq/awq.py @@ -32,6 +32,20 @@ __all__ = ["AWQ"] +def _remove_accelerate_hooks(module): + for submodule in module.modules(): + if hasattr(submodule, "_hf_hook"): + try: + from accelerate.hooks import remove_hook_from_module + remove_hook_from_module(submodule) + except ImportError: + # Should not happen if _hf_hook is present + delattr(submodule, "_hf_hook") + if hasattr(submodule, "_old_forward"): + submodule.forward = submodule._old_forward + delattr(submodule, "_old_forward") + + class AWQ: def __init__( self, @@ -156,6 +170,7 @@ def run(self, dataloader): f"GPU Memory: {torch.cuda.memory_allocated() / 1024 ** 2:.2f} MB" ) + _remove_accelerate_hooks(layers[i]) layer = layers[i].to(dev) if not self.low_memory: outs = outs.to(dev) @@ -232,6 +247,10 @@ def deduplicate_tensors(tensor_list): ) self.scale_function.apply_scale(layer, scales_list, input_feat) + + # Fix: Ensure all submodules are on the same device after apply_scale + # In low_memory mode, apply_scale may move weights to different devices + layer = layer.to(dev) for scales in scales_list: name = "language_model.encoder.layers.{}.{}.scale".format(i, scales[0]) self.scales_dict[name] = scales[2] @@ -240,6 +259,9 @@ def deduplicate_tensors(tensor_list): clip_list = self.clip_function.auto_clip(layer, input_feat) self.clip_function.apply_clip(layer, clip_list) + # Fix: Ensure all submodules are on the same device after apply_clip + layer = layer.to(dev) + for j in range(min(self.inps.shape[1], nsamples)): with torch.no_grad(): outs[j, :, :] = layer(