Skip to content

Commit 734ccfb

Browse files
authored
Merge pull request #196 from SasView/refactor_24_linalg
Adds linear algebra operations
2 parents 37e2d7a + b72f6fa commit 734ccfb

3 files changed

Lines changed: 455 additions & 84 deletions

File tree

sasdata/quantities/quantity.py

Lines changed: 284 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,72 @@ def transpose(a: Union["Quantity[ArrayLike]", ArrayLike], axes: tuple | None = N
3939
return np.transpose(a, axes=axes)
4040

4141

42+
def trace(a: Union["Quantity[ArrayLike]", ArrayLike], offset: int = 0, axis1: int = 0, axis2: int = 1):
43+
"""Find the trace of an array or an array based quantity."""
44+
if isinstance(a, Quantity):
45+
return DerivedQuantity(
46+
value=np.trace(a.value, offset, axis1, axis2),
47+
units=a.units,
48+
history=QuantityHistory.apply_operation(Trace, a.history, offset=offset, axis1=axis1, axis2=axis2),
49+
)
50+
else:
51+
return np.trace(a, offset, axis1, axis2)
52+
53+
54+
def matinv(a: Union["Quantity[ArrayLike]", ArrayLike]):
55+
"""Find the inverse of a matrix."""
56+
if isinstance(a, Quantity):
57+
return DerivedQuantity(
58+
value=np.linalg.inv(a.value),
59+
units=a.units,
60+
history=QuantityHistory.apply_operation(MatInv, a.history),
61+
)
62+
else:
63+
return np.linalg.inv(a)
64+
65+
66+
def norm_1(a: Union["Quantity[ArrayLike]", ArrayLike], axes: int | tuple[int] | None = None):
67+
"""Caculate the 1-norm of an array or an array based quantity."""
68+
if isinstance(a, Quantity):
69+
if axes is None:
70+
return DerivedQuantity(
71+
value=np.linalg.norm(a.value, ord=1, axes=axes),
72+
units=a.units,
73+
history=QuantityHistory.apply_operation(Norm_1, a.history),
74+
)
75+
76+
else:
77+
return DerivedQuantity(
78+
value=np.linalg.norm(a.value, ord=1, axes=axes),
79+
units=a.units,
80+
history=QuantityHistory.apply_operation(Norm_1, a.history, axes=axes),
81+
)
82+
83+
else:
84+
return np.linalg.norm(a.value, ord=1, axes=axes)
85+
86+
87+
def norm_2(a: Union["Quantity[ArrayLike]", ArrayLike], axes: int | tuple[int] | None = None):
88+
"""Caculate the 2-norm of an array or an array based quantity."""
89+
if isinstance(a, Quantity):
90+
if axes is None:
91+
return DerivedQuantity(
92+
value=np.linalg.norm(a.value, axes=axes),
93+
units=a.units,
94+
history=QuantityHistory.apply_operation(Norm_2, a.history),
95+
)
96+
97+
else:
98+
return DerivedQuantity(
99+
value=np.linalg.norm(a.value, axes=axes),
100+
units=a.units,
101+
history=QuantityHistory.apply_operation(Norm_2, a.history, axes=axes),
102+
)
103+
104+
else:
105+
return np.linalg.norm(a.value, axes=axes)
106+
107+
42108
def dot(a: Union["Quantity[ArrayLike]", ArrayLike], b: Union["Quantity[ArrayLike]", ArrayLike]):
43109
"""Dot product of two arrays or two array based quantities"""
44110
a_is_quantity = isinstance(a, Quantity)
@@ -294,6 +360,34 @@ def __eq__(self, other):
294360
return False
295361

296362

363+
class MatrixIdentity(ConstantBase):
364+
serialisation_name = "identity"
365+
366+
def __init__(self, size):
367+
self.size = size
368+
369+
def evaluate(self, variables: dict[int, T]) -> T:
370+
return np.eye(self.size)
371+
372+
def _derivative(self, hash_value: int):
373+
return Constant(np.zeros((self.size, self.size)))
374+
375+
@staticmethod
376+
def _deserialise(parameters: dict) -> "Operation":
377+
return MatrixIdentity(self.size)
378+
379+
def _serialise_parameters(self) -> dict[str, Any]:
380+
return {"size": numerical_encode(self.size)}
381+
382+
def summary(self, indent_amount: int = 0, indent=" "):
383+
return f"{indent_amount * indent}{self.size} [Matrix Id.]"
384+
385+
def __eq__(self, other):
386+
if isinstance(other, MatrixIdentity):
387+
return self.size == other.size
388+
return False
389+
390+
297391
class Constant(ConstantBase):
298392
serialisation_name = "constant"
299393

@@ -999,14 +1093,14 @@ def __init__(self, a: Operation, axes: tuple[int] | None = None):
9991093
self.axes = axes
10001094

10011095
def evaluate(self, variables: dict[int, T]) -> T:
1002-
return np.transpose(self.a.evaluate(variables))
1096+
return np.transpose(self.a.evaluate(variables), self.axes)
10031097

10041098
def _derivative(self, hash_value: int) -> Operation:
1005-
return Transpose(self.a.derivative(hash_value)) # TODO: Check!
1099+
return Transpose(self.a.derivative(hash_value), self.axes) # TODO: Check!
10061100

10071101
def _clean(self):
10081102
clean_a = self.a._clean()
1009-
return Transpose(clean_a)
1103+
return Transpose(clean_a, self.axes)
10101104

10111105
def _serialise_parameters(self) -> dict[str, Any]:
10121106
if self.axes is None:
@@ -1034,7 +1128,7 @@ def summary(self, indent_amount: int = 0, indent=" "):
10341128
f"{indent_amount * indent}Transpose(\n"
10351129
+ self.a.summary(indent_amount + 1, indent)
10361130
+ "\n"
1037-
+ f"{(indent_amount + 1) * indent}{self.axes}\n"
1131+
+ f"{(indent_amount + 1) * indent}{list(self.axes)}\n"
10381132
+ f"{indent_amount * indent})"
10391133
)
10401134

