Skip to content

Commit 1c82004

Browse files
committed
limit workers to 1 when using GPU to avoid running out of RAM when the buffer is full
1 parent 053412b commit 1c82004

File tree

2 files changed

+8
-6
lines changed

2 files changed

+8
-6
lines changed

text_preprocessing/preprocessor.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,12 +98,14 @@ def __init__(
9898
if nlp_model is not None:
9999
self.nlp = nlp_model
100100
else:
101-
self.nlp = load_language_model(self.language, self.normalize_options)
101+
self.nlp, using_gpu = load_language_model(self.language, self.normalize_options)
102102
if workers is None:
103103
cpu_count = os.cpu_count() or 2
104104
self.workers = cpu_count - 1
105105
else:
106106
self.workers = workers
107+
if using_gpu is True:
108+
self.workers = 1
107109
ngrams = ngrams or 0
108110
if ngrams:
109111
self.ngram_config = {
@@ -124,7 +126,7 @@ def __init__(
124126
is_philo_db=is_philo_db,
125127
text_object_type=text_object_type,
126128
ngram_config=self.ngram_config,
127-
workers=workers,
129+
workers=self.workers,
128130
)
129131
if self.normalize_options["pos_to_keep"] or self.normalize_options["ents_to_keep"] or lemmatizer == "spacy":
130132
self.do_nlp = True

text_preprocessing/spacy_helpers.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""Helper functions for Spacy"""
22

33
import os
4+
import pickle
45
import re
56
import sys
67
import unicodedata
@@ -9,7 +10,6 @@
910
from typing import Any, Dict, Iterable, List, Optional, Union
1011
from xml.sax.saxutils import unescape as unescape_xml
1112

12-
import pickle
1313
import spacy
1414
from spacy.language import Language
1515
from spacy.tokens import Doc, Token
@@ -528,7 +528,7 @@ def clear_trf_data(doc):
528528
return doc
529529

530530

531-
def load_language_model(language, normalize_options: dict[str, Any]) -> Language:
531+
def load_language_model(language, normalize_options: dict[str, Any]) -> tuple[Language, bool]:
532532
"""Load language model based on name"""
533533
nlp = None
534534
language = language.lower()
@@ -573,7 +573,7 @@ def load_language_model(language, normalize_options: dict[str, Any]) -> Language
573573
if normalize_options["ents_to_keep"] and "ner" not in nlp.pipe_names:
574574
print(f"There is no NER pipeline for model {model_loaded}. Exiting...")
575575
exit(-1)
576-
return nlp
576+
return nlp, use_gpu
577577
nlp = spacy.blank("en")
578578
nlp.add_pipe("postprocessor", config=normalize_options, last=True)
579-
return nlp
579+
return nlp, False

0 commit comments

Comments
 (0)