Skip to content

Commit d0b0c71

Browse files
committed
MPI benchmarks and matmul correctness: Float32, 10k, per-block check
- benchmarks/run_matmul.jl: Float32, N=10k, relative error + per-block report - benchmarks/run_distribute_fetch.jl, run_qr.jl, check_comm_asymmetry (jl/py) - src: alloc, darray, mul, mpi, options, Sch, submission, thunk, tochunk, dagdebug Made-with: Cursor
1 parent bc2f7d7 commit d0b0c71

15 files changed

Lines changed: 677 additions & 70 deletions

benchmarks/check_comm_asymmetry.jl

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
#!/usr/bin/env julia
2+
# Parse MPI+Dagger logs and report communication decision asymmetry per tag.
3+
# Asymmetry: for the same tag, one rank decides to send (local+bcast, sender+communicated, etc.)
4+
# and another rank decides to infer (inferred, uninvolved) and never recv → deadlock.
5+
#
6+
# Usage: julia check_comm_asymmetry.jl < logfile
7+
# Or: mpiexec -n 10 julia ... run_matmul.jl 2>&1 | tee matmul.log; julia check_comm_asymmetry.jl < matmul.log
8+
9+
const SEND_DECISIONS = Set([
10+
"local+bcast", "sender+communicated", "sender+inferred", "receiver+bcast",
11+
"aliasing", # when followed by local+bcast we already capture local+bcast
12+
])
13+
const RECV_DECISIONS = Set([
14+
"communicated", "receiver", "sender+communicated", # received data
15+
])
16+
const INFER_DECISIONS = Set([
17+
"inferred", "uninvolved", # did not recv (uses inferred type)
18+
])
19+
20+
function parse_line(line)
21+
# Match [rank X][tag Y] then any [...] and capture the last bracket pair before space or end
22+
rank = nothing
23+
tag = nothing
24+
decision = nothing
25+
category = nothing # aliasing, execute!, remotecall_endpoint
26+
for m in eachmatch(r"\[rank\s+(\d+)\]", line)
27+
rank = parse(Int, m.captures[1])
28+
end
29+
for m in eachmatch(r"\[tag\s+(\d+)\]", line)
30+
tag = parse(Int, m.captures[1])
31+
end
32+
for m in eachmatch(r"\[(execute!|aliasing|remotecall_endpoint)\]", line)
33+
category = m.captures[1]
34+
end
35+
# Decision is usually in last [...] that looks like [word] or [word+word]
36+
for m in eachmatch(r"\]\[([^\]]+)\]", line)
37+
candidate = m.captures[1]
38+
# Normalize: "communicated" "inferred" "local+bcast" "sender+inferred" "receiver" etc.
39+
if occursin("inferred", candidate) && !occursin("communicated", candidate)
40+
decision = "inferred"
41+
break
42+
elseif occursin("communicated", candidate)
43+
decision = "communicated"
44+
break
45+
elseif occursin("local+bcast", candidate)
46+
decision = "local+bcast"
47+
break
48+
elseif occursin("sender+", candidate)
49+
decision = startswith(candidate, "sender+inferred") ? "sender+inferred" : "sender+communicated"
50+
break
51+
elseif candidate == "receiver"
52+
decision = "receiver"
53+
break
54+
elseif candidate == "receiver+bcast"
55+
decision = "receiver+bcast"
56+
break
57+
elseif candidate == "inplace_move"
58+
decision = "inplace_move"
59+
break
60+
end
61+
end
62+
return rank, tag, category, decision
63+
end
64+
65+
function main()
66+
# tag => Dict(rank => decision)
67+
by_tag = Dict{Int, Dict{Int, String}}()
68+
for line in eachline(stdin)
69+
rank, tag, category, decision = parse_line(line)
70+
isnothing(rank) && continue
71+
isnothing(tag) && continue
72+
isnothing(decision) && continue
73+
if !haskey(by_tag, tag)
74+
by_tag[tag] = Dict{Int, String}()
75+
end
76+
by_tag[tag][rank] = decision
77+
end
78+
79+
# For each tag, check: is there at least one sender and one inferrer (non-receiver)?
80+
send_keys = Set(["local+bcast", "sender+communicated", "sender+inferred", "receiver+bcast"])
81+
infer_keys = Set(["inferred", "sender+inferred"]) # sender+inferred means sender didn't need to recv
82+
recv_keys = Set(["communicated", "receiver", "sender+communicated"])
83+
84+
asymmetries = []
85+
for (tag, ranks) in sort(collect(by_tag), by = first)
86+
senders = [r for (r, d) in ranks if d in send_keys]
87+
inferrers = [r for (r, d) in ranks if d in infer_keys || d == "uninvolved"]
88+
receivers = [r for (r, d) in ranks if d in recv_keys]
89+
# Asymmetry: someone sends (bcast) so will send to ALL other ranks; someone chose infer and won't recv.
90+
if !isempty(senders) && !isempty(inferrers)
91+
push!(asymmetries, (tag, senders, inferrers, receivers, ranks))
92+
end
93+
end
94+
95+
if isempty(asymmetries)
96+
println("No communication decision asymmetry found (no tag has both sender and inferrer).")
97+
return
98+
end
99+
100+
println("=== Communication decision asymmetry (can cause deadlock) ===\n")
101+
for (tag, senders, inferrers, receivers, ranks) in asymmetries
102+
println("Tag $tag:")
103+
println(" Senders (will bcast to all others): $senders")
104+
println(" Inferrers (did not recv): $inferrers")
105+
println(" Receivers: $receivers")
106+
println(" All decisions: $ranks")
107+
println()
108+
end
109+
end
110+
111+
main()

