Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/test-torch.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ jobs:
mkdir -p /tmp/pytest-log
export TEST_FASTSAFETENSORS_FRAMEWORK=pytorch
COVERAGE_FILE=.coverage_0 pytest -s --cov=${LIBDIR} test_fastsafetensors.py > /tmp/pytest-log/0.log 2>&1
COVERAGE_FILE=.coverage_rob pytest -s --cov=${LIBDIR} test_robustness.py > /tmp/pytest-log/robustness.log 2>&1
COVERAGE_FILE=.coverage_config pytest -s --cov=${LIBDIR} test_config.py > /tmp/pytest-log/config.log 2>&1
COVERAGE_FILE=.coverage_auto pytest -s --cov=${LIBDIR} test_auto_loader.py > /tmp/pytest-log/auto_loader.log 2>&1
COVERAGE_FILE=.coverage_3fs pytest -s --cov=${LIBDIR} threefs/ > /tmp/pytest-log/threefs.log 2>&1
Expand Down
61 changes: 58 additions & 3 deletions fastsafetensors/copier/gds.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import platform
import warnings
from typing import Dict, Optional
from typing import Dict, List, Optional

from .. import cpp as fstcpp
from ..common import SafeTensorsMetadata, init_logger, is_gpu_found
Expand All @@ -14,6 +14,8 @@

logger = init_logger(__name__)

_warned_gds_fallback = False


