Skip to content

Commit fe75135

Browse files
committed
Merge goniz/rocm-support-fixes with extensive ROCm optimizations
2 parents b48adae + 879a200 commit fe75135

42 files changed

Lines changed: 10285 additions & 1097 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

.github/workflows/build_rocm.yml

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
name: Build ROCm and Test
2+
3+
on:
4+
push:
5+
branches: [ rocm-support ]
6+
workflow_dispatch:
7+
8+
jobs:
9+
build-and-test:
10+
runs-on: strix-halo
11+
12+
steps:
13+
- name: Checkout code
14+
uses: actions/checkout@v4
15+
16+
- name: Set up Python
17+
run: |
18+
uv venv venv
19+
source venv/bin/activate
20+
uv pip install --upgrade mlx-lm
21+
22+
- name: Build and install MLX ROCm wheel
23+
run: |
24+
source venv/bin/activate
25+
export CMAKE_ARGS="-DMLX_BUILD_ROCM=ON -DMLX_ROCM_ARCHITECTURES=gfx1151 -DBLA_VENDOR=OpenBLAS -DCMAKE_BUILD_TYPE=RelWithDebInfo"
26+
rm -rf wheelhouse
27+
mkdir -p wheelhouse
28+
uv build --wheel --out-dir wheelhouse .
29+
uv pip install --force-reinstall wheelhouse/mlx-*.whl
30+
31+
- name: Basic MLX GPU test
32+
run: |
33+
source venv/bin/activate
34+
python3 -c "
35+
import mlx.core as mx
36+
print('MLX version:', mx.__version__)
37+
print('Default device:', mx.default_device())
38+
mx.set_default_device(mx.gpu)
39+
print('GPU device set')
40+
41+
# Test basic operations
42+
a = mx.ones((10, 10))
43+
mx.eval(a)
44+
print('Basic array creation: OK')
45+
46+
# Test matmul
47+
b = mx.random.normal((256, 256))
48+
c = mx.matmul(b, b)
49+
mx.eval(c)
50+
print('Matmul test: OK')
51+
52+
# Test softmax
53+
d = mx.softmax(b, axis=-1)
54+
mx.eval(d)
55+
print('Softmax test: OK')
56+
57+
print('All basic tests passed!')
58+
"
59+
60+
- name: Run inference tests
61+
run: |
62+
source venv/bin/activate
63+
export HIP_LAUNCH_BLOCKING=1
64+
export PYTHONFAULTHANDLER=1
65+
mkdir -p "${GITHUB_WORKSPACE}/rocm-stacktraces"
66+
67+
run_and_trace() {
68+
local name="$1"
69+
shift
70+
lldb -Q -b \
71+
-o "run" \
72+
-k "bt" \
73+
-k "quit 1" \
74+
-- python3 "$(which mlx_lm.generate)" "$@" \
75+
> >(tee "${GITHUB_WORKSPACE}/rocm-stacktraces/${name}.log") 2>&1
76+
}
77+
78+
run_and_trace qwen3_bf16 --model mlx-community/Qwen3-0.6B-bf16 --prompt "Hi" --max-tokens 5
79+
run_and_trace qwen3_8bit --model mlx-community/Qwen3-0.6B-8bit --prompt "How tall is Mt Everest?" --max-tokens 128
80+
81+
- name: Upload ROCm wheel artifacts
82+
if: ${{ always() }}
83+
uses: actions/upload-artifact@v6
84+
with:
85+
name: rocm-wheel-${{ github.run_attempt }}
86+
path: wheelhouse/mlx-*.whl
87+
if-no-files-found: warn
88+
retention-days: 14
89+
90+
- name: Upload ROCm stacktrace artifacts
91+
if: ${{ always() }}
92+
uses: actions/upload-artifact@v6
93+
with:
94+
name: rocm-stacktraces-${{ github.run_attempt }}
95+
path: ${{ github.workspace }}/rocm-stacktraces/*
96+
if-no-files-found: warn
97+
retention-days: 14

.gitignore

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,4 +81,8 @@ uv.lock
8181
*.swp
8282

8383
# keys
84-
*.pem
84+
*.pem
85+
86+
build.sh
87+
github-runner/
88+
sync_fork.sh

0 commit comments

Comments
 (0)