Skip to content

Commit 021e0f3

Browse files
committed
check in a working merger script
1 parent 6be8e7d commit 021e0f3

1 file changed

Lines changed: 200 additions & 0 deletions

File tree

ci/tools/merge_cuda_core_wheels.py

Lines changed: 200 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,200 @@
1+
#!/usr/bin/env python3
2+
"""
3+
Script to merge CUDA-specific wheels into a single multi-CUDA wheel.
4+
5+
This script takes wheels built for different CUDA versions (cu12, cu13) and merges them
6+
into a single wheel that supports both CUDA versions.
7+
8+
In particular, each wheel contains a CUDA-specific build of the `cuda.core` library
9+
and the associated bindings. This script merges these directories into a single wheel
10+
that supports both CUDA versions, i.e., containing both `cuda/core/experimental/cu12`
11+
and `cuda/core/experimental/cu13`. At runtime, the code in `cuda/core/experimental/__init__.py`
12+
is used to import the appropriate CUDA-specific bindings.
13+
"""
14+
15+
import argparse
16+
import os
17+
import shutil
18+
import subprocess
19+
import sys
20+
import tempfile
21+
from pathlib import Path
22+
from typing import List
23+
24+
25+
def run_command(
26+
cmd: List[str], cwd: Path = None, env: dict = None
27+
) -> subprocess.CompletedProcess:
28+
"""Run a command with error handling."""
29+
print(f"Running: {' '.join(cmd)}")
30+
if cwd:
31+
print(f" Working directory: {cwd}")
32+
33+
result = subprocess.run(cmd, cwd=cwd, env=env, capture_output=True, text=True)
34+
35+
if result.returncode != 0:
36+
print(f"Command failed with return code {result.returncode}")
37+
print("STDOUT:", result.stdout)
38+
print("STDERR:", result.stderr)
39+
result.check_returncode()
40+
41+
return result
42+
43+
44+
def merge_wheels(wheels: List[Path], output_dir: Path) -> Path:
45+
"""Merge multiple wheels into a single wheel with version-specific binaries."""
46+
print("\n=== Merging wheels ===")
47+
print(f"Input wheels: {[w.name for w in wheels]}")
48+
49+
if len(wheels) == 1:
50+
raise RuntimeError("only one wheel is provided, nothing to merge")
51+
52+
# Extract all wheels to temporary directories
53+
with tempfile.TemporaryDirectory() as temp_dir:
54+
temp_path = Path(temp_dir)
55+
extracted_wheels = []
56+
57+
for i, wheel in enumerate(wheels):
58+
print(f"Extracting wheel {i + 1}/{len(wheels)}: {wheel.name}")
59+
# Extract wheel - wheel unpack creates the directory itself
60+
run_command(
61+
[
62+
"python",
63+
"-m",
64+
"wheel",
65+
"unpack",
66+
str(wheel),
67+
"--dest",
68+
str(temp_path),
69+
]
70+
)
71+
72+
# Find the extracted directory (wheel unpack creates a subdirectory)
73+
extract_dir = None
74+
for item in temp_path.iterdir():
75+
if item.is_dir() and item.name.startswith("cuda_core"):
76+
extract_dir = item
77+
break
78+
79+
if not extract_dir:
80+
raise RuntimeError(
81+
f"Could not find extracted wheel directory for {wheel.name}"
82+
)
83+
84+
# Rename to our expected name
85+
expected_name = temp_path / f"wheel_{i}"
86+
extract_dir.rename(expected_name)
87+
extract_dir = expected_name
88+
89+
extracted_wheels.append(extract_dir)
90+
91+
# Use the first wheel as the base and merge binaries from others
92+
base_wheel = extracted_wheels[0]
93+
94+
# now copy the version-specific directory from other wheels
95+
# into the appropriate place in the base wheel
96+
for i, wheel_dir in enumerate(extracted_wheels):
97+
cuda_version = wheels[i].name.split(".cu")[1].split(".")[0]
98+
base_dir = (
99+
Path("cuda")
100+
/ "core"
101+
/ "experimental"
102+
)
103+
# Copy from other wheels
104+
print(f" Copying {wheel_dir} to {base_wheel}")
105+
shutil.copytree(wheel_dir / base_dir, base_wheel / base_dir / f"cu{cuda_version}")
106+
107+
# Overwrite the __init__.py in versioned dirs
108+
open(base_wheel / base_dir / f"cu{cuda_version}" / "__init__.py", "w").close()
109+
110+
# The base dir should only contain __init__.py, the include dir, and the versioned dirs
111+
files_to_remove = os.listdir(base_wheel / base_dir)
112+
for f in files_to_remove:
113+
f_abspath = base_wheel / base_dir / f
114+
if f not in ("__init__.py", "cu12", "cu13", "include"):
115+
if os.path.isdir(f_abspath):
116+
shutil.rmtree(f_abspath)
117+
else:
118+
os.remove(f_abspath)
119+
120+
# Repack the merged wheel
121+
output_dir.mkdir(parents=True, exist_ok=True)
122+
123+
# Create a clean wheel name without CUDA version suffixes
124+
base_wheel_name = wheels[0].name
125+
# Remove any .cu* suffix from the wheel name
126+
if ".cu" in base_wheel_name:
127+
base_wheel_name = base_wheel_name.split(".cu")[0] + ".whl"
128+
129+
print(f"Repacking merged wheel as: {base_wheel_name}")
130+
run_command(
131+
[
132+
"python",
133+
"-m",
134+
"wheel",
135+
"pack",
136+
str(base_wheel),
137+
"--dest-dir",
138+
str(output_dir),
139+
]
140+
)
141+
142+
# Find the output wheel
143+
output_wheels = list(output_dir.glob("*.whl"))
144+
if not output_wheels:
145+
raise RuntimeError("Failed to create merged wheel")
146+
147+
merged_wheel = output_wheels[0]
148+
print(f"Successfully merged wheel: {merged_wheel}")
149+
return merged_wheel
150+
151+
152+
def main():
153+
"""Main merge script."""
154+
parser = argparse.ArgumentParser(
155+
description="Merge CUDA-specific wheels into a single multi-CUDA wheel"
156+
)
157+
parser.add_argument(
158+
"wheels", nargs="+", help="Paths to the CUDA-specific wheels to merge"
159+
)
160+
parser.add_argument(
161+
"--output-dir", "-o", default="dist", help="Output directory for merged wheel"
162+
)
163+
164+
args = parser.parse_args()
165+
166+
print("cuda.core Wheel Merger")
167+
print("======================")
168+
169+
# Convert wheel paths to Path objects and validate
170+
wheels = []
171+
for wheel_path in args.wheels:
172+
wheel = Path(wheel_path)
173+
if not wheel.exists():
174+
print(f"Error: Wheel not found: {wheel}")
175+
sys.exit(1)
176+
if not wheel.name.endswith(".whl"):
177+
print(f"Error: Not a wheel file: {wheel}")
178+
sys.exit(1)
179+
wheels.append(wheel)
180+
181+
if not wheels:
182+
print("Error: No wheels provided")
183+
sys.exit(1)
184+
185+
output_dir = Path(args.output_dir)
186+
187+
# Check that we have wheel tool available
188+
try:
189+
run_command(["python", "-m", "wheel", "--help"])
190+
except Exception:
191+
print("Error: wheel package not available. Install with: pip install wheel")
192+
sys.exit(1)
193+
194+
# Merge the wheels
195+
merged_wheel = merge_wheels(wheels, output_dir)
196+
print(f"\nMerge complete! Output: {merged_wheel}")
197+
198+
199+
if __name__ == "__main__":
200+
main()

0 commit comments

Comments
 (0)