Skip to content

Commit 8218862

Browse files
committed
fix a crash when using context retrievers
1 parent 44c84f2 commit 8218862

3 files changed

Lines changed: 22 additions & 4 deletions

File tree

renard/pipeline/ner/ner.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -281,7 +281,7 @@ def batch_labels(
281281
batch_labels = ["O"] * len(tokens)
282282

283283
try:
284-
inference_start = ctxmask[batch_i].tolist().index(1)
284+
inference_start = ctxmask[batch_i].tolist().index(0)
285285
except ValueError:
286286
inference_start = 0
287287

@@ -290,7 +290,7 @@ def batch_labels(
290290
if token_i is None:
291291
continue
292292

293-
if ctxmask[batch_i][token_i] == 0:
293+
if ctxmask[batch_i][token_i] == 1:
294294
continue
295295

296296
batch_labels[token_i - inference_start] = wp_label

renard/pipeline/ner/retrieval.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,14 +66,15 @@ def __call__(self, dataset: NERDataset) -> NERDataset:
6666

6767
elements_with_context.append((lctx, elt, rctx))
6868

69-
return NERDataset(
69+
ner_dataset = NERDataset(
7070
[lctx + element + rctx for lctx, element, rctx in elements_with_context],
7171
dataset.tokenizer,
7272
[
7373
[1] * len(lctx) + [0] * len(element) + [1] * len(rctx)
7474
for lctx, element, rctx in elements_with_context
7575
],
7676
)
77+
return ner_dataset
7778

7879

7980
class NERSamenounContextRetriever(NERContextRetriever):

tests/test_ner.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,9 @@
55
from hypothesis.control import assume
66
from hypothesis.strategies import lists, sampled_from
77
from transformers import BertTokenizerFast
8+
from renard.pipeline.progress import get_progress_reporter
89
from renard.ner_utils import NERDataset
9-
from renard.pipeline.ner import ner_entities, score_ner
10+
from renard.pipeline.ner import ner_entities, score_ner, BertNamedEntityRecognizer
1011
from renard.pipeline.ner.retrieval import (
1112
NERBM25ContextRetriever,
1213
NERContextRetriever,
@@ -34,6 +35,22 @@ def test_has_correct_number_of_entities(tokens: List[str]):
3435
assert len(entities) == len(tokens)
3536

3637

38+
@pytest.mark.skipif(os.getenv("RENARD_TEST_SLOW") != "1", reason="performance")
39+
def test_run_with_context_retriever():
40+
ner_step = BertNamedEntityRecognizer(
41+
context_retriever=NERNeighborsContextRetriever(k=2)
42+
)
43+
ner_step._pipeline_init_(lang="eng", progress_reporter=get_progress_reporter(None))
44+
# known crash in Renard==0.7.1
45+
sentences = [
46+
"Whether i shall turn out to be the hero of my own life , or whether that station will be held by anybody else , these pages must show .".split(),
47+
"To begin my life with the beginning of my life , i record that i was born ( as i have been informed and believe ) on a friday , at twelve o'clock at night .".split(),
48+
"This was the fault of Dr. Strange .".split(),
49+
]
50+
tokens = [token for tokens in sentences for token in tokens]
51+
_ = ner_step(tokens, sentences)
52+
53+
3754
@pytest.mark.skipif(os.getenv("RENARD_TEST_SLOW") != "1", reason="performance")
3855
@pytest.mark.parametrize(
3956
"retriever_class", [NERSamenounContextRetriever, NERBM25ContextRetriever]

0 commit comments

Comments
 (0)