1818STRICTNESS = os .environ .get ("CUDA_PATHFINDER_TEST_FIND_NVIDIA_BITCODE_LIB_STRICTNESS" , "see_what_works" )
1919assert STRICTNESS in ("see_what_works" , "all_must_work" )
2020
21- BCL_FILENAME = find_bitcode_lib_module ._SUPPORTED_BITCODE_LIBS_INFO ["device" ]["filename" ]
21+
22+ def _bitcode_lib_info (libname : str ):
23+ return find_bitcode_lib_module ._SUPPORTED_BITCODE_LIBS_INFO [libname ]
24+
25+
26+ def _bitcode_lib_filename (libname : str ) -> str :
27+ return _bitcode_lib_info (libname )["filename" ]
2228
2329
2430@pytest .fixture
@@ -30,15 +36,20 @@ def clear_find_bitcode_lib_cache():
3036 get_cuda_path_or_home .cache_clear ()
3137
3238
33- def _make_bitcode_lib_file (dir_path : Path ) -> str :
39+ def _make_bitcode_lib_file (dir_path : Path , libname : str ) -> str :
3440 dir_path .mkdir (parents = True , exist_ok = True )
35- file_path = dir_path / BCL_FILENAME
41+ file_path = dir_path / _bitcode_lib_filename ( libname )
3642 file_path .touch ()
3743 return str (file_path )
3844
3945
40- def _bitcode_lib_dir_under (anchor_dir : Path ) -> Path :
41- return anchor_dir / "nvvm" / "libdevice"
46+ def _bitcode_lib_dir_under (anchor_dir : Path , libname : str ) -> Path :
47+ return anchor_dir / _bitcode_lib_info (libname )["rel_path" ]
48+
49+
50+ def _site_packages_bitcode_lib_dir_under (anchor_dir : Path , libname : str ) -> Path :
51+ rel_dir = _bitcode_lib_info (libname )["site_packages_dirs" ][0 ]
52+ return anchor_dir .joinpath (* rel_dir .split ("/" ))
4253
4354
4455def _conda_anchor (conda_prefix : Path ) -> Path :
@@ -79,44 +90,55 @@ def test_locate_bitcode_lib(info_summary_append, libname):
7990
8091
8192@pytest .mark .usefixtures ("clear_find_bitcode_lib_cache" )
82- def test_locate_bitcode_lib_search_order (monkeypatch , tmp_path ):
83- site_packages_lib_dir = tmp_path / "site-packages" / "nvidia" / "cu13" / "nvvm" / "libdevice"
84- site_packages_path = _make_bitcode_lib_file (site_packages_lib_dir )
93+ @pytest .mark .parametrize ("libname" , SUPPORTED_BITCODE_LIBS )
94+ def test_locate_bitcode_lib_search_order (monkeypatch , tmp_path , libname ):
95+ site_packages_lib_dir = _site_packages_bitcode_lib_dir_under (tmp_path / "site-packages" , libname )
96+ site_packages_path = _make_bitcode_lib_file (site_packages_lib_dir , libname )
8597
8698 conda_prefix = tmp_path / "conda-prefix"
87- conda_path = _make_bitcode_lib_file (_bitcode_lib_dir_under (_conda_anchor (conda_prefix )) )
99+ conda_path = _make_bitcode_lib_file (_bitcode_lib_dir_under (_conda_anchor (conda_prefix ), libname ), libname )
88100
89101 cuda_home = tmp_path / "cuda-home"
90- cuda_home_path = _make_bitcode_lib_file (_bitcode_lib_dir_under (cuda_home ))
102+ cuda_home_path = _make_bitcode_lib_file (_bitcode_lib_dir_under (cuda_home , libname ), libname )
103+
104+ site_packages_sub_dirs = tuple (
105+ tuple (rel_dir .split ("/" )) for rel_dir in _bitcode_lib_info (libname )["site_packages_dirs" ]
106+ )
107+
108+ def find_expected_sub_dir (sub_dir ):
109+ assert sub_dir in site_packages_sub_dirs
110+ if sub_dir == site_packages_sub_dirs [0 ]:
111+ return [str (site_packages_lib_dir )]
112+ return []
91113
92114 monkeypatch .setattr (
93115 find_bitcode_lib_module ,
94116 "find_sub_dirs_all_sitepackages" ,
95- lambda _sub_dir : [ str ( site_packages_lib_dir )] ,
117+ find_expected_sub_dir ,
96118 )
97119 monkeypatch .setenv ("CONDA_PREFIX" , str (conda_prefix ))
98120 monkeypatch .setenv ("CUDA_HOME" , str (cuda_home ))
99121 monkeypatch .delenv ("CUDA_PATH" , raising = False )
100122
101- located_lib = locate_bitcode_lib ("device" )
123+ located_lib = locate_bitcode_lib (libname )
102124 assert located_lib .abs_path == site_packages_path
103125 assert located_lib .found_via == "site-packages"
104126 os .remove (site_packages_path )
105127
106- located_lib = locate_bitcode_lib ("device" )
128+ located_lib = locate_bitcode_lib (libname )
107129 assert located_lib .abs_path == conda_path
108130 assert located_lib .found_via == "conda"
109131 os .remove (conda_path )
110132
111- located_lib = locate_bitcode_lib ("device" )
133+ located_lib = locate_bitcode_lib (libname )
112134 assert located_lib .abs_path == cuda_home_path
113135 assert located_lib .found_via == "CUDA_PATH"
114136
115137
116138@pytest .mark .usefixtures ("clear_find_bitcode_lib_cache" )
117139def test_find_bitcode_lib_not_found_error_includes_cuda_home_directory_listing (monkeypatch , tmp_path ):
118140 cuda_home = tmp_path / "cuda-home"
119- lib_dir = _bitcode_lib_dir_under (cuda_home )
141+ lib_dir = _bitcode_lib_dir_under (cuda_home , "device" )
120142 lib_dir .mkdir (parents = True , exist_ok = True )
121143 extra_file = lib_dir / "README.txt"
122144 extra_file .write_text ("placeholder" , encoding = "utf-8" )
@@ -134,7 +156,7 @@ def test_find_bitcode_lib_not_found_error_includes_cuda_home_directory_listing(m
134156 find_bitcode_lib ("device" )
135157
136158 message = str (exc_info .value )
137- expected_missing_file = os .path .join (str (lib_dir ), BCL_FILENAME )
159+ expected_missing_file = os .path .join (str (lib_dir ), _bitcode_lib_filename ( "device" ) )
138160 assert f"No such file: { expected_missing_file } " in message
139161 assert f'listdir("{ lib_dir } "):' in message
140162 assert "README.txt" in message
0 commit comments