Skip to content

Commit eb9d8af

Browse files
committed
Add caller-side byte-range selection (sub-file reads) + EP-slice demonstrator
Prototype for #71: let a caller restrict which byte ranges of a shard are actually read, so an expert-parallel rank can read only the experts it owns instead of reading the whole file and discarding the rest. - SafeTensorsMetadata.select_byte_ranges(keep_tensor): merged [start,end) file byte-ranges covering exactly the kept tensors (adjacent kept tensors are coalesced within a page to limit the number of reads). - CopierInterface.set_byte_ranges(): read only those runs; the device buffer is still allocated for the full data section so tensor offsets are unchanged (byte_ranges=None reproduces the original full-file read). Implemented for the nogds and unified copiers; gds/threefs inherit the base no-op (full read). - SafeTensorsFileLoader.set_tensor_filter(keep_tensor): wire a keep(name) predicate through copy_files_to_device to the copier. - ParallelLoader/PipelineParallel: tensor_filter= passthrough, and all_local= on ParallelLoader (loads via a single-process group so a per-rank filter is not broken by get_tensor's cross-rank broadcast). - fastsafetensors.ep_slice: owned_expert_range / expert_parallel_filter (contiguous-block "linear" expert assignment; no external dependency) plus expert_parallel_filter_from_env. - tests/unit/test_ep_slice.py: range math, select_byte_ranges coverage, and byte-identical partial-read + full-read-unchanged regressions for both the nogds and unified copiers (the unified pair is GPU-gated). Motivation/measurements: on 2x DGX Spark (GB10), reading only owned experts cut DeepSeek-V4-Flash weight load ~17.6s -> ~13.0s under EP=2 (~half the bytes per rank). This is the I/O-side counterpart of the memory-side sub-file batching in this issue; both want the same byte-range read primitive. Verified byte-identical on stock 0.3.2 (GB10): copier partial read, a 2-rank gloo ParallelLoader all-local run, and a real DeepSeek-V4-Flash shard (256 experts, EP=2) where both the nogds and unified copiers load the rank's owned experts byte-identical while reading ~52% of the bytes (no broadcast). Signed-off-by: git bisector <gitbisector@gmail.com>
1 parent 7bae3af commit eb9d8af

10 files changed

Lines changed: 501 additions & 43 deletions

File tree

fastsafetensors/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,11 @@
1212
get_device_numa_node,
1313
)
1414
from .config import LoaderConfig, load_config
15+
from .ep_slice import (
16+
expert_parallel_filter,
17+
expert_parallel_filter_from_env,
18+
owned_expert_range,
19+
)
1520
from .file_buffer import FilesBufferOnDevice
1621
from .loader import SafeTensorsFileLoader, fastsafe_open
1722
from .parallel_loader import ParallelLoader

fastsafetensors/common.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import sys
77
from collections import OrderedDict
88
from dataclasses import dataclass
9-
from typing import Dict, List, Optional, Tuple
9+
from typing import Callable, Dict, List, Optional, Tuple
1010

