1+ import numpy as np
2+
3+ from LoopStructural .interpolators ._discrete_interpolator import DiscreteInterpolator
4+ from LoopStructural .interpolators ._finite_difference_interpolator import FiniteDifferenceInterpolator
5+ from ._p1interpolator import P1Interpolator
6+ from typing import Optional , Union , Callable
7+ from scipy import sparse
8+ from LoopStructural .utils import rng
9+
10+ class ConstantNormInterpolator :
11+ """Adds a non linear constraint to an interpolator to constrain
12+ the norm of the gradient to be a set value.
13+
14+ Returns
15+ -------
16+ _type_
17+ _description_
18+ """
19+ def __init__ (self , interpolator : DiscreteInterpolator ,basetype ):
20+ """Initialise the constant norm inteprolator
21+ with a discrete interpolator.
22+
23+ Parameters
24+ ----------
25+ interpolator : DiscreteInterpolator
26+ The discrete interpolator to add constant norm to.
27+ """
28+ self .basetype = basetype
29+ self .interpolator = interpolator
30+ self .support = interpolator .support
31+ self .random_subset = False
32+ self .norm_length = 1.0
33+ self .n_iterations = 20
34+ self .store_solution_history = False
35+ self .solution_history = []#np.zeros((self.n_iterations, self.support.n_nodes))
36+ self .gradient_constraint_store = []
37+ def add_constant_norm (self , w :float ):
38+ """Add a constraint to the interpolator to constrain the norm of the gradient
39+ to be a set value
40+
41+ Parameters
42+ ----------
43+ w : float
44+ weighting of the constraint
45+ """
46+ if "constant norm" in self .interpolator .constraints :
47+ _ = self .interpolator .constraints .pop ("constant norm" )
48+
49+ element_indices = np .arange (self .support .elements .shape [0 ])
50+ if self .random_subset :
51+ rng .shuffle (element_indices )
52+ element_indices = element_indices [: int (0.1 * self .support .elements .shape [0 ])]
53+ vertices , gradient , elements , inside = self .support .get_element_gradient_for_location (
54+ self .support .barycentre [element_indices ]
55+ )
56+
57+ t_g = gradient [:, :, :]
58+ # t_n = gradient[self.support.shared_element_relationships[:, 1], :, :]
59+ v_t = np .einsum (
60+ "ijk,ik->ij" ,
61+ t_g ,
62+ self .interpolator .c [self .support .elements [elements ]],
63+ )
64+
65+ v_t = v_t / np .linalg .norm (v_t , axis = 1 )[:, np .newaxis ]
66+ self .gradient_constraint_store .append (np .hstack ([self .support .barycentre [element_indices ],v_t ]))
67+ A1 = np .einsum ("ij,ijk->ik" , v_t , t_g )
68+ volume = self .support .element_size [element_indices ]
69+ A1 = A1 / volume [:, np .newaxis ] # normalise by element size
70+
71+ b = np .zeros (A1 .shape [0 ]) + self .norm_length
72+ b = b / volume # normalise by element size
73+ idc = np .hstack (
74+ [
75+ self .support .elements [elements ],
76+ ]
77+ )
78+ self .interpolator .add_constraints_to_least_squares (A1 , b , idc , w = w , name = "constant norm" )
79+
80+ def solve_system (
81+ self ,
82+ solver : Optional [Union [Callable [[sparse .csr_matrix , np .ndarray ], np .ndarray ], str ]] = None ,
83+ tol : Optional [float ] = None ,
84+ solver_kwargs : dict = {},
85+ ) -> bool :
86+ """Solve the system of equations iteratively for the constant norm interpolator.
87+
88+ Parameters
89+ ----------
90+ solver : Optional[Union[Callable[[sparse.csr_matrix, np.ndarray], np.ndarray], str]], optional
91+ Solver function or name, by default None
92+ tol : Optional[float], optional
93+ Tolerance for the solver, by default None
94+ solver_kwargs : dict, optional
95+ Additional arguments for the solver, by default {}
96+
97+ Returns
98+ -------
99+ bool
100+ Success status of the solver
101+ """
102+ success = True
103+ for i in range (self .n_iterations ):
104+ if i > 0 :
105+ self .add_constant_norm (w = (0.1 * i ) ** 2 + 0.01 )
106+ # Ensure the interpolator is cast to P1Interpolator before calling solve_system
107+ if isinstance (self .interpolator , self .basetype ):
108+ success = self .basetype .solve_system (self .interpolator , solver = solver , tol = tol , solver_kwargs = solver_kwargs )
109+ if self .store_solution_history :
110+
111+ self .solution_history .append (self .interpolator .c )
112+ else :
113+ raise TypeError ("self.interpolator is not an instance of P1Interpolator" )
114+ if not success :
115+ break
116+ return success
117+
118+ class ConstantNormP1Interpolator (P1Interpolator , ConstantNormInterpolator ):
119+ """Constant norm interpolator using P1 base interpolator
120+
121+ Parameters
122+ ----------
123+ P1Interpolator : class
124+ The P1Interpolator class.
125+ ConstantNormInterpolator : class
126+ The ConstantNormInterpolator class.
127+ """
128+ def __init__ (self , support ):
129+ """Initialise the constant norm P1 interpolator.
130+
131+ Parameters
132+ ----------
133+ support : _type_
134+ _description_
135+ """
136+ P1Interpolator .__init__ (self , support )
137+ ConstantNormInterpolator .__init__ (self , self , P1Interpolator )
138+
139+ def solve_system (
140+ self ,
141+ solver : Optional [Union [Callable [[sparse .csr_matrix , np .ndarray ], np .ndarray ], str ]] = None ,
142+ tol : Optional [float ] = None ,
143+ solver_kwargs : dict = {},
144+ ) -> bool :
145+ """Solve the system of equations for the constant norm P1 interpolator.
146+
147+ Parameters
148+ ----------
149+ solver : Optional[Union[Callable[[sparse.csr_matrix, np.ndarray], np.ndarray], str]], optional
150+ Solver function or name, by default None
151+ tol : Optional[float], optional
152+ Tolerance for the solver, by default None
153+ solver_kwargs : dict, optional
154+ Additional arguments for the solver, by default {}
155+
156+ Returns
157+ -------
158+ bool
159+ Success status of the solver
160+ """
161+ return ConstantNormInterpolator .solve_system (self , solver = solver , tol = tol , solver_kwargs = solver_kwargs )
162+
163+ class ConstantNormFDIInterpolator (FiniteDifferenceInterpolator , ConstantNormInterpolator ):
164+ """Constant norm interpolator using finite difference base interpolator
165+
166+ Parameters
167+ ----------
168+ FiniteDifferenceInterpolator : class
169+ The FiniteDifferenceInterpolator class.
170+ ConstantNormInterpolator : class
171+ The ConstantNormInterpolator class.
172+ """
173+ def __init__ (self , support ):
174+ """Initialise the constant norm finite difference interpolator.
175+
176+ Parameters
177+ ----------
178+ support : _type_
179+ _description_
180+ """
181+ FiniteDifferenceInterpolator .__init__ (self , support )
182+ ConstantNormInterpolator .__init__ (self , self , FiniteDifferenceInterpolator )
183+ def solve_system (
184+ self ,
185+ solver : Optional [Union [Callable [[sparse .csr_matrix , np .ndarray ], np .ndarray ], str ]] = None ,
186+ tol : Optional [float ] = None ,
187+ solver_kwargs : dict = {},
188+ ) -> bool :
189+ """Solve the system of equations for the constant norm finite difference interpolator.
190+
191+ Parameters
192+ ----------
193+ solver : Optional[Union[Callable[[sparse.csr_matrix, np.ndarray], np.ndarray], str]], optional
194+ Solver function or name, by default None
195+ tol : Optional[float], optional
196+ Tolerance for the solver, by default None
197+ solver_kwargs : dict, optional
198+ Additional arguments for the solver, by default {}
199+
200+ Returns
201+ -------
202+ bool
203+ Success status of the solver
204+ """
205+ return ConstantNormInterpolator .solve_system (self , solver = solver , tol = tol , solver_kwargs = solver_kwargs )
0 commit comments