Skip to content

Commit 3bf0e98

Browse files
committed
Merge branch 'main' into pathfinder_compatibility_guard_rails
2 parents b622613 + 8a83a4f commit 3bf0e98

2 files changed

Lines changed: 44 additions & 16 deletions

File tree

cuda_pathfinder/cuda/pathfinder/_static_libs/find_bitcode_lib.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,12 @@ class _BitcodeLibInfo(TypedDict):
4242
),
4343
"available_on_windows": True,
4444
},
45+
"nccl_device": {
46+
"filename": "libnccl_device.bc",
47+
"rel_path": "lib",
48+
"site_packages_dirs": ("nvidia/nccl/lib",),
49+
"available_on_windows": False,
50+
},
4551
"nvshmem_device": {
4652
"filename": "libnvshmem_device.bc",
4753
"rel_path": "lib",

cuda_pathfinder/tests/test_find_bitcode_lib.py

Lines changed: 38 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,13 @@
1818
STRICTNESS = os.environ.get("CUDA_PATHFINDER_TEST_FIND_NVIDIA_BITCODE_LIB_STRICTNESS", "see_what_works")
1919
assert STRICTNESS in ("see_what_works", "all_must_work")
2020

21-
BCL_FILENAME = find_bitcode_lib_module._SUPPORTED_BITCODE_LIBS_INFO["device"]["filename"]
21+
22+
def _bitcode_lib_info(libname: str):
23+
return find_bitcode_lib_module._SUPPORTED_BITCODE_LIBS_INFO[libname]
24+
25+
26+
def _bitcode_lib_filename(libname: str) -> str:
27+
return _bitcode_lib_info(libname)["filename"]
2228

2329

2430
@pytest.fixture
@@ -30,15 +36,20 @@ def clear_find_bitcode_lib_cache():
3036
get_cuda_path_or_home.cache_clear()
3137

3238

33-
def _make_bitcode_lib_file(dir_path: Path) -> str:
39+
def _make_bitcode_lib_file(dir_path: Path, libname: str) -> str:
3440
dir_path.mkdir(parents=True, exist_ok=True)
35-
file_path = dir_path / BCL_FILENAME
41+
file_path = dir_path / _bitcode_lib_filename(libname)
3642
file_path.touch()
3743
return str(file_path)
3844

3945

40-
def _bitcode_lib_dir_under(anchor_dir: Path) -> Path:
41-
return anchor_dir / "nvvm" / "libdevice"
46+
def _bitcode_lib_dir_under(anchor_dir: Path, libname: str) -> Path:
47+
return anchor_dir / _bitcode_lib_info(libname)["rel_path"]
48+
49+
50+
def _site_packages_bitcode_lib_dir_under(anchor_dir: Path, libname: str) -> Path:
51+
rel_dir = _bitcode_lib_info(libname)["site_packages_dirs"][0]
52+
return anchor_dir.joinpath(*rel_dir.split("/"))
4253

4354

4455
def _conda_anchor(conda_prefix: Path) -> Path:
@@ -79,44 +90,55 @@ def test_locate_bitcode_lib(info_summary_append, libname):
7990

8091

8192
@pytest.mark.usefixtures("clear_find_bitcode_lib_cache")
82-
def test_locate_bitcode_lib_search_order(monkeypatch, tmp_path):
83-
site_packages_lib_dir = tmp_path / "site-packages" / "nvidia" / "cu13" / "nvvm" / "libdevice"
84-
site_packages_path = _make_bitcode_lib_file(site_packages_lib_dir)
93+
@pytest.mark.parametrize("libname", SUPPORTED_BITCODE_LIBS)
94+
def test_locate_bitcode_lib_search_order(monkeypatch, tmp_path, libname):
95+
site_packages_lib_dir = _site_packages_bitcode_lib_dir_under(tmp_path / "site-packages", libname)
96+
site_packages_path = _make_bitcode_lib_file(site_packages_lib_dir, libname)
8597

8698
conda_prefix = tmp_path / "conda-prefix"
87-
conda_path = _make_bitcode_lib_file(_bitcode_lib_dir_under(_conda_anchor(conda_prefix)))
99+
conda_path = _make_bitcode_lib_file(_bitcode_lib_dir_under(_conda_anchor(conda_prefix), libname), libname)
88100

89101
cuda_home = tmp_path / "cuda-home"
90-
cuda_home_path = _make_bitcode_lib_file(_bitcode_lib_dir_under(cuda_home))
102+
cuda_home_path = _make_bitcode_lib_file(_bitcode_lib_dir_under(cuda_home, libname), libname)
103+
104+
site_packages_sub_dirs = tuple(
105+
tuple(rel_dir.split("/")) for rel_dir in _bitcode_lib_info(libname)["site_packages_dirs"]
106+
)
107+
108+
def find_expected_sub_dir(sub_dir):
109+
assert sub_dir in site_packages_sub_dirs
110+
if sub_dir == site_packages_sub_dirs[0]:
111+
return [str(site_packages_lib_dir)]
112+
return []
91113

92114
monkeypatch.setattr(
93115
find_bitcode_lib_module,
94116
"find_sub_dirs_all_sitepackages",
95-
lambda _sub_dir: [str(site_packages_lib_dir)],
117+
find_expected_sub_dir,
96118
)
97119
monkeypatch.setenv("CONDA_PREFIX", str(conda_prefix))
98120
monkeypatch.setenv("CUDA_HOME", str(cuda_home))
99121
monkeypatch.delenv("CUDA_PATH", raising=False)
100122

101-
located_lib = locate_bitcode_lib("device")
123+
located_lib = locate_bitcode_lib(libname)
102124
assert located_lib.abs_path == site_packages_path
103125
assert located_lib.found_via == "site-packages"
104126
os.remove(site_packages_path)
105127

106-
located_lib = locate_bitcode_lib("device")
128+
located_lib = locate_bitcode_lib(libname)
107129
assert located_lib.abs_path == conda_path
108130
assert located_lib.found_via == "conda"
109131
os.remove(conda_path)
110132

111-
located_lib = locate_bitcode_lib("device")
133+
located_lib = locate_bitcode_lib(libname)
112134
assert located_lib.abs_path == cuda_home_path
113135
assert located_lib.found_via == "CUDA_PATH"
114136

115137

116138
@pytest.mark.usefixtures("clear_find_bitcode_lib_cache")
117139
def test_find_bitcode_lib_not_found_error_includes_cuda_home_directory_listing(monkeypatch, tmp_path):
118140
cuda_home = tmp_path / "cuda-home"
119-
lib_dir = _bitcode_lib_dir_under(cuda_home)
141+
lib_dir = _bitcode_lib_dir_under(cuda_home, "device")
120142
lib_dir.mkdir(parents=True, exist_ok=True)
121143
extra_file = lib_dir / "README.txt"
122144
extra_file.write_text("placeholder", encoding="utf-8")
@@ -134,7 +156,7 @@ def test_find_bitcode_lib_not_found_error_includes_cuda_home_directory_listing(m
134156
find_bitcode_lib("device")
135157

136158
message = str(exc_info.value)
137-
expected_missing_file = os.path.join(str(lib_dir), BCL_FILENAME)
159+
expected_missing_file = os.path.join(str(lib_dir), _bitcode_lib_filename("device"))
138160
assert f"No such file: {expected_missing_file}" in message
139161
assert f'listdir("{lib_dir}"):' in message
140162
assert "README.txt" in message

0 commit comments

Comments
 (0)