Skip to content

feat: add AMD ROCm 7.2.x support#1247

Closed
gtherond wants to merge 16 commits into
fishaudio:mainfrom
imagilux:feat/rocm-support
Closed

feat: add AMD ROCm 7.2.x support#1247
gtherond wants to merge 16 commits into
fishaudio:mainfrom
imagilux:feat/rocm-support

Conversation

@gtherond
Copy link
Copy Markdown
Contributor

@gtherond gtherond commented Mar 25, 2026

Summary

Adds full AMD ROCm 7.2 support for Fish Speech inference (Docker and bare-metal), tested on an RX 9070 (16GB VRAM, RDNA 4 / gfx1201).

Changes:

  • Device-agnostic code: Replace hardcoded .cuda() calls with .to(device) and replace @torch.autocast(device_type="cuda") with a runtime helper that selects the correct device type
  • VQ-GAN precision: Pass user-requested precision (bfloat16) through to the VQ-GAN decoder, halving its VRAM footprint (~1.87GB → ~1.39GB in bfloat16)
  • Configurable KV cache (MAX_SEQ_LEN env var): Allow users with constrained VRAM to reduce KV cache allocation (defaults to 32768, unchanged behavior)
  • Optional VRAM cap (VRAM_FRACTION env var): Calls torch.cuda.set_per_process_memory_fraction() to prevent driver-level freezes on memory-constrained GPUs (no-op when unset)
  • Dockerfile: Add base-rocm stage using rocm/dev-ubuntu-24.04:7.2
  • Docker Compose: Extract NVIDIA GPU reservation into compose.cuda.yml overlay; add compose.rocm.yml overlay with /dev/kfd + /dev/dri passthrough and ROCm-tuned env vars
  • pyproject.toml: Add rocm72 optional dependency extra (PyTorch 2.11.0 + triton-rocm from ROCm 7.2 wheel index)
  • Bugfix: Fix UnboundLocalError in torchaudio backend detection (reference_loader.py)

Usage:

# ROCm (AMD GPU)
BACKEND=rocm UV_EXTRA=rocm72 VRAM_FRACTION=0.95 MAX_SEQ_LEN=4096 \
  docker compose -f compose.yml -f compose.rocm.yml --profile webui up

# CUDA (NVIDIA GPU) — same as before, just explicit overlay
docker compose -f compose.yml -f compose.cuda.yml --profile webui up

# CPU — unchanged
BACKEND=cpu UV_EXTRA=cpu docker compose --profile webui up

Test plan

  • Tested on RX 9070 16GB (gfx1201/RDNA4), Ubuntu 24.04, kernel 6.17 + amdgpu-dkms 6.18.4, ROCm 7.2.0 AND ROCm 7.2.1
  • WebUI loads both Llama (10.27GB bf16) and VQ-GAN (1.39GB bf16) models within 16GB VRAM
  • TTS inference produces correct audio output
  • VRAM_FRACTION and MAX_SEQ_LEN env vars work as expected
  • Verify CUDA path is not regressed (compose.cuda.yml overlay)
  • Verify CPU path is not regressed

Closes #1241
Closes #1243
Closes #1246
Closes #1249
Closes #1250

gtherond and others added 15 commits March 25, 2026 18:20
…atives

Replace .cuda() with .to(device) in modded_dac.py and extract_vq.py.
Replace @torch.autocast(device_type="cuda") decorator with a runtime
_autocast() helper in model_utils.py that selects the correct device type.

This enables non-CUDA backends (ROCm/HIP, MPS, XPU) to work without
code changes.
- Add precision parameter to dac/inference.py load_model() so the
  VQ-GAN decoder loads in bfloat16 instead of float32, halving its
  VRAM footprint (~1.87GB → ~1.39GB).
- Thread precision through from run_webui.py and model_manager.py.
- Load weights to CPU before moving to device to avoid double VRAM
  allocation.
- Add optional VRAM_FRACTION env var in run_webui.py that calls
  torch.cuda.set_per_process_memory_fraction() to prevent driver-level
  freezes on memory-constrained GPUs. No-op when unset (default).
Respect MAX_SEQ_LEN environment variable when allocating KV caches in
both the worker thread and the generate() code path. Defaults to the
model's configured max_seq_len (32768) when unset.

This allows users with constrained VRAM (e.g. 16GB) to reduce KV cache
allocation and fit both models in memory.
The bare 'import torchaudio.io._load_audio_fileobj' statement was being
optimized away, causing the ffmpeg backend detection to silently fail.
Use importlib.import_module() to force the runtime import check.
- Add rocm72 optional dependency extra in pyproject.toml pointing to
  PyTorch 2.11.0 from the ROCm 7.2 wheel index, with triton-rocm.
