Skip to content

Commit 51bffaf

Browse files
committed
Adds norm operations
1 parent 933f582 commit 51bffaf

2 files changed

Lines changed: 248 additions & 3 deletions

File tree

sasdata/quantities/quantity.py

Lines changed: 183 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,48 @@ def determinant(a: Union["Quantity[ArrayLike]", ArrayLike]):
7676
return np.linalg.det(a)
7777

7878

79+
def norm_1(a: Union["Quantity[ArrayLike]", ArrayLike], axes: int | tuple[int] | None = None):
80+
"""Caculate the 1-norm of an array or an array based quantity."""
81+
if isinstance(a, Quantity):
82+
if axes is None:
83+
return DerivedQuantity(
84+
value=np.linalg.norm(a.value, ord=1, axes=axes),
85+
units=a.units,
86+
history=QuantityHistory.apply_operation(Norm_1, a.history),
87+
)
88+
89+
else:
90+
return DerivedQuantity(
91+
value=np.linalg.norm(a.value, ord=1, axes=axes),
92+
units=a.units,
93+
history=QuantityHistory.apply_operation(Norm_1, a.history, axes=axes),
94+
)
95+
96+
else:
97+
return np.linalg.norm(a.value, ord=1, axes=axes)
98+
99+
100+
def norm_2(a: Union["Quantity[ArrayLike]", ArrayLike], axes: int | tuple[int] | None = None):
101+
"""Caculate the 2-norm of an array or an array based quantity."""
102+
if isinstance(a, Quantity):
103+
if axes is None:
104+
return DerivedQuantity(
105+
value=np.linalg.norm(a.value, axes=axes),
106+
units=a.units,
107+
history=QuantityHistory.apply_operation(Norm_2, a.history),
108+
)
109+
110+
else:
111+
return DerivedQuantity(
112+
value=np.linalg.norm(a.value, axes=axes),
113+
units=a.units,
114+
history=QuantityHistory.apply_operation(Norm_2, a.history, axes=axes),
115+
)
116+
117+
else:
118+
return np.linalg.norm(a.value, axes=axes)
119+
120+
79121
def dot(a: Union["Quantity[ArrayLike]", ArrayLike], b: Union["Quantity[ArrayLike]", ArrayLike]):
80122
"""Dot product of two arrays or two array based quantities"""
81123
a_is_quantity = isinstance(a, Quantity)
@@ -331,6 +373,34 @@ def __eq__(self, other):
331373
return False
332374

333375

376+
class MatrixIdentity(ConstantBase):
377+
serialisation_name = "identity"
378+
379+
def __init__(self, size):
380+
self.size = size
381+
382+
def evaluate(self, variables: dict[int, T]) -> T:
383+
return np.eye(self.size)
384+
385+
def _derivative(self, hash_value: int):
386+
return Constant(np.zeros((self.size, self.size)))
387+
388+
@staticmethod
389+
def _deserialise(parameters: dict) -> "Operation":
390+
return MatrixIdentity(self.size)
391+
392+
def _serialise_parameters(self) -> dict[str, Any]:
393+
return {"size": numerical_encode(self.size)}
394+
395+
def summary(self, indent_amount: int = 0, indent=" "):
396+
return f"{indent_amount * indent}{self.size} [Matrix Id.]"
397+
398+
def __eq__(self, other):
399+
if isinstance(other, MatrixIdentity):
400+
return self.size == other.size
401+
return False
402+
403+
334404
class Constant(ConstantBase):
335405
serialisation_name = "constant"
336406

@@ -1071,7 +1141,7 @@ def summary(self, indent_amount: int = 0, indent=" "):
10711141
f"{indent_amount * indent}Transpose(\n"
10721142
+ self.a.summary(indent_amount + 1, indent)
10731143
+ "\n"
1074-
+ f"{(indent_amount + 1) * indent}{self.axes}\n"
1144+
+ f"{(indent_amount + 1) * indent}{list(self.axes)}\n"
10751145
+ f"{indent_amount * indent})"
10761146
)
10771147

@@ -1248,6 +1318,116 @@ def _deserialise(parameters: dict) -> "Operation":
12481318
)
12491319

12501320

