Skip to content

Commit 8ec8d21

Browse files
committed
refactor: remove CPU offload from ROCm branch, bump ROCm 7.2.1
Separate CPU offload work into its own branch (feat/cpu-offload) to keep the ROCm support PR focused on infrastructure. This branch now contains only: - ROCm 7.2 Docker/compose infrastructure - Device-agnostic code replacements - VQ-GAN precision pass-through and VRAM_FRACTION cap - MAX_SEQ_LEN configurable KV cache - ROCm gfx arch auto-detection and VRAM guidance - MIOpen exhaustive kernel tuning (MIOPEN_FIND_MODE=3) - Persistent MIOpen cache volume Tested on ROCm 7.2.1 with RX 9070 XT (16GB): - Full GPU INT8 + COMPILE=1 + MIOpen tuning: 34.9s (was 85s on 7.2.0) - 2.4x performance improvement from ROCm 7.2.1 HIP/MIOpen fixes
1 parent 99b370f commit 8ec8d21

3 files changed

Lines changed: 10 additions & 172 deletions

File tree

fish_speech/models/text2semantic/inference.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -768,13 +768,6 @@ def worker():
768768
dtype=next(model.parameters()).dtype,
769769
)
770770

771-
# Offload weights to pinned CPU memory if requested.
772-
# Runs after setup_caches so KV caches exist and can be
773-
# preserved on GPU while layer weights move to CPU.
774-
from fish_speech.utils.gpu import setup_cpu_offload
775-
776-
setup_cpu_offload(model, torch.device(device))
777-
778771
init_event.set()
779772

780773
while True:

fish_speech/models/text2semantic/llama.py

Lines changed: 9 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -396,28 +396,22 @@ def forward_generate(
396396
return_all: bool = False,
397397
) -> BaseTransformerForwardResult:
398398

399-
# When CPU offload is active, embeddings are on CPU — move input there
400-
# Capture original device before any moves (for returning results to GPU)
401-
_orig_device = inp.device
402-
embed_device = self.embeddings.weight.device
403-
inp_e = inp.to(embed_device) if inp.device != embed_device else inp
404-
405399
# Embedding logic replicated from embed() for compilation compatibility
406400
embeds = []
407401
for i in range(self.config.num_codebooks):
408402
emb = self.codebook_embeddings(
409-
inp_e[:, i + 1] + i * self.config.codebook_size
403+
inp[:, i + 1] + i * self.config.codebook_size
410404
)
411405
embeds.append(emb)
412406

413407
vq_embeds_sum = torch.stack(embeds, dim=1).sum(dim=1)
414408

