1+ # SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+ #
3+ # SPDX-License-Identifier: Apache-2.0
4+
15#!/usr/bin/env python3
26"""
37Script to merge CUDA-specific wheels into a single multi-CUDA wheel.
1014that supports both CUDA versions, i.e., containing both `cuda/core/experimental/cu12`
1115and `cuda/core/experimental/cu13`. At runtime, the code in `cuda/core/experimental/__init__.py`
1216is used to import the appropriate CUDA-specific bindings.
17+
18+ This script is based on the one in NVIDIA/CCCL.
1319"""
1420
1521import argparse
1622import os
1723import shutil
18- import subprocess
24+ import subprocess # nosec: B404
1925import sys
2026import tempfile
2127from pathlib import Path
2228from typing import List
2329
2430
25- def run_command (
26- cmd : List [str ], cwd : Path = None , env : dict = None
27- ) -> subprocess .CompletedProcess :
31+ def run_command (cmd : List [str ], cwd : Path = None , env : dict = os .environ ) -> subprocess .CompletedProcess :
2832 """Run a command with error handling."""
2933 print (f"Running: { ' ' .join (cmd )} " )
3034 if cwd :
3135 print (f" Working directory: { cwd } " )
3236
33- result = subprocess .run (cmd , cwd = cwd , env = env , capture_output = True , text = True )
37+ result = subprocess .run (cmd , cwd = cwd , env = env , capture_output = True , text = True ) # nosec: B603
3438
3539 if result .returncode != 0 :
3640 print (f"Command failed with return code { result .returncode } " )
@@ -77,9 +81,7 @@ def merge_wheels(wheels: List[Path], output_dir: Path) -> Path:
7781 break
7882
7983 if not extract_dir :
80- raise RuntimeError (
81- f"Could not find extracted wheel directory for { wheel .name } "
82- )
84+ raise RuntimeError (f"Could not find extracted wheel directory for { wheel .name } " )
8385
8486 # Rename to our expected name
8587 expected_name = temp_path / f"wheel_{ i } "
@@ -95,11 +97,7 @@ def merge_wheels(wheels: List[Path], output_dir: Path) -> Path:
9597 # into the appropriate place in the base wheel
9698 for i , wheel_dir in enumerate (extracted_wheels ):
9799 cuda_version = wheels [i ].name .split (".cu" )[1 ].split ("." )[0 ]
98- base_dir = (
99- Path ("cuda" )
100- / "core"
101- / "experimental"
102- )
100+ base_dir = Path ("cuda" ) / "core" / "experimental"
103101 # Copy from other wheels
104102 print (f" Copying { wheel_dir } to { base_wheel } " )
105103 shutil .copytree (wheel_dir / base_dir , base_wheel / base_dir / f"cu{ cuda_version } " )
@@ -151,15 +149,9 @@ def merge_wheels(wheels: List[Path], output_dir: Path) -> Path:
151149
152150def main ():
153151 """Main merge script."""
154- parser = argparse .ArgumentParser (
155- description = "Merge CUDA-specific wheels into a single multi-CUDA wheel"
156- )
157- parser .add_argument (
158- "wheels" , nargs = "+" , help = "Paths to the CUDA-specific wheels to merge"
159- )
160- parser .add_argument (
161- "--output-dir" , "-o" , default = "dist" , help = "Output directory for merged wheel"
162- )
152+ parser = argparse .ArgumentParser (description = "Merge CUDA-specific wheels into a single multi-CUDA wheel" )
153+ parser .add_argument ("wheels" , nargs = "+" , help = "Paths to the CUDA-specific wheels to merge" )
154+ parser .add_argument ("--output-dir" , "-o" , default = "dist" , help = "Output directory for merged wheel" )
163155
164156 args = parser .parse_args ()
165157
0 commit comments