22#
33# SPDX-License-Identifier: Apache-2.0
44from distutils import file_util
5- import shutil
6-
75from setuptools import setup
86from setuptools .command .build_ext import build_ext
9- from setuptools .command .bdist_wheel import bdist_wheel
107from setuptools .extension import Extension
118import os
129import sys
@@ -76,51 +73,6 @@ def run(self):
7673 dry_run = self .dry_run )
7774
7875
79- class BdistWheelWithDeps (bdist_wheel ):
80- user_options = bdist_wheel .user_options + [
81- ('include-tileir' , None , 'Include tileir dependency in the package' ),
82- ('tileir-deps=' , None , 'Path to extra tileir dependency to package' ),
83- ]
84-
85- def initialize_options (self ):
86- super ().initialize_options ()
87- self .include_tileir = False
88- self .tileir_deps = None
89-
90- def finalize_options (self ):
91- super ().finalize_options ()
92-
93- def _copy_tileir_deps (self , src , dst ):
94- bin_dir = os .path .join (dst , 'bin' )
95- lib_dir = os .path .join (dst , 'lib' )
96- file_util .copy_file (os .path .join (src , 'tileiras' ), bin_dir ,
97- dry_run = self .dry_run )
98- file_util .copy_file (os .path .join (src , 'ptxas' ), bin_dir ,
99- dry_run = self .dry_run )
100- file_util .copy_file (os .path .join (src , 'libnvvm.so' ), lib_dir ,
101- dry_run = self .dry_run )
102-
103- def run (self ):
104- build_py = self .get_finalized_command ('build_py' )
105- build_lib = build_py .build_lib
106- dst_dir = os .path .join (build_lib , 'cuda' , 'tile' , '_deps' )
107- dst_bin_dir = os .path .join (dst_dir , 'bin' )
108- dst_lib_dir = os .path .join (dst_dir , 'lib' )
109- if self .include_tileir :
110- os .makedirs (dst_bin_dir , exist_ok = True )
111- os .makedirs (dst_lib_dir , exist_ok = True )
112- if self .tileir_deps :
113- src_dir = os .path .abspath (self .tileir_deps )
114- self ._copy_tileir_deps (src_dir , dst_dir )
115- else :
116- compiler_src = shutil .which ('tileiras' )
117- if compiler_src is None :
118- raise FileNotFoundError ("Cannot find `tileiras`. Make sure it is in PATH." )
119- file_util .copy_file (compiler_src , dst_bin_dir , update = True , dry_run = self .dry_run )
120-
121- super ().run ()
122-
123-
12476def _get_csrc_dir (ext_name : str ):
12577 prefix = "cuda.tile._"
12678 assert ext_name .startswith (prefix )
@@ -141,6 +93,5 @@ def _get_build_lib_filename(ext_name: str):
14193 ],
14294 cmdclass = dict (
14395 build_ext = BuildExtWithCmake ,
144- bdist_wheel = BdistWheelWithDeps
14596 )
14697)
0 commit comments