Skip to content

Commit 214f62f

Browse files
committed
xp_dist
1 parent 0c3fe9e commit 214f62f

1 file changed

Lines changed: 57 additions & 0 deletions

File tree

xp_dist.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
import argparse, json
2+
import matplotlib.pyplot as plt
3+
from tqdm import tqdm
4+
from conivel.datas.context import (
5+
SameNounRetriever,
6+
BM25ContextRetriever,
7+
IdealNeuralContextRetriever,
8+
)
9+
from conivel.datas.dekker import DekkerDataset
10+
from conivel.utils import pretrained_bert_for_token_classification
11+
from conivel.train import train_ner_model
12+
13+
14+
parser = argparse.ArgumentParser()
15+
parser.add_argument("-o", "--output", type=str)
16+
parser.add_argument("-r", "--oracle", action="store_true")
17+
args = parser.parse_args()
18+
19+
20+
sn_dists = []
21+
bm25_dists = []
22+
23+
dataset = DekkerDataset()
24+
kfolds = dataset.kfolds(5, shuffle=True, shuffle_seed=0)
25+
26+
for train, test in kfolds:
27+
28+
# * retriever instantiation
29+
if args.oracle:
30+
ner_model = pretrained_bert_for_token_classification(
31+
"bert-base-cased", dataset.tag_to_id
32+
)
33+
ner_model = train_ner_model(
34+
ner_model, train, train, epochs_nb=2, learning_rate=2e-5
35+
)
36+
sn_retriever = IdealNeuralContextRetriever(
37+
1, SameNounRetriever(16), ner_model, 4, dataset.tags
38+
)
39+
bm25_retriever = IdealNeuralContextRetriever(
40+
1, BM25ContextRetriever(16), ner_model, 4, dataset.tags
41+
)
42+
else:
43+
sn_retriever = SameNounRetriever(1)
44+
bm25_retriever = BM25ContextRetriever(1)
45+
46+
# * retrieval
47+
for document in tqdm(test.documents): # TODO
48+
for sent_i, sent in enumerate(document):
49+
sn_matchs = sn_retriever.retrieve(sent_i, document)
50+
bm25_matchs = bm25_retriever.retrieve(sent_i, document)
51+
if len(sn_matchs) != 0:
52+
sn_dists.append(abs(sent_i - sn_matchs[0].sentence_idx))
53+
bm25_dists.append(abs(sent_i - bm25_matchs[0].sentence_idx))
54+
55+
56+
with open(args.output, "w") as f:
57+
json.dump({"samenoun_dists": sn_dists, "bm25_dists": bm25_dists}, f, indent=4)

0 commit comments

Comments
 (0)