Skip to content

Commit bb8f826

Browse files
committed
better handling of GPU out-of-memory issues
1 parent f9c7a7a commit bb8f826

File tree

2 files changed

+17
-7
lines changed

2 files changed

+17
-7
lines changed

text_preprocessing/preprocessor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -143,11 +143,11 @@ def process_texts(
143143
fetched_texts = self.text_fetcher(
144144
texts, do_nlp=self.do_nlp, keep_all=keep_all, progress=progress, post_func=self.post_func
145145
)
146-
if self.text_fetcher.text_object_type in ("para", "sent") and self.do_nlp is True:
146+
if self.text_fetcher.text_object_type == "sent" and self.do_nlp is True:
147147
fetched_texts = self.nlp.pipe(
148148
((make_spacy_doc(self.nlp, tokens), c) for tokens, c in fetched_texts),
149149
as_tuples=True,
150-
batch_size=500,
150+
batch_size=250,
151151
)
152152
for tokens, doc_count in fetched_texts:
153153
count += 1

text_preprocessing/spacy_helpers.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -502,6 +502,14 @@ def __normalize_token(self, orig_token: Token | PreprocessorToken) -> str:
502502
return token
503503

504504

505+
@Language.component("clear_trf_data")
506+
def clear_trf_data(doc):
507+
"""Clear the cache of a doc to free GPU memory"""
508+
if hasattr(doc._, "trf_data"):
509+
doc._.trf_data = None
510+
return doc
511+
512+
505513
def load_language_model(language, normalize_options: dict[str, Any]) -> Language:
506514
"""Load language model based on name"""
507515
nlp = None
@@ -521,17 +529,17 @@ def load_language_model(language, normalize_options: dict[str, Any]) -> Language
521529
normalize_options["ents_to_keep"],
522530
)
523531
):
524-
diabled_pipelines = ["tokenizer", "textcat"]
532+
disabled_pipelines = ["tokenizer", "textcat"]
525533
if not normalize_options["pos_to_keep"]:
526-
diabled_pipelines.append("tagger")
534+
disabled_pipelines.append("tagger")
527535
if not normalize_options["ents_to_keep"]:
528-
diabled_pipelines.append("ner")
536+
disabled_pipelines.append("ner")
529537
model_loaded = ""
530538
set_gpu_allocator("pytorch")
531-
prefer_gpu()
539+
use_gpu = prefer_gpu()
532540
for model in possible_models:
533541
try:
534-
nlp = spacy.load(model, exclude=diabled_pipelines)
542+
nlp = spacy.load(model, exclude=disabled_pipelines)
535543
print("Using Spacy model", model)
536544
except OSError:
537545
pass
@@ -541,6 +549,8 @@ def load_language_model(language, normalize_options: dict[str, Any]) -> Language
541549
if nlp is None:
542550
print(f"No Spacy model installed for the {language} language. Stopping...")
543551
exit(-1)
552+
if use_gpu is True:
553+
nlp.add_pipe("clear_trf_data", last=True)
544554
nlp.add_pipe("postprocessor", config=normalize_options, last=True)
545555
if normalize_options["ents_to_keep"] and "ner" not in nlp.pipe_names:
546556
print(f"There is no NER pipeline for model {model_loaded}. Exiting...")

0 commit comments

Comments
 (0)