@@ -56,6 +56,10 @@ class KbitLoraModel(nn.Module):
5656 Set False for non-first pipeline stages.
5757 include_lm_head: Whether to quantize and keep the LM head. Default True.
5858 Set False for non-last pipeline stages.
59+ target_device: Device for quantized weights and LoRA params. If None,
60+ uses the model's device. Set this when loading the HF model on
61+ CPU to stream weights to GPU one layer at a time (minimizes
62+ peak GPU memory). Example: torch.device("cuda:0").
5963 """
6064
6165 def __init__ (
@@ -73,6 +77,7 @@ def __init__(
7377 layer_range : Optional [tuple [int , int ]] = None ,
7478 include_embed : bool = True ,
7579 include_lm_head : bool = True ,
80+ target_device : Optional [torch .device ] = None ,
7681 ):
7782 super ().__init__ ()
7883
@@ -123,10 +128,20 @@ def __init__(
123128 self ._layer_start , self ._layer_end = 0 , total_layers
124129 self ._num_loaded_layers = self ._layer_end - self ._layer_start
125130
131+ # Determine target device for quantized weights.
132+ # When target_device is explicitly set (streaming mode), we free each
133+ # layer from the source model after quantization to save memory.
134+ self ._streaming = target_device is not None
135+ if target_device is not None :
136+ self ._target_device = target_device
137+ else :
138+ self ._target_device = next (model .parameters ()).device
139+
126140 # Keep reference to original model for embeddings
127141 self .model = model
128142 if include_embed :
129- self .embed_tokens = model .model .embed_tokens
143+ # Move embedding to target device (may be CPU->GPU transfer)
144+ self .embed_tokens = model .model .embed_tokens .to (self ._target_device )
130145 else :
131146 self .embed_tokens = None
132147 self .lm_head_tied = hasattr (model , "lm_head" ) and (
@@ -140,7 +155,7 @@ def __init__(
140155
141156 self ._quantize_and_create_lora (model )
142157
143- # Freeze all base model parameters
158+ # Freeze all base model parameters (any that remain)
144159 for p in model .parameters ():
145160 p .requires_grad_ (False )
146161
@@ -151,19 +166,27 @@ def __init__(
151166 p .requires_grad_ (True )
152167
153168 def _quantize_weight (self , weight : torch .Tensor , name : str , k : int | None = None ):
154- """Quantize a weight matrix and store packed data."""
169+ """Quantize a weight matrix and store packed data.
170+
171+ The weight is moved to _target_device for quantization (CUDA kernel),
172+ then the original weight reference is no longer needed.
173+ """
155174 if k is None :
156175 k = self .k
176+ # Move to target device for quantization (CPU -> GPU transfer if needed)
177+ weight = weight .to (self ._target_device )
157178 N , K = weight .shape
158179 N_padded = ((N + 127 ) // 128 ) * 128
159180 if N_padded != N :
160181 w_padded = torch .nn .functional .pad (weight .float (), (0 , 0 , 0 , N_padded - N ))
161182 else :
162183 w_padded = weight .float ()
184+ del weight # Free the fp16 copy on GPU
163185
164186 packed , absmax , codebook = F .quantize_kbit (
165187 w_padded .reshape (- 1 ), k = k , absmax_format = "fp32" ,
166188 )
189+ del w_padded # Free the fp32 padded copy
167190
168191 # Store as non-trainable buffers
169192 safe_name = name .replace ("." , "_" )
@@ -173,9 +196,10 @@ def _quantize_weight(self, weight: torch.Tensor, name: str, k: int | None = None
173196
174197 return packed , absmax , codebook , N_padded , N , K
175198
176- def _create_lora (self , name : str , N : int , K : int , device : torch . device ):
177- """Create LoRA A and B parameters for a weight matrix."""
199+ def _create_lora (self , name : str , N : int , K : int ):
200+ """Create LoRA A and B parameters for a weight matrix on _target_device ."""
178201 safe_name = name .replace ("." , "_" )
202+ device = self ._target_device
179203 # A: [r, K] initialized with Kaiming uniform
180204 A = nn .Parameter (torch .empty (self .lora_r , K , dtype = self .compute_dtype , device = device ))
181205 nn .init .kaiming_uniform_ (A , a = math .sqrt (5 ))
@@ -190,8 +214,13 @@ def _quantize_and_create_lora(self, model: nn.Module):
190214
191215 Only processes layers in [_layer_start, _layer_end) and optionally
192216 skips embedding and LM head for pipeline parallelism.
217+
218+ Streams weights one layer at a time: each layer's weights are moved
219+ from the model's device (often CPU) to _target_device (GPU), quantized,
220+ then the original layer is deleted. This keeps peak GPU memory at
221+ ~1 layer of fp16 weights plus the growing quantized data.
193222 """
194- device = next ( model . parameters ()). device
223+ device = self . _target_device
195224
196225 # Process only the decoder layers in our range
197226 layers = model .model .layers
@@ -207,12 +236,12 @@ def _quantize_and_create_lora(self, model: nn.Module):
207236
208237 # Attention projections (use k_attention)
209238 for proj_name in ["q_proj" , "k_proj" , "v_proj" , "o_proj" ]:
210- weight = getattr (attn , proj_name ).weight .data . to ( device )
239+ weight = getattr (attn , proj_name ).weight .data
211240 name = f"{ prefix } _attn_{ proj_name } "
212241 packed , absmax , codebook , N_padded , N , K = self ._quantize_weight (
213242 weight , name , k = self .k_attention ,
214243 )
215- A , B = self ._create_lora (name , N , K , device )
244+ A , B = self ._create_lora (name , N , K )
216245 layer_info [proj_name ] = {
217246 "packed" : packed , "absmax" : absmax , "codebook" : codebook ,
218247 "N_padded" : N_padded , "N" : N , "K" : K , "A" : A , "B" : B ,
@@ -221,24 +250,24 @@ def _quantize_and_create_lora(self, model: nn.Module):
221250
222251 # MLP projections (use k_mlp)
223252 for proj_name in ["gate_proj" , "up_proj" , "down_proj" ]:
224- weight = getattr (mlp , proj_name ).weight .data . to ( device )
253+ weight = getattr (mlp , proj_name ).weight .data
225254 name = f"{ prefix } _mlp_{ proj_name } "
226255 packed , absmax , codebook , N_padded , N , K = self ._quantize_weight (
227256 weight , name , k = self .k_mlp ,
228257 )
229- A , B = self ._create_lora (name , N , K , device )
258+ A , B = self ._create_lora (name , N , K )
230259 layer_info [proj_name ] = {
231260 "packed" : packed , "absmax" : absmax , "codebook" : codebook ,
232261 "N_padded" : N_padded , "N" : N , "K" : K , "A" : A , "B" : B ,
233262 "k" : self .k_mlp ,
234263 }
235264
236- # Norm weights (trainable, not quantized)
265+ # Norm weights (trainable, not quantized) — move to target device
237266 for norm_name in ["input_layernorm" , "post_attention_layernorm" ]:
238267 norm = getattr (layer , norm_name )
239268 safe = f"{ prefix } _{ norm_name } _weight"
240269 self ._norm_weights [safe ] = nn .Parameter (
241- norm .weight .data .to (self .compute_dtype ).clone ()
270+ norm .weight .data .to (device = device , dtype = self .compute_dtype ).clone ()
242271 )
243272 layer_info [norm_name ] = self ._norm_weights [safe ]
244273
@@ -248,23 +277,31 @@ def _quantize_and_create_lora(self, model: nn.Module):
248277 norm = getattr (attn , norm_name )
249278 safe = f"{ prefix } _attn_{ norm_name } _weight"
250279 self ._norm_weights [safe ] = nn .Parameter (
251- norm .weight .data .to (self .compute_dtype ).clone ()
280+ norm .weight .data .to (device = device , dtype = self .compute_dtype ).clone ()
252281 )
253282 layer_info [norm_name ] = self ._norm_weights [safe ]
254283
255284 self ._layer_data .append (layer_info )
256285
286+ # In streaming mode, free each layer from the source model
287+ # after quantization to release memory (typically CPU RAM).
288+ if self ._streaming :
289+ layers [i ] = nn .Module ()
290+ del layer
291+ if device .type == "cuda" :
292+ torch .cuda .empty_cache ()
293+
257294 # Final norm (only needed by last stage or full model)
258295 if self .include_lm_head :
259296 final_norm = model .model .norm
260297 self ._norm_weights ["final_norm_weight" ] = nn .Parameter (
261- final_norm .weight .data .to (self .compute_dtype ).clone ()
298+ final_norm .weight .data .to (device = device , dtype = self .compute_dtype ).clone ()
262299 )
263300
264301 # LM head (only needed by last stage or full model)
265302 self ._lm_head_info = None
266303 if self .include_lm_head :
267- lm_weight = model .lm_head .weight .data . to ( device )
304+ lm_weight = model .lm_head .weight .data
268305 name = "lm_head"
269306 packed , absmax , codebook , N_padded , N , K = self ._quantize_weight (
270307 lm_weight , name , k = self .k_lm_head ,
0 commit comments