@@ -1044,6 +1138,56 @@ def __eq__(self, other):
10441138
return False
10451139

10461140

1141+
class Trace(Operation):
1142+
"""Trace operation - as per numpy"""
1143+
1144+
serialisation_name = "trace"
1145+
1146+
def __init__(self, a: Operation, offset: int = 0, axis1: int = 0, axis2: int = 1):
1147+
self.a = a
1148+
self.offset = offset
1149+
self.axis1 = axis1
1150+
self.axis2 = axis2
1151+
1152+
def evaluate(self, variables: dict[int, T]) -> T:
1153+
return np.trace(self.a.evaluate(variables), self.offset, self.axis1, self.axis2)
1154+
1155+
def _derivative(self, hash_value: int) -> Operation:
1156+
return Trace(self.a.derivative(hash_value), self.offset, self.axis1, self.axis2)
1157+
1158+
def _clean(self):
1159+
clean_a = self.a._clean()
1160+
return Trace(clean_a, self.offset, self.axis1, self.axis2)
1161+
1162+
def _serialise_parameters(self) -> dict[str, Any]:
1163+
return {"a": self.a._serialise_json(), "offset": self.offset, "axis1": self.axis1, "axis2": self.axis2}
1164+
1165+
@staticmethod
1166+
def _deserialise(parameters: dict) -> "Operation":
1167+
return Trace(
1168+
a=Operation.deserialise_json(parameters["a"]),
1169+
offset=parameters["offset"],
1170+
axis1=parameters["axis1"],
1171+
axis2=parameters["axis2"],
1172+
)
1173+
1174+
def summary(self, indent_amount: int = 0, indent=" "):
1175+
return (
1176+
f"{indent_amount * indent}Trace(\n"
1177+
+ self.a.summary(indent_amount + 1, indent)
1178+
+ "\n"
1179+
+ f"{(indent_amount + 1) * indent}{self.offset}\n"
1180+
+ f"{(indent_amount + 1) * indent}{self.axis1}\n"
1181+
+ f"{(indent_amount + 1) * indent}{self.axis2}\n"
1182+
+ f"{indent_amount * indent})"
1183+
)
1184+
1185+
def __eq__(self, other):
1186+
if isinstance(other, Trace):
1187+
return other.a == self.a
1188+
return False
1189+
1190+
10471191
class Dot(BinaryOperation):
10481192
"""Dot product - backed by numpy's dot method"""
10491193

@@ -1089,6 +1233,28 @@ def _clean_ab(self, a, b):
10891233
return MatMul(a, b)
10901234

10911235

1236+
class MatInv(UnaryOperation):
1237+
"""Matrix inversion, using numpy"""
1238+
1239+
serialisation_name = "matinv"
1240+
1241+
def evaluate(self, variables: dict[int, T]) -> T:
1242+
return np.linalg.inv(self.a.evaluate(variables))
1243+
1244+
def _derivative(self, hash_value: int) -> Operation:
1245+
return Neg(Matmul(MatMul(MatInv(self.a), self.a._derivative(hash_value)), MatInv(self.a)))
1246+
1247+
def _clean(self):
1248+
clean_a = self.a._clean()
1249+
1250+
if isinstance(clean_a, MatInv):
1251+
# Removes double inversions
1252+
return clean_a.a
1253+
1254+
else:
1255+
return MatInv(clean_a)
1256+
1257+
10921258
class TensorDot(Operation):
10931259
serialisation_name = "tensor_product"
10941260

@@ -1119,6 +1285,116 @@ def _deserialise(parameters: dict) -> "Operation":
11191285
)
11201286

11211287