- Add base-rocm Dockerfile stage using rocm/dev-ubuntu-24.04:7.2.
- Extract NVIDIA GPU reservation from compose.base.yml into a new
  compose.cuda.yml overlay so the base is backend-agnostic.
- Add compose.rocm.yml overlay with /dev/kfd + /dev/dri passthrough,
  render group, ROCm-specific env vars (HSA_ENABLE_SDMA, GPU_MAX_HW_QUEUES,
  PYTORCH_HIP_ALLOC_CONF), and VRAM_FRACTION/MAX_SEQ_LEN passthrough.
- Use --no-sync in entrypoint scripts to skip unnecessary uv resolution.
- Update uv.lock with ROCm wheel resolution.

Usage:
  BACKEND=rocm UV_EXTRA=rocm72 \
    docker compose -f compose.yml -f compose.rocm.yml --profile webui up

Tested on: RX 9070 (16GB), Ubuntu 24.04, kernel 6.17, ROCm 7.2.

Closes #1246
…mport

- Add LLAMA_CHECKPOINT_PATH, DECODER_CHECKPOINT_PATH, DECODER_CONFIG_NAME
  as runtime environment variables in compose.yml (enables switching
  checkpoints without rebuilding, e.g. INT8 quantized models)
- Fix broken import in quantize tool: load_model was renamed to init_model
- New fish_speech/utils/gpu.py with two helpers:
  - auto_detect_rocm_gfx(): sets HSA_OVERRIDE_GFX_VERSION for GPUs not
    yet in PyTorch's HIP target list (e.g. RDNA 4 / Navi 48)
  - check_vram_and_advise(): estimates memory usage at startup and logs
    actionable suggestions (INT8 quantization, reduce MAX_SEQ_LEN, set
    VRAM_FRACTION) when the model likely won't fit
- Wired into both entry points (run_webui.py, model_manager.py)
Stream transformer layer weights from CPU pinned memory (GTT) to GPU
on demand, keeping KV caches on VRAM for low-latency attention.

Enable with OFFLOAD_WEIGHTS_TO_GTT=true (ROCm only).

How it works:
- After model loads to GPU, weights are moved to pinned CPU memory
- Forward hooks prefetch the next layer via async HIP stream while
  the current layer computes
- KV caches (k_cache, v_cache) stay on GPU
- Fast transformer layers are also offloaded

This enables running models that exceed physical VRAM by trading
PCIe bandwidth for memory capacity.
…tion

Replace subprocess cat of sysfs PCI device ID with native
torch.cuda.get_device_properties().gcnArchName for cleaner,
more portable ROCm gfx arch auto-detection.
Stream transformer layer weights from CPU to GPU on demand using a
LayerStreamer that prefetches the next layer via an async HIP stream
while the current layer computes. KV caches stay on GPU.

Enable with OFFLOAD_WEIGHTS_TO_GTT=true (ROCm only).

Architecture:
- LayerStreamer.run() replaces the layer iteration loop
- _layer_to_gpu/_layer_to_cpu preserve KV caches during moves
- torch.inference_mode(False) wrapper for version counter compat
- Patched forward_generate and forward_generate_fast in llama.py
- Falls through to normal loop when streamer is not attached
… pin_memory

- Renamed env var to accurately reflect CPU offload (not true GTT)
- Pin memory on initial offload for faster DMA transfers
- Skip re-pinning during streaming loop (overhead exceeds benefit)
- Updated compose.rocm.yml and inference.py references
Fast layers are only ~200MB but called 10x per token (once per codebook).
Keeping them on GPU eliminates 40 PCIe round-trips per token.
Also removed torch.cuda.synchronize() as layer.to("cpu") is implicitly sync.
Replace PCIe layer-streaming approach with full CPU execution for
slow transformer layers — eliminates 72 PCIe round-trips per token.
AVX-512 BF16 matmuls on CPU with DDR5 bandwidth (~80-100 GB/s)
handle the compute while only the final hidden state (~10KB)
transfers to GPU for fast layers + decoder.

Key changes:
- CPUOffloadExecutor replaces LayerStreamer (cpu.py)
- Slow layers + shared modules (embeddings, norm, output) move to CPU
- Fast layers stay on GPU (small footprint, called 10x per token)
- INT8 models rejected with guidance (dequant overhead ~30% slower)
- Thread count set to physical cores (HT causes 37% cache contention)
- MIOpen exhaustive kernel search (MIOPEN_FIND_MODE=3) with cache volume
- ulimits nofile 65536 for Triton/MIOpen file descriptor needs
- llama.py forward_generate handles CPU/GPU device transitions inline

