Skip to content

Commit 6f0c6ae

Browse files
Optimize merge sort, quick sort, and knapsack algorithms
- sorts/merge_sort.py: Replace O(n) pop(0) calls in merge() with O(1) index-based traversal, fixing overall complexity from O(n² log n) to the correct O(n log n). Added type hints and comprehensive doctests. - sorts/quick_sort.py: Fix input mutation bug where collection.pop() destroyed the caller's original list. Use median-of-three pivot selection with three-way partitioning for better worst-case behavior on sorted/nearly-sorted inputs. Added type hints and doctests. - dynamic_programming/knapsack.py: Remove fragile global state from mf_knapsack() by passing the memoization table explicitly. Added type hints and doctests to knapsack() and mf_knapsack(). Added new knapsack_optimized() function with O(W) space complexity using a 1-D rolling array.
1 parent 68473af commit 6f0c6ae

File tree

3 files changed

+225
-43
lines changed

3 files changed

+225
-43
lines changed

dynamic_programming/knapsack.py

Lines changed: 133 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -4,29 +4,91 @@
44
55
Note that only the integer weights 0-1 knapsack problem is solvable
66
using dynamic programming.
7+
8+
This module provides multiple approaches:
9+
- ``knapsack``: Bottom-up DP with O(n*W) space and solution reconstruction.
10+
- ``knapsack_with_example_solution``: Wrapper that returns optimal value and subset.
11+
- ``knapsack_optimized``: Space-optimized bottom-up DP using O(W) space (value only).
12+
- ``mf_knapsack``: Top-down memoized approach (memory function) with no global state.
713
"""
814

915

10-
def mf_knapsack(i, wt, val, j):
16+
def mf_knapsack(
17+
i: int,
18+
wt: list[int],
19+
val: list[int],
20+
j: int,
21+
memo: list[list[int]] | None = None,
22+
) -> int:
1123
"""
12-
This code involves the concept of memory functions. Here we solve the subproblems
13-
which are needed unlike the below example
14-
F is a 2D array with ``-1`` s filled up
24+
Solve the 0-1 knapsack problem using top-down memoization (memory function).
25+
26+
Unlike the previous implementation, this version does **not** rely on a global
27+
``f`` table. The memoization table is passed explicitly or created on first call.
28+
29+
:param i: Number of items to consider (1-indexed).
30+
:param wt: List of item weights.
31+
:param val: List of item values.
32+
:param j: Remaining knapsack capacity.
33+
:param memo: Optional pre-allocated memoization table of shape ``(i+1) x (j+1)``
34+
initialised with ``-1`` for unsolved sub-problems and ``0`` for base cases.
35+
When ``None`` a table is created automatically.
36+
:return: Maximum obtainable value considering items ``1..i`` with capacity ``j``.
37+
38+
Examples:
39+
>>> mf_knapsack(4, [4, 3, 2, 3], [3, 2, 4, 4], 6)
40+
8
41+
>>> mf_knapsack(0, [1, 2], [10, 20], 5)
42+
0
43+
>>> mf_knapsack(3, [1, 3, 5], [10, 20, 100], 10)
44+
130
45+
>>> mf_knapsack(1, [5], [50], 3)
46+
0
47+
>>> mf_knapsack(1, [5], [50], 5)
48+
50
1549
"""
16-
global f # a global dp table for knapsack
17-
if f[i][j] < 0:
18-
if j < wt[i - 1]:
19-
val = mf_knapsack(i - 1, wt, val, j)
20-
else:
21-
val = max(
22-
mf_knapsack(i - 1, wt, val, j),
23-
mf_knapsack(i - 1, wt, val, j - wt[i - 1]) + val[i - 1],
24-
)
25-
f[i][j] = val
26-
return f[i][j]
50+
if memo is None:
51+
memo = [[0] * (j + 1)] + [[0] + [-1] * j for _ in range(i)]
52+
53+
if i == 0 or j == 0:
54+
return 0
2755

56+
if memo[i][j] >= 0:
57+
return memo[i][j]
2858

29-
def knapsack(w, wt, val, n):
59+
if j < wt[i - 1]:
60+
memo[i][j] = mf_knapsack(i - 1, wt, val, j, memo)
61+
else:
62+
memo[i][j] = max(
63+
mf_knapsack(i - 1, wt, val, j, memo),
64+
mf_knapsack(i - 1, wt, val, j - wt[i - 1], memo) + val[i - 1],
65+
)
66+
return memo[i][j]
67+
68+
69+
def knapsack(
70+
w: int, wt: list[int], val: list[int], n: int
71+
) -> tuple[int, list[list[int]]]:
72+
"""
73+
Solve the 0-1 knapsack problem using bottom-up dynamic programming.
74+
75+
:param w: Maximum knapsack capacity.
76+
:param wt: List of item weights.
77+
:param val: List of item values.
78+
:param n: Number of items.
79+
:return: A tuple ``(optimal_value, dp_table)`` where ``dp_table`` can be used
80+
for solution reconstruction via ``_construct_solution``.
81+
82+
Examples:
83+
>>> knapsack(6, [4, 3, 2, 3], [3, 2, 4, 4], 4)[0]
84+
8
85+
>>> knapsack(10, [1, 3, 5, 2], [10, 20, 100, 22], 4)[0]
86+
142
87+
>>> knapsack(0, [1, 2], [10, 20], 2)[0]
88+
0
89+
>>> knapsack(5, [], [], 0)[0]
90+
0
91+
"""
3092
dp = [[0] * (w + 1) for _ in range(n + 1)]
3193

3294
for i in range(1, n + 1):
@@ -36,10 +98,51 @@ def knapsack(w, wt, val, n):
3698
else:
3799
dp[i][w_] = dp[i - 1][w_]
38100

39-
return dp[n][w_], dp
101+
return dp[n][w], dp
40102

41103

42-
def knapsack_with_example_solution(w: int, wt: list, val: list):
104+
def knapsack_optimized(w: int, wt: list[int], val: list[int], n: int) -> int:
105+
"""
106+
Solve the 0-1 knapsack problem using space-optimized bottom-up DP.
107+
108+
Uses a single 1-D array of size ``w + 1`` instead of a 2-D ``(n+1) x (w+1)``
109+
table, reducing space complexity from O(n*W) to O(W).
110+
111+
.. note::
112+
This variant returns only the optimal value; it does **not** support
113+
solution reconstruction (i.e. which items are included).
114+
115+
:param w: Maximum knapsack capacity.
116+
:param wt: List of item weights.
117+
:param val: List of item values.
118+
:param n: Number of items.
119+
:return: Maximum obtainable value.
120+
121+
Examples:
122+
>>> knapsack_optimized(6, [4, 3, 2, 3], [3, 2, 4, 4], 4)
123+
8
124+
>>> knapsack_optimized(10, [1, 3, 5, 2], [10, 20, 100, 22], 4)
125+
142
126+
>>> knapsack_optimized(0, [1, 2], [10, 20], 2)
127+
0
128+
>>> knapsack_optimized(5, [], [], 0)
129+
0
130+
>>> knapsack_optimized(50, [10, 20, 30], [60, 100, 120], 3)
131+
220
132+
"""
133+
dp = [0] * (w + 1)
134+
135+
for i in range(n):
136+
# Traverse capacity in reverse so each item is used at most once
137+
for capacity in range(w, wt[i] - 1, -1):
138+
dp[capacity] = max(dp[capacity], dp[capacity - wt[i]] + val[i])
139+
140+
return dp[w]
141+
142+
143+
def knapsack_with_example_solution(
144+
w: int, wt: list[int], val: list[int]
145+
) -> tuple[int, set[int]]:
43146
"""
44147
Solves the integer weights knapsack problem returns one of
45148
the several possible optimal subsets.
@@ -94,13 +197,15 @@ def knapsack_with_example_solution(w: int, wt: list, val: list):
94197
raise TypeError(msg)
95198

