Skip to content

Commit 71d9bef

Browse files
committed
Add Reservoir Sampling algorithm for random selection from data streams
1 parent 788d95b commit 71d9bef

File tree

1 file changed

+65
-0
lines changed

1 file changed

+65
-0
lines changed

other/reservoir_sampling.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
"""
2+
reservoir_sampling.py
3+
4+
An implementation of Reservoir Sampling — a random algorithm to
5+
select `k` items from a stream of unknown or very large size with equal probability.
6+
7+
Reference:
8+
https://en.wikipedia.org/wiki/Reservoir_sampling
9+
10+
Example:
11+
>>> data_stream = [1, 2, 3, 4, 5, 6, 7, 8, 9]
12+
>>> len(reservoir_sampling(data_stream, 3))
13+
3
14+
>>> all(isinstance(i, int) for i in reservoir_sampling(data_stream, 3))
15+
True
16+
"""
17+
18+
import random
19+
from typing import Iterable, List, TypeVar
20+
21+
T = TypeVar("T")
22+
23+
24+
def reservoir_sampling(stream: Iterable[T], k: int) -> List[T]:
25+
"""
26+
Return a random sample of size `k` from the given data stream.
27+
28+
:param stream: An iterable data stream (e.g., list, generator)
29+
:param k: Number of elements to sample
30+
:return: A list of `k` randomly selected items
31+
32+
>>> data = [10, 20, 30, 40, 50]
33+
>>> len(reservoir_sampling(data, 2))
34+
2
35+
>>> isinstance(reservoir_sampling(data, 3), list)
36+
True
37+
>>> try:
38+
... reservoir_sampling([], 1)
39+
... except ValueError:
40+
... print("Error")
41+
Error
42+
"""
43+
if k <= 0:
44+
raise ValueError("Sample size k must be greater than zero")
45+
46+
reservoir = []
47+
for i, item in enumerate(stream):
48+
if i < k:
49+
reservoir.append(item)
50+
else:
51+
j = random.randint(0, i)
52+
if j < k:
53+
reservoir[j] = item
54+
55+
if len(reservoir) < k:
56+
raise ValueError("Stream has fewer elements than the requested sample size")
57+
58+
return reservoir
59+
60+
61+
if __name__ == "__main__":
62+
# Example usage
63+
data_stream = range(1, 100)
64+
sample = reservoir_sampling(data_stream, 5)
65+
print("Random sample from stream:", sample)

0 commit comments

Comments
 (0)