Benchmarked on RX 9070 XT (16GB), ROCm 7.2.1:
- Full GPU path: 85s → 34.9s (2.4x faster with MIOpen + compile)
- CPU offload: 100s, VRAM 16GB → 1.94GB
Separate CPU offload work into its own branch (feat/cpu-offload)
to keep the ROCm support PR focused on infrastructure.

This branch now contains only:
- ROCm 7.2 Docker/compose infrastructure
- Device-agnostic code replacements
- VQ-GAN precision pass-through and VRAM_FRACTION cap
- MAX_SEQ_LEN configurable KV cache
- ROCm gfx arch auto-detection and VRAM guidance
- MIOpen exhaustive kernel tuning (MIOPEN_FIND_MODE=3)
- Persistent MIOpen cache volume

Tested on ROCm 7.2.1 with RX 9070 XT (16GB):
- Full GPU INT8 + COMPILE=1 + MIOpen tuning: 34.9s (was 85s on 7.2.0)
- 2.4x performance improvement from ROCm 7.2.1 HIP/MIOpen fixes
@gtherond gtherond force-pushed the feat/rocm-support branch from c758595 to 8ec8d21 Compare March 26, 2026 13:40
@gtherond
Copy link
Copy Markdown
Contributor Author

Alright, sorry for the commit noise, I forget to swap my branch in between two features improvement.

So, this request brings ROCm 7.2.1 support for fish through pytorch without breaking anything CUDA or CPU related.

I've pinned pytorch to 2.8.0 for CUDA/CPU related profiles but create a proper rocm72 profile for Pytorch 2.11.0 requirements that comes with ROCm 7.2.x support from official AMD repository.

Does it needs something more?

model = instantiate(cfg)
state_dict = torch.load(
checkpoint_path, map_location=device, mmap=True, weights_only=True
checkpoint_path, map_location="cpu", mmap=True, weights_only=True
Copy link
Copy Markdown
Contributor Author

@gtherond gtherond Mar 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The purpose of this is to avoid having to load the weights twice on GPU and avoid waste of memory for restrained vram on GA cards.

As instantiate(cfg) creates the model with random weights and then load_state_dict() replace them with checkpoints weights, there is a short time where it use twice as much memory size as needed.

By using map_location="cpu" we do perform the swap directly on system memory and THEN we load the definitive structure to the GPU device, this avoid having either the random init or the checkpoint to reach GPU before the final .to(device) call.

Using this method avoid GA restrained VRAM cards to OOM while matching how llama is loaded too.

Comment thread fish_speech/utils/gpu.py

# Known ROCm gfx arch overrides for GPUs not yet in PyTorch's HIP target list.
# Maps gcnArchName to the closest supported HSA_OVERRIDE_GFX_VERSION.
_ROCM_GFX_OVERRIDES = {
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could be removed as 7.2.1 introduced full blown support of GFX1201 but still there as a safety net for users who would build using 7.2.0

It's your call to tell me if you want me to get ride of it.

@gtherond gtherond changed the title feat: add AMD ROCm 7.2 support feat: add AMD ROCm 7.2.x support Mar 26, 2026
@Whale-Dolphin
Copy link
Copy Markdown
Member

Unfortunately, we do not have any AMD GPUs available for testing. It may need further testing from the community before it can be merged.

@gtherond
Copy link
Copy Markdown
Contributor Author

Unfortunately, we do not have any AMD GPUs available for testing. It may need further testing from the community before it can be merged.

I've a RX6600 too if you need me to test it on a previous generation.

But yeah I do understand.

In the meantime, I've forked the repository in order to work deeper on the AMD compatibility layer (MIOpen specifically) as F.conv1d / F.conv_transpose1d via torch.nn.functional.scaled_dot_product_attention path get a known performance issue as pytorch workspace allocation with the MIOPen kernel is broken so far.

Let me know if I can help in any way!

@gtherond gtherond closed this Mar 31, 2026
@gtherond gtherond deleted the feat/rocm-support branch March 31, 2026 17:40
@pete-h
Copy link
Copy Markdown

pete-h commented Apr 14, 2026

I can't seem to find the fork anymore...
I have a 9070 XT and a 9060 XT and could help with testing...

@gtherond
Copy link
Copy Markdown
Contributor Author

I can't seem to find the fork anymore... I have a 9070 XT and a 9060 XT and could help with testing...

In here hard forked.
https://github.com/imagilux/fish-speech/

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

3 participants