Skip to content

Commit 5721a0d

Browse files
authored
fix(awq): apply_scale and apply_clip device and add _remove_accelerate_hooks (#184)
1 parent c1d01ca commit 5721a0d

1 file changed

Lines changed: 22 additions & 0 deletions

File tree

  • angelslim/compressor/quant/modules/awq

angelslim/compressor/quant/modules/awq/awq.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,20 @@
3232
__all__ = ["AWQ"]
3333

3434

35+
def _remove_accelerate_hooks(module):
36+
for submodule in module.modules():
37+
if hasattr(submodule, "_hf_hook"):
38+
try:
39+
from accelerate.hooks import remove_hook_from_module
40+
remove_hook_from_module(submodule)
41+
except ImportError:
42+
# Should not happen if _hf_hook is present
43+
delattr(submodule, "_hf_hook")
44+
if hasattr(submodule, "_old_forward"):
45+
submodule.forward = submodule._old_forward
46+
delattr(submodule, "_old_forward")
47+
48+
3549
class AWQ:
3650
def __init__(
3751
self,
@@ -156,6 +170,7 @@ def run(self, dataloader):
156170
f"GPU Memory: {torch.cuda.memory_allocated() / 1024 ** 2:.2f} MB"
157171
)
158172

173+
_remove_accelerate_hooks(layers[i])
159174
layer = layers[i].to(dev)
160175
if not self.low_memory:
161176
outs = outs.to(dev)
@@ -232,6 +247,10 @@ def deduplicate_tensors(tensor_list):
232247
)
233248

234249
self.scale_function.apply_scale(layer, scales_list, input_feat)
250+
251+
# Fix: Ensure all submodules are on the same device after apply_scale
252+
# In low_memory mode, apply_scale may move weights to different devices
253+
layer = layer.to(dev)
235254
for scales in scales_list:
236255
name = "language_model.encoder.layers.{}.{}.scale".format(i, scales[0])
237256
self.scales_dict[name] = scales[2]
@@ -240,6 +259,9 @@ def deduplicate_tensors(tensor_list):
240259
clip_list = self.clip_function.auto_clip(layer, input_feat)
241260
self.clip_function.apply_clip(layer, clip_list)
242261

262+
# Fix: Ensure all submodules are on the same device after apply_clip
263+
layer = layer.to(dev)
264+
243265
for j in range(min(self.inps.shape[1], nsamples)):
244266
with torch.no_grad():
245267
outs[j, :, :] = layer(

0 commit comments

Comments
 (0)