benchmarks/check_comm_asymmetry.py

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
#!/usr/bin/env python3
2+
"""
3+
Parse MPI+Dagger logs and report communication decision asymmetry per tag.
4+
Asymmetry: for the same tag, one rank decides to send (local+bcast, etc.)
5+
and another decides to infer (inferred) and never recv → deadlock.
6+
7+
Usage:
8+
# Capture full log (all ranks' Core.println from mpi.jl go to stdout):
9+
mpiexec -n 10 julia --project=/path/to/Dagger.jl benchmarks/run_matmul.jl 2>&1 | tee matmul.log
10+
# Then look for asymmetry (same tag: one rank sends, another infers → deadlock):
11+
python3 check_comm_asymmetry.py < matmul.log
12+
"""
13+
14+
import re
15+
import sys
16+
from collections import defaultdict
17+
18+
SEND_DECISIONS = {"local+bcast", "sender+communicated", "sender+inferred", "receiver+bcast"}
19+
RECV_DECISIONS = {"communicated", "receiver", "sender+communicated"}
20+
INFER_DECISIONS = {"inferred", "uninvolved", "sender+inferred"}
21+
22+
23+
def parse_line(line: str):
24+
rank = tag = category = decision = None
25+
m = re.search(r"\[rank\s+(\d+)\]", line)
26+
if m:
27+
rank = int(m.group(1))
28+
m = re.search(r"\[tag\s+(\d+)\]", line)
29+
if m:
30+
tag = int(m.group(1))
31+
m = re.search(r"\[(execute!|aliasing|remotecall_endpoint)\]", line)
32+
if m:
33+
category = m.group(1)
34+
# Capture decision from [...] blocks
35+
for m in re.finditer(r"\]\[([^\]]+)\]", line):
36+
candidate = m.group(1)
37+
if "inferred" in candidate and "communicated" not in candidate:
38+
decision = "inferred"
39+
break
40+
if "communicated" in candidate:
41+
decision = "communicated"
42+
break
43+
if "local+bcast" in candidate:
44+
decision = "local+bcast"
45+
break
46+
if candidate.startswith("sender+"):
47+
decision = "sender+inferred" if "inferred" in candidate else "sender+communicated"
48+
break
49+
if candidate == "receiver":
50+
decision = "receiver"
51+
break
52+
if candidate == "receiver+bcast":
53+
decision = "receiver+bcast"
54+
break
55+
if candidate == "inplace_move":
56+
decision = "inplace_move"
57+
break
58+
return rank, tag, category, decision
59+
60+
61+
def main():
62+
by_tag = defaultdict(dict) # tag -> {rank: decision}
63+
for line in sys.stdin:
64+
rank, tag, category, decision = parse_line(line)
65+
if rank is None or tag is None or decision is None:
66+
continue
67+
by_tag[tag][rank] = decision
68+
69+
send_keys = {"local+bcast", "sender+communicated", "sender+inferred", "receiver+bcast"}
70+
infer_keys = {"inferred", "sender+inferred", "uninvolved"}
71+
recv_keys = {"communicated", "receiver", "sender+communicated"}
72+
73+
asymmetries = []
74+
for tag in sorted(by_tag.keys()):
75+
ranks = by_tag[tag]
76+
senders = [r for r, d in ranks.items() if d in send_keys]
77+
inferrers = [r for r, d in ranks.items() if d in infer_keys]
78+
receivers = [r for r, d in ranks.items() if d in recv_keys]
79+
if senders and inferrers:
80+
asymmetries.append((tag, senders, inferrers, receivers, ranks))
81+
82+
if not asymmetries:
83+
print("No communication decision asymmetry found (no tag has both sender and inferrer).")
84+
return
85+
86+
print("=== Communication decision asymmetry (can cause deadlock) ===\n")
87+
for tag, senders, inferrers, receivers, ranks in asymmetries:
88+
print(f"Tag {tag}:")
89+
print(f" Senders (will bcast to all others): {senders}")
90+
print(f" Inferrers (did not recv): {inferrers}")
91+
print(f" Receivers: {receivers}")
92+
print(f" All decisions: {dict(ranks)}")
93+
print()
94+
95+
96+
if __name__ == "__main__":
97+
main()

