3333
3434def _get_cpu_offload_hook (hook ):
3535 if isinstance (hook , AlignDevicesHook ) and hook .offload and hook .weights_map is not None :
36- assert "weight" in hook .weights_map
36+ assert len ( hook .weights_map ) > 0
3737 if (
3838 isinstance (hook .weights_map , PrefixedDataset )
3939 and hook .weights_map .prefix + "weight" not in hook .weights_map .dataset .state_dict
@@ -50,32 +50,79 @@ def _get_cpu_offload_hook(hook):
5050 return None
5151
5252
53+ def _writeback_params_to_weights_map (module , align_hook ):
54+ """Write all non-meta parameters back to the hook's CPU weights_map."""
55+ for name , param in module .named_parameters ():
56+ if param .device .type == "meta" :
57+ continue
58+ if isinstance (align_hook .weights_map , PrefixedDataset ):
59+ key = align_hook .weights_map .prefix + name
60+ w_map = align_hook .weights_map .dataset .state_dict
61+ else :
62+ w_map = align_hook .weights_map
63+ key = name
64+ if key in w_map :
65+ w_map [key ] = param .data .to (w_map [key ].device , dtype = w_map [key ].dtype )
66+
67+
5368@contextmanager
5469def weight_access_and_writeback_context (module ):
55- """Context manager for weight access and writeback for modules managed by accelerate."""
70+ """Context manager for weight access and writeback for modules managed by accelerate.
71+
72+ Handles two cases:
73+ 1. **Single-module**: the module's own ``_hf_hook`` is an offload hook.
74+ 2. **Sub-module**: the module's hook is non-offloading, but its children have
75+ offload hooks (common with ``SequentialHook`` on sub-modules placed by
76+ ``load_checkpoint_and_dispatch``).
77+
78+ For the sub-module case, ``pre_forward`` is skipped on sub-modules whose weights
79+ are already materialized (not on meta). This allows the context manager to be
80+ used as a pure writeback after weight-modifying algorithms.
81+ """
5682 assert hasattr (module , "_hf_hook" )
5783 align_hook = _get_cpu_offload_hook (module ._hf_hook )
5884
5985 if align_hook :
60- # Accelerate uses AlignDevicesHook to offload weights to CPU/Disk and then reload them in the forward pass
61- # The CPU/Disk offloaded weights are managed by PrefixDataset and OffloadedWeightsLoader
62- # See https://github.com/huggingface/accelerate/blame/f48d95c4939b281505a45b3d6e0bf554b65cc1ea/src/accelerate/utils/offload.py#L104-L141
63- # TODO: Add support for disk-offloaded models if needed (they will be really slow, hence low priority)
64-
65- # This will load the weights from CPU state_dict and move it to the GPU from meta device
86+ # Guard: the sub-module branch below is not reached when the parent has
87+ # an offload hook. Assert that no children also carry offload hooks,
88+ # which would require a combined writeback strategy.
89+ assert not any (
90+ _get_cpu_offload_hook (mod ._hf_hook )
91+ for mod in module .modules ()
92+ if mod is not module and hasattr (mod , "_hf_hook" )
93+ ), (
94+ "Both the module and one of its sub-modules have CPU-offload hooks. "
95+ "weight_access_and_writeback_context does not support this layout yet."
96+ )
6697 align_hook .pre_forward (module )
98+ try :
99+ yield
100+ finally :
101+ _writeback_params_to_weights_map (module , align_hook )
102+ align_hook .post_forward (module , None )
103+ return
104+
105+ materialized : list [tuple [torch .nn .Module , AlignDevicesHook , bool ]] = []
106+ for mod in module .modules ():
107+ if mod is module or not hasattr (mod , "_hf_hook" ):
108+ continue
109+ hook = _get_cpu_offload_hook (mod ._hf_hook )
110+ if hook is None :
111+ continue
112+ # Only call pre_forward if weights need materializing; already-materialized
113+ # weights would be overwritten with stale CPU state_dict values.
114+ needs_materialize = any (p .device .type == "meta" for p in mod .parameters ())
115+ if needs_materialize :
116+ hook .pre_forward (mod )
117+ materialized .append ((mod , hook , needs_materialize ))
118+
67119 try :
68120 yield
69121 finally :
70- if align_hook :
71- # Update the weight in the CPU state_dict
72- if isinstance (align_hook .weights_map , PrefixedDataset ):
73- key = align_hook .weights_map .prefix + "weight"
74- w_map = align_hook .weights_map .dataset .state_dict
75- else :
76- key , w_map = "weight" , align_hook .weights_map
77- w_map [key ] = module .weight .data .to (w_map [key ].device , dtype = w_map [key ].dtype )
78- align_hook .post_forward (module , None )
122+ for mod , hook , was_materialized in materialized :
123+ _writeback_params_to_weights_map (mod , hook )
124+ if was_materialized :
125+ hook .post_forward (mod , None )
79126
80127
81128@contextmanager
0 commit comments