Skip to content

Commit 49df693

Browse files
committed
fix: improve AMG implementation with better neighbor collection and residual handling
1 parent 0b19bfd commit 49df693

5 files changed

Lines changed: 69 additions & 58 deletions

File tree

benches/benchmark_iteratives.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,7 @@ fn bench_methods(c: &mut Criterion) {
185185
// Only run AMG if BENCH_METHOD is unset or matches
186186
#[cfg(feature = "amg")]
187187
if bench_method.as_deref().is_none_or(|m| m.eq_ignore_ascii_case("amg")) {
188+
use nalgebra_sparse_linalg::iteratives::amg;
188189
group.bench_with_input(BenchmarkId::new("AMG", n), &n, |be, &_n| {
189190
be.iter_batched(
190191
|| {

src/iteratives/amg/graph.rs

Lines changed: 12 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,11 @@ where
1515
let values = row.values();
1616

1717
// Find the maximum absolute off-diagonal value in the current row i
18-
let mut max_abs_off_diag_val = T::zero();
19-
for (col_idx, val) in col_indices.iter().zip(values.iter()) {
20-
if *col_idx != i { // Exclude diagonal
21-
let abs_val = val.abs();
22-
if abs_val > max_abs_off_diag_val {
23-
max_abs_off_diag_val = abs_val;
24-
}
25-
}
26-
}
18+
let max_abs_off_diag_val = col_indices.iter()
19+
.zip(values.iter())
20+
.filter(|(col_idx, _)| **col_idx != i)
21+
.map(|(_, val)| val.abs())
22+
.fold(T::zero(), |acc, val| if val > acc { val } else { acc });
2723

2824
// If there are no off-diagonal elements, no strong connections can be formed.
2925
if max_abs_off_diag_val == T::zero() {
@@ -33,17 +29,13 @@ where
3329

3430
let threshold_val = max_abs_off_diag_val * theta;
3531

36-
// Pre-allocate vector for strong neighbors (estimate based on typical sparsity)
37-
let mut strong_neighbors = Vec::with_capacity(col_indices.len().min(10));
38-
39-
// Filter connections based on the threshold
40-
for (j_idx, val) in col_indices.iter().zip(values.iter()) {
41-
// Point j is a strong neighbor of i if |a_ij| >= threshold_val
42-
// And j must not be i itself.
43-
if *j_idx != i && val.abs() >= threshold_val {
44-
strong_neighbors.push(*j_idx);
45-
}
46-
}
32+
// Collect strong neighbors directly using iterator chains
33+
let strong_neighbors: Vec<usize> = col_indices.iter()
34+
.zip(values.iter())
35+
.filter_map(|(j_idx, val)| {
36+
(*j_idx != i && val.abs() >= threshold_val).then_some(*j_idx)
37+
})
38+
.collect();
4739

4840
result.push(strong_neighbors);
4941
},

src/iteratives/amg/interpolate.rs

Lines changed: 49 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -13,59 +13,73 @@ where
1313
let n = a.nrows();
1414
let n_coarse = coarse_of.iter().filter(|&&c| c != usize::MAX).count();
1515

16+
// Better capacity estimation based on connectivity
17+
let f_points = marks.iter().filter(|&&m| matches!(m, Mark::F)).count();
18+
let estimated_nnz = n_coarse + f_points * 4; // More accurate estimate
19+
1620
let mut trip = CooMatrix::new(n, n_coarse);
17-
// Pre-estimate capacity to reduce reallocations
18-
let estimated_nnz = n + (n - n_coarse) * 3; // Rough estimate
1921
trip.reserve(estimated_nnz);
2022

23+
// Pre-allocate reusable vectors
24+
let mut c_neighbors = Vec::with_capacity(8);
25+
let mut weights = Vec::with_capacity(8);
26+
2127
for i in 0..n {
2228
match marks[i] {
2329
Mark::C => {
24-
let j = coarse_of[i];
25-
trip.push(i, j, N::one());
30+
trip.push(i, coarse_of[i], N::one());
2631
}
2732
Mark::F => {
28-
// Find C-point neighbors in the strength graph
29-
let c_neighbors: Vec<usize> = s[i]
30-
.iter()
31-
.copied()
32-
.filter(|&nbr| matches!(marks[nbr], Mark::C))
33-
.collect();
33+
c_neighbors.clear();
34+
// Collect coarse neighbors
35+
for &nbr in &s[i] {
36+
if matches!(marks[nbr], Mark::C) {
37+
c_neighbors.push(nbr);
38+
}
39+
}
3440

41+
// If no coarse neighbors, add a zero entry
3542
if c_neighbors.is_empty() {
36-
// Fallback: connect to first coarse point if available
3743
if n_coarse > 0 {
3844
trip.push(i, 0, N::one());
3945
}
4046
continue;
4147
}
4248

43-
// Compute interpolation weights
44-
if let Some(diag_entry) = a.get_entry(i, i) {
45-
let diag = diag_entry.into_value();
46-
let mut weight_sum = N::zero();
47-
48-
// First pass: compute weights
49-
let mut weights = Vec::with_capacity(c_neighbors.len());
49+
// Early exit if no diagonal entry
50+
let Some(diag_entry) = a.get_entry(i, i) else {
51+
let equal_weight = N::one() / N::from_usize(c_neighbors.len()).unwrap();
5052
for &nbr in &c_neighbors {
51-
if let Some(a_ij) = a.get_entry(i, nbr) {
52-
let w = -a_ij.into_value() / diag;
53-
weights.push((coarse_of[nbr], w));
54-
weight_sum += w;
55-
}
53+
trip.push(i, coarse_of[nbr], equal_weight);
54+
}
55+
continue;
56+
};
57+
58+
let diag = diag_entry.into_value();
59+
let mut weight_sum = N::zero();
60+
61+
weights.clear();
62+
weights.reserve(c_neighbors.len());
63+
64+
// Compute weights in single pass
65+
for &nbr in &c_neighbors {
66+
if let Some(a_ij) = a.get_entry(i, nbr) {
67+
let w = -a_ij.into_value() / diag;
68+
weights.push((coarse_of[nbr], w));
69+
weight_sum += w;
5670
}
57-
58-
// Normalize weights to sum to 1 for better stability
59-
if weight_sum != N::zero() {
60-
for (col, w) in weights {
61-
trip.push(i, col, w / weight_sum);
62-
}
63-
} else {
64-
// Fallback: equal weights
65-
let equal_weight = N::one() / N::from_usize(c_neighbors.len()).unwrap();
66-
for &nbr in &c_neighbors {
67-
trip.push(i, coarse_of[nbr], equal_weight);
68-
}
71+
}
72+
73+
// Add entries with normalized weights
74+
if weight_sum != N::zero() {
75+
let inv_sum = N::one() / weight_sum;
76+
for (col, w) in weights.drain(..) {
77+
trip.push(i, col, w * inv_sum);
78+
}
79+
} else {
80+
let equal_weight = N::one() / N::from_usize(c_neighbors.len()).unwrap();
81+
for &nbr in &c_neighbors {
82+
trip.push(i, coarse_of[nbr], equal_weight);
6983
}
7084
}
7185
}

src/iteratives/amg/level.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ impl<N: RealField + Copy> Hierarchy<N> {
136136
// 5. Post-smooth: x_l gets updated by nu_post smoothing steps
137137
gauss_seidel::solve_with_initial_guess(&lev.a, b, x, nu_post, tol);
138138
*residual_buffer = &lev.a * &*x - b;
139-
println!("Level {} residual norm: {}", l, residual_buffer.norm());
139+
//println!("Level {} residual norm: {}", l, residual_buffer.norm());
140140
}
141141
}
142142

src/iteratives/amg/mod.rs

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ where
3131
use level::*;
3232

3333
// Pre-compute initial residual
34-
let mut residual_buffer = DVector::from(&a * &*x - b);
34+
let residual_buffer = DVector::from(&a * &*x - b);
3535
let hierarchy = setup(a, theta, 100);
3636

3737
// Check if we're already converged
@@ -44,11 +44,15 @@ where
4444
let adaptive_tol = tol.max(initial_residual_norm * T::from_f64(1e-3).unwrap());
4545

4646
for i in 0..max_iter {
47+
let mut residual_buffer = DVector::zeros(b.len());
4748
hierarchy.vcycle(0, b, x, &mut residual_buffer, adaptive_tol, 1, 1);
4849

50+
// Use the residual buffer that was updated by vcycle
51+
let current_residual = residual_buffer;
52+
4953
// Check convergence every few iterations to reduce overhead
5054
if i % 5 == 4 || i == max_iter - 1 {
51-
let residual_norm = residual_buffer.amax();
55+
let residual_norm = current_residual.amax();
5256
if residual_norm <= tol {
5357
return true;
5458
}

0 commit comments

Comments
 (0)