Skip to content
This repository was archived by the owner on Apr 6, 2026. It is now read-only.

Commit d7aa270

Browse files
authored
Overlay Torch in the package set (#323)
Before this change, we wouldn't overlay Torch in the package set. The main reason was that evaluation was too expensive -- even more so when we also had C++98 and C++11 variants (which doubled the number of Torch versions). However, this does not really work well with dependencies (some dependencies like `einops` will rely on Torch) or extra test dependencies that a user adds to a flake. So overlay Torch so that all of nixpkgs sees the version that is defined in the build set. In this change, the `cxx11Abi` is also removed and the C++11 ABI is set as the default. Upstream does not build for the C++98 ABI anymore, so it is time to remove the support before it bitrots.
1 parent 2ffd75b commit d7aa270

11 files changed

Lines changed: 71 additions & 173 deletions

File tree

examples/relu-specific-torch/flake.nix

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
{
1818
torchVersion = "2.9";
1919
cudaVersion = "12.8";
20-
cxx11Abi = true;
2120
systems = [
2221
"x86_64-linux"
2322
"aarch64-linux"

lib/build-sets.nix

Lines changed: 66 additions & 119 deletions
Original file line numberDiff line numberDiff line change
@@ -25,26 +25,22 @@ let
2525
in
2626
builtins.map (buildConfig: buildConfig // { backend = backend buildConfig; }) systemBuildConfigs;
2727

28-
cudaVersions =
29-
let
30-
withCuda = builtins.filter (torchVersion: torchVersion ? cudaVersion) torchVersions;
31-
in
32-
lib.unique (builtins.map (torchVersion: torchVersion.cudaVersion) withCuda);
33-
34-
rocmVersions =
35-
let
36-
withRocm = builtins.filter (torchVersion: torchVersion ? rocmVersion) torchVersions;
37-
in
38-
lib.unique (builtins.map (torchVersion: torchVersion.rocmVersion) withRocm);
39-
40-
xpuVersions =
41-
let
42-
withXpu = builtins.filter (torchVersion: torchVersion ? xpuVersion) torchVersions;
43-
in
44-
lib.unique (builtins.map (torchVersion: torchVersion.xpuVersion) withXpu);
45-
4628
flattenVersion = version: lib.replaceStrings [ "." ] [ "_" ] (lib.versions.pad 2 version);
4729

30+
overlayForTorchVersion = torchVersion: sourceBuild: self: super: {
31+
pythonPackagesExtensions = super.pythonPackagesExtensions ++ [
32+
(
33+
python-self: python-super: with python-self; {
34+
torch =
35+
if sourceBuild then
36+
python-self."torch_${flattenVersion torchVersion}"
37+
else
38+
python-self."torch-bin_${flattenVersion torchVersion}";
39+
}
40+
)
41+
];
42+
};
43+
4844
# An overlay that overides CUDA to the given version.
4945
overlayForCudaVersion = cudaVersion: self: super: {
5046
cudaPackages = super."cudaPackages_${flattenVersion cudaVersion}";
@@ -57,6 +53,38 @@ let
5753
overlayForXpuVersion = xpuVersion: self: super: {
5854
xpuPackages = super."xpuPackages_${flattenVersion xpuVersion}";
5955
};
56+
57+
backendConfig = {
58+
cpu = {
59+
allowUnfree = true;
60+
};
61+
62+
cuda = {
63+
allowUnfree = true;
64+
cudaSupport = true;
65+
};
66+
67+
metal = {
68+
allowUnfree = true;
69+
metalSupport = true;
70+
};
71+
72+
rocm = {
73+
allowUnfree = true;
74+
rocmSupport = true;
75+
};
76+
77+
xpu = {
78+
allowUnfree = true;
79+
xpuSupport = true;
80+
};
81+
};
82+
83+
xpuConfig = {
84+
allowUnfree = true;
85+
xpuSupport = true;
86+
};
87+
6088
# Construct the nixpkgs package set for the given versions.
6189
mkBuildSet =
6290
buildConfig@{
@@ -67,34 +95,38 @@ let
6795
rocmVersion ? null,
6896
xpuVersion ? null,
6997
torchVersion,
70-
cxx11Abi,
7198
system,
7299
bundleBuild ? false,
73100
sourceBuild ? false,
74101
}:
75102
let
76-
pkgs =
103+
backendOverlay =
77104
if buildConfig.backend == "cpu" then
78-
pkgsForCpu
105+
[ ]
79106
else if buildConfig.backend == "cuda" then
80-
pkgsByCudaVer.${cudaVersion}
107+
[ (overlayForCudaVersion buildConfig.cudaVersion) ]
81108
else if buildConfig.backend == "rocm" then
82-
pkgsByRocmVer.${rocmVersion}
109+
[ (overlayForRocmVersion buildConfig.rocmVersion) ]
83110
else if buildConfig.backend == "metal" then
84-
pkgsForMetal
111+
[ ]
85112
else if buildConfig.backend == "xpu" then
86-
pkgsByXpuVer.${xpuVersion}
113+
[ (overlayForXpuVersion buildConfig.xpuVersion) ]
87114
else
88115
throw "No compute framework set in Torch version";
89-
torch =
90-
if sourceBuild then
91-
pkgs.python3.pkgs."torch_${flattenVersion torchVersion}".override {
92-
inherit cxx11Abi;
93-
}
94-
else
95-
pkgs.python3.pkgs."torch-bin_${flattenVersion torchVersion}".override {
96-
inherit cxx11Abi;
97-
};
116+
config =
117+
backendConfig.${buildConfig.backend} or (throw "No backend config for ${buildConfig.backend}");
118+
119+
pkgs = import nixpkgs {
120+
inherit config system;
121+
overlays = [
122+
overlay
123+
]
124+
++ backendOverlay
125+
++ [ (overlayForTorchVersion torchVersion sourceBuild) ];
126+
};
127+
128+
torch = pkgs.python3.pkgs.torch;
129+
98130
extension = pkgs.callPackage ./torch-extension { inherit torch; };
99131
in
100132
{
@@ -106,90 +138,5 @@ let
106138
bundleBuild
107139
;
108140
};
109-
pkgsForXpuVersions =
110-
xpuVersions:
111-
builtins.listToAttrs (
112-
map (xpuVersion: {
113-
name = xpuVersion;
114-
value = import nixpkgs {
115-
inherit system;
116-
config = {
117-
allowUnfree = true;
118-
xpuSupport = true;
119-
};
120-
overlays = [
121-
overlay
122-
(overlayForXpuVersion xpuVersion)
123-
];
124-
};
125-
}) xpuVersions
126-
);
127-
pkgsByXpuVer = pkgsForXpuVersions xpuVersions;
128-
129-
pkgsForMetal = import nixpkgs {
130-
inherit system;
131-
config = {
132-
allowUnfree = true;
133-
metalSupport = true;
134-
};
135-
overlays = [
136-
overlay
137-
];
138-
};
139-
140-
pkgsForCpu = import nixpkgs {
141-
inherit system;
142-
config = {
143-
allowUnfree = true;
144-
};
145-
overlays = [
146-
overlay
147-
];
148-
};
149-
150-
# Instantiate nixpkgs for the given CUDA versions. Returns
151-
# an attribute set like `{ "12.4" = <nixpkgs with 12.4>; ... }`.
152-
pkgsForCudaVersions =
153-
cudaVersions:
154-
builtins.listToAttrs (
155-
map (cudaVersion: {
156-
name = cudaVersion;
157-
value = import nixpkgs {
158-
inherit system;
159-
config = {
160-
allowUnfree = true;
161-
cudaSupport = true;
162-
};
163-
overlays = [
164-
overlay
165-
(overlayForCudaVersion cudaVersion)
166-
];
167-
};
168-
}) cudaVersions
169-
);
170-
171-
pkgsByCudaVer = pkgsForCudaVersions cudaVersions;
172-
173-
pkgsForRocmVersions =
174-
rocmVersions:
175-
builtins.listToAttrs (
176-
map (rocmVersion: {
177-
name = rocmVersion;
178-
value = import nixpkgs {
179-
inherit system;
180-
config = {
181-
allowUnfree = true;
182-
rocmSupport = true;
183-
};
184-
overlays = [
185-
overlay
186-
(overlayForRocmVersion rocmVersion)
187-
];
188-
};
189-
}) rocmVersions
190-
);
191-
192-
pkgsByRocmVer = pkgsForRocmVersions rocmVersions;
193-
194141
in
195142
map mkBuildSet (buildConfigs system)

lib/build-variants.nix

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ rec {
99

1010
buildName =
1111
let
12-
inherit (import ./version-utils.nix { inherit lib; }) abiString flattenVersion;
12+
inherit (import ./version-utils.nix { inherit lib; }) flattenVersion;
1313
computeString =
1414
version:
1515
if backend version == "cpu" then
@@ -29,7 +29,7 @@ rec {
2929
if version.system == "aarch64-darwin" then
3030
"torch${flattenVersion version.torchVersion}-${computeString version}-${version.system}"
3131
else
32-
"torch${flattenVersion version.torchVersion}-${abiString version.cxx11Abi}-${computeString version}-${version.system}";
32+
"torch${flattenVersion version.torchVersion}-cxx11-${computeString version}-${version.system}";
3333

3434
# Build variants included in bundle builds.
3535
buildVariants =

lib/build.nix

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
}:
1313

1414
let
15-
abi = torch: if torch.passthru.cxx11Abi then "cxx11" else "cxx98";
1615
supportedCudaCapabilities = builtins.fromJSON (
1716
builtins.readFile ../build2cmake/src/cuda_supported_archs.json
1817
);

lib/version-utils.nix

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,5 +6,4 @@ in
66
{
77
flattenVersion =
88
version: lib.replaceStrings [ "." ] [ "" ] (versions.majorMinor (versions.pad 2 version));
9-
abiString = cxx11Abi: if cxx11Abi then "cxx11" else "cxx98";
109
}

pkgs/python-modules/torch/binary/generic.nix

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,6 @@
4343
url,
4444
hash,
4545
version,
46-
# Remove, needed for compat.
47-
cxx11Abi ? true,
4846

4947
effectiveStdenv ? if cudaSupport then cudaPackages.backendStdenv else stdenv,
5048
}:
@@ -322,7 +320,6 @@ buildPythonPackage {
322320
inherit
323321
cudaSupport
324322
cudaPackages
325-
cxx11Abi
326323
rocmSupport
327324
rocmPackages
328325
xpuSupport
@@ -333,7 +330,6 @@ buildPythonPackage {
333330
rocmArchs = if rocmSupport then supportedTorchRocmArchs else [ ];
334331
}
335332
// (callPackage ../variant.nix {
336-
inherit cxx11Abi;
337333
torchVersion = version;
338334
});
339335

pkgs/python-modules/torch/binary/torch-versions.json

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2,98 +2,82 @@
22
{
33
"torchVersion": "2.8.0",
44
"cudaVersion": "12.6",
5-
"cxx11Abi": true,
65
"systems": ["x86_64-linux"]
76
},
87
{
98
"torchVersion": "2.8.0",
109
"cudaVersion": "12.8",
11-
"cxx11Abi": true,
1210
"systems": ["x86_64-linux"]
1311
},
1412
{
1513
"torchVersion": "2.8.0",
1614
"cudaVersion": "12.9",
17-
"cxx11Abi": true,
1815
"systems": ["x86_64-linux", "aarch64-linux"]
1916
},
2017
{
2118
"torchVersion": "2.8.0",
2219
"rocmVersion": "6.3",
23-
"cxx11Abi": true,
2420
"systems": ["x86_64-linux"]
2521
},
2622
{
2723
"torchVersion": "2.8.0",
2824
"rocmVersion": "6.4",
29-
"cxx11Abi": true,
3025
"systems": ["x86_64-linux"]
3126
},
3227
{
3328
"torchVersion": "2.8.0",
34-
"cxx11Abi": true,
3529
"metal": true,
3630
"systems": ["aarch64-darwin"]
3731
},
3832
{
3933
"torchVersion": "2.8.0",
40-
"cxx11Abi": true,
4134
"cpu": true,
4235
"systems": ["aarch64-linux", "x86_64-linux"]
4336
},
4437
{
4538
"torchVersion": "2.8.0",
4639
"xpuVersion": "2025.1.3",
47-
"cxx11Abi": true,
4840
"systems": ["x86_64-linux"]
4941
},
5042

5143
{
5244
"torchVersion": "2.9.0",
5345
"cudaVersion": "12.6",
54-
"cxx11Abi": true,
5546
"systems": ["x86_64-linux", "aarch64-linux"]
5647
},
5748
{
5849
"torchVersion": "2.9.0",
5950
"cudaVersion": "12.8",
60-
"cxx11Abi": true,
6151
"systems": ["x86_64-linux", "aarch64-linux"]
6252
},
6353
{
6454
"torchVersion": "2.9.0",
6555
"cudaVersion": "13.0",
66-
"cxx11Abi": true,
6756
"systems": ["x86_64-linux", "aarch64-linux"]
6857
},
6958
{
7059
"torchVersion": "2.9.0",
7160
"rocmVersion": "6.3",
72-
"cxx11Abi": true,
7361
"systems": ["x86_64-linux"]
7462
},
7563
{
7664
"torchVersion": "2.9.0",
7765
"rocmVersion": "6.4",
78-
"cxx11Abi": true,
7966
"systems": ["x86_64-linux"]
8067
},
8168
{
8269
"torchVersion": "2.9.0",
83-
"cxx11Abi": true,
8470
"metal": true,
8571
"systems": ["aarch64-darwin"]
8672
},
8773
{
8874
"torchVersion": "2.9.0",
89-
"cxx11Abi": true,
9075
"cpu": true,
9176
"systems": ["aarch64-linux", "x86_64-linux"]
9277
},
9378
{
9479
"torchVersion": "2.9.0",
9580
"xpuVersion": "2025.2.1",
96-
"cxx11Abi": true,
9781
"systems": ["x86_64-linux"]
9882
}
9983
]

0 commit comments

Comments
 (0)