Skip to content

Commit 61b081e

Browse files
committed
Refactor model.py for improved readability and performance; add batch Mahalanobis distance test in test_interval.py
1 parent e35e58c commit 61b081e

File tree

3 files changed

+74
-36
lines changed

3 files changed

+74
-36
lines changed

pybdr/model/model.py

Lines changed: 11 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from typing import Callable
44

55
import numpy as np
6-
from sympy import symbols, Matrix, lambdify, derive_by_array, ImmutableDenseNDimArray
6+
from sympy import ImmutableDenseNDimArray, Matrix, derive_by_array, lambdify, symbols
77

88

99
@dataclass
@@ -26,9 +26,7 @@ def __validation(self):
2626
assert len(self.var_dims) == vars_num
2727
self.__inr_dim = sum(self.var_dims)
2828
assert type(self.__inr_dim) == int # ensure input dimensions are integers
29-
self.__inr_vars = symbols(
30-
[vars[i] + ":" + str(self.var_dims[i]) for i in range(vars_num)]
31-
)
29+
self.__inr_vars = symbols([vars[i] + ":" + str(self.var_dims[i]) for i in range(vars_num)])
3230
self.__inr_x = symbols("inr_x:" + str(self.__inr_dim))
3331
self.__inr_f = self.f(*self.__inr_vars)
3432
self.__inr_f = -1 * self.__inr_f if self.__reversed else self.__inr_f
@@ -45,9 +43,7 @@ def __validation(self):
4543
)
4644
)
4745
self.__inr_f = self.__inr_f.subs(d)
48-
self.__inr_series[0] = {
49-
"sym": {v: np.asarray(self.__inr_f) for v in range(vars_num)}
50-
}
46+
self.__inr_series[0] = {"sym": {v: np.asarray(self.__inr_f) for v in range(vars_num)}}
5147

5248
def __post_init__(self):
5349
self.__validation()
@@ -56,10 +52,7 @@ def __series(self, order: int, mod: str, v: int):
5652
return self.__inr_series[order][mod][v]
5753

5854
def __take_derivative(self, order: int, v: int):
59-
if (
60-
order - 1 not in self.__inr_series
61-
or v not in self.__inr_series[order - 1]["sym"]
62-
):
55+
if order - 1 not in self.__inr_series or v not in self.__inr_series[order - 1]["sym"]:
6356
self.__take_derivative(order - 1, v)
6457
start, end = self.__inr_idx[v]
6558
x = self.__inr_x[start:end]
@@ -77,30 +70,22 @@ def evaluate(self, xs: tuple, mod: str, order: int, v: int):
7770
self.__take_derivative(order, v)
7871

7972
def _eval_numpy():
80-
if (
81-
mod not in self.__inr_series[order]
82-
or v not in self.__inr_series[order][mod]
83-
):
73+
if mod not in self.__inr_series[order] or v not in self.__inr_series[order][mod]:
8474
if mod not in self.__inr_series:
8575
self.__inr_series[order][mod] = {}
8676
d = self.__series(order, "sym", v)
8777
d = d if order == 0 else d.squeeze(axis=-1)
8878
d = ImmutableDenseNDimArray(d)
8979
if v not in self.__inr_series[order][mod]:
90-
self.__inr_series[order][mod][v] = lambdify(
91-
self.__inr_x, d, "numpy"
92-
)
80+
self.__inr_series[order][mod][v] = lambdify(self.__inr_x, d, "numpy")
9381
# self.__inr_series[order][mod] = {v: lambdify(self.__inr_x, d, "numpy")}
9482
r = np.asarray(self.__series(order, mod, v)(*np.concatenate(xs, axis=-1)))
9583
return r.squeeze(axis=-1) if order == 0 else r
9684

9785
def _eval_interval():
9886
from pybdr.geometry import Interval
9987

