|
25 | 25 | in |
26 | 26 | builtins.map (buildConfig: buildConfig // { backend = backend buildConfig; }) systemBuildConfigs; |
27 | 27 |
|
28 | | - cudaVersions = |
29 | | - let |
30 | | - withCuda = builtins.filter (torchVersion: torchVersion ? cudaVersion) torchVersions; |
31 | | - in |
32 | | - lib.unique (builtins.map (torchVersion: torchVersion.cudaVersion) withCuda); |
33 | | - |
34 | | - rocmVersions = |
35 | | - let |
36 | | - withRocm = builtins.filter (torchVersion: torchVersion ? rocmVersion) torchVersions; |
37 | | - in |
38 | | - lib.unique (builtins.map (torchVersion: torchVersion.rocmVersion) withRocm); |
39 | | - |
40 | | - xpuVersions = |
41 | | - let |
42 | | - withXpu = builtins.filter (torchVersion: torchVersion ? xpuVersion) torchVersions; |
43 | | - in |
44 | | - lib.unique (builtins.map (torchVersion: torchVersion.xpuVersion) withXpu); |
45 | | - |
46 | 28 | flattenVersion = version: lib.replaceStrings [ "." ] [ "_" ] (lib.versions.pad 2 version); |
47 | 29 |
|
| 30 | + overlayForTorchVersion = torchVersion: sourceBuild: self: super: { |
| 31 | + pythonPackagesExtensions = super.pythonPackagesExtensions ++ [ |
| 32 | + ( |
| 33 | + python-self: python-super: with python-self; { |
| 34 | + torch = |
| 35 | + if sourceBuild then |
| 36 | + python-self."torch_${flattenVersion torchVersion}" |
| 37 | + else |
| 38 | + python-self."torch-bin_${flattenVersion torchVersion}"; |
| 39 | + } |
| 40 | + ) |
| 41 | + ]; |
| 42 | + }; |
| 43 | + |
48 | 44 | # An overlay that overides CUDA to the given version. |
49 | 45 | overlayForCudaVersion = cudaVersion: self: super: { |
50 | 46 | cudaPackages = super."cudaPackages_${flattenVersion cudaVersion}"; |
|
57 | 53 | overlayForXpuVersion = xpuVersion: self: super: { |
58 | 54 | xpuPackages = super."xpuPackages_${flattenVersion xpuVersion}"; |
59 | 55 | }; |
| 56 | + |
| 57 | + backendConfig = { |
| 58 | + cpu = { |
| 59 | + allowUnfree = true; |
| 60 | + }; |
| 61 | + |
| 62 | + cuda = { |
| 63 | + allowUnfree = true; |
| 64 | + cudaSupport = true; |
| 65 | + }; |
| 66 | + |
| 67 | + metal = { |
| 68 | + allowUnfree = true; |
| 69 | + metalSupport = true; |
| 70 | + }; |
| 71 | + |
| 72 | + rocm = { |
| 73 | + allowUnfree = true; |
| 74 | + rocmSupport = true; |
| 75 | + }; |
| 76 | + |
| 77 | + xpu = { |
| 78 | + allowUnfree = true; |
| 79 | + xpuSupport = true; |
| 80 | + }; |
| 81 | + }; |
| 82 | + |
| 83 | + xpuConfig = { |
| 84 | + allowUnfree = true; |
| 85 | + xpuSupport = true; |
| 86 | + }; |
| 87 | + |
60 | 88 | # Construct the nixpkgs package set for the given versions. |
61 | 89 | mkBuildSet = |
62 | 90 | buildConfig@{ |
|
67 | 95 | rocmVersion ? null, |
68 | 96 | xpuVersion ? null, |
69 | 97 | torchVersion, |
70 | | - cxx11Abi, |
71 | 98 | system, |
72 | 99 | bundleBuild ? false, |
73 | 100 | sourceBuild ? false, |
74 | 101 | }: |
75 | 102 | let |
76 | | - pkgs = |
| 103 | + backendOverlay = |
77 | 104 | if buildConfig.backend == "cpu" then |
78 | | - pkgsForCpu |
| 105 | + [ ] |
79 | 106 | else if buildConfig.backend == "cuda" then |
80 | | - pkgsByCudaVer.${cudaVersion} |
| 107 | + [ (overlayForCudaVersion buildConfig.cudaVersion) ] |
81 | 108 | else if buildConfig.backend == "rocm" then |
82 | | - pkgsByRocmVer.${rocmVersion} |
| 109 | + [ (overlayForRocmVersion buildConfig.rocmVersion) ] |
83 | 110 | else if buildConfig.backend == "metal" then |
84 | | - pkgsForMetal |
| 111 | + [ ] |
85 | 112 | else if buildConfig.backend == "xpu" then |
86 | | - pkgsByXpuVer.${xpuVersion} |
| 113 | + [ (overlayForXpuVersion buildConfig.xpuVersion) ] |
87 | 114 | else |
88 | 115 | throw "No compute framework set in Torch version"; |
89 | | - torch = |
90 | | - if sourceBuild then |
91 | | - pkgs.python3.pkgs."torch_${flattenVersion torchVersion}".override { |
92 | | - inherit cxx11Abi; |
93 | | - } |
94 | | - else |
95 | | - pkgs.python3.pkgs."torch-bin_${flattenVersion torchVersion}".override { |
96 | | - inherit cxx11Abi; |
97 | | - }; |
| 116 | + config = |
| 117 | + backendConfig.${buildConfig.backend} or (throw "No backend config for ${buildConfig.backend}"); |
| 118 | + |
| 119 | + pkgs = import nixpkgs { |
| 120 | + inherit config system; |
| 121 | + overlays = [ |
| 122 | + overlay |
| 123 | + ] |
| 124 | + ++ backendOverlay |
| 125 | + ++ [ (overlayForTorchVersion torchVersion sourceBuild) ]; |
| 126 | + }; |
| 127 | + |
| 128 | + torch = pkgs.python3.pkgs.torch; |
| 129 | + |
98 | 130 | extension = pkgs.callPackage ./torch-extension { inherit torch; }; |
99 | 131 | in |
100 | 132 | { |
|
106 | 138 | bundleBuild |
107 | 139 | ; |
108 | 140 | }; |
109 | | - pkgsForXpuVersions = |
110 | | - xpuVersions: |
111 | | - builtins.listToAttrs ( |
112 | | - map (xpuVersion: { |
113 | | - name = xpuVersion; |
114 | | - value = import nixpkgs { |
115 | | - inherit system; |
116 | | - config = { |
117 | | - allowUnfree = true; |
118 | | - xpuSupport = true; |
119 | | - }; |
120 | | - overlays = [ |
121 | | - overlay |
122 | | - (overlayForXpuVersion xpuVersion) |
123 | | - ]; |
124 | | - }; |
125 | | - }) xpuVersions |
126 | | - ); |
127 | | - pkgsByXpuVer = pkgsForXpuVersions xpuVersions; |
128 | | - |
129 | | - pkgsForMetal = import nixpkgs { |
130 | | - inherit system; |
131 | | - config = { |
132 | | - allowUnfree = true; |
133 | | - metalSupport = true; |
134 | | - }; |
135 | | - overlays = [ |
136 | | - overlay |
137 | | - ]; |
138 | | - }; |
139 | | - |
140 | | - pkgsForCpu = import nixpkgs { |
141 | | - inherit system; |
142 | | - config = { |
143 | | - allowUnfree = true; |
144 | | - }; |
145 | | - overlays = [ |
146 | | - overlay |
147 | | - ]; |
148 | | - }; |
149 | | - |
150 | | - # Instantiate nixpkgs for the given CUDA versions. Returns |
151 | | - # an attribute set like `{ "12.4" = <nixpkgs with 12.4>; ... }`. |
152 | | - pkgsForCudaVersions = |
153 | | - cudaVersions: |
154 | | - builtins.listToAttrs ( |
155 | | - map (cudaVersion: { |
156 | | - name = cudaVersion; |
157 | | - value = import nixpkgs { |
158 | | - inherit system; |
159 | | - config = { |
160 | | - allowUnfree = true; |
161 | | - cudaSupport = true; |
162 | | - }; |
163 | | - overlays = [ |
164 | | - overlay |
165 | | - (overlayForCudaVersion cudaVersion) |
166 | | - ]; |
167 | | - }; |
168 | | - }) cudaVersions |
169 | | - ); |
170 | | - |
171 | | - pkgsByCudaVer = pkgsForCudaVersions cudaVersions; |
172 | | - |
173 | | - pkgsForRocmVersions = |
174 | | - rocmVersions: |
175 | | - builtins.listToAttrs ( |
176 | | - map (rocmVersion: { |
177 | | - name = rocmVersion; |
178 | | - value = import nixpkgs { |
179 | | - inherit system; |
180 | | - config = { |
181 | | - allowUnfree = true; |
182 | | - rocmSupport = true; |
183 | | - }; |
184 | | - overlays = [ |
185 | | - overlay |
186 | | - (overlayForRocmVersion rocmVersion) |
187 | | - ]; |
188 | | - }; |
189 | | - }) rocmVersions |
190 | | - ); |
191 | | - |
192 | | - pkgsByRocmVer = pkgsForRocmVersions rocmVersions; |
193 | | - |
194 | 141 | in |
195 | 142 | map mkBuildSet (buildConfigs system) |
0 commit comments