diff --git a/examples/tgis_weight.py b/examples/tgis_weight.py index 6b1538f..157031b 100644 --- a/examples/tgis_weight.py +++ b/examples/tgis_weight.py @@ -151,13 +151,13 @@ def close(self): torch.cuda.empty_cache() def _get_alias(self, tensor_name: str) -> str: - if self._fb.get_filename(tensor_name) is None: - if tensor_name in self.aliases: - for alias in self.aliases[tensor_name]: - if self._fb.get_filename(alias) is not None: - return alias - raise RuntimeError(f"weight {tensor_name} does not exist") - return tensor_name + if tensor_name in self._fb.key_to_rank_lidx: + return tensor_name + if tensor_name in self.aliases: + for alias in self.aliases[tensor_name]: + if alias in self._fb.key_to_rank_lidx: + return alias + raise RuntimeError(f"weight {tensor_name} does not exist") def get_shape(self, tensor_name: str) -> torch.Size: return torch.Size(self._fb.get_shape(self._get_alias(tensor_name))) diff --git a/fastsafetensors/__init__.py b/fastsafetensors/__init__.py index 2d009b7..72d9ecd 100644 --- a/fastsafetensors/__init__.py +++ b/fastsafetensors/__init__.py @@ -12,6 +12,11 @@ get_device_numa_node, ) from .config import LoaderConfig, load_config +from .ep_slice import ( + expert_parallel_filter, + expert_parallel_filter_from_env, + owned_expert_range, +) from .file_buffer import FilesBufferOnDevice from .loader import SafeTensorsFileLoader, fastsafe_open from .parallel_loader import ParallelLoader diff --git a/fastsafetensors/common.py b/fastsafetensors/common.py index e2d2b02..0f86829 100644 --- a/fastsafetensors/common.py +++ b/fastsafetensors/common.py @@ -6,7 +6,7 @@ import sys from collections import OrderedDict from dataclasses import dataclass -from typing import Dict, List, Optional, Tuple +from typing import Callable, Dict, List, Optional, Tuple from . import cpp as fstcpp from .dlpack import from_cuda_buffer @@ -359,6 +359,38 @@ def get_tensors( ret[tensor_name] = t2 return ret + def select_byte_ranges( + self, keep_tensor: Callable[[str], bool], merge_gap: int = 4096 + ) -> List[Tuple[int, int]]: + """Compute the file byte-ranges covering only the kept tensors. + + Returns a sorted list of ``[start, end)`` absolute file offsets spanning + exactly the tensors for which ``keep_tensor(name)`` is True. Kept tensors + separated by a gap of at most ``merge_gap`` bytes are coalesced into one + range to reduce the number of reads; the few non-kept bytes inside a + coalesced range are read but never instantiated as tensors. + + Pass the result to a partial-read-capable copier (see + ``NoGdsFileCopier.set_byte_ranges``) to load only a subset of a shard -- + e.g. only the experts an expert-parallel rank owns. Tensor data offsets + are unchanged, so unread regions of the device buffer simply stay + uninitialized and their tensors must not be requested. + """ + ranges: List[Tuple[int, int]] = [] + for name, frame in self.tensors.items(): + if not keep_tensor(name): + continue + s, e = frame.data_offsets[0], frame.data_offsets[1] + ranges.append((self.header_length + s, self.header_length + e)) + ranges.sort() + merged: List[List[int]] = [] + for s, e in ranges: + if merged and s - merged[-1][1] <= merge_gap: + merged[-1][1] = max(merged[-1][1], e) + else: + merged.append([s, e]) + return [(s, e) for s, e in merged] + def __repr__(self) -> str: return str({"__metadata__": self.metadata, "tensors": self.tensors}) diff --git a/fastsafetensors/copier/base.py b/fastsafetensors/copier/base.py index 1af8b22..29aa8b3 100644 --- a/fastsafetensors/copier/base.py +++ b/fastsafetensors/copier/base.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 from abc import ABC, abstractmethod -from typing import Dict +from typing import Dict, List, Optional, Tuple from .. import cpp as fstcpp from ..frameworks import TensorBase @@ -9,6 +9,18 @@ class CopierInterface(ABC): + def set_byte_ranges(self, byte_ranges: Optional[List[Tuple[int, int]]]) -> None: + """Restrict reads to these ``[start, end)`` absolute file-offset runs. + + The default implementation ignores the runs and reads the whole file, so + the byte-range filter is a correct no-op on copiers that don't implement + partial reads. Range-capable copiers (``nogds``, ``unified``) override + this to read only the given runs, leaving the rest of the device buffer + uninitialized (so skipped tensors must not be requested). Build runs with + ``SafeTensorsMetadata.select_byte_ranges``; ``None`` means full read. + """ + return + @abstractmethod def submit_io( self, use_buf_register: bool, max_copy_block_size: int diff --git a/fastsafetensors/copier/nogds.py b/fastsafetensors/copier/nogds.py index ccdf941..fbecedc 100644 --- a/fastsafetensors/copier/nogds.py +++ b/fastsafetensors/copier/nogds.py @@ -2,7 +2,7 @@ import os import sys -from typing import Dict, List +from typing import Dict, List, Optional, Tuple from .. import cpp as fstcpp from ..common import SafeTensorsMetadata, is_gpu_found, resolve_cudart_lib_name @@ -35,24 +35,42 @@ def __init__( ) self.device = device self.reqs: List[int] = [] + self.byte_ranges: Optional[List[Tuple[int, int]]] = None + + def set_byte_ranges(self, byte_ranges: Optional[List[Tuple[int, int]]]) -> None: + """Restrict reads to these ``[start, end)`` absolute file-offset runs. + + Bytes outside the given runs are not read; their regions of the device + buffer are left uninitialized, so the corresponding tensors must not be + requested. ``None`` (the default) reads the whole data section. Build + runs with ``SafeTensorsMetadata.select_byte_ranges``. + """ + self.byte_ranges = byte_ranges def submit_io( self, use_buf_register: bool, max_copy_block_size: int ) -> fstcpp.gds_device_buffer: - total_length = self.metadata.size_bytes - self.metadata.header_length + header_length = self.metadata.header_length + total_length = self.metadata.size_bytes - header_length gbuf = self.framework.alloc_tensor_memory(total_length, self.device) - count = 0 - while count < total_length: - l = total_length - count - if max_copy_block_size < l: - l = max_copy_block_size - req = self.reader.submit_read( - self.fd, gbuf, self.metadata.header_length + count, l, count - ) - if req < 0: - raise Exception(f"submit_io: submit_nogds_read failed, err={req}") - self.reqs.append(req) - count += l + # Default to a single run spanning the whole data section, which + # reproduces the original full-file read. + runs = self.byte_ranges + if runs is None: + runs = [(header_length, self.metadata.size_bytes)] + for start, end in runs: + count = start + while count < end: + l = end - count + if max_copy_block_size < l: + l = max_copy_block_size + req = self.reader.submit_read( + self.fd, gbuf, count, l, count - header_length + ) + if req < 0: + raise Exception(f"submit_io: submit_nogds_read failed, err={req}") + self.reqs.append(req) + count += l return gbuf def wait_io( diff --git a/fastsafetensors/copier/unified.py b/fastsafetensors/copier/unified.py index 9af2568..f5366f5 100644 --- a/fastsafetensors/copier/unified.py +++ b/fastsafetensors/copier/unified.py @@ -9,7 +9,7 @@ """ import os -from typing import Dict, Optional +from typing import Dict, List, Optional, Tuple import torch @@ -40,12 +40,25 @@ def __init__( self.device = device self.framework = framework self._file_tensor: Optional[torch.Tensor] = None - self._pinned: Optional[torch.Tensor] = None + self._pinned: List[torch.Tensor] = [] + self.byte_ranges: Optional[List[Tuple[int, int]]] = None + + def set_byte_ranges(self, byte_ranges: Optional[List[Tuple[int, int]]]) -> None: + """Restrict reads to these ``[start, end)`` absolute file-offset runs. + + Only the bytes in the given runs are mmap-faulted, pinned, and copied; + the rest of the device buffer is left uninitialized (so the corresponding + tensors must not be requested). Tensor offsets are unchanged. ``None`` + reads the whole data section. Build runs with + ``SafeTensorsMetadata.select_byte_ranges``. + """ + self.byte_ranges = byte_ranges def submit_io( self, use_buf_register: bool, max_copy_block_size: int ) -> fstcpp.gds_device_buffer: - data_length = self.metadata.size_bytes - self.metadata.header_length + header_length = self.metadata.header_length + data_length = self.metadata.size_bytes - header_length # Allocate CUDA buffer via framework's allocator (proper lifecycle) gbuf = self.framework.alloc_tensor_memory(data_length, self.device) @@ -55,25 +68,32 @@ def submit_io( self.metadata.src, size=self.metadata.size_bytes, dtype=torch.uint8 ) self._file_tensor = file_tensor - data_tensor = file_tensor[self.metadata.header_length :] - # pin_memory triggers kernel readahead + pins pages for DMA - pinned = data_tensor.pin_memory() - self._pinned = pinned - - # Async DMA from pinned CPU → framework-allocated CUDA buffer - ret = fstcpp.memcpy_h2d_async( # type: ignore[attr-defined] - gbuf.get_base_address(), - pinned.data_ptr(), - data_length, - ) - if ret != 0: - self.framework.free_tensor_memory(gbuf, self.device) - self._pinned = None - self._file_tensor = None - raise RuntimeError( - f"cudaMemcpyAsync failed with error {ret} " f"for {self.metadata.src}" + # Default to the whole data section, reproducing the full-file read. + # An empty list (vs None) reads nothing — same semantics as nogds. + runs = self.byte_ranges + if runs is None: + runs = [(header_length, self.metadata.size_bytes)] + + base_address = gbuf.get_base_address() + self._pinned = [] + for start, end in runs: + # pin_memory faults in + pins only this run's pages, then DMA to the + # matching offset in gbuf (data section starts at header_length). + pinned = file_tensor[start:end].pin_memory() + self._pinned.append(pinned) + ret = fstcpp.memcpy_h2d_async( # type: ignore[attr-defined] + base_address + (start - header_length), + pinned.data_ptr(), + end - start, ) + if ret != 0: + self.framework.free_tensor_memory(gbuf, self.device) + self._pinned = [] + self._file_tensor = None + raise RuntimeError( + f"cudaMemcpyAsync failed with error {ret} for {self.metadata.src}" + ) return gbuf @@ -94,7 +114,7 @@ def wait_io( ) # Release mmap and pinned memory - self._pinned = None + self._pinned = [] self._file_tensor = None return tensors diff --git a/fastsafetensors/ep_slice.py b/fastsafetensors/ep_slice.py new file mode 100644 index 0000000..4e56cff --- /dev/null +++ b/fastsafetensors/ep_slice.py @@ -0,0 +1,114 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Expert-parallel (EP) slice helpers. + +Under expert parallelism each rank only *uses* the routed experts it owns, yet +file-granular loading makes every rank read the whole shard -- the unowned +experts' bytes are read and then discarded. These helpers build a tensor-name +predicate selecting just this rank's owned experts (plus every non-expert +tensor), so a partial-read-capable loader can skip the unowned bytes: + + from fastsafetensors import SafeTensorsFileLoader + from fastsafetensors.ep_slice import expert_parallel_filter + + loader = SafeTensorsFileLoader(pg, device, nogds=True) + loader.set_tensor_filter(expert_parallel_filter(num_experts=256, + ep_size=2, ep_rank=rank)) + loader.add_filenames(...) + bufs = loader.copy_files_to_device() + +Owned experts use contiguous-block ("linear") assignment: each rank owns +``num_experts // ep_size`` consecutive experts, with any remainder given to the +lowest-numbered ranks. This is a common expert-to-rank convention; the caller is +responsible for ensuring it matches the assignment its runtime expects. No +external dependency is required. +""" +import os +import re +from typing import Callable, Optional, Pattern, Tuple + +# Matches the per-expert index in routed-MoE tensor names, e.g. +# "model.layers.3.mlp.experts.42.w1.weight" or DeepSeek's +# "...ffn.experts.42.gate_proj.weight". Override for a different convention. +DEFAULT_EXPERT_PATTERN: Pattern[str] = re.compile(r"\.experts\.(\d+)\.") + + +def owned_expert_range(num_experts: int, ep_size: int, ep_rank: int) -> Tuple[int, int]: + """Return the ``[lo, hi)`` routed-expert indices owned by ``ep_rank``. + + Contiguous-block ("linear") assignment: each rank owns a consecutive block + of experts, with the remainder distributed to the lowest-numbered ranks. + """ + if ep_size <= 0: + raise ValueError(f"ep_size must be positive, got {ep_size}") + if not 0 <= ep_rank < ep_size: + raise ValueError(f"ep_rank {ep_rank} out of range for ep_size {ep_size}") + base = num_experts // ep_size + rem = num_experts % ep_size + local = base + (1 if ep_rank < rem else 0) + start = ep_rank * base + min(ep_rank, rem) + return (start, start + local) + + +def expert_parallel_filter( + num_experts: int, + ep_size: int, + ep_rank: int, + pattern: Pattern[str] = DEFAULT_EXPERT_PATTERN, +) -> Callable[[str], bool]: + """Build a ``keep(name) -> bool`` predicate for this EP rank. + + Non-expert tensors (names not matching ``pattern``) are kept on every rank; + routed-expert tensors are kept only when their index is in this rank's owned + range. Pass the predicate to ``SafeTensorsFileLoader.set_tensor_filter`` or + ``SafeTensorsMetadata.select_byte_ranges``. + """ + lo, hi = owned_expert_range(num_experts, ep_size, ep_rank) + + def keep(name: str) -> bool: + m = pattern.search(name) + if m is None: + return True + return lo <= int(m.group(1)) < hi + + return keep + + +def expert_parallel_filter_from_env() -> Optional[Callable[[str], bool]]: + """Build an EP filter from environment variables, or ``None`` if disabled. + + Recognized variables (kept compatible with the DGX Spark overlay this + prototype generalizes): + + ``FASTSAFETENSORS_EP_SLICE=1`` enable EP-slice reading + ``FASTSAFETENSORS_EP_NUM_EXPERTS=N`` global routed-expert count (required) + ``FASTSAFETENSORS_EP_SIZE`` / ``_RANK`` override EP size/rank; otherwise + taken from the initialized + torch.distributed group, else from + ``WORLD_SIZE`` / ``RANK``. + + Returns ``None`` (load everything) unless EP-slice is enabled, the expert + count is known, and ``ep_size > 1``. + """ + if os.getenv("FASTSAFETENSORS_EP_SLICE", "0") != "1": + return None + num_experts = int(os.getenv("FASTSAFETENSORS_EP_NUM_EXPERTS", "0")) + if num_experts <= 0: + return None + ep_size = int(os.getenv("FASTSAFETENSORS_EP_SIZE", "0")) + ep_rank = int(os.getenv("FASTSAFETENSORS_EP_RANK", "-1")) + if ep_size <= 0 or ep_rank < 0: + try: + import torch.distributed as dist + + if dist.is_available() and dist.is_initialized(): + ep_size = dist.get_world_size() + ep_rank = dist.get_rank() + except Exception: + pass + if ep_size <= 0: + ep_size = int(os.getenv("WORLD_SIZE", "1")) + if ep_rank < 0: + ep_rank = int(os.getenv("RANK", "0")) + if ep_size <= 1: + return None + return expert_parallel_filter(num_experts, ep_size, ep_rank) diff --git a/fastsafetensors/file_buffer.py b/fastsafetensors/file_buffer.py index 519e4b4..fb94804 100644 --- a/fastsafetensors/file_buffer.py +++ b/fastsafetensors/file_buffer.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 from collections import OrderedDict -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Callable, Dict, List, Optional, Tuple from .common import init_logger from .frameworks import FrameworkOpBase, ProcessGroupBase, TensorBase @@ -28,6 +28,11 @@ class FilesBufferOnDevice: rank_loaders (Dict): Tensor factories per rank, which hold device pointers for buffers. pg (ProcessGroupBase): process group for calling distributed ops. auto_mem_delete (bool): automatically release device buffers when all the tensors are shuffled. + keep_tensor (Callable[[str], bool], optional): If set, only tensors for + which ``keep_tensor(name)`` is True are registered in ``key_to_rank_lidx``; + others raise ``ValueError`` from ``get_tensor`` / ``get_filename`` / + ``get_shape``. Subclasses that reimplement the registration loop must + honor this. Examples: See examples/run_single.py and examples/run_parallel.py. @@ -39,6 +44,7 @@ def __init__( pg: ProcessGroupBase, framework: FrameworkOpBase, auto_mem_delete: bool = True, + keep_tensor: Optional[Callable[[str], bool]] = None, ): self.framework = framework self.rank_loaders: Dict[int, List[LazyTensorFactory]] = rank_loaders @@ -48,6 +54,8 @@ def __init__( self.instantiated[rank] = {} for lidx, loader in enumerate(loaders): for key in loader.metadata.tensors.keys(): + if keep_tensor is not None and not keep_tensor(key): + continue if key in self.key_to_rank_lidx: raise Exception( f"FilesBufferOnDevice: key {key} must be unique among files" @@ -69,9 +77,7 @@ def close(self): self.rank_loaders = {} def get_filename(self, tensor_name: str) -> str: - if tensor_name not in self.key_to_rank_lidx: - return "" - rank, lidx = self.key_to_rank_lidx[tensor_name] + rank, lidx = self._get_rank_lidx(tensor_name) return self.rank_loaders[rank][lidx].metadata.src def get_shape(self, tensor_name: str) -> List[int]: diff --git a/fastsafetensors/loader.py b/fastsafetensors/loader.py index 1e8575e..20b7427 100644 --- a/fastsafetensors/loader.py +++ b/fastsafetensors/loader.py @@ -2,7 +2,17 @@ import math import platform -from typing import Any, Dict, List, Mapping, Optional, OrderedDict, Tuple, Union +from typing import ( + Any, + Callable, + Dict, + List, + Mapping, + Optional, + OrderedDict, + Tuple, + Union, +) from . import cpp as fstcpp from .common import ( @@ -64,6 +74,7 @@ def __init__( self.meta: Dict[str, Tuple[SafeTensorsMetadata, int]] = {} self.frames = OrderedDict[str, TensorFrame]() self.disable_cache = disable_cache + self._tensor_filter: Optional[Callable[[str], bool]] = None self.init_numa(set_numa) self.copier_constructor: CopierConstructFunc = create_copier_constructor( copier_type=copier_type, @@ -88,11 +99,30 @@ def close(self): del self.copier_constructor def get_keys(self) -> List[str]: - return list(self.frames.keys()) + if self._tensor_filter is None: + return list(self.frames.keys()) + keep = self._tensor_filter + return [k for k in self.frames.keys() if keep(k)] def get_shape(self, tensor_name: str) -> List[int]: + if self._tensor_filter is not None and not self._tensor_filter(tensor_name): + raise ValueError(f"get_shape: key {tensor_name} is filtered out") return self.frames[tensor_name].shape + def set_tensor_filter(self, keep_tensor: Optional[Callable[[str], bool]]) -> None: + """Load only the tensors for which ``keep_tensor(name)`` is True. + + The ``nogds`` and ``unified`` copiers skip reading bytes for filtered + tensors; other copiers load the full file. The filter narrows the + public API on every backend: ``get_keys()`` omits filtered tensors, + ``FilesBufferOnDevice`` does not register them, and ``get_tensor``, + ``get_filename``, and ``get_shape`` raise ``ValueError`` for them. + ``ParallelLoader.iterate_weights()`` skips them. ``None`` (the + default) loads every tensor. See + ``fastsafetensors.ep_slice.expert_parallel_filter``. + """ + self._tensor_filter = keep_tensor + def add_filenames(self, filenames: Dict[int, List[str]]): """ Register files to ranks to be copied at copy_file_to_device(). @@ -144,6 +174,10 @@ def copy_files_to_device( self_rank = self.pg.rank() == rank if self_rank: copier = self.copier_constructor(meta, self.device, self.framework) + if self._tensor_filter is not None and hasattr( + copier, "set_byte_ranges" + ): + copier.set_byte_ranges(meta.select_byte_ranges(self._tensor_filter)) else: copier = None factory = LazyTensorFactory( @@ -164,7 +198,12 @@ def copy_files_to_device( lidx += 1 for factory in need_wait: factory.wait_io(dtype=dtype, noalign=False) - return FilesBufferOnDevice(factories, pg=self.pg, framework=self.framework) + return FilesBufferOnDevice( + factories, + pg=self.pg, + framework=self.framework, + keep_tensor=self._tensor_filter, + ) class SafeTensorsFileLoader(BaseSafeTensorsFileLoader): diff --git a/fastsafetensors/parallel_loader.py b/fastsafetensors/parallel_loader.py index ed49606..19f7306 100644 --- a/fastsafetensors/parallel_loader.py +++ b/fastsafetensors/parallel_loader.py @@ -4,7 +4,7 @@ import queue import threading import time -from typing import Any, Generator, List, Optional, Tuple, Union +from typing import Any, Callable, Generator, List, Optional, Tuple, Union import torch @@ -20,6 +20,7 @@ def tqdm(iterable, *args, **kwargs): from . import cpp as fstcpp +from .common import SingleGroup from .loader import BaseSafeTensorsFileLoader, SafeTensorsFileLoader @@ -139,6 +140,7 @@ def __init__( # >0 : buffered pipeline — up to (queue_size+1) batches in GPU mem queue_size: int = 0, use_tqdm_on_load: bool = True, + tensor_filter: Optional[Callable[[str], bool]] = None, **kwargs, ): @@ -149,6 +151,13 @@ def __init__( "batches must be processed in strict order across all ranks." ) self.loader = loader + # Read only the tensors this rank keeps (e.g. its owned experts); see + # SafeTensorsFileLoader.set_tensor_filter. NOTE: get_tensor broadcasts + # across the loader's process group, so a per-rank filter is only + # correct when the loader uses a single-process group (no broadcast) -- + # use ParallelLoader(all_local=True), which sets that up. + if tensor_filter is not None: + loader.set_tensor_filter(tensor_filter) self.hf_weights_files = hf_weights_files self.max_concurrent_producers = max_concurrent_producers self.queue_size = queue_size @@ -436,6 +445,18 @@ class ParallelLoader(PipelineParallel): set_numa (bool): If True, set NUMA node for optimal memory allocation. debug_log (bool): Enable debug logs. framework (str): Framework to use for tensor operations, e.g., "pytorch". + tensor_filter (Optional[Callable[[str], bool]]): If set, read only the + tensors for which the predicate is True (see + SafeTensorsFileLoader.set_tensor_filter and + fastsafetensors.ep_slice.expert_parallel_filter). Must + be paired with all_local=True. + all_local (bool): If True, the loader uses a single-process group so + every rank loads its files independently with no + cross-rank broadcast. Required when tensor_filter drops + tensors that other ranks would otherwise receive via + broadcast (e.g. expert-parallel slicing). The EP + rank/size for the filter come from the real + distributed world, independent of the loader's group. Additional GPU memory consumption: (max_concurrent_producers + queue_size) * file_size To reduce GPU memory consumption, re-accessing tensors that have already been accessed is prohibited. @@ -462,6 +483,8 @@ def __init__( set_numa: bool = True, debug_log: bool = False, framework="pytorch", + tensor_filter: Optional[Callable[[str], bool]] = None, + all_local: bool = False, **kwargs, ): """Initialize PipelineParallelLoader with a pre-configured SafeTensorsFileLoader. @@ -480,8 +503,13 @@ def __init__( debug_log (bool): Enable debug logs. framework (str): Framework to use for tensor operations. """ + # all_local: load with a single-process group so each rank reads its + # files independently (no cross-rank broadcast in get_tensor). This is + # what makes a per-rank tensor_filter correct -- otherwise get_tensor + # would broadcast tensors this rank never read. + loader_pg = SingleGroup() if all_local else pg loader = SafeTensorsFileLoader( - pg, + loader_pg, device, bbuf_size_kb=bbuf_size_kb, max_threads=max_threads, @@ -493,11 +521,12 @@ def __init__( **kwargs, ) super().__init__( - pg, + loader_pg, loader, hf_weights_files, max_concurrent_producers, queue_size, use_tqdm_on_load, + tensor_filter=tensor_filter, **kwargs, ) diff --git a/tests/unit/test_ep_slice.py b/tests/unit/test_ep_slice.py new file mode 100644 index 0000000..edb7a4e --- /dev/null +++ b/tests/unit/test_ep_slice.py @@ -0,0 +1,198 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Tests for the sub-file byte-range read primitive and the EP-slice demonstrator. + +The expert-range math is pure Python (no GPU / C extension needed). The +partial-read tests reuse the gpt2 fixture and the nogds copier to prove that +loading only a selected subset of tensors yields byte-identical data for the +kept tensors while skipping the rest. +""" +import pytest +import torch + +from fastsafetensors import SafeTensorsMetadata +from fastsafetensors import cpp as fstcpp +from fastsafetensors.copier.nogds import NoGdsFileCopier +from fastsafetensors.copier.unified import new_unified_copier +from fastsafetensors.ep_slice import ( + expert_parallel_filter, + owned_expert_range, +) + +# The unified copier (mmap → pin_memory → cudaMemcpyAsync) needs a CUDA device; +# skip its partial-read tests on CPU-only runners. +_requires_cuda = pytest.mark.skipif( + not torch.cuda.is_available(), reason="unified copier requires a CUDA device" +) + +# Reuse helpers from the main test module (tests/unit is on sys.path via conftest). +from test_fastsafetensors import get_and_check_device, load_safetensors_file + +# ---- pure-Python EP range math (contiguous-block "linear" assignment) ---- + + +def test_owned_expert_range_even(): + assert owned_expert_range(256, 2, 0) == (0, 128) + assert owned_expert_range(256, 2, 1) == (128, 256) + + +def test_owned_expert_range_remainder(): + # remainder goes to the lowest-numbered ranks + assert owned_expert_range(10, 3, 0) == (0, 4) + assert owned_expert_range(10, 3, 1) == (4, 7) + assert owned_expert_range(10, 3, 2) == (7, 10) + # whole owned set tiles the expert space with no gaps/overlaps + covered = [] + for r in range(4): + lo, hi = owned_expert_range(13, 4, r) + covered.extend(range(lo, hi)) + assert covered == list(range(13)) + + +def test_owned_expert_range_invalid(): + with pytest.raises(ValueError): + owned_expert_range(8, 0, 0) + with pytest.raises(ValueError): + owned_expert_range(8, 2, 2) + + +def test_expert_parallel_filter_keeps_nonexpert_and_owned(): + keep = expert_parallel_filter(num_experts=256, ep_size=2, ep_rank=0) + # non-expert tensors are kept on every rank + assert keep("model.embed_tokens.weight") is True + assert keep("model.layers.0.self_attn.q_proj.weight") is True + # owned vs unowned routed experts + assert keep("model.layers.0.mlp.experts.5.w1.weight") is True + assert keep("model.layers.0.mlp.experts.200.w1.weight") is False + # DeepSeek-style "ffn.experts" naming also matches the default pattern + assert keep("model.layers.3.ffn.experts.10.gate_proj.weight") is True + assert keep("model.layers.3.ffn.experts.130.gate_proj.weight") is False + + +# ---- byte-range selection + partial read (uses the gpt2 fixture) ---- + + +def _keep_every_other(meta: SafeTensorsMetadata): + """A non-EP predicate exercising the primitive on a model without experts: + keep every other tensor by sorted name (so kept tensors are non-adjacent and + produce multiple, non-mergeable runs).""" + kept = set(sorted(meta.tensors.keys())[::2]) + return lambda name: name in kept + + +def test_select_byte_ranges_all_equals_full(input_files, framework): + meta = SafeTensorsMetadata.from_file(input_files[0], framework) + ranges = meta.select_byte_ranges(lambda name: True) + # contiguous tensors with no large gaps coalesce into one run that begins at + # the data section and never exceeds the file size + assert len(ranges) == 1 + assert ranges[0][0] == meta.header_length + assert ranges[0][1] <= meta.size_bytes + + +def test_select_byte_ranges_covers_only_kept(input_files, framework): + meta = SafeTensorsMetadata.from_file(input_files[0], framework) + keep = _keep_every_other(meta) + ranges = meta.select_byte_ranges(keep) + # sorted, non-overlapping + for (a_lo, a_hi), (b_lo, b_hi) in zip(ranges, ranges[1:]): + assert a_hi <= b_lo + # every kept tensor is fully covered by some run + for name, fr in meta.tensors.items(): + if not keep(name): + continue + s = meta.header_length + fr.data_offsets[0] + e = meta.header_length + fr.data_offsets[1] + assert any(lo <= s and e <= hi for lo, hi in ranges), name + + +def test_nogds_partial_read_byte_identical(fstcpp_log, input_files, framework): + device, dev_is_gpu = get_and_check_device(framework) + meta = SafeTensorsMetadata.from_file(input_files[0], framework) + keep = _keep_every_other(meta) + ranges = meta.select_byte_ranges(keep) + + reader = fstcpp.nogds_file_reader( + False, 256 * 1024, 4, dev_is_gpu, device.index or 0 + ) + copier = NoGdsFileCopier(meta, device, reader, framework) + copier.set_byte_ranges(ranges) + gbuf = copier.submit_io(False, 10 * 1024 * 1024 * 1024) + tensors = copier.wait_io(gbuf) + + ref = load_safetensors_file(input_files[0], device, framework) + kept_names = [n for n in meta.tensors if keep(n)] + assert kept_names, "fixture should have at least one kept tensor" + for name in kept_names: + assert framework.is_equal(tensors[name], ref[name]), name + + framework.free_tensor_memory(gbuf, device) + del copier + del reader + assert framework.get_mem_used() == 0 + + +def test_nogds_full_read_unchanged(fstcpp_log, input_files, framework): + """set_byte_ranges(None) must reproduce the original full-file load exactly.""" + device, dev_is_gpu = get_and_check_device(framework) + meta = SafeTensorsMetadata.from_file(input_files[0], framework) + reader = fstcpp.nogds_file_reader( + False, 256 * 1024, 4, dev_is_gpu, device.index or 0 + ) + copier = NoGdsFileCopier(meta, device, reader, framework) + copier.set_byte_ranges(None) # explicit default + gbuf = copier.submit_io(False, 10 * 1024 * 1024 * 1024) + tensors = copier.wait_io(gbuf) + for key, exp in load_safetensors_file(input_files[0], device, framework).items(): + assert framework.is_equal(tensors[key], exp), key + framework.free_tensor_memory(gbuf, device) + del copier + del reader + assert framework.get_mem_used() == 0 + + +# ---- same partial-read guarantees for the unified-memory copier ---- + + +@_requires_cuda +def test_unified_partial_read_byte_identical(fstcpp_log, input_files, framework): + device, dev_is_gpu = get_and_check_device(framework) + if not dev_is_gpu: + pytest.skip("unified copier targets a GPU device") + meta = SafeTensorsMetadata.from_file(input_files[0], framework) + keep = _keep_every_other(meta) + ranges = meta.select_byte_ranges(keep) + + # factory path loads the CUDA fn pointers (load_library_func); constructing + # UnifiedMemCopier directly would leave memcpy_h2d_async unbound. + copier = new_unified_copier(device)(meta, device, framework) + copier.set_byte_ranges(ranges) + gbuf = copier.submit_io(False, 10 * 1024 * 1024 * 1024) + tensors = copier.wait_io(gbuf) + + ref = load_safetensors_file(input_files[0], device, framework) + kept_names = [n for n in meta.tensors if keep(n)] + assert kept_names, "fixture should have at least one kept tensor" + for name in kept_names: + assert framework.is_equal(tensors[name], ref[name]), name + + framework.free_tensor_memory(gbuf, device) + del copier + assert framework.get_mem_used() == 0 + + +@_requires_cuda +def test_unified_full_read_unchanged(fstcpp_log, input_files, framework): + """unified set_byte_ranges(None) must reproduce the full-file load exactly.""" + device, dev_is_gpu = get_and_check_device(framework) + if not dev_is_gpu: + pytest.skip("unified copier targets a GPU device") + meta = SafeTensorsMetadata.from_file(input_files[0], framework) + copier = new_unified_copier(device)(meta, device, framework) + copier.set_byte_ranges(None) # explicit default + gbuf = copier.submit_io(False, 10 * 1024 * 1024 * 1024) + tensors = copier.wait_io(gbuf) + for key, exp in load_safetensors_file(input_files[0], device, framework).items(): + assert framework.is_equal(tensors[key], exp), key + framework.free_tensor_memory(gbuf, device) + del copier + assert framework.get_mem_used() == 0 diff --git a/tests/unit/test_fastsafetensors.py b/tests/unit/test_fastsafetensors.py index 4135220..ea022ab 100644 --- a/tests/unit/test_fastsafetensors.py +++ b/tests/unit/test_fastsafetensors.py @@ -7,9 +7,16 @@ import pytest -from fastsafetensors import SafeTensorsFileLoader, SafeTensorsMetadata, SingleGroup +from fastsafetensors import ( + ParallelLoader, + SafeTensorsFileLoader, + SafeTensorsMetadata, + SingleGroup, +) from fastsafetensors import cpp as fstcpp -from fastsafetensors import fastsafe_open +from fastsafetensors import ( + fastsafe_open, +) from fastsafetensors.common import get_device_numa_node, is_gpu_found from fastsafetensors.copier.gds import GdsFileCopier from fastsafetensors.copier.nogds import NoGdsFileCopier @@ -402,7 +409,7 @@ def fake_memcpy_h2d_async(dst, src, size): assert framework.is_equal(actual, exp) # Lifecycle: mmap + pinned references released in wait_io assert copier._file_tensor is None - assert copier._pinned is None + assert copier._pinned == [] framework.free_tensor_memory(gbuf, device) assert framework.get_mem_used() == 0 assert fstcpp.get_cpp_metrics().bounce_buffer_bytes == 0 @@ -427,7 +434,7 @@ def test_UnifiedMemCopier_cuda_error( # gbuf must be freed and mmap/pin refs released on error assert framework.get_mem_used() == 0 assert copier._file_tensor is None - assert copier._pinned is None + assert copier._pinned == [] @pytest.mark.parametrize( @@ -531,7 +538,8 @@ def test_SafeTensorsFileLoader(fstcpp_log, input_files, framework) -> None: assert bufs.get_filename(last_key) == input_files[0] assert bufs.get_shape(last_key) == last_shape assert loader.get_shape(last_key) == last_shape - assert bufs.get_filename("aaaaaaaaaaaaa") == "" + with pytest.raises(ValueError): + bufs.get_filename("aaaaaaaaaaaaa") bufs.close() loader.close() assert framework.get_mem_used() == 0 @@ -560,6 +568,62 @@ def test_SafeTensorsFileLoaderNoGds(fstcpp_log, input_files, framework) -> None: assert fstcpp.get_cpp_metrics().bounce_buffer_bytes == 0 +def test_tensor_filter_hides_skipped_tensors(fstcpp_log, input_files, framework): + device, _ = get_and_check_device(framework) + meta = SafeTensorsMetadata.from_file(input_files[0], framework) + + kept = set(sorted(meta.tensors.keys())[::2]) + keep = lambda name: name in kept # noqa: E731 + skipped = next(name for name in meta.tensors if name not in kept) + + loader = SafeTensorsFileLoader( + pg=SingleGroup(), + device=device.as_str(), + framework=framework.get_name(), + nogds=True, + ) + loader.set_tensor_filter(keep) + loader.add_filenames({0: [input_files[0]]}) + bufs = loader.copy_files_to_device() + + assert set(loader.get_keys()) == kept + assert skipped not in bufs.key_to_rank_lidx + with pytest.raises(ValueError): + bufs.get_tensor(skipped) + with pytest.raises(ValueError): + bufs.get_filename(skipped) + with pytest.raises(ValueError): + loader.get_shape(skipped) + + bufs.close() + loader.close() + + +def test_tensor_filter_iterate_weights_hides_skipped( + fstcpp_log, input_files, framework +): + device, _ = get_and_check_device(framework) + meta = SafeTensorsMetadata.from_file(input_files[0], framework) + + kept = set(sorted(meta.tensors.keys())[::2]) + keep = lambda name: name in kept # noqa: E731 + + loader = ParallelLoader( + pg=SingleGroup(), + hf_weights_files=[input_files[0]], + device=device.as_str(), + nogds=True, + framework=framework.get_name(), + tensor_filter=keep, + all_local=True, + ) + yielded = {key for key, _t in loader.iterate_weights()} + assert yielded == kept + + loader.close() + assert framework.get_mem_used() == 0 + + def test_fastsafe_open(fstcpp_log, input_files, framework) -> None: device, _ = get_and_check_device(framework)