Skip to content

Commit badc31f

Browse files
Remove NNMF problem and update Benchmark.jl accordingly (#254)
* Remove NNMF problem and update Benchmark.jl accordingly * Refactor SVM benchmark output formatting and improve LaTeX table generation
1 parent 3fe3b3d commit badc31f

File tree

4 files changed

+22
-185
lines changed

4 files changed

+22
-185
lines changed

paper/examples/Benchmark.jl

Lines changed: 20 additions & 179 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
# ======== IMPORTS ======== #
44
#############################
55
using Random, LinearAlgebra
6-
using ProximalOperators, ProximalCore, ProximalAlgorithms
76
using ShiftedProximalOperators
87
using NLPModels, NLPModelsModifiers
98
using RegularizedOptimization, RegularizedProblems
@@ -14,7 +13,7 @@ using LaTeXStrings
1413

1514
# Local includes
1615
include("comparison-config.jl")
17-
using .ComparisonConfig: CFG, CFG2
16+
using .ComparisonConfig: CFG
1817

1918
#############################
2019
# ===== Helper utils ====== #
@@ -58,7 +57,7 @@ function run_tr_svm!(model, x0; λ = 1.0, qn = :LSR1, atol = 1e-3, rtol = 1e-3,
5857
t = @elapsed RegularizedOptimization.solve!(solver, reg_nlp, stats;
5958
x = x0, atol = atol, rtol = rtol, verbose = verbose, sub_kwargs = sub_kwargs)
6059
return (
61-
name = "TR ($(String(qn)), SVM)",
60+
name = "TR",
6261
status = string(stats.status),
6362
time = t,
6463
iters = get(stats.solver_specific, :outer_iter, missing),
@@ -84,7 +83,7 @@ function run_r2n_svm!(model, x0; λ = 1.0, qn = :LBFGS, atol = 1e-3, rtol = 1e-3
8483
t = @elapsed RegularizedOptimization.solve!(solver, reg_nlp, stats;
8584
x = x0, atol = atol, rtol = rtol, verbose = verbose, sub_kwargs = sub_kwargs)
8685
return (
87-
name = "R2N ($(String(qn)), SVM)",
86+
name = "R2N",
8887
status = string(stats.status),
8988
time = t,
9089
iters = get(stats.solver_specific, :outer_iter, missing),
@@ -108,7 +107,7 @@ function run_LM_svm!(nls_model, x0; λ = 1.0, atol = 1e-3, rtol = 1e-3, verbose
108107
t = @elapsed RegularizedOptimization.solve!(solver, reg_nls, stats;
109108
x = x0, atol = atol, rtol = rtol, verbose = verbose, sub_kwargs = sub_kwargs)
110109
return (
111-
name = "LM (SVM)",
110+
name = "LM",
112111
status = string(stats.status),
113112
time = t,
114113
iters = get(stats.solver_specific, :outer_iter, missing),
@@ -132,7 +131,7 @@ function run_LMTR_svm!(nls_model, x0; λ = 1.0, atol = 1e-3, rtol = 1e-3, verbos
132131
t = @elapsed RegularizedOptimization.solve!(solver, reg_nls, stats;
133132
x = x0, atol = atol, rtol = rtol, verbose = verbose, sub_kwargs = sub_kwargs)
134133
return (
135-
name = "LMTR (SVM)",
134+
name = "LMTR",
136135
status = string(stats.status),
137136
time = t,
138137
iters = get(stats.solver_specific, :outer_iter, missing),
@@ -159,18 +158,14 @@ function bench_svm!(cfg = CFG)
159158
println("\n=== SVM: solver comparison ===")
160159
for m in results
161160
println("\n", m.name)
162-
println(" status = ", m.status)
163-
println(" time (s) = ", round(m.time, digits = 4))
164-
m.iters !== missing && println(" outer iters = ", m.iters)
165-
println(" # f eval = ", m.fevals)
166-
println(" # ∇f eval = ", m.gevals)
167-
m.proxcalls !== missing && println(" # prox calls = ", Int(m.proxcalls))
168-
println(" final objective= ", round(obj(model, m.solution), digits = 4))
169-
println(" accuracy (%) = ", round(acc(residual(nls_train, m.solution)), digits = 1))
161+
println(" status = ", m.status)
162+
println(" time (s) = ", round(m.time, digits = 4))
163+
println(" # f eval = ", m.fevals)
164+
println(" # ∇f eval = ", m.gevals)
165+
m.proxcalls !== missing && println(" # prox calls = ", Int(m.proxcalls))
166+
println(" final objective = ", round(obj(model, m.solution), digits = 4))
170167
end
171168

172-
println("\nSVM Config:"); print_config(cfg)
173-
174169
data_svm = [
175170
(; name=m.name,
176171
status=string(m.status),
@@ -185,182 +180,28 @@ function bench_svm!(cfg = CFG)
185180
return data_svm
186181
end
187182

188-
#############################
189-
# ======= NNMF bench ====== #
190-
#############################
191-
192-
function run_tr_nnmf!(model, x0; λ = 1.0, qn = :LSR1, atol = 1e-3, rtol = 1e-3, verbose = 0, sub_kwargs = (;), selected = nothing)
193-
qn_model = ensure_qn(model, qn)
194-
reset!(qn_model)
195-
reg_nlp = RegularizedNLPModel(qn_model, NormL0(λ), selected)
196-
solver = TRSolver(reg_nlp)
197-
stats = RegularizedExecutionStats(reg_nlp)
198-
RegularizedOptimization.solve!(solver, reg_nlp, stats;
199-
x = x0, atol = atol, rtol = rtol, verbose = verbose, sub_kwargs = sub_kwargs)
200-
reset!(qn_model) # Reset counters before timing
201-
reg_nlp = RegularizedNLPModel(qn_model, NormL0(λ), selected) # Re-create to reset prox eval count
202-
solver = TRSolver(reg_nlp)
203-
t = @elapsed RegularizedOptimization.solve!(solver, reg_nlp, stats;
204-
x = x0, atol = atol, rtol = rtol, verbose = verbose, sub_kwargs = sub_kwargs)
205-
return (
206-
name = "TR ($(String(qn)), NNMF)",
207-
status = string(stats.status),
208-
time = t,
209-
iters = get(stats.solver_specific, :outer_iter, missing),
210-
fevals = neval_obj(qn_model),
211-
gevals = neval_grad(qn_model),
212-
proxcalls = get(stats.solver_specific, :prox_evals, missing),
213-
solution = stats.solution,
214-
final_obj = obj(model, stats.solution)
215-
)
216-
end
217-
218-
function run_r2n_nnmf!(model, x0; λ = 1.0, qn = :LBFGS, atol = 1e-3, rtol = 1e-3, verbose = 0, sub_kwargs = (;), σk = 1e5, selected = nothing)
219-
qn_model = ensure_qn(model, qn)
220-
reset!(qn_model)
221-
reg_nlp = RegularizedNLPModel(qn_model, NormL0(λ), selected)
222-
solver = R2NSolver(reg_nlp)
223-
stats = RegularizedExecutionStats(reg_nlp)
224-
RegularizedOptimization.solve!(solver, reg_nlp, stats;
225-
x = x0, atol = atol, rtol = rtol, verbose = verbose,
226-
sub_kwargs = sub_kwargs)
227-
228-
reset!(qn_model) # Reset counters before timing
229-
reg_nlp = RegularizedNLPModel(qn_model, NormL0(λ), selected) # Re-create to reset prox eval count
230-
solver = R2NSolver(reg_nlp)
231-
t = @elapsed RegularizedOptimization.solve!(solver, reg_nlp, stats;
232-
x = x0, atol = atol, rtol = rtol, verbose = verbose,
233-
sub_kwargs = sub_kwargs)
234-
return (
235-
name = "R2N ($(String(qn)), NNMF)",
236-
status = string(stats.status),
237-
time = t,
238-
iters = get(stats.solver_specific, :outer_iter, missing),
239-
fevals = neval_obj(qn_model),
240-
gevals = neval_grad(qn_model),
241-
proxcalls = get(stats.solver_specific, :prox_evals, missing),
242-
solution = stats.solution,
243-
final_obj = obj(model, stats.solution)
244-
)
245-
end
246-
247-
function run_LM_nnmf!(nls_model, x0; λ = 1.0, atol = 1e-3, rtol = 1e-3, verbose = 0, selected = nothing, sub_kwargs = (;))
248-
reg_nls = RegularizedNLSModel(nls_model, NormL0(λ), selected)
249-
solver = LMSolver(reg_nls)
250-
stats = RegularizedExecutionStats(reg_nls)
251-
RegularizedOptimization.solve!(solver, reg_nls, stats;
252-
x = x0, atol = atol, rtol = rtol, verbose = verbose, sub_kwargs = sub_kwargs)
253-
reset!(nls_model) # Reset counters before timing
254-
reg_nls = RegularizedNLSModel(nls_model, NormL0(λ), selected)
255-
solver = LMSolver(reg_nls)
256-
t = @elapsed RegularizedOptimization.solve!(solver, reg_nls, stats;
257-
x = x0, atol = atol, rtol = rtol, verbose = verbose, sub_kwargs = sub_kwargs)
258-
return (
259-
name = "LM (NNMF)",
260-
status = string(stats.status),
261-
time = t,
262-
iters = get(stats.solver_specific, :outer_iter, missing),
263-
fevals = neval_residual(nls_model),
264-
gevals = neval_jtprod_residual(nls_model) + neval_jprod_residual(nls_model),
265-
proxcalls = get(stats.solver_specific, :prox_evals, missing),
266-
solution = stats.solution,
267-
final_obj = obj(nls_model, stats.solution)
268-
)
269-
end
270-
271-
function run_LMTR_nnmf!(nls_model, x0; λ = 1.0, atol = 1e-3, rtol = 1e-3, verbose = 0, selected = nothing, sub_kwargs = (;))
272-
reg_nls = RegularizedNLSModel(nls_model, NormL0(λ), selected)
273-
solver = LMTRSolver(reg_nls)
274-
stats = RegularizedExecutionStats(reg_nls)
275-
RegularizedOptimization.solve!(solver, reg_nls, stats;
276-
x = x0, atol = atol, rtol = rtol, verbose = verbose, sub_kwargs = sub_kwargs)
277-
reset!(nls_model) # Reset counters before timing
278-
reg_nls = RegularizedNLSModel(nls_model, NormL0(λ), selected)
279-
solver = LMTRSolver(reg_nls)
280-
t = @elapsed RegularizedOptimization.solve!(solver, reg_nls, stats;
281-
x = x0, atol = atol, rtol = rtol, verbose = verbose, sub_kwargs = sub_kwargs)
282-
return (
283-
name = "LMTR (NNMF)",
284-
status = string(stats.status),
285-
time = t,
286-
iters = get(stats.solver_specific, :outer_iter, missing),
287-
fevals = neval_residual(nls_model),
288-
gevals = neval_jtprod_residual(nls_model) + neval_jprod_residual(nls_model),
289-
proxcalls = get(stats.solver_specific, :prox_evals, missing),
290-
solution = stats.solution,
291-
final_obj = obj(nls_model, stats.solution)
292-
)
293-
end
294-
295-
function bench_nnmf!(cfg = CFG2; m = 100, n = 50, k = 5)
296-
Random.seed!(cfg.SEED)
297-
298-
model, nls_model, _, selected = nnmf_model(m, n, k)
299-
300-
# build x0 on positive orthant as original
301-
x0 = max.(rand(model.meta.nvar), 0.0)
302-
303-
# heuristic lambda (copied logic)
304-
cfg.LAMBDA_L0 = norm(grad(model, rand(model.meta.nvar)), Inf) / 200
305-
306-
results = NamedTuple[]
307-
(:TR in cfg.RUN_SOLVERS) && push!(results, run_tr_nnmf!(model, x0; λ = cfg.LAMBDA_L0, qn = cfg.QN_FOR_TR, atol = cfg.TOL, rtol = cfg.RTOL, verbose = cfg.VERBOSE_RO, sub_kwargs = cfg.SUB_KWARGS_R2N, selected = selected))
308-
(:R2N in cfg.RUN_SOLVERS) && push!(results, run_r2n_nnmf!(model, x0; λ = cfg.LAMBDA_L0, qn = cfg.QN_FOR_R2N, atol = cfg.TOL, rtol = cfg.RTOL, verbose = cfg.VERBOSE_RO, sub_kwargs = cfg.SUB_KWARGS_R2N, selected = selected))
309-
(:LM in cfg.RUN_SOLVERS) && push!(results, run_LM_nnmf!(nls_model, x0; λ = cfg.LAMBDA_L0, atol = cfg.TOL, rtol = cfg.RTOL, verbose = cfg.VERBOSE_RO, selected = selected, sub_kwargs = cfg.SUB_KWARGS_R2N))
310-
(:LMTR in cfg.RUN_SOLVERS) && push!(results, run_LMTR_nnmf!(nls_model, x0; λ = cfg.LAMBDA_L0, atol = cfg.TOL, rtol = cfg.RTOL, verbose = cfg.VERBOSE_RO, selected = selected, sub_kwargs = cfg.SUB_KWARGS_R2N))
311-
312-
println("\n=== NNMF: solver comparison ===")
313-
for m in results
314-
println("\n", m.name)
315-
println(" status = ", m.status)
316-
println(" time (s) = ", round(m.time, digits = 4))
317-
m.iters !== missing && println(" outer iters = ", m.iters)
318-
println(" # f eval = ", m.fevals)
319-
println(" # ∇f eval = ", m.gevals)
320-
m.proxcalls !== missing && println(" # prox calls = ", Int(m.proxcalls))
321-
println(" final objective= ", round(obj(model, m.solution), digits = 4))
322-
end
323-
324-
println("\nNNMF Config:"); print_config(cfg)
325-
326-
data_nnmf = [
327-
(; name=m.name,
328-
status=string(m.status),
329-
time=round(m.time, digits=4),
330-
fe=m.fevals,
331-
ge=m.gevals,
332-
prox = m.proxcalls === missing ? missing : Int(m.proxcalls),
333-
obj = round(m.final_obj, digits=4))
334-
for m in results
335-
]
336-
337-
return data_nnmf
338-
end
339-
340183
# #############################
341184
# # ========= Main ========== #
342185
# #############################
343186

344-
function main(latex_out = false)
187+
function main(;latex_out = false)
345188
data_svm = bench_svm!(CFG)
346-
data_nnmf = bench_nnmf!(CFG2)
347-
348-
all_data = vcat(data_svm, data_nnmf)
349189

350190
println("\n=== Full Benchmark Table ===")
351191
# what is inside the table
352-
for row in all_data
192+
for row in data_svm
353193
println(row)
354194
end
355195

356196
# save as latex format
357197
if latex_out
358-
359-
table_str = pretty_table(String, all_data;
360-
header = ["Method", "Status", L"$t$($s$)", L"$\#f$", L"$\#\nabla f$", L"$\#prox$", "Objective"],
361-
backend = Val(:latex),
362-
alignment = [:l, :c, :r, :r, :r, :r, :r],
363-
)
198+
table_str = pretty_table(String,
199+
data_svm;
200+
backend = :latex,
201+
column_labels = ["Method", "Status", L"$t$($s$)", L"$\#f$", L"$\#\nabla f$", L"$\#prox$", "Objective"],
202+
style = LatexTableStyle(column_label = String[]),
203+
table_format = latex_table_format__booktabs
204+
)
364205

365206
open("Benchmark.tex", "w") do io
366207
write(io, table_str)

paper/examples/Project.toml

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,5 @@ MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458"
33
NLPModels = "a4795742-8479-5a88-8948-cc11e1c8c1a6"
44
NLPModelsModifiers = "e01155f1-5c6f-4375-a9d8-616dd036575f"
55
PrettyTables = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d"
6-
ProximalAlgorithms = "140ffc9f-1907-541a-a177-7475e0a401e9"
7-
ProximalCore = "dc4f5ac2-75d1-4f31-931e-60435d74994b"
8-
ProximalOperators = "a725b495-10eb-56fe-b38b-717eba820537"
96
RegularizedProblems = "ea076b23-609f-44d2-bb12-a4ae45328278"
107
ShiftedProximalOperators = "d4fd37fa-580c-4e43-9b30-361c21aae263"

paper/examples/comparison-config.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,5 @@ end
1717

1818
# One global, constant *binding* to a mutable object = type stable & editable
1919
const CFG = Config(QN_FOR_R2N=:LSR1)
20-
const CFG2 = Config(QN_FOR_TR = :LBFGS)
2120

2221
end # module
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
using LinearAlgebra, Random, ProximalOperators
2-
using NLPModels, RegularizedProblems, RegularizedOptimization
1+
using LinearAlgebra, Random, ShiftedProximalOperators
2+
using NLPModels, NLPModelsModifiers, RegularizedProblems, RegularizedOptimization
33
using MLDatasets
44

55
Random.seed!(1234)

0 commit comments

Comments
 (0)