Skip to content

Commit 7775b5e

Browse files
gitbisectorclaude
andcommitted
gds copier: fall back to nogds when cuFile handle registration fails
cuFile can probe as available yet fail cuFileHandleRegister at I/O time (compat-mode hosts without nvidia-fs, checkpoints on overlayfs, CI runners). Catch the failure in submit_io, warn once, and transparently delegate the copier to the nogds bounce path -- so every consumer stops carrying its own gds->nogds retry wrapper. The fallback (and its bounce-buffer reader) lives only for the file's submit/wait cycle. Signed-off-by: git bisector <gitbisector@gmail.com> Co-authored-by: Claude <noreply@anthropic.com>
1 parent aed9416 commit 7775b5e

3 files changed

Lines changed: 137 additions & 3 deletions

File tree

.github/workflows/test-torch.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ jobs:
4646
mkdir -p /tmp/pytest-log
4747
export TEST_FASTSAFETENSORS_FRAMEWORK=pytorch
4848
COVERAGE_FILE=.coverage_0 pytest -s --cov=${LIBDIR} test_fastsafetensors.py > /tmp/pytest-log/0.log 2>&1
49+
COVERAGE_FILE=.coverage_rob pytest -s --cov=${LIBDIR} test_robustness.py > /tmp/pytest-log/robustness.log 2>&1
4950
COVERAGE_FILE=.coverage_config pytest -s --cov=${LIBDIR} test_config.py > /tmp/pytest-log/config.log 2>&1
5051
COVERAGE_FILE=.coverage_auto pytest -s --cov=${LIBDIR} test_auto_loader.py > /tmp/pytest-log/auto_loader.log 2>&1
5152
COVERAGE_FILE=.coverage_3fs pytest -s --cov=${LIBDIR} threefs/ > /tmp/pytest-log/threefs.log 2>&1

fastsafetensors/copier/gds.py

Lines changed: 58 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import platform
44
import warnings
5-
from typing import Dict, Optional
5+
from typing import Dict, List, Optional
66

77
from .. import cpp as fstcpp
88
from ..common import SafeTensorsMetadata, init_logger, is_gpu_found
@@ -14,6 +14,8 @@
1414

1515
logger = init_logger(__name__)
1616

17+
_warned_gds_fallback = False
18+
1719

