-
Notifications
You must be signed in to change notification settings - Fork 2.6k
feat: add AMD ROCm 7.2.x support #1247
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
+1,082
−857
Closed
Changes from all commits
Commits
Show all changes
16 commits
Select commit
Hold shift + click to select a range
a2f0df6
refactor: replace hardcoded .cuda() calls with device-agnostic altern…
gtherond 4b72575
feat: pass precision to VQ-GAN decoder, add VRAM_FRACTION cap
gtherond badc4c0
feat: add MAX_SEQ_LEN env var for configurable KV cache size
gtherond b5cbbaf
fix: resolve UnboundLocalError in torchaudio backend detection
gtherond 5b35ba7
feat: add AMD ROCm 7.2 support for Docker builds
gtherond 69a9ee7
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 8a27069
fix: runtime env overrides for checkpoint paths and broken quantize i…
gtherond 0e70dae
feat: add VRAM guidance and auto-detect ROCm gfx arch
gtherond 798d554
feat: add GTT weight offloading for VRAM-constrained AMD GPUs
gtherond 18c87b4
refactor: use PyTorch gcnArchName instead of subprocess for gfx detec…
gtherond 51fa6b0
feat: add GTT weight offloading for VRAM-constrained AMD GPUs
gtherond c64d574
refactor: rename OFFLOAD_WEIGHTS_TO_GTT → OFFLOAD_WEIGHTS_TO_CPU, add…
gtherond 8a6e55e
perf: keep fast layers on GPU, remove redundant synchronize
gtherond 99b370f
perf: CPU offload executor, MIOpen tuning, INT8 guard
gtherond 8ec8d21
refactor: remove CPU offload from ROCm branch, bump ROCm 7.2.1
gtherond f291880
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,20 @@ | ||
| # CUDA (NVIDIA GPU) overlay — use with: | ||
| # docker compose -f compose.yml -f compose.cuda.yml --profile webui up | ||
| services: | ||
| webui: | ||
| deploy: | ||
| resources: | ||
| reservations: | ||
| devices: | ||
| - driver: nvidia | ||
| count: all | ||
| capabilities: [gpu] | ||
|
|
||
| server: | ||
| deploy: | ||
| resources: | ||
| reservations: | ||
| devices: | ||
| - driver: nvidia | ||
| count: all | ||
| capabilities: [gpu] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,64 @@ | ||
| # ROCm (AMD GPU) overlay — use with: | ||
| # BACKEND=rocm UV_EXTRA=rocm72 docker compose -f compose.yml -f compose.rocm.yml --profile webui up | ||
| # | ||
| # Requires: /dev/kfd and /dev/dri on the host (ROCm kernel driver installed) | ||
| # RENDER_GID should match the host's render group GID (run: stat -c '%g' /dev/kfd) | ||
| # For the AMD Container Toolkit alternative (runtime: amd), see: | ||
| # https://instinct.docs.amd.com/projects/container-toolkit/en/latest/container-runtime/docker-compose.html | ||
| services: | ||
| webui: | ||
| devices: | ||
| - /dev/kfd:/dev/kfd | ||
| - /dev/dri:/dev/dri | ||
| group_add: | ||
| - video | ||
| - "${RENDER_GID:-993}" | ||
| security_opt: | ||
| - seccomp=unconfined | ||
| ipc: host | ||
| shm_size: 8G | ||
| ulimits: | ||
| nofile: | ||
| soft: 65536 | ||
| hard: 65536 | ||
| volumes: | ||
| - miopen-cache:/home/fish/.config/miopen | ||
| environment: | ||
| HSA_ENABLE_SDMA: "0" | ||
| GPU_MAX_HW_QUEUES: "1" | ||
| HSA_USE_SVM: "0" | ||
| PYTORCH_HIP_ALLOC_CONF: "garbage_collection_threshold:0.7,max_split_size_mb:4096" | ||
| MIOPEN_FIND_MODE: "${MIOPEN_FIND_MODE:-3}" | ||
| VRAM_FRACTION: "${VRAM_FRACTION:-0}" | ||
| MAX_SEQ_LEN: "${MAX_SEQ_LEN:-32768}" | ||
| OFFLOAD_WEIGHTS_TO_CPU: "${OFFLOAD_WEIGHTS_TO_CPU:-false}" | ||
|
|
||
| server: | ||
| devices: | ||
| - /dev/kfd:/dev/kfd | ||
| - /dev/dri:/dev/dri | ||
| group_add: | ||
| - video | ||
| - "${RENDER_GID:-993}" | ||
| security_opt: | ||
| - seccomp=unconfined | ||
| ipc: host | ||
| shm_size: 8G | ||
| ulimits: | ||
| nofile: | ||
| soft: 65536 | ||
| hard: 65536 | ||
| volumes: | ||
| - miopen-cache:/home/fish/.config/miopen | ||
| environment: | ||
| HSA_ENABLE_SDMA: "0" | ||
| GPU_MAX_HW_QUEUES: "1" | ||
| HSA_USE_SVM: "0" | ||
| PYTORCH_HIP_ALLOC_CONF: "garbage_collection_threshold:0.7,max_split_size_mb:4096" | ||
| MIOPEN_FIND_MODE: "${MIOPEN_FIND_MODE:-3}" | ||
| VRAM_FRACTION: "${VRAM_FRACTION:-0}" | ||
| MAX_SEQ_LEN: "${MAX_SEQ_LEN:-32768}" | ||
| OFFLOAD_WEIGHTS_TO_CPU: "${OFFLOAD_WEIGHTS_TO_CPU:-false}" | ||
|
|
||
| volumes: | ||
| miopen-cache: |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,107 @@ | ||
| """GPU detection, VRAM guidance, and ROCm gfx arch auto-detection.""" | ||
|
|
||
| import os | ||
|
|
||
| import torch | ||
| from loguru import logger | ||
|
|
||
| # Known ROCm gfx arch overrides for GPUs not yet in PyTorch's HIP target list. | ||
| # Maps gcnArchName to the closest supported HSA_OVERRIDE_GFX_VERSION. | ||
| _ROCM_GFX_OVERRIDES = { | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could be removed as 7.2.1 introduced full blown support of GFX1201 but still there as a safety net for users who would build using 7.2.0 It's your call to tell me if you want me to get ride of it. |
||
| "gfx1201": "12.0.0", # Navi 48 — RX 9070/9070 XT → fallback to gfx1200 | ||
| } | ||
|
|
||
| # Approximate model memory requirements (in GB) for VRAM guidance. | ||
| _MODEL_ESTIMATE_BF16 = 10.3 | ||
| _MODEL_ESTIMATE_INT8 = 5.1 | ||
| _DECODER_ESTIMATE_BF16 = 3.6 | ||
| _DECODER_ESTIMATE_INT8 = 1.8 | ||
|
|
||
|
|
||
| def _is_rocm() -> bool: | ||
| """Check if running on ROCm (AMD HIP backend).""" | ||
| return ( | ||
| torch.cuda.is_available() | ||
| and hasattr(torch.version, "hip") | ||
| and torch.version.hip is not None | ||
| ) | ||
|
|
||
|
|
||
| def auto_detect_rocm_gfx(): | ||
| """Set HSA_OVERRIDE_GFX_VERSION if running on an unrecognized AMD GPU. | ||
|
|
||
| Only acts when: | ||
| - Running on ROCm (HIP backend) | ||
| - HSA_OVERRIDE_GFX_VERSION is not already set | ||
| - The GPU's gcnArchName matches a known override | ||
| """ | ||
| if not _is_rocm(): | ||
| return | ||
| if os.environ.get("HSA_OVERRIDE_GFX_VERSION"): | ||
| return | ||
|
|
||
| props = torch.cuda.get_device_properties(0) | ||
| arch = getattr(props, "gcnArchName", None) | ||
| if arch is None: | ||
| return | ||
|
|
||
| gfx_ver = _ROCM_GFX_OVERRIDES.get(arch) | ||
| if gfx_ver is not None: | ||
| os.environ["HSA_OVERRIDE_GFX_VERSION"] = gfx_ver | ||
| logger.info( | ||
| f"Auto-detected AMD GPU arch {arch}, " | ||
| f"setting HSA_OVERRIDE_GFX_VERSION={gfx_ver}" | ||
| ) | ||
|
|
||
|
|
||
| def check_vram_and_advise(checkpoint_path: str): | ||
| """Log VRAM guidance if the model may not fit. | ||
|
|
||
| Estimates memory usage based on whether INT8 quantization is active | ||
| and the configured MAX_SEQ_LEN, then compares against available VRAM. | ||
| """ | ||
| if not torch.cuda.is_available(): | ||
| return | ||
|
|
||
| props = torch.cuda.get_device_properties(0) | ||
| total_gb = props.total_memory / 1e9 | ||
|
|
||
| is_int8 = "int8" in str(checkpoint_path) | ||
| max_seq_len = int(os.environ.get("MAX_SEQ_LEN", "32768")) | ||
|
|
||
| model_gb = _MODEL_ESTIMATE_INT8 if is_int8 else _MODEL_ESTIMATE_BF16 | ||
| decoder_gb = _DECODER_ESTIMATE_BF16 | ||
| # KV cache: ~1.2GB at 8192, scales linearly | ||
| kv_gb = (max_seq_len / 8192) * 1.2 | ||
| # Inference scratch/activations overhead | ||
| overhead_gb = 0.5 | ||
|
|
||
| estimated_gb = model_gb + decoder_gb + kv_gb + overhead_gb | ||
|
|
||
| logger.info( | ||
| f"GPU: {props.name}, VRAM: {total_gb:.1f}GB | " | ||
| f"Estimated usage: {estimated_gb:.1f}GB " | ||
| f"(model={'INT8' if is_int8 else 'bf16'}, " | ||
| f"seq_len={max_seq_len}, decoder=bf16)" | ||
| ) | ||
|
|
||
| if estimated_gb > total_gb: | ||
| shortfall = estimated_gb - total_gb | ||
| suggestions = [] | ||
| if not is_int8: | ||
| suggestions.append( | ||
| "quantize to INT8 (saves ~5GB): " | ||
| "python tools/llama/quantize.py --checkpoint-path <path> --mode int8" | ||
| ) | ||
| if max_seq_len > 4096: | ||
| suggestions.append( | ||
| f"reduce MAX_SEQ_LEN (current: {max_seq_len}, try 4096 to save ~{(max_seq_len - 4096) / 8192 * 1.2:.1f}GB)" | ||
| ) | ||
| suggestions.append("set VRAM_FRACTION=0.95 to prevent system freeze on OOM") | ||
|
|
||
| logger.warning( | ||
| f"Estimated VRAM ({estimated_gb:.1f}GB) exceeds available ({total_gb:.1f}GB) " | ||
| f"by {shortfall:.1f}GB. Suggestions:" | ||
| ) | ||
| for i, s in enumerate(suggestions, 1): | ||
| logger.warning(f" {i}. {s}") | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The purpose of this is to avoid having to load the weights twice on GPU and avoid waste of memory for restrained vram on GA cards.
As
instantiate(cfg)creates the model with random weights and thenload_state_dict()replace them with checkpoints weights, there is a short time where it use twice as much memory size as needed.By using map_location="cpu" we do perform the swap directly on system memory and THEN we load the definitive structure to the GPU device, this avoid having either the random init or the checkpoint to reach GPU before the final
.to(device)call.Using this method avoid GA restrained VRAM cards to OOM while matching how llama is loaded too.