benchmarks/run_distribute_fetch.jl

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
#!/usr/bin/env julia
2+
# Create a matrix with a fixed reproducible pattern, distribute it with an
3+
# MPI procgrid, then on each rank fetch and println the chunk(s) it owns.
4+
# Usage (from repo root, use full path to Dagger.jl):
5+
# mpiexec -n 4 julia --project=/path/to/Dagger.jl benchmarks/run_distribute_fetch.jl
6+
7+
using MPI
8+
using Dagger
9+
10+
if !isdefined(Dagger, :accelerate!)
11+
error("Dagger.accelerate! not found. Run with the local Dagger project: julia --project=/path/to/Dagger.jl ...")
12+
end
13+
Dagger.accelerate!(:mpi)
14+
15+
const comm = MPI.COMM_WORLD
16+
const rank = MPI.Comm_rank(comm)
17+
const nranks = MPI.Comm_size(comm)
18+
19+
# Fixed reproducible pattern: 6×6 matrix, M[i,j] = 10*i + j (same on all ranks)
20+
const N = 6
21+
const BLOCK = 2
22+
A = [10 * i + j for i in 1:N, j in 1:N]
23+
24+
# Procgrid: use Dagger's compatible processors so the procgrid passes validation
25+
availprocs = collect(Dagger.compatible_processors())
26+
nblocks = (cld(N, BLOCK), cld(N, BLOCK))
27+
procgrid = reshape(
28+
[availprocs[mod(i - 1, length(availprocs)) + 1] for i in 1:prod(nblocks)],
29+
nblocks,
30+
)
31+
32+
# Distribute so chunk (i,j) is computed on procgrid[i,j]
33+
D = distribute(A, Blocks(BLOCK, BLOCK), procgrid)
34+
D_fetched = fetch(D)
35+
36+
# On each rank: fetch and print only the chunk(s) this rank owns
37+
for (idx, ch) in enumerate(D_fetched.chunks)
38+
if ch isa Dagger.Chunk && ch.handle isa Dagger.MPIRef && ch.handle.rank == rank
39+
data = fetch(ch)
40+
println("rank $rank chunk $idx: ", data)
41+
end
42+
end

