@@ -58,7 +58,7 @@ using PrecompileTools
5858# ============================================================================
5959
6060# Import the functions we need to extend
61- import MultiGridBarrier: amgb_zeros, amgb_all_isfinite, amgb_assert_uniform, amgb_diag, amgb_blockdiag, map_rows, map_rows_gpu, vertex_indices, _raw_array, _to_cpu_array, _rows_to_svectors
61+ import MultiGridBarrier: amgb_zeros, amgb_all_isfinite, amgb_diag, amgb_blockdiag, map_rows, map_rows_gpu, vertex_indices, _raw_array, _to_cpu_array, _rows_to_svectors
6262
6363# amgb_zeros: Create distributed zero matrices/vectors using Base.zeros from LinearAlgebraMPI
6464MultiGridBarrier. amgb_zeros (:: SparseMatrixMPI{T,Ti,AV} , m, n) where {T,Ti,AV} =
@@ -87,112 +87,6 @@ function MultiGridBarrier.amgb_all_isfinite(z::VectorMPI{T,AV}) where {T,AV}
8787 MPI. Allreduce (local_all_finite, & , MPI. COMM_WORLD)
8888end
8989
90- # amgb_assert_uniform: Assert that a scalar value is identical on all MPI ranks
91- # Gathers all values to rank 0, checks uniformity, and aborts if not uniform
92- function MultiGridBarrier. amgb_assert_uniform (x:: T , msg:: String = " " ) where T<: Number
93- comm = MPI. COMM_WORLD
94- rank = MPI. Comm_rank (comm)
95- nranks = MPI. Comm_size (comm)
96-
97- # Gather all values to rank 0
98- all_values = MPI. Gather (x, 0 , comm)
99-
100- # Check uniformity on rank 0
101- is_uniform = true
102- if rank == 0
103- ref_val = all_values[1 ]
104- for i in 2 : nranks
105- # Use isequal for exact equality (handles NaN correctly: isequal(NaN,NaN)=true)
106- if ! isequal (all_values[i], ref_val)
107- is_uniform = false
108- break
109- end
110- end
111- end
112-
113- # Broadcast uniformity result to all ranks
114- is_uniform = MPI. Bcast (is_uniform, 0 , comm)
115-
116- if ! is_uniform
117- # Print error info on rank 0 only (use stdout for visibility)
118- if rank == 0
119- println (" \n " * " =" ^ 60 )
120- println (" MPI UNIFORMITY ASSERTION FAILED: $msg " )
121- println (" =" ^ 60 )
122- println (" Values across ranks:" )
123- for i in 1 : nranks
124- println (" Rank $(i- 1 ) : $(all_values[i]) " )
125- end
126- println (" \n Stack trace:" )
127- for frame in stacktrace ()
128- println (" " , frame)
129- end
130- println (" =" ^ 60 )
131- flush (stdout )
132- end
133-
134- # Small delay to ensure output is flushed before abort
135- sleep (0.1 )
136-
137- # Abort all ranks
138- MPI. Abort (comm, 1 )
139- end
140-
141- return nothing
142- end
143-
144- # Also handle boolean specifically for converged flags
145- function MultiGridBarrier. amgb_assert_uniform (x:: Bool , msg:: String = " " )
146- comm = MPI. COMM_WORLD
147- rank = MPI. Comm_rank (comm)
148- nranks = MPI. Comm_size (comm)
149-
150- # Convert to Int for MPI (some MPI implementations don't handle Bool well)
151- x_int = Int32 (x)
152- all_values = MPI. Gather (x_int, 0 , comm)
153-
154- # Check uniformity on rank 0
155- is_uniform = true
156- if rank == 0
157- ref_val = all_values[1 ]
158- for i in 2 : nranks
159- if all_values[i] != ref_val
160- is_uniform = false
161- break
162- end
163- end
164- end
165-
166- # Broadcast uniformity result to all ranks
167- is_uniform = MPI. Bcast (is_uniform, 0 , comm)
168-
169- if ! is_uniform
170- # Print error info on rank 0 only (use stdout for visibility)
171- if rank == 0
172- println (" \n " * " =" ^ 60 )
173- println (" MPI UNIFORMITY ASSERTION FAILED: $msg " )
174- println (" =" ^ 60 )
175- println (" Boolean values across ranks:" )
176- for i in 1 : nranks
177- println (" Rank $(i- 1 ) : $(Bool (all_values[i])) " )
178- end
179- println (" \n Stack trace:" )
180- for frame in stacktrace ()
181- println (" " , frame)
182- end
183- println (" =" ^ 60 )
184- flush (stdout )
185- end
186-
187- # Small delay to ensure output is flushed before abort
188- sleep (0.1 )
189-
190- MPI. Abort (comm, 1 )
191- end
192-
193- return nothing
194- end
195-
19690# amgb_diag: Create diagonal matrix from vector
19791# SparseMatrixMPI with VectorMPI - preserves vector's array type in nzval
19892MultiGridBarrier. amgb_diag (:: SparseMatrixMPI{T,Ti,AV} , z:: VectorMPI{T,AV2} , m= length (z), n= length (z)) where {T,Ti,AV,AV2} =
0 commit comments