Skip to content

Commit 933f582

Browse files
committed
Adds matrix inverse operation
1 parent 4ab3265 commit 933f582

3 files changed

Lines changed: 58 additions & 2 deletions

File tree

sasdata/quantities/quantity.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,18 @@ def trace(a: Union["Quantity[ArrayLike]", ArrayLike], offset: int = 0, axis1: in
5252
return np.trace(a, offset, axis1, axis2)
5353

5454

55+
def matinv(a: Union["Quantity[ArrayLike]", ArrayLike]):
56+
"""Find the inverse of a matrix."""
57+
if isinstance(a, Quantity):
58+
return DerivedQuantity(
59+
value=np.linalg.inv(a.value),
60+
units=a.units,
61+
history=QuantityHistory.apply_operation(MatInv, a.history),
62+
)
63+
else:
64+
return np.linalg.inv(a)
65+
66+
5567
def determinant(a: Union["Quantity[ArrayLike]", ArrayLike]):
5668
"""Find the determinant of an array or an array based quantity."""
5769
if isinstance(a, Quantity):
@@ -1180,6 +1192,32 @@ def _clean_ab(self, a, b):
11801192
return MatMul(a, b)
11811193

11821194

1195+
class MatInv(UnaryOperation):
1196+
"""Matrix inversion, using numpy"""
1197+
1198+
serialisation_name = "matinv"
1199+
1200+
def evaluate(self, variables: dict[int, T]) -> T:
1201+
return np.linalg.inv(self.a.evaluate(variables))
1202+
1203+
def _derivative(self, hash_value: int) -> Operation:
1204+
return Neg(Matmul(MatMul(MatInv(self.a), self.a._derivative(hash_value)), MatInv(self.a)))
1205+
1206+
def _clean(self):
1207+
clean_a = self.a._clean()
1208+
1209+
if isinstance(clean_a, MatInv):
1210+
# Removes double inversions
1211+
return clean_a.a
1212+
1213+
elif isinstance(clean_a, Determinant):
1214+
# Inverse of determinant is determinant of inverse
1215+
return Determinant(MatInv(clean_a.a))
1216+
1217+
else:
1218+
return MatInv(clean_a)
1219+
1220+
11831221
class TensorDot(Operation):
11841222
serialisation_name = "tensor_product"
11851223

@@ -1236,6 +1274,7 @@ def _deserialise(parameters: dict) -> "Operation":
12361274
Determinant,
12371275
Dot,
12381276
MatMul,
1277+
MatInv,
12391278
TensorDot,
12401279
]
12411280

test/quantities/utest_math_operations.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import pytest
55

66
from sasdata.quantities import units
7-
from sasdata.quantities.quantity import NamedQuantity, tensordot, trace, transpose
7+
from sasdata.quantities.quantity import NamedQuantity, matinv, tensordot, trace, transpose
88

99
order_list = [[0, 1, 2, 3], [0, 2, 1], [1, 0], [0, 1], [2, 0, 1], [3, 1, 2, 0]]
1010

@@ -57,6 +57,21 @@ def test_trace_axes(matrix, axis1, axis2, expected_trace):
5757
).all()
5858

5959

60+
@pytest.mark.parametrize(
61+
"matrix, expected_inverse",
62+
[
63+
(np.array([[1]]), np.array([[1]])),
64+
(np.array([[-2.0, 1.0], [1.5, -0.5]]), np.array([[1, 2], [3, 4]])),
65+
],
66+
)
67+
def test_inverse(matrix, expected_inverse):
68+
"""Check that the matinv operation correctly inverse for raw data and quantities."""
69+
print(matinv(matrix))
70+
print(expected_inverse)
71+
assert (matinv(matrix) == expected_inverse).all()
72+
assert (matinv(NamedQuantity("testmat", matrix, units=units.none)).value == expected_inverse).all()
73+
74+
6075
rng_seed = 1979
6176
tensor_product_with_identity_sizes = (4, 6, 5)
6277

test/quantities/utest_operations.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
Inv,
1919
Ln,
2020
Log,
21+
MatInv,
2122
MatMul,
2223
Mul,
2324
MultiplicativeIdentity,
@@ -48,7 +49,7 @@
4849
)
4950

5051

51-
@pytest.fixture(params=[Determinant, Inv, Exp, Ln, Neg, Sin, ArcSin, Cos, ArcCos, Tan, ArcTan, Transpose])
52+
@pytest.fixture(params=[Determinant, Inv, Exp, Ln, MatInv, Neg, Sin, ArcSin, Cos, ArcCos, Tan, ArcTan, Transpose])
5253
def unary_operation(request):
5354
return request.param(x)
5455

@@ -232,6 +233,7 @@ def test_derivative(op, result):
232233
[
233234
(Neg(Neg(x))),
234235
(Inv(Inv(x))),
236+
(MatInv(MatInv(x))),
235237
],
236238
)
237239
def test_clean_double_applications(op):

0 commit comments

Comments
 (0)