Skip to content

Commit f5932f7

Browse files
Support SolverBenchmarks (#7)
Thank you @MaxenceGollier! We can think about reducing duplicate code in a follow-up pr.
1 parent ffdaff7 commit f5932f7

3 files changed

Lines changed: 247 additions & 2 deletions

File tree

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ Git = "d7ba0133-e1db-5d97-8f8c-041e4b3a1eb2"
1111
GitHub = "bc5e4493-9b4d-5f90-b8aa-2b2bcaad7a26"
1212
JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819"
1313
JSON = "682c06a0-de6a-54ab-a142-c8b1cf79cde6"
14+
LibGit2 = "76f85450-5226-5b5a-8eaa-529ad045b433"
1415
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
1516
PkgBenchmark = "32113eaa-f34f-5b0d-bd6c-c81e245fc73d"
1617
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
@@ -25,6 +26,7 @@ Git = "1.3"
2526
GitHub = "5.9"
2627
JLD2 = "0.4,0.5"
2728
JSON = "0.21"
29+
LibGit2 = "1.11.0"
2830
Pkg = "1.9"
2931
PkgBenchmark = "0.2"
3032
Plots = "1.39"

src/JSOBenchmarks.jl

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,17 +10,21 @@ using Git
1010
using GitHub
1111
using JLD2
1212
using JSON
13+
using LibGit2
1314
using PkgBenchmark
1415
using Plots
1516
using StatsPlots
1617

1718
# JSO modules
1819
using SolverBenchmark
1920

20-
export run_benchmarks
21+
include("solver_benchmarks.jl")
22+
23+
export run_benchmarks, run_solver_benchmarks
2124
export profile_solvers_from_pkgbmark
2225
export create_gist_from_json_dict, create_gist_from_json_file
2326
export update_gist_from_json_dict, update_gist_from_json_file
27+
export solver_benchmark_profile_values, solver_benchmark_table_values
2428
export write_md
2529

2630
const git = Git.git()
@@ -199,7 +203,7 @@ function run_benchmarks(
199203
)
200204

201205
@info "finished"
202-
return nothing
206+
return update_gist ? gist_url : new_gist_url
203207
end
204208

205209
# Utility functions

src/solver_benchmarks.jl

Lines changed: 239 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,239 @@
1+
"""
2+
run_solver_benchmarks(repo_name, bmark_dir; reference_branch="main", gist_url=nothing, script="benchmarks.jl")
3+
4+
Run a benchmark script, based on the SolverBenchmarks.jl package, for a Julia repository.
5+
6+
This function executes a benchmark script (`script`) in the specified benchmark directory (`bmark_dir`) for
7+
the current state of the repository containing `repo_name`.
8+
The output of the script should be a result of `BenchmarkSolver.bmark_solvers`. If the repository is a Git repository, the
9+
benchmarks are run on the current commit and optionally compared to a reference branch (default `"main"`).
10+
The results are saved as `.jld2` files, performance profile plots and summary tables are generated.
11+
Optionally, results can be uploaded or updated in a GitHub Gist (`gist_url`).
12+
13+
# Arguments
14+
15+
- `repo_name::AbstractString`
16+
The name of the Julia package repository being benchmarked.
17+
18+
- `bmark_dir::AbstractString`
19+
Path to the directory containing the benchmark scripts. This is usually a `benchmarks/` folder
20+
inside the repository.
21+
22+
# Keyword Arguments
23+
24+
- `reference_branch::AbstractString = "main"`
25+
The Git branch used as a reference for comparison in plots and tables.
26+
27+
- `gist_url::Union{AbstractString, Nothing} = nothing`
28+
If provided, the function updates the existing Gist at this URL. Otherwise, a new Gist is created.
29+
30+
- `script::AbstractString = "benchmarks.jl"`
31+
The Julia script in `bmark_dir` that runs the benchmark suite. Must return a `Dict{Symbol, DataFrame}`
32+
as produced by `BenchmarkSolver.bmark_solvers`.
33+
34+
# Output
35+
36+
Returns a `String` containing the URL of the Gist with benchmark results. If `gist_url` was provided,
37+
the existing Gist is updated; otherwise, a new Gist URL is returned.
38+
39+
# Plots and Tables values
40+
41+
In order to compare specific outputs from the benchmark results, the `script` can override the functions
42+
JSOBenchmarks.solver_benchmark_profile_values()
43+
JSOBenchmarks.solver_benchmark_table_values()
44+
to specify which columns from the DataFrames should be used for the performance profiles and summary tables, respectively.
45+
Both should return an array of pairs, where the first element is a `Symbol` representing the column name in the DataFrame
46+
and the second element is a `String` representing the label to be used in the plots and tables.
47+
48+
# Notes
49+
50+
- This function is mostly expected to be called from a GitHub workflow.
51+
- Please refer to `SolverBenchmarks.bmark_solvers` for more information on how to write the benchmark script.
52+
"""
53+
function run_solver_benchmarks(
54+
repo_name::AbstractString,
55+
bmark_dir::AbstractString;
56+
reference_branch::AbstractString = "main",
57+
gist_url::Union{AbstractString, Nothing} = nothing,
58+
script = "benchmarks.jl",
59+
)
60+
61+
update_gist = gist_url !== nothing
62+
is_git = isdir(joinpath(bmark_dir, "..", ".git"))
63+
@info "" is_git update_gist
64+
65+
local gist_id
66+
if update_gist
67+
gist_id = split(gist_url, "/")[end]
68+
@info "" gist_id
69+
end
70+
71+
# if we are running these benchmarks from the git repository
72+
# we want to develop the package instead of using the release
73+
if is_git
74+
Pkg.develop(PackageSpec(path = joinpath(bmark_dir, "..")))
75+
else
76+
Pkg.activate(bmark_dir)
77+
end
78+
Pkg.instantiate()
79+
80+
# name the benchmark after the repo or the sha of HEAD
81+
bmarkname = is_git ? readchomp(`$git rev-parse HEAD`) : lowercase(repo_name)
82+
@info "" bmarkname
83+
84+
# Run the benchmark script on this commit
85+
this_commit = Base.include(Main, joinpath(bmark_dir, script))
86+
@assert this_commit isa Dict{Symbol, DataFrame} "Expected the benchmark script to return a Dict{Symbol, DataFrame}, but got $(typeof(this_commit)). Make sure your benchmark script returns a dict resulting from BenchmarkSolver.bmark_solvers function"
87+
@save "$(bmarkname)_solver_benchmarks_this_commit.jld2" this_commit
88+
89+
# Run the benchmark script on the reference branch
90+
local reference
91+
if is_git
92+
repo_dir = joinpath(bmark_dir, "..")
93+
repo = LibGit2.GitRepo(repo_dir)
94+
reference = _withcommit(joinpath(bmark_dir, script), repo, reference_branch, bmarkname = bmarkname)
95+
end
96+
97+
# Plotting and tables
98+
local profile_values, table_values
99+
100+
profile_values = Base.invokelatest(solver_benchmark_profile_values)
101+
table_values = Base.invokelatest(solver_benchmark_table_values)
102+
103+
files_dict = Dict{String, Any}()
104+
svgs = String[]
105+
106+
solved(df) = (df.status .== :first_order)
107+
costs = [df -> .!solved(df) * Inf + getproperty(df, value[1]) for value in profile_values]
108+
costnames = [value[2] for value in profile_values]
109+
110+
stats_columns = [value[1] for value in table_values]
111+
112+
tables = "# Solver Benchmarks Tables \n\n"
113+
if is_git
114+
for key in keys(this_commit)
115+
if haskey(reference, key)
116+
@info "Plotting $key"
117+
stats_subset = Dict(:this_commit => this_commit[key], :reference => reference[key])
118+
p = profile_solvers(stats_subset, costs, costnames, xlabel = "", ylabel = "")
119+
fname = "this_commit_vs_reference_$(key)"
120+
savefig("$(fname).svg")
121+
savefig("profiles_$(fname).pdf")
122+
push!(svgs, "$(fname).svg")
123+
content = read("$(fname).svg", String)
124+
files_dict["$(fname).svg"] = Dict("content" => content)
125+
126+
@info "Creating tables for $key"
127+
tables *= "\n## This commit vs reference: $(key)\n\n"
128+
tables *= "### This commit\n\n\n"
129+
tables *= sprint(io -> pretty_stats(io, this_commit[key][!, stats_columns], hdr_override = Dict(table_values), tf=tf_markdown))
130+
open("this_commit_$(key).tex", "w") do io
131+
pretty_latex_stats(io, this_commit[key][!, stats_columns], hdr_override = Dict(table_values))
132+
end
133+
tables *= "\n\n### Reference\n\n\n"
134+
tables *= sprint(io -> pretty_stats(io, reference[key][!, stats_columns], hdr_override = Dict(table_values), tf=tf_markdown))
135+
open("reference_$(key).tex", "w") do io
136+
pretty_latex_stats(io, reference[key][!, stats_columns], hdr_override = Dict(table_values))
137+
end
138+
else
139+
@warn "$(reference_branch) branch benchmarks do not run the solver $key. Please update the benchmark solver list in a separate PR and rebase."
140+
end
141+
end
142+
end
143+
144+
files_dict["tables.md"] = Dict("content" => tables)
145+
146+
@info "creating or updating gist"
147+
# json description of gist
148+
json_dict = Dict{String, Any}(
149+
"description" => "$(repo_name) repository benchmark",
150+
"public" => true,
151+
"files" => files_dict,
152+
)
153+
154+
if update_gist
155+
json_dict["gist_id"] = gist_id
156+
end
157+
158+
gist_json = "$(bmarkname).json"
159+
open(gist_json, "w") do f
160+
JSON.print(f, json_dict)
161+
end
162+
163+
local new_gist_url
164+
if update_gist
165+
update_gist_from_json_dict(gist_id, json_dict)
166+
else
167+
new_gist = create_gist_from_json_dict(json_dict)
168+
new_gist_url = string(new_gist.html_url)
169+
end
170+
171+
# Update markdown report
172+
if is_git
173+
fname = "bmark_$(bmarkname).md"
174+
open(fname, "a") do f
175+
write_md_svgs(f, "SolverBenchmark Profiles", gist_url, svgs)
176+
end
177+
end
178+
179+
@info "finished"
180+
return update_gist ? gist_url : new_gist_url
181+
end
182+
183+
function solver_benchmark_profile_values()
184+
return [(:elapsed_time, "CPU Time"), (:neval_obj, "# Objective Evals"), (:neval_grad, "# Gradient Evals")]
185+
end
186+
187+
function solver_benchmark_table_values()
188+
return [(:name, "Name"), (:objective, "f(x)"), (:elapsed_time, "Time")]
189+
end
190+
191+
# Runs a script at a commit on a repo and afterwards goes back
192+
# to the original commit / branch.
193+
# This code is based on https://github.com/JuliaCI/PkgBenchmark.jl/blob/master/src/util.jl
194+
function _withcommit(script, repo, commit; bmarkname = "")
195+
original_commit = string(LibGit2.GitHash(LibGit2.GitObject(repo, "HEAD")))
196+
local result
197+
LibGit2.transact(repo) do r
198+
branch = try LibGit2.branch(r) catch err; nothing end
199+
try
200+
LibGit2.checkout!(r, _shastring(r, commit))
201+
202+
env_to_use = dirname(Pkg.Types.Context().env.project_file)
203+
save_file_name = "$(bmarkname)_solver_benchmarks_reference"
204+
exec_str =
205+
"""
206+
using JSOBenchmarks
207+
JSOBenchmarks._run_local($(repr(script)), "$(save_file_name)")
208+
"""
209+
run(`$(Base.julia_cmd()) --project=$env_to_use --depwarn=no -e $exec_str`)
210+
211+
result = load("$(save_file_name).jld2")["result"]
212+
213+
@assert result isa Dict{Symbol, DataFrame} "Expected the benchmark script to return a Dict{Symbol, DataFrame}, but got $(typeof(result)). Make sure your benchmark script returns a dict resulting from BenchmarkSolver.bmark_solvers function"
214+
catch err
215+
rethrow(err)
216+
finally
217+
if branch !== nothing
218+
LibGit2.branch!(r, branch)
219+
else
220+
LibGit2.checkout!(r, original_commit)
221+
end
222+
end
223+
end
224+
return result
225+
end
226+
227+
function _run_local(script, save_file_name)
228+
result = Base.include(Main, script)
229+
@save "$(save_file_name).jld2" result
230+
end
231+
232+
function _shastring(r::LibGit2.GitRepo, targetname)
233+
branch = LibGit2.lookup_branch(r, targetname)
234+
branch = branch === nothing ? LibGit2.lookup_branch(r, targetname, true) : branch # Search remote as well if not found locally.
235+
branch = branch === nothing ? LibGit2.lookup_branch(r, "origin/$(targetname)") : branch
236+
branch = branch === nothing ? LibGit2.lookup_branch(r, "origin/$(targetname)", true) : branch # Search remote as well if not found locally.
237+
@assert branch !== nothing "Branch $(targetname) not found in repository."
238+
return string(LibGit2.GitHash(LibGit2.GitObject(r, LibGit2.name(branch))))
239+
end

0 commit comments

Comments
 (0)