Skip to content

Commit b7fb848

Browse files
authored
feat: add consecutive batch shard sampler for pytorch (#3886)
Signed-off-by: jukejian <jukejian@bytedance.com>
1 parent 621c791 commit b7fb848

2 files changed

Lines changed: 349 additions & 2 deletions

File tree

python/python/lance/sampler.py

Lines changed: 106 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -356,6 +356,10 @@ def __init__(
356356
self._world_size = world_size
357357
self._randomize = randomize
358358
self._seed = seed
359+
self._epoch = 0
360+
361+
def set_epoch(self, epoch: int):
362+
self._epoch = epoch
359363

360364
@staticmethod
361365
def from_torch(randomize: bool = False, seed: int = 0) -> ShardedFragmentSampler:
@@ -399,6 +403,7 @@ class ShardedBatchSampler(Sampler):
399403
not assigned to it. The resulting stream is then randomized via a reservoir
400404
sampler. This does not perfectly randomize the stream but it should generate
401405
a stream that is random enough for many use cases.
406+
402407
"""
403408

404409
def __init__(
@@ -408,6 +413,13 @@ def __init__(
408413
self._world_size = world_size
409414
self._randomize = randomize
410415
self._seed = seed
416+
self._epoch = 0
417+
418+
def __len__(self):
419+
return self._len
420+
421+
def set_epoch(self, epoch: int):
422+
self._epoch = epoch
411423

412424
@staticmethod
413425
def from_torch(randomize: bool = False, seed: int = 0) -> ShardedBatchSampler:
@@ -488,7 +500,7 @@ def _sample_filtered(
488500
if not self._randomize:
489501
yield from shard_scan
490502

491-
random.seed(self._seed)
503+
random.seed(self._seed + self._epoch)
492504
heap = []
493505
# We want to randomize the incoming sequence. The normal approach
494506
# is to pull the whole thing in memory and run fisher-yates. We
@@ -563,3 +575,96 @@ def __call__(
563575
return self._sample_filtered(
564576
dataset, batch_size, columns, batch_readahead, filter
565577
)
578+
579+
580+
class ShardedFixedBatchSampler(ShardedBatchSampler):
581+
"""
582+
Sharded fixed batch sampler for distributed index-based batching.
583+
584+
This sampler is designed for static datasets with a known total number of rows.
585+
It divides the dataset into consecutive index ranges (batches) and assigns each
586+
process (rank) a unique subset of these batches for efficient distributed loading.
587+
588+
Features:
589+
- Requires `total_num_rows` and `batch_size` to be specified.
590+
- Each rank receives consecutive, non-overlapping index ranges.
591+
- Optionally randomizes the order of batches per epoch if `randomize=True`.
592+
- Suitable for integration with PyTorch DataLoader or similar frameworks.
593+
594+
Example (total_num_rows=1000, world_size=4, batch_size=100):
595+
- Rank 0: [0-99], [100-199], [200-299]
596+
- Rank 1: [250-349], [350-449], [450-549]
597+
- Rank 2: [500-599], [600-699], [700-799]
598+
- Rank 3: [750-849], [850-949], [950-999]
599+
600+
Parameters
601+
----------
602+
rank : int
603+
The rank (process index) in the distributed cluster.
604+
world_size : int
605+
The total number of processes in the distributed cluster.
606+
randomize : bool, default False
607+
Whether to randomize the order of batches for each epoch.
608+
seed : int, default 0
609+
Random seed for reproducibility when randomize is enabled.
610+
batch_size : int, default 0
611+
The number of rows per batch.
612+
total_num_rows : int, default 0
613+
The total number of rows in the dataset.
614+
"""
615+
616+
def __init__(
617+
self,
618+
rank: int,
619+
world_size: int,
620+
randomize: bool = False,
621+
seed: int = 0,
622+
batch_size: int = 0,
623+
total_num_rows: int = 0,
624+
):
625+
super().__init__(rank, world_size, randomize, seed)
626+
self._total_num_rows = total_num_rows
627+
self._batch_size = batch_size
628+
self._len = self._compute_length()
629+
630+
# The sampler here is mainly implemented with the hope that
631+
# the data of batch_size are all adjacent, so we don't want
632+
# to use filter to break this adjacent feature.
633+
def _compute_length(self):
634+
if self._batch_size == 0 and self._total_num_rows == 0:
635+
return 0
636+
per_rank = math.ceil(self._total_num_rows / self._world_size)
637+
return math.ceil(per_rank / self._batch_size)
638+
639+
def __len__(self):
640+
return self._len
641+
642+
def __iter__(self) -> Generator[List[int], None, None]:
643+
per_rank = math.ceil(self._total_num_rows / self._world_size)
644+
start = self._rank * per_rank
645+
end = min(start + per_rank, self._total_num_rows)
646+
647+
batches = []
648+
current = start
649+
while current < end:
650+
batch_end = min(current + self._batch_size, end)
651+
batches.append(list(range(current, batch_end)))
652+
current = batch_end
653+
654+
if self._randomize:
655+
random.seed(self._seed + self._epoch)
656+
random.shuffle(batches)
657+
658+
yield from batches
659+
660+
@staticmethod
661+
def from_torch(
662+
total_num_rows: int, batch_size: int, randomize: bool = False, seed: int = 0
663+
) -> ShardedFixedBatchSampler:
664+
import torch
665+
666+
rank = torch.distributed.get_rank()
667+
world_size = torch.distributed.get_world_size()
668+
return ShardedFixedBatchSampler(
669+
rank, world_size, total_num_rows, batch_size, randomize, seed
670+
)

python/python/tests/test_sampler.py

Lines changed: 243 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,249 @@
77
import numpy as np
88
import pyarrow as pa
99
import pytest
10-
from lance.sampler import maybe_sample
10+
from lance.sampler import ShardedBatchSampler, ShardedFixedBatchSampler, maybe_sample
11+
12+
TEST_CONFIG = {
13+
"total_rows": 1000,
14+
"batch_size": 250,
15+
"world_size": 4,
16+
"vec_dim": 32,
17+
"test_port": "29501",
18+
"master_addr": "127.0.0.1",
19+
"seed": 42,
20+
"test_shard_ratio": 0.5,
21+
"max_takes_factor": 0.1,
22+
}
23+
24+
25+
@pytest.fixture
26+
def sample_dataset_path(tmp_path):
27+
data = pa.Table.from_arrays(
28+
[
29+
pa.array(range(TEST_CONFIG["total_rows"])),
30+
pa.array(np.random.rand(TEST_CONFIG["total_rows"])),
31+
pa.array([f"text_{i}" for i in range(TEST_CONFIG["total_rows"])]),
32+
],
33+
names=["id", "value", "text"],
34+
)
35+
36+
dataset_path = tmp_path / "test_dataset.lance"
37+
lance.write_dataset(data, dataset_path)
38+
return dataset_path
39+
40+
41+
@pytest.fixture
42+
def sample_dataset(sample_dataset_path) -> lance.LanceDataset:
43+
return lance.dataset(sample_dataset_path)
44+
45+
46+
def test_consecutive_index_blocks():
47+
sampler = ShardedFixedBatchSampler(
48+
rank=0,
49+
world_size=TEST_CONFIG["world_size"],
50+
total_num_rows=TEST_CONFIG["total_rows"],
51+
batch_size=TEST_CONFIG["batch_size"],
52+
)
53+
54+
batches = list(sampler)
55+
expected_size = TEST_CONFIG["total_rows"] // (
56+
TEST_CONFIG["world_size"] * TEST_CONFIG["batch_size"]
57+
)
58+
assert len(batches) == expected_size
59+
assert batches[0] == list(range(TEST_CONFIG["batch_size"]))
60+
61+
62+
def _distributed_test_worker(rank, world_size, dataset_path):
63+
import os
64+
65+
import torch
66+
67+
os.environ.update(
68+
{
69+
"MASTER_ADDR": TEST_CONFIG["master_addr"],
70+
"MASTER_PORT": TEST_CONFIG["test_port"],
71+
"CUDA_VISIBLE_DEVICES": ",".join(
72+
map(str, range(torch.cuda.device_count()))
73+
),
74+
}
75+
)
76+
77+
try:
78+
if torch.cuda.is_available():
79+
torch.cuda.set_device(rank % torch.cuda.device_count())
80+
81+
backend = "nccl" if torch.cuda.is_available() else "gloo"
82+
torch.distributed.init_process_group(
83+
backend=backend, world_size=world_size, rank=rank
84+
)
85+
86+
dataset = lance.dataset(dataset_path)
87+
assert len(dataset) == TEST_CONFIG["total_rows"]
88+
89+
sampler = ShardedBatchSampler(
90+
rank=rank,
91+
world_size=world_size,
92+
total_num_rows=TEST_CONFIG["total_rows"],
93+
batch_size=TEST_CONFIG["batch_size"],
94+
)
95+
96+
class DatasetAdapter(torch.utils.data.Dataset):
97+
def __init__(self, dataset):
98+
self.dataset = dataset
99+
100+
def __getitem__(self, index):
101+
return self.dataset.take([index], ["id", "value"]).to_pylist()[0]
102+
103+
def __len__(self):
104+
return len(self.dataset)
105+
106+
def collate_fn(batch):
107+
return {
108+
"ids": torch.tensor([x["id"] for x in batch], dtype=torch.long),
109+
"values": torch.tensor(
110+
[x["value"] for x in batch], dtype=torch.float32
111+
),
112+
}
113+
114+
dataloader = torch.utils.data.DataLoader(
115+
DatasetAdapter(dataset),
116+
batch_sampler=sampler,
117+
collate_fn=collate_fn,
118+
num_workers=0,
119+
)
120+
121+
total = 0
122+
for batch_indices, batch_data in zip(sampler, dataloader):
123+
current_size = batch_data["ids"].size(0)
124+
assert current_size == TEST_CONFIG["batch_size"]
125+
assert batch_data["ids"].tolist() == list(batch_indices)
126+
total += current_size
127+
128+
expected_total = TEST_CONFIG["total_rows"] // world_size
129+
assert total == expected_total
130+
131+
finally:
132+
if torch.distributed.is_initialized():
133+
torch.distributed.destroy_process_group()
134+
135+
136+
@pytest.mark.cuda
137+
def test_pytorch_integration(sample_dataset_path):
138+
import torch
139+
140+
test_world_sizes = [1, 2] if torch.cuda.device_count() >= 2 else [1]
141+
for ws in test_world_sizes:
142+
torch.multiprocessing.spawn(
143+
_distributed_test_worker,
144+
args=(ws, str(sample_dataset_path)),
145+
nprocs=ws,
146+
join=True,
147+
)
148+
149+
150+
def test_data_stream_without_filter(sample_dataset):
151+
"""Validate direct data loading without filters."""
152+
sampler = ShardedFixedBatchSampler(0, 4)
153+
batches = list(sampler(sample_dataset, batch_size=250, columns=["id", "value"]))
154+
155+
# Data integrity checks
156+
batch = batches[0]
157+
assert batch.num_rows == 250, "Batch should contain 250 records"
158+
assert batch.column_names == ["id", "value"], "Should load specified columns"
159+
160+
# Consecutive ID validation
161+
ids = batch["id"].to_numpy()
162+
assert np.array_equal(ids, np.arange(0, 250)), "IDs should be sequential 0-249"
163+
164+
165+
def test_filtered_data_handling(sample_dataset):
166+
"""Test filtered data processing with sharding."""
167+
# Apply ID filter and load data
168+
sampler = ShardedFixedBatchSampler(0, 4)
169+
batches = list(
170+
sampler(sample_dataset, batch_size=100, filter="id < 500", columns=["id"])
171+
)
172+
173+
# Aggregated results validation
174+
all_ids = []
175+
for batch in batches:
176+
all_ids.extend(batch["id"].to_numpy().tolist())
177+
178+
# Filter and sharding assertions
179+
assert all(id_val < 500 for id_val in all_ids), "Should respect ID filter"
180+
assert all(id_val % 4 == 0 for id_val in all_ids), "Should keep rank 0 shard"
181+
182+
183+
def test_randomization_effect():
184+
"""Verify epoch-based randomization behavior."""
185+
# Initialize randomized sampler
186+
sampler = ShardedFixedBatchSampler(
187+
rank=0,
188+
world_size=4,
189+
total_num_rows=2000,
190+
batch_size=250,
191+
randomize=True,
192+
seed=42,
193+
)
194+
195+
assert len(list(sampler)) > 1
196+
197+
# Cross-epoch comparison
198+
sampler.set_epoch(1)
199+
epoch1 = list(sampler)
200+
sampler.set_epoch(2)
201+
epoch2 = list(sampler)
202+
203+
assert epoch1 != epoch2, "Different epochs should produce different orders"
204+
205+
206+
def test_edge_cases():
207+
"""Validate handling of partial batches and data boundaries."""
208+
209+
sampler = ShardedFixedBatchSampler(
210+
rank=3, world_size=4, batch_size=250, total_num_rows=1000
211+
)
212+
batches = list(sampler)
213+
assert len(batches) == 1, "Should handle partial batch"
214+
assert batches[0] == list(range(750, 1000)), "Last rank should get 750-999"
215+
216+
sampler = ShardedFixedBatchSampler(
217+
rank=0, world_size=2, batch_size=128, total_num_rows=500
218+
)
219+
batches = list(sampler)
220+
# rank 0: 0~249, rank 1: 250~499
221+
# rank 0: [0-127], [128-249]
222+
assert batches[0] == list(range(0, 128))
223+
assert batches[1] == list(range(128, 250))
224+
225+
# total_num_rows < batch_size
226+
sampler = ShardedFixedBatchSampler(
227+
rank=0, world_size=1, batch_size=250, total_num_rows=100
228+
)
229+
batches = list(sampler)
230+
assert len(batches) == 1
231+
assert batches[0] == list(range(0, 100))
232+
233+
# total_num_rows < world_size
234+
sampler = ShardedFixedBatchSampler(
235+
rank=2, world_size=4, batch_size=10, total_num_rows=2
236+
)
237+
batches = list(sampler)
238+
assert len(batches) == 0, "No data for this rank"
239+
240+
# batch_size=1
241+
sampler = ShardedFixedBatchSampler(
242+
rank=0, world_size=2, batch_size=1, total_num_rows=4
243+
)
244+
batches = list(sampler)
245+
assert batches == [[0], [1]]
246+
247+
# world_size=1
248+
sampler = ShardedFixedBatchSampler(
249+
rank=0, world_size=1, batch_size=3, total_num_rows=5
250+
)
251+
batches = list(sampler)
252+
assert batches == [list(range(0, 3)), list(range(3, 5))]
11253

12254

13255
# We use + 97 to test case where num_rows and chunk_size aren't exactly aligned.

0 commit comments

Comments
 (0)