Skip to content

Commit 2a8bc0f

Browse files
committed
perf(python): optimize _efficient_sample to use sorted random indices with chunked take
1 parent 1a1094d commit 2a8bc0f

3 files changed

Lines changed: 41 additions & 57 deletions

File tree

python/python/lance/sampler.py

Lines changed: 37 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -48,10 +48,15 @@ def _efficient_sample(
4848
n: int,
4949
columns: Optional[Union[List[str], Dict[str, str]]],
5050
batch_size: int,
51-
max_takes: int,
5251
) -> Generator[pa.RecordBatch, None, None]:
5352
"""Sample n records from the dataset.
5453
54+
Mirrors the Rust ``sample_fsl_uniform`` strategy: generate n uniformly
55+
random indices, sort them, and take in large contiguous chunks (default
56+
8192 rows per take). Sorting allows the underlying object store to merge
57+
adjacent row reads into fewer, larger range requests, which drastically
58+
reduces I/O latency on remote storage (e.g. S3).
59+
5560
Parameters
5661
----------
5762
dataset : lance.LanceDataset
@@ -61,55 +66,47 @@ def _efficient_sample(
6166
columns : list[str]
6267
The columns to load.
6368
batch_size : int
64-
The batch size to use when loading the data.
65-
max_takes : int
66-
The maximum number of takes to perform. This is used to limit the number of
67-
random reads. Large enough value can give a good random sample without
68-
having to issue too many random reads.
69+
The batch size to use when yielding output RecordBatches.
6970
7071
Returns
7172
-------
7273
Generator of a RecordBatch.
7374
"""
74-
buf: list[pa.RecordBatch] = []
7575
total_records = len(dataset)
7676
assert total_records > n
77-
chunk_size = total_records // max_takes
78-
chunk_sample_size = n // max_takes
79-
80-
num_sampled = 0
8177

82-
for idx, i in enumerate(range(0, total_records, chunk_size)):
83-
# If we have already sampled enough, break. This can happen if there
84-
# is a remainder in the division.
85-
if num_sampled >= n:
86-
break
87-
num_sampled += chunk_sample_size
78+
indices = np.random.choice(total_records, n, replace=False)
79+
indices.sort()
8880

89-
# If we are at the last chunk, we may not have enough records to sample.
90-
local_size = min(chunk_size, total_records - i)
91-
local_sample_size = min(chunk_sample_size, local_size)
81+
LOGGER.info(
82+
"Sampling %d rows from %d total (sorted random indices, chunk take)",
83+
n,
84+
total_records,
85+
)
9286

93-
if local_sample_size < local_size:
94-
# Add more randomness within each chunk, if there is room.
95-
offset = i + np.random.randint(0, local_size - local_sample_size)
96-
else:
97-
offset = i
87+
take_chunk_size = 8192
88+
buf: list[pa.RecordBatch] = []
9889

