22
33import platform
44import warnings
5- from typing import Dict , Optional
5+ from typing import Dict , List , Optional
66
77from .. import cpp as fstcpp
88from ..common import SafeTensorsMetadata , init_logger , is_gpu_found
1414
1515logger = init_logger (__name__ )
1616
17+ _warned_gds_fallback = False
18+
1719
1820class GdsFileCopier (CopierInterface ):
1921 def __init__ (
@@ -22,6 +24,7 @@ def __init__(
2224 device : Device ,
2325 reader : fstcpp .gds_file_reader ,
2426 framework : FrameworkOpBase ,
27+ fallback_cache : Optional [List [CopierConstructFunc ]] = None ,
2528 ):
2629 self .framework = framework
2730 self .metadata = metadata
@@ -31,6 +34,11 @@ def __init__(
3134 self .fh : Optional [fstcpp .gds_file_handle ] = None
3235 self .copy_reqs : Dict [int , int ] = {}
3336 self .aligned_length = 0
37+ self ._fallback : Optional [CopierInterface ] = None
38+ # One-slot cell shared by all copiers from the same factory, so a
39+ # broken-GDS host builds a single nogds fallback reader (and its
40+ # pinned bounce buffer) per loader instead of one per file.
41+ self ._fallback_cache = fallback_cache
3442 cuda_ver = framework .get_cuda_ver ()
3543 if cuda_ver and cuda_ver != "0.0" :
3644 # Parse version string (e.g., "cuda-12.1" or "hip-5.7.0")
@@ -65,7 +73,45 @@ def submit_io(
6573 self .device .type == DeviceType .CUDA or self .device .type == DeviceType .GPU
6674 )
6775 ALIGN : int = fstcpp .get_alignment_size ()
68- self .fh = fstcpp .gds_file_handle (self .metadata .src , self .o_direct , dev_is_cuda )
76+ try :
77+ self .fh = fstcpp .gds_file_handle (
78+ self .metadata .src , self .o_direct , dev_is_cuda
79+ )
80+ except RuntimeError as e :
81+ # cuFile can probe as available yet fail at I/O time: handle
82+ # registration errors on compat-mode hosts or unsupported
83+ # filesystems (e.g. overlayfs), or open(O_DIRECT) rejections.
84+ # Downgrade this copier to the nogds bounce path instead of
85+ # failing, so consumers don't each need their own gds->nogds
86+ # retry. Deliberately limited to file-handle setup: failures in
87+ # already-submitted reads stay fatal (falling back mid-cycle
88+ # would re-read earlier data).
89+ global _warned_gds_fallback
90+ if not _warned_gds_fallback :
91+ _warned_gds_fallback = True
92+ # str(e): keeping the exception object in the log record would
93+ # retain its traceback (and this frame's locals) via any
94+ # record-capturing handler.
95+ logger .warning (
96+ "GDS file-handle setup failed (%s); "
97+ "falling back to the nogds copier" ,
98+ str (e ),
99+ )
100+ if self ._fallback_cache is not None :
101+ if not self ._fallback_cache :
102+ self ._fallback_cache .append (
103+ new_nogds_file_copier (self .device , framework = self .framework )
104+ )
105+ self ._fallback = self ._fallback_cache [0 ](
106+ self .metadata , self .device , self .framework
107+ )
108+ else :
109+ # direct construction (no factory): reader lives only for this
110+ # file's submit/wait cycle and is released in wait_io
111+ self ._fallback = new_nogds_file_copier (
112+ self .device , framework = self .framework
113+ )(self .metadata , self .device , self .framework )
114+ return self ._fallback .submit_io (use_buf_register , max_copy_block_size )
69115 offset = self .metadata .header_length
70116 length = self .metadata .size_bytes - self .metadata .header_length
71117 head_bytes = offset % ALIGN
@@ -120,6 +166,11 @@ def wait_io(
120166 dtype : DType = DType .AUTO ,
121167 noalign : bool = False ,
122168 ) -> Dict [str , TensorBase ]:
169+ if self ._fallback is not None :
170+ tensors = self ._fallback .wait_io (gbuf , dtype = dtype , noalign = noalign )
171+ # Drop the fallback copier so its bounce-buffer reader is freed.
172+ self ._fallback = None
173+ return tensors
123174 failed = []
124175 for req , c in sorted (self .copy_reqs .items (), key = lambda x : x [0 ]):
125176 count = self .reader .wait_read (req )
@@ -222,11 +273,15 @@ def new_gds_file_copier(
222273
223274 reader = fstcpp .gds_file_reader (max_threads , device_is_not_cpu , device_id )
224275
276+ fallback_cache : List [CopierConstructFunc ] = []
277+
225278 def construct_copier (
226279 metadata : SafeTensorsMetadata ,
227280 device : Device ,
228281 framework : FrameworkOpBase ,
229282 ) -> CopierInterface :
230- return GdsFileCopier (metadata , device , reader , framework )
283+ return GdsFileCopier (
284+ metadata , device , reader , framework , fallback_cache = fallback_cache
285+ )
231286
232287 return construct_copier
0 commit comments