415-
vq_masks = (inp_e[:, 0] >= self.config.semantic_begin_id) & (
416-
inp_e[:, 0] <= self.config.semantic_end_id
409+
vq_masks = (inp[:, 0] >= self.config.semantic_begin_id) & (
410+
inp[:, 0] <= self.config.semantic_end_id
417411
)
418412

419413
vq_embeds_sum[~vq_masks] = 0
420-
x = self.embeddings(inp_e[:, 0]) + vq_embeds_sum
414+
x = self.embeddings(inp[:, 0]) + vq_embeds_sum
421415

422416
if self.config.scale_codebook_embeddings:
423417
vq_masks_expanded = vq_masks.unsqueeze(-1).expand_as(x)
@@ -427,12 +421,14 @@ def forward_generate(
427421

428422
# Audio embeddings
429423
if audio_parts is not None:
424+
# Note: This assumes self.audio_projector exists if audio_parts is used
425+
# It seems missing in init, but we keep existing logic
430426
if hasattr(self, "audio_projector"):
431-
audio_embeds = self.audio_projector(audio_parts.to(embed_device))
427+
audio_embeds = self.audio_projector(audio_parts)
432428
if self.config.scale_codebook_embeddings:
433-
x[audio_masks.to(embed_device)] = audio_embeds / math.sqrt(2)
429+
x[audio_masks] = audio_embeds / math.sqrt(2)
434430
else:
435-
x[audio_masks.to(embed_device)] = audio_embeds
431+
x[audio_masks] = audio_embeds
436432
else:
437433
logger.warning("audio_parts provided but model has no audio_projector")
438434

@@ -445,39 +441,6 @@ def forward_generate(
445441
mask = self.causal_mask[None, None, input_pos, :max_seq_len] # (B, N, Q, K)
446442
freqs_cis = self.freqs_cis[input_pos]
447443

448-
if getattr(self, "_layer_streamer", None) is not None:
449-
# CPU offload: run slow layers + norm + logits on CPU,
450-
# then transfer only the small results to GPU
451-
gpu_device = _orig_device
452-
x_cpu = x.to("cpu")
453-
freqs_cpu = freqs_cis.to("cpu")
454-
mask_cpu = mask.to("cpu")
455-
ipos_cpu = input_pos.to("cpu") if input_pos is not None else None
456-
457-
for layer in self.layers:
458-
x_cpu = layer(x_cpu, freqs_cpu, mask_cpu, input_pos=ipos_cpu)
459-
460-
if x_cpu.size(1) > 1 and not return_all:
461-
x_cpu = x_cpu[:, -1:]
462-
463-
slow_out_cpu = self.norm(x_cpu)
464-
465-
if self.config.is_reward_model:
466-
token_logits_cpu = self.score_output(slow_out_cpu)
467-
elif self.config.tie_word_embeddings:
468-
token_logits_cpu = F.linear(slow_out_cpu, self.embeddings.weight)
469-
else:
470-
token_logits_cpu = self.output(slow_out_cpu)
471-
472-
hidden_cpu = (
473-
slow_out_cpu if getattr(self.config, "norm_fastlayer_input", False) else x_cpu
474-
)
475-
476-
return BaseTransformerForwardResult(
477-
logits=token_logits_cpu.to(gpu_device),
478-
hidden_states=hidden_cpu.to(gpu_device),
479-
)
480-
481444
for layer in self.layers:
482445
x = layer(x, freqs_cis, mask, input_pos=input_pos)
483446

fish_speech/utils/gpu.py

Lines changed: 1 addition & 119 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
1-
"""GPU detection, VRAM guidance, ROCm gfx arch auto-detection, and CPU weight offloading."""
1+
"""GPU detection, VRAM guidance, and ROCm gfx arch auto-detection."""
22

33
import os
44

55
import torch
6-
import torch.nn as nn
76
from loguru import logger
87

98
# Known ROCm gfx arch overrides for GPUs not yet in PyTorch's HIP target list.
@@ -94,7 +93,6 @@ def check_vram_and_advise(checkpoint_path: str):
9493
suggestions.append(
9594
f"reduce MAX_SEQ_LEN (current: {max_seq_len}, try 4096 to save ~{(max_seq_len - 4096) / 8192 * 1.2:.1f}GB)"
9695
)
97-
suggestions.append("set OFFLOAD_WEIGHTS_TO_CPU=true to run slow layers on CPU")
9896
suggestions.append("set VRAM_FRACTION=0.95 to prevent system freeze on OOM")
9997

10098
logger.warning(
@@ -103,119 +101,3 @@ def check_vram_and_advise(checkpoint_path: str):
103101
)
104102
for i, s in enumerate(suggestions, 1):
105103
logger.warning(f" {i}. {s}")
106-
107-
108-
class CPUOffloadExecutor:
109-
"""Runs slow transformer layers on CPU (using AVX-512/VNNI), keeps fast path on GPU.
110-
111-
Instead of streaming layers GPU↔CPU (72 PCIe round-trips per token),
112-
this executes the slow transformer entirely on CPU and only transfers
113-
the final hidden state (~10KB) to GPU for the fast transformer + decoder.
114-
115-
For batch=1 single-token inference, CPU execution with DDR5 bandwidth
116-
(~80-100 GB/s) and AVX-512 is competitive with the PCIe streaming approach
117-
while eliminating all allocation overhead.
118-
"""
119-
120-
def __init__(self, gpu_device: torch.device):
121-
self.gpu_device = gpu_device
122-
123-
def run(self, layers: nn.ModuleList, x, *args, **kwargs):
124-
"""Execute layers on CPU, return result on pinned memory for fast GPU transfer."""
125-
# Move hidden state and all positional args to CPU
126-
x_cpu = x.to("cpu")
127-
args_cpu = tuple(a.to("cpu") if isinstance(a, torch.Tensor) else a for a in args)
128-
kwargs_cpu = {
129-
k: v.to("cpu") if isinstance(v, torch.Tensor) else v
130-
for k, v in kwargs.items()
131-
}
132-
133-
# Run all layers on CPU — weights are already here, no PCIe needed
134-
for layer in layers:
135-
x_cpu = layer(x_cpu, *args_cpu, **kwargs_cpu)
136-
137-
# Pin the result for faster DMA to GPU, then transfer non-blocking
138-
return x_cpu.pin_memory().to(self.gpu_device, non_blocking=True)
139-
140-
141-
def _has_int8_weights(module: nn.Module) -> bool:
142-
"""Check if any submodule uses INT8 quantized weights."""
143-
for child in module.modules():
144-
if hasattr(child, "weight") and hasattr(child, "scales") and child.weight.dtype == torch.int8:
145-
return True
146-
return False
147-
148-
149-
def setup_cpu_offload(model: nn.Module, device: torch.device):
150-
"""Offload slow transformer layers to CPU execution.
151-
152-
Moves slow layer weights + KV caches to CPU. The slow transformer runs
153-
entirely on CPU using AVX-512, and only the final hidden state is
154-
transferred to GPU for the fast transformer and decoder.
155-
156-
Fast layers stay on GPU (small footprint, called 10x per token).
157-
158-
Enable with OFFLOAD_WEIGHTS_TO_CPU=true.
159-
Requires native bf16 weights — INT8 quantized models are not supported
160-
because autoregressive decode (M=1) cannot use VNNI _int_mm (requires M>16),
161-
and the dequant+bf16 fallback is ~30% slower than native bf16 matmuls.
162-
"""
163-
if not os.environ.get("OFFLOAD_WEIGHTS_TO_CPU", "").lower() in ("true", "1"):
164-
return False
165-
166-
if not hasattr(model, "layers"):
167-
logger.warning("Model has no 'layers' attribute, cannot offload weights.")
168-
return False
169-
170-
if _has_int8_weights(model):
171-
logger.warning(
172-
"CPU offload requires native bf16 weights. INT8 quantized models are not supported "
173-
"because autoregressive decode (batch=1) cannot use VNNI INT8 matmuls (requires M>16), "
174-
"and the dequant+bf16 fallback is ~30% slower than native bf16. "
175-
"Please use the original (non-quantized) checkpoint with OFFLOAD_WEIGHTS_TO_CPU=true."
176-
)
177-
return False
178-
179-
# Use physical cores only — HyperThreading causes cache contention
180-
# on Zen 4 and hurts bf16 matmul throughput (~37% slower with HT).
181-
physical_cores = os.cpu_count() // 2 if os.cpu_count() else 8
182-
torch.set_num_threads(physical_cores)
183-
logger.info(f"CPU offload: set torch threads to {physical_cores} (physical cores only)")
184-
185-
layers = model.layers
186-
n_layers = len(layers)
187-
188-
# Move slow layers entirely to CPU (including KV caches)
189-
gpu_mem_before = torch.cuda.memory_allocated()
190-
with torch.inference_mode(False):
191-
for layer in layers:
192-
layer.to("cpu")
193-
gpu_mem_after = torch.cuda.memory_allocated()
194-
saved_gb = (gpu_mem_before - gpu_mem_after) / 1e9
195-
196-
# Move shared slow-path modules to CPU.
197-
# Keep causal_mask and fast_freqs_cis on GPU (shared with fast path).
198-
for name in ("norm", "embeddings", "codebook_embeddings", "output"):
199-
module = getattr(model, name, None)
200-
if module is not None:
201-
with torch.inference_mode(False):
202-
if isinstance(module, nn.Module):
203-
module.to("cpu")
204-
205-
gpu_mem_final = torch.cuda.memory_allocated()
206-
total_saved_gb = (gpu_mem_before - gpu_mem_final) / 1e9
207-
208-
logger.info(
209-
f"CPU offload: moved {n_layers} slow layers + shared modules to CPU, "
210-
f"freed {total_saved_gb:.1f}GB VRAM. Fast layers + decoder remain on GPU."
211-
)
212-
213-
# Keep fast_layers on GPU — small footprint, called 10x per token
214-
fast_layers = getattr(model, "fast_layers", None)
215-
if fast_layers is not None:
216-
logger.info(f"CPU offload: keeping {len(fast_layers)} fast layers on GPU.")
217-
218-
# Attach executor — forward_generate will use it
219-
model._layer_streamer = CPUOffloadExecutor(device)
220-
221-
return True

0 commit comments

Comments
 (0)