|
15 | 15 | inherit (kernel-builder.inputs.nixpkgs) lib; |
16 | 16 |
|
17 | 17 | cudaVersion = "cu126"; |
| 18 | + rocmVersion = "rocm71"; |
18 | 19 | torchVersion = "211"; |
19 | 20 | tvmFfiVersion = "01"; |
20 | 21 |
|
|
136 | 137 | } |
137 | 138 | ]; |
138 | 139 |
|
| 140 | + # ROCm kernels to build in CI. |
| 141 | + ciRocmKernels = [ |
| 142 | + { |
| 143 | + name = "relu-invalid-capability"; |
| 144 | + path = ./relu-invalid-capability; |
| 145 | + drv = sys: out: out.packages.${sys}.redistributable.${"torch${torchVersion}-${rocmVersion}-${sys}"}; |
| 146 | + assertFail = true; |
| 147 | + assertFailLogs = [ "empty set of architectures" ]; |
| 148 | + } |
| 149 | + { |
| 150 | + name = "relu-kernel"; |
| 151 | + path = ./relu; |
| 152 | + drv = sys: out: out.packages.${sys}.redistributable.${"torch${torchVersion}-${rocmVersion}-${sys}"}; |
| 153 | + } |
| 154 | + { |
| 155 | + name = "relu-compiler-flags"; |
| 156 | + path = ./relu-compiler-flags; |
| 157 | + drv = sys: out: out.packages.${sys}.redistributable.${"torch${torchVersion}-${rocmVersion}-${sys}"}; |
| 158 | + } |
| 159 | + ]; |
| 160 | + |
139 | 161 | mkKernelOutputs = |
140 | 162 | { |
141 | 163 | path, |
|
148 | 170 | // lib.optionalAttrs (torchVersions != null) { inherit torchVersions; } |
149 | 171 | ); |
150 | 172 |
|
151 | | - ciKernelOutputs = map ( |
152 | | - kernel: |
153 | | - kernel |
154 | | - // { |
155 | | - outputs = mkKernelOutputs { |
156 | | - inherit (kernel) path; |
157 | | - torchVersions = kernel.torchVersions or null; |
158 | | - }; |
159 | | - } |
160 | | - ) ciKernels; |
| 173 | + mkKernelOutputs' = |
| 174 | + kernels: |
| 175 | + map ( |
| 176 | + kernel: |
| 177 | + kernel |
| 178 | + // { |
| 179 | + outputs = mkKernelOutputs { |
| 180 | + inherit (kernel) path; |
| 181 | + torchVersions = kernel.torchVersions or null; |
| 182 | + }; |
| 183 | + } |
| 184 | + ) kernels; |
| 185 | + |
| 186 | + ciKernelOutputs = mkKernelOutputs' ciKernels; |
| 187 | + ciRocmKernelOutputs = mkKernelOutputs' ciRocmKernels; |
161 | 188 | in |
162 | 189 | flake-utils.lib.eachSystem |
163 | 190 | [ |
|
169 | 196 | let |
170 | 197 | pkgs = nixpkgs.legacyPackages.${system}; |
171 | 198 |
|
172 | | - resolvedKernels = map (kernel: { |
173 | | - inherit (kernel) name; |
174 | | - drv = |
175 | | - let |
176 | | - baseDrv = kernel.drv system kernel.outputs; |
177 | | - in |
178 | | - if kernel.assertFail or false then |
179 | | - pkgs.testers.testBuildFailure' { |
180 | | - drv = baseDrv; |
181 | | - expectedBuilderLogEntries = kernel.assertFailLogs or [ ]; |
182 | | - } |
183 | | - else |
184 | | - baseDrv; |
185 | | - }) ciKernelOutputs; |
186 | | - |
187 | | - ci-build = pkgs.linkFarm "ci-kernels" ( |
| 199 | + resolveKernels = |
| 200 | + kernelOutputsList: |
188 | 201 | map (kernel: { |
189 | 202 | inherit (kernel) name; |
190 | | - path = kernel.drv; |
191 | | - }) resolvedKernels |
192 | | - ); |
| 203 | + drv = |
| 204 | + let |
| 205 | + baseDrv = kernel.drv system kernel.outputs; |
| 206 | + in |
| 207 | + if kernel.assertFail or false then |
| 208 | + pkgs.testers.testBuildFailure' { |
| 209 | + drv = baseDrv; |
| 210 | + expectedBuilderLogEntries = kernel.assertFailLogs or [ ]; |
| 211 | + } |
| 212 | + else |
| 213 | + baseDrv; |
| 214 | + }) kernelOutputsList; |
| 215 | + |
| 216 | + mkCiBuild = |
| 217 | + name: kernelOutputsList: |
| 218 | + pkgs.linkFarm name ( |
| 219 | + map (kernel: { |
| 220 | + inherit (kernel) name; |
| 221 | + path = kernel.drv; |
| 222 | + }) (resolveKernels kernelOutputsList) |
| 223 | + ); |
| 224 | + |
| 225 | + ci-build-cuda = mkCiBuild "ci-kernels-cuda" ciKernelOutputs; |
| 226 | + ci-build-rocm = mkCiBuild "ci-kernels-rocm" ciRocmKernelOutputs; |
193 | 227 | in |
194 | 228 | { |
195 | 229 | packages = { |
196 | | - inherit ci-build; |
197 | | - default = ci-build; |
| 230 | + inherit ci-build-cuda ci-build-rocm; |
| 231 | + default = ci-build-cuda; |
198 | 232 | }; |
199 | 233 | } |
200 | 234 | ); |
|
0 commit comments