1321+
class Norm_1(Operation):
1322+
"""1-norm of a matrix from numpy"""
1323+
1324+
serialisation_name = "norm_1"
1325+
1326+
def __init__(self, a: Operation, axes: int | tuple[int] | None = None):
1327+
self.a = a
1328+
self.axes = axes
1329+
1330+
def evaluate(self, variables: dict[int, T]) -> T:
1331+
return np.linalg.norm(self.a.evaluate(variables), ord=1, axis=self.axes)
1332+
1333+
def _derivative(self, hash_value: int) -> Operation:
1334+
return np.sign(self.a)
1335+
1336+
def _clean(self):
1337+
clean_a = self.a._clean()
1338+
return Norm_1(clean_a, self.axes)
1339+
1340+
def _serialise_parameters(self) -> dict[str, Any]:
1341+
if self.axes is None:
1342+
return {"a": self.a._serialise_json()}
1343+
else:
1344+
return {"a": self.a._serialise_json(), "axes": list(self.axes)}
1345+
1346+
@staticmethod
1347+
def _deserialise(parameters: dict) -> "Operation":
1348+
if "axes" in parameters:
1349+
return Norm_1(a=Operation.deserialise_json(parameters["a"]), axes=tuple(parameters["axes"]))
1350+
else:
1351+
return Norm_1(a=Operation.deserialise_json(parameters["a"]))
1352+
1353+
def summary(self, indent_amount: int = 0, indent=" "):
1354+
if self.axes is None:
1355+
return (
1356+
f"{indent_amount * indent}Norm_1(\n"
1357+
+ self.a.summary(indent_amount + 1, indent)
1358+
+ "\n"
1359+
+ f"{indent_amount * indent})"
1360+
)
1361+
else:
1362+
return (
1363+
f"{indent_amount * indent}Norm_1(\n"
1364+
+ self.a.summary(indent_amount + 1, indent)
1365+
+ "\n"
1366+
+ f"{(indent_amount + 1) * indent}{list(self.axes)}\n"
1367+
+ f"{indent_amount * indent})"
1368+
)
1369+
1370+
def __eq__(self, other):
1371+
if isinstance(other, Norm_1):
1372+
return other.a == self.a
1373+
return False
1374+
1375+
1376+
class Norm_2(Operation):
1377+
"""2-norm of a matrix from numpy"""
1378+
1379+
serialisation_name = "norm_2"
1380+
1381+
def __init__(self, a: Operation, axes: int | tuple[int] | None = None):
1382+
self.a = a
1383+
self.axes = axes
1384+
1385+
def evaluate(self, variables: dict[int, T]) -> T:
1386+
return np.linalg.norm(self.a.evaluate(variables), axis=self.axes)
1387+
1388+
def _derivative(self, hash_value: int) -> Operation:
1389+
return Transpose(Div(self.a, Norm_2(self.a, self.axes)))
1390+
1391+
def _clean(self):
1392+
clean_a = self.a._clean()
1393+
return Norm_2(clean_a, self.axes)
1394+
1395+
def _serialise_parameters(self) -> dict[str, Any]:
1396+
if self.axes is None:
1397+
return {"a": self.a._serialise_json()}
1398+
else:
1399+
return {"a": self.a._serialise_json(), "axes": list(self.axes)}
1400+
1401+
@staticmethod
1402+
def _deserialise(parameters: dict) -> "Operation":
1403+
if "axes" in parameters:
1404+
return Norm_2(a=Operation.deserialise_json(parameters["a"]), axes=tuple(parameters["axes"]))
1405+
else:
1406+
return Norm_2(a=Operation.deserialise_json(parameters["a"]))
1407+
1408+
def summary(self, indent_amount: int = 0, indent=" "):
1409+
if self.axes is None:
1410+
return (
1411+
f"{indent_amount * indent}Norm_2(\n"
1412+
+ self.a.summary(indent_amount + 1, indent)
1413+
+ "\n"
1414+
+ f"{indent_amount * indent})"
1415+
)
1416+
else:
1417+
return (
1418+
f"{indent_amount * indent}Norm_2(\n"
1419+
+ self.a.summary(indent_amount + 1, indent)
1420+
+ "\n"
1421+
+ f"{(indent_amount + 1) * indent}{list(self.axes)}\n"
1422+
+ f"{indent_amount * indent})"
1423+
)
1424+
1425+
def __eq__(self, other):
1426+
if isinstance(other, Norm_2):
1427+
return other.a == self.a
1428+
return False
1429+
1430+
12511431
_serialisable_classes = [
12521432
AdditiveIdentity,
12531433
MultiplicativeIdentity,
@@ -1276,6 +1456,8 @@ def _deserialise(parameters: dict) -> "Operation":
12761456
MatMul,
12771457
MatInv,
12781458
TensorDot,
1459+
Norm_1,
1460+
Norm_2,
12791461
]
12801462

12811463
_serialisation_lookup = {class_.serialisation_name: class_ for class_ in _serialisable_classes}

test/quantities/utest_operations.py

