|
| 1 | +# docker/Dockerfile.rocm |
| 2 | +# |
| 3 | +# Fish Speech on AMD ROCm (RDNA3 / RDNA4). |
| 4 | +# The checkpoints are NOT bundled — mount them at /app/checkpoints. |
| 5 | +# |
| 6 | +# Build: |
| 7 | +# docker build -f docker/Dockerfile.rocm --target webui -t fish-speech-webui:rocm . |
| 8 | +# docker build -f docker/Dockerfile.rocm --target server -t fish-speech-server:rocm . |
| 9 | +# |
| 10 | +# Run (webui): |
| 11 | +# docker run --device=/dev/kfd --device=/dev/dri \ |
| 12 | +# --group-add video --group-add render \ |
| 13 | +# -e ROCBLAS_USE_HIPBLASLT=0 \ |
| 14 | +# -v ./checkpoints:/app/checkpoints \ |
| 15 | +# -p 7860:7860 fish-speech-webui:rocm |
| 16 | + |
| 17 | +ARG ROCM_VERSION=7.2.3 |
| 18 | +ARG BASE_IMAGE=rocm/pytorch:rocm${ROCM_VERSION}_ubuntu24.04_py3.12_pytorch_release_2.9.1 |
| 19 | + |
| 20 | +FROM ${BASE_IMAGE} AS app-base |
| 21 | + |
| 22 | +ENV DEBIAN_FRONTEND=noninteractive \ |
| 23 | + PYTHONDONTWRITEBYTECODE=1 \ |
| 24 | + PYTHONUNBUFFERED=1 \ |
| 25 | + ROCBLAS_USE_HIPBLASLT=0 |
| 26 | + |
| 27 | +RUN apt-get update \ |
| 28 | + && apt-get install -y --no-install-recommends \ |
| 29 | + git ffmpeg libsox-dev build-essential cmake \ |
| 30 | + libasound-dev portaudio19-dev libportaudio2 libportaudiocpp0 \ |
| 31 | + && apt-get clean \ |
| 32 | + && rm -rf /var/lib/apt/lists/* |
| 33 | + |
| 34 | +WORKDIR /app |
| 35 | + |
| 36 | +COPY . /app |
| 37 | + |
| 38 | +# Install runtime dependencies WITHOUT torch/torchaudio — the ROCm base image |
| 39 | +# already ships a gfx-tuned torch (2.9.1+rocm7.2.3). Then install the package |
| 40 | +# itself with --no-deps so pip does not try to pull a CUDA/CPU torch. |
| 41 | +RUN pip install --no-cache-dir --upgrade pip setuptools wheel \ |
| 42 | + && pip install --no-cache-dir \ |
| 43 | + numpy "transformers<=4.57.3" datasets lightning pytorch_lightning \ |
| 44 | + hydra-core natsort einops librosa rich "gradio>5.0.0" wandb grpcio kui \ |
| 45 | + uvicorn loguru loralib pyrootutils resampy "einx[torch]==0.2.2" zstandard \ |
| 46 | + pydub "modelscope==1.17.1" "opencc-python-reimplemented==0.1.7" \ |
| 47 | + silero-vad ormsgpack tiktoken "pydantic==2.9.2" cachetools \ |
| 48 | + descript-audio-codec safetensors soundfile vector_quantize_pytorch \ |
| 49 | + && pip install --no-cache-dir --no-build-isolation pyaudio \ |
| 50 | + && pip install --no-cache-dir --no-deps -e . \ |
| 51 | + # descript-audiotools pins protobuf<3.20, but fish-speech's generated proto |
| 52 | + # code needs >=3.20. Override after install (mirrors pyproject's uv override). |
| 53 | + && pip install --no-cache-dir --no-deps --upgrade "protobuf>=4.25,<6.0" |
| 54 | + |
| 55 | +EXPOSE 7860 8080 |
| 56 | + |
| 57 | +# torch.compile is enabled by default (verified working on gfx1201/RDNA4). |
| 58 | +# Set COMPILE=0 to disable. |
| 59 | +ENV COMPILE=1 |
| 60 | + |
| 61 | +############################################################## |
| 62 | +# Gradio WebUI |
| 63 | +############################################################## |
| 64 | +FROM app-base AS webui |
| 65 | + |
| 66 | +ARG GRADIO_SERVER_NAME="0.0.0.0" |
| 67 | +ARG GRADIO_SERVER_PORT=7860 |
| 68 | +ENV GRADIO_SERVER_NAME=${GRADIO_SERVER_NAME} \ |
| 69 | + GRADIO_SERVER_PORT=${GRADIO_SERVER_PORT} |
| 70 | + |
| 71 | +RUN printf '%s\n' \ |
| 72 | + '#!/bin/bash' \ |
| 73 | + 'set -e' \ |
| 74 | + 'ARGS=()' \ |
| 75 | + 'if [ "${COMPILE:-0}" = "1" ] || [ "${COMPILE:-}" = "true" ]; then ARGS+=(--compile); fi' \ |
| 76 | + 'exec python tools/run_webui.py \' \ |
| 77 | + ' --llama-checkpoint-path checkpoints/s2-pro \' \ |
| 78 | + ' --decoder-checkpoint-path checkpoints/s2-pro/codec.pth \' \ |
| 79 | + ' --decoder-config-name modded_dac_vq "${ARGS[@]}"' \ |
| 80 | + > /app/start_webui.sh && chmod +x /app/start_webui.sh |
| 81 | + |
| 82 | +ENTRYPOINT ["/app/start_webui.sh"] |
| 83 | + |
| 84 | +############################################################## |
| 85 | +# API Server |
| 86 | +############################################################## |
| 87 | +FROM app-base AS server |
| 88 | + |
| 89 | +ARG API_SERVER_NAME="0.0.0.0" |
| 90 | +ARG API_SERVER_PORT=8080 |
| 91 | +ENV API_SERVER_NAME=${API_SERVER_NAME} \ |
| 92 | + API_SERVER_PORT=${API_SERVER_PORT} |
| 93 | + |
| 94 | +RUN printf '%s\n' \ |
| 95 | + '#!/bin/bash' \ |
| 96 | + 'set -e' \ |
| 97 | + 'ARGS=()' \ |
| 98 | + 'if [ "${COMPILE:-0}" = "1" ] || [ "${COMPILE:-}" = "true" ]; then ARGS+=(--compile); fi' \ |
| 99 | + 'exec python tools/api_server.py \' \ |
| 100 | + ' --listen 0.0.0.0:8080 \' \ |
| 101 | + ' --llama-checkpoint-path checkpoints/s2-pro \' \ |
| 102 | + ' --decoder-checkpoint-path checkpoints/s2-pro/codec.pth \' \ |
| 103 | + ' --decoder-config-name modded_dac_vq "${ARGS[@]}"' \ |
| 104 | + > /app/start_server.sh && chmod +x /app/start_server.sh |
| 105 | + |
| 106 | +ENTRYPOINT ["/app/start_server.sh"] |
0 commit comments