|
| 1 | +import pandas as pd |
| 2 | +import matplotlib.pyplot as plt |
| 3 | +import seaborn as sns |
| 4 | +from matplotlib import rcParams |
| 5 | + |
| 6 | +# Font settings for consistency |
| 7 | +rcParams['font.family'] = 'DejaVu Sans' |
| 8 | +rcParams['font.size'] = 14 |
| 9 | + |
| 10 | +# Colors from your existing palette |
| 11 | +ours_color = "#10a37f" |
| 12 | +vila_fewshot_color = "#d8b88f" |
| 13 | + |
| 14 | +# Data from the table |
| 15 | +data = { |
| 16 | + ("Cleanup", "Ours"): [100, 100, 100, 100, 100], |
| 17 | + ("Cleanup", "ViLa-fewshot"): [100, 100, 66, 0, 33], |
| 18 | + ("Juice", "Ours"): [66, 66, 100, 66, 33], |
| 19 | + ("Juice", "ViLa-fewshot"): [0, 66, 100, 0, 0], |
| 20 | +} |
| 21 | + |
| 22 | +axes_labels = ["New obj.", "New vis.", "More obj.", "Novel goal 1", "Novel goal 2"] |
| 23 | + |
| 24 | +# Convert to DataFrame |
| 25 | +rows = [] |
| 26 | +for (task, approach), vals in data.items(): |
| 27 | + for axis, v in zip(axes_labels, vals): |
| 28 | + rows.append({ |
| 29 | + "Task": task, |
| 30 | + "Approach": approach, |
| 31 | + "Axis": axis, |
| 32 | + "Success %": v |
| 33 | + }) |
| 34 | +df = pd.DataFrame(rows) |
| 35 | +df["Axis"] = pd.Categorical(df["Axis"], categories=axes_labels, ordered=True) |
| 36 | + |
| 37 | +# Create vertical stack of horizontal bar plots |
| 38 | +fig, axes = plt.subplots(2, 1, figsize=(4, 4.8), sharex=True) |
| 39 | + |
| 40 | +tasks = ["Cleanup", "Juice"] |
| 41 | +for ax, task in zip(axes, tasks): |
| 42 | + d = df[df["Task"] == task] |
| 43 | + sns.barplot( |
| 44 | + data=d, y="Axis", x="Success %", |
| 45 | + hue="Approach", |
| 46 | + palette=[ours_color, vila_fewshot_color], |
| 47 | + ax=ax, capsize=0.1, orient='h' |
| 48 | + ) |
| 49 | + |
| 50 | + ax.set_title(task, fontsize=14) |
| 51 | + ax.set_xlabel("% success (3 trials)", fontsize=11) |
| 52 | + ax.set_ylabel("") |
| 53 | + ax.set_xlim(0, 110) |
| 54 | + ax.grid(axis='x', linestyle='--', alpha=0.6) |
| 55 | + ax.tick_params(axis='y', labelsize=11) |
| 56 | + |
| 57 | +# Legend above the top plot |
| 58 | +handles, labels = axes[0].get_legend_handles_labels() |
| 59 | +fig.legend(handles, labels, loc='upper center', ncol=2, frameon=False, fontsize=11, bbox_to_anchor=(0.5, 1.02)) |
| 60 | +axes[0].get_legend().remove() |
| 61 | +axes[1].get_legend().remove() |
| 62 | + |
| 63 | +fig.tight_layout(rect=[0, 0, 1, 0.96]) |
| 64 | +plt.show() |
0 commit comments