Skip to content

Commit 68c59f3

Browse files
committed
adapt plot_mean_test_f1
1 parent fff38b2 commit 68c59f3

1 file changed

Lines changed: 35 additions & 51 deletions

File tree

plot_mean_test_f1.py

Lines changed: 35 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -6,68 +6,52 @@
66

77
parser = argparse.ArgumentParser()
88
parser.add_argument("-o", "--output", type=str, default=None)
9-
parser.add_argument(
10-
"-g", "--group", type=str, default="global", help="one of: 'global', 'local'"
11-
)
9+
parser.add_argument("-r", "--oracle", action="store_true")
1210
args = parser.parse_args()
1311

1412

15-
if args.group == "global":
16-
runs = [
17-
["random", "neural_random", "neural_ideal_random"],
18-
["bm25", "neural_bm25", "neural_ideal_bm25"],
19-
["samenoun", "neural_samenoun", "neural_ideal_samenoun"],
20-
]
21-
elif args.group == "local":
22-
runs = [
23-
["left", "neural_left", "neural_ideal_left"],
24-
["right", "neural_right", "neural_ideal_right"],
25-
["neighbors", "neural_neighbors", "neural_ideal_neighbors"],
26-
]
27-
else:
28-
raise ValueError(f"unknown group: {args.group}")
13+
runs = ["random", "bm25", "samenoun", "left", "right", "neighbors"]
2914

3015
with open(f"./runs/bare/metrics.json") as f:
3116
bare_metrics = json.load(f)
3217

3318

3419
plt.style.use("science")
35-
plt.rcParams.update({"xtick.labelsize": 18})
36-
plt.rcParams.update({"ytick.labelsize": 18})
37-
fig, axs = plt.subplots(1, 3)
38-
39-
fig.set_size_inches(24, 4)
40-
41-
for i, run_group in enumerate(runs):
42-
43-
min_steps = []
44-
max_steps = []
45-
for run in run_group:
46-
with open(f"./runs/{run}/metrics.json") as f:
47-
metrics = json.load(f)
48-
axs[i].plot(
49-
[int(step) for step in metrics["mean_test_f1"]["steps"]],
50-
metrics["mean_test_f1"]["values"],
51-
)
52-
min_steps.append(min(metrics["mean_test_f1"]["steps"]))
53-
max_steps.append(max(metrics["mean_test_f1"]["steps"]))
54-
55-
# bare baseline
56-
axs[i].plot(
57-
[min(min_steps), max(max_steps)],
58-
[bare_metrics["mean_test_f1"]["values"][0]] * 2,
59-
linestyle="--",
20+
# plt.rcParams.update({"xtick.labelsize": 18})
21+
# plt.rcParams.update({"ytick.labelsize": 18})
22+
plt.rc("xtick", labelsize=40) # fontsize of the tick labels
23+
plt.rc("ytick", labelsize=40) # fontsize of the tick labels
24+
fig, ax = plt.subplots()
25+
26+
fig.set_size_inches(16, 8)
27+
28+
for run in runs:
29+
if args.oracle:
30+
run = f"oracle_{run}"
31+
with open(f"./runs/short/{run}/metrics.json") as f:
32+
metrics = json.load(f)
33+
ax.plot(
34+
[int(step) for step in metrics["mean_test_f1"]["steps"]],
35+
metrics["mean_test_f1"]["values"],
6036
)
6137

62-
axs[i].grid()
63-
axs[i].set_ylabel("F1", fontsize=20)
64-
axs[i].set_xlabel("Number of retrieved sentences", fontsize=20)
65-
axs[i].legend(
66-
[r.replace("_", " ") for r in run_group] + ["no retrieval"],
67-
loc="lower center",
68-
bbox_to_anchor=(0.5, 1),
69-
fontsize=20,
70-
)
38+
# bare baseline
39+
ax.plot(
40+
[1, 6],
41+
[bare_metrics["mean_test_f1"]["values"][0]] * 2,
42+
linestyle="--",
43+
)
44+
45+
ax.grid()
46+
ax.set_ylabel("F1", fontsize=40)
47+
ax.set_xlabel("Number of retrieved sentences", fontsize=40)
48+
ax.legend(
49+
runs + ["no retrieval"],
50+
loc="lower center",
51+
bbox_to_anchor=(0.5, 1),
52+
fontsize=40,
53+
ncol=3,
54+
)
7155

7256

7357
if args.output:

0 commit comments

Comments
 (0)