Skip to content

Commit 260af04

Browse files
committed
Apply tensor filter to public API surface, not just I/O
set_tensor_filter() previously only narrowed which bytes the nogds/unified copiers read; the public API still advertised filtered-out tensors and get_tensor / get_sharded returned views onto uninitialized device memory rather than raising. Reported by @takeshi-yoshimura in #81 review. * SafeTensorsFileLoader.get_keys() honors _tensor_filter at read time so set_tensor_filter() may still be called before or after add_filenames(). * FilesBufferOnDevice gains an optional keep_tensor kwarg and excludes filtered keys from key_to_rank_lidx at construction time. SafeTensorsFileLoader.copy_files_to_device() threads its own _tensor_filter through. get_tensor / get_sharded therefore hit the existing _get_rank_lidx ValueError instead of returning uninitialized data; fastsafe_open.keys() inherits this because it iterates key_to_rank_lidx. * ParallelLoader.iterate_weights() already iterates list(fb.key_to_rank_lidx.keys()), so it skips filtered-out tensors with no further change. Adds two tests to tests/unit/test_ep_slice.py covering the new contract: the exact test from the PR review (get_keys / key_to_rank_lidx / get_tensor) plus a ParallelLoader.iterate_weights variant.
1 parent eb9d8af commit 260af04

3 files changed

Lines changed: 92 additions & 6 deletions

File tree

fastsafetensors/file_buffer.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# SPDX-License-Identifier: Apache-2.0
22

33
from collections import OrderedDict
4-
from typing import Any, Dict, List, Optional, Tuple
4+
from typing import Any, Callable, Dict, List, Optional, Tuple
55

