Skip to content

Commit 989349d

Browse files
committed
WIP: MPI inference
1 parent d0b0c71 commit 989349d

3 files changed

Lines changed: 314 additions & 46 deletions

File tree

benchmarks/run_matmul.jl

Lines changed: 23 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ if !isdefined(Dagger, :accelerate!)
1414
end
1515
Dagger.accelerate!(:mpi)
1616

17-
const N = 10_000
17+
const N = 2_000
1818
const comm = MPI.COMM_WORLD
1919
const rank = MPI.Comm_rank(comm)
2020
const nranks = MPI.Comm_size(comm)
@@ -54,7 +54,6 @@ if CHECK_CORRECTNESS
5454
t_upload = @elapsed begin
5555
A_g = CUDA.cu(A_full)
5656
B_g = CUDA.cu(B_full)
57-
C_dagger_g = CUDA.cu(C_dagger)
5857
end
5958
println("Collect + upload time: ", round(t_collect + t_upload; digits=4), " s")
6059

@@ -63,44 +62,43 @@ if CHECK_CORRECTNESS
6362
end
6463
println("Baseline (GPU/CUDA) time: ", round(t_baseline; digits=4), " s")
6564

66-
rtol = 1f-5
67-
atol = 1f-6
68-
err = norm(C_dagger_g - C_ref_g)
69-
ref_norm = norm(C_ref_g)
70-
rel_err = ref_norm > 0 ? err / ref_norm : err
71-
ok = err <= atol + rtol * ref_norm
65+
# Require all elements within 100× machine epsilon relative error (componentwise)
66+
C_dagger_cpu = C_dagger
67+
C_ref_cpu = Array(C_ref_g)
68+
eps_f = eps(Float32)
69+
rtol = 50.0f0 * eps_f
70+
diff = C_dagger_cpu .- C_ref_cpu
71+
# rel_ij = |diff|/|C_ref|, denominator at least eps to avoid div by zero
72+
denom = max.(abs.(C_ref_cpu), eps_f)
73+
rel_err = abs.(diff) ./ denom
74+
max_rel_err = Float32(maximum(rel_err))
75+
ok = max_rel_err <= rtol
7276
if ok
73-
println("Correctness: OK (rel_err = ", Float32(rel_err), ", abs_err = ", Float32(err), ")")
77+
println("Correctness: OK (max rel_err = ", max_rel_err, " <= 100×eps = ", rtol, ")")
7478
else
75-
println("Correctness: FAIL (rel_err = ", Float32(rel_err), ", abs_err = ", Float32(err), ", rtol=$rtol, atol=$atol)")
79+
println("Correctness: FAIL (max rel_err = ", max_rel_err, " > 100×eps = ", rtol, ")")
7680
end
7781

78-
# Per-block analysis: which sections exceed tolerance (same block size as Dagger layout)
79-
C_dagger_cpu = Array(C_dagger_g)
80-
C_ref_cpu = Array(C_ref_g)
82+
# Per-block: which blocks have any element with rel_err > 100×eps
8183
n_bi = ceil(Int, N / BLOCK)
8284
n_bj = ceil(Int, N / BLOCK)
83-
bad_blocks = Tuple{Int,Int,Float32,Float32}[]
85+
bad_blocks = Tuple{Int,Int,Float32}[]
8486
for bi in 1:n_bi, bj in 1:n_bj
8587
ri = (bi - 1) * BLOCK + 1 : min(bi * BLOCK, N)
8688
rj = (bj - 1) * BLOCK + 1 : min(bj * BLOCK, N)
87-
diff_block = @view(C_dagger_cpu[ri, rj]) .- @view(C_ref_cpu[ri, rj])
88-
ref_block = @view(C_ref_cpu[ri, rj])
89-
block_err = norm(diff_block)
90-
block_ref = norm(ref_block)
91-
block_rel = block_ref > 0 ? block_err / block_ref : block_err
92-
if block_err > atol + rtol * block_ref
93-
push!(bad_blocks, (bi, bj, Float32(block_rel), Float32(block_err)))
89+
block_rel = Float32(maximum(@view(rel_err[ri, rj])))
90+
if block_rel > rtol
91+
push!(bad_blocks, (bi, bj, block_rel))
9492
end
9593
end
9694
if isempty(bad_blocks)
97-
println("Per-block: all ", n_bi * n_bj, " blocks within tolerance.")
95+
println("Per-block: all ", n_bi * n_bj, " blocks within 100×eps rel_err.")
9896
else
99-
println("Per-block: ", length(bad_blocks), " block(s) exceed tolerance (block size ", BLOCK, "×", BLOCK, "):")
97+
println("Per-block: ", length(bad_blocks), " block(s) exceed 100×eps rel_err (block size ", BLOCK, "×", BLOCK, "):")
10098
sort!(bad_blocks; by = x -> -x[3])
101-
for (bi, bj, brel, babs) in bad_blocks
99+
for (bi, bj, block_rel) in bad_blocks
102100
println(" block [", bi, ",", bj, "] rows ", (bi - 1) * BLOCK + 1, ":", min(bi * BLOCK, N),
103-
", cols ", (bj - 1) * BLOCK + 1, ":", min(bj * BLOCK, N), " rel_err = ", brel, " abs_err = ", babs)
101+
", cols ", (bj - 1) * BLOCK + 1, ":", min(bj * BLOCK, N), " max rel_err = ", block_rel)
104102
end
105103
end
106104
end

0 commit comments

Comments
 (0)