From 151af0f6113106bb45392f56c89e1d0494408c68 Mon Sep 17 00:00:00 2001 From: leoobai Date: Wed, 24 Dec 2025 17:58:00 +0800 Subject: [PATCH] =?UTF-8?q?fix(awq):=20=E4=BF=AE=E5=A4=8D=E4=BD=8E?= =?UTF-8?q?=E5=86=85=E5=AD=98=E6=A8=A1=E5=BC=8F=E4=B8=8B=E8=AE=BE=E5=A4=87?= =?UTF-8?q?=E4=B8=8D=E4=B8=80=E8=87=B4=E9=97=AE=E9=A2=98=E5=B9=B6=E7=A7=BB?= =?UTF-8?q?=E9=99=A4accelerate=20hooks?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 确保在apply_scale和apply_clip后所有子模块在同一设备上 添加_remove_accelerate_hooks函数来清理accelerate的hooks --- angelslim/compressor/quant/modules/awq/awq.py | 22 +++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/angelslim/compressor/quant/modules/awq/awq.py b/angelslim/compressor/quant/modules/awq/awq.py index f8407536..3ec56f9c 100644 --- a/angelslim/compressor/quant/modules/awq/awq.py +++ b/angelslim/compressor/quant/modules/awq/awq.py @@ -32,6 +32,20 @@ __all__ = ["AWQ"] +def _remove_accelerate_hooks(module): + for submodule in module.modules(): + if hasattr(submodule, "_hf_hook"): + try: + from accelerate.hooks import remove_hook_from_module + remove_hook_from_module(submodule) + except ImportError: + # Should not happen if _hf_hook is present + delattr(submodule, "_hf_hook") + if hasattr(submodule, "_old_forward"): + submodule.forward = submodule._old_forward + delattr(submodule, "_old_forward") + + class AWQ: def __init__( self, @@ -156,6 +170,7 @@ def run(self, dataloader): f"GPU Memory: {torch.cuda.memory_allocated() / 1024 ** 2:.2f} MB" ) + _remove_accelerate_hooks(layers[i]) layer = layers[i].to(dev) if not self.low_memory: outs = outs.to(dev) @@ -232,6 +247,10 @@ def deduplicate_tensors(tensor_list): ) self.scale_function.apply_scale(layer, scales_list, input_feat) + + # Fix: Ensure all submodules are on the same device after apply_scale + # In low_memory mode, apply_scale may move weights to different devices + layer = layer.to(dev) for scales in scales_list: name = "language_model.encoder.layers.{}.{}.scale".format(i, scales[0]) self.scales_dict[name] = scales[2] @@ -240,6 +259,9 @@ def deduplicate_tensors(tensor_list): clip_list = self.clip_function.auto_clip(layer, input_feat) self.clip_function.apply_clip(layer, clip_list) + # Fix: Ensure all submodules are on the same device after apply_clip + layer = layer.to(dev) + for j in range(min(self.inps.shape[1], nsamples)): with torch.no_grad(): outs[j, :, :] = layer(