1- """GPU detection, VRAM guidance, ROCm gfx arch auto-detection, and CPU weight offloading ."""
1+ """GPU detection, VRAM guidance, and ROCm gfx arch auto-detection."""
22
33import os
44
55import torch
6- import torch .nn as nn
76from loguru import logger
87
98# Known ROCm gfx arch overrides for GPUs not yet in PyTorch's HIP target list.
@@ -94,7 +93,6 @@ def check_vram_and_advise(checkpoint_path: str):
9493 suggestions .append (
9594 f"reduce MAX_SEQ_LEN (current: { max_seq_len } , try 4096 to save ~{ (max_seq_len - 4096 ) / 8192 * 1.2 :.1f} GB)"
9695 )
97- suggestions .append ("set OFFLOAD_WEIGHTS_TO_CPU=true to run slow layers on CPU" )
9896 suggestions .append ("set VRAM_FRACTION=0.95 to prevent system freeze on OOM" )
9997
10098 logger .warning (
@@ -103,119 +101,3 @@ def check_vram_and_advise(checkpoint_path: str):
103101 )
104102 for i , s in enumerate (suggestions , 1 ):
105103 logger .warning (f" { i } . { s } " )
106-
107-
108- class CPUOffloadExecutor :
109- """Runs slow transformer layers on CPU (using AVX-512/VNNI), keeps fast path on GPU.
110-
111- Instead of streaming layers GPU↔CPU (72 PCIe round-trips per token),
112- this executes the slow transformer entirely on CPU and only transfers
113- the final hidden state (~10KB) to GPU for the fast transformer + decoder.
114-
115- For batch=1 single-token inference, CPU execution with DDR5 bandwidth
116- (~80-100 GB/s) and AVX-512 is competitive with the PCIe streaming approach
117- while eliminating all allocation overhead.
118- """
119-
120- def __init__ (self , gpu_device : torch .device ):
121- self .gpu_device = gpu_device
122-
123- def run (self , layers : nn .ModuleList , x , * args , ** kwargs ):
124- """Execute layers on CPU, return result on pinned memory for fast GPU transfer."""
125- # Move hidden state and all positional args to CPU
126- x_cpu = x .to ("cpu" )
127- args_cpu = tuple (a .to ("cpu" ) if isinstance (a , torch .Tensor ) else a for a in args )
128- kwargs_cpu = {
129- k : v .to ("cpu" ) if isinstance (v , torch .Tensor ) else v
130- for k , v in kwargs .items ()
131- }
132-
133- # Run all layers on CPU — weights are already here, no PCIe needed
134- for layer in layers :
135- x_cpu = layer (x_cpu , * args_cpu , ** kwargs_cpu )
136-
137- # Pin the result for faster DMA to GPU, then transfer non-blocking
138- return x_cpu .pin_memory ().to (self .gpu_device , non_blocking = True )
139-
140-
141- def _has_int8_weights (module : nn .Module ) -> bool :
142- """Check if any submodule uses INT8 quantized weights."""
143- for child in module .modules ():
144- if hasattr (child , "weight" ) and hasattr (child , "scales" ) and child .weight .dtype == torch .int8 :
145- return True
146- return False
147-
148-
149- def setup_cpu_offload (model : nn .Module , device : torch .device ):
150- """Offload slow transformer layers to CPU execution.
151-
152- Moves slow layer weights + KV caches to CPU. The slow transformer runs
153- entirely on CPU using AVX-512, and only the final hidden state is
154- transferred to GPU for the fast transformer and decoder.
155-
156- Fast layers stay on GPU (small footprint, called 10x per token).
157-
158- Enable with OFFLOAD_WEIGHTS_TO_CPU=true.
159- Requires native bf16 weights — INT8 quantized models are not supported
160- because autoregressive decode (M=1) cannot use VNNI _int_mm (requires M>16),
161- and the dequant+bf16 fallback is ~30% slower than native bf16 matmuls.
162- """
163- if not os .environ .get ("OFFLOAD_WEIGHTS_TO_CPU" , "" ).lower () in ("true" , "1" ):
164- return False
165-
166- if not hasattr (model , "layers" ):
167- logger .warning ("Model has no 'layers' attribute, cannot offload weights." )
168- return False
169-
170- if _has_int8_weights (model ):
171- logger .warning (
172- "CPU offload requires native bf16 weights. INT8 quantized models are not supported "
173- "because autoregressive decode (batch=1) cannot use VNNI INT8 matmuls (requires M>16), "
174- "and the dequant+bf16 fallback is ~30% slower than native bf16. "
175- "Please use the original (non-quantized) checkpoint with OFFLOAD_WEIGHTS_TO_CPU=true."
176- )
177- return False
178-
179- # Use physical cores only — HyperThreading causes cache contention
180- # on Zen 4 and hurts bf16 matmul throughput (~37% slower with HT).
181- physical_cores = os .cpu_count () // 2 if os .cpu_count () else 8
182- torch .set_num_threads (physical_cores )
183- logger .info (f"CPU offload: set torch threads to { physical_cores } (physical cores only)" )
184-
185- layers = model .layers
186- n_layers = len (layers )
187-
188- # Move slow layers entirely to CPU (including KV caches)
189- gpu_mem_before = torch .cuda .memory_allocated ()
190- with torch .inference_mode (False ):
191- for layer in layers :
192- layer .to ("cpu" )
193- gpu_mem_after = torch .cuda .memory_allocated ()
194- saved_gb = (gpu_mem_before - gpu_mem_after ) / 1e9
195-
196- # Move shared slow-path modules to CPU.
197- # Keep causal_mask and fast_freqs_cis on GPU (shared with fast path).
198- for name in ("norm" , "embeddings" , "codebook_embeddings" , "output" ):
199- module = getattr (model , name , None )
200- if module is not None :
201- with torch .inference_mode (False ):
202- if isinstance (module , nn .Module ):
203- module .to ("cpu" )
204-
205- gpu_mem_final = torch .cuda .memory_allocated ()
206- total_saved_gb = (gpu_mem_before - gpu_mem_final ) / 1e9
207-
208- logger .info (
209- f"CPU offload: moved { n_layers } slow layers + shared modules to CPU, "
210- f"freed { total_saved_gb :.1f} GB VRAM. Fast layers + decoder remain on GPU."
211- )
212-
213- # Keep fast_layers on GPU — small footprint, called 10x per token
214- fast_layers = getattr (model , "fast_layers" , None )
215- if fast_layers is not None :
216- logger .info (f"CPU offload: keeping { len (fast_layers )} fast layers on GPU." )
217-
218- # Attach executor — forward_generate will use it
219- model ._layer_streamer = CPUOffloadExecutor (device )
220-
221- return True
0 commit comments