33from typing import Callable
44
55import numpy as np
6- from sympy import symbols , Matrix , lambdify , derive_by_array , ImmutableDenseNDimArray
6+ from sympy import ImmutableDenseNDimArray , Matrix , derive_by_array , lambdify , symbols
77
88
99@dataclass
@@ -26,9 +26,7 @@ def __validation(self):
2626 assert len (self .var_dims ) == vars_num
2727 self .__inr_dim = sum (self .var_dims )
2828 assert type (self .__inr_dim ) == int # ensure input dimensions are integers
29- self .__inr_vars = symbols (
30- [vars [i ] + ":" + str (self .var_dims [i ]) for i in range (vars_num )]
31- )
29+ self .__inr_vars = symbols ([vars [i ] + ":" + str (self .var_dims [i ]) for i in range (vars_num )])
3230 self .__inr_x = symbols ("inr_x:" + str (self .__inr_dim ))
3331 self .__inr_f = self .f (* self .__inr_vars )
3432 self .__inr_f = - 1 * self .__inr_f if self .__reversed else self .__inr_f
@@ -45,9 +43,7 @@ def __validation(self):
4543 )
4644 )
4745 self .__inr_f = self .__inr_f .subs (d )
48- self .__inr_series [0 ] = {
49- "sym" : {v : np .asarray (self .__inr_f ) for v in range (vars_num )}
50- }
46+ self .__inr_series [0 ] = {"sym" : {v : np .asarray (self .__inr_f ) for v in range (vars_num )}}
5147
5248 def __post_init__ (self ):
5349 self .__validation ()
@@ -56,10 +52,7 @@ def __series(self, order: int, mod: str, v: int):
5652 return self .__inr_series [order ][mod ][v ]
5753
5854 def __take_derivative (self , order : int , v : int ):
59- if (
60- order - 1 not in self .__inr_series
61- or v not in self .__inr_series [order - 1 ]["sym" ]
62- ):
55+ if order - 1 not in self .__inr_series or v not in self .__inr_series [order - 1 ]["sym" ]:
6356 self .__take_derivative (order - 1 , v )
6457 start , end = self .__inr_idx [v ]
6558 x = self .__inr_x [start :end ]
@@ -77,30 +70,22 @@ def evaluate(self, xs: tuple, mod: str, order: int, v: int):
7770 self .__take_derivative (order , v )
7871
7972 def _eval_numpy ():
80- if (
81- mod not in self .__inr_series [order ]
82- or v not in self .__inr_series [order ][mod ]
83- ):
73+ if mod not in self .__inr_series [order ] or v not in self .__inr_series [order ][mod ]:
8474 if mod not in self .__inr_series :
8575 self .__inr_series [order ][mod ] = {}
8676 d = self .__series (order , "sym" , v )
8777 d = d if order == 0 else d .squeeze (axis = - 1 )
8878 d = ImmutableDenseNDimArray (d )
8979 if v not in self .__inr_series [order ][mod ]:
90- self .__inr_series [order ][mod ][v ] = lambdify (
91- self .__inr_x , d , "numpy"
92- )
80+ self .__inr_series [order ][mod ][v ] = lambdify (self .__inr_x , d , "numpy" )
9381 # self.__inr_series[order][mod] = {v: lambdify(self.__inr_x, d, "numpy")}
9482 r = np .asarray (self .__series (order , mod , v )(* np .concatenate (xs , axis = - 1 )))
9583 return r .squeeze (axis = - 1 ) if order == 0 else r
9684
9785 def _eval_interval ():
9886 from pybdr .geometry import Interval
9987
100- if (
101- mod not in self .__inr_series [order ]
102- or v not in self .__inr_series [order ][mod ]
103- ):
88+ if mod not in self .__inr_series [order ] or v not in self .__inr_series [order ][mod ]:
10489 if mod not in self .__inr_series :
10590 self .__inr_series [order ][mod ] = {}
10691 d = self .__series (order , "sym" , v )
@@ -113,7 +98,9 @@ def _eval_interval():
11398 # self.__inr_series[order][mod] = {v: [None, mask]}
11499 else :
115100 sym_d = ImmutableDenseNDimArray (d [mask ])
116- vf = lambdify (self .__inr_x , sym_d , Interval .functional ())
101+ # Convert ImmutableDenseNDimArray to list for lambdify
102+ sym_d_list = list (sym_d ) if hasattr (sym_d , "__iter__" ) else [sym_d ]
103+ vf = lambdify (self .__inr_x , sym_d_list , Interval .functional ())
117104 if v not in self .__inr_series [order ][mod ]:
118105 self .__inr_series [order ][mod ][v ] = [vf , mask ]
119106 # self.__inr_series[order][mod] = {v: [vf, mask]}
@@ -124,15 +111,7 @@ def _eval_interval():
124111 ub = np .zeros_like (d , dtype = float )
125112 # calculate interval expressions
126113 if vm [0 ] is not None :
127- vx = np .asarray (
128- vm [0 ](
129- * [
130- xs [i ][j ]
131- for i in range (len (self .var_dims ))
132- for j in range (self .var_dims [i ])
133- ]
134- )
135- )
114+ vx = np .asarray (vm [0 ](* [xs [i ][j ] for i in range (len (self .var_dims )) for j in range (self .var_dims [i ])]))
136115 inff = np .frompyfunc (lambda x : x .inf , 1 , 1 )
137116 supf = np .frompyfunc (lambda x : x .sup , 1 , 1 )
138117 lb [vm [1 ]] = inff (vx )
0 commit comments