Skip to content

Commit a01c8af

Browse files
committed
change the type of Matrix entries to float
1 parent 945d25f commit a01c8af

4 files changed

Lines changed: 24 additions & 21 deletions

File tree

src/book/data_structures.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import math
34
from builtins import len
45
from collections.abc import Callable
56
from typing import Any
@@ -70,10 +71,10 @@ class Matrix:
7071
__end_row: int
7172
__start_col: int
7273
__end_col: int
73-
__elements: list[list[int]]
74+
__elements: list[list[float]]
7475

7576
def __init__(self, end_row: int, end_col: int, start_row: int = 1, start_col: int = 1,
76-
elements: list[list[int]] | None = None) -> None:
77+
elements: list[list[float]] | None = None) -> None:
7778
if elements is None:
7879
assert start_row == 1
7980
assert start_col == 1
@@ -97,13 +98,13 @@ def even_rows_submatrix(self) -> Matrix:
9798
return Matrix(start_row=self.__start_row, end_row=submatrix_end_row, start_col=self.__start_col,
9899
end_col=self.__end_col, elements=even_rows)
99100

100-
def __getitem__(self, indices: tuple[int, int]) -> int:
101+
def __getitem__(self, indices: tuple[int, int]) -> float:
101102
row, col = indices[0], indices[1]
102103
assert 1 <= row <= self.__end_row - self.__start_row + 1
103104
assert 1 <= col <= self.__end_col - self.__start_col + 1
104105
return self.__elements[self.__start_row - 1 + row - 1][self.__start_col - 1 + col - 1]
105106

106-
def __setitem__(self, indices: tuple[int, int], value: int) -> None:
107+
def __setitem__(self, indices: tuple[int, int], value: float) -> None:
107108
row, col = indices[0], indices[1]
108109
assert 1 <= row <= self.__end_row - self.__start_row + 1
109110
assert 1 <= col <= self.__end_col - self.__start_col + 1
@@ -114,7 +115,7 @@ def __eq__(self, other: Any) -> bool:
114115
return NotImplemented
115116
for i, row in enumerate(self.__elements, start=1):
116117
for j, element in enumerate(row, start=1):
117-
if element != other[i, j]:
118+
if not math.isclose(element, other[i, j], abs_tol=1e-7):
118119
return False
119120
try:
120121
_ = other[i, len(row) + 1]

test/test_book/test_chapter4.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import numpy
22
from hypothesis import given
33
from hypothesis import strategies as st
4+
from hypothesis.strategies import floats
45
from hypothesis.strategies import integers
56
from hypothesis.strategies import lists
67

@@ -17,10 +18,10 @@ class TestChapter4(ClrsTestCase):
1718
def test_matrix_multiply(self, data):
1819
n = data.draw(integers(min_value=1, max_value=15), label="Matrices dimension")
1920
elements1 = data.draw(
20-
lists(lists(integers(min_value=-1000, max_value=1000), min_size=n, max_size=n), min_size=n, max_size=n),
21+
lists(lists(floats(min_value=-1000, max_value=1000), min_size=n, max_size=n), min_size=n, max_size=n),
2122
label="First matrix elements")
2223
elements2 = data.draw(
23-
lists(lists(integers(min_value=-1000, max_value=1000), min_size=n, max_size=n), min_size=n, max_size=n),
24+
lists(lists(floats(min_value=-1000, max_value=1000), min_size=n, max_size=n), min_size=n, max_size=n),
2425
label="Second matrix elements")
2526
A = create_matrix(elements1)
2627
B = create_matrix(elements2)
@@ -36,10 +37,10 @@ def test_matrix_multiply_recursive(self, data):
3637
k = data.draw(integers(min_value=0, max_value=4), label="Matrices dimension exponent")
3738
n = 2 ** k
3839
elements1 = data.draw(
39-
lists(lists(integers(min_value=-1000, max_value=1000), min_size=n, max_size=n), min_size=n, max_size=n),
40+
lists(lists(floats(min_value=-1000, max_value=1000), min_size=n, max_size=n), min_size=n, max_size=n),
4041
label="First matrix elements")
4142
elements2 = data.draw(
42-
lists(lists(integers(min_value=-1000, max_value=1000), min_size=n, max_size=n), min_size=n, max_size=n),
43+
lists(lists(floats(min_value=-1000, max_value=1000), min_size=n, max_size=n), min_size=n, max_size=n),
4344
label="Second matrix elements")
4445
A = create_matrix(elements1)
4546
B = create_matrix(elements2)

