|
| 1 | +#!/usr/bin/env julia |
| 2 | +# N×N matmul benchmark (Float32); block size scales with number of ranks. |
| 3 | +# Usage (use the full path to Dagger.jl, not "..."): |
| 4 | +# mpiexec -n 10 julia --project=/home/felipetome/dagger-dev/mpi/Dagger.jl benchmarks/run_matmul.jl |
| 5 | +# Set CHECK_CORRECTNESS=true to collect and compare against GPU baseline: |
| 6 | +# CHECK_CORRECTNESS=true mpiexec -n 10 julia --project=/home/felipetome/dagger-dev/mpi/Dagger.jl benchmarks/run_matmul.jl |
| 7 | + |
| 8 | +using MPI |
| 9 | +using Dagger |
| 10 | +using LinearAlgebra |
| 11 | + |
| 12 | +if !isdefined(Dagger, :accelerate!) |
| 13 | + error("Dagger.accelerate! not found. Run with the local Dagger project: julia --project=/path/to/Dagger.jl ...") |
| 14 | +end |
| 15 | +Dagger.accelerate!(:mpi) |
| 16 | + |
| 17 | +const N = 2_000 |
| 18 | +const comm = MPI.COMM_WORLD |
| 19 | +const rank = MPI.Comm_rank(comm) |
| 20 | +const nranks = MPI.Comm_size(comm) |
| 21 | +# Block size proportional to ranks: ~nranks blocks in 2D => side blocks ≈ √nranks |
| 22 | +const BLOCK = max(1, ceil(Int, N / ceil(Int, sqrt(nranks)))) |
| 23 | + |
| 24 | +const CHECK_CORRECTNESS = parse(Bool, get(ENV, "CHECK_CORRECTNESS", "false")) |
| 25 | + |
| 26 | +if rank == 0 |
| 27 | + println("Benchmark: ", nranks, " ranks, N=", N, ", block size ", BLOCK, "×", BLOCK, " (matmul)") |
| 28 | +end |
| 29 | + |
| 30 | +# Allocate and fill matrices in blocks (Float32) |
| 31 | +A = rand(Blocks(BLOCK, BLOCK), Float32, N, N) |
| 32 | +B = rand(Blocks(BLOCK, BLOCK), Float32, N, N) |
| 33 | + |
| 34 | +# Matrix multiply C = A * B |
| 35 | +t_matmul = @elapsed begin |
| 36 | + C = A * B |
| 37 | +end |
| 38 | + |
| 39 | +if rank == 0 |
| 40 | + println("Matmul time: ", round(t_matmul; digits=4), " s") |
| 41 | +end |
| 42 | + |
| 43 | +# Optional: collect via datadeps (root=0). All ranks participate in the datadeps region. |
| 44 | +if CHECK_CORRECTNESS |
| 45 | + t_collect = @elapsed begin |
| 46 | + A_full = Dagger.collect_datadeps(A; root=0) |
| 47 | + B_full = Dagger.collect_datadeps(B; root=0) |
| 48 | + C_dagger = Dagger.collect_datadeps(C; root=0) |
| 49 | + end |
| 50 | + if rank == 0 |
| 51 | + println("Collecting result and computing baseline for correctness check (GPU)...") |
| 52 | + using CUDA |
| 53 | + CUDA.functional() || error("CUDA not functional; cannot compute GPU baseline. Check CUDA driver and device.") |
| 54 | + t_upload = @elapsed begin |
| 55 | + A_g = CUDA.cu(A_full) |
| 56 | + B_g = CUDA.cu(B_full) |
| 57 | + end |
| 58 | + println("Collect + upload time: ", round(t_collect + t_upload; digits=4), " s") |
| 59 | + |
| 60 | + t_baseline = @elapsed begin |
| 61 | + C_ref_g = A_g * B_g |
| 62 | + end |
| 63 | + println("Baseline (GPU/CUDA) time: ", round(t_baseline; digits=4), " s") |
| 64 | + |
| 65 | + # Require all elements within 100× machine epsilon relative error (componentwise) |
| 66 | + C_dagger_cpu = C_dagger |
| 67 | + C_ref_cpu = Array(C_ref_g) |
| 68 | + eps_f = eps(Float32) |
| 69 | + rtol = 50.0f0 * eps_f |
| 70 | + diff = C_dagger_cpu .- C_ref_cpu |
| 71 | + # rel_ij = |diff|/|C_ref|, denominator at least eps to avoid div by zero |
| 72 | + denom = max.(abs.(C_ref_cpu), eps_f) |
| 73 | + rel_err = abs.(diff) ./ denom |
| 74 | + max_rel_err = Float32(maximum(rel_err)) |
| 75 | + ok = max_rel_err <= rtol |
| 76 | + if ok |
| 77 | + println("Correctness: OK (max rel_err = ", max_rel_err, " <= 100×eps = ", rtol, ")") |
| 78 | + else |
| 79 | + println("Correctness: FAIL (max rel_err = ", max_rel_err, " > 100×eps = ", rtol, ")") |
| 80 | + end |
| 81 | + |
| 82 | + # Per-block: which blocks have any element with rel_err > 100×eps |
| 83 | + n_bi = ceil(Int, N / BLOCK) |
| 84 | + n_bj = ceil(Int, N / BLOCK) |
| 85 | + bad_blocks = Tuple{Int,Int,Float32}[] |
| 86 | + for bi in 1:n_bi, bj in 1:n_bj |
| 87 | + ri = (bi - 1) * BLOCK + 1 : min(bi * BLOCK, N) |
| 88 | + rj = (bj - 1) * BLOCK + 1 : min(bj * BLOCK, N) |
| 89 | + block_rel = Float32(maximum(@view(rel_err[ri, rj]))) |
| 90 | + if block_rel > rtol |
| 91 | + push!(bad_blocks, (bi, bj, block_rel)) |
| 92 | + end |
| 93 | + end |
| 94 | + if isempty(bad_blocks) |
| 95 | + println("Per-block: all ", n_bi * n_bj, " blocks within 100×eps rel_err.") |
| 96 | + else |
| 97 | + println("Per-block: ", length(bad_blocks), " block(s) exceed 100×eps rel_err (block size ", BLOCK, "×", BLOCK, "):") |
| 98 | + sort!(bad_blocks; by = x -> -x[3]) |
| 99 | + for (bi, bj, block_rel) in bad_blocks |
| 100 | + println(" block [", bi, ",", bj, "] rows ", (bi - 1) * BLOCK + 1, ":", min(bi * BLOCK, N), |
| 101 | + ", cols ", (bj - 1) * BLOCK + 1, ":", min(bj * BLOCK, N), " max rel_err = ", block_rel) |
| 102 | + end |
| 103 | + end |
| 104 | + end |
| 105 | +end |
0 commit comments