1111
from . import cpp as fstcpp
1212
from .dlpack import from_cuda_buffer
@@ -359,6 +359,38 @@ def get_tensors(
359359
ret[tensor_name] = t2
360360
return ret
361361

362+
def select_byte_ranges(
363+
self, keep_tensor: Callable[[str], bool], merge_gap: int = 4096
364+
) -> List[Tuple[int, int]]:
365+
"""Compute the file byte-ranges covering only the kept tensors.
366+
367+
Returns a sorted list of ``[start, end)`` absolute file offsets spanning
368+
exactly the tensors for which ``keep_tensor(name)`` is True. Kept tensors
369+
separated by a gap of at most ``merge_gap`` bytes are coalesced into one
370+
range to reduce the number of reads; the few non-kept bytes inside a
371+
coalesced range are read but never instantiated as tensors.
372+
373+
Pass the result to a partial-read-capable copier (see
374+
``NoGdsFileCopier.set_byte_ranges``) to load only a subset of a shard --
375+
e.g. only the experts an expert-parallel rank owns. Tensor data offsets
376+
are unchanged, so unread regions of the device buffer simply stay
377+
uninitialized and their tensors must not be requested.
378+
"""
379+
ranges: List[Tuple[int, int]] = []
380+
for name, frame in self.tensors.items():
381+
if not keep_tensor(name):
382+
continue
383+
s, e = frame.data_offsets[0], frame.data_offsets[1]
384+
ranges.append((self.header_length + s, self.header_length + e))
385+
ranges.sort()
386+
merged: List[List[int]] = []
387+
for s, e in ranges:
388+
if merged and s - merged[-1][1] <= merge_gap:
389+
merged[-1][1] = max(merged[-1][1], e)
390+
else:
391+
merged.append([s, e])
392+
return [(s, e) for s, e in merged]
393+
362394
def __repr__(self) -> str:
363395
return str({"__metadata__": self.metadata, "tensors": self.tensors})
364396

fastsafetensors/copier/base.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,26 @@
11
# SPDX-License-Identifier: Apache-2.0
22

33
from abc import ABC, abstractmethod
4-
from typing import Dict
4+
from typing import Dict, List, Optional, Tuple
55

66
from .. import cpp as fstcpp
77
from ..frameworks import TensorBase
88
from ..st_types import DType
99

1010

1111
class CopierInterface(ABC):
12+
def set_byte_ranges(self, byte_ranges: Optional[List[Tuple[int, int]]]) -> None:
13+
"""Restrict reads to these ``[start, end)`` absolute file-offset runs.
14+
15+
The default implementation ignores the runs and reads the whole file, so
16+
the byte-range filter is a correct no-op on copiers that don't implement
17+
partial reads. Range-capable copiers (``nogds``, ``unified``) override
18+
this to read only the given runs, leaving the rest of the device buffer
19+
uninitialized (so skipped tensors must not be requested). Build runs with
20+
``SafeTensorsMetadata.select_byte_ranges``; ``None`` means full read.
21+
"""
22+
return
23+
1224
@abstractmethod
1325
def submit_io(
1426
self, use_buf_register: bool, max_copy_block_size: int

fastsafetensors/copier/nogds.py

Lines changed: 32 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import os
44
import sys
5-
from typing import Dict, List
5+
from typing import Dict, List, Optional, Tuple
66

77
from .. import cpp as fstcpp
88
from ..common import SafeTensorsMetadata, is_gpu_found, resolve_cudart_lib_name
@@ -35,24 +35,42 @@ def __init__(
3535
)
3636
self.device = device
3737
self.reqs: List[int] = []
38+
self.byte_ranges: Optional[List[Tuple[int, int]]] = None
39+
40+
def set_byte_ranges(self, byte_ranges: Optional[List[Tuple[int, int]]]) -> None:
41+
"""Restrict reads to these ``[start, end)`` absolute file-offset runs.
42+
43+
Bytes outside the given runs are not read; their regions of the device
44+
buffer are left uninitialized, so the corresponding tensors must not be
45+
requested. ``None`` (the default) reads the whole data section. Build
46+
runs with ``SafeTensorsMetadata.select_byte_ranges``.
47+
"""
48+
self.byte_ranges = byte_ranges
3849

3950
def submit_io(
4051
self, use_buf_register: bool, max_copy_block_size: int
4152
) -> fstcpp.gds_device_buffer:
42-
total_length = self.metadata.size_bytes - self.metadata.header_length
53+
header_length = self.metadata.header_length
54+
total_length = self.metadata.size_bytes - header_length
4355
gbuf = self.framework.alloc_tensor_memory(total_length, self.device)
44-
count = 0
45-
while count < total_length:
46-
l = total_length - count
47-
if max_copy_block_size < l:
48-
l = max_copy_block_size
49-
req = self.reader.submit_read(
50-
self.fd, gbuf, self.metadata.header_length + count, l, count
51-
)
52-
if req < 0:
53-
raise Exception(f"submit_io: submit_nogds_read failed, err={req}")
54-
self.reqs.append(req)
55-
count += l
56+
# Default to a single run spanning the whole data section, which
57+
# reproduces the original full-file read.
58+
runs = self.byte_ranges
59+
if runs is None:
60+
runs = [(header_length, self.metadata.size_bytes)]
61+
for start, end in runs:
62+
count = start
63+
while count < end:
64+
l = end - count
65+
if max_copy_block_size < l:
66+
l = max_copy_block_size
67+
req = self.reader.submit_read(
68+
self.fd, gbuf, count, l, count - header_length
69+
)
70+
if req < 0:
71+
raise Exception(f"submit_io: submit_nogds_read failed, err={req}")
72+
self.reqs.append(req)
73+
count += l
5674
return gbuf
5775

5876
def wait_io(

fastsafetensors/copier/unified.py

Lines changed: 41 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
"""
1010

1111
import os
12-
from typing import Dict, Optional
12+
from typing import Dict, List, Optional, Tuple
1313

1414
import torch
1515

@@ -40,12 +40,25 @@ def __init__(
4040
self.device = device
4141
self.framework = framework
4242
self._file_tensor: Optional[torch.Tensor] = None
43-
self._pinned: Optional[torch.Tensor] = None
43+
self._pinned: List[torch.Tensor] = []
44+
self.byte_ranges: Optional[List[Tuple[int, int]]] = None
45+
46+
def set_byte_ranges(self, byte_ranges: Optional[List[Tuple[int, int]]]) -> None:
47+
"""Restrict reads to these ``[start, end)`` absolute file-offset runs.
48+
49+
Only the bytes in the given runs are mmap-faulted, pinned, and copied;
50+
the rest of the device buffer is left uninitialized (so the corresponding
51+
tensors must not be requested). Tensor offsets are unchanged. ``None``
52+
reads the whole data section. Build runs with
53+
``SafeTensorsMetadata.select_byte_ranges``.
54+
"""
55+
self.byte_ranges = byte_ranges
4456

4557
def submit_io(
4658
self, use_buf_register: bool, max_copy_block_size: int
4759
) -> fstcpp.gds_device_buffer:
48-
data_length = self.metadata.size_bytes - self.metadata.header_length
60+
header_length = self.metadata.header_length
61+
data_length = self.metadata.size_bytes - header_length
4962

5063
# Allocate CUDA buffer via framework's allocator (proper lifecycle)
5164
gbuf = self.framework.alloc_tensor_memory(data_length, self.device)
@@ -55,25 +68,32 @@ def submit_io(
5568
self.metadata.src, size=self.metadata.size_bytes, dtype=torch.uint8
5669
)
5770
self._file_tensor = file_tensor
58-
data_tensor = file_tensor[self.metadata.header_length :]
5971

60-
# pin_memory triggers kernel readahead + pins pages for DMA
61-
pinned = data_tensor.pin_memory()
62-
self._pinned = pinned
63-
64-
# Async DMA from pinned CPU → framework-allocated CUDA buffer
65-
ret = fstcpp.memcpy_h2d_async( # type: ignore[attr-defined]
66-
gbuf.get_base_address(),
67-
pinned.data_ptr(),
68-
data_length,
69-
)
70-
if ret != 0:
71-
self.framework.free_tensor_memory(gbuf, self.device)
72-
self._pinned = None
73-
self._file_tensor = None
74-
raise RuntimeError(
75-
f"cudaMemcpyAsync failed with error {ret} " f"for {self.metadata.src}"
72+
# Default to the whole data section, reproducing the full-file read.
73+
# An empty list (vs None) reads nothing — same semantics as nogds.
74+
runs = self.byte_ranges
75+
if runs is None:
76+
runs = [(header_length, self.metadata.size_bytes)]
77+
78+
base_address = gbuf.get_base_address()
79+
self._pinned = []
80+
for start, end in runs:
81+
# pin_memory faults in + pins only this run's pages, then DMA to the
82+
# matching offset in gbuf (data section starts at header_length).
83+
pinned = file_tensor[start:end].pin_memory()
84+
self._pinned.append(pinned)
85+
ret = fstcpp.memcpy_h2d_async( # type: ignore[attr-defined]
86+
base_address + (start - header_length),
87+
pinned.data_ptr(),
88+
end - start,
7689
)
90+
if ret != 0:
91+
self.framework.free_tensor_memory(gbuf, self.device)
92+
self._pinned = []
93+
self._file_tensor = None
94+
raise RuntimeError(
95+
f"cudaMemcpyAsync failed with error {ret} for {self.metadata.src}"
96+
)
7797

7898
return gbuf
7999

@@ -94,7 +114,7 @@ def wait_io(
94114
)
95115

96116
# Release mmap and pinned memory
97-
self._pinned = None
117+
self._pinned = []
98118
self._file_tensor = None
99119

100120
return tensors

fastsafetensors/ep_slice.py

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
"""Expert-parallel (EP) slice helpers.
3+
4+
Under expert parallelism each rank only *uses* the routed experts it owns, yet
5+
file-granular loading makes every rank read the whole shard -- the unowned
6+
experts' bytes are read and then discarded. These helpers build a tensor-name
7+
predicate selecting just this rank's owned experts (plus every non-expert
8+
tensor), so a partial-read-capable loader can skip the unowned bytes:
9+
10+
from fastsafetensors import SafeTensorsFileLoader
11+
from fastsafetensors.ep_slice import expert_parallel_filter
12+
13+
loader = SafeTensorsFileLoader(pg, device, nogds=True)
14+
loader.set_tensor_filter(expert_parallel_filter(num_experts=256,
15+
ep_size=2, ep_rank=rank))
16+
loader.add_filenames(...)
17+
bufs = loader.copy_files_to_device()
18+
19+
Owned experts use contiguous-block ("linear") assignment: each rank owns
20+
``num_experts // ep_size`` consecutive experts, with any remainder given to the
21+
lowest-numbered ranks. This is a common expert-to-rank convention; the caller is
22+
responsible for ensuring it matches the assignment its runtime expects. No
23+
external dependency is required.
24+
"""
25+
import os
26+
import re
27+
from typing import Callable, Optional, Pattern, Tuple
28+
29+
# Matches the per-expert index in routed-MoE tensor names, e.g.
30+
# "model.layers.3.mlp.experts.42.w1.weight" or DeepSeek's
31+
# "...ffn.experts.42.gate_proj.weight". Override for a different convention.
32+
DEFAULT_EXPERT_PATTERN: Pattern[str] = re.compile(r"\.experts\.(\d+)\.")
33+
34+
35+
def owned_expert_range(num_experts: int, ep_size: int, ep_rank: int) -> Tuple[int, int]:
36+
"""Return the ``[lo, hi)`` routed-expert indices owned by ``ep_rank``.
37+
38+
Contiguous-block ("linear") assignment: each rank owns a consecutive block
39+
of experts, with the remainder distributed to the lowest-numbered ranks.
40+
"""
41+
if ep_size <= 0:
42+
raise ValueError(f"ep_size must be positive, got {ep_size}")
43+
if not 0 <= ep_rank < ep_size:
44+
raise ValueError(f"ep_rank {ep_rank} out of range for ep_size {ep_size}")
45+
base = num_experts // ep_size
46+
rem = num_experts % ep_size
47+
local = base + (1 if ep_rank < rem else 0)
48+
start = ep_rank * base + min(ep_rank, rem)
49+
return (start, start + local)
50+
51+
52+
def expert_parallel_filter(
53+
num_experts: int,
54+
ep_size: int,
55+
ep_rank: int,
56+
pattern: Pattern[str] = DEFAULT_EXPERT_PATTERN,
57+
) -> Callable[[str], bool]:
58+
"""Build a ``keep(name) -> bool`` predicate for this EP rank.
59+
60+
Non-expert tensors (names not matching ``pattern``) are kept on every rank;
61+
routed-expert tensors are kept only when their index is in this rank's owned
62+
range. Pass the predicate to ``SafeTensorsFileLoader.set_tensor_filter`` or
63+
``SafeTensorsMetadata.select_byte_ranges``.
64+
"""
65+
lo, hi = owned_expert_range(num_experts, ep_size, ep_rank)
66+
67+
def keep(name: str) -> bool:
68+
m = pattern.search(name)
69+
if m is None:
70+
return True
71+
return lo <= int(m.group(1)) < hi
72+
73+
return keep
74+
75+
76+
def expert_parallel_filter_from_env() -> Optional[Callable[[str], bool]]:
77+
"""Build an EP filter from environment variables, or ``None`` if disabled.
78+
79+
Recognized variables (kept compatible with the DGX Spark overlay this
80+
prototype generalizes):
81+
82+
``FASTSAFETENSORS_EP_SLICE=1`` enable EP-slice reading
83+
``FASTSAFETENSORS_EP_NUM_EXPERTS=N`` global routed-expert count (required)
84+
``FASTSAFETENSORS_EP_SIZE`` / ``_RANK`` override EP size/rank; otherwise
85+
taken from the initialized
86+
torch.distributed group, else from
87+
``WORLD_SIZE`` / ``RANK``.
88+
89+
Returns ``None`` (load everything) unless EP-slice is enabled, the expert
90+
count is known, and ``ep_size > 1``.
91+
"""
92+
if os.getenv("FASTSAFETENSORS_EP_SLICE", "0") != "1":
93+
return None
94+
num_experts = int(os.getenv("FASTSAFETENSORS_EP_NUM_EXPERTS", "0"))
95+
if num_experts <= 0:
96+
return None
97+
ep_size = int(os.getenv("FASTSAFETENSORS_EP_SIZE", "0"))
98+
ep_rank = int(os.getenv("FASTSAFETENSORS_EP_RANK", "-1"))
99+
if ep_size <= 0 or ep_rank < 0:
100+
try:
101+
import torch.distributed as dist
102+
103+
if dist.is_available() and dist.is_initialized():
104+
ep_size = dist.get_world_size()
105+
ep_rank = dist.get_rank()
106+
except Exception:
107+
pass
108+
if ep_size <= 0:
109+
ep_size = int(os.getenv("WORLD_SIZE", "1"))
110+
if ep_rank < 0:
111+
ep_rank = int(os.getenv("RANK", "0"))
112+
if ep_size <= 1:
113+
return None
114+
return expert_parallel_filter(num_experts, ep_size, ep_rank)

0 commit comments

Comments
 (0)