Skip to content

Commit 09b4a3f

Browse files
author
Ruslan Shaiakhmetov
committed
fix: Plot 2 only range
1 parent a16ccdc commit 09b4a3f

1 file changed

Lines changed: 9 additions & 9 deletions

File tree

benchmark_smart_plot.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -266,7 +266,7 @@ def rosenbrock(x, y, b=100):
266266
"optim.QHAdam",
267267
]
268268

269-
if True:
269+
if False:
270270

271271
fig, ax = plt.subplots(1, 4, figsize=(12, 6))
272272

@@ -340,7 +340,7 @@ def parallel_acc(file_name):
340340
plt.savefig("img/Smart_plot_1.pdf", format="pdf", dpi=300)
341341
plt.close()
342342

343-
if True:
343+
if False:
344344

345345
fig, ax = plt.subplots(1, 4, figsize=(12, 6))
346346

@@ -436,9 +436,9 @@ def parallel_acc(file_name):
436436
plt.savefig("img/Smart_plot_11.pdf", format="pdf", dpi=300)
437437
plt.close()
438438

439-
if False:
439+
if True:
440440

441-
fig, ax = plt.subplots(1, 2, figsize=(6, 6))
441+
fig, ax = plt.subplots(figsize=(4, 6))
442442

443443
# Start to generate charts
444444
rosen_list = ['1000.0', '100.0', '10.0', '1.0']
@@ -469,18 +469,18 @@ def parallel_acc(file_name):
469469
max_dict[key] = max_val
470470
relative_dict[key] = max_val/min_val
471471

472-
plot_horizontal_bar_chart(plot_1000, ax[0],
473-
xlabel="Optimized learning rate for \n Rosenbrock function b=1000.0")
474-
plot_horizontal_bar_chart(relative_dict, ax[1],
475-
xlabel="Ratio of maximum and \nminimum values of \nlearning rate")
472+
#plot_horizontal_bar_chart(plot_1000, ax[0],
473+
# xlabel="Optimized learning rate for \n Rosenbrock function b=1000.0")
474+
plot_horizontal_bar_chart(relative_dict, ax,
475+
xlabel="Ratio between maximum \nand minimum global LRs", red=False)
476476

477477
# Adjust layout and show
478478
plt.tight_layout()
479479
plt.savefig("img/Smart_plot_2.pdf", format="pdf", dpi=300)
480480
plt.close()
481481

482482
if False:
483-
fig, ax = plt.subplots(4, 1, figsize=(6, 12))
483+
fig, ax = plt.subplots(4, 1, figsize=(7, 12))
484484
# n_traj = 100
485485
plot_loss("SGD_1.0_optimtrack", ax[0], name="Asymp. convergence: 100 SGD real., b=1", n_trajectories=100, linthresh=10e-16)
486486
plot_loss("Adagrad_1.0_optimtrack", ax[1], name="Init. cond. dep.: 500 AdaGrad real., b=1", n_trajectories=500, linthresh=10e-25)

0 commit comments

Comments
 (0)