Skip to content

Commit 729a7cb

Browse files
committed
plot_mean_test_f1: add --restricted plot
1 parent 9db6e1d commit 729a7cb

1 file changed

Lines changed: 25 additions & 12 deletions

File tree

plot_mean_test_f1.py

Lines changed: 25 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
parser = argparse.ArgumentParser()
88
parser.add_argument("-o", "--output", type=str, default=None)
99
parser.add_argument("-r", "--oracle", action="store_true")
10+
parser.add_argument("-e", "--restricted", action="store_true")
11+
parser.add_argument("--no-baseline", action="store_true")
1012
args = parser.parse_args()
1113

1214
# from https://matplotlib.org/stable/gallery/lines_bars_and_markers/linestyles.html
@@ -29,7 +31,19 @@
2931
]
3032

3133

32-
runs = ["random", "bm25", "samenoun", "left", "right", "neighbors"]
34+
# runs is of form {run_dir_name => name}
35+
if args.oracle:
36+
runs = {
37+
f"oracle_{run}": run
38+
for run in ["random", "bm25", "samenoun", "before", "after", "surrounding"]
39+
}
40+
elif args.restricted:
41+
runs = {"oracle_bm25": "bm25", "bm25_restricted": "restricted bm25"}
42+
else:
43+
runs = {
44+
run: run
45+
for run in ["random", "bm25", "samenoun", "before", "after", "surrounding"]
46+
}
3347

3448
with open(f"./runs/bare/metrics.json") as f:
3549
bare_metrics = json.load(f)
@@ -42,9 +56,7 @@
4256

4357
fig.set_size_inches(16, 12)
4458

45-
for run_i, run in enumerate(runs):
46-
if args.oracle:
47-
run = f"oracle_{run}"
59+
for run_i, (run, run_name) in enumerate(runs.items()):
4860
with open(f"./runs/short/{run}/metrics.json") as f:
4961
metrics = json.load(f)
5062
ax.plot(
@@ -54,19 +66,20 @@
5466
linewidth=4,
5567
)
5668

57-
# bare baseline
58-
ax.plot(
59-
[1, 6],
60-
[bare_metrics["mean_test_f1"]["values"][0]] * 2,
61-
linestyle=linestyle_tuple[len(runs)][1],
62-
linewidth=4,
63-
)
69+
# no retrieval baseline
70+
if not args.no_baseline:
71+
ax.plot(
72+
[1, 6],
73+
[bare_metrics["mean_test_f1"]["values"][0]] * 2,
74+
linestyle=linestyle_tuple[len(runs)][1],
75+
linewidth=4,
76+
)
6477

6578
ax.grid()
6679
ax.set_ylabel("F1", fontsize=40)
6780
ax.set_xlabel("Number of retrieved sentences", fontsize=40)
6881
ax.legend(
69-
runs + ["no retrieval"],
82+
list(runs.values()) + ["no retrieval"] if not args.no_baseline else [],
7083
loc="lower center",
7184
bbox_to_anchor=(0.5, 1),
7285
fontsize=40,

0 commit comments

Comments
 (0)