11{
22 lib ,
33
4- # List of build sets. Each build set is a attrset of the form
5- #
6- # { pkgs = <nixpkgs>, torch = <torch drv> }
7- #
8- # The Torch derivation is built as-is. So e.g. the ABI version should
9- # already be set.
10- buildSets ,
4+ # Every `buildSets` argument is a list of build sets. Each build set is
5+ # a attrset of the form
6+ #
7+ # { pkgs = <nixpkgs>, torch = <torch drv> }
8+ #
9+ # The Torch derivation is built as-is. So e.g. the ABI version should
10+ # already be set.
1111} :
1212
1313let
@@ -106,10 +106,11 @@ rec {
106106 in
107107 builtins . filter supportedBuildSet buildSets ;
108108
109- applicableBuildSets = path : filterApplicableBuildSets ( readBuildConfig path ) buildSets ;
109+ applicableBuildSets =
110+ { path , buildSets } : filterApplicableBuildSets ( readBuildConfig path ) buildSets ;
110111
111112 # Build a single Torch extension.
112- buildTorchExtension =
113+ mkTorchExtension =
113114 {
114115 buildConfig ,
115116 pkgs ,
@@ -172,56 +173,47 @@ rec {
172173 } ) ;
173174
174175 # Build multiple Torch extensions.
175- buildNixTorchExtensions =
176- { path , rev } :
177- let
178- extensionForTorch =
179- { path , rev } :
180- buildSet : {
181- name = torchBuildVersion buildSet ;
182- value = buildTorchExtension buildSet { inherit path rev ; } ;
183- } ;
184- in
185- builtins . listToAttrs ( lib . map ( extensionForTorch { inherit path rev ; } ) ( applicableBuildSets path ) ) ;
186-
187- # Build multiple Torch extensions.
188- buildDistTorchExtensions =
176+ mkDistTorchExtensions =
189177 {
190178 path ,
191179 rev ,
192180 doGetKernelCheck ,
193181 bundleOnly ,
182+ buildSets ,
194183 } :
195184 let
196185 extensionForTorch =
197186 { path , rev } :
198187 buildSet : {
199188 name = torchBuildVersion buildSet ;
200- value = buildTorchExtension buildSet {
189+ value = mkTorchExtension buildSet {
201190 inherit path rev doGetKernelCheck ;
202191 stripRPath = true ;
203192 oldLinuxCompat = true ;
204193 } ;
205194 } ;
206195 applicableBuildSets' =
207- if bundleOnly then
208- builtins . filter ( buildSet : buildSet . bundleBuild ) ( applicableBuildSets path )
209- else
210- ( applicableBuildSets path ) ;
196+ if bundleOnly then builtins . filter ( buildSet : buildSet . bundleBuild ) buildSets else buildSets ;
211197 in
212198 builtins . listToAttrs ( lib . map ( extensionForTorch { inherit path rev ; } ) applicableBuildSets' ) ;
213199
214- buildTorchExtensionBundle =
200+ mkTorchExtensionBundle =
215201 {
216202 path ,
217203 rev ,
218204 doGetKernelCheck ,
205+ buildSets ,
219206 } :
220207 let
221208 # We just need to get any nixpkgs for use by the path join.
222209 pkgs = ( builtins . head buildSets ) . pkgs ;
223- extensions = buildDistTorchExtensions {
224- inherit path rev doGetKernelCheck ;
210+ extensions = mkDistTorchExtensions {
211+ inherit
212+ buildSets
213+ path
214+ rev
215+ doGetKernelCheck
216+ ;
225217 bundleOnly = true ;
226218 } ;
227219 buildConfig = readBuildConfig path ;
@@ -243,6 +235,7 @@ rec {
243235 {
244236 path ,
245237 rev ,
238+ buildSets ,
246239 doGetKernelCheck ,
247240 pythonCheckInputs ,
248241 pythonNativeCheckInputs ,
@@ -271,18 +264,19 @@ rec {
271264 ++ ( pythonCheckInputs python3 . pkgs ) ;
272265 shellHook = ''
273266 export PYTHONPATH='' ${PYTHONPATH}:${
274- buildTorchExtension buildSet { inherit path rev doGetKernelCheck ; }
267+ mkTorchExtension buildSet { inherit path rev doGetKernelCheck ; }
275268 }
276269 '' ;
277270 } ;
278271 } ;
279272 in
280- builtins . listToAttrs ( lib . map ( shellForBuildSet { inherit path rev ; } ) ( applicableBuildSets path ) ) ;
273+ builtins . listToAttrs ( lib . map ( shellForBuildSet { inherit path rev ; } ) buildSets ) ;
281274
282- torchDevShells =
275+ mkTorchDevShells =
283276 {
284277 path ,
285278 rev ,
279+ buildSets ,
286280 doGetKernelCheck ,
287281 pythonCheckInputs ,
288282 pythonNativeCheckInputs ,
@@ -309,7 +303,7 @@ rec {
309303 ]
310304 ++ ( pythonNativeCheckInputs python3 . pkgs ) ;
311305 buildInputs = with pkgs ; [ python3 . pkgs . pytest ] ++ ( pythonCheckInputs python3 . pkgs ) ;
312- inputsFrom = [ ( buildTorchExtension buildSet { inherit path rev doGetKernelCheck ; } ) ] ;
306+ inputsFrom = [ ( mkTorchExtension buildSet { inherit path rev doGetKernelCheck ; } ) ] ;
313307 env = lib . optionalAttrs rocmSupport {
314308 PYTORCH_ROCM_ARCH = lib . concatStringsSep ";" buildSet . torch . rocmArchs ;
315309 HIP_PATH = pkgs . rocmPackages . clr ;
@@ -318,5 +312,5 @@ rec {
318312 } ;
319313 } ;
320314 in
321- builtins . listToAttrs ( lib . map shellForBuildSet ( applicableBuildSets path ) ) ;
315+ builtins . listToAttrs ( lib . map shellForBuildSet buildSets ) ;
322316}
0 commit comments