Skip to content

Commit 4a8b4a4

Browse files
gitbisectorclaude
andcommitted
Apply tensor filter to public API surface, not just I/O
set_tensor_filter() previously only narrowed which bytes the nogds/unified copiers read; the public API still advertised filtered tensors and get_tensor / get_sharded returned views onto uninitialized device memory. Reported by @takeshi-yoshimura in #81 review. * BaseSafeTensorsFileLoader.get_keys() and get_shape() consult _tensor_filter at read time (set_tensor_filter can be called before or after add_filenames). * FilesBufferOnDevice takes an optional keep_tensor kwarg that excludes filtered keys from key_to_rank_lidx. copy_files_to_device() threads self._tensor_filter through. get_tensor / get_sharded then hit the existing _get_rank_lidx ValueError; fastsafe_open.keys() inherits the same narrowing. * FilesBufferOnDevice.get_filename now raises ValueError via _get_rank_lidx for unreachable tensors (previously returned ""). The behavior change also affects unknown names; the tgis_weight example is updated to probe key_to_rank_lidx directly. * ParallelLoader.iterate_weights() iterates fb.key_to_rank_lidx.keys(), so it skips filtered tensors automatically. Adds two tests in test_fastsafetensors.py covering the API-surface contract (so they run under the existing test-torch CI job). Signed-off-by: git bisector <gitbisector@gmail.com> Co-authored-by: Claude <noreply@anthropic.com>
1 parent eb9d8af commit 4a8b4a4

4 files changed

Lines changed: 104 additions & 25 deletions

File tree

examples/tgis_weight.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -151,13 +151,13 @@ def close(self):
151151
torch.cuda.empty_cache()
152152

153153
def _get_alias(self, tensor_name: str) -> str:
154-
if self._fb.get_filename(tensor_name) is None:
155-
if tensor_name in self.aliases:
156-
for alias in self.aliases[tensor_name]:
157-
if self._fb.get_filename(alias) is not None:
158-
return alias
159-
raise RuntimeError(f"weight {tensor_name} does not exist")
160-
return tensor_name
154+
if tensor_name in self._fb.key_to_rank_lidx:
155+
return tensor_name
156+
if tensor_name in self.aliases:
157+
for alias in self.aliases[tensor_name]:
158+
if alias in self._fb.key_to_rank_lidx:
159+
return alias
160+
raise RuntimeError(f"weight {tensor_name} does not exist")
161161

162162
def get_shape(self, tensor_name: str) -> torch.Size:
163163
return torch.Size(self._fb.get_shape(self._get_alias(tensor_name)))

fastsafetensors/file_buffer.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# SPDX-License-Identifier: Apache-2.0
22

33
from collections import OrderedDict
4-
from typing import Any, Dict, List, Optional, Tuple
4+
from typing import Any, Callable, Dict, List, Optional, Tuple
55