96199
optimal_val, dp_table = knapsack(w, wt, val, num_items)
97-
example_optional_set: set = set()
200+
example_optional_set: set[int] = set()
98201
_construct_solution(dp_table, wt, num_items, w, example_optional_set)
99202

100203
return optimal_val, example_optional_set
101204

102205

103-
def _construct_solution(dp: list, wt: list, i: int, j: int, optimal_set: set):
206+
def _construct_solution(
207+
dp: list[list[int]], wt: list[int], i: int, j: int, optimal_set: set[int]
208+
) -> None:
104209
"""
105210
Recursively reconstructs one of the optimal subsets given
106211
a filled DP table and the vector of weights
@@ -135,14 +240,20 @@ def _construct_solution(dp: list, wt: list, i: int, j: int, optimal_set: set):
135240
"""
136241
Adding test case for knapsack
137242
"""
243+
import doctest
244+
245+
doctest.testmod()
246+
138247
val = [3, 2, 4, 4]
139248
wt = [4, 3, 2, 3]
140249
n = 4
141250
w = 6
142-
f = [[0] * (w + 1)] + [[0] + [-1] * (w + 1) for _ in range(n + 1)]
143251
optimal_solution, _ = knapsack(w, wt, val, n)
144252
print(optimal_solution)
145-
print(mf_knapsack(n, wt, val, w)) # switched the n and w
253+
print(mf_knapsack(n, wt, val, w))
254+
255+
# Space-optimized knapsack
256+
print(f"Optimized: {knapsack_optimized(w, wt, val, n)}")
146257

147258
# testing the dynamic programming problem with example
148259
# the optimal subset for the above example are items 3 and 4

sorts/merge_sort.py

Lines changed: 48 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,10 @@
99
python merge_sort.py
1010
"""
1111

