Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
16 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 3 additions & 10 deletions compose.base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,15 @@ services:
context: .
dockerfile: docker/Dockerfile
args:
BACKEND: ${BACKEND:-cuda} # or cpu
BACKEND: ${BACKEND:-cuda} # cuda, cpu, or rocm
CUDA_VER: ${CUDA_VER:-12.9.0}
UV_EXTRA: ${UV_EXTRA:-cu129}
ROCM_VER: ${ROCM_VER:-7.2}
UV_EXTRA: ${UV_EXTRA:-cu129} # cu126, cu128, cu129, rocm72, or cpu
UV_VERSION: ${UV_VERSION:-0.8.15}
volumes:
- ./checkpoints:/app/checkpoints
- ./references:/app/references
environment:
COMPILE: ${COMPILE:-0}
# GPU (remove this block if CPU-only):
deploy:
resources:
reservations:
devices:
- driver: nvidia
count: all
capabilities: [gpu]
tty: true
stdin_open: true
20 changes: 20 additions & 0 deletions compose.cuda.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# CUDA (NVIDIA GPU) overlay — use with:
# docker compose -f compose.yml -f compose.cuda.yml --profile webui up
services:
webui:
deploy:
resources:
reservations:
devices:
- driver: nvidia
count: all
capabilities: [gpu]

server:
deploy:
resources:
reservations:
devices:
- driver: nvidia
count: all
capabilities: [gpu]
64 changes: 64 additions & 0 deletions compose.rocm.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# ROCm (AMD GPU) overlay — use with:
# BACKEND=rocm UV_EXTRA=rocm72 docker compose -f compose.yml -f compose.rocm.yml --profile webui up
#
# Requires: /dev/kfd and /dev/dri on the host (ROCm kernel driver installed)
# RENDER_GID should match the host's render group GID (run: stat -c '%g' /dev/kfd)
# For the AMD Container Toolkit alternative (runtime: amd), see:
# https://instinct.docs.amd.com/projects/container-toolkit/en/latest/container-runtime/docker-compose.html
services:
webui:
devices:
- /dev/kfd:/dev/kfd
- /dev/dri:/dev/dri
group_add:
- video
- "${RENDER_GID:-993}"
security_opt:
- seccomp=unconfined
ipc: host
shm_size: 8G
ulimits:
nofile:
soft: 65536
hard: 65536
volumes:
- miopen-cache:/home/fish/.config/miopen
environment:
HSA_ENABLE_SDMA: "0"
GPU_MAX_HW_QUEUES: "1"
HSA_USE_SVM: "0"
PYTORCH_HIP_ALLOC_CONF: "garbage_collection_threshold:0.7,max_split_size_mb:4096"
MIOPEN_FIND_MODE: "${MIOPEN_FIND_MODE:-3}"
VRAM_FRACTION: "${VRAM_FRACTION:-0}"
MAX_SEQ_LEN: "${MAX_SEQ_LEN:-32768}"
OFFLOAD_WEIGHTS_TO_CPU: "${OFFLOAD_WEIGHTS_TO_CPU:-false}"

server:
devices:
- /dev/kfd:/dev/kfd
- /dev/dri:/dev/dri
group_add:
- video
- "${RENDER_GID:-993}"
security_opt:
- seccomp=unconfined
ipc: host
shm_size: 8G
ulimits:
nofile:
soft: 65536
hard: 65536
volumes:
- miopen-cache:/home/fish/.config/miopen
environment:
HSA_ENABLE_SDMA: "0"
GPU_MAX_HW_QUEUES: "1"
HSA_USE_SVM: "0"
PYTORCH_HIP_ALLOC_CONF: "garbage_collection_threshold:0.7,max_split_size_mb:4096"
MIOPEN_FIND_MODE: "${MIOPEN_FIND_MODE:-3}"
VRAM_FRACTION: "${VRAM_FRACTION:-0}"
MAX_SEQ_LEN: "${MAX_SEQ_LEN:-32768}"
OFFLOAD_WEIGHTS_TO_CPU: "${OFFLOAD_WEIGHTS_TO_CPU:-false}"

