Skip to content
This repository was archived by the owner on Apr 6, 2026. It is now read-only.

Commit 378524d

Browse files
authored
Select default shells from available variants (#247)
Before this change, `nix develop` and `nix develop .#test` would use the `torch28-metal-aarch64-darwin` or `torch28-cxx11-cu126-${system}` build variants on macOS or Linux respectively. This had two downsides: (1) it resulted in an error for ROCm/XPU-only kernels, requiring to manually specify the variant; (2) we had to maintain the variants (e.g. bump up the Torch version when 2.8 is not supported anymore). This change chooses the shell depending on the variants that are available for a given kernel on a given system. We use the following ordering to sort the available variants: - Bundle variants before non-bundle variants. - CUDA variants before other frameworks. - Newer Torch versions before older versions. - Older frameworks before newer (best system compatibility). Then we choose the first variant in this ordering. This should select the best variant in most cases.
1 parent e832e87 commit 378524d

4 files changed

Lines changed: 71 additions & 48 deletions

File tree

flake.nix

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,8 @@
6363
buildVariants =
6464
(import ./lib/build-variants.nix {
6565
inherit (nixpkgs) lib;
66-
torchVersions = torchVersions';
67-
}).buildVariants;
66+
}).buildVariants
67+
torchVersions';
6868
in
6969
builtins.toJSON buildVariants;
7070
genFlakeOutputs =
@@ -104,7 +104,6 @@
104104
pythonNativeCheckInputs
105105
;
106106
build = buildPerSystem.${system};
107-
buildSet = buildSetPerSystem.${system};
108107
}
109108
);
110109
}

