Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,12 @@ class _BitcodeLibInfo(TypedDict):
),
"available_on_windows": True,
},
"nccl_device": {
"filename": "libnccl_device.bc",
"rel_path": "lib",
"site_packages_dirs": ("nvidia/nccl/lib",),
"available_on_windows": False,
},
"nvshmem_device": {
"filename": "libnvshmem_device.bc",
"rel_path": "lib",
Expand Down
54 changes: 38 additions & 16 deletions cuda_pathfinder/tests/test_find_bitcode_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,13 @@
STRICTNESS = os.environ.get("CUDA_PATHFINDER_TEST_FIND_NVIDIA_BITCODE_LIB_STRICTNESS", "see_what_works")
assert STRICTNESS in ("see_what_works", "all_must_work")

BCL_FILENAME = find_bitcode_lib_module._SUPPORTED_BITCODE_LIBS_INFO["device"]["filename"]

def _bitcode_lib_info(libname: str):
return find_bitcode_lib_module._SUPPORTED_BITCODE_LIBS_INFO[libname]


def _bitcode_lib_filename(libname: str) -> str:
return _bitcode_lib_info(libname)["filename"]


@pytest.fixture
Expand All @@ -30,15 +36,20 @@ def clear_find_bitcode_lib_cache():
get_cuda_path_or_home.cache_clear()


def _make_bitcode_lib_file(dir_path: Path) -> str:
def _make_bitcode_lib_file(dir_path: Path, libname: str) -> str:
dir_path.mkdir(parents=True, exist_ok=True)
file_path = dir_path / BCL_FILENAME
file_path = dir_path / _bitcode_lib_filename(libname)
file_path.touch()
return str(file_path)


def _bitcode_lib_dir_under(anchor_dir: Path) -> Path:
return anchor_dir / "nvvm" / "libdevice"
def _bitcode_lib_dir_under(anchor_dir: Path, libname: str) -> Path:
return anchor_dir / _bitcode_lib_info(libname)["rel_path"]


def _site_packages_bitcode_lib_dir_under(anchor_dir: Path, libname: str) -> Path:
rel_dir = _bitcode_lib_info(libname)["site_packages_dirs"][0]
return anchor_dir.joinpath(*rel_dir.split("/"))


def _conda_anchor(conda_prefix: Path) -> Path:
Expand Down Expand Up @@ -79,44 +90,55 @@ def test_locate_bitcode_lib(info_summary_append, libname):


@pytest.mark.usefixtures("clear_find_bitcode_lib_cache")
def test_locate_bitcode_lib_search_order(monkeypatch, tmp_path):
site_packages_lib_dir = tmp_path / "site-packages" / "nvidia" / "cu13" / "nvvm" / "libdevice"
site_packages_path = _make_bitcode_lib_file(site_packages_lib_dir)
@pytest.mark.parametrize("libname", SUPPORTED_BITCODE_LIBS)
def test_locate_bitcode_lib_search_order(monkeypatch, tmp_path, libname):
site_packages_lib_dir = _site_packages_bitcode_lib_dir_under(tmp_path / "site-packages", libname)
site_packages_path = _make_bitcode_lib_file(site_packages_lib_dir, libname)

conda_prefix = tmp_path / "conda-prefix"
conda_path = _make_bitcode_lib_file(_bitcode_lib_dir_under(_conda_anchor(conda_prefix)))
conda_path = _make_bitcode_lib_file(_bitcode_lib_dir_under(_conda_anchor(conda_prefix), libname), libname)

cuda_home = tmp_path / "cuda-home"
cuda_home_path = _make_bitcode_lib_file(_bitcode_lib_dir_under(cuda_home))
cuda_home_path = _make_bitcode_lib_file(_bitcode_lib_dir_under(cuda_home, libname), libname)

site_packages_sub_dirs = tuple(
tuple(rel_dir.split("/")) for rel_dir in _bitcode_lib_info(libname)["site_packages_dirs"]
)

def find_expected_sub_dir(sub_dir):
assert sub_dir in site_packages_sub_dirs
if sub_dir == site_packages_sub_dirs[0]:
return [str(site_packages_lib_dir)]
return []

monkeypatch.setattr(
find_bitcode_lib_module,
"find_sub_dirs_all_sitepackages",
lambda _sub_dir: [str(site_packages_lib_dir)],
find_expected_sub_dir,
)
monkeypatch.setenv("CONDA_PREFIX", str(conda_prefix))
monkeypatch.setenv("CUDA_HOME", str(cuda_home))
monkeypatch.delenv("CUDA_PATH", raising=False)

located_lib = locate_bitcode_lib("device")
located_lib = locate_bitcode_lib(libname)
assert located_lib.abs_path == site_packages_path
assert located_lib.found_via == "site-packages"
os.remove(site_packages_path)

located_lib = locate_bitcode_lib("device")
located_lib = locate_bitcode_lib(libname)
assert located_lib.abs_path == conda_path
assert located_lib.found_via == "conda"
os.remove(conda_path)

located_lib = locate_bitcode_lib("device")
located_lib = locate_bitcode_lib(libname)
assert located_lib.abs_path == cuda_home_path
assert located_lib.found_via == "CUDA_PATH"


@pytest.mark.usefixtures("clear_find_bitcode_lib_cache")
def test_find_bitcode_lib_not_found_error_includes_cuda_home_directory_listing(monkeypatch, tmp_path):
cuda_home = tmp_path / "cuda-home"
lib_dir = _bitcode_lib_dir_under(cuda_home)
lib_dir = _bitcode_lib_dir_under(cuda_home, "device")
lib_dir.mkdir(parents=True, exist_ok=True)
extra_file = lib_dir / "README.txt"
extra_file.write_text("placeholder", encoding="utf-8")
Expand All @@ -134,7 +156,7 @@ def test_find_bitcode_lib_not_found_error_includes_cuda_home_directory_listing(m
find_bitcode_lib("device")

message = str(exc_info.value)
expected_missing_file = os.path.join(str(lib_dir), BCL_FILENAME)
expected_missing_file = os.path.join(str(lib_dir), _bitcode_lib_filename("device"))
assert f"No such file: {expected_missing_file}" in message
assert f'listdir("{lib_dir}"):' in message
assert "README.txt" in message
Expand Down
Loading