@@ -48,10 +48,15 @@ def _efficient_sample(
4848 n : int ,
4949 columns : Optional [Union [List [str ], Dict [str , str ]]],
5050 batch_size : int ,
51- max_takes : int ,
5251) -> Generator [pa .RecordBatch , None , None ]:
5352 """Sample n records from the dataset.
5453
54+ Mirrors the Rust ``sample_fsl_uniform`` strategy: generate n uniformly
55+ random indices, sort them, and take in large contiguous chunks (default
56+ 8192 rows per take). Sorting allows the underlying object store to merge
57+ adjacent row reads into fewer, larger range requests, which drastically
58+ reduces I/O latency on remote storage (e.g. S3).
59+
5560 Parameters
5661 ----------
5762 dataset : lance.LanceDataset
@@ -61,55 +66,47 @@ def _efficient_sample(
6166 columns : list[str]
6267 The columns to load.
6368 batch_size : int
64- The batch size to use when loading the data.
65- max_takes : int
66- The maximum number of takes to perform. This is used to limit the number of
67- random reads. Large enough value can give a good random sample without
68- having to issue too many random reads.
69+ The batch size to use when yielding output RecordBatches.
6970
7071 Returns
7172 -------
7273 Generator of a RecordBatch.
7374 """
74- buf : list [pa .RecordBatch ] = []
7575 total_records = len (dataset )
7676 assert total_records > n
77- chunk_size = total_records // max_takes
78- chunk_sample_size = n // max_takes
79-
80- num_sampled = 0
8177
82- for idx , i in enumerate (range (0 , total_records , chunk_size )):
83- # If we have already sampled enough, break. This can happen if there
84- # is a remainder in the division.
85- if num_sampled >= n :
86- break
87- num_sampled += chunk_sample_size
78+ indices = np .random .choice (total_records , n , replace = False )
79+ indices .sort ()
8880
89- # If we are at the last chunk, we may not have enough records to sample.
90- local_size = min (chunk_size , total_records - i )
91- local_sample_size = min (chunk_sample_size , local_size )
81+ LOGGER .info (
82+ "Sampling %d rows from %d total (sorted random indices, chunk take)" ,
83+ n ,
84+ total_records ,
85+ )
9286
93- if local_sample_size < local_size :
94- # Add more randomness within each chunk, if there is room.
95- offset = i + np .random .randint (0 , local_size - local_sample_size )
96- else :
97- offset = i
87+ take_chunk_size = 8192
88+ buf : list [pa .RecordBatch ] = []
9889
99- buf .extend (
100- dataset .take (
101- list (range (offset , offset + local_sample_size )),
102- columns = columns ,
103- ).to_batches ()
90+ for start in range (0 , len (indices ), take_chunk_size ):
91+ chunk = indices [start : start + take_chunk_size ].tolist ()
92+ buf .extend (dataset .take (chunk , columns = columns ).to_batches ())
93+ LOGGER .info (
94+ "Sampled chunk %d/%d, rows %d-%d" ,
95+ start // take_chunk_size + 1 ,
96+ math .ceil (len (indices ) / take_chunk_size ),
97+ chunk [0 ],
98+ chunk [- 1 ],
10499 )
105- if idx % 50 == 0 :
106- LOGGER .info ("Sampled at offset=%s, len=%s" , offset , chunk_sample_size )
107- if sum (len (b ) for b in buf ) >= batch_size :
100+ while sum (len (b ) for b in buf ) >= batch_size :
108101 tbl = pa .Table .from_batches (buf )
109102 buf .clear ()
110- tbl = tbl .combine_chunks ()
111- yield tbl .to_batches ()[0 ]
112- del tbl
103+ batch_tbl = tbl .slice (0 , batch_size ).combine_chunks ()
104+ rest_tbl = tbl .slice (batch_size )
105+ yield batch_tbl .to_batches ()[0 ]
106+ del batch_tbl
107+ if rest_tbl .num_rows > 0 :
108+ buf .extend (rest_tbl .to_batches ())
109+ del rest_tbl , tbl
113110 if buf :
114111 tbl = pa .Table .from_batches (buf ).combine_chunks ()
115112 yield tbl .to_batches ()[0 ]
@@ -121,10 +118,10 @@ def _filtered_efficient_sample(
121118 n : int ,
122119 columns : List [str ],
123120 batch_size : int ,
124- target_takes : int ,
125121 filter : str ,
126122) -> Generator [pa .RecordBatch , None , None ]:
127123 total_records = len (dataset )
124+ target_takes = max (1 , n // 32 )
128125 shard_size = math .ceil (n / target_takes )
129126 num_shards = math .ceil (total_records / shard_size )
130127
@@ -189,10 +186,7 @@ def maybe_sample(
189186 batch_size : int, optional
190187 The batch size to use when loading the data, by default 10240.
191188 max_takes : int, optional
192- The maximum number of takes to perform, by default 2048.
193- This is employed to minimize the number of random reads necessary for sampling.
194- A sufficiently large value can provide an effective random sample without
195- the need for excessive random reads.
189+ Deprecated and ignored. Kept for API compatibility only.
196190 filter : str, optional
197191 The filter to apply to the dataset, by default None. If a filter is provided,
198192 then we will first load all row ids in memory and then batch through the ids
@@ -215,19 +209,9 @@ def maybe_sample(
215209 columns = columns , batch_size = batch_size , filter = filt
216210 )
217211 elif filt is not None :
218- yield from _filtered_efficient_sample (
219- dataset , n , columns , batch_size , max_takes , filt
220- )
221- elif n > max_takes :
222- yield from _efficient_sample (dataset , n , columns , batch_size , max_takes )
212+ yield from _filtered_efficient_sample (dataset , n , columns , batch_size , filt )
223213 else :
224- choices = np .random .choice (len (dataset ), n , replace = False )
225- idx = 0
226- while idx < len (choices ):
227- end = min (idx + batch_size , len (choices ))
228- tbl = dataset .take (choices [idx :end ], columns = columns ).combine_chunks ()
229- yield tbl .to_batches ()[0 ]
230- idx += batch_size
214+ yield from _efficient_sample (dataset , n , columns , batch_size )
231215
232216
233217T = TypeVar ("T" )
0 commit comments