class GdsFileCopier(CopierInterface):
def __init__(
Expand All @@ -22,6 +24,7 @@ def __init__(
device: Device,
reader: fstcpp.gds_file_reader,
framework: FrameworkOpBase,
fallback_cache: Optional[List[CopierConstructFunc]] = None,
):
self.framework = framework
self.metadata = metadata
Expand All @@ -31,6 +34,11 @@ def __init__(
self.fh: Optional[fstcpp.gds_file_handle] = None
self.copy_reqs: Dict[int, int] = {}
self.aligned_length = 0
self._fallback: Optional[CopierInterface] = None
# One-slot cell shared by all copiers from the same factory, so a
# broken-GDS host builds a single nogds fallback reader (and its
# pinned bounce buffer) per loader instead of one per file.
self._fallback_cache = fallback_cache
cuda_ver = framework.get_cuda_ver()
if cuda_ver and cuda_ver != "0.0":
# Parse version string (e.g., "cuda-12.1" or "hip-5.7.0")
Expand Down Expand Up @@ -65,7 +73,45 @@ def submit_io(
self.device.type == DeviceType.CUDA or self.device.type == DeviceType.GPU
)
ALIGN: int = fstcpp.get_alignment_size()
self.fh = fstcpp.gds_file_handle(self.metadata.src, self.o_direct, dev_is_cuda)
try:
self.fh = fstcpp.gds_file_handle(
self.metadata.src, self.o_direct, dev_is_cuda
)
except RuntimeError as e:
# cuFile can probe as available yet fail at I/O time: handle
# registration errors on compat-mode hosts or unsupported
# filesystems (e.g. overlayfs), or open(O_DIRECT) rejections.
# Downgrade this copier to the nogds bounce path instead of
# failing, so consumers don't each need their own gds->nogds
# retry. Deliberately limited to file-handle setup: failures in
# already-submitted reads stay fatal (falling back mid-cycle
# would re-read earlier data).
global _warned_gds_fallback
if not _warned_gds_fallback:
_warned_gds_fallback = True
# str(e): keeping the exception object in the log record would
# retain its traceback (and this frame's locals) via any
# record-capturing handler.
logger.warning(
"GDS file-handle setup failed (%s); "
"falling back to the nogds copier",
str(e),
)
if self._fallback_cache is not None:
if not self._fallback_cache:
self._fallback_cache.append(
new_nogds_file_copier(self.device, framework=self.framework)
)
self._fallback = self._fallback_cache[0](
self.metadata, self.device, self.framework
)
else:
# direct construction (no factory): reader lives only for this
# file's submit/wait cycle and is released in wait_io
self._fallback = new_nogds_file_copier(
self.device, framework=self.framework
)(self.metadata, self.device, self.framework)
return self._fallback.submit_io(use_buf_register, max_copy_block_size)
offset = self.metadata.header_length
length = self.metadata.size_bytes - self.metadata.header_length
head_bytes = offset % ALIGN
Expand Down Expand Up @@ -120,6 +166,11 @@ def wait_io(
dtype: DType = DType.AUTO,
noalign: bool = False,
) -> Dict[str, TensorBase]:
if self._fallback is not None:
tensors = self._fallback.wait_io(gbuf, dtype=dtype, noalign=noalign)
# Drop the fallback copier so its bounce-buffer reader is freed.
self._fallback = None
return tensors
failed = []
for req, c in sorted(self.copy_reqs.items(), key=lambda x: x[0]):
count = self.reader.wait_read(req)
Expand Down Expand Up @@ -222,11 +273,15 @@ def new_gds_file_copier(

reader = fstcpp.gds_file_reader(max_threads, device_is_not_cpu, device_id)

fallback_cache: List[CopierConstructFunc] = []

def construct_copier(
metadata: SafeTensorsMetadata,
device: Device,
framework: FrameworkOpBase,
) -> CopierInterface:
return GdsFileCopier(metadata, device, reader, framework)
return GdsFileCopier(
metadata, device, reader, framework, fallback_cache=fallback_cache
)

return construct_copier
78 changes: 78 additions & 0 deletions tests/unit/test_robustness.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
# SPDX-License-Identifier: Apache-2.0
"""Robustness tests: graceful degradation of I/O paths."""

import pytest

# ---- runtime GDS -> nogds fallback ----


def test_gds_copier_falls_back_to_nogds(input_files, framework, monkeypatch):
if framework.get_name() != "pytorch":
pytest.skip("pytorch-only")
from test_fastsafetensors import load_safetensors_file

from fastsafetensors import SafeTensorsMetadata
from fastsafetensors import cpp as fstcpp
from fastsafetensors.copier.gds import GdsFileCopier
from fastsafetensors.st_types import Device

def _boom(*a, **k):
raise RuntimeError(
"raw_gds_file_handle: cuFileHandleRegister returned an error = 5027"
)

monkeypatch.setattr(fstcpp, "gds_file_handle", _boom)

device = Device.from_str("cpu")
meta = SafeTensorsMetadata.from_file(input_files[0], framework)
reader = fstcpp.gds_file_reader(4, False, 0)
copier = GdsFileCopier(meta, device, reader, framework)
gbuf = copier.submit_io(False, 10 * 1024 * 1024 * 1024)
tensors = copier.wait_io(gbuf)
expected = load_safetensors_file(input_files[0], device, framework)
assert set(tensors.keys()) == set(expected.keys())
for k, exp in expected.items():
assert framework.is_equal(tensors[k], exp), k
framework.free_tensor_memory(gbuf, device)
# the fallback's bounce-buffer reader must not outlive the copy cycle
assert fstcpp.get_cpp_metrics().bounce_buffer_bytes == 0


def test_gds_fallback_warns_once_and_shares_reader(
input_files, framework, monkeypatch, caplog
):
if framework.get_name() != "pytorch":
pytest.skip("pytorch-only")
import logging

from fastsafetensors import SafeTensorsMetadata
from fastsafetensors import cpp as fstcpp
from fastsafetensors.copier import gds as gds_mod
from fastsafetensors.st_types import Device

monkeypatch.setattr(
fstcpp,
"gds_file_handle",
lambda *a, **k: (_ for _ in ()).throw(RuntimeError("error = 5027")),
)
monkeypatch.setattr(gds_mod, "_warned_gds_fallback", False)

device = Device.from_str("cpu")
meta = SafeTensorsMetadata.from_file(input_files[0], framework)
reader = fstcpp.gds_file_reader(4, False, 0)
cache = []
with caplog.at_level(logging.WARNING, logger="fastsafetensors.copier.gds"):
c1 = gds_mod.GdsFileCopier(
meta, device, reader, framework, fallback_cache=cache
)
g1 = c1.submit_io(False, 1 << 30)
c1.wait_io(g1)
c2 = gds_mod.GdsFileCopier(
meta, device, reader, framework, fallback_cache=cache
)
g2 = c2.submit_io(False, 1 << 30)
c2.wait_io(g2)
assert caplog.text.count("falling back to the nogds copier") == 1 # warn once
assert len(cache) == 1 # one shared nogds constructor for the whole loader
framework.free_tensor_memory(g1, device)
framework.free_tensor_memory(g2, device)
Loading