1212
1313viridis = plt .get_cmap ()
1414
15- def plot_horizontal_bar_chart (data , ax , compare_data = None , xlabel = 'X label' , show_increase = True , low_labels = False , reverse = True , red = True ):
15+ def plot_horizontal_bar_chart (data , ax , xlabel = 'X label' , reverse = True , red = True ):
1616 """
1717 Plots a horizontal bar chart on the provided AxesSubplot (ax) with a log-scaled x-axis.
1818
@@ -39,25 +39,56 @@ def plot_horizontal_bar_chart(data, ax, compare_data = None, xlabel='X label', s
3939 error_color_smart .append (error_colors [i ])
4040 error_color_smart .append (error_colors [i + half ])
4141
42+ ax .barh (categories , values , color = error_color_smart if red else color_smart , height = 1.0 )
43+
44+ # Set log scale for the x-axis
45+ ax .set_xscale ("log" )
46+ ax .set_ylim ((- 0.5 , len (categories )- 0.5 ))
47+ ax .set_xlabel (xlabel )
48+ ax .grid (True , linestyle = '--' , linewidth = 0.7 , color = 'gray' , alpha = 0.7 )
49+
50+ def plot_horizontal_bar_chart_compare (data , ax , compare_data , xlabel = 'X label' , reverse = True ):
51+ """
52+ Plots a horizontal bar chart on the provided AxesSubplot (ax) with a log-scaled x-axis.
4253
43- if compare_data :
44- # Plot horizontal bars
45- ax .barh (categories , values , color = error_color_smart , height = 1.0 )
46-
47- error_values = []
48- for ii , each in enumerate (categories ):
49- error_values .append (compare_data [each ])
50- diff = (data [each ]- compare_data [each ])/ compare_data [each ]
51- if show_increase :
52- if diff < 1.0 and low_labels :
53- ax .text (values [ii ]* 2 if diff > 0 else error_values [ii ]* 2 , ii - 0.25 , f'{ diff * 100 :.0f} %' , ha = 'center' )
54- else :
55- ax .text (values [ii ]* 0.5 if diff < 11 else values [ii ]* 0.35 , ii - 0.25 , f'×{ diff :.2f} ' , ha = 'center' )
56-
57- ax .barh (categories , error_values , color = color_smart , height = 1.0 )
58- else :
59- ax .barh (categories , values , color = error_color_smart if red else color_smart , height = 1.0 )
54+ Parameters:
55+ data (dict): Dictionary where keys are categories (str) and values are corresponding numeric values.
56+ ax (AxesSubplot): Matplotlib AxesSubplot object to draw the chart on.
57+ """
58+ # Sort the data by max values
59+ max_dict = {key : max (data [key ], compare_data [key ]) for key in data }
60+ categories = sorted (max_dict , key = max_dict .get , reverse = reverse ) # Sort by values
6061
62+ # Generate positions and colors
63+ values = [data [key ] for key in categories ]
64+ error_values = [compare_data [key ] for key in categories ]
65+ max_values = [max_dict [key ] for key in categories ]
66+ num_categories = len (categories )+ 2
67+ colors = cm .viridis (np .linspace (0.0 , 1.0 , num_categories ))
68+ error_colors = cm .magma (np .linspace (0.5 , 1.0 , num_categories ))
69+
70+ color_smart = []
71+ error_color_smart = []
72+ half = int ((len (categories )+ 1 )/ 2 )
73+ for i in range (half ):
74+ color_smart .append (colors [i ])
75+ color_smart .append (colors [i + half ])
76+ error_color_smart .append (error_colors [i ])
77+ error_color_smart .append (error_colors [i + half ])
78+
79+ ax .barh (categories , values , color = error_color_smart , height = 1.0 )
80+ ax .barh (categories , error_values , color = color_smart , height = 1.0 )
81+
82+ max_value = max (max_values )
83+
84+ for ii , each in enumerate (categories ):
85+ diff = data [each ]- compare_data [each ]
86+ ratio = data [each ]/ compare_data [each ]
87+ if diff / max_value > 0.02 :
88+ ax .text (max_values [ii ]* 0.9 , ii - 0.25 , f'×{ ratio :.2f} ' , ha = 'right' )
89+ else :
90+ ax .text (max_values [ii ], ii - 0.25 , f'×{ ratio :.2f} ' , ha = 'left' )
91+
6192
6293 # Set log scale for the x-axis
6394 ax .set_xscale ("log" )
@@ -316,7 +347,7 @@ def parallel_acc(file_name):
316347 data_compare_with = plot_data .copy ()
317348 else :
318349
319- plot_horizontal_bar_chart (plot_data , ax [n_plot ], compare_data = data_compare_with , xlabel = f'Area under the curve' )
350+ plot_horizontal_bar_chart_compare (plot_data , ax [n_plot ], compare_data = data_compare_with , xlabel = f'Area under the curve' )
320351 text = ax [n_plot ].text (
321352 0.95 , 0.98 , # Position (x, y) in axes coordinates
322353 f'b={ b [:- 2 ]} ' , # The text content
@@ -419,7 +450,7 @@ def parallel_acc(file_name):
419450 nice_label = each .replace ("optim." , "" ).replace ("torch." , "" )
420451 plot_data_without [nice_label ] = results [i ]
421452
422- plot_horizontal_bar_chart (plot_data_without , ax [n_plot ], compare_data = plot_data , xlabel = f' Area under the curve' , low_labels = True )
453+ plot_horizontal_bar_chart_compare (plot_data_without , ax [n_plot ], compare_data = plot_data , xlabel = " Area under the curve" )
423454 text = ax [n_plot ].text (
424455 0.95 , 0.98 , # Position (x, y) in axes coordinates
425456 f'b={ b [:- 2 ]} ' , # The text content
@@ -436,7 +467,7 @@ def parallel_acc(file_name):
436467 plt .savefig ("img/Smart_plot_11.pdf" , format = "pdf" , dpi = 300 )
437468 plt .close ()
438469
439- if True :
470+ if False :
440471
441472 fig , ax = plt .subplots (figsize = (4 , 6 ))
442473
0 commit comments