Skip to content

Commit 15acac0

Browse files
committed
Adds trig and exp operations (#186)
* Adds exp operation and corrects log derivatives * Adjusts math library imports * Adds trig operations * Adds evaluation tests * Corrects unary clean signatures * Adds tests * Adds clean for exp and ln * Adds parameterised fixtures * Addresses review comments
1 parent 2813fca commit 15acac0

2 files changed

Lines changed: 334 additions & 49 deletions

File tree

sasdata/quantities/quantity.py

Lines changed: 220 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import hashlib
22
import json
3-
from math import e, log
3+
import math
44
from typing import Any, Self, TypeVar, Union
55

66
import h5py
@@ -451,25 +451,207 @@ def _clean(self):
451451
class Ln(UnaryOperation):
452452
serialisation_name = "ln"
453453

454-
def evaluate(self, variables: dict[int, T]) -> Operation:
455-
return log(self.a.evaluate(variables))
454+
def evaluate(self, variables: dict[int, T]) -> T:
455+
return math.log(self.a.evaluate(variables))
456456

457457
def _derivative(self, hash_value: int) -> Operation:
458-
return Inv(self.a)
458+
return Div(self.a._derivative(hash_value), self.a)
459459

460-
def _clean(self, a):
460+
def _clean(self):
461461
clean_a = self.a._clean()
462462

463-
if isinstance(a, MultiplicativeIdentity):
463+
if isinstance(clean_a, Exp):
464+
# Convert ln(exp(x)) to x
465+
return clean_a.a
466+
467+
elif isinstance(clean_a, MultiplicativeIdentity):
464468
# Convert ln(1) to 0
465469
return AdditiveIdentity()
466470

467-
elif a == e:
471+
elif clean_a == math.e:
468472
# Convert ln(e) to 1
469473
return MultiplicativeIdentity()
470474

471475
else:
472-
return Log(clean_a)
476+
return Ln(clean_a)
477+
478+
479+
class Exp(UnaryOperation):
480+
serialisation_name = "exp"
481+
482+
def evaluate(self, variables: dict[int, T]) -> T:
483+
return math.exp(self.a.evaluate(variables))
484+
485+
def _derivative(self, hash_value: int) -> Operation:
486+
return Mul(self.a._derivative(hash_value), Exp(self.a))
487+
488+
def _clean(self):
489+
clean_a = self.a._clean()
490+
491+
if isinstance(clean_a, Ln):
492+
# Convert exp(ln(x)) to x
493+
return clean_a.a
494+
495+
elif isinstance(clean_a, MultiplicativeIdentity):
496+
# Convert e**1 to e
497+
return math.e
498+
499+
elif isinstance(clean_a, AdditiveIdentity):
500+
# Convert e**0 to 1
501+
return 1
502+
503+
else:
504+
return Exp(clean_a)
505+
506+
507+
class Sin(UnaryOperation):
508+
serialisation_name = "sin"
509+
510+
def evaluate(self, variables: dict[int, T]) -> T:
511+
return np.sin(self.a.evaluate(variables))
512+
513+
def _derivative(self, hash_value: int) -> Operation:
514+
return Mul(self.a._derivative(hash_value), Cos(self.a))
515+
516+
def _clean(self):
517+
clean_a = self.a._clean()
518+
519+
if isinstance(clean_a, ArcSin):
520+
return clean_a.a
521+
522+
elif isinstance(clean_a, AdditiveIdentity):
523+
# Convert sin(0) to 0
524+
return AdditiveIdentity()
525+
526+
else:
527+
return Sin(clean_a)
528+
529+
530+
class ArcSin(UnaryOperation):
531+
serialisation_name = "arcsin"
532+
533+
def evaluate(self, variables: dict[int, T]) -> T:
534+
return np.arcsin(self.a.evaluate(variables))
535+
536+
def _derivative(self, hash_value: int) -> Operation:
537+
return Div(self.a._derivative(hash_value), Sqrt(Sub(MultiplicativeIdentity(), Mul(self.a, self.a))))
538+
539+
def _clean(self):
540+
clean_a = self.a._clean()
541+
542+
if isinstance(clean_a, Sin):
543+
return clean_a.a
544+
545+
elif isinstance(clean_a, AdditiveIdentity):
546+
# Convert arcsin(0) to 0
547+
return AdditiveIdentity()
548+
549+
elif isinstance(clean_a, MultiplicativeIdentity):
550+
# Convert arcsin(1) to pi/2
551+
return Constant(0.5 * math.pi)
552+
553+
else:
554+
return ArcSin(clean_a)
555+
556+
557+
class Cos(UnaryOperation):
558+
serialisation_name = "cos"
559+
560+
def evaluate(self, variables: dict[int, T]) -> T:
561+
return np.cos(self.a.evaluate(variables))
562+
563+
def _derivative(self, hash_value: int) -> Operation:
564+
return Mul(self.a._derivative(hash_value), Neg(Sin(self.a)))
565+
566+
def _clean(self):
567+
clean_a = self.a._clean()
568+
569+
if isinstance(clean_a, ArcCos):
570+
return clean_a.a
571+
572+
elif isinstance(clean_a, AdditiveIdentity):
573+
# Convert cos(0) to 1
574+
return MultiplicativeIdentity()
575+
576+
else:
577+
return Cos(clean_a)
578+
579+
580+
class ArcCos(UnaryOperation):
581+
serialisation_name = "arccos"
582+
583+
def evaluate(self, variables: dict[int, T]) -> T:
584+
return np.arccos(self.a.evaluate(variables))
585+
586+
def _derivative(self, hash_value: int) -> Operation:
587+
return Neg(Div(self.a._derivative(hash_value), Sqrt(Sub(MultiplicativeIdentity(), Mul(self.a, self.a)))))
588+
589+
def _clean(self):
590+
clean_a = self.a._clean()
591+
592+
if isinstance(clean_a, Cos):
593+
return clean_a.a
594+
595+
elif isinstance(clean_a, AdditiveIdentity):
596+
# Convert arccos(0) to pi/2
597+
return Constant(0.5 * math.pi)
598+
599+
elif isinstance(clean_a, MultiplicativeIdentity):
600+
# Convert arccos(1) to 0
601+
return AdditiveIdentity()
602+
603+
else:
604+
return ArcCos(clean_a)
605+
606+
607+
class Tan(UnaryOperation):
608+
serialisation_name = "tan"
609+
610+
def evaluate(self, variables: dict[int, T]) -> T:
611+
return np.tan(self.a.evaluate(variables))
612+
613+
def _derivative(self, hash_value: int) -> Operation:
614+
return Div(self.a._derivative(hash_value), Mul(Cos(self.a), Cos(self.a)))
615+
616+
def _clean(self):
617+
clean_a = self.a._clean()
618+
619+
if isinstance(clean_a, ArcTan):
620+
return clean_a.a
621+
622+
elif isinstance(clean_a, AdditiveIdentity):
623+
# Convert tan(0) to 0
624+
return AdditiveIdentity()
625+
626+
else:
627+
return Tan(clean_a)
628+
629+
630+
class ArcTan(UnaryOperation):
631+
serialisation_name = "arctan"
632+
633+
def evaluate(self, variables: dict[int, T]) -> T:
634+
return np.arctan(self.a.evaluate(variables))
635+
636+
def _derivative(self, hash_value: int) -> Operation:
637+
return Div(self.a._derivative(hash_value), Add(MultiplicativeIdentity(), Mul(self.a, self.a)))
638+
639+
def _clean(self):
640+
clean_a = self.a._clean()
641+
642+
if isinstance(clean_a, Tan):
643+
return clean_a.a
644+
645+
elif isinstance(clean_a, AdditiveIdentity):
646+
# Convert arctan(0) to 0
647+
return AdditiveIdentity()
648+
649+
elif isinstance(clean_a, MultiplicativeIdentity):
650+
# Convert arctan(1) to pi/4
651+
return Constant(0.25 * math.pi)
652+
653+
else:
654+
return ArcTan(clean_a)
473655

474656

475657
class BinaryOperation(Operation):
@@ -703,11 +885,11 @@ def __init__(self, a: Operation, base: float):
703885
self.a = a
704886
self.base = base
705887

706-
def evaluate(self, variables: dict[int, T]) -> Operation:
707-
return log(self.a.evaluate(variables), self.base)
888+
def evaluate(self, variables: dict[int, T]) -> T:
889+
return math.log(self.a.evaluate(variables), self.base)
708890

709891
def _derivative(self, hash_value: int) -> Operation:
710-
return Inv(Mul(self.a, Ln(Constant(self.base))))
892+
return Div(self.a.derivative(hash_value), Mul(self.a, Ln(Constant(self.base))))
711893

712894
def _clean_ab(self) -> Operation:
713895
a = self.a._clean()
@@ -765,7 +947,7 @@ def _derivative(self, hash_value: int) -> Operation:
765947
else:
766948
return Mul(Constant(self.power), Mul(Pow(self.a, self.power - 1), self.a._derivative(hash_value)))
767949

768-
def _clean(self) -> Operation:
950+
def _clean(self):
769951
a = self.a._clean()
770952

771953
if self.power == 1:
@@ -945,6 +1127,13 @@ def _deserialise(parameters: dict) -> "Operation":
9451127
Neg,
9461128
Inv,
9471129
Ln,
1130+
Exp,
1131+
Sin,
1132+
ArcSin,
1133+
Cos,
1134+
ArcCos,
1135+
Tan,
1136+
ArcTan,
9481137
Add,
9491138
Sub,
9501139
Mul,
@@ -1080,13 +1269,15 @@ def summary(self):
10801269

10811270

10821271
class Quantity[QuantityType]:
1083-
def __init__(self,
1084-
value: QuantityType,
1085-
units: Unit,
1086-
standard_error: QuantityType | None = None,
1087-
hash_seed="",
1088-
name="",
1089-
id_header=""):
1272+
def __init__(
1273+
self,
1274+
value: QuantityType,
1275+
units: Unit,
1276+
standard_error: QuantityType | None = None,
1277+
hash_seed="",
1278+
name="",
1279+
id_header="",
1280+
):
10901281
self.value = value
10911282
""" Numerical value of this data, in the specified units"""
10921283

@@ -1173,11 +1364,13 @@ def in_units_of(self, units: Unit) -> QuantityType:
11731364

11741365
def to_units_of(self, new_units: Unit) -> "Quantity[QuantityType]":
11751366
new_value, new_error = self.in_units_of_with_standard_error(new_units)
1176-
return Quantity(value=new_value,
1177-
units=new_units,
1178-
standard_error=new_error,
1179-
hash_seed=self._hash_seed,
1180-
id_header=self._id_header)
1367+
return Quantity(
1368+
value=new_value,
1369+
units=new_units,
1370+
standard_error=new_error,
1371+
hash_seed=self._hash_seed,
1372+
id_header=self._id_header,
1373+
)
11811374

11821375
def variance_in_units_of(self, units: Unit) -> QuantityType:
11831376
"""Get the variance of quantity in other units"""
@@ -1397,12 +1590,9 @@ def as_h5(self, group: h5py.Group, name: str):
13971590

13981591

13991592
class NamedQuantity[QuantityType](Quantity[QuantityType]):
1400-
def __init__(self,
1401-
name: str,
1402-
value: QuantityType,
1403-
units: Unit,
1404-
standard_error: QuantityType | None = None,
1405-
id_header=""):
1593+
def __init__(
1594+
self, name: str, value: QuantityType, units: Unit, standard_error: QuantityType | None = None, id_header=""
1595+
):
14061596
super().__init__(value, units, standard_error=standard_error, hash_seed=name, name=name, id_header=id_header)
14071597

14081598
def __repr__(self):

0 commit comments

Comments
 (0)