Skip to content

Commit d0751ed

Browse files
committed
#17 median_of_3_partition
1 parent b3895d9 commit d0751ed

2 files changed

Lines changed: 42 additions & 0 deletions

File tree

src/book/chapter7/section4.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from book.chapter7.section1 import partition
22
from book.data_structures import Array
33
from book.data_structures import CT
4+
from solutions.chapter5.section1.exercise2 import random
45
from util import range_of
56

67

@@ -26,3 +27,28 @@ def insertion_quicksort(A: Array[CT], p: int, r: int, k: int) -> None:
2627
A[j + 1] = A[j]
2728
j -= 1
2829
A[j + 1] = key
30+
31+
32+
def median_of_3_partition(A: Array[CT], p: int, r: int) -> int:
33+
"""Partitions an array into two subarrays, the low side and the high side, such that each element in the low side of
34+
the partition is less than or equal to the pivot value, which is, in turn, less than or equal to each element in the
35+
high side. Uses a median of randomly picked three elements as the pivot.
36+
37+
Args:
38+
A: an Array to partition
39+
p: the lower index of the subarray to partition
40+
r: the upper index of the subarray to partition
41+
42+
Returns:
43+
The index q, such that each element in A[p:q - 1] is less than or equal to A[q], and that A[q] is less than or
44+
equal to each element in A[q + 1:r].
45+
"""
46+
i1, i2, i3 = random(p, r), random(p, r), random(p, r)
47+
if A[i2] <= A[i1] <= A[i3] or A[i3] <= A[i1] <= A[i2]:
48+
m = i1
49+
elif A[i1] <= A[i2] <= A[i3] or A[i3] <= A[i2] <= A[i1]:
50+
m = i2
51+
else:
52+
m = i3
53+
A[m], A[r] = A[r], A[m]
54+
return partition(A, p, r)

test/test_book/test_chapter7.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,10 @@
66
from book.chapter7.section1 import quicksort
77
from book.chapter7.section3 import randomized_quicksort
88
from book.chapter7.section4 import insertion_quicksort
9+
from book.chapter7.section4 import median_of_3_partition
910
from test_case import ClrsTestCase
1011
from test_util import create_array
12+
from util import range_of
1113

1214

1315
class TestChapter7(ClrsTestCase):
@@ -45,3 +47,17 @@ def test_insertion_quicksort(self, data):
4547

4648
self.assertArraySorted(A, end=n)
4749
self.assertArrayPermuted(A, elements, end=n)
50+
51+
@given(st.data())
52+
def test_median_of_3_partition(self, data):
53+
elements = data.draw(lists(integers(), min_size=1))
54+
A = create_array(elements)
55+
n = len(elements)
56+
57+
actual_pivot_index = median_of_3_partition(A, 1, n)
58+
59+
for i in range_of(1, to=actual_pivot_index):
60+
self.assertLessEqual(A[i], A[actual_pivot_index])
61+
for i in range_of(actual_pivot_index + 1, to=n):
62+
self.assertGreaterEqual(A[i], A[actual_pivot_index])
63+
self.assertArrayPermuted(A, elements, end=n)

0 commit comments

Comments
 (0)