Skip to content
This repository was archived by the owner on Apr 6, 2026. It is now read-only.

Commit 0b7e009

Browse files
committed
Restructure extension handling
Before this change the `torch-extension` directory was just a derivation that built an extension for a given configuration + build set. The downside of this approach was that we could not easily get things like - The standard environment to be used by dev shells. - Overrides for caching. For instance, we override ROCm's `clr` and XPU's `oneapi-torch-dev` and `onednn-xpu` to use stdenv with an old glibc. This change modifies `torch-extension` so that we can instantiate it given a build set. A set is returned that can give access to the things mentioned above, as well as mkExtension and mkNoArchExtension functions. We assign this set to the corresponding build set, so that it is directly accessible from the build set.
1 parent b1e93c3 commit 0b7e009

6 files changed

Lines changed: 268 additions & 237 deletions

File tree

lib/build-sets.nix

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,10 +87,12 @@ let
8787
torch = pkgs.python3.pkgs."torch_${flattenVersion torchVersion}".override {
8888
inherit cxx11Abi;
8989
};
90+
extension = pkgs.callPackage ./torch-extension { inherit torch; };
9091
in
9192
{
9293
inherit
9394
buildConfig
95+
extension
9496
pkgs
9597
torch
9698
bundleBuild

lib/build-version.nix

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
{
22
buildConfig,
3+
extension,
34
pkgs,
45
torch,
56
bundleBuild,

lib/build.nix

Lines changed: 10 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -22,21 +22,6 @@ let
2222
isRocm
2323
isXpu
2424
;
25-
mkStdenv =
26-
buildSet: oldLinuxCompat:
27-
let
28-
inherit (buildSet) pkgs torch;
29-
in
30-
if pkgs.stdenv.hostPlatform.isDarwin then
31-
pkgs.stdenv
32-
else if oldLinuxCompat then
33-
# Uses CUDA stdenv when we are building for CUDA.
34-
pkgs.stdenvGlibc_2_27
35-
else if torch.cudaSupport then
36-
torch.cudaPackages.backendStdenv
37-
else
38-
pkgs.stdenv;
39-
4025
in
4126
rec {
4227
resolveDeps = import ./deps.nix { inherit lib; };
@@ -113,6 +98,7 @@ rec {
11398
mkTorchExtension =
11499
{
115100
buildConfig,
101+
extension,
116102
pkgs,
117103
torch,
118104
bundleBuild,
@@ -122,7 +108,6 @@ rec {
122108
rev,
123109
doGetKernelCheck,
124110
stripRPath ? false,
125-
oldLinuxCompat ? false,
126111
}:
127112
let
128113
inherit (lib) fileset;
@@ -143,34 +128,32 @@ rec {
143128
_: buildConfig: builtins.length (buildConfig.cuda-capabilities or supportedCudaCapabilities)
144129
) buildConfig.kernel
145130
);
146-
stdenv = mkStdenv { inherit pkgs torch; } oldLinuxCompat;
147131
in
148132
if buildConfig.general.universal then
149133
# No torch extension sources? Treat it as a noarch package.
150-
pkgs.callPackage ./torch-extension-noarch ({
134+
135+
extension.mkNoArchExtension {
151136
inherit
152137
src
153138
rev
154-
torch
155139
doGetKernelCheck
156140
;
157141
extensionName = buildConfig.general.name;
158-
})
142+
}
159143
else
160-
pkgs.callPackage ./torch-extension ({
144+
extension.mkExtension {
161145
inherit
162146
doGetKernelCheck
163147
extraDeps
164148
nvccThreads
165149
src
166-
stdenv
167150
stripRPath
168-
torch
169151
rev
170152
;
153+
171154
extensionName = buildConfig.general.name;
172-
doAbiCheck = oldLinuxCompat;
173-
});
155+
doAbiCheck = true;
156+
};
174157

175158
# Build multiple Torch extensions.
176159
mkDistTorchExtensions =
@@ -189,7 +172,6 @@ rec {
189172
value = mkTorchExtension buildSet {
190173
inherit path rev doGetKernelCheck;
191174
stripRPath = true;
192-
oldLinuxCompat = true;
193175
};
194176
};
195177
applicableBuildSets' =
@@ -247,8 +229,7 @@ rec {
247229
let
248230
pkgs = buildSet.pkgs;
249231
rocmSupport = pkgs.config.rocmSupport or false;
250-
stdenv = mkStdenv buildSet false;
251-
mkShell = pkgs.mkShell.override { inherit stdenv; };
232+
mkShell = pkgs.mkShell.override { inherit (buildSet.extension) stdenv; };
252233
in
253234
{
254235
name = torchBuildVersion buildSet;
@@ -288,8 +269,7 @@ rec {
288269
pkgs = buildSet.pkgs;
289270
rocmSupport = pkgs.config.rocmSupport or false;
290271
xpuSupport = pkgs.config.xpuSupport or false;
291-
stdenv = mkStdenv buildSet false;
292-
mkShell = pkgs.mkShell.override { inherit stdenv; };
272+
mkShell = pkgs.mkShell.override { inherit (buildSet.extension) stdenv; };
293273
in
294274
{
295275
name = torchBuildVersion buildSet;

lib/torch-extension/arch.nix

Lines changed: 219 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,219 @@
1+
{
2+
cudaSupport ? torch.cudaSupport,
3+
rocmSupport ? torch.rocmSupport,
4+
xpuSupport ? torch.xpuSupport,
5+
6+
lib,
7+
stdenv,
8+
cudaPackages,
9+
cmake,
10+
cmakeNvccThreadsHook,
11+
ninja,
12+
build2cmake,
13+
get-kernel-check,
14+
kernel-abi-check,
15+
python3,
16+
rewrite-nix-paths-macho,
17+
rocmPackages,
18+
writeScriptBin,
19+
xpuPackages,
20+
21+
apple-sdk_15,
22+
clr,
23+
oneapi-torch-dev,
24+
onednn-xpu,
25+
torch,
26+
}:
27+
28+
{
29+
# Whether to do ABI checks.
30+
doAbiCheck ? true,
31+
32+
# Whether to run get-kernel-check.
33+
doGetKernelCheck ? true,
34+
35+
extensionName,
36+
37+
# Extra dependencies (such as CUTLASS).
38+
extraDeps ? [ ],
39+
40+
nvccThreads,
41+
42+
# Wheter to strip rpath for non-nix use.
43+
stripRPath ? false,
44+
45+
# Revision to bake into the ops name.
46+
rev,
47+
48+
src,
49+
}:
50+
51+
let
52+
# On Darwin, we need the host's xcrun for `xcrun metal` to compile Metal shaders.
53+
# It's not supported by the nixpkgs shim.
54+
xcrunHost = writeScriptBin "xcrunHost" ''
55+
# Use system SDK for Metal files.
56+
unset DEVELOPER_DIR
57+
/usr/bin/xcrun $@
58+
'';
59+
60+
in
61+
62+
stdenv.mkDerivation (prevAttrs: {
63+
name = "${extensionName}-torch-ext";
64+
65+
inherit doAbiCheck nvccThreads src;
66+
67+
# Generate build files.
68+
postPatch = ''
69+
build2cmake generate-torch --backend ${
70+
if cudaSupport then
71+
"cuda"
72+
else if rocmSupport then
73+
"rocm"
74+
else if xpuSupport then
75+
"xpu"
76+
else
77+
"metal"
78+
} --ops-id ${rev} build.toml
79+
'';
80+
81+
# hipify copies files, but its target is run in the CMake build and install
82+
# phases. Since some of the files come from the Nix store, this fails the
83+
# second time around.
84+
preInstall = ''
85+
chmod -R u+w .
86+
'';
87+
88+
nativeBuildInputs = [
89+
kernel-abi-check
90+
cmake
91+
ninja
92+
build2cmake
93+
]
94+
++ lib.optionals doGetKernelCheck [
95+
get-kernel-check
96+
]
97+
++ lib.optionals cudaSupport [
98+
cmakeNvccThreadsHook
99+
cudaPackages.cuda_nvcc
100+
]
101+
++ lib.optionals rocmSupport [
102+
clr
103+
]
104+
++ lib.optionals xpuSupport ([
105+
xpuPackages.ocloc
106+
oneapi-torch-dev
107+
])
108+
++ lib.optionals stdenv.hostPlatform.isDarwin [
109+
rewrite-nix-paths-macho
110+
];
111+
112+
buildInputs = [
113+
torch
114+
torch.cxxdev
115+
]
116+
++ lib.optionals cudaSupport (
117+
with cudaPackages;
118+
[
119+
cuda_cudart
120+
121+
# Make dependent on build configuration dependencies once
122+
# the Torch dependency is gone.
123+
cuda_cccl
124+
libcublas
125+
libcusolver
126+
libcusparse
127+
]
128+
)
129+
++ lib.optionals rocmSupport (
130+
with rocmPackages;
131+
[
132+
hipsparselt
133+
rocwmma-devel
134+
]
135+
)
136+
++ lib.optionals xpuSupport ([
137+
oneapi-torch-dev
138+
onednn-xpu
139+
])
140+
++ lib.optionals stdenv.hostPlatform.isDarwin [
141+
apple-sdk_15
142+
]
143+
++ extraDeps;
144+
145+
env =
146+
lib.optionalAttrs cudaSupport {
147+
CUDAToolkit_ROOT = "${lib.getDev cudaPackages.cuda_nvcc}";
148+
TORCH_CUDA_ARCH_LIST =
149+
if cudaPackages.cudaOlder "12.8" then
150+
"7.0;7.5;8.0;8.6;8.9;9.0"
151+
else if cudaPackages.cudaOlder "13.0" then
152+
"7.0;7.5;8.0;8.6;8.9;9.0;10.0;10.1;12.0"
153+
else
154+
# sm_101 has been renamed to sm_110 in CUDA 13.
155+
"7.5;8.0;8.6;8.9;9.0;10.0;11.0;12.0";
156+
}
157+
// lib.optionalAttrs rocmSupport {
158+
PYTORCH_ROCM_ARCH = lib.concatStringsSep ";" torch.rocmArchs;
159+
}
160+
// lib.optionalAttrs xpuSupport {
161+
MKLROOT = oneapi-torch-dev;
162+
SYCL_ROOT = oneapi-torch-dev;
163+
};
164+
165+
# If we use the default setup, CMAKE_CUDA_HOST_COMPILER gets set to nixpkgs g++.
166+
dontSetupCUDAToolkitCompilers = true;
167+
168+
cmakeFlags = [
169+
(lib.cmakeFeature "Python_EXECUTABLE" "${python3.withPackages (ps: [ torch ])}/bin/python")
170+
]
171+
++ lib.optionals cudaSupport [
172+
(lib.cmakeFeature "CMAKE_CUDA_HOST_COMPILER" "${stdenv.cc}/bin/g++")
173+
]
174+
++ lib.optionals rocmSupport [
175+
# Ensure sure that we use HIP from our CLR override and not HIP from
176+
# the symlink-joined ROCm toolkit.
177+
(lib.cmakeFeature "CMAKE_HIP_COMPILER_ROCM_ROOT" "${clr}")
178+
(lib.cmakeFeature "HIP_ROOT_DIR" "${clr}")
179+
]
180+
++ lib.optionals xpuSupport [
181+
(lib.cmakeFeature "ONEDNN_XPU_INCLUDE_DIR" "${onednn-xpu}/include")
182+
]
183+
++ lib.optionals stdenv.hostPlatform.isDarwin [
184+
# Use host compiler for Metal. Not included in the redistributable SDK.
185+
(lib.cmakeFeature "METAL_COMPILER" "${xcrunHost}/bin/xcrunHost")
186+
];
187+
188+
postInstall = ''
189+
(
190+
cd ..
191+
cp -r torch-ext/${extensionName} $out/
192+
)
193+
cp $out/_${extensionName}_*/* $out/${extensionName}
194+
rm -rf $out/_${extensionName}_*
195+
''
196+
+ (lib.optionalString (stripRPath && stdenv.hostPlatform.isLinux)) ''
197+
find $out/${extensionName} -name '*.so' \
198+
-exec patchelf --set-rpath "" {} \;
199+
''
200+
+ (lib.optionalString (stripRPath && stdenv.hostPlatform.isDarwin)) ''
201+
find $out/${extensionName} -name '*.so' \
202+
-exec rewrite-nix-paths-macho {} \;
203+
204+
# Stub some rpath.
205+
find $out/${extensionName} -name '*.so' \
206+
-exec install_name_tool -add_rpath "@loader_path/lib" {} \;
207+
'';
208+
209+
doInstallCheck = true;
210+
211+
getKernelCheck = extensionName;
212+
213+
# We need access to the host system on Darwin for the Metal compiler.
214+
__noChroot = stdenv.hostPlatform.isDarwin;
215+
216+
passthru = {
217+
inherit torch;
218+
};
219+
})

0 commit comments

Comments
 (0)