Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions examples/tgis_weight.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,13 +151,13 @@ def close(self):
torch.cuda.empty_cache()

def _get_alias(self, tensor_name: str) -> str:
if self._fb.get_filename(tensor_name) is None:
if tensor_name in self.aliases:
for alias in self.aliases[tensor_name]:
if self._fb.get_filename(alias) is not None:
return alias
raise RuntimeError(f"weight {tensor_name} does not exist")
return tensor_name
if tensor_name in self._fb.key_to_rank_lidx:
return tensor_name
if tensor_name in self.aliases:
for alias in self.aliases[tensor_name]:
if alias in self._fb.key_to_rank_lidx:
return alias
raise RuntimeError(f"weight {tensor_name} does not exist")

def get_shape(self, tensor_name: str) -> torch.Size:
return torch.Size(self._fb.get_shape(self._get_alias(tensor_name)))
Expand Down
5 changes: 5 additions & 0 deletions fastsafetensors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@
get_device_numa_node,
)
from .config import LoaderConfig, load_config
from .ep_slice import (
expert_parallel_filter,
expert_parallel_filter_from_env,
owned_expert_range,
)
from .file_buffer import FilesBufferOnDevice
from .loader import SafeTensorsFileLoader, fastsafe_open
from .parallel_loader import ParallelLoader
34 changes: 33 additions & 1 deletion fastsafetensors/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import sys
from collections import OrderedDict
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple
from typing import Callable, Dict, List, Optional, Tuple