Lines changed: 65 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,12 @@
2020
Log,
2121
MatInv,
2222
MatMul,
23+
MatrixIdentity,
2324
Mul,
2425
MultiplicativeIdentity,
2526
Neg,
27+
Norm_1,
28+
Norm_2,
2629
Operation,
2730
Pow,
2831
Sin,
@@ -64,6 +67,16 @@ def log_pow_operation(request):
6467
return request.param(x, 2)
6568

6669

70+
@pytest.fixture(params=[Transpose, Norm_1, Norm_2])
71+
def axis_operation(request):
72+
return request.param(x, (0,))
73+
74+
75+
@pytest.fixture(params=[Transpose, Norm_1, Norm_2])
76+
def axis_none_operation(request):
77+
return request.param(x, None)
78+
79+
6780
def test_serialise_deserialise():
6881
serialised = operation_with_everything.serialise()
6982
deserialised = Operation.deserialise(serialised)
@@ -96,6 +109,22 @@ def test_log_pow_serialise_deserialise(log_pow_operation):
96109
assert serialised == reserialised
97110

98111

112+
def test_axis_serialise_deserialise(axis_operation):
113+
serialised = axis_operation.serialise()
114+
deserialised = Operation.deserialise(serialised)
115+
reserialised = deserialised.serialise()
116+
117+
assert serialised == reserialised
118+
119+
120+
def test_axis_none_serialise_deserialise(axis_none_operation):
121+
serialised = axis_none_operation.serialise()
122+
deserialised = Operation.deserialise(serialised)
123+
reserialised = deserialised.serialise()
124+
125+
assert serialised == reserialised
126+
127+
99128
def test_trace_serialise_deserialise():
100129
serialised = Trace(x).serialise()
101130
deserialised = Operation.deserialise(serialised)
@@ -113,6 +142,11 @@ def test_summary(op, summary):
113142
assert f.summary() == summary
114143

115144

145+
def test_matrix_id_summary():
146+
f = MatrixIdentity(1)
147+
assert f.summary() == "1 [Matrix Id.]"
148+
149+
116150
def test_variable_summary():
117151
assert x.summary() == "x"
118152

@@ -129,6 +163,14 @@ def test_log_pow_summary(log_pow_operation):
129163
assert log_pow_operation.summary() == f"{log_pow_operation.__class__.__name__}(\n x\n 2\n)"
130164

131165

166+
def test_axis_summary(axis_operation):
167+
assert axis_operation.summary() == f"{axis_operation.__class__.__name__}(\n x\n [0]\n)"
168+
169+
170+
def test_axis_none_summary(axis_none_operation):
171+
assert axis_none_operation.summary() == f"{axis_none_operation.__class__.__name__}(\n x\n)"
172+
173+
132174
def test_trace_summary():
133175
op = Trace(x)
134176
assert op.summary() == f"{op.__class__.__name__}(\n x\n 0\n 0\n 1\n)"
@@ -212,10 +254,31 @@ def test_matmul_evaluation(op, a, b, result):
212254

213255
@pytest.mark.parametrize(
214256
"op, a, result",
215-
[(Transpose, np.array([[1, 2]]), np.array([[1], [2]])), (Transpose, [[1, 2], [3, 4]], [[1, 3], [2, 4]])],
257+
[
258+
(Transpose, np.array([[1, 2]]), np.array([[1], [2]])),
259+
(Transpose, [[1, 2], [3, 4]], [[1, 3], [2, 4]]),
260+
(Norm_1, [[1, 2], [3, 4]], np.float64(6.0)),
261+
(Norm_2, [[1, 2], [3, 4]], np.float64(np.sqrt(30.0))),
262+
(Norm_2, [[1.0, 2.5], [3.0, 4.0]], np.float64(np.sqrt(32.25))),
263+
],
216264
)
217-
def test_transpose_evaluation(op, a, result):
265+
def test_axis_none_evaluation(op, a, result):
218266
f = op(Constant(a))
267+
print(f.evaluate({}))
268+
print(result)
269+
assert (f.evaluate({}) == result).all()
270+
271+
272+
@pytest.mark.parametrize(
273+
"op, a, axes, result",
274+
[
275+
(Transpose, [[1, 2], [3, 4]], (1, 0), [[1, 3], [2, 4]]),
276+
(Norm_1, np.array([[1, 2], [3, 4]]), 1, np.array([3.0, 7.0])),
277+
(Norm_2, np.array([[1, 2], [3, 4]]), 1, np.array([np.sqrt(5.0), 5.0])),
278+
],
279+
)
280+
def test_axis_evaluation(op, a, axes, result):
281+
f = op(Constant(a), axes=axes)
219282
assert (f.evaluate({}) == result).all()
220283

221284

0 commit comments

Comments
 (0)