-
-
Notifications
You must be signed in to change notification settings - Fork 50.3k
Expand file tree
/
Copy pathmatrix_trace.py
More file actions
143 lines (109 loc) · 3.9 KB
/
matrix_trace.py
File metadata and controls
143 lines (109 loc) · 3.9 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
"""
Matrix trace calculation.
The trace of a square matrix is the sum of the elements on the main diagonal.
It's an important linear algebra operation with many applications.
Reference: https://en.wikipedia.org/wiki/Trace_(linear_algebra)
"""
import numpy as np
from numpy import float64
from numpy.typing import NDArray
def trace(matrix: NDArray[float64]) -> float:
"""
Calculate the trace of a square matrix.
The trace is the sum of the diagonal elements of a square matrix.
Parameters:
matrix (NDArray[float64]): A square matrix
Returns:
float: The trace of the matrix
Raises:
ValueError: If the matrix is not square
Examples:
>>> import numpy as np
>>> matrix = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=float)
>>> trace(matrix)
5.0
>>> matrix = np.array(
... [[2.0, -1.0, 3.0], [4.0, 5.0, -2.0], [1.0, 0.0, 7.0]], dtype=float
... )
>>> trace(matrix)
14.0
>>> matrix = np.array([[5.0]], dtype=float)
>>> trace(matrix)
5.0
"""
if matrix.shape[0] != matrix.shape[1]:
raise ValueError("Matrix must be square")
return float(np.sum(np.diag(matrix)))
def trace_properties_demo(matrix: NDArray[float64]) -> dict:
"""
Demonstrate various properties of the trace operation.
Parameters:
matrix (NDArray[float64]): A square matrix
Returns:
dict: Dictionary containing trace properties and calculations
"""
if matrix.shape[0] != matrix.shape[1]:
raise ValueError("Matrix must be square")
n = matrix.shape[0]
# Calculate trace
tr = trace(matrix)
# Calculate transpose trace (should be equal to original)
tr_transpose = trace(matrix.T)
# Calculate trace of scalar multiple
scalar = 2.0
tr_scalar = trace(scalar * matrix)
# Create identity matrix for comparison
identity = np.eye(n, dtype=float64)
tr_identity = trace(identity)
return {
"original_trace": tr,
"transpose_trace": tr_transpose,
"scalar_multiple_trace": tr_scalar,
"scalar_factor": scalar,
"identity_trace": tr_identity,
"trace_equals_transpose": abs(tr - tr_transpose) < 1e-10,
"scalar_property_check": abs(tr_scalar - scalar * tr) < 1e-10,
}
def test_trace() -> None:
"""
Test function for matrix trace calculation.
>>> test_trace() # self running tests
"""
# Test 1: 2x2 matrix
matrix_2x2 = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=float)
tr_2x2 = trace(matrix_2x2)
assert abs(tr_2x2 - 5.0) < 1e-10, "2x2 trace calculation failed"
# Test 2: 3x3 matrix
matrix_3x3 = np.array(
[[2.0, -1.0, 3.0], [4.0, 5.0, -2.0], [1.0, 0.0, 7.0]], dtype=float
)
tr_3x3 = trace(matrix_3x3)
assert abs(tr_3x3 - 14.0) < 1e-10, "3x3 trace calculation failed"
# Test 3: Identity matrix
identity_4x4 = np.eye(4, dtype=float)
tr_identity = trace(identity_4x4)
assert abs(tr_identity - 4.0) < 1e-10, (
"Identity matrix trace should equal dimension"
)
# Test 4: Zero matrix
zero_matrix = np.zeros((3, 3), dtype=float)
tr_zero = trace(zero_matrix)
assert abs(tr_zero) < 1e-10, "Zero matrix should have zero trace"
# Test 5: Trace properties
test_matrix = np.array(
[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]], dtype=float
)
properties = trace_properties_demo(test_matrix)
assert properties["trace_equals_transpose"], "Trace should equal transpose trace"
assert properties["scalar_property_check"], "Scalar multiplication property failed"
# Test 6: Diagonal matrix
diagonal_matrix = np.diag([1.0, 2.0, 3.0, 4.0])
tr_diagonal = trace(diagonal_matrix)
expected = 1.0 + 2.0 + 3.0 + 4.0
assert abs(tr_diagonal - expected) < 1e-10, (
"Diagonal matrix trace should equal sum of diagonal elements"
)
if __name__ == "__main__":
import doctest
doctest.testmod()
test_trace()