12+
from typing import Any
1213

13-
def merge_sort(collection: list) -> list:
14+
15+
def merge_sort(collection: list[Any]) -> list[Any]:
1416
"""
1517
Sorts a list using the merge sort algorithm.
1618
@@ -27,21 +29,56 @@ def merge_sort(collection: list) -> list:
2729
[]
2830
>>> merge_sort([-2, -5, -45])
2931
[-45, -5, -2]
32+
>>> merge_sort([1])
33+
[1]
34+
>>> merge_sort([1, 2, 3, 4, 5])
35+
[1, 2, 3, 4, 5]
36+
>>> merge_sort([5, 4, 3, 2, 1])
37+
[1, 2, 3, 4, 5]
38+
>>> merge_sort([3, 3, 3, 3])
39+
[3, 3, 3, 3]
40+
>>> merge_sort(['d', 'a', 'b', 'e', 'c'])
41+
['a', 'b', 'c', 'd', 'e']
42+
>>> merge_sort([1.1, 0.5, 3.3, 2.2])
43+
[0.5, 1.1, 2.2, 3.3]
44+
>>> import random
45+
>>> collection_arg = random.sample(range(-50, 50), 100)
46+
>>> merge_sort(collection_arg) == sorted(collection_arg)
47+
True
3048
"""
3149

32-
def merge(left: list, right: list) -> list:
50+
def merge(left: list[Any], right: list[Any]) -> list[Any]:
3351
"""
34-
Merge two sorted lists into a single sorted list.
52+
Merge two sorted lists into a single sorted list using index-based
53+
traversal instead of pop(0) to achieve O(n) merge performance.
54+
55+
:param left: Left sorted collection
56+
:param right: Right sorted collection
57+
:return: Merged sorted result
3558
36-
:param left: Left collection
37-
:param right: Right collection
38-
:return: Merged result
59+
>>> merge([1, 3, 5], [2, 4, 6])
60+
[1, 2, 3, 4, 5, 6]
61+
>>> merge([], [1, 2])
62+
[1, 2]
63+
>>> merge([1], [])
64+
[1]
65+
>>> merge([], [])
66+
[]
3967
"""
40-
result = []
41-
while left and right:
42-
result.append(left.pop(0) if left[0] <= right[0] else right.pop(0))
43-
result.extend(left)
44-
result.extend(right)
68+
result: list[Any] = []
69+
left_index, right_index = 0, 0
70+
71+
while left_index < len(left) and right_index < len(right):
72+
if left[left_index] <= right[right_index]:
73+
result.append(left[left_index])
74+
left_index += 1
75+
else:
76+
result.append(right[right_index])
77+
right_index += 1
78+
79+
# Append any remaining elements from either list
80+
result.extend(left[left_index:])
81+
result.extend(right[right_index:])
4582
return result
4683

