Skip to content

Commit aed9416

Browse files
Add framework-hinted runtime library discovery to address runtime mismatch on dual-vendor systems. (#86)
Signed-off-by: Petr Kurapov <petr.kurapov@gmail.com>
1 parent cf6c65f commit aed9416

9 files changed

Lines changed: 163 additions & 28 deletions

File tree

fastsafetensors/common.py

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -69,16 +69,34 @@ def _normalize_windows_dll_path(path: str, source: str) -> str:
6969
return normalized
7070

7171

72-
def resolve_cudart_lib_name() -> str:
73-
"""Resolve the CUDA runtime library name for the current platform.
72+
def resolve_runtime_lib_name(framework=None) -> str:
73+
"""Resolve the GPU runtime library to dlopen for the current platform.
7474
75-
Returns:
76-
On Windows, an absolute DLL path string. On other platforms, "" to use
77-
the compiled-in default.
75+
On Windows, returns an absolute cudart DLL path. On other platforms, maps the framework's declared GPU
76+
vendor to a runtime library so the dlopen'd vendor stays in sync with the
77+
framework's GPU build. Returns "" when there is no usable hint so the caller falls
78+
back to auto-detection.
7879
"""
79-
if sys.platform != "win32":
80-
return "" # Non-Windows: use auto-detection (CUDA first, then ROCm)
80+
if sys.platform == "win32":
81+
return _resolve_windows_cudart_lib_name()
82+
if framework is None:
83+
return ""
84+
try:
85+
ver = framework.get_cuda_ver()
86+
except Exception:
87+
return ""
88+
if not ver or "-" not in ver:
89+
return ""
90+
vendor = ver.split("-", 1)[0]
91+
if vendor == "hip":
92+
return "libamdhip64.so"
93+
if vendor == "cuda":
94+
return "libcudart.so"
95+
return ""
96+
8197

98+
def _resolve_windows_cudart_lib_name() -> str:
99+
"""Resolve the absolute cudart DLL path on Windows, "" for the default."""
82100
# Allow explicit override via environment variable
83101
override = os.environ.get("FASTSAFETENSORS_CUDART_LIB", "").strip()
84102
if override:

fastsafetensors/copier/dstorage.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
SafeTensorsMetadata,
1212
init_logger,
1313
is_gpu_found,
14-
resolve_cudart_lib_name,
14+
resolve_runtime_lib_name,
1515
)
1616
from ..frameworks import FrameworkOpBase, TensorBase
1717
from ..st_types import Device, DeviceType, DType
@@ -114,7 +114,8 @@ def init_dstorage(device_id: int = 0) -> None:
114114
load_library_func()
115115
if not is_gpu_found():
116116
raise RuntimeError("CUDA runtime not found")
117-
cudart_dll = resolve_cudart_lib_name()
117+
# Windows-only; resolve_runtime_lib_name() returns the cudart DLL path here
118+
cudart_dll = resolve_runtime_lib_name()
118119
if not cudart_dll:
119120
raise RuntimeError("Could not find CUDA runtime DLL")
120121
if _dstorage_dll_dir is None:

fastsafetensors/copier/gds.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -165,8 +165,8 @@ def wait_io(
165165
_inited_gds = False
166166

167167

168-
def init_gds():
169-
load_library_func()
168+
def init_gds(framework: Optional[FrameworkOpBase] = None):
169+
load_library_func(framework)
170170
global _inited_gds
171171
if not _inited_gds:
172172
if fstcpp.init_gds() != 0:
@@ -181,7 +181,7 @@ def new_gds_file_copier(
181181
max_threads: int = 16,
182182
**kwargs,
183183
) -> CopierConstructFunc:
184-
init_gds()
184+
init_gds(kwargs.get("framework"))
185185
device_is_not_cpu = device.type != DeviceType.CPU
186186
if device_is_not_cpu and not is_gpu_found():
187187
raise Exception(
@@ -215,8 +215,10 @@ def new_gds_file_copier(
215215
from .unified import is_unified_memory_system, new_unified_copier
216216

217217
if device_is_not_cpu and is_unified_memory_system(kwargs.get("framework")):
218-
return new_unified_copier(device)
219-
return new_nogds_file_copier(device, bbuf_size_kb, max_threads)
218+
return new_unified_copier(device, framework=kwargs.get("framework"))
219+
return new_nogds_file_copier(
220+
device, bbuf_size_kb, max_threads, framework=kwargs.get("framework")
221+
)
220222

221223
reader = fstcpp.gds_file_reader(max_threads, device_is_not_cpu, device_id)
222224

fastsafetensors/copier/nogds.py

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,11 @@
55
from typing import Dict, List
66

77
from .. import cpp as fstcpp
8-
from ..common import SafeTensorsMetadata, is_gpu_found, resolve_cudart_lib_name
8+
from ..common import (
9+
SafeTensorsMetadata,
10+
is_gpu_found,
11+
resolve_runtime_lib_name,
12+
)
913
from ..frameworks import FrameworkOpBase, TensorBase
1014
from ..st_types import Device, DeviceType, DType
1115
from .base import CopierInterface
@@ -81,12 +85,23 @@ def wait_io(
8185
_loaded_library = False
8286

8387

84-
def load_library_func():
88+
def load_library_func(framework=None):
8589
global _loaded_library
86-
if not _loaded_library:
87-
cudart_lib = resolve_cudart_lib_name()
88-
fstcpp.load_library_functions(cudart_lib)
89-
_loaded_library = True
90+
if _loaded_library:
91+
return
92+
93+
lib = resolve_runtime_lib_name(framework)
94+
fstcpp.load_library_functions(lib)
95+
if lib and not is_gpu_found():
96+
# The framework hinted a specific vendor's runtime but loading it found
97+
# no GPU. A GPU-built framework only reports a vendor when it sees a
98+
# device, so this is a real mismatch (wrong/missing runtime for that
99+
# vendor).
100+
raise Exception(
101+
f"[FAIL] framework hinted GPU runtime '{lib}' but no GPU was found "
102+
"after loading it (runtime/devices for that vendor not present)"
103+
)
104+
_loaded_library = True
90105

91106

92107
@register_copier_constructor("nogds")
@@ -96,7 +111,7 @@ def new_nogds_file_copier(
96111
max_threads: int = 16,
97112
**kwargs,
98113
) -> CopierConstructFunc:
99-
load_library_func()
114+
load_library_func(kwargs.get("framework"))
100115
device_is_not_cpu = device.type != DeviceType.CPU
101116
if device_is_not_cpu and not is_gpu_found():
102117
raise Exception(

fastsafetensors/copier/unified.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ def new_unified_copier(device: Device, **kwargs) -> CopierConstructFunc:
122122
"""
123123
from .nogds import load_library_func
124124

125-
load_library_func()
125+
load_library_func(kwargs.get("framework"))
126126

127127
def construct_unified_copier(
128128
metadata: SafeTensorsMetadata,

fastsafetensors/threefs_loader.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from typing import Any, Dict, List, Mapping, Optional
44

55
from . import cpp as fstcpp
6-
from .common import init_logger, resolve_cudart_lib_name
6+
from .common import init_logger, resolve_runtime_lib_name
77
from .frameworks import get_framework_op
88
from .loader import BaseSafeTensorsFileLoader, loaded_library
99
from .parallel_loader import PipelineParallel
@@ -54,7 +54,7 @@ def __init__(
5454

5555
global loaded_library
5656
if not loaded_library:
57-
fstcpp.load_library_functions(resolve_cudart_lib_name())
57+
fstcpp.load_library_functions(resolve_runtime_lib_name())
5858
loaded_library = True
5959
fstcpp.set_debug_log(debug_log)
6060
super().__init__(

tests/unit/conftest.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
from fastsafetensors import SingleGroup
99
from fastsafetensors import cpp as fstcpp
10-
from fastsafetensors.common import is_gpu_found, resolve_cudart_lib_name
10+
from fastsafetensors.common import is_gpu_found, resolve_runtime_lib_name
1111
from fastsafetensors.cpp import load_library_functions
1212
from fastsafetensors.frameworks import FrameworkOpBase, get_framework_op
1313
from fastsafetensors.st_types import Device
@@ -25,7 +25,7 @@
2525
os.makedirs(TMP_DIR, 0o777, True)
2626
os.makedirs(GENERATED_DIR, 0o777, True)
2727

28-
load_library_functions(resolve_cudart_lib_name())
28+
load_library_functions(resolve_runtime_lib_name())
2929
FRAMEWORK = get_framework_op(os.getenv("TEST_FASTSAFETENSORS_FRAMEWORK", "please set"))
3030

3131
# Print platform information at test startup

tests/unit/test_runtime_hint.py

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
3+
"""Unit tests for the framework-hinted GPU runtime selection."""
4+
5+
import sys
6+
7+
import pytest
8+
9+
from fastsafetensors import common
10+
11+
12+
class _FakeFramework:
13+
def __init__(self, cuda_ver):
14+
self._ver = cuda_ver
15+
16+
def get_cuda_ver(self):
17+
if isinstance(self._ver, Exception):
18+
raise self._ver
19+
return self._ver
20+
21+
22+
@pytest.fixture(autouse=True)
23+
def _force_non_windows(monkeypatch):
24+
# The hint is intentionally a no-op on Windows (cudart resolver owns it).
25+
monkeypatch.setattr(sys, "platform", "linux")
26+
27+
28+
def test_none_framework_uses_autodetect():
29+
assert common.resolve_runtime_lib_name(None) == ""
30+
31+
32+
def test_hip_framework_selects_amdhip():
33+
assert (
34+
common.resolve_runtime_lib_name(_FakeFramework("hip-7.2.0")) == "libamdhip64.so"
35+
)
36+
37+
38+
def test_cuda_framework_selects_cudart():
39+
assert (
40+
common.resolve_runtime_lib_name(_FakeFramework("cuda-12.1")) == "libcudart.so"
41+
)
42+
43+
44+
@pytest.mark.parametrize("ver", ["", "weird", "rocm-7.0"])
45+
def test_unknown_vendor_uses_autodetect(ver):
46+
assert common.resolve_runtime_lib_name(_FakeFramework(ver)) == ""
47+
48+
49+
def test_get_cuda_ver_raises_uses_autodetect():
50+
assert common.resolve_runtime_lib_name(_FakeFramework(RuntimeError("boom"))) == ""
51+
52+
53+
def test_windows_is_noop(monkeypatch):
54+
monkeypatch.setattr(sys, "platform", "win32")
55+
assert common.resolve_runtime_lib_name(_FakeFramework("hip-7.2.0")) == ""
56+
57+
58+
def test_load_library_func_hint_with_no_gpu_raises(monkeypatch):
59+
"""A hint that finds no GPU is a hard failure"""
60+
from fastsafetensors.copier import nogds
61+
62+
calls = []
63+
64+
def fake_load(lib):
65+
calls.append(lib)
66+
67+
monkeypatch.setattr(nogds.fstcpp, "load_library_functions", fake_load)
68+
monkeypatch.setattr(
69+
nogds, "resolve_runtime_lib_name", lambda fw=None: "libamdhip64.so"
70+
)
71+
72+
monkeypatch.setattr(nogds, "is_gpu_found", lambda: False)
73+
monkeypatch.setattr(nogds, "_loaded_library", False)
74+
75+
with pytest.raises(Exception, match="libamdhip64.so"):
76+
nogds.load_library_func(_FakeFramework("hip-7.2.0"))
77+
78+
assert calls == ["libamdhip64.so"]
79+
assert nogds._loaded_library is False
80+
81+
82+
def test_load_library_func_hint_succeeds_no_fallback(monkeypatch):
83+
from fastsafetensors.copier import nogds
84+
85+
calls = []
86+
monkeypatch.setattr(
87+
nogds.fstcpp, "load_library_functions", lambda lib: calls.append(lib)
88+
)
89+
monkeypatch.setattr(
90+
nogds,
91+
"resolve_runtime_lib_name",
92+
lambda fw=None: "libamdhip64.so" if fw is not None else "",
93+
)
94+
monkeypatch.setattr(nogds, "is_gpu_found", lambda: True)
95+
monkeypatch.setattr(nogds, "_loaded_library", False)
96+
97+
nogds.load_library_func(_FakeFramework("hip-7.2.0"))
98+
99+
assert calls == ["libamdhip64.so"]

tests/unit/threefs/conftest.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
from fastsafetensors import SingleGroup
1616
from fastsafetensors import cpp as fstcpp
17-
from fastsafetensors.common import is_gpu_found, resolve_cudart_lib_name
17+
from fastsafetensors.common import is_gpu_found, resolve_runtime_lib_name
1818
from fastsafetensors.cpp import load_library_functions
1919
from fastsafetensors.frameworks import FrameworkOpBase, get_framework_op
2020
from fastsafetensors.st_types import Device
@@ -34,7 +34,7 @@ def mock_3fs_reader():
3434
yield
3535

3636

37-
load_library_functions(resolve_cudart_lib_name())
37+
load_library_functions(resolve_runtime_lib_name())
3838
FRAMEWORK = get_framework_op(os.getenv("TEST_FASTSAFETENSORS_FRAMEWORK", "please set"))
3939

4040

0 commit comments

Comments
 (0)