benchmarks/run_matmul.jl

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
#!/usr/bin/env julia
2+
# N×N matmul benchmark (Float32); block size scales with number of ranks.
3+
# Usage (use the full path to Dagger.jl, not "..."):
4+
# mpiexec -n 10 julia --project=/home/felipetome/dagger-dev/mpi/Dagger.jl benchmarks/run_matmul.jl
5+
# Set CHECK_CORRECTNESS=true to collect and compare against GPU baseline:
6+
# CHECK_CORRECTNESS=true mpiexec -n 10 julia --project=/home/felipetome/dagger-dev/mpi/Dagger.jl benchmarks/run_matmul.jl
7+
8+
using MPI
9+
using Dagger
10+
using LinearAlgebra
11+
12+
if !isdefined(Dagger, :accelerate!)
13+
error("Dagger.accelerate! not found. Run with the local Dagger project: julia --project=/path/to/Dagger.jl ...")
14+
end
15+
Dagger.accelerate!(:mpi)
16+
17+
const N = 10_000
18+
const comm = MPI.COMM_WORLD
19+
const rank = MPI.Comm_rank(comm)
20+
const nranks = MPI.Comm_size(comm)
21+
# Block size proportional to ranks: ~nranks blocks in 2D => side blocks ≈ √nranks
22+
const BLOCK = max(1, ceil(Int, N / ceil(Int, sqrt(nranks))))
23+
24+
const CHECK_CORRECTNESS = parse(Bool, get(ENV, "CHECK_CORRECTNESS", "false"))
25+
26+
if rank == 0
27+
println("Benchmark: ", nranks, " ranks, N=", N, ", block size ", BLOCK, "×", BLOCK, " (matmul)")
28+
end
29+
30+
# Allocate and fill matrices in blocks (Float32)
31+
A = rand(Blocks(BLOCK, BLOCK), Float32, N, N)
32+
B = rand(Blocks(BLOCK, BLOCK), Float32, N, N)
33+
34+
# Matrix multiply C = A * B
35+
t_matmul = @elapsed begin
36+
C = A * B
37+
end
38+
39+
if rank == 0
40+
println("Matmul time: ", round(t_matmul; digits=4), " s")
41+
end
42+
43+
# Optional: collect via datadeps (root=0). All ranks participate in the datadeps region.
44+
if CHECK_CORRECTNESS
45+
t_collect = @elapsed begin
46+
A_full = Dagger.collect_datadeps(A; root=0)
47+
B_full = Dagger.collect_datadeps(B; root=0)
48+
C_dagger = Dagger.collect_datadeps(C; root=0)
49+
end
50+
if rank == 0
51+
println("Collecting result and computing baseline for correctness check (GPU)...")
52+
using CUDA
53+
CUDA.functional() || error("CUDA not functional; cannot compute GPU baseline. Check CUDA driver and device.")
54+
t_upload = @elapsed begin
55+
A_g = CUDA.cu(A_full)
56+
B_g = CUDA.cu(B_full)
57+
C_dagger_g = CUDA.cu(C_dagger)
58+
end
59+
println("Collect + upload time: ", round(t_collect + t_upload; digits=4), " s")
60+
61+
t_baseline = @elapsed begin
62+
C_ref_g = A_g * B_g
63+
end
64+
println("Baseline (GPU/CUDA) time: ", round(t_baseline; digits=4), " s")
65+
66+
rtol = 1f-5
67+
atol = 1f-6
68+
err = norm(C_dagger_g - C_ref_g)
69+
ref_norm = norm(C_ref_g)
70+
rel_err = ref_norm > 0 ? err / ref_norm : err
71+
ok = err <= atol + rtol * ref_norm
72+
if ok
73+
println("Correctness: OK (rel_err = ", Float32(rel_err), ", abs_err = ", Float32(err), ")")
74+
else
75+
println("Correctness: FAIL (rel_err = ", Float32(rel_err), ", abs_err = ", Float32(err), ", rtol=$rtol, atol=$atol)")
76+
end
77+
78+
# Per-block analysis: which sections exceed tolerance (same block size as Dagger layout)
79+
C_dagger_cpu = Array(C_dagger_g)
80+
C_ref_cpu = Array(C_ref_g)
81+
n_bi = ceil(Int, N / BLOCK)
82+
n_bj = ceil(Int, N / BLOCK)
83+
bad_blocks = Tuple{Int,Int,Float32,Float32}[]
84+
for bi in 1:n_bi, bj in 1:n_bj
85+
ri = (bi - 1) * BLOCK + 1 : min(bi * BLOCK, N)
86+
rj = (bj - 1) * BLOCK + 1 : min(bj * BLOCK, N)
87+
diff_block = @view(C_dagger_cpu[ri, rj]) .- @view(C_ref_cpu[ri, rj])
88+
ref_block = @view(C_ref_cpu[ri, rj])
89+
block_err = norm(diff_block)
90+
block_ref = norm(ref_block)
91+
block_rel = block_ref > 0 ? block_err / block_ref : block_err
92+
if block_err > atol + rtol * block_ref
93+
push!(bad_blocks, (bi, bj, Float32(block_rel), Float32(block_err)))
94+
end
95+
end
96+
if isempty(bad_blocks)
97+
println("Per-block: all ", n_bi * n_bj, " blocks within tolerance.")
98+
else
99+
println("Per-block: ", length(bad_blocks), " block(s) exceed tolerance (block size ", BLOCK, "×", BLOCK, "):")
100+
sort!(bad_blocks; by = x -> -x[3])
101+
for (bi, bj, brel, babs) in bad_blocks
102+
println(" block [", bi, ",", bj, "] rows ", (bi - 1) * BLOCK + 1, ":", min(bi * BLOCK, N),
103+
", cols ", (bj - 1) * BLOCK + 1, ":", min(bj * BLOCK, N), " rel_err = ", brel, " abs_err = ", babs)
104+
end
105+
end
106+
end
107+
end

0 commit comments

Comments
 (0)