Skip to content

Commit ae44554

Browse files
committed
Hook up ROCm test and make it concurrent
1 parent 7f7fca0 commit ae44554

3 files changed

Lines changed: 69 additions & 38 deletions

File tree

.github/workflows/build_kernel.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ jobs:
4040
run: nix-shell -p nix-info --run "nix-info -m"
4141

4242
- name: Build all example kernels
43-
run: nix build -L ./examples/kernels#ci-build
43+
run: nix build -L ./examples/kernels#ci-build-cuda
4444
- name: Copy kernel artifacts
4545
run: cp -rL result/* .
4646

.github/workflows/build_kernel_rocm.yaml

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ on:
1212

1313
jobs:
1414
build:
15-
name: Build kernel
15+
name: Build kernels (ROCm)
1616
runs-on:
1717
group: aws-highmemory-32-plus-nix
1818
steps:
@@ -33,8 +33,5 @@ jobs:
3333
run: nix-shell -p nix-info --run "nix-info -m"
3434
# For now we only test that there are no regressions in building ROCm
3535
# kernels. Also run tests once we have a ROCm runner.
36-
- name: Build relu kernel
37-
run: ( cd examples/kernels/relu && nix build .\#redistributable.torch211-rocm71-x86_64-linux -L )
38-
39-
- name: Build relu kernel (compiler flags)
40-
run: ( cd examples/kernels/relu-compiler-flags && nix build .\#redistributable.torch211-rocm71-x86_64-linux )
36+
- name: Build all ROCm example kernels
37+
run: nix build -L ./examples/kernels#ci-build-rocm

examples/kernels/flake.nix

Lines changed: 65 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
inherit (kernel-builder.inputs.nixpkgs) lib;
1616

1717
cudaVersion = "cu126";
18+
rocmVersion = "rocm71";
1819
torchVersion = "211";
1920
tvmFfiVersion = "01";
2021

@@ -136,6 +137,27 @@
136137
}
137138
];
138139

140+
# ROCm kernels to build in CI.
141+
ciRocmKernels = [
142+
{
143+
name = "relu-invalid-capability";
144+
path = ./relu-invalid-capability;
145+
drv = sys: out: out.packages.${sys}.redistributable.${"torch${torchVersion}-${rocmVersion}-${sys}"};
146+
assertFail = true;
147+
assertFailLogs = [ "empty set of architectures" ];
148+
}
149+
{
150+
name = "relu-kernel";
151+
path = ./relu;
152+
drv = sys: out: out.packages.${sys}.redistributable.${"torch${torchVersion}-${rocmVersion}-${sys}"};
153+
}
154+
{
155+
name = "relu-compiler-flags";
156+
path = ./relu-compiler-flags;
157+
drv = sys: out: out.packages.${sys}.redistributable.${"torch${torchVersion}-${rocmVersion}-${sys}"};
158+
}
159+
];
160+
139161
mkKernelOutputs =
140162
{
141163
path,
@@ -148,16 +170,21 @@
148170
// lib.optionalAttrs (torchVersions != null) { inherit torchVersions; }
149171
);
150172

151-
ciKernelOutputs = map (
152-
kernel:
153-
kernel
154-
// {
155-
outputs = mkKernelOutputs {
156-
inherit (kernel) path;
157-
torchVersions = kernel.torchVersions or null;
158-
};
159-
}
160-
) ciKernels;
173+
mkKernelOutputs' =
174+
kernels:
175+
map (
176+
kernel:
177+
kernel
178+
// {
179+
outputs = mkKernelOutputs {
180+
inherit (kernel) path;
181+
torchVersions = kernel.torchVersions or null;
182+
};
183+
}
184+
) kernels;
185+
186+
ciKernelOutputs = mkKernelOutputs' ciKernels;
187+
ciRocmKernelOutputs = mkKernelOutputs' ciRocmKernels;
161188
in
162189
flake-utils.lib.eachSystem
163190
[
@@ -169,32 +196,39 @@
169196
let
170197
pkgs = nixpkgs.legacyPackages.${system};
171198

172-
resolvedKernels = map (kernel: {
173-
inherit (kernel) name;
174-
drv =
175-
let
176-
baseDrv = kernel.drv system kernel.outputs;
177-
in
178-
if kernel.assertFail or false then
179-
pkgs.testers.testBuildFailure' {
180-
drv = baseDrv;
181-
expectedBuilderLogEntries = kernel.assertFailLogs or [ ];
182-
}
183-
else
184-
baseDrv;
185-
}) ciKernelOutputs;
186-
187-
ci-build = pkgs.linkFarm "ci-kernels" (
199+
resolveKernels =
200+
kernelOutputsList:
188201
map (kernel: {
189202
inherit (kernel) name;
190-
path = kernel.drv;
191-
}) resolvedKernels
192-
);
203+
drv =
204+
let
205+
baseDrv = kernel.drv system kernel.outputs;
206+
in
207+
if kernel.assertFail or false then
208+
pkgs.testers.testBuildFailure' {
209+
drv = baseDrv;
210+
expectedBuilderLogEntries = kernel.assertFailLogs or [ ];
211+
}
212+
else
213+
baseDrv;
214+
}) kernelOutputsList;
215+
216+
mkCiBuild =
217+
name: kernelOutputsList:
218+
pkgs.linkFarm name (
219+
map (kernel: {
220+
inherit (kernel) name;
221+
path = kernel.drv;
222+
}) (resolveKernels kernelOutputsList)
223+
);
224+
225+
ci-build-cuda = mkCiBuild "ci-kernels-cuda" ciKernelOutputs;
226+
ci-build-rocm = mkCiBuild "ci-kernels-rocm" ciRocmKernelOutputs;
193227
in
194228
{
195229
packages = {
196-
inherit ci-build;
197-
default = ci-build;
230+
inherit ci-build-cuda ci-build-rocm;
231+
default = ci-build-cuda;
198232
};
199233
}
200234
);

0 commit comments

Comments
 (0)