3232__all__ = ["AWQ" ]
3333
3434
35+ def _remove_accelerate_hooks (module ):
36+ for submodule in module .modules ():
37+ if hasattr (submodule , "_hf_hook" ):
38+ try :
39+ from accelerate .hooks import remove_hook_from_module
40+ remove_hook_from_module (submodule )
41+ except ImportError :
42+ # Should not happen if _hf_hook is present
43+ delattr (submodule , "_hf_hook" )
44+ if hasattr (submodule , "_old_forward" ):
45+ submodule .forward = submodule ._old_forward
46+ delattr (submodule , "_old_forward" )
47+
48+
3549class AWQ :
3650 def __init__ (
3751 self ,
@@ -156,6 +170,7 @@ def run(self, dataloader):
156170 f"GPU Memory: { torch .cuda .memory_allocated () / 1024 ** 2 :.2f} MB"
157171 )
158172
173+ _remove_accelerate_hooks (layers [i ])
159174 layer = layers [i ].to (dev )
160175 if not self .low_memory :
161176 outs = outs .to (dev )
@@ -232,6 +247,10 @@ def deduplicate_tensors(tensor_list):
232247 )
233248
234249 self .scale_function .apply_scale (layer , scales_list , input_feat )
250+
251+ # Fix: Ensure all submodules are on the same device after apply_scale
252+ # In low_memory mode, apply_scale may move weights to different devices
253+ layer = layer .to (dev )
235254 for scales in scales_list :
236255 name = "language_model.encoder.layers.{}.{}.scale" .format (i , scales [0 ])
237256 self .scales_dict [name ] = scales [2 ]
@@ -240,6 +259,9 @@ def deduplicate_tensors(tensor_list):
240259 clip_list = self .clip_function .auto_clip (layer , input_feat )
241260 self .clip_function .apply_clip (layer , clip_list )
242261
262+ # Fix: Ensure all submodules are on the same device after apply_clip
263+ layer = layer .to (dev )
264+
243265 for j in range (min (self .inps .shape [1 ], nsamples )):
244266 with torch .no_grad ():
245267 outs [j , :, :] = layer (
0 commit comments