Skip to content

Merge branch 'ml-explore:main' into rocm-support #16

Merge branch 'ml-explore:main' into rocm-support

Merge branch 'ml-explore:main' into rocm-support #16

Workflow file for this run

name: Build ROCm and Test
on:
push:
branches: [ rocm-support ]
workflow_dispatch:
jobs:
build-and-test:
runs-on: strix-halo
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Set up Python
run: |
uv venv venv
source venv/bin/activate
uv pip install --upgrade mlx-lm
- name: Build and install MLX ROCm wheel
run: |
source venv/bin/activate
export CMAKE_ARGS="-DMLX_BUILD_ROCM=ON -DMLX_ROCM_ARCHITECTURES=gfx1151 -DBLA_VENDOR=OpenBLAS -DCMAKE_BUILD_TYPE=RelWithDebInfo"
rm -rf wheelhouse
mkdir -p wheelhouse
uv build --wheel --out-dir wheelhouse .
uv pip install --force-reinstall wheelhouse/mlx-*.whl
- name: Basic MLX GPU test
run: |
source venv/bin/activate
python3 -c "
import mlx.core as mx
print('MLX version:', mx.__version__)
print('Default device:', mx.default_device())
mx.set_default_device(mx.gpu)
print('GPU device set')
# Test basic operations
a = mx.ones((10, 10))
mx.eval(a)
print('Basic array creation: OK')
# Test matmul
b = mx.random.normal((256, 256))
c = mx.matmul(b, b)
mx.eval(c)
print('Matmul test: OK')
# Test softmax
d = mx.softmax(b, axis=-1)
mx.eval(d)
print('Softmax test: OK')
print('All basic tests passed!')
"
- name: Run inference tests
run: |
source venv/bin/activate
export HIP_LAUNCH_BLOCKING=1
export PYTHONFAULTHANDLER=1
mkdir -p "${GITHUB_WORKSPACE}/rocm-stacktraces"
run_and_trace() {
local name="$1"
shift
lldb -Q -b \
-o "run" \
-k "bt" \
-k "quit 1" \
-- python3 "$(which mlx_lm.generate)" "$@" \
> >(tee "${GITHUB_WORKSPACE}/rocm-stacktraces/${name}.log") 2>&1
}
run_and_trace qwen3_bf16 --model mlx-community/Qwen3-0.6B-bf16 --prompt "Hi" --max-tokens 5
run_and_trace qwen3_8bit --model mlx-community/Qwen3-0.6B-8bit --prompt "How tall is Mt Everest?" --max-tokens 128
- name: Upload ROCm wheel artifacts
if: ${{ always() }}
uses: actions/upload-artifact@v6
with:
name: rocm-wheel-${{ github.run_attempt }}
path: wheelhouse/mlx-*.whl
if-no-files-found: warn
retention-days: 14
- name: Upload ROCm stacktrace artifacts
if: ${{ always() }}
uses: actions/upload-artifact@v6
with:
name: rocm-stacktraces-${{ github.run_attempt }}
path: ${{ github.workspace }}/rocm-stacktraces/*
if-no-files-found: warn
retention-days: 14