4784
if len(collection) <= 1:

sorts/quick_sort.py

Lines changed: 44 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,18 @@
1010

1111
from __future__ import annotations
1212

13-
from random import randrange
13+
from typing import Any
1414

1515

16-
def quick_sort(collection: list) -> list:
16+
def quick_sort(collection: list[Any]) -> list[Any]:
1717
"""A pure Python implementation of quicksort algorithm.
1818
19+
This implementation does not mutate the original collection. It uses
20+
median-of-three pivot selection for improved performance on sorted or
21+
nearly-sorted inputs.
22+
1923
:param collection: a mutable collection of comparable items
20-
:return: the same collection ordered in ascending order
24+
:return: a new list with the same elements ordered in ascending order
2125
2226
Examples:
2327
>>> quick_sort([0, 5, 3, 2, 2])
@@ -26,24 +30,54 @@ def quick_sort(collection: list) -> list:
2630
[]
2731
>>> quick_sort([-2, 5, 0, -45])
2832
[-45, -2, 0, 5]
33+
>>> quick_sort([1])
34+
[1]
35+
>>> quick_sort([1, 2, 3, 4, 5])
36+
[1, 2, 3, 4, 5]
37+
>>> quick_sort([5, 4, 3, 2, 1])
38+
[1, 2, 3, 4, 5]
39+
>>> quick_sort([3, 3, 3, 3])
40+
[3, 3, 3, 3]
41+
>>> quick_sort(['d', 'a', 'b', 'e', 'c'])
42+
['a', 'b', 'c', 'd', 'e']
43+
>>> quick_sort([1.1, 0.5, 3.3, 2.2])
44+
[0.5, 1.1, 2.2, 3.3]
45+
>>> original = [3, 1, 2]
46+
>>> sorted_list = quick_sort(original)
47+
>>> original
48+
[3, 1, 2]
49+
>>> sorted_list
50+
[1, 2, 3]
51+
>>> import random
52+
>>> collection_arg = random.sample(range(-50, 50), 100)
53+
>>> quick_sort(collection_arg) == sorted(collection_arg)
54+
True
2955
"""
3056
# Base case: if the collection has 0 or 1 elements, it is already sorted
3157
if len(collection) < 2:
3258
return collection
3359

34-
# Randomly select a pivot index and remove the pivot element from the collection
35-
pivot_index = randrange(len(collection))
36-
pivot = collection.pop(pivot_index)
60+
# Use median-of-three pivot selection for better worst-case performance.
61+
# Compare the first, middle, and last elements and pick the median value.
62+
first = collection[0]
63+
middle = collection[len(collection) // 2]
64+
last = collection[-1]
65+
pivot = sorted((first, middle, last))[1]
3766

38-
# Partition the remaining elements into two groups: lesser or equal, and greater
39-
lesser = [item for item in collection if item <= pivot]
67+
# Partition elements into three groups without mutating the original list
68+
lesser = [item for item in collection if item < pivot]
69+
equal = [item for item in collection if item == pivot]
4070
greater = [item for item in collection if item > pivot]
4171

42-
# Recursively sort the lesser and greater groups, and combine with the pivot
43-
return [*quick_sort(lesser), pivot, *quick_sort(greater)]
72+
# Recursively sort the lesser and greater groups, and combine with equal
73+
return [*quick_sort(lesser), *equal, *quick_sort(greater)]
4474

4575

4676
if __name__ == "__main__":
77+
import doctest
78+
79+
doctest.testmod()
80+
4781
# Get user input and convert it into a list of integers
4882
user_input = input("Enter numbers separated by a comma:\n").strip()
4983
unsorted = [int(item) for item in user_input.split(",")]

0 commit comments

Comments
 (0)