volumes:
miopen-cache:
6 changes: 6 additions & 0 deletions compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ services:
target: webui
environment:
COMPILE: ${COMPILE:-0}
LLAMA_CHECKPOINT_PATH: ${LLAMA_CHECKPOINT_PATH:-checkpoints/s2-pro}
DECODER_CHECKPOINT_PATH: ${DECODER_CHECKPOINT_PATH:-checkpoints/s2-pro/codec.pth}
DECODER_CONFIG_NAME: ${DECODER_CONFIG_NAME:-modded_dac_vq}
profiles: ["webui"]
ports:
- "${GRADIO_PORT:-7860}:7860"
Expand All @@ -21,6 +24,9 @@ services:
target: server
environment:
COMPILE: ${COMPILE:-0}
LLAMA_CHECKPOINT_PATH: ${LLAMA_CHECKPOINT_PATH:-checkpoints/s2-pro}
DECODER_CHECKPOINT_PATH: ${DECODER_CHECKPOINT_PATH:-checkpoints/s2-pro/codec.pth}
DECODER_CONFIG_NAME: ${DECODER_CONFIG_NAME:-modded_dac_vq}
profiles: ["server"]
ports:
- "${API_PORT:-8080}:8080"
32 changes: 27 additions & 5 deletions docker/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
# docker build \
# --platform linux/amd64 \
# -f docker/Dockerfile \
# --build-arg BACKEND=[cuda, cpu] \
# --build-arg BACKEND=[cuda, cpu, rocm] \
# --target [webui, server] \
# -t fish-speech-[webui, server]:[cuda, cpu] .
# -t fish-speech-[webui, server]:[cuda, cpu, rocm] .

# e.g. for building the webui:
# docker build \
Expand Down Expand Up @@ -62,9 +62,11 @@

# Select the specific cuda version (see https://hub.docker.com/r/nvidia/cuda/)
ARG CUDA_VER=12.9.0
# Adapt the uv extra to fit the cuda version (one of [cu126, cu128, cu129])
# Adapt the uv extra to fit the backend (one of [cu126, cu128, cu129, rocm72, cpu])
ARG UV_EXTRA=cu129
ARG BACKEND=cuda
# ROCm version (see https://hub.docker.com/r/rocm/dev-ubuntu-24.04)
ARG ROCM_VER=7.2

ARG UBUNTU_VER=24.04
ARG PY_VER=3.12
Expand All @@ -84,6 +86,26 @@ FROM nvidia/cuda:${CUDA_VER}-cudnn-runtime-ubuntu${UBUNTU_VER} AS base-cuda
ENV DEBIAN_FRONTEND=noninteractive

