Skip to content

Commit f55b213

Browse files
authored
Fix aarch64-linux and add it to CI (#286)
* CI: run builder tests on AArch64 * builder: fix aarch64 support
1 parent 3d4bc4c commit f55b213

4 files changed

Lines changed: 36 additions & 15 deletions

File tree

.github/workflows/build_kernel.yaml

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,16 @@ on:
99

1010
jobs:
1111
build:
12-
name: Build kernels
12+
name: Build kernels (${{ matrix.arch }})
13+
strategy:
14+
matrix:
15+
include:
16+
- arch: x86_64-linux
17+
runner: aws-highmemory-32-plus-nix
18+
- arch: aarch64-linux
19+
runner: aws-r8g-8xl-plus-nix
1320
runs-on:
14-
group: aws-highmemory-32-plus-nix
21+
group: ${{ matrix.runner }}
1522
steps:
1623
- uses: actions/checkout@v6
1724
- uses: DeterminateSystems/nix-installer-action@main
@@ -30,27 +37,27 @@ jobs:
3037
run: nix-shell -p nix-info --run "nix-info -m"
3138

3239
- name: Build relu kernel
33-
run: ( cd builder/examples/relu && nix build .\#redistributable.torch29-cxx11-cu126-x86_64-linux )
40+
run: ( cd builder/examples/relu && nix build .\#redistributable.torch29-cxx11-cu126-${{ matrix.arch }} )
3441
- name: Copy relu kernel
3542
run: cp -rL builder/examples/relu/result relu-kernel
3643

3744
- name: Build extra-data kernel
38-
run: ( cd builder/examples/extra-data && nix build .\#redistributable.torch29-cxx11-cu126-x86_64-linux )
45+
run: ( cd builder/examples/extra-data && nix build .\#redistributable.torch29-cxx11-cu126-${{ matrix.arch }} )
3946
- name: Copy extra-data kernel
4047
run: cp -rL builder/examples/extra-data/result extra-data
4148

4249
- name: Build relu kernel (CPU)
43-
run: ( cd builder/examples/relu && nix build .\#redistributable.torch29-cxx11-cpu-x86_64-linux )
50+
run: ( cd builder/examples/relu && nix build .\#redistributable.torch29-cxx11-cpu-${{ matrix.arch }} )
4451
- name: Copy relu kernel (CPU)
4552
run: cp -rL builder/examples/relu/result relu-kernel-cpu
4653

4754
- name: Build cutlass GEMM kernel
48-
run: ( cd builder/examples/cutlass-gemm && nix build .\#redistributable.torch29-cxx11-cu126-x86_64-linux )
55+
run: ( cd builder/examples/cutlass-gemm && nix build .\#redistributable.torch29-cxx11-cu126-${{ matrix.arch }} )
4956
- name: Copy cutlass GEMM kernel
5057
run: cp -rL builder/examples/cutlass-gemm/result cutlass-gemm-kernel
5158

5259
- name: Build relu-backprop-compile kernel
53-
run: ( cd builder/examples/relu-backprop-compile && nix build .\#redistributable.torch29-cxx11-cu126-x86_64-linux )
60+
run: ( cd builder/examples/relu-backprop-compile && nix build .\#redistributable.torch29-cxx11-cu126-${{ matrix.arch }} )
5461
- name: Copy relu-backprop-compile kernel
5562
run: cp -rL builder/examples/relu-backprop-compile/result relu-backprop-compile-kernel
5663

@@ -59,10 +66,10 @@ jobs:
5966
run: ( cd builder/examples/relu-specific-torch && nix build . )
6067

6168
- name: Build relu kernel (compiler flags)
62-
run: ( cd builder/examples/relu-compiler-flags && nix build .\#redistributable.torch29-cxx11-cu126-x86_64-linux )
69+
run: ( cd builder/examples/relu-compiler-flags && nix build .\#redistributable.torch29-cxx11-cu126-${{ matrix.arch }} )
6370

6471
- name: Test that we can build a test shell (e.g. that gcc corresponds to CUDA-required)
65-
run: ( cd builder/examples/relu && nix build .#devShells.x86_64-linux.test )
72+
run: ( cd builder/examples/relu && nix build .#devShells.${{ matrix.arch }}.test )
6673

6774
- name: Build silu-and-mul kernel
6875
run: ( cd builder/examples/silu-and-mul && nix build .\#redistributable.torch-cuda )
@@ -72,7 +79,7 @@ jobs:
7279
- name: Upload kernel artifacts
7380
uses: actions/upload-artifact@v6
7481
with:
75-
name: built-kernels
82+
name: built-kernels-${{ matrix.arch }}
7683
path: |
7784
activation-kernel
7885
cutlass-gemm-kernel
@@ -93,7 +100,7 @@ jobs:
93100
- name: Download kernel artifacts
94101
uses: actions/download-artifact@v7
95102
with:
96-
name: built-kernels
103+
name: built-kernels-x86_64-linux
97104
path: .
98105

99106
- name: Set up Docker Buildx

nix/pkgs/python-modules/cuda-bindings/default.nix

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,10 @@ let
1818
let
1919
cuda_12 = {
2020
version = "12.9.4";
21-
hash = "sha256-Mr3Fp2kGvkxh65j1RqZ4bFdzqIHzsWZIZEm10UHko58=";
21+
hash = {
22+
x86_64-linux = "sha256-Mr3Fp2kGvkxh65j1RqZ4bFdzqIHzsWZIZEm10UHko58=";
23+
aarch64-linux = "sha256-z4v67cI487EV2VfR/WVit+hDW6V/bQ4vh9DnFJzLLaU=";
24+
};
2225
};
2326
in
2427
{
@@ -27,14 +30,20 @@ let
2730
"12.9" = cuda_12;
2831
"13.0" = {
2932
version = "13.0.3";
30-
hash = "sha256-US0NgDpeR6ikLVo0zgkygCv3L+lS/bEax5hxWjXG5cs=";
33+
hash = {
34+
x86_64-linux = "sha256-US0NgDpeR6ikLVo0zgkygCv3L+lS/bEax5hxWjXG5cs=";
35+
aarch64-linux = "sha256-+xan92nJxnRprdeh2fbBTdRGN/aSHLa564LLUBWzXD0=";
36+
};
3137
};
3238
};
3339

3440
versionHash =
3541
versionHashes.${cudaPackages.cudaMajorMinorVersion}
3642
or (throw "Unsupported CUDA version: ${cudaPackages.cudaMajorMinorVersion}");
37-
inherit (versionHash) hash version;
43+
inherit (versionHash) version;
44+
hash =
45+
versionHash.hash.${stdenv.hostPlatform.system}
46+
or (throw "No hash defined for system: ${stdenv.hostPlatform.system}");
3847

3948
format = "wheel";
4049
pyShortVersion = "cp" + builtins.replaceStrings [ "." ] [ "" ] python.pythonVersion;

nix/pkgs/python-modules/torch/binary/torch-versions-hash.json

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,11 @@
116116
"hash": "sha256-vbzHAzgvlI6VHAY0SMlAa/OM5mxB3WmNnicz/PlsA3o=",
117117
"version": "2.10.0"
118118
},
119+
"cu130": {
120+
"url": "https://download.pytorch.org/whl/cu130/torch-2.10.0%2Bcu130-cp313-cp313-manylinux_2_28_aarch64.whl",
121+
"hash": "sha256-dXgCgzCN+f7eNx7toB6WB8iGKhgDovLzGgiiwN6u00I=",
122+
"version": "2.10.0"
123+
},
119124
"cpu": {
120125
"url": "https://download.pytorch.org/whl/cpu/torch-2.10.0%2Bcpu-cp313-cp313-manylinux_2_28_aarch64.whl",
121126
"hash": "sha256-5RmUSSzbdu3OKdqI3jZyowIvnvD/2QNFQ2lI1Jkr4sc=",

nix/pkgs/python-modules/torch/binary/torch-versions.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@
5353
{
5454
"torchVersion": "2.10.0",
5555
"cudaVersion": "13.0",
56-
"systems": ["x86_64-linux"]
56+
"systems": ["x86_64-linux", "aarch64-linux"]
5757
},
5858
{
5959
"torchVersion": "2.10.0",

0 commit comments

Comments
 (0)