test/test_solutions/test_chapter4.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from hypothesis import given
44
from hypothesis import strategies as st
55
from hypothesis.strategies import complex_numbers
6+
from hypothesis.strategies import floats
67
from hypothesis.strategies import integers
78
from hypothesis.strategies import lists
89
from hypothesis.strategies import sampled_from
@@ -40,10 +41,10 @@ class TestChapter4(ClrsTestCase):
4041
def test_matrix_multiply_recursive_general(self, data):
4142
n = data.draw(integers(min_value=1, max_value=15), label="Matrices dimension")
4243
elements1 = data.draw(
43-
lists(lists(integers(min_value=-1000, max_value=1000), min_size=n, max_size=n), min_size=n, max_size=n),
44+
lists(lists(floats(min_value=-1000, max_value=1000), min_size=n, max_size=n), min_size=n, max_size=n),
4445
label="First matrix elements")
4546
elements2 = data.draw(
46-
lists(lists(integers(min_value=-1000, max_value=1000), min_size=n, max_size=n), min_size=n, max_size=n),
47+
lists(lists(floats(min_value=-1000, max_value=1000), min_size=n, max_size=n), min_size=n, max_size=n),
4748
label="Second matrix elements")
4849
A = create_matrix(elements1)
4950
B = create_matrix(elements2)
@@ -59,10 +60,10 @@ def test_matrix_multiply_recursive_by_copying(self, data):
5960
k = data.draw(integers(min_value=0, max_value=4), label="Matrices dimension exponent")
6061
n = 2 ** k
6162
elements1 = data.draw(
62-
lists(lists(integers(min_value=-1000, max_value=1000), min_size=n, max_size=n), min_size=n, max_size=n),
63+
lists(lists(floats(min_value=-1000, max_value=1000), min_size=n, max_size=n), min_size=n, max_size=n),
6364
label="First matrix elements")
6465
elements2 = data.draw(
65-
lists(lists(integers(min_value=-1000, max_value=1000), min_size=n, max_size=n), min_size=n, max_size=n),
66+
lists(lists(floats(min_value=-1000, max_value=1000), min_size=n, max_size=n), min_size=n, max_size=n),
6667
label="Second matrix elements")
6768
A = create_matrix(elements1)
6869
B = create_matrix(elements2)
@@ -78,10 +79,10 @@ def test_matrix_add_recursive(self, data):
7879
k = data.draw(integers(min_value=0, max_value=4), label="Matrices dimension exponent")
7980
n = 2 ** k
8081
elements1 = data.draw(
81-
lists(lists(integers(min_value=-1000, max_value=1000), min_size=n, max_size=n), min_size=n, max_size=n),
82+
lists(lists(floats(min_value=-1000, max_value=1000), min_size=n, max_size=n), min_size=n, max_size=n),
8283
label="First matrix elements")
8384
elements2 = data.draw(
84-
lists(lists(integers(min_value=-1000, max_value=1000), min_size=n, max_size=n), min_size=n, max_size=n),
85+
lists(lists(floats(min_value=-1000, max_value=1000), min_size=n, max_size=n), min_size=n, max_size=n),
8586
label="Second matrix elements")
8687
A = create_matrix(elements1)
8788
B = create_matrix(elements2)
@@ -97,10 +98,10 @@ def test_strassen(self, data):
9798
k = data.draw(integers(min_value=0, max_value=4), label="Matrices dimension exponent")
9899
n = 2 ** k
99100
elements1 = data.draw(
100-
lists(lists(integers(min_value=-1000, max_value=1000), min_size=n, max_size=n), min_size=n, max_size=n),
101+
lists(lists(floats(min_value=-1000, max_value=1000), min_size=n, max_size=n), min_size=n, max_size=n),
101102
label="First matrix elements")
102103
elements2 = data.draw(
103-
lists(lists(integers(min_value=-1000, max_value=1000), min_size=n, max_size=n), min_size=n, max_size=n),
104+
lists(lists(floats(min_value=-1000, max_value=1000), min_size=n, max_size=n), min_size=n, max_size=n),
104105
label="Second matrix elements")
105106
A = create_matrix(elements1)
106107
B = create_matrix(elements2)
@@ -126,10 +127,10 @@ def test_complex_multiply(self, data):
126127
def test_matrix_multiply_by_squaring(self, data):
127128
n = data.draw(integers(min_value=1, max_value=15), label="Matrices dimension")
128129
elements1 = data.draw(
129-
lists(lists(integers(min_value=-1000, max_value=1000), min_size=n, max_size=n), min_size=n, max_size=n),
130+
lists(lists(floats(min_value=-1000, max_value=1000), min_size=n, max_size=n), min_size=n, max_size=n),
130131
label="First matrix elements")
131132
elements2 = data.draw(
132-
lists(lists(integers(min_value=-1000, max_value=1000), min_size=n, max_size=n), min_size=n, max_size=n),
133+
lists(lists(floats(min_value=-1000, max_value=1000), min_size=n, max_size=n), min_size=n, max_size=n),
133134
label="Second matrix elements")
134135
A = create_matrix(elements1)
135136
B = create_matrix(elements2)
@@ -162,7 +163,7 @@ def test_monge_leftmost_minimums(self, data):
162163
m = data.draw(integers(min_value=1, max_value=15), label="Monge array row dimension")
163164
n = data.draw(integers(min_value=1, max_value=15), label="Monge array column dimension")
164165
elements = data.draw(
165-
lists(lists(integers(min_value=-1000, max_value=1000), min_size=n, max_size=n), min_size=m, max_size=m),
166+
lists(lists(floats(min_value=-1000, max_value=1000), min_size=n, max_size=n), min_size=m, max_size=m),
166167
label="Monge array elements")
167168
A = create_matrix(elements)
168169
assume(is_monge_array(A, m, n))

test/test_util.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ def create_array(elements: list[T], start: int = 1) -> Array[T]:
2121
return array
2222

2323

24-
def create_matrix(elements: list[list[int]]) -> Matrix:
24+
def create_matrix(elements: list[list[float]]) -> Matrix:
2525
rows = len(elements)
2626
cols = len(elements[0])
2727
matrix = Matrix(rows, cols)

0 commit comments

Comments
 (0)