Skip to content

Commit e209087

Browse files
new realspot plots
1 parent 7587330 commit e209087

1 file changed

Lines changed: 64 additions & 0 deletions

File tree

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
import pandas as pd
2+
import matplotlib.pyplot as plt
3+
import seaborn as sns
4+
from matplotlib import rcParams
5+
6+
# Font settings for consistency
7+
rcParams['font.family'] = 'DejaVu Sans'
8+
rcParams['font.size'] = 14
9+
10+
# Colors from your existing palette
11+
ours_color = "#10a37f"
12+
vila_fewshot_color = "#d8b88f"
13+
14+
# Data from the table
15+
data = {
16+
("Cleanup", "Ours"): [100, 100, 100, 100, 100],
17+
("Cleanup", "ViLa-fewshot"): [100, 100, 66, 0, 33],
18+
("Juice", "Ours"): [66, 66, 100, 66, 33],
19+
("Juice", "ViLa-fewshot"): [0, 66, 100, 0, 0],
20+
}
21+
22+
axes_labels = ["New obj.", "New vis.", "More obj.", "Novel goal 1", "Novel goal 2"]
23+
24+
# Convert to DataFrame
25+
rows = []
26+
for (task, approach), vals in data.items():
27+
for axis, v in zip(axes_labels, vals):
28+
rows.append({
29+
"Task": task,
30+
"Approach": approach,
31+
"Axis": axis,
32+
"Success %": v
33+
})
34+
df = pd.DataFrame(rows)
35+
df["Axis"] = pd.Categorical(df["Axis"], categories=axes_labels, ordered=True)
36+
37+
# Create vertical stack of horizontal bar plots
38+
fig, axes = plt.subplots(2, 1, figsize=(4, 4.8), sharex=True)
39+
40+
tasks = ["Cleanup", "Juice"]
41+
for ax, task in zip(axes, tasks):
42+
d = df[df["Task"] == task]
43+
sns.barplot(
44+
data=d, y="Axis", x="Success %",
45+
hue="Approach",
46+
palette=[ours_color, vila_fewshot_color],
47+
ax=ax, capsize=0.1, orient='h'
48+
)
49+
50+
ax.set_title(task, fontsize=14)
51+
ax.set_xlabel("% success (3 trials)", fontsize=11)
52+
ax.set_ylabel("")
53+
ax.set_xlim(0, 110)
54+
ax.grid(axis='x', linestyle='--', alpha=0.6)
55+
ax.tick_params(axis='y', labelsize=11)
56+
57+
# Legend above the top plot
58+
handles, labels = axes[0].get_legend_handles_labels()
59+
fig.legend(handles, labels, loc='upper center', ncol=2, frameon=False, fontsize=11, bbox_to_anchor=(0.5, 1.02))
60+
axes[0].get_legend().remove()
61+
axes[1].get_legend().remove()
62+
63+
fig.tight_layout(rect=[0, 0, 1, 0.96])
64+
plt.show()

0 commit comments

Comments
 (0)