@@ -39,6 +39,72 @@ def transpose(a: Union["Quantity[ArrayLike]", ArrayLike], axes: tuple | None = N
3939 return np .transpose (a , axes = axes )
4040
4141
42+ def trace (a : Union ["Quantity[ArrayLike]" , ArrayLike ], offset : int = 0 , axis1 : int = 0 , axis2 : int = 1 ):
43+ """Find the trace of an array or an array based quantity."""
44+ if isinstance (a , Quantity ):
45+ return DerivedQuantity (
46+ value = np .trace (a .value , offset , axis1 , axis2 ),
47+ units = a .units ,
48+ history = QuantityHistory .apply_operation (Trace , a .history , offset = offset , axis1 = axis1 , axis2 = axis2 ),
49+ )
50+ else :
51+ return np .trace (a , offset , axis1 , axis2 )
52+
53+
54+ def matinv (a : Union ["Quantity[ArrayLike]" , ArrayLike ]):
55+ """Find the inverse of a matrix."""
56+ if isinstance (a , Quantity ):
57+ return DerivedQuantity (
58+ value = np .linalg .inv (a .value ),
59+ units = a .units ,
60+ history = QuantityHistory .apply_operation (MatInv , a .history ),
61+ )
62+ else :
63+ return np .linalg .inv (a )
64+
65+
66+ def norm_1 (a : Union ["Quantity[ArrayLike]" , ArrayLike ], axes : int | tuple [int ] | None = None ):
67+ """Caculate the 1-norm of an array or an array based quantity."""
68+ if isinstance (a , Quantity ):
69+ if axes is None :
70+ return DerivedQuantity (
71+ value = np .linalg .norm (a .value , ord = 1 , axes = axes ),
72+ units = a .units ,
73+ history = QuantityHistory .apply_operation (Norm_1 , a .history ),
74+ )
75+
76+ else :
77+ return DerivedQuantity (
78+ value = np .linalg .norm (a .value , ord = 1 , axes = axes ),
79+ units = a .units ,
80+ history = QuantityHistory .apply_operation (Norm_1 , a .history , axes = axes ),
81+ )
82+
83+ else :
84+ return np .linalg .norm (a .value , ord = 1 , axes = axes )
85+
86+
87+ def norm_2 (a : Union ["Quantity[ArrayLike]" , ArrayLike ], axes : int | tuple [int ] | None = None ):
88+ """Caculate the 2-norm of an array or an array based quantity."""
89+ if isinstance (a , Quantity ):
90+ if axes is None :
91+ return DerivedQuantity (
92+ value = np .linalg .norm (a .value , axes = axes ),
93+ units = a .units ,
94+ history = QuantityHistory .apply_operation (Norm_2 , a .history ),
95+ )
96+
97+ else :
98+ return DerivedQuantity (
99+ value = np .linalg .norm (a .value , axes = axes ),
100+ units = a .units ,
101+ history = QuantityHistory .apply_operation (Norm_2 , a .history , axes = axes ),
102+ )
103+
104+ else :
105+ return np .linalg .norm (a .value , axes = axes )
106+
107+
42108def dot (a : Union ["Quantity[ArrayLike]" , ArrayLike ], b : Union ["Quantity[ArrayLike]" , ArrayLike ]):
43109 """Dot product of two arrays or two array based quantities"""
44110 a_is_quantity = isinstance (a , Quantity )
@@ -294,6 +360,34 @@ def __eq__(self, other):
294360 return False
295361
296362
363+ class MatrixIdentity (ConstantBase ):
364+ serialisation_name = "identity"
365+
366+ def __init__ (self , size ):
367+ self .size = size
368+
369+ def evaluate (self , variables : dict [int , T ]) -> T :
370+ return np .eye (self .size )
371+
372+ def _derivative (self , hash_value : int ):
373+ return Constant (np .zeros ((self .size , self .size )))
374+
375+ @staticmethod
376+ def _deserialise (parameters : dict ) -> "Operation" :
377+ return MatrixIdentity (self .size )
378+
379+ def _serialise_parameters (self ) -> dict [str , Any ]:
380+ return {"size" : numerical_encode (self .size )}
381+
382+ def summary (self , indent_amount : int = 0 , indent = " " ):
383+ return f"{ indent_amount * indent } { self .size } [Matrix Id.]"
384+
385+ def __eq__ (self , other ):
386+ if isinstance (other , MatrixIdentity ):
387+ return self .size == other .size
388+ return False
389+
390+
297391class Constant (ConstantBase ):
298392 serialisation_name = "constant"
299393
@@ -999,14 +1093,14 @@ def __init__(self, a: Operation, axes: tuple[int] | None = None):
9991093 self .axes = axes
10001094
10011095 def evaluate (self , variables : dict [int , T ]) -> T :
1002- return np .transpose (self .a .evaluate (variables ))
1096+ return np .transpose (self .a .evaluate (variables ), self . axes )
10031097
10041098 def _derivative (self , hash_value : int ) -> Operation :
1005- return Transpose (self .a .derivative (hash_value )) # TODO: Check!
1099+ return Transpose (self .a .derivative (hash_value ), self . axes ) # TODO: Check!
10061100
10071101 def _clean (self ):
10081102 clean_a = self .a ._clean ()
1009- return Transpose (clean_a )
1103+ return Transpose (clean_a , self . axes )
10101104
10111105 def _serialise_parameters (self ) -> dict [str , Any ]:
10121106 if self .axes is None :
@@ -1034,7 +1128,7 @@ def summary(self, indent_amount: int = 0, indent=" "):
10341128 f"{ indent_amount * indent } Transpose(\n "
10351129 + self .a .summary (indent_amount + 1 , indent )
10361130 + "\n "
1037- + f"{ (indent_amount + 1 ) * indent } { self .axes } \n "
1131+ + f"{ (indent_amount + 1 ) * indent } { list ( self .axes ) } \n "
10381132 + f"{ indent_amount * indent } )"
10391133 )
10401134
@@ -1044,6 +1138,56 @@ def __eq__(self, other):
10441138 return False
10451139
10461140
1141+ class Trace (Operation ):
1142+ """Trace operation - as per numpy"""
1143+
1144+ serialisation_name = "trace"
1145+
1146+ def __init__ (self , a : Operation , offset : int = 0 , axis1 : int = 0 , axis2 : int = 1 ):
1147+ self .a = a
1148+ self .offset = offset
1149+ self .axis1 = axis1
1150+ self .axis2 = axis2
1151+
1152+ def evaluate (self , variables : dict [int , T ]) -> T :
1153+ return np .trace (self .a .evaluate (variables ), self .offset , self .axis1 , self .axis2 )
1154+
1155+ def _derivative (self , hash_value : int ) -> Operation :
1156+ return Trace (self .a .derivative (hash_value ), self .offset , self .axis1 , self .axis2 )
1157+
1158+ def _clean (self ):
1159+ clean_a = self .a ._clean ()
1160+ return Trace (clean_a , self .offset , self .axis1 , self .axis2 )
1161+
1162+ def _serialise_parameters (self ) -> dict [str , Any ]:
1163+ return {"a" : self .a ._serialise_json (), "offset" : self .offset , "axis1" : self .axis1 , "axis2" : self .axis2 }
1164+
1165+ @staticmethod
1166+ def _deserialise (parameters : dict ) -> "Operation" :
1167+ return Trace (
1168+ a = Operation .deserialise_json (parameters ["a" ]),
1169+ offset = parameters ["offset" ],
1170+ axis1 = parameters ["axis1" ],
1171+ axis2 = parameters ["axis2" ],
1172+ )
1173+
1174+ def summary (self , indent_amount : int = 0 , indent = " " ):
1175+ return (
1176+ f"{ indent_amount * indent } Trace(\n "
1177+ + self .a .summary (indent_amount + 1 , indent )
1178+ + "\n "
1179+ + f"{ (indent_amount + 1 ) * indent } { self .offset } \n "
1180+ + f"{ (indent_amount + 1 ) * indent } { self .axis1 } \n "
1181+ + f"{ (indent_amount + 1 ) * indent } { self .axis2 } \n "
1182+ + f"{ indent_amount * indent } )"
1183+ )
1184+
1185+ def __eq__ (self , other ):
1186+ if isinstance (other , Trace ):
1187+ return other .a == self .a
1188+ return False
1189+
1190+
10471191class Dot (BinaryOperation ):
10481192 """Dot product - backed by numpy's dot method"""
10491193
@@ -1089,6 +1233,28 @@ def _clean_ab(self, a, b):
10891233 return MatMul (a , b )
10901234
10911235
1236+ class MatInv (UnaryOperation ):
1237+ """Matrix inversion, using numpy"""
1238+
1239+ serialisation_name = "matinv"
1240+
1241+ def evaluate (self , variables : dict [int , T ]) -> T :
1242+ return np .linalg .inv (self .a .evaluate (variables ))
1243+
1244+ def _derivative (self , hash_value : int ) -> Operation :
1245+ return Neg (Matmul (MatMul (MatInv (self .a ), self .a ._derivative (hash_value )), MatInv (self .a )))
1246+
1247+ def _clean (self ):
1248+ clean_a = self .a ._clean ()
1249+
1250+ if isinstance (clean_a , MatInv ):
1251+ # Removes double inversions
1252+ return clean_a .a
1253+
1254+ else :
1255+ return MatInv (clean_a )
1256+
1257+
10921258class TensorDot (Operation ):
10931259 serialisation_name = "tensor_product"
10941260
@@ -1119,6 +1285,116 @@ def _deserialise(parameters: dict) -> "Operation":
11191285 )
11201286
11211287
1288+ class Norm_1 (Operation ):
1289+ """1-norm of a matrix from numpy"""
1290+
1291+ serialisation_name = "norm_1"
1292+
1293+ def __init__ (self , a : Operation , axes : int | tuple [int ] | None = None ):
1294+ self .a = a
1295+ self .axes = axes
1296+
1297+ def evaluate (self , variables : dict [int , T ]) -> T :
1298+ return np .linalg .norm (self .a .evaluate (variables ), ord = 1 , axis = self .axes )
1299+
1300+ def _derivative (self , hash_value : int ) -> Operation :
1301+ return np .sign (self .a )
1302+
1303+ def _clean (self ):
1304+ clean_a = self .a ._clean ()
1305+ return Norm_1 (clean_a , self .axes )
1306+
1307+ def _serialise_parameters (self ) -> dict [str , Any ]:
1308+ if self .axes is None :
1309+ return {"a" : self .a ._serialise_json ()}
1310+ else :
1311+ return {"a" : self .a ._serialise_json (), "axes" : list (self .axes )}
1312+
1313+ @staticmethod
1314+ def _deserialise (parameters : dict ) -> "Operation" :
1315+ if "axes" in parameters :
1316+ return Norm_1 (a = Operation .deserialise_json (parameters ["a" ]), axes = tuple (parameters ["axes" ]))
1317+ else :
1318+ return Norm_1 (a = Operation .deserialise_json (parameters ["a" ]))
1319+
1320+ def summary (self , indent_amount : int = 0 , indent = " " ):
1321+ if self .axes is None :
1322+ return (
1323+ f"{ indent_amount * indent } Norm_1(\n "
1324+ + self .a .summary (indent_amount + 1 , indent )
1325+ + "\n "
1326+ + f"{ indent_amount * indent } )"
1327+ )
1328+ else :
1329+ return (
1330+ f"{ indent_amount * indent } Norm_1(\n "
1331+ + self .a .summary (indent_amount + 1 , indent )
1332+ + "\n "
1333+ + f"{ (indent_amount + 1 ) * indent } { list (self .axes )} \n "
1334+ + f"{ indent_amount * indent } )"
1335+ )
1336+
1337+ def __eq__ (self , other ):
1338+ if isinstance (other , Norm_1 ):
1339+ return other .a == self .a
1340+ return False
1341+
1342+
1343+ class Norm_2 (Operation ):
1344+ """2-norm of a matrix from numpy"""
1345+
1346+ serialisation_name = "norm_2"
1347+
1348+ def __init__ (self , a : Operation , axes : int | tuple [int ] | None = None ):
1349+ self .a = a
1350+ self .axes = axes
1351+
1352+ def evaluate (self , variables : dict [int , T ]) -> T :
1353+ return np .linalg .norm (self .a .evaluate (variables ), axis = self .axes )
1354+
1355+ def _derivative (self , hash_value : int ) -> Operation :
1356+ return Transpose (Div (self .a , Norm_2 (self .a , self .axes )))
1357+
1358+ def _clean (self ):
1359+ clean_a = self .a ._clean ()
1360+ return Norm_2 (clean_a , self .axes )
1361+
1362+ def _serialise_parameters (self ) -> dict [str , Any ]:
1363+ if self .axes is None :
1364+ return {"a" : self .a ._serialise_json ()}
1365+ else :
1366+ return {"a" : self .a ._serialise_json (), "axes" : list (self .axes )}
1367+
1368+ @staticmethod
1369+ def _deserialise (parameters : dict ) -> "Operation" :
1370+ if "axes" in parameters :
1371+ return Norm_2 (a = Operation .deserialise_json (parameters ["a" ]), axes = tuple (parameters ["axes" ]))
1372+ else :
1373+ return Norm_2 (a = Operation .deserialise_json (parameters ["a" ]))
1374+
1375+ def summary (self , indent_amount : int = 0 , indent = " " ):
1376+ if self .axes is None :
1377+ return (
1378+ f"{ indent_amount * indent } Norm_2(\n "
1379+ + self .a .summary (indent_amount + 1 , indent )
1380+ + "\n "
1381+ + f"{ indent_amount * indent } )"
1382+ )
1383+ else :
1384+ return (
1385+ f"{ indent_amount * indent } Norm_2(\n "
1386+ + self .a .summary (indent_amount + 1 , indent )
1387+ + "\n "
1388+ + f"{ (indent_amount + 1 ) * indent } { list (self .axes )} \n "
1389+ + f"{ indent_amount * indent } )"
1390+ )
1391+
1392+ def __eq__ (self , other ):
1393+ if isinstance (other , Norm_2 ):
1394+ return other .a == self .a
1395+ return False
1396+
1397+
11221398_serialisable_classes = [
11231399 AdditiveIdentity ,
11241400 MultiplicativeIdentity ,
@@ -1141,9 +1417,13 @@ def _deserialise(parameters: dict) -> "Operation":
11411417 Pow ,
11421418 Log ,
11431419 Transpose ,
1420+ Trace ,
11441421 Dot ,
11451422 MatMul ,
1423+ MatInv ,
11461424 TensorDot ,
1425+ Norm_1 ,
1426+ Norm_2 ,
11471427]
11481428
11491429_serialisation_lookup = {class_ .serialisation_name : class_ for class_ in _serialisable_classes }
0 commit comments