Skip to content

Commit 28187fc

Browse files
committed
Add the NER retrieval module to reference documentation
1 parent f703acb commit 28187fc

2 files changed

Lines changed: 12 additions & 1 deletion

File tree

docs/reference.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,13 @@ NLTKNamedEntityRecognizer
6464
:members:
6565

6666

67+
Retrieval
68+
---------
69+
70+
.. automodule:: renard.pipeline.ner.retrieval
71+
:members:
72+
73+
6774
Coreference Resolution
6875
======================
6976

renard/pipeline/ner/retrieval.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@ def __hash__(self) -> int:
2929

3030

3131
class NERContextRetriever:
32+
"""Base class for NER context retrievers."""
33+
3234
def __init__(self, k: int) -> None:
3335
self.k = k
3436

@@ -278,7 +280,9 @@ def predict(self, examples: List[NERNeuralContextRetrievalExample]) -> torch.Ten
278280
self.ctx_classifier = self.ctx_classifier.to(self.device)
279281

280282
data_collator = DataCollatorWithPadding(dataset.tokenizer) # type: ignore
281-
dataloader = DataLoader(dataset, batch_size=self.batch_size, shuffle=False, collate_fn=data_collator) # type: ignore
283+
dataloader = DataLoader(
284+
dataset, batch_size=self.batch_size, shuffle=False, collate_fn=data_collator
285+
) # type: ignore
282286

283287
# inference using self.ctx_classifier
284288
self.ctx_classifier = self.ctx_classifier.eval()

0 commit comments

Comments
 (0)