100-
if (
101-
mod not in self.__inr_series[order]
102-
or v not in self.__inr_series[order][mod]
103-
):
88+
if mod not in self.__inr_series[order] or v not in self.__inr_series[order][mod]:
10489
if mod not in self.__inr_series:
10590
self.__inr_series[order][mod] = {}
10691
d = self.__series(order, "sym", v)
@@ -113,7 +98,9 @@ def _eval_interval():
11398
# self.__inr_series[order][mod] = {v: [None, mask]}
11499
else:
115100
sym_d = ImmutableDenseNDimArray(d[mask])
116-
vf = lambdify(self.__inr_x, sym_d, Interval.functional())
101+
# Convert ImmutableDenseNDimArray to list for lambdify
102+
sym_d_list = list(sym_d) if hasattr(sym_d, "__iter__") else [sym_d]
103+
vf = lambdify(self.__inr_x, sym_d_list, Interval.functional())
117104
if v not in self.__inr_series[order][mod]:
118105
self.__inr_series[order][mod][v] = [vf, mask]
119106
# self.__inr_series[order][mod] = {v: [vf, mask]}
@@ -124,15 +111,7 @@ def _eval_interval():
124111
ub = np.zeros_like(d, dtype=float)
125112
# calculate interval expressions
126113
if vm[0] is not None:
127-
vx = np.asarray(
128-
vm[0](
129-
*[
130-
xs[i][j]
131-
for i in range(len(self.var_dims))
132-
for j in range(self.var_dims[i])
133-
]
134-
)
135-
)
114+
vx = np.asarray(vm[0](*[xs[i][j] for i in range(len(self.var_dims)) for j in range(self.var_dims[i])]))
136115
inff = np.frompyfunc(lambda x: x.inf, 1, 1)
137116
supf = np.frompyfunc(lambda x: x.sup, 1, 1)
138117
lb[vm[1]] = inff(vx)

test/geometry/test_interval.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -448,5 +448,20 @@ def test_contains():
448448
print(b.contains(np.array([0, 3])))
449449

450450

451+
def test_Batch_MDC():
452+
"""
453+
Test the batch Mahalanobis distance computation
454+
"""
455+
pts = Interval.rand(100, 3, 1) # points with uncertainty
456+
mean = np.random.rand(3, 1) # mean of each point
457+
cov = np.random.rand(100, 3, 3) # covariance of each point
458+
diff = pts - mean[None, :, :]
459+
460+
mdc = diff.transpose(0, 2, 1) @ np.linalg.inv(cov) @ diff
461+
# mdc is of shape (100, 1, 1)
462+
mdc = Interval.squeeze(mdc) # reshape to (100,)
463+
print(mdc.shape)
464+
465+
451466
if __name__ == "__main__":
452-
pass
467+
test_Batch_MDC()

test/util/test_model.py

Lines changed: 47 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
1-
import inspect
2-
31
import numpy as np
4-
from pybdr.model import Model
52
from sympy import *
63

4+
from pybdr.model import Model
5+
76

87
def test_case_00():
98
import pybdr.util.functional.auxiliary as aux
@@ -50,3 +49,48 @@ def f(x, u):
5049
print(temp1.shape)
5150
print(temp2.shape)
5251
print(temp3.shape)
52+
53+
54+
def test_case_01():
55+
import numpy as np
56+
57+
from pybdr.geometry import Interval
58+
from pybdr.model import Model, tank6eq
59+
from pybdr.util.functional import performance_counter, performance_counter_start
60+
61+
m = Model(tank6eq, [6, 1])
62+
63+
time_start = performance_counter_start()
64+
x, u = np.random.random(6), np.random.rand(1)
65+
66+
np_derivative_0 = m.evaluate((x, u), "numpy", 3, 0)
67+
np_derivative_1 = m.evaluate((x, u), "numpy", 3, 1)
68+
np_derivative_2 = m.evaluate((x, u), "numpy", 0, 0)
69+
70+
x, u = Interval.rand(6), Interval.rand(1)
71+
int_derivative_0 = m.evaluate((x, u), "interval", 3, 0)
72+
int_derivative_1 = m.evaluate((x, u), "interval", 2, 0)
73+
int_derivative_2 = m.evaluate((x, u), "interval", 2, 0)
74+
int_derivative_3 = m.evaluate((x, u), "interval", 0, 1)
75+
76+
performance_counter(time_start, "sym_derivative")
77+
78+
79+
def test_case_02():
80+
from pybdr.geometry import Interval
81+
from pybdr.model import Model, tank6eq
82+
83+
m = Model(tank6eq, [6, 1])
84+
x, u = Interval.rand(6), Interval.rand(1)
85+
Jx = m.evaluate((x, u), "interval", 1, 0)
86+
pts_iv = np.random.rand(100, 6, 1)
87+
88+
quad_iv = pts_iv.transpose(0, 2, 1) @ Jx[None, ...] @ pts_iv
89+
quad_iv = Interval.squeeze(quad_iv)
90+
print(quad_iv.shape)
91+
92+
93+
if __name__ == "__main__":
94+
# test_case_00()
95+
# test_case_01()
96+
test_case_02()

0 commit comments

Comments
 (0)