from . import cpp as fstcpp
from .dlpack import from_cuda_buffer
Expand Down Expand Up @@ -359,6 +359,38 @@ def get_tensors(
ret[tensor_name] = t2
return ret

def select_byte_ranges(
self, keep_tensor: Callable[[str], bool], merge_gap: int = 4096
) -> List[Tuple[int, int]]:
"""Compute the file byte-ranges covering only the kept tensors.

Returns a sorted list of ``[start, end)`` absolute file offsets spanning
exactly the tensors for which ``keep_tensor(name)`` is True. Kept tensors
separated by a gap of at most ``merge_gap`` bytes are coalesced into one
range to reduce the number of reads; the few non-kept bytes inside a
coalesced range are read but never instantiated as tensors.

Pass the result to a partial-read-capable copier (see
``NoGdsFileCopier.set_byte_ranges``) to load only a subset of a shard --
e.g. only the experts an expert-parallel rank owns. Tensor data offsets
are unchanged, so unread regions of the device buffer simply stay
uninitialized and their tensors must not be requested.
"""
ranges: List[Tuple[int, int]] = []
for name, frame in self.tensors.items():
if not keep_tensor(name):
continue
s, e = frame.data_offsets[0], frame.data_offsets[1]
ranges.append((self.header_length + s, self.header_length + e))
ranges.sort()
merged: List[List[int]] = []
for s, e in ranges:
if merged and s - merged[-1][1] <= merge_gap:
merged[-1][1] = max(merged[-1][1], e)
else:
merged.append([s, e])
return [(s, e) for s, e in merged]

def __repr__(self) -> str:
return str({"__metadata__": self.metadata, "tensors": self.tensors})

Expand Down
14 changes: 13 additions & 1 deletion fastsafetensors/copier/base.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,26 @@
# SPDX-License-Identifier: Apache-2.0

from abc import ABC, abstractmethod
from typing import Dict
from typing import Dict, List, Optional, Tuple

from .. import cpp as fstcpp
from ..frameworks import TensorBase
from ..st_types import DType


class CopierInterface(ABC):
def set_byte_ranges(self, byte_ranges: Optional[List[Tuple[int, int]]]) -> None:
"""Restrict reads to these ``[start, end)`` absolute file-offset runs.

The default implementation ignores the runs and reads the whole file, so
the byte-range filter is a correct no-op on copiers that don't implement
partial reads. Range-capable copiers (``nogds``, ``unified``) override
this to read only the given runs, leaving the rest of the device buffer
uninitialized (so skipped tensors must not be requested). Build runs with
``SafeTensorsMetadata.select_byte_ranges``; ``None`` means full read.
"""
return

@abstractmethod
def submit_io(
self, use_buf_register: bool, max_copy_block_size: int
Expand Down
46 changes: 32 additions & 14 deletions fastsafetensors/copier/nogds.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import os
import sys
from typing import Dict, List
from typing import Dict, List, Optional, Tuple

from .. import cpp as fstcpp
from ..common import SafeTensorsMetadata, is_gpu_found, resolve_cudart_lib_name
Expand Down Expand Up @@ -35,24 +35,42 @@ def __init__(
)
self.device = device
self.reqs: List[int] = []
self.byte_ranges: Optional[List[Tuple[int, int]]] = None

def set_byte_ranges(self, byte_ranges: Optional[List[Tuple[int, int]]]) -> None:
"""Restrict reads to these ``[start, end)`` absolute file-offset runs.

Bytes outside the given runs are not read; their regions of the device
buffer are left uninitialized, so the corresponding tensors must not be
requested. ``None`` (the default) reads the whole data section. Build
runs with ``SafeTensorsMetadata.select_byte_ranges``.
"""
self.byte_ranges = byte_ranges

def submit_io(
self, use_buf_register: bool, max_copy_block_size: int
) -> fstcpp.gds_device_buffer:
total_length = self.metadata.size_bytes - self.metadata.header_length
header_length = self.metadata.header_length
total_length = self.metadata.size_bytes - header_length
gbuf = self.framework.alloc_tensor_memory(total_length, self.device)
count = 0
while count < total_length:
l = total_length - count
if max_copy_block_size < l:
l = max_copy_block_size
req = self.reader.submit_read(
self.fd, gbuf, self.metadata.header_length + count, l, count
)
if req < 0:
raise Exception(f"submit_io: submit_nogds_read failed, err={req}")
self.reqs.append(req)
count += l
# Default to a single run spanning the whole data section, which
# reproduces the original full-file read.
runs = self.byte_ranges
if runs is None:
runs = [(header_length, self.metadata.size_bytes)]
for start, end in runs:
count = start
while count < end:
l = end - count
if max_copy_block_size < l:
l = max_copy_block_size
req = self.reader.submit_read(
self.fd, gbuf, count, l, count - header_length
)
if req < 0:
raise Exception(f"submit_io: submit_nogds_read failed, err={req}")
self.reqs.append(req)
count += l
return gbuf

def wait_io(
Expand Down
62 changes: 41 additions & 21 deletions fastsafetensors/copier/unified.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
"""

import os
from typing import Dict, Optional
from typing import Dict, List, Optional, Tuple

import torch

Expand Down Expand Up @@ -40,12 +40,25 @@ def __init__(
self.device = device
self.framework = framework
self._file_tensor: Optional[torch.Tensor] = None
self._pinned: Optional[torch.Tensor] = None
self._pinned: List[torch.Tensor] = []
self.byte_ranges: Optional[List[Tuple[int, int]]] = None

def set_byte_ranges(self, byte_ranges: Optional[List[Tuple[int, int]]]) -> None:
"""Restrict reads to these ``[start, end)`` absolute file-offset runs.

Only the bytes in the given runs are mmap-faulted, pinned, and copied;
the rest of the device buffer is left uninitialized (so the corresponding
tensors must not be requested). Tensor offsets are unchanged. ``None``
reads the whole data section. Build runs with
``SafeTensorsMetadata.select_byte_ranges``.
"""
self.byte_ranges = byte_ranges

def submit_io(
self, use_buf_register: bool, max_copy_block_size: int
) -> fstcpp.gds_device_buffer:
data_length = self.metadata.size_bytes - self.metadata.header_length
header_length = self.metadata.header_length
data_length = self.metadata.size_bytes - header_length

# Allocate CUDA buffer via framework's allocator (proper lifecycle)
gbuf = self.framework.alloc_tensor_memory(data_length, self.device)
Expand All @@ -55,25 +68,32 @@ def submit_io(
self.metadata.src, size=self.metadata.size_bytes, dtype=torch.uint8
)
self._file_tensor = file_tensor
data_tensor = file_tensor[self.metadata.header_length :]

# pin_memory triggers kernel readahead + pins pages for DMA
pinned = data_tensor.pin_memory()
self._pinned = pinned

# Async DMA from pinned CPU → framework-allocated CUDA buffer
ret = fstcpp.memcpy_h2d_async( # type: ignore[attr-defined]
gbuf.get_base_address(),
pinned.data_ptr(),
data_length,
)
if ret != 0:
self.framework.free_tensor_memory(gbuf, self.device)
self._pinned = None
self._file_tensor = None
raise RuntimeError(
f"cudaMemcpyAsync failed with error {ret} " f"for {self.metadata.src}"
# Default to the whole data section, reproducing the full-file read.
# An empty list (vs None) reads nothing — same semantics as nogds.
runs = self.byte_ranges
if runs is None:
runs = [(header_length, self.metadata.size_bytes)]

base_address = gbuf.get_base_address()
self._pinned = []
for start, end in runs:
# pin_memory faults in + pins only this run's pages, then DMA to the
# matching offset in gbuf (data section starts at header_length).
pinned = file_tensor[start:end].pin_memory()
self._pinned.append(pinned)
ret = fstcpp.memcpy_h2d_async( # type: ignore[attr-defined]
base_address + (start - header_length),
pinned.data_ptr(),
end - start,
)
if ret != 0:
self.framework.free_tensor_memory(gbuf, self.device)
self._pinned = []
self._file_tensor = None
raise RuntimeError(
f"cudaMemcpyAsync failed with error {ret} for {self.metadata.src}"
)

return gbuf

Expand All @@ -94,7 +114,7 @@ def wait_io(
)

# Release mmap and pinned memory
self._pinned = None
self._pinned = []
self._file_tensor = None

return tensors
Expand Down
114 changes: 114 additions & 0 deletions fastsafetensors/ep_slice.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
# SPDX-License-Identifier: Apache-2.0
"""Expert-parallel (EP) slice helpers.

Under expert parallelism each rank only *uses* the routed experts it owns, yet
file-granular loading makes every rank read the whole shard -- the unowned
experts' bytes are read and then discarded. These helpers build a tensor-name
predicate selecting just this rank's owned experts (plus every non-expert
tensor), so a partial-read-capable loader can skip the unowned bytes:

from fastsafetensors import SafeTensorsFileLoader
from fastsafetensors.ep_slice import expert_parallel_filter

loader = SafeTensorsFileLoader(pg, device, nogds=True)
loader.set_tensor_filter(expert_parallel_filter(num_experts=256,
ep_size=2, ep_rank=rank))
loader.add_filenames(...)
bufs = loader.copy_files_to_device()

Owned experts use contiguous-block ("linear") assignment: each rank owns
``num_experts // ep_size`` consecutive experts, with any remainder given to the
lowest-numbered ranks. This is a common expert-to-rank convention; the caller is
responsible for ensuring it matches the assignment its runtime expects. No
external dependency is required.
"""
import os
import re
from typing import Callable, Optional, Pattern, Tuple

# Matches the per-expert index in routed-MoE tensor names, e.g.
# "model.layers.3.mlp.experts.42.w1.weight" or DeepSeek's
# "...ffn.experts.42.gate_proj.weight". Override for a different convention.
DEFAULT_EXPERT_PATTERN: Pattern[str] = re.compile(r"\.experts\.(\d+)\.")


def owned_expert_range(num_experts: int, ep_size: int, ep_rank: int) -> Tuple[int, int]:
"""Return the ``[lo, hi)`` routed-expert indices owned by ``ep_rank``.

Contiguous-block ("linear") assignment: each rank owns a consecutive block
of experts, with the remainder distributed to the lowest-numbered ranks.
"""
if ep_size <= 0:
raise ValueError(f"ep_size must be positive, got {ep_size}")
if not 0 <= ep_rank < ep_size:
raise ValueError(f"ep_rank {ep_rank} out of range for ep_size {ep_size}")
base = num_experts // ep_size
rem = num_experts % ep_size
local = base + (1 if ep_rank < rem else 0)
start = ep_rank * base + min(ep_rank, rem)
return (start, start + local)


def expert_parallel_filter(
num_experts: int,
ep_size: int,
ep_rank: int,
pattern: Pattern[str] = DEFAULT_EXPERT_PATTERN,
) -> Callable[[str], bool]:
"""Build a ``keep(name) -> bool`` predicate for this EP rank.

Non-expert tensors (names not matching ``pattern``) are kept on every rank;
routed-expert tensors are kept only when their index is in this rank's owned
range. Pass the predicate to ``SafeTensorsFileLoader.set_tensor_filter`` or
``SafeTensorsMetadata.select_byte_ranges``.
"""
lo, hi = owned_expert_range(num_experts, ep_size, ep_rank)

def keep(name: str) -> bool:
m = pattern.search(name)
if m is None:
return True
return lo <= int(m.group(1)) < hi

return keep


def expert_parallel_filter_from_env() -> Optional[Callable[[str], bool]]:
"""Build an EP filter from environment variables, or ``None`` if disabled.

Recognized variables (kept compatible with the DGX Spark overlay this
prototype generalizes):

``FASTSAFETENSORS_EP_SLICE=1`` enable EP-slice reading
``FASTSAFETENSORS_EP_NUM_EXPERTS=N`` global routed-expert count (required)
``FASTSAFETENSORS_EP_SIZE`` / ``_RANK`` override EP size/rank; otherwise
taken from the initialized
torch.distributed group, else from
``WORLD_SIZE`` / ``RANK``.

Returns ``None`` (load everything) unless EP-slice is enabled, the expert
count is known, and ``ep_size > 1``.
"""
if os.getenv("FASTSAFETENSORS_EP_SLICE", "0") != "1":
return None
num_experts = int(os.getenv("FASTSAFETENSORS_EP_NUM_EXPERTS", "0"))
if num_experts <= 0:
return None
ep_size = int(os.getenv("FASTSAFETENSORS_EP_SIZE", "0"))
ep_rank = int(os.getenv("FASTSAFETENSORS_EP_RANK", "-1"))
if ep_size <= 0 or ep_rank < 0:
try:
import torch.distributed as dist

if dist.is_available() and dist.is_initialized():
ep_size = dist.get_world_size()
ep_rank = dist.get_rank()
except Exception:
pass
if ep_size <= 0:
ep_size = int(os.getenv("WORLD_SIZE", "1"))
if ep_rank < 0:
ep_rank = int(os.getenv("RANK", "0"))
if ep_size <= 1:
return None
return expert_parallel_filter(num_experts, ep_size, ep_rank)
Loading
Loading