66
from .common import init_logger
77
from .frameworks import FrameworkOpBase, ProcessGroupBase, TensorBase
@@ -28,6 +28,11 @@ class FilesBufferOnDevice:
2828
rank_loaders (Dict<rank, list(LazyTensorFacotry)>): Tensor factories per rank, which hold device pointers for buffers.
2929
pg (ProcessGroupBase): process group for calling distributed ops.
3030
auto_mem_delete (bool): automatically release device buffers when all the tensors are shuffled.
31+
keep_tensor (Callable[[str], bool], optional): If set, only tensors for
32+
which ``keep_tensor(name)`` is True are registered in ``key_to_rank_lidx``;
33+
others raise ``ValueError`` from ``get_tensor`` / ``get_filename`` /
34+
``get_shape``. Subclasses that reimplement the registration loop must
35+
honor this.
3136
3237
Examples:
3338
See examples/run_single.py and examples/run_parallel.py.
@@ -39,6 +44,7 @@ def __init__(
3944
pg: ProcessGroupBase,
4045
framework: FrameworkOpBase,
4146
auto_mem_delete: bool = True,
47+
keep_tensor: Optional[Callable[[str], bool]] = None,
4248
):
4349
self.framework = framework
4450
self.rank_loaders: Dict[int, List[LazyTensorFactory]] = rank_loaders
@@ -48,6 +54,8 @@ def __init__(
4854
self.instantiated[rank] = {}
4955
for lidx, loader in enumerate(loaders):
5056
for key in loader.metadata.tensors.keys():
57+
if keep_tensor is not None and not keep_tensor(key):
58+
continue
5159
if key in self.key_to_rank_lidx:
5260
raise Exception(
5361
f"FilesBufferOnDevice: key {key} must be unique among files"
@@ -69,9 +77,7 @@ def close(self):
6977
self.rank_loaders = {}
7078

7179
def get_filename(self, tensor_name: str) -> str:
72-
if tensor_name not in self.key_to_rank_lidx:
73-
return ""
74-
rank, lidx = self.key_to_rank_lidx[tensor_name]
80+
rank, lidx = self._get_rank_lidx(tensor_name)
7581
return self.rank_loaders[rank][lidx].metadata.src
7682

7783
def get_shape(self, tensor_name: str) -> List[int]:

fastsafetensors/loader.py

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -99,23 +99,27 @@ def close(self):
9999
del self.copier_constructor
100100

101101
def get_keys(self) -> List[str]:
102-
return list(self.frames.keys())
102+
if self._tensor_filter is None:
103+
return list(self.frames.keys())
104+
keep = self._tensor_filter
105+
return [k for k in self.frames.keys() if keep(k)]
103106

104107
def get_shape(self, tensor_name: str) -> List[int]:
108+
if self._tensor_filter is not None and not self._tensor_filter(tensor_name):
109+
raise ValueError(f"get_shape: key {tensor_name} is filtered out")
105110
return self.frames[tensor_name].shape
106111

107112
def set_tensor_filter(self, keep_tensor: Optional[Callable[[str], bool]]) -> None:
108113
"""Load only the tensors for which ``keep_tensor(name)`` is True.
109114
110-
When set, each owned file is read partially: bytes belonging to
111-
filtered-out tensors are skipped and their device-buffer regions are
112-
left uninitialized, so those tensors must not be requested afterwards.
113-
Useful e.g. under expert parallelism to read only this rank's owned
114-
experts -- see ``fastsafetensors.ep_slice.expert_parallel_filter``.
115-
``None`` (the default) loads every tensor.
116-
117-
Partial reads are implemented by the ``nogds`` and ``unified`` copiers;
118-
other copiers ignore the filter and load the full file.
115+
The ``nogds`` and ``unified`` copiers skip reading bytes for filtered
116+
tensors; other copiers load the full file. The filter narrows the
117+
public API on every backend: ``get_keys()`` omits filtered tensors,
118+
``FilesBufferOnDevice`` does not register them, and ``get_tensor``,
119+
``get_filename``, and ``get_shape`` raise ``ValueError`` for them.
120+
``ParallelLoader.iterate_weights()`` skips them. ``None`` (the
121+
default) loads every tensor. See
122+
``fastsafetensors.ep_slice.expert_parallel_filter``.
119123
"""
120124
self._tensor_filter = keep_tensor
121125

@@ -194,7 +198,12 @@ def copy_files_to_device(
194198
lidx += 1
195199
for factory in need_wait:
196200
factory.wait_io(dtype=dtype, noalign=False)
197-
return FilesBufferOnDevice(factories, pg=self.pg, framework=self.framework)
201+
return FilesBufferOnDevice(
202+
factories,
203+
pg=self.pg,
204+
framework=self.framework,
205+
keep_tensor=self._tensor_filter,
206+
)
198207

199208

200209
class SafeTensorsFileLoader(BaseSafeTensorsFileLoader):

tests/unit/test_fastsafetensors.py

Lines changed: 67 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,16 @@
77

88
import pytest
99

10-
from fastsafetensors import SafeTensorsFileLoader, SafeTensorsMetadata, SingleGroup
10+
from fastsafetensors import (
11+
ParallelLoader,
12+
SafeTensorsFileLoader,
13+
SafeTensorsMetadata,
14+
SingleGroup,
15+
)
1116
from fastsafetensors import cpp as fstcpp
12-
from fastsafetensors import fastsafe_open
17+
from fastsafetensors import (
18+
fastsafe_open,
19+
)
1320
from fastsafetensors.common import get_device_numa_node, is_gpu_found
1421
from fastsafetensors.copier.gds import GdsFileCopier
1522
from fastsafetensors.copier.nogds import NoGdsFileCopier
@@ -531,7 +538,8 @@ def test_SafeTensorsFileLoader(fstcpp_log, input_files, framework) -> None:
531538
assert bufs.get_filename(last_key) == input_files[0]
532539
assert bufs.get_shape(last_key) == last_shape
533540
assert loader.get_shape(last_key) == last_shape
534-
assert bufs.get_filename("aaaaaaaaaaaaa") == ""
541+
with pytest.raises(ValueError):
542+
bufs.get_filename("aaaaaaaaaaaaa")
535543
bufs.close()
536544
loader.close()
537545
assert framework.get_mem_used() == 0
@@ -560,6 +568,62 @@ def test_SafeTensorsFileLoaderNoGds(fstcpp_log, input_files, framework) -> None:
560568
assert fstcpp.get_cpp_metrics().bounce_buffer_bytes == 0
561569

562570

571+
def test_tensor_filter_hides_skipped_tensors(fstcpp_log, input_files, framework):
572+
device, _ = get_and_check_device(framework)
573+
meta = SafeTensorsMetadata.from_file(input_files[0], framework)
574+
575+
kept = set(sorted(meta.tensors.keys())[::2])
576+
keep = lambda name: name in kept # noqa: E731
577+
skipped = next(name for name in meta.tensors if name not in kept)
578+
579+
loader = SafeTensorsFileLoader(
580+
pg=SingleGroup(),
581+
device=device.as_str(),
582+
framework=framework.get_name(),
583+
nogds=True,
584+
)
585+
loader.set_tensor_filter(keep)
586+
loader.add_filenames({0: [input_files[0]]})
587+
bufs = loader.copy_files_to_device()
588+
589+
assert set(loader.get_keys()) == kept
590+
assert skipped not in bufs.key_to_rank_lidx
591+
with pytest.raises(ValueError):
592+
bufs.get_tensor(skipped)
593+
with pytest.raises(ValueError):
594+
bufs.get_filename(skipped)
595+
with pytest.raises(ValueError):
596+
loader.get_shape(skipped)
597+
598+
bufs.close()
599+
loader.close()
600+
601+
602+
def test_tensor_filter_iterate_weights_hides_skipped(
603+
fstcpp_log, input_files, framework
604+
):
605+
device, _ = get_and_check_device(framework)
606+
meta = SafeTensorsMetadata.from_file(input_files[0], framework)
607+
608+
kept = set(sorted(meta.tensors.keys())[::2])
609+
keep = lambda name: name in kept # noqa: E731
610+
611+
loader = ParallelLoader(
612+
pg=SingleGroup(),
613+
hf_weights_files=[input_files[0]],
614+
device=device.as_str(),
615+
nogds=True,
616+
framework=framework.get_name(),
617+
tensor_filter=keep,
618+
all_local=True,
619+
)
620+
yielded = {key for key, _t in loader.iterate_weights()}
621+
assert yielded == kept
622+
623+
loader.close()
624+
assert framework.get_mem_used() == 0
625+
626+
563627
def test_fastsafe_open(fstcpp_log, input_files, framework) -> None:
564628
device, _ = get_and_check_device(framework)
565629

0 commit comments

Comments
 (0)