1288+
class Norm_1(Operation):
1289+
"""1-norm of a matrix from numpy"""
1290+
1291+
serialisation_name = "norm_1"
1292+
1293+
def __init__(self, a: Operation, axes: int | tuple[int] | None = None):
1294+
self.a = a
1295+
self.axes = axes
1296+
1297+
def evaluate(self, variables: dict[int, T]) -> T:
1298+
return np.linalg.norm(self.a.evaluate(variables), ord=1, axis=self.axes)
1299+
1300+
def _derivative(self, hash_value: int) -> Operation:
1301+
return np.sign(self.a)
1302+
1303+
def _clean(self):
1304+
clean_a = self.a._clean()
1305+
return Norm_1(clean_a, self.axes)
1306+
1307+
def _serialise_parameters(self) -> dict[str, Any]:
1308+
if self.axes is None:
1309+
return {"a": self.a._serialise_json()}
1310+
else:
1311+
return {"a": self.a._serialise_json(), "axes": list(self.axes)}
1312+
1313+
@staticmethod
1314+
def _deserialise(parameters: dict) -> "Operation":
1315+
if "axes" in parameters:
1316+
return Norm_1(a=Operation.deserialise_json(parameters["a"]), axes=tuple(parameters["axes"]))
1317+
else:
1318+
return Norm_1(a=Operation.deserialise_json(parameters["a"]))
1319+
1320+
def summary(self, indent_amount: int = 0, indent=" "):
1321+
if self.axes is None:
1322+
return (
1323+
f"{indent_amount * indent}Norm_1(\n"
1324+
+ self.a.summary(indent_amount + 1, indent)
1325+
+ "\n"
1326+
+ f"{indent_amount * indent})"
1327+
)
1328+
else:
1329+
return (
1330+
f"{indent_amount * indent}Norm_1(\n"
1331+
+ self.a.summary(indent_amount + 1, indent)
1332+
+ "\n"
1333+
+ f"{(indent_amount + 1) * indent}{list(self.axes)}\n"
1334+
+ f"{indent_amount * indent})"
1335+
)
1336+
1337+
def __eq__(self, other):
1338+
if isinstance(other, Norm_1):
1339+
return other.a == self.a
1340+
return False
1341+
1342+
1343+
class Norm_2(Operation):
1344+
"""2-norm of a matrix from numpy"""
1345+
1346+
serialisation_name = "norm_2"
1347+
1348+
def __init__(self, a: Operation, axes: int | tuple[int] | None = None):
1349+
self.a = a
1350+
self.axes = axes
1351+
1352+
def evaluate(self, variables: dict[int, T]) -> T:
1353+
return np.linalg.norm(self.a.evaluate(variables), axis=self.axes)
1354+
1355+
def _derivative(self, hash_value: int) -> Operation:
1356+
return Transpose(Div(self.a, Norm_2(self.a, self.axes)))
1357+
1358+
def _clean(self):
1359+
clean_a = self.a._clean()
1360+
return Norm_2(clean_a, self.axes)
1361+
1362+
def _serialise_parameters(self) -> dict[str, Any]:
1363+
if self.axes is None:
1364+
return {"a": self.a._serialise_json()}
1365+
else:
1366+
return {"a": self.a._serialise_json(), "axes": list(self.axes)}
1367+
1368+
@staticmethod
1369+
def _deserialise(parameters: dict) -> "Operation":
1370+
if "axes" in parameters:
1371+
return Norm_2(a=Operation.deserialise_json(parameters["a"]), axes=tuple(parameters["axes"]))
1372+
else:
1373+
return Norm_2(a=Operation.deserialise_json(parameters["a"]))
1374+
1375+
def summary(self, indent_amount: int = 0, indent=" "):
1376+
if self.axes is None:
1377+
return (
1378+
f"{indent_amount * indent}Norm_2(\n"
1379+
+ self.a.summary(indent_amount + 1, indent)
1380+
+ "\n"
1381+
+ f"{indent_amount * indent})"
1382+
)
1383+
else:
1384+
return (
1385+
f"{indent_amount * indent}Norm_2(\n"
1386+
+ self.a.summary(indent_amount + 1, indent)
1387+
+ "\n"
1388+
+ f"{(indent_amount + 1) * indent}{list(self.axes)}\n"
1389+
+ f"{indent_amount * indent})"
1390+
)
1391+
1392+
def __eq__(self, other):
1393+
if isinstance(other, Norm_2):
1394+
return other.a == self.a
1395+
return False
1396+
1397+
11221398
_serialisable_classes = [
11231399
AdditiveIdentity,
11241400
MultiplicativeIdentity,
@@ -1141,9 +1417,13 @@ def _deserialise(parameters: dict) -> "Operation":
11411417
Pow,
11421418
Log,
11431419
Transpose,
1420+
Trace,
11441421
Dot,
11451422
MatMul,
1423+
MatInv,
11461424
TensorDot,
1425+
Norm_1,
1426+
Norm_2,
11471427
]
11481428

11491429
_serialisation_lookup = {class_.serialisation_name: class_ for class_ in _serialisable_classes}

0 commit comments

Comments
 (0)