Skip to content

Commit c66f314

Browse files
author
Ruslan Shaiakhmetov
committed
fix: no percents
1 parent 09b4a3f commit c66f314

1 file changed

Lines changed: 52 additions & 21 deletions

File tree

benchmark_smart_plot.py

Lines changed: 52 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
viridis = plt.get_cmap()
1414

15-
def plot_horizontal_bar_chart(data, ax, compare_data = None, xlabel='X label', show_increase=True, low_labels = False, reverse=True, red=True):
15+
def plot_horizontal_bar_chart(data, ax, xlabel='X label', reverse=True, red=True):
1616
"""
1717
Plots a horizontal bar chart on the provided AxesSubplot (ax) with a log-scaled x-axis.
1818
@@ -39,25 +39,56 @@ def plot_horizontal_bar_chart(data, ax, compare_data = None, xlabel='X label', s
3939
error_color_smart.append(error_colors[i])
4040
error_color_smart.append(error_colors[i+half])
4141

42+
ax.barh(categories, values, color=error_color_smart if red else color_smart, height=1.0)
43+
44+
# Set log scale for the x-axis
45+
ax.set_xscale("log")
46+
ax.set_ylim((-0.5, len(categories)-0.5))
47+
ax.set_xlabel(xlabel)
48+
ax.grid(True, linestyle='--', linewidth=0.7, color='gray', alpha=0.7)
49+
50+
def plot_horizontal_bar_chart_compare(data, ax, compare_data, xlabel='X label', reverse=True):
51+
"""
52+
Plots a horizontal bar chart on the provided AxesSubplot (ax) with a log-scaled x-axis.
4253
43-
if compare_data:
44-
# Plot horizontal bars
45-
ax.barh(categories, values, color=error_color_smart, height=1.0)
46-
47-
error_values = []
48-
for ii, each in enumerate(categories):
49-
error_values.append(compare_data[each])
50-
diff = (data[each]-compare_data[each])/compare_data[each]
51-
if show_increase:
52-
if diff < 1.0 and low_labels:
53-
ax.text(values[ii]*2 if diff > 0 else error_values[ii]*2, ii-0.25, f'{diff*100:.0f}%', ha='center')
54-
else:
55-
ax.text(values[ii]*0.5 if diff < 11 else values[ii]*0.35, ii-0.25, f'×{diff:.2f}', ha='center')
56-
57-
ax.barh(categories, error_values, color=color_smart, height=1.0)
58-
else:
59-
ax.barh(categories, values, color=error_color_smart if red else color_smart, height=1.0)
54+
Parameters:
55+
data (dict): Dictionary where keys are categories (str) and values are corresponding numeric values.
56+
ax (AxesSubplot): Matplotlib AxesSubplot object to draw the chart on.
57+
"""
58+
# Sort the data by max values
59+
max_dict = {key: max(data[key], compare_data[key]) for key in data}
60+
categories = sorted(max_dict, key=max_dict.get, reverse=reverse) # Sort by values
6061

62+
# Generate positions and colors
63+
values = [data[key] for key in categories]
64+
error_values = [compare_data[key] for key in categories]
65+
max_values = [max_dict[key] for key in categories]
66+
num_categories = len(categories)+2
67+
colors = cm.viridis(np.linspace(0.0, 1.0, num_categories))
68+
error_colors = cm.magma(np.linspace(0.5, 1.0, num_categories))
69+
70+
color_smart = []
71+
error_color_smart = []
72+
half = int((len(categories)+1)/2)
73+
for i in range(half):
74+
color_smart.append(colors[i])
75+
color_smart.append(colors[i+half])
76+
error_color_smart.append(error_colors[i])
77+
error_color_smart.append(error_colors[i+half])
78+
79+
ax.barh(categories, values, color=error_color_smart, height=1.0)
80+
ax.barh(categories, error_values, color=color_smart, height=1.0)
81+
82+
max_value = max(max_values)
83+
84+
for ii, each in enumerate(categories):
85+
diff = data[each]-compare_data[each]
86+
ratio = data[each]/compare_data[each]
87+
if diff/max_value>0.02:
88+
ax.text(max_values[ii]*0.9, ii-0.25, f'×{ratio:.2f}', ha='right')
89+
else:
90+
ax.text(max_values[ii], ii-0.25, f'×{ratio:.2f}', ha='left')
91+
6192

6293
# Set log scale for the x-axis
6394
ax.set_xscale("log")
@@ -316,7 +347,7 @@ def parallel_acc(file_name):
316347
data_compare_with = plot_data.copy()
317348
else:
318349

319-
plot_horizontal_bar_chart(plot_data, ax[n_plot], compare_data=data_compare_with, xlabel=f'Area under the curve')
350+
plot_horizontal_bar_chart_compare(plot_data, ax[n_plot], compare_data=data_compare_with, xlabel=f'Area under the curve')
320351
text = ax[n_plot].text(
321352
0.95, 0.98, # Position (x, y) in axes coordinates
322353
f'b={b[:-2]}', # The text content
@@ -419,7 +450,7 @@ def parallel_acc(file_name):
419450
nice_label = each.replace("optim.", "").replace("torch.", "")
420451
plot_data_without[nice_label] = results[i]
421452

422-
plot_horizontal_bar_chart(plot_data_without, ax[n_plot], compare_data=plot_data, xlabel=f'Area under the curve', low_labels = True)
453+
plot_horizontal_bar_chart_compare(plot_data_without, ax[n_plot], compare_data=plot_data, xlabel="Area under the curve")
423454
text = ax[n_plot].text(
424455
0.95, 0.98, # Position (x, y) in axes coordinates
425456
f'b={b[:-2]}', # The text content
@@ -436,7 +467,7 @@ def parallel_acc(file_name):
436467
plt.savefig("img/Smart_plot_11.pdf", format="pdf", dpi=300)
437468
plt.close()
438469

439-
if True:
470+
if False:
440471

441472
fig, ax = plt.subplots(figsize=(4, 6))
442473

0 commit comments

Comments
 (0)