lib/build-variants.nix

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
{ lib, torchVersions }:
1+
{ lib }:
22
let
33
inherit (import ./torch-version-utils.nix { inherit lib; })
44
flattenSystems
@@ -22,8 +22,7 @@ rec {
2222
else
2323
throw "Could not find compute framework: no CUDA, ROCm, XPU version specified and Metal is not enabled";
2424

25-
# Build variants included in bundle builds.
26-
buildVariants =
25+
buildName =
2726
let
2827
inherit (import ./version-utils.nix { inherit lib; }) abiString flattenVersion;
2928
computeString =
@@ -38,12 +37,17 @@ rec {
3837
"xpu${flattenVersion (lib.versions.majorMinor version.xpuVersion)}"
3938
else
4039
throw "No compute framework set in Torch version";
41-
buildName =
42-
version:
43-
if version.system == "aarch64-darwin" then
44-
"torch${flattenVersion version.torchVersion}-${computeString version}-${version.system}"
45-
else
46-
"torch${flattenVersion version.torchVersion}-${abiString version.cxx11Abi}-${computeString version}-${version.system}";
40+
in
41+
version:
42+
if version.system == "aarch64-darwin" then
43+
"torch${flattenVersion version.torchVersion}-${computeString version}-${version.system}"
44+
else
45+
"torch${flattenVersion version.torchVersion}-${abiString version.cxx11Abi}-${computeString version}-${version.system}";
46+
47+
# Build variants included in bundle builds.
48+
buildVariants =
49+
torchVersions:
50+
let
4751
bundleBuildVersions = lib.filter (version: version.bundleBuild or false);
4852
in
4953
lib.foldl' (

lib/build.nix

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,6 @@ rec {
7171
lib.foldl (backends: kernel: backends // { ${kernelBackend kernel} = true; }) init kernels;
7272

7373
readBuildConfig = path: validateBuildConfig (readToml (path + "/build.toml"));
74-
tracedReadBuildConfig = path: readBuildConfig path;
7574

7675
srcFilter =
7776
src: name: type:
@@ -81,7 +80,7 @@ rec {
8180
mkSourceSet = import ./source-set.nix { inherit lib; };
8281

8382
# Filter buildsets that are applicable to a given kernel build config.
84-
applicableBuildSets =
83+
filterApplicableBuildSets =
8584
buildConfig: buildSets:
8685
let
8786
backends' = backends buildConfig;
@@ -107,6 +106,8 @@ rec {
107106
in
108107
builtins.filter supportedBuildSet buildSets;
109108

109+
applicableBuildSets = path: filterApplicableBuildSets (readBuildConfig path) buildSets;
110+
110111
# Build a single Torch extension.
111112
buildTorchExtension =
112113
{
@@ -180,17 +181,16 @@ rec {
180181
name = torchBuildVersion buildSet;
181182
value = buildTorchExtension buildSet { inherit path rev; };
182183
};
183-
filteredBuildSets = applicableBuildSets (readBuildConfig path) buildSets;
184184
in
185-
builtins.listToAttrs (lib.map (extensionForTorch { inherit path rev; }) filteredBuildSets);
185+
builtins.listToAttrs (lib.map (extensionForTorch { inherit path rev; }) (applicableBuildSets path));
186186

187187
# Build multiple Torch extensions.
188188
buildDistTorchExtensions =
189189
{
190-
buildSets,
191190
path,
192191
rev,
193192
doGetKernelCheck,
193+
bundleOnly,
194194
}:
195195
let
196196
extensionForTorch =
@@ -203,9 +203,13 @@ rec {
203203
oldLinuxCompat = true;
204204
};
205205
};
206-
filteredBuildSets = applicableBuildSets (readBuildConfig path) buildSets;
206+
applicableBuildSets' =
207+
if bundleOnly then
208+
builtins.filter (buildSet: buildSet.bundleBuild) (applicableBuildSets path)
209+
else
210+
(applicableBuildSets path);
207211
in
208-
builtins.listToAttrs (lib.map (extensionForTorch { inherit path rev; }) filteredBuildSets);
212+
builtins.listToAttrs (lib.map (extensionForTorch { inherit path rev; }) applicableBuildSets');
209213

210214
buildTorchExtensionBundle =
211215
{
@@ -216,10 +220,9 @@ rec {
216220
let
217221
# We just need to get any nixpkgs for use by the path join.
218222
pkgs = (builtins.head buildSets).pkgs;
219-
bundleBuildSets = builtins.filter (buildSet: buildSet.bundleBuild) buildSets;
220223
extensions = buildDistTorchExtensions {
221224
inherit path rev doGetKernelCheck;
222-
buildSets = bundleBuildSets;
225+
bundleOnly = true;
223226
};
224227
buildConfig = readBuildConfig path;
225228
namePaths =
@@ -273,9 +276,8 @@ rec {
273276
'';
274277
};
275278
};
276-
filteredBuildSets = applicableBuildSets (readBuildConfig path) buildSets;
277279
in
278-
builtins.listToAttrs (lib.map (shellForBuildSet { inherit path rev; }) filteredBuildSets);
280+
builtins.listToAttrs (lib.map (shellForBuildSet { inherit path rev; }) (applicableBuildSets path));
279281

280282
torchDevShells =
281283
{
@@ -315,7 +317,6 @@ rec {
315317
venvDir = "./.venv";
316318
};
317319
};
318-
filteredBuildSets = applicableBuildSets (readBuildConfig path) buildSets;
319320
in
320-
builtins.listToAttrs (lib.map shellForBuildSet filteredBuildSets);
321+
builtins.listToAttrs (lib.map shellForBuildSet (applicableBuildSets path));
321322
}

lib/gen-flake-outputs.nix

Lines changed: 42 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
{
2+
lib,
23
build,
3-
buildSet,
44
system,
55

66
writeScriptBin,
@@ -16,6 +16,8 @@
1616
}:
1717

1818
let
19+
inherit (import ./build-variants.nix { inherit lib; }) buildName;
20+
1921
supportedFormat = ''
2022
kernel-builder.lib.genFlakeOutputs {
2123
inherit self;
@@ -34,8 +36,45 @@ let
3436
throw "Flake's `self` must be passed to `genFlakeOutputs` as follows:\n\n${supportedFormat}";
3537

3638
revUnderscored = builtins.replaceStrings [ "-" ] [ "_" ] flakeRev;
39+
40+
# For picking a default shell, etc. we want to use the following logic:
41+
#
42+
# - Prefer bundle builds over non-bundle builds.
43+
# - Prefer CUDA over other frameworks.
44+
# - Prefer newer Torch versions over older.
45+
# - Prefer older frameworks over newer (best compatibility).
46+
47+
# Enrich the build configs with generic attributes for framework
48+
# order/version. Also make bundleBuild attr explicit.
49+
buildConfigs = map (
50+
set:
51+
let
52+
inherit (set) buildConfig;
53+
in
54+
buildConfig
55+
// {
56+
bundleBuild = buildConfig.bundleBuild or false;
57+
frameworkOrder = if buildConfig ? cudaVersion then 0 else 1;
58+
frameworkVersion =
59+
buildConfig.cudaVersion or buildConfig.rocmVersion or buildConfig.xpuVersion or "0.0";
60+
}
61+
) (build.applicableBuildSets path);
62+
configCompare =
63+
a: b:
64+
if a.bundleBuild != b.bundleBuild then
65+
a.bundleBuild
66+
else if a.frameworkOrder != b.frameworkOrder then
67+
a.frameworkOrder < b.frameworkOrder
68+
else if a.torchVersion != b.torchVersion then
69+
builtins.compareVersions a.torchVersion b.torchVersion > 0
70+
else
71+
builtins.compareVersions a.frameworkVersion b.frameworkVersion < 0;
72+
buildConfigsSorted = lib.sort configCompare buildConfigs;
3773
shellTorch =
38-
if system == "aarch64-darwin" then "torch28-metal-${system}" else "torch28-cxx11-cu126-${system}";
74+
if buildConfigsSorted == [ ] then
75+
throw "No build variant is compatible with this system"
76+
else
77+
buildName (builtins.head buildConfigsSorted);
3978
in
4079

4180
{
@@ -90,28 +129,8 @@ in
90129
};
91130
redistributable = build.buildDistTorchExtensions {
92131
inherit path doGetKernelCheck;
93-
buildSets = buildSet;
132+
bundleOnly = false;
94133
rev = revUnderscored;
95134
};
96-
buildTree =
97-
let
98-
src = build.mkSourceSet path;
99-
in
100-
runCommand "torch-extension-build-tree"
101-
{
102-
nativeBuildInputs = [ buildSet.pkgs.build2cmake ];
103-
inherit src;
104-
meta = {
105-
description = "Build tree for torch extension with source files and CMake configuration";
106-
};
107-
}
108-
''
109-
# Copy sources
110-
install -dm755 $out/src
111-
cp -r $src/. $out/src/
112-
113-
# Generate cmake files
114-
build2cmake generate-torch --ops-id "${revUnderscored}" $src/build.toml $out --force
115-
'';
116135
};
117136
}

0 commit comments

Comments
 (0)