# Install system dependencies in a single layer with cleanup
RUN --mount=type=cache,target=/var/cache/apt,sharing=locked \
--mount=type=cache,target=/var/lib/apt,sharing=locked \
set -eux \
&& rm -f /etc/apt/apt.conf.d/docker-clean \
&& echo 'Binary::apt::APT::Keep-Downloaded-Packages "true";' >/etc/apt/apt.conf.d/keep-cache \
&& apt-get update \
&& apt-get install -y --no-install-recommends \
python3-pip \
python3-dev \
git \
ca-certificates \
curl \
&& apt-get clean \
&& rm -rf /var/lib/apt/lists/*

# --- ROCm (AMD GPU, x86_64) ---
FROM rocm/dev-ubuntu-${UBUNTU_VER}:${ROCM_VER} AS base-rocm
ENV DEBIAN_FRONTEND=noninteractive
ENV UV_EXTRA=rocm72

RUN --mount=type=cache,target=/var/cache/apt,sharing=locked \
--mount=type=cache,target=/var/lib/apt,sharing=locked \
set -eux \
Expand Down Expand Up @@ -311,7 +333,7 @@ RUN printf '%s\n' \
'log "Compile args: ${COMPILE_ARGS}"' \
'log "Server: ${GRADIO_SERVER_NAME}:${GRADIO_SERVER_PORT}"' \
'' \
'exec uv run tools/run_webui.py \' \
'exec uv run --no-sync tools/run_webui.py \' \
' --llama-checkpoint-path "${LLAMA_CHECKPOINT_PATH}" \' \
' --decoder-checkpoint-path "${DECODER_CHECKPOINT_PATH}" \' \
' --decoder-config-name "${DECODER_CONFIG_NAME}" \' \
Expand Down Expand Up @@ -359,7 +381,7 @@ RUN printf '%s\n' \
'log "Compile args: ${COMPILE_ARGS}"' \
'log "Server: ${API_SERVER_NAME}:${API_SERVER_PORT}"' \
'' \
'exec uv run tools/api_server.py \' \
'exec uv run --no-sync tools/api_server.py \' \
' --listen "${API_SERVER_NAME}:${API_SERVER_PORT}" \' \
' --llama-checkpoint-path "${LLAMA_CHECKPOINT_PATH}" \' \
' --decoder-checkpoint-path "${DECODER_CHECKPOINT_PATH}" \' \
Expand Down
4 changes: 3 additions & 1 deletion fish_speech/inference_engine/reference_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,9 @@ def __init__(self) -> None:
# torchaudio 2.9+ removed list_audio_backends()
# Try ffmpeg first, fallback to soundfile
try:
import torchaudio.io._load_audio_fileobj # noqa: F401
from importlib import import_module

import_module("torchaudio.io._load_audio_fileobj")

self.backend = "ffmpeg"
except (ImportError, ModuleNotFoundError):
Expand Down
9 changes: 6 additions & 3 deletions fish_speech/models/dac/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,14 @@
OmegaConf.register_new_resolver("eval", eval)


def load_model(config_name, checkpoint_path, device="cuda"):
def load_model(config_name, checkpoint_path, device="cuda", precision=None):
hydra.core.global_hydra.GlobalHydra.instance().clear()
with initialize(version_base="1.3", config_path="../../configs"):
cfg = compose(config_name=config_name)

model = instantiate(cfg)
state_dict = torch.load(
checkpoint_path, map_location=device, mmap=True, weights_only=True
checkpoint_path, map_location="cpu", mmap=True, weights_only=True
Copy link
Copy Markdown
Contributor Author

@gtherond gtherond Mar 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The purpose of this is to avoid having to load the weights twice on GPU and avoid waste of memory for restrained vram on GA cards.

As instantiate(cfg) creates the model with random weights and then load_state_dict() replace them with checkpoints weights, there is a short time where it use twice as much memory size as needed.

By using map_location="cpu" we do perform the swap directly on system memory and THEN we load the definitive structure to the GPU device, this avoid having either the random init or the checkpoint to reach GPU before the final .to(device) call.

Using this method avoid GA restrained VRAM cards to OOM while matching how llama is loaded too.

)
if "state_dict" in state_dict:
state_dict = state_dict["state_dict"]
Expand All @@ -41,7 +41,10 @@ def load_model(config_name, checkpoint_path, device="cuda"):

result = model.load_state_dict(state_dict, strict=False, assign=True)
model.eval()
model.to(device)
if precision is not None:
model.to(device=device, dtype=precision)
else:
model.to(device)

logger.info(f"Loaded model: {result}")
return model
Expand Down
5 changes: 3 additions & 2 deletions fish_speech/models/dac/modded_dac.py
Original file line number Diff line number Diff line change
Expand Up @@ -1015,13 +1015,14 @@ def forward(
model = hydra.utils.instantiate(OmegaConf.load(config_path))
new_sd = torch.load(checkpoint_path, map_location="cpu")
model.load_state_dict(new_sd, strict=False)
model.cuda()
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
model.eval()

# 2. 加载外部 codes (.npy)
# 预期 shape 通常为 [num_codebooks, seq_len] 或 [1, num_codebooks, seq_len]
codes_np = np.load(codes_path)
codes_tensor = torch.from_numpy(codes_np).to(torch.long).cuda()
codes_tensor = torch.from_numpy(codes_np).to(torch.long).to(device)

# 如果 codes 没有 batch 维度,增加一个维度 [1, num_codebooks, seq_len]
if len(codes_tensor.shape) == 2:
Expand Down
8 changes: 6 additions & 2 deletions fish_speech/models/text2semantic/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,10 +280,11 @@ def generate(

# Critical fix: Only set up cache on first run or when necessary
if not hasattr(model, "_cache_setup_done") or not model._cache_setup_done:
max_seq_len = int(os.environ.get("MAX_SEQ_LEN", model.config.max_seq_len))
with torch.device(device):
model.setup_caches(
max_batch_size=1, # Fixed to 1, avoid dynamic changes
max_seq_len=model.config.max_seq_len,
max_seq_len=max_seq_len,
dtype=next(model.parameters()).dtype,
)
model._cache_setup_done = True
Expand Down Expand Up @@ -758,12 +759,15 @@ def worker():
model, decode_one_token = init_model(
checkpoint_path, device, precision, compile=compile
)

max_seq_len = int(os.environ.get("MAX_SEQ_LEN", model.config.max_seq_len))
with torch.device(device):
model.setup_caches(
max_batch_size=1,
max_seq_len=model.config.max_seq_len,
max_seq_len=max_seq_len,
dtype=next(model.parameters()).dtype,
)

init_event.set()

while True:
Expand Down
107 changes: 107 additions & 0 deletions fish_speech/utils/gpu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
"""GPU detection, VRAM guidance, and ROCm gfx arch auto-detection."""

import os

import torch
from loguru import logger

# Known ROCm gfx arch overrides for GPUs not yet in PyTorch's HIP target list.
# Maps gcnArchName to the closest supported HSA_OVERRIDE_GFX_VERSION.
_ROCM_GFX_OVERRIDES = {
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could be removed as 7.2.1 introduced full blown support of GFX1201 but still there as a safety net for users who would build using 7.2.0

It's your call to tell me if you want me to get ride of it.

"gfx1201": "12.0.0", # Navi 48 — RX 9070/9070 XT → fallback to gfx1200
}

# Approximate model memory requirements (in GB) for VRAM guidance.
_MODEL_ESTIMATE_BF16 = 10.3
_MODEL_ESTIMATE_INT8 = 5.1
_DECODER_ESTIMATE_BF16 = 3.6
_DECODER_ESTIMATE_INT8 = 1.8


def _is_rocm() -> bool:
"""Check if running on ROCm (AMD HIP backend)."""
return (
torch.cuda.is_available()
and hasattr(torch.version, "hip")
and torch.version.hip is not None
)


def auto_detect_rocm_gfx():
"""Set HSA_OVERRIDE_GFX_VERSION if running on an unrecognized AMD GPU.

Only acts when:
- Running on ROCm (HIP backend)
- HSA_OVERRIDE_GFX_VERSION is not already set
- The GPU's gcnArchName matches a known override
"""
if not _is_rocm():
return
if os.environ.get("HSA_OVERRIDE_GFX_VERSION"):
return

props = torch.cuda.get_device_properties(0)
arch = getattr(props, "gcnArchName", None)
if arch is None:
return

gfx_ver = _ROCM_GFX_OVERRIDES.get(arch)
if gfx_ver is not None:
os.environ["HSA_OVERRIDE_GFX_VERSION"] = gfx_ver
logger.info(
f"Auto-detected AMD GPU arch {arch}, "
f"setting HSA_OVERRIDE_GFX_VERSION={gfx_ver}"
)


def check_vram_and_advise(checkpoint_path: str):
"""Log VRAM guidance if the model may not fit.

Estimates memory usage based on whether INT8 quantization is active
and the configured MAX_SEQ_LEN, then compares against available VRAM.
"""
if not torch.cuda.is_available():
return

props = torch.cuda.get_device_properties(0)
total_gb = props.total_memory / 1e9

is_int8 = "int8" in str(checkpoint_path)
max_seq_len = int(os.environ.get("MAX_SEQ_LEN", "32768"))

model_gb = _MODEL_ESTIMATE_INT8 if is_int8 else _MODEL_ESTIMATE_BF16
decoder_gb = _DECODER_ESTIMATE_BF16
# KV cache: ~1.2GB at 8192, scales linearly
kv_gb = (max_seq_len / 8192) * 1.2
# Inference scratch/activations overhead
overhead_gb = 0.5

estimated_gb = model_gb + decoder_gb + kv_gb + overhead_gb

logger.info(
f"GPU: {props.name}, VRAM: {total_gb:.1f}GB | "
f"Estimated usage: {estimated_gb:.1f}GB "
f"(model={'INT8' if is_int8 else 'bf16'}, "
f"seq_len={max_seq_len}, decoder=bf16)"
)

if estimated_gb > total_gb:
shortfall = estimated_gb - total_gb
suggestions = []
if not is_int8:
suggestions.append(
"quantize to INT8 (saves ~5GB): "
"python tools/llama/quantize.py --checkpoint-path <path> --mode int8"
)
if max_seq_len > 4096:
suggestions.append(
f"reduce MAX_SEQ_LEN (current: {max_seq_len}, try 4096 to save ~{(max_seq_len - 4096) / 8192 * 1.2:.1f}GB)"
)
suggestions.append("set VRAM_FRACTION=0.95 to prevent system freeze on OOM")

logger.warning(
f"Estimated VRAM ({estimated_gb:.1f}GB) exceeds available ({total_gb:.1f}GB) "
f"by {shortfall:.1f}GB. Suggestions:"
)
for i, s in enumerate(suggestions, 1):
logger.warning(f" {i}. {s}")
Loading