Skip to content

Commit 5c7759e

Browse files
committed
refactor: enhance AMG solver structure with improved initialization and iteration handling
1 parent 43fd995 commit 5c7759e

2 files changed

Lines changed: 245 additions & 33 deletions

File tree

src/iteratives/amg/cycle.rs

Whitespace-only changes.

src/iteratives/amg/mod.rs

Lines changed: 245 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
// https://edoc.unibas.ch/server/api/core/bitstreams/d56e8bdd-9b91-49ec-a8f5-04eff6db51ca/content
2-
pub mod cycle;
32
pub mod graph;
43
pub mod coarsen;
54
pub mod interpolate;
@@ -28,56 +27,220 @@ pub fn solve_with_initial_guess<T>(a: CsrMatrix<T>, b: &DVector<T>, x: &mut DVec
2827
where
2928
T: RealField + Copy,
3029
{
31-
use level::*;
32-
33-
// Pre-compute initial residual
34-
let residual_buffer = DVector::from(&a * &*x - b);
35-
let hierarchy = setup(a, theta, 100);
36-
37-
// Check if we're already converged
38-
let initial_residual_norm = residual_buffer.amax();
39-
if initial_residual_norm <= tol {
40-
return true;
41-
}
42-
43-
// Use adaptive tolerance for intermediate iterations
44-
let adaptive_tol = tol.max(initial_residual_norm * T::from_f64(1e-3).unwrap());
45-
46-
for i in 0..max_iter {
47-
let mut residual_buffer = DVector::zeros(b.len());
48-
hierarchy.vcycle(0, b, x, &mut residual_buffer, adaptive_tol, 1, 1);
49-
50-
// Use the residual buffer that was updated by vcycle
51-
let current_residual = residual_buffer;
30+
let mut solver = Amg::new(tol, theta, max_iter);
31+
solver.init(&a, b, Some(x));
32+
let converged = solver.solve_iterations(&a, b, max_iter);
33+
*x = solver.x.clone();
34+
converged
35+
}
36+
37+
pub fn solve<T>(a: CsrMatrix<T>, b: &DVector<T>, max_iter: usize, tol: T, theta: T) -> Option<DVector<T>>
38+
where
39+
T: RealField + Copy,
40+
{
41+
let mut x = DVector::<T>::zeros(a.nrows());
42+
if solve_with_initial_guess(a, b, &mut x, max_iter, tol, theta) {
43+
Some(x)
44+
} else {
45+
None
46+
}
47+
}
48+
49+
/// AMG (Algebraic Multigrid) solver struct that implements the IterativeSolver trait
50+
/// for customizable solving with configurable parameters.
51+
pub struct Amg<T> {
52+
pub x: DVector<T>,
53+
pub tol: T,
54+
pub theta: T,
55+
pub max_iter: usize,
56+
pub iter: usize,
57+
pub converged: bool,
58+
pub nu_pre: usize,
59+
pub nu_post: usize,
60+
pub n_min: usize,
61+
hierarchy: Option<level::Hierarchy<T>>,
62+
residual_buffer: DVector<T>,
63+
}
64+
65+
impl<T> Amg<T>
66+
where
67+
T: RealField + Copy,
68+
{
69+
/// Creates a new AMG solver with specified parameters
70+
pub fn new(tol: T, theta: T, max_iter: usize) -> Self {
71+
Self {
72+
x: DVector::zeros(0),
73+
tol,
74+
theta,
75+
max_iter,
76+
iter: 0,
77+
converged: false,
78+
nu_pre: 1,
79+
nu_post: 1,
80+
n_min: 100,
81+
hierarchy: None,
82+
residual_buffer: DVector::zeros(0),
83+
}
84+
}
85+
86+
/// Creates a new AMG solver with custom smoothing parameters
87+
pub fn with_smoothing(tol: T, theta: T, max_iter: usize, nu_pre: usize, nu_post: usize) -> Self {
88+
Self {
89+
x: DVector::zeros(0),
90+
tol,
91+
theta,
92+
max_iter,
93+
iter: 0,
94+
converged: false,
95+
nu_pre,
96+
nu_post,
97+
n_min: 100,
98+
hierarchy: None,
99+
residual_buffer: DVector::zeros(0),
100+
}
101+
}
102+
103+
/// Sets the minimum coarse level size
104+
pub fn with_coarse_size(mut self, n_min: usize) -> Self {
105+
self.n_min = n_min;
106+
self
107+
}
108+
109+
/// Sets the pre and post smoothing iterations
110+
pub fn set_smoothing(&mut self, nu_pre: usize, nu_post: usize) {
111+
self.nu_pre = nu_pre;
112+
self.nu_post = nu_post;
113+
}
114+
}
115+
116+
impl<T> IterativeSolver<CsrMatrix<T>, DVector<T>, T> for Amg<T>
117+
where
118+
T: RealField + Copy,
119+
{
120+
fn init(&mut self, a: &CsrMatrix<T>, _b: &DVector<T>, x0: Option<&DVector<T>>) {
121+
let n = a.nrows();
122+
self.x = match x0 {
123+
Some(x0) => x0.clone(),
124+
None => DVector::<T>::zeros(n),
125+
};
126+
self.residual_buffer = DVector::zeros(n);
127+
self.iter = 0;
128+
self.converged = false;
52129

53-
// Check convergence every few iterations to reduce overhead
54-
if i % 5 == 4 || i == max_iter - 1 {
130+
// Build the AMG hierarchy
131+
self.hierarchy = Some(level::setup(a.clone(), self.theta, self.n_min));
132+
}
133+
134+
fn step(&mut self, a: &CsrMatrix<T>, b: &DVector<T>) -> bool {
135+
if let Some(ref hierarchy) = self.hierarchy {
136+
// Check if we're already converged
137+
let current_residual = a * &self.x - b;
55138
let residual_norm = current_residual.amax();
56-
if residual_norm <= tol {
139+
140+
if residual_norm <= self.tol {
141+
self.converged = true;
57142
return true;
58143
}
59144

60-
// Optional: print progress less frequently
61-
if i % 20 == 19 {
62-
println!("Iteration {}: max residual norm = {}", i + 1, residual_norm);
145+
// Use adaptive tolerance for intermediate iterations
146+
let adaptive_tol = self.tol.max(residual_norm * T::from_f64(1e-3).unwrap());
147+
148+
// Perform one V-cycle
149+
hierarchy.vcycle(
150+
0,
151+
b,
152+
&mut self.x,
153+
&mut self.residual_buffer,
154+
adaptive_tol,
155+
self.nu_pre,
156+
self.nu_post
157+
);
158+
159+
self.iter += 1;
160+
161+
// Check convergence after the V-cycle
162+
let new_residual = a * &self.x - b;
163+
let new_residual_norm = new_residual.amax();
164+
165+
if new_residual_norm <= self.tol {
166+
self.converged = true;
167+
true
168+
} else {
169+
false
63170
}
171+
} else {
172+
false
64173
}
65174
}
66-
false
175+
176+
fn reset(&mut self) {
177+
self.x.fill(T::zero());
178+
self.residual_buffer.fill(T::zero());
179+
self.iter = 0;
180+
self.converged = false;
181+
// Keep hierarchy for reuse
182+
}
183+
184+
fn hard_reset(&mut self) {
185+
self.x = DVector::<T>::zeros(0);
186+
self.residual_buffer = DVector::<T>::zeros(0);
187+
self.iter = 0;
188+
self.converged = false;
189+
self.hierarchy = None;
190+
}
191+
192+
fn soft_reset(&mut self) {
193+
self.x.fill(T::zero());
194+
self.residual_buffer.fill(T::zero());
195+
self.iter = 0;
196+
self.converged = false;
197+
// Keep hierarchy for reuse
198+
}
199+
200+
fn solution(&self) -> &DVector<T> {
201+
&self.x
202+
}
203+
204+
fn iterations(&self) -> usize {
205+
self.iter
206+
}
67207
}
68208

69-
pub fn solve<T>(a: CsrMatrix<T>, b: &DVector<T>, max_iter: usize, tol: T, theta: T) -> Option<DVector<T>>
209+
/// Convenience function to create an AMG solver with default parameters
210+
/// and solve the linear system Ax = b.
211+
pub fn solve_amg<T>(a: &CsrMatrix<T>, b: &DVector<T>, max_iter: usize, tol: T, theta: T) -> Option<DVector<T>>
70212
where
71213
T: RealField + Copy,
72214
{
73-
let mut x = DVector::<T>::zeros(a.nrows());
74-
if solve_with_initial_guess(a, b, &mut x, max_iter, tol, theta) {
75-
Some(x)
215+
let mut solver = Amg::new(tol, theta, max_iter);
216+
solver.init(a, b, None);
217+
if solver.solve_iterations(a, b, max_iter) {
218+
Some(solver.x.clone())
76219
} else {
77220
None
78221
}
79222
}
80223

224+
/// Convenience function to create an AMG solver with default parameters
225+
/// and solve with an initial guess.
226+
pub fn solve_amg_with_initial_guess<T>(
227+
a: &CsrMatrix<T>,
228+
b: &DVector<T>,
229+
x: &mut DVector<T>,
230+
max_iter: usize,
231+
tol: T,
232+
theta: T
233+
) -> bool
234+
where
235+
T: RealField + Copy,
236+
{
237+
let mut solver = Amg::new(tol, theta, max_iter);
238+
solver.init(a, b, Some(x));
239+
let converged = solver.solve_iterations(a, b, max_iter);
240+
*x = solver.x.clone();
241+
converged
242+
}
243+
81244
#[cfg(test)]
82245
mod tests {
83246
use crate::iteratives::amg::{coarsen::Mark, graph::strength_graph, interpolate::build_p};
@@ -145,4 +308,53 @@ mod tests {
145308
}
146309
}
147310
}
311+
312+
#[test]
313+
fn amg_solver_struct() {
314+
let a = poisson_1d(20);
315+
let b = DVector::from_vec(vec![1.0; 20]);
316+
317+
// Test with struct-based solver
318+
let mut solver = Amg::new(1e-6, 0.25, 100);
319+
solver.init(&a, &b, None);
320+
321+
let converged = solver.solve_iterations(&a, &b, 100);
322+
assert!(converged, "AMG solver should converge");
323+
assert!(solver.iterations() > 0, "Should perform at least one iteration");
324+
325+
// Verify solution quality
326+
let residual = &a * solver.solution() - &b;
327+
let residual_norm = residual.amax();
328+
assert!(residual_norm <= 1e-6, "Residual should be below tolerance");
329+
}
330+
331+
#[test]
332+
fn amg_solver_with_custom_smoothing() {
333+
let a = poisson_1d(16);
334+
let b = DVector::from_vec(vec![1.0; 16]);
335+
336+
// Test with custom smoothing parameters
337+
let mut solver = Amg::with_smoothing(1e-6, 0.25, 50, 2, 2);
338+
assert_eq!(solver.nu_pre, 2);
339+
assert_eq!(solver.nu_post, 2);
340+
341+
solver.init(&a, &b, None);
342+
let converged = solver.solve_iterations(&a, &b, 50);
343+
assert!(converged, "AMG solver with custom smoothing should converge");
344+
}
345+
346+
#[test]
347+
fn amg_convenience_functions() {
348+
let a = poisson_1d(12);
349+
let b = DVector::from_vec(vec![1.0; 12]);
350+
351+
// Test convenience function
352+
let solution = solve_amg(&a, &b, 100, 1e-6, 0.25);
353+
assert!(solution.is_some(), "Convenience function should return a solution");
354+
355+
// Test with initial guess
356+
let mut x = DVector::zeros(12);
357+
let converged = solve_amg_with_initial_guess(&a, &b, &mut x, 100, 1e-6, 0.25);
358+
assert!(converged, "Convenience function with initial guess should converge");
359+
}
148360
}

0 commit comments

Comments
 (0)