|
7 | 7 | parser = argparse.ArgumentParser() |
8 | 8 | parser.add_argument("-o", "--output", type=str, default=None) |
9 | 9 | parser.add_argument("-r", "--oracle", action="store_true") |
| 10 | +parser.add_argument("-e", "--restricted", action="store_true") |
| 11 | +parser.add_argument("--no-baseline", action="store_true") |
10 | 12 | args = parser.parse_args() |
11 | 13 |
|
12 | 14 | # from https://matplotlib.org/stable/gallery/lines_bars_and_markers/linestyles.html |
|
29 | 31 | ] |
30 | 32 |
|
31 | 33 |
|
32 | | -runs = ["random", "bm25", "samenoun", "left", "right", "neighbors"] |
| 34 | +# runs is of form {run_dir_name => name} |
| 35 | +if args.oracle: |
| 36 | + runs = { |
| 37 | + f"oracle_{run}": run |
| 38 | + for run in ["random", "bm25", "samenoun", "before", "after", "surrounding"] |
| 39 | + } |
| 40 | +elif args.restricted: |
| 41 | + runs = {"oracle_bm25": "bm25", "bm25_restricted": "restricted bm25"} |
| 42 | +else: |
| 43 | + runs = { |
| 44 | + run: run |
| 45 | + for run in ["random", "bm25", "samenoun", "before", "after", "surrounding"] |
| 46 | + } |
33 | 47 |
|
34 | 48 | with open(f"./runs/bare/metrics.json") as f: |
35 | 49 | bare_metrics = json.load(f) |
|
42 | 56 |
|
43 | 57 | fig.set_size_inches(16, 12) |
44 | 58 |
|
45 | | -for run_i, run in enumerate(runs): |
46 | | - if args.oracle: |
47 | | - run = f"oracle_{run}" |
| 59 | +for run_i, (run, run_name) in enumerate(runs.items()): |
48 | 60 | with open(f"./runs/short/{run}/metrics.json") as f: |
49 | 61 | metrics = json.load(f) |
50 | 62 | ax.plot( |
|
54 | 66 | linewidth=4, |
55 | 67 | ) |
56 | 68 |
|
57 | | -# bare baseline |
58 | | -ax.plot( |
59 | | - [1, 6], |
60 | | - [bare_metrics["mean_test_f1"]["values"][0]] * 2, |
61 | | - linestyle=linestyle_tuple[len(runs)][1], |
62 | | - linewidth=4, |
63 | | -) |
| 69 | +# no retrieval baseline |
| 70 | +if not args.no_baseline: |
| 71 | + ax.plot( |
| 72 | + [1, 6], |
| 73 | + [bare_metrics["mean_test_f1"]["values"][0]] * 2, |
| 74 | + linestyle=linestyle_tuple[len(runs)][1], |
| 75 | + linewidth=4, |
| 76 | + ) |
64 | 77 |
|
65 | 78 | ax.grid() |
66 | 79 | ax.set_ylabel("F1", fontsize=40) |
67 | 80 | ax.set_xlabel("Number of retrieved sentences", fontsize=40) |
68 | 81 | ax.legend( |
69 | | - runs + ["no retrieval"], |
| 82 | + list(runs.values()) + ["no retrieval"] if not args.no_baseline else [], |
70 | 83 | loc="lower center", |
71 | 84 | bbox_to_anchor=(0.5, 1), |
72 | 85 | fontsize=40, |
|
0 commit comments