Skip to content
This repository was archived by the owner on Apr 6, 2026. It is now read-only.

Commit fb5ecc9

Browse files
authored
Add a CUDA 12.9 build variant for Torch 2.7 (#136)
This change adds a CUDA 12.9 build variant. Since this is not a variant supported/distributed by upstream, we add an `upstreamVariant` attribute to the Torch variants. A variant is only a candidate for a kernel's `bundle` package when `upstreamVariant = true`. Todo: also use `upstreamVariant` in generating variant JSON.
1 parent 9402e64 commit fb5ecc9

10 files changed

Lines changed: 167 additions & 95 deletions

File tree

.github/workflows/check_variants.yaml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,22 +17,22 @@ jobs:
1717
with:
1818
nix_path: nixpkgs=channel:nixos-unstable
1919
- name: Generate variants JSON
20-
run: nix eval --raw .#lib.allBuildVariantsJSON | nix run nixpkgs#jq > build-variants.json
20+
run: nix eval --raw .#lib.allBuildVariantsJSON | nix run nixpkgs#jq 'walk(if type == "array" then sort else . end)' > build-variants.json
2121
- name: Check if variants JSON is up-to-date
2222
run: |
2323
if git diff --exit-code build-variants.json; then
2424
echo "✅ variants.json is up-to-date"
2525
else
26-
echo "🛑 regenerate variants.json: nix eval --raw .#lib.allBuildVariantsJSON | nix run nixpkgs#jq > build-variants.json"
26+
echo "🛑 regenerate variants.json: nix eval --raw .#lib.allBuildVariantsJSON | nix run nixpkgs#jq 'walk(if type == "array" then sort else . end)' > build-variants.json"
2727
exit 1
2828
fi
2929
- name: Generate variants Markdown
30-
run: scripts/gen_variants_markdown.py
30+
run: nix run nixpkgs#python3 scripts/gen_variants_markdown.py
3131
- name: Check if variants Markdown is up-to-date
3232
run: |
3333
if git diff --exit-code docs/build-variants.md; then
3434
echo "✅ docs/build-variants.md is up-to-date"
3535
else
36-
echo "🛑 regenerate docs/build-variants: scripts/gen_variants_markdown.py"
36+
echo "🛑 regenerate docs/build-variants: nix run nixpkgs#python3 scripts/gen_variants_markdown.py"
3737
exit 1
3838
fi

build-variants.json

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,10 @@
1010
"x86_64-linux": {
1111
"cuda": [
1212
"torch26-cxx11-cu118-x86_64-linux",
13-
"torch26-cxx98-cu118-x86_64-linux",
1413
"torch26-cxx11-cu124-x86_64-linux",
15-
"torch26-cxx98-cu124-x86_64-linux",
1614
"torch26-cxx11-cu126-x86_64-linux",
15+
"torch26-cxx98-cu118-x86_64-linux",
16+
"torch26-cxx98-cu124-x86_64-linux",
1717
"torch26-cxx98-cu126-x86_64-linux",
1818
"torch27-cxx11-cu118-x86_64-linux",
1919
"torch27-cxx11-cu126-x86_64-linux",

docs/build-variants.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,10 @@ available. This list will be updated as new PyTorch versions are released.
1515
## CUDA x86_64-linux
1616

1717
- `torch26-cxx11-cu118-x86_64-linux`
18-
- `torch26-cxx98-cu118-x86_64-linux`
1918
- `torch26-cxx11-cu124-x86_64-linux`
20-
- `torch26-cxx98-cu124-x86_64-linux`
2119
- `torch26-cxx11-cu126-x86_64-linux`
20+
- `torch26-cxx98-cu118-x86_64-linux`
21+
- `torch26-cxx98-cu124-x86_64-linux`
2222
- `torch26-cxx98-cu126-x86_64-linux`
2323
- `torch27-cxx11-cu118-x86_64-linux`
2424
- `torch27-cxx11-cu126-x86_64-linux`

flake.lock

Lines changed: 6 additions & 6 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

flake.nix

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33

44
inputs = {
55
flake-utils.url = "github:numtide/flake-utils";
6-
nixpkgs.url = "github:nixos/nixpkgs/nixos-unstable-small";
6+
#nixpkgs.url = "github:nixos/nixpkgs/nixos-unstable-small";
7+
nixpkgs.url = "github:danieldk/nixpkgs/kernel-builder-cuda-12.9.0";
78
flake-compat.url = "github:edolstra/flake-compat";
89
rocm-nix = {
910
url = "github:huggingface/rocm-nix";
@@ -85,6 +86,7 @@
8586
};
8687
redistributable = build.buildDistTorchExtensions {
8788
inherit path;
89+
buildSets = buildSetPerSystem.${system};
8890
rev = revUnderscored;
8991
};
9092
buildTree =

lib/build-version.nix

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
gpu,
33
pkgs,
44
torch,
5+
upstreamVariant,
56
}:
67
let
78
inherit (pkgs) lib;

lib/build.nix

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ rec {
4343
in
4444
lib.foldl (langs: kernel: langs // { ${kernelLang kernel} = true; }) init kernels;
4545

46+
# Filter buildsets that are applicable to a given kernel build config.
4647
applicableBuildSets =
4748
buildConfig: buildSets:
4849
let
@@ -61,6 +62,7 @@ rec {
6162
gpu,
6263
pkgs,
6364
torch,
65+
upstreamVariant,
6466
}:
6567
{
6668
path,
@@ -126,7 +128,11 @@ rec {
126128

127129
# Build multiple Torch extensions.
128130
buildDistTorchExtensions =
129-
{ path, rev }:
131+
{
132+
buildSets,
133+
path,
134+
rev,
135+
}:
130136
let
131137
extensionForTorch =
132138
{ path, rev }:
@@ -147,7 +153,11 @@ rec {
147153
let
148154
# We just need to get any nixpkgs for use by the path join.
149155
pkgs = (builtins.head buildSets).pkgs;
150-
extensions = buildDistTorchExtensions { inherit path rev; };
156+
upstreamBuildSets = builtins.filter (buildSet: buildSet.upstreamVariant) buildSets;
157+
extensions = buildDistTorchExtensions {
158+
inherit path rev;
159+
buildSets = upstreamBuildSets;
160+
};
151161
buildConfig = readBuildConfig path;
152162
namePaths =
153163
if buildConfig.torch.universal or false then

lib/buildsets.nix

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ let
3232
rocmVersion ? "",
3333
torchVersion,
3434
cxx11Abi,
35+
upstreamVariant ? false,
3536
}:
3637
let
3738
pkgs = if gpu == "cuda" then pkgsByCudaVer.${cudaVersion} else pkgsByRocmVer.${rocmVersion};
@@ -40,7 +41,12 @@ let
4041
};
4142
in
4243
{
43-
inherit gpu pkgs torch;
44+
inherit
45+
gpu
46+
pkgs
47+
torch
48+
upstreamVariant
49+
;
4450
};
4551

4652
pkgsForRocm = import nixpkgs {

pkgs/python-modules/torch_2_7/default.nix

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ let
126126
supportedTorchCudaCapabilities =
127127
let
128128
# https://github.com/pytorch/pytorch/blob/release/2.7/.ci/manywheel/build_cuda.sh
129-
capsPerCudaVersion = {
129+
capsPerCudaVersion = rec {
130130
"12.8" = [
131131
"7.5"
132132
"8.0"
@@ -163,6 +163,16 @@ let
163163
"8.6"
164164
"9.0"
165165
];
166+
167+
# Not in upstream yet, so use same capabilities as 12.8.
168+
"12.9" = [
169+
"7.5"
170+
"8.0"
171+
"8.6"
172+
"9.0"
173+
"10.0"
174+
"12.0"
175+
];
166176
};
167177
real = capsPerCudaVersion."${lib.versions.majorMinor cudaPackages.cudaVersion}";
168178
ptx = lists.map (x: "${x}+PTX") real;

0 commit comments

Comments
 (0)