99-
buf.extend(
100-
dataset.take(
101-
list(range(offset, offset + local_sample_size)),
102-
columns=columns,
103-
).to_batches()
90+
for start in range(0, len(indices), take_chunk_size):
91+
chunk = indices[start : start + take_chunk_size].tolist()
92+
buf.extend(dataset.take(chunk, columns=columns).to_batches())
93+
LOGGER.info(
94+
"Sampled chunk %d/%d, rows %d-%d",
95+
start // take_chunk_size + 1,
96+
math.ceil(len(indices) / take_chunk_size),
97+
chunk[0],
98+
chunk[-1],
10499
)
105-
if idx % 50 == 0:
106-
LOGGER.info("Sampled at offset=%s, len=%s", offset, chunk_sample_size)
107-
if sum(len(b) for b in buf) >= batch_size:
100+
while sum(len(b) for b in buf) >= batch_size:
108101
tbl = pa.Table.from_batches(buf)
109102
buf.clear()
110-
tbl = tbl.combine_chunks()
111-
yield tbl.to_batches()[0]
112-
del tbl
103+
batch_tbl = tbl.slice(0, batch_size).combine_chunks()
104+
rest_tbl = tbl.slice(batch_size)
105+
yield batch_tbl.to_batches()[0]
106+
del batch_tbl
107+
if rest_tbl.num_rows > 0:
108+
buf.extend(rest_tbl.to_batches())
109+
del rest_tbl, tbl
113110
if buf:
114111
tbl = pa.Table.from_batches(buf).combine_chunks()
115112
yield tbl.to_batches()[0]
@@ -121,10 +118,10 @@ def _filtered_efficient_sample(
121118
n: int,
122119
columns: List[str],
123120
batch_size: int,
124-
target_takes: int,
125121
filter: str,
126122
) -> Generator[pa.RecordBatch, None, None]:
127123
total_records = len(dataset)
124+
target_takes = max(1, n // 32)
128125
shard_size = math.ceil(n / target_takes)
129126
num_shards = math.ceil(total_records / shard_size)
130127

@@ -189,10 +186,7 @@ def maybe_sample(
189186
batch_size : int, optional
190187
The batch size to use when loading the data, by default 10240.
191188
max_takes : int, optional
192-
The maximum number of takes to perform, by default 2048.
193-
This is employed to minimize the number of random reads necessary for sampling.
194-
A sufficiently large value can provide an effective random sample without
195-
the need for excessive random reads.
189+
Deprecated and ignored. Kept for API compatibility only.
196190
filter : str, optional
197191
The filter to apply to the dataset, by default None. If a filter is provided,
198192
then we will first load all row ids in memory and then batch through the ids
@@ -215,19 +209,9 @@ def maybe_sample(
215209
columns=columns, batch_size=batch_size, filter=filt
216210
)
217211
elif filt is not None:
218-
yield from _filtered_efficient_sample(
219-
dataset, n, columns, batch_size, max_takes, filt
220-
)
221-
elif n > max_takes:
222-
yield from _efficient_sample(dataset, n, columns, batch_size, max_takes)
212+
yield from _filtered_efficient_sample(dataset, n, columns, batch_size, filt)
223213
else:
224-
choices = np.random.choice(len(dataset), n, replace=False)
225-
idx = 0
226-
while idx < len(choices):
227-
end = min(idx + batch_size, len(choices))
228-
tbl = dataset.take(choices[idx:end], columns=columns).combine_chunks()
229-
yield tbl.to_batches()[0]
230-
idx += batch_size
214+
yield from _efficient_sample(dataset, n, columns, batch_size)
231215

232216

233217
T = TypeVar("T")

python/python/tests/test_sampler.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
"master_addr": "127.0.0.1",
1919
"seed": 42,
2020
"test_shard_ratio": 0.5,
21-
"max_takes_factor": 0.1,
21+
2222
}
2323

2424

@@ -270,8 +270,8 @@ def test_sample_dataset(tmp_path: Path, nrows: int):
270270
assert simple_scan[0].schema == pa.schema([pa.field("vec", fsl.type)])
271271
assert simple_scan[0].num_rows == min(nrows, 128)
272272

273-
# Random path.
274-
large_scan = list(maybe_sample(ds, 128, ["vec"], max_takes=32))
273+
# Sorted-index take path (n < len(dataset)).
274+
large_scan = list(maybe_sample(ds, 128, ["vec"]))
275275
assert len(large_scan) == 1
276276
assert isinstance(large_scan[0], pa.RecordBatch)
277277
assert large_scan[0].schema == pa.schema([pa.field("vec", fsl.type)])

python/python/tests/torch_tests/test_data.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def iter_over_dataset(tmp_path):
6262
assert batch["vec"].shape[1] == 32
6363
assert total_rows == 1024
6464

65-
# test when sample size is greater than max_takes
65+
# test larger sample size
6666
torch_ds = LanceDataset(
6767
ds,
6868
batch_size=256,

0 commit comments

Comments
 (0)