66
from .common import init_logger
77
from .frameworks import FrameworkOpBase, ProcessGroupBase, TensorBase
@@ -39,6 +39,7 @@ def __init__(
3939
pg: ProcessGroupBase,
4040
framework: FrameworkOpBase,
4141
auto_mem_delete: bool = True,
42+
keep_tensor: Optional[Callable[[str], bool]] = None,
4243
):
4344
self.framework = framework
4445
self.rank_loaders: Dict[int, List[LazyTensorFactory]] = rank_loaders
@@ -48,6 +49,8 @@ def __init__(
4849
self.instantiated[rank] = {}
4950
for lidx, loader in enumerate(loaders):
5051
for key in loader.metadata.tensors.keys():
52+
if keep_tensor is not None and not keep_tensor(key):
53+
continue
5154
if key in self.key_to_rank_lidx:
5255
raise Exception(
5356
f"FilesBufferOnDevice: key {key} must be unique among files"

fastsafetensors/loader.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,10 @@ def close(self):
9999
del self.copier_constructor
100100

101101
def get_keys(self) -> List[str]:
102-
return list(self.frames.keys())
102+
if self._tensor_filter is None:
103+
return list(self.frames.keys())
104+
keep = self._tensor_filter
105+
return [k for k in self.frames.keys() if keep(k)]
103106

104107
def get_shape(self, tensor_name: str) -> List[int]:
105108
return self.frames[tensor_name].shape
@@ -109,13 +112,20 @@ def set_tensor_filter(self, keep_tensor: Optional[Callable[[str], bool]]) -> Non
109112
110113
When set, each owned file is read partially: bytes belonging to
111114
filtered-out tensors are skipped and their device-buffer regions are
112-
left uninitialized, so those tensors must not be requested afterwards.
115+
left uninitialized. Filtered-out tensors are also hidden from the
116+
public API: ``get_keys()`` omits them, ``FilesBufferOnDevice`` does
117+
not register them in ``key_to_rank_lidx``, and any attempt to fetch
118+
one via ``get_tensor`` / ``get_sharded`` raises ``ValueError`` rather
119+
than returning uninitialized data. ``ParallelLoader.iterate_weights``
120+
likewise skips them.
121+
113122
Useful e.g. under expert parallelism to read only this rank's owned
114123
experts -- see ``fastsafetensors.ep_slice.expert_parallel_filter``.
115124
``None`` (the default) loads every tensor.
116125
117126
Partial reads are implemented by the ``nogds`` and ``unified`` copiers;
118-
other copiers ignore the filter and load the full file.
127+
other copiers ignore the filter and load the full file, but the API
128+
surface is still narrowed for consistency regardless of backend.
119129
"""
120130
self._tensor_filter = keep_tensor
121131

@@ -194,7 +204,12 @@ def copy_files_to_device(
194204
lidx += 1
195205
for factory in need_wait:
196206
factory.wait_io(dtype=dtype, noalign=False)
197-
return FilesBufferOnDevice(factories, pg=self.pg, framework=self.framework)
207+
return FilesBufferOnDevice(
208+
factories,
209+
pg=self.pg,
210+
framework=self.framework,
211+
keep_tensor=self._tensor_filter,
212+
)
198213

199214

200215
class SafeTensorsFileLoader(BaseSafeTensorsFileLoader):

tests/unit/test_ep_slice.py

Lines changed: 69 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,12 @@
99
import pytest
1010
import torch
1111

12-
from fastsafetensors import SafeTensorsMetadata
12+
from fastsafetensors import (
13+
ParallelLoader,
14+
SafeTensorsFileLoader,
15+
SafeTensorsMetadata,
16+
SingleGroup,
17+
)
1318
from fastsafetensors import cpp as fstcpp
1419
from fastsafetensors.copier.nogds import NoGdsFileCopier
1520
from fastsafetensors.copier.unified import new_unified_copier
@@ -196,3 +201,66 @@ def test_unified_full_read_unchanged(fstcpp_log, input_files, framework):
196201
framework.free_tensor_memory(gbuf, device)
197202
del copier
198203
assert framework.get_mem_used() == 0
204+
205+
206+
# ---- filter must narrow the public API, not just I/O ----
207+
208+
209+
def test_tensor_filter_hides_skipped_tensors(fstcpp_log, input_files, framework):
210+
"""A skipped tensor must not appear in get_keys / key_to_rank_lidx, and
211+
asking for it via get_tensor must raise ValueError rather than returning
212+
uninitialized device memory."""
213+
device, _ = get_and_check_device(framework)
214+
meta = SafeTensorsMetadata.from_file(input_files[0], framework)
215+
216+
kept = set(sorted(meta.tensors.keys())[::2])
217+
keep = lambda name: name in kept # noqa: E731
218+
skipped = next(name for name in meta.tensors if name not in kept)
219+
220+
loader = SafeTensorsFileLoader(
221+
pg=SingleGroup(),
222+
device=device.as_str(),
223+
framework=framework.get_name(),
224+
nogds=True,
225+
)
226+
loader.set_tensor_filter(keep)
227+
loader.add_filenames({0: [input_files[0]]})
228+
bufs = loader.copy_files_to_device()
229+
230+
assert set(loader.get_keys()) == kept
231+
assert skipped not in bufs.key_to_rank_lidx
232+
with pytest.raises(ValueError):
233+
bufs.get_tensor(skipped)
234+
235+
bufs.close()
236+
loader.close()
237+
238+
239+
def test_tensor_filter_iterate_weights_hides_skipped(
240+
fstcpp_log, input_files, framework
241+
):
242+
"""ParallelLoader.iterate_weights must yield only filter-kept tensors;
243+
skipped ones (which would return uninitialized data) must never appear."""
244+
device, _ = get_and_check_device(framework)
245+
meta = SafeTensorsMetadata.from_file(input_files[0], framework)
246+
247+
kept = set(sorted(meta.tensors.keys())[::2])
248+
keep = lambda name: name in kept # noqa: E731
249+
assert kept and any(
250+
n not in kept for n in meta.tensors
251+
), "fixture should split into kept and skipped"
252+
253+
loader = ParallelLoader(
254+
pg=SingleGroup(),
255+
hf_weights_files=[input_files[0]],
256+
device=device.as_str(),
257+
nogds=True,
258+
framework=framework.get_name(),
259+
tensor_filter=keep,
260+
all_local=True, # required when tensor_filter drops tensors
261+
)
262+
yielded = {key for key, _tensor in loader.iterate_weights()}
263+
assert yielded == kept
264+
265+
loader.close()
266+
assert framework.get_mem_used() == 0

0 commit comments

Comments
 (0)