@@ -14,7 +14,7 @@ if !isdefined(Dagger, :accelerate!)
1414end
1515Dagger. accelerate! (:mpi )
1616
17- const N = 10_000
17+ const N = 2_000
1818const comm = MPI. COMM_WORLD
1919const rank = MPI. Comm_rank (comm)
2020const nranks = MPI. Comm_size (comm)
@@ -54,7 +54,6 @@ if CHECK_CORRECTNESS
5454 t_upload = @elapsed begin
5555 A_g = CUDA. cu (A_full)
5656 B_g = CUDA. cu (B_full)
57- C_dagger_g = CUDA. cu (C_dagger)
5857 end
5958 println (" Collect + upload time: " , round (t_collect + t_upload; digits= 4 ), " s" )
6059
@@ -63,44 +62,43 @@ if CHECK_CORRECTNESS
6362 end
6463 println (" Baseline (GPU/CUDA) time: " , round (t_baseline; digits= 4 ), " s" )
6564
66- rtol = 1f-5
67- atol = 1f-6
68- err = norm (C_dagger_g - C_ref_g)
69- ref_norm = norm (C_ref_g)
70- rel_err = ref_norm > 0 ? err / ref_norm : err
71- ok = err <= atol + rtol * ref_norm
65+ # Require all elements within 100× machine epsilon relative error (componentwise)
66+ C_dagger_cpu = C_dagger
67+ C_ref_cpu = Array (C_ref_g)
68+ eps_f = eps (Float32)
69+ rtol = 50.0f0 * eps_f
70+ diff = C_dagger_cpu .- C_ref_cpu
71+ # rel_ij = |diff|/|C_ref|, denominator at least eps to avoid div by zero
72+ denom = max .(abs .(C_ref_cpu), eps_f)
73+ rel_err = abs .(diff) ./ denom
74+ max_rel_err = Float32 (maximum (rel_err))
75+ ok = max_rel_err <= rtol
7276 if ok
73- println (" Correctness: OK (rel_err = " , Float32 (rel_err) , " , abs_err = " , Float32 (err) , " )" )
77+ println (" Correctness: OK (max rel_err = " , max_rel_err , " <= 100×eps = " , rtol , " )" )
7478 else
75- println (" Correctness: FAIL (rel_err = " , Float32 (rel_err) , " , abs_err = " , Float32 (err), " , rtol= $rtol , atol= $atol )" )
79+ println (" Correctness: FAIL (max rel_err = " , max_rel_err , " > 100×eps = " , rtol, " )" )
7680 end
7781
78- # Per-block analysis: which sections exceed tolerance (same block size as Dagger layout)
79- C_dagger_cpu = Array (C_dagger_g)
80- C_ref_cpu = Array (C_ref_g)
82+ # Per-block: which blocks have any element with rel_err > 100×eps
8183 n_bi = ceil (Int, N / BLOCK)
8284 n_bj = ceil (Int, N / BLOCK)
83- bad_blocks = Tuple{Int,Int,Float32,Float32 }[]
85+ bad_blocks = Tuple{Int,Int,Float32}[]
8486 for bi in 1 : n_bi, bj in 1 : n_bj
8587 ri = (bi - 1 ) * BLOCK + 1 : min (bi * BLOCK, N)
8688 rj = (bj - 1 ) * BLOCK + 1 : min (bj * BLOCK, N)
87- diff_block = @view (C_dagger_cpu[ri, rj]) .- @view (C_ref_cpu[ri, rj])
88- ref_block = @view (C_ref_cpu[ri, rj])
89- block_err = norm (diff_block)
90- block_ref = norm (ref_block)
91- block_rel = block_ref > 0 ? block_err / block_ref : block_err
92- if block_err > atol + rtol * block_ref
93- push! (bad_blocks, (bi, bj, Float32 (block_rel), Float32 (block_err)))
89+ block_rel = Float32 (maximum (@view (rel_err[ri, rj])))
90+ if block_rel > rtol
91+ push! (bad_blocks, (bi, bj, block_rel))
9492 end
9593 end
9694 if isempty (bad_blocks)
97- println (" Per-block: all " , n_bi * n_bj, " blocks within tolerance ." )
95+ println (" Per-block: all " , n_bi * n_bj, " blocks within 100×eps rel_err ." )
9896 else
99- println (" Per-block: " , length (bad_blocks), " block(s) exceed tolerance (block size " , BLOCK, " ×" , BLOCK, " ):" )
97+ println (" Per-block: " , length (bad_blocks), " block(s) exceed 100×eps rel_err (block size " , BLOCK, " ×" , BLOCK, " ):" )
10098 sort! (bad_blocks; by = x -> - x[3 ])
101- for (bi, bj, brel, babs ) in bad_blocks
99+ for (bi, bj, block_rel ) in bad_blocks
102100 println (" block [" , bi, " ," , bj, " ] rows " , (bi - 1 ) * BLOCK + 1 , " :" , min (bi * BLOCK, N),
103- " , cols " , (bj - 1 ) * BLOCK + 1 , " :" , min (bj * BLOCK, N), " rel_err = " , brel, " abs_err = " , babs )
101+ " , cols " , (bj - 1 ) * BLOCK + 1 , " :" , min (bj * BLOCK, N), " max rel_err = " , block_rel )
104102 end
105103 end
106104 end
0 commit comments