1820
class GdsFileCopier(CopierInterface):
1921
def __init__(
@@ -22,6 +24,7 @@ def __init__(
2224
device: Device,
2325
reader: fstcpp.gds_file_reader,
2426
framework: FrameworkOpBase,
27+
fallback_cache: Optional[List[CopierConstructFunc]] = None,
2528
):
2629
self.framework = framework
2730
self.metadata = metadata
@@ -31,6 +34,11 @@ def __init__(
3134
self.fh: Optional[fstcpp.gds_file_handle] = None
3235
self.copy_reqs: Dict[int, int] = {}
3336
self.aligned_length = 0
37+
self._fallback: Optional[CopierInterface] = None
38+
# One-slot cell shared by all copiers from the same factory, so a
39+
# broken-GDS host builds a single nogds fallback reader (and its
40+
# pinned bounce buffer) per loader instead of one per file.
41+
self._fallback_cache = fallback_cache
3442
cuda_ver = framework.get_cuda_ver()
3543
if cuda_ver and cuda_ver != "0.0":
3644
# Parse version string (e.g., "cuda-12.1" or "hip-5.7.0")
@@ -65,7 +73,45 @@ def submit_io(
6573
self.device.type == DeviceType.CUDA or self.device.type == DeviceType.GPU
6674
)
6775
ALIGN: int = fstcpp.get_alignment_size()
68-
self.fh = fstcpp.gds_file_handle(self.metadata.src, self.o_direct, dev_is_cuda)
76+
try:
77+
self.fh = fstcpp.gds_file_handle(
78+
self.metadata.src, self.o_direct, dev_is_cuda
79+
)
80+
except RuntimeError as e:
81+
# cuFile can probe as available yet fail at I/O time: handle
82+
# registration errors on compat-mode hosts or unsupported
83+
# filesystems (e.g. overlayfs), or open(O_DIRECT) rejections.
84+
# Downgrade this copier to the nogds bounce path instead of
85+
# failing, so consumers don't each need their own gds->nogds
86+
# retry. Deliberately limited to file-handle setup: failures in
87+
# already-submitted reads stay fatal (falling back mid-cycle
88+
# would re-read earlier data).
89+
global _warned_gds_fallback
90+
if not _warned_gds_fallback:
91+
_warned_gds_fallback = True
92+
# str(e): keeping the exception object in the log record would
93+
# retain its traceback (and this frame's locals) via any
94+
# record-capturing handler.
95+
logger.warning(
96+
"GDS file-handle setup failed (%s); "
97+
"falling back to the nogds copier",
98+
str(e),
99+
)
100+
if self._fallback_cache is not None:
101+
if not self._fallback_cache:
102+
self._fallback_cache.append(
103+
new_nogds_file_copier(self.device, framework=self.framework)
104+
)
105+
self._fallback = self._fallback_cache[0](
106+
self.metadata, self.device, self.framework
107+
)
108+
else:
109+
# direct construction (no factory): reader lives only for this
110+
# file's submit/wait cycle and is released in wait_io
111+
self._fallback = new_nogds_file_copier(
112+
self.device, framework=self.framework
113+
)(self.metadata, self.device, self.framework)
114+
return self._fallback.submit_io(use_buf_register, max_copy_block_size)
69115
offset = self.metadata.header_length
70116
length = self.metadata.size_bytes - self.metadata.header_length
71117
head_bytes = offset % ALIGN
@@ -120,6 +166,11 @@ def wait_io(
120166
dtype: DType = DType.AUTO,
121167
noalign: bool = False,
122168
) -> Dict[str, TensorBase]:
169+
if self._fallback is not None:
170+
tensors = self._fallback.wait_io(gbuf, dtype=dtype, noalign=noalign)
171+
# Drop the fallback copier so its bounce-buffer reader is freed.
172+
self._fallback = None
173+
return tensors
123174
failed = []
124175
for req, c in sorted(self.copy_reqs.items(), key=lambda x: x[0]):
125176
count = self.reader.wait_read(req)
@@ -222,11 +273,15 @@ def new_gds_file_copier(
222273

223274
reader = fstcpp.gds_file_reader(max_threads, device_is_not_cpu, device_id)
224275

276+
fallback_cache: List[CopierConstructFunc] = []
277+
225278
def construct_copier(
226279
metadata: SafeTensorsMetadata,
227280
device: Device,
228281
framework: FrameworkOpBase,
229282
) -> CopierInterface:
230-
return GdsFileCopier(metadata, device, reader, framework)
283+
return GdsFileCopier(
284+
metadata, device, reader, framework, fallback_cache=fallback_cache
285+
)
231286

232287
return construct_copier

tests/unit/test_robustness.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
"""Robustness tests: graceful degradation of I/O paths."""
3+
4+
import pytest
5+
6+
# ---- runtime GDS -> nogds fallback ----
7+
8+
9+
def test_gds_copier_falls_back_to_nogds(input_files, framework, monkeypatch):
10+
if framework.get_name() != "pytorch":
11+
pytest.skip("pytorch-only")
12+
from test_fastsafetensors import load_safetensors_file
13+
14+
from fastsafetensors import SafeTensorsMetadata
15+
from fastsafetensors import cpp as fstcpp
16+
from fastsafetensors.copier.gds import GdsFileCopier
17+
from fastsafetensors.st_types import Device
18+
19+
def _boom(*a, **k):
20+
raise RuntimeError(
21+
"raw_gds_file_handle: cuFileHandleRegister returned an error = 5027"
22+
)
23+
24+
monkeypatch.setattr(fstcpp, "gds_file_handle", _boom)
25+
26+
device = Device.from_str("cpu")
27+
meta = SafeTensorsMetadata.from_file(input_files[0], framework)
28+
reader = fstcpp.gds_file_reader(4, False, 0)
29+
copier = GdsFileCopier(meta, device, reader, framework)
30+
gbuf = copier.submit_io(False, 10 * 1024 * 1024 * 1024)
31+
tensors = copier.wait_io(gbuf)
32+
expected = load_safetensors_file(input_files[0], device, framework)
33+
assert set(tensors.keys()) == set(expected.keys())
34+
for k, exp in expected.items():
35+
assert framework.is_equal(tensors[k], exp), k
36+
framework.free_tensor_memory(gbuf, device)
37+
# the fallback's bounce-buffer reader must not outlive the copy cycle
38+
assert fstcpp.get_cpp_metrics().bounce_buffer_bytes == 0
39+
40+
41+
def test_gds_fallback_warns_once_and_shares_reader(
42+
input_files, framework, monkeypatch, caplog
43+
):
44+
if framework.get_name() != "pytorch":
45+
pytest.skip("pytorch-only")
46+
import logging
47+
48+
from fastsafetensors import SafeTensorsMetadata
49+
from fastsafetensors import cpp as fstcpp
50+
from fastsafetensors.copier import gds as gds_mod
51+
from fastsafetensors.st_types import Device
52+
53+
monkeypatch.setattr(
54+
fstcpp,
55+
"gds_file_handle",
56+
lambda *a, **k: (_ for _ in ()).throw(RuntimeError("error = 5027")),
57+
)
58+
monkeypatch.setattr(gds_mod, "_warned_gds_fallback", False)
59+
60+
device = Device.from_str("cpu")
61+
meta = SafeTensorsMetadata.from_file(input_files[0], framework)
62+
reader = fstcpp.gds_file_reader(4, False, 0)
63+
cache = []
64+
with caplog.at_level(logging.WARNING, logger="fastsafetensors.copier.gds"):
65+
c1 = gds_mod.GdsFileCopier(
66+
meta, device, reader, framework, fallback_cache=cache
67+
)
68+
g1 = c1.submit_io(False, 1 << 30)
69+
c1.wait_io(g1)
70+
c2 = gds_mod.GdsFileCopier(
71+
meta, device, reader, framework, fallback_cache=cache
72+
)
73+
g2 = c2.submit_io(False, 1 << 30)
74+
c2.wait_io(g2)
75+
assert caplog.text.count("falling back to the nogds copier") == 1 # warn once
76+
assert len(cache) == 1 # one shared nogds constructor for the whole loader
77+
framework.free_tensor_memory(g1, device)
78+
framework.free_tensor_memory(g2, device)

0 commit comments

Comments
 (0)