diff --git a/.github/workflows/test-torch.yaml b/.github/workflows/test-torch.yaml index 6e7254c..251580c 100644 --- a/.github/workflows/test-torch.yaml +++ b/.github/workflows/test-torch.yaml @@ -46,6 +46,7 @@ jobs: mkdir -p /tmp/pytest-log export TEST_FASTSAFETENSORS_FRAMEWORK=pytorch COVERAGE_FILE=.coverage_0 pytest -s --cov=${LIBDIR} test_fastsafetensors.py > /tmp/pytest-log/0.log 2>&1 + COVERAGE_FILE=.coverage_rob pytest -s --cov=${LIBDIR} test_robustness.py > /tmp/pytest-log/robustness.log 2>&1 COVERAGE_FILE=.coverage_config pytest -s --cov=${LIBDIR} test_config.py > /tmp/pytest-log/config.log 2>&1 COVERAGE_FILE=.coverage_auto pytest -s --cov=${LIBDIR} test_auto_loader.py > /tmp/pytest-log/auto_loader.log 2>&1 COVERAGE_FILE=.coverage_3fs pytest -s --cov=${LIBDIR} threefs/ > /tmp/pytest-log/threefs.log 2>&1 diff --git a/fastsafetensors/copier/gds.py b/fastsafetensors/copier/gds.py index 246211b..c9164e5 100644 --- a/fastsafetensors/copier/gds.py +++ b/fastsafetensors/copier/gds.py @@ -2,7 +2,7 @@ import platform import warnings -from typing import Dict, Optional +from typing import Dict, List, Optional from .. import cpp as fstcpp from ..common import SafeTensorsMetadata, init_logger, is_gpu_found @@ -14,6 +14,8 @@ logger = init_logger(__name__) +_warned_gds_fallback = False + class GdsFileCopier(CopierInterface): def __init__( @@ -22,6 +24,7 @@ def __init__( device: Device, reader: fstcpp.gds_file_reader, framework: FrameworkOpBase, + fallback_cache: Optional[List[CopierConstructFunc]] = None, ): self.framework = framework self.metadata = metadata @@ -31,6 +34,11 @@ def __init__( self.fh: Optional[fstcpp.gds_file_handle] = None self.copy_reqs: Dict[int, int] = {} self.aligned_length = 0 + self._fallback: Optional[CopierInterface] = None + # One-slot cell shared by all copiers from the same factory, so a + # broken-GDS host builds a single nogds fallback reader (and its + # pinned bounce buffer) per loader instead of one per file. + self._fallback_cache = fallback_cache cuda_ver = framework.get_cuda_ver() if cuda_ver and cuda_ver != "0.0": # Parse version string (e.g., "cuda-12.1" or "hip-5.7.0") @@ -65,7 +73,45 @@ def submit_io( self.device.type == DeviceType.CUDA or self.device.type == DeviceType.GPU ) ALIGN: int = fstcpp.get_alignment_size() - self.fh = fstcpp.gds_file_handle(self.metadata.src, self.o_direct, dev_is_cuda) + try: + self.fh = fstcpp.gds_file_handle( + self.metadata.src, self.o_direct, dev_is_cuda + ) + except RuntimeError as e: + # cuFile can probe as available yet fail at I/O time: handle + # registration errors on compat-mode hosts or unsupported + # filesystems (e.g. overlayfs), or open(O_DIRECT) rejections. + # Downgrade this copier to the nogds bounce path instead of + # failing, so consumers don't each need their own gds->nogds + # retry. Deliberately limited to file-handle setup: failures in + # already-submitted reads stay fatal (falling back mid-cycle + # would re-read earlier data). + global _warned_gds_fallback + if not _warned_gds_fallback: + _warned_gds_fallback = True + # str(e): keeping the exception object in the log record would + # retain its traceback (and this frame's locals) via any + # record-capturing handler. + logger.warning( + "GDS file-handle setup failed (%s); " + "falling back to the nogds copier", + str(e), + ) + if self._fallback_cache is not None: + if not self._fallback_cache: + self._fallback_cache.append( + new_nogds_file_copier(self.device, framework=self.framework) + ) + self._fallback = self._fallback_cache[0]( + self.metadata, self.device, self.framework + ) + else: + # direct construction (no factory): reader lives only for this + # file's submit/wait cycle and is released in wait_io + self._fallback = new_nogds_file_copier( + self.device, framework=self.framework + )(self.metadata, self.device, self.framework) + return self._fallback.submit_io(use_buf_register, max_copy_block_size) offset = self.metadata.header_length length = self.metadata.size_bytes - self.metadata.header_length head_bytes = offset % ALIGN @@ -120,6 +166,11 @@ def wait_io( dtype: DType = DType.AUTO, noalign: bool = False, ) -> Dict[str, TensorBase]: + if self._fallback is not None: + tensors = self._fallback.wait_io(gbuf, dtype=dtype, noalign=noalign) + # Drop the fallback copier so its bounce-buffer reader is freed. + self._fallback = None + return tensors failed = [] for req, c in sorted(self.copy_reqs.items(), key=lambda x: x[0]): count = self.reader.wait_read(req) @@ -222,11 +273,15 @@ def new_gds_file_copier( reader = fstcpp.gds_file_reader(max_threads, device_is_not_cpu, device_id) + fallback_cache: List[CopierConstructFunc] = [] + def construct_copier( metadata: SafeTensorsMetadata, device: Device, framework: FrameworkOpBase, ) -> CopierInterface: - return GdsFileCopier(metadata, device, reader, framework) + return GdsFileCopier( + metadata, device, reader, framework, fallback_cache=fallback_cache + ) return construct_copier diff --git a/tests/unit/test_robustness.py b/tests/unit/test_robustness.py new file mode 100644 index 0000000..76572ce --- /dev/null +++ b/tests/unit/test_robustness.py @@ -0,0 +1,78 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Robustness tests: graceful degradation of I/O paths.""" + +import pytest + +# ---- runtime GDS -> nogds fallback ---- + + +def test_gds_copier_falls_back_to_nogds(input_files, framework, monkeypatch): + if framework.get_name() != "pytorch": + pytest.skip("pytorch-only") + from test_fastsafetensors import load_safetensors_file + + from fastsafetensors import SafeTensorsMetadata + from fastsafetensors import cpp as fstcpp + from fastsafetensors.copier.gds import GdsFileCopier + from fastsafetensors.st_types import Device + + def _boom(*a, **k): + raise RuntimeError( + "raw_gds_file_handle: cuFileHandleRegister returned an error = 5027" + ) + + monkeypatch.setattr(fstcpp, "gds_file_handle", _boom) + + device = Device.from_str("cpu") + meta = SafeTensorsMetadata.from_file(input_files[0], framework) + reader = fstcpp.gds_file_reader(4, False, 0) + copier = GdsFileCopier(meta, device, reader, framework) + gbuf = copier.submit_io(False, 10 * 1024 * 1024 * 1024) + tensors = copier.wait_io(gbuf) + expected = load_safetensors_file(input_files[0], device, framework) + assert set(tensors.keys()) == set(expected.keys()) + for k, exp in expected.items(): + assert framework.is_equal(tensors[k], exp), k + framework.free_tensor_memory(gbuf, device) + # the fallback's bounce-buffer reader must not outlive the copy cycle + assert fstcpp.get_cpp_metrics().bounce_buffer_bytes == 0 + + +def test_gds_fallback_warns_once_and_shares_reader( + input_files, framework, monkeypatch, caplog +): + if framework.get_name() != "pytorch": + pytest.skip("pytorch-only") + import logging + + from fastsafetensors import SafeTensorsMetadata + from fastsafetensors import cpp as fstcpp + from fastsafetensors.copier import gds as gds_mod + from fastsafetensors.st_types import Device + + monkeypatch.setattr( + fstcpp, + "gds_file_handle", + lambda *a, **k: (_ for _ in ()).throw(RuntimeError("error = 5027")), + ) + monkeypatch.setattr(gds_mod, "_warned_gds_fallback", False) + + device = Device.from_str("cpu") + meta = SafeTensorsMetadata.from_file(input_files[0], framework) + reader = fstcpp.gds_file_reader(4, False, 0) + cache = [] + with caplog.at_level(logging.WARNING, logger="fastsafetensors.copier.gds"): + c1 = gds_mod.GdsFileCopier( + meta, device, reader, framework, fallback_cache=cache + ) + g1 = c1.submit_io(False, 1 << 30) + c1.wait_io(g1) + c2 = gds_mod.GdsFileCopier( + meta, device, reader, framework, fallback_cache=cache + ) + g2 = c2.submit_io(False, 1 << 30) + c2.wait_io(g2) + assert caplog.text.count("falling back to the nogds copier") == 1 # warn once + assert len(cache) == 1 # one shared nogds constructor for the whole loader + framework.free_tensor_memory(g1, device) + framework.free_tensor_memory(g2, device)