Skip to content

Commit 7daaa46

Browse files
committed
run SpaCy operations inside a separate process to avoid GPU OOM issues
1 parent a8b572f commit 7daaa46

File tree

3 files changed

+152
-103
lines changed

3 files changed

+152
-103
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
install_requires=[
1313
"unidecode",
1414
"PyStemmer",
15-
"spacy>=3.7,<3.8",
15+
"spacy>=3.8",
1616
"orjson",
1717
"requests",
1818
"lz4",

text_preprocessing/preprocessor.py

Lines changed: 147 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,34 @@
11
#!/usr/bin/env python3
22
"""Text Preprocessor"""
33

4+
import gc
5+
import multiprocessing as mp
46
import os
57
import sqlite3
68
import sys
9+
import warnings
710
from collections import defaultdict, deque
811
from dataclasses import dataclass
912
from itertools import combinations
1013
from typing import Any, Callable, DefaultDict, Deque, Iterable
1114

15+
import cupy as cp
1216
import lz4.frame
1317
import orjson
1418
import regex as re
19+
import spacy
20+
import torch
1521
from multiprocess.pool import Pool
1622
from spacy.language import Language
1723
from spacy.tokens import Doc, Token
1824

1925
from .modernizer import Modernizer
2026
from .spacy_helpers import PreprocessorToken, Tokens, load_language_model
2127

28+
# Suppress all UserWarning messages
29+
warnings.filterwarnings("ignore", category=UserWarning)
30+
mp.set_start_method("spawn", force=True)
31+
2232
Doc.set_extension("metadata", default={})
2333
Doc.set_extension("char_num", default=0)
2434
Token.set_extension("ext", default={})
@@ -38,6 +48,73 @@
3848
PHILO_OBJECT_LEVEL: dict[int, str] = {1: "doc", 2: "div1", 3: "div2", 4: "div3", 5: "para", 6: "sent", 7: "word"}
3949

4050

51+
def check_gpu_ram():
52+
"""Returns the percentage of GPU memory being used."""
53+
device = torch.cuda.current_device()
54+
allocated = torch.cuda.memory_allocated(device)
55+
total = torch.cuda.get_device_properties(device).total_memory
56+
allocated_percent = (allocated / total) * 100
57+
58+
if allocated_percent > 20: # This is is only a subset of RAM usage, but indicative of high usage
59+
torch.cuda.empty_cache()
60+
torch.cuda.synchronize()
61+
62+
gc.collect()
63+
print("Attempting to free GPU memory...")
64+
65+
66+
def process_batch_texts(
67+
text_fetcher_args,
68+
batch_texts,
69+
language_model,
70+
normalize_options,
71+
do_nlp,
72+
keep_all,
73+
using_gpu,
74+
):
75+
nlp = load_language_model(language_model, normalize_options)
76+
results = []
77+
text_fetcher = TextFetcher(nlp, **text_fetcher_args) # Initialize text_fetcher with required params
78+
for tokens, _ in text_fetcher(batch_texts, do_nlp=do_nlp, keep_all=keep_all, progress=False):
79+
if isinstance(tokens, PreparedDoc):
80+
spacy_doc = make_spacy_doc(nlp, tokens)
81+
if spacy_doc._.char_num > 100000 and using_gpu is True:
82+
split_doc = split_spacy_docs(nlp, spacy_doc)
83+
doc = Doc.from_docs(list(nlp.pipe(split_doc, batch_size=128)))
84+
doc._.metadata = spacy_doc._.metadata
85+
results.append(Tokens(doc, keep_all=keep_all))
86+
else:
87+
results.append(Tokens(nlp(spacy_doc), keep_all=keep_all))
88+
elif isinstance(tokens, Doc):
89+
results.append(Tokens(tokens, keep_all=keep_all))
90+
else:
91+
results.append(tokens)
92+
return results
93+
94+
95+
def split_spacy_docs(nlp, doc: Doc) -> list[Doc]:
96+
"""Split spacy doc into smaller docs of 10 sentences"""
97+
sentence_group: list[Doc] = []
98+
docs: list[Doc] = []
99+
for sent in doc.sents:
100+
if len(sentence_group) == 10:
101+
docs.append(Doc.from_docs(sentence_group))
102+
sentence_group = []
103+
else:
104+
sent_starts = []
105+
words = []
106+
for token in sent:
107+
sent_starts.append(token.is_sent_start)
108+
words.append(token.text)
109+
sent_doc = Doc(nlp.vocab, words, sent_starts=sent_starts)
110+
for pos, token in enumerate(sent):
111+
sent_doc[pos]._.ext = token._.ext
112+
sentence_group.append(sent_doc)
113+
if sentence_group:
114+
docs.append(Doc.from_docs(sentence_group))
115+
return docs
116+
117+
41118
@dataclass(slots=True)
42119
class PreparedDoc:
43120
"""Prepared doc for conversion to Spacy Doc object"""
@@ -97,12 +174,8 @@ def __init__(
97174
"ents_to_keep": ents_to_keep or [],
98175
}
99176
self.language = language
100-
self.using_gpu = using_gpu
101-
if nlp_model is not None:
102-
self.nlp = nlp_model
103-
else:
104-
self.nlp, using_gpu = load_language_model(language_model, self.normalize_options)
105-
self.using_gpu = using_gpu
177+
self.language_model = language_model
178+
self.using_gpu = spacy.prefer_gpu()
106179
if workers is None:
107180
cpu_count = os.cpu_count() or 2
108181
self.workers = cpu_count - 1
@@ -120,23 +193,41 @@ def __init__(
120193
else:
121194
self.ngram_config = None
122195
self.post_func = post_processing_function
123-
self.text_fetcher = TextFetcher(
124-
self.nlp,
125-
word_regex=word_regex,
126-
sentence_boundaries=sentence_boundaries,
127-
language=language,
128-
modernize=modernize,
129-
strip_tags=strip_tags,
130-
is_philo_db=is_philo_db,
131-
text_object_type=text_object_type,
132-
ngram_config=self.ngram_config,
133-
workers=self.workers,
134-
)
196+
self.text_fetcher_args = {
197+
"word_regex": word_regex,
198+
"sentence_boundaries": sentence_boundaries,
199+
"language": language,
200+
"modernize": modernize,
201+
"strip_tags": strip_tags,
202+
"is_philo_db": is_philo_db,
203+
"text_object_type": text_object_type,
204+
"workers": self.workers,
205+
"ngram_config": self.ngram_config,
206+
}
135207
if self.normalize_options["pos_to_keep"] or self.normalize_options["ents_to_keep"] or lemmatizer == "spacy":
136208
self.do_nlp = True
137209
else:
138210
self.do_nlp = False
139211

212+
def __process_batch(self, pool, batch, keep_all):
213+
for tokens in pool.apply(
214+
process_batch_texts,
215+
(
216+
self.text_fetcher_args,
217+
batch,
218+
self.language_model,
219+
self.normalize_options,
220+
self.do_nlp,
221+
keep_all,
222+
self.using_gpu,
223+
),
224+
):
225+
if self.ngram_config is not None:
226+
tokens = generate_ngrams(**self.ngram_config, tokens=tokens)
227+
if self.post_func is not None:
228+
tokens = self.post_func(tokens)
229+
yield tokens
230+
140231
def process_texts(
141232
self,
142233
texts: Iterable[str],
@@ -145,82 +236,48 @@ def process_texts(
145236
progress_prefix="Processing texts...",
146237
) -> Iterable[Tokens]:
147238
"""Process all documents. Returns an iterator of documents"""
239+
148240
count = 0
149-
fetched_texts = self.text_fetcher(
150-
texts, do_nlp=self.do_nlp, keep_all=keep_all, progress=progress, post_func=self.post_func
151-
)
152-
if self.text_fetcher.text_object_type == "sent" and self.do_nlp is True:
153-
fetched_texts = self.nlp.pipe(
154-
((make_spacy_doc(self.nlp, tokens), c) for tokens, c in fetched_texts),
155-
as_tuples=True,
156-
batch_size=250,
157-
)
158-
for tokens, doc_count in fetched_texts:
159-
count += 1
160-
if progress is True:
161-
if doc_count is not None: # workaround for sent and para since nlp.pipe does not return context...
162-
print(
163-
f"\r{progress_prefix} {doc_count} done: {count} text objects extracted... ",
164-
end="",
165-
flush=True,
166-
)
167-
else:
168-
print(
169-
f"\r{progress_prefix} {count} text objects extracted... ",
170-
end="",
171-
flush=True,
172-
)
173-
if isinstance(tokens, PreparedDoc):
174-
spacy_doc = make_spacy_doc(self.nlp, tokens)
175-
if spacy_doc._.char_num > 100000 and self.using_gpu is True: # being conservative to preserve GPU RAM
176-
split_doc = self.__split_spacy_docs(spacy_doc)
177-
rebuilt_doc = Doc.from_docs(list(self.nlp.pipe(split_doc, batch_size=128)))
178-
rebuilt_doc._.metadata = spacy_doc._.metadata
179-
tokens = Tokens(rebuilt_doc, keep_all=keep_all)
180-
else:
181-
tokens = Tokens(self.nlp(spacy_doc), keep_all=keep_all)
182-
if self.ngram_config is not None:
183-
tokens = generate_ngrams(**self.ngram_config, tokens=tokens)
184-
if self.post_func is not None:
185-
tokens = self.post_func(tokens)
186-
yield tokens
187-
elif isinstance(tokens, Doc):
188-
tokens = Tokens(tokens, keep_all=keep_all)
189-
if self.ngram_config is not None:
190-
tokens = generate_ngrams(**self.ngram_config, tokens=tokens)
191-
if self.post_func is not None:
192-
tokens = self.post_func(tokens)
193-
yield tokens
194-
else:
195-
yield tokens
241+
current_batch = []
242+
print(f"\r{progress_prefix} {count} texts processed...", end="", flush=True)
243+
for text in texts:
244+
current_batch.append(text)
245+
if len(current_batch) >= 20:
246+
with mp.Pool(1) as pool:
247+
yield from self.__process_batch(pool, current_batch, keep_all)
248+
count += len(current_batch)
249+
if progress:
250+
print(f"\r{progress_prefix} {count} texts processed...", end="", flush=True)
251+
current_batch = []
252+
253+
# Process the remaining texts
254+
if current_batch:
255+
with mp.Pool(1) as pool:
256+
yield from self.__process_batch(pool, current_batch, keep_all)
257+
count += len(current_batch)
258+
if progress:
259+
print(f"\r{progress_prefix} {count} texts processed...", end="", flush=True)
196260

197261
def process_string(self, text: str, keep_all: bool = True) -> Tokens:
198262
"""Take a string and return a list of preprocessed tokens"""
199-
doc = self.text_fetcher.process_string(text)
200-
processed_doc = self.nlp(doc)
201-
return Tokens(processed_doc, keep_all=keep_all)
202-
203-
def __split_spacy_docs(self, doc: Doc) -> list[Doc]:
204-
"""Split spacy doc into smaller docs of 10 sentences"""
205-
sentence_group: list[Doc] = []
206-
docs: list[Doc] = []
207-
for sent in doc.sents:
208-
if len(sentence_group) == 10:
209-
docs.append(Doc.from_docs(sentence_group))
210-
sentence_group = []
211-
else:
212-
sent_starts = []
213-
words = []
214-
for token in sent:
215-
sent_starts.append(token.is_sent_start)
216-
words.append(token.text)
217-
sent_doc = Doc(self.nlp.vocab, words, sent_starts=sent_starts)
218-
for pos, token in enumerate(sent):
219-
sent_doc[pos]._.ext = token._.ext
220-
sentence_group.append(sent_doc)
221-
if sentence_group:
222-
docs.append(Doc.from_docs(sentence_group))
223-
return docs
263+
mp.set_start_method("spawn")
264+
with mp.Pool(1) as pool:
265+
for tokens in pool.apply(
266+
process_batch_texts,
267+
(
268+
self.text_fetcher_args,
269+
[text],
270+
self.language_model,
271+
self.normalize_options,
272+
self.do_nlp,
273+
keep_all,
274+
self.ngram_config,
275+
self.post_func,
276+
),
277+
):
278+
output_tokens = Tokens(tokens, keep_all=keep_all)
279+
break
280+
return output_tokens
224281

225282

226283
class TextFetcher:

text_preprocessing/spacy_helpers.py

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -473,15 +473,7 @@ def __normalize_token(self, orig_token: Token | PreprocessorToken) -> str:
473473
return token
474474

475475

476-
@Language.component("clear_trf_data")
477-
def clear_trf_data(doc):
478-
"""Clear the cache of a doc to free GPU memory"""
479-
if hasattr(doc._, "trf_data"):
480-
doc._.trf_data = None
481-
return doc
482-
483-
484-
def load_language_model(language_model, normalize_options: dict[str, Any]) -> tuple[Language, bool]:
476+
def load_language_model(language_model, normalize_options: dict[str, Any]) -> Language:
485477
"""Load language model based on name"""
486478
nlp = None
487479
if language_model is not None and any(
@@ -508,12 +500,12 @@ def load_language_model(language_model, normalize_options: dict[str, Any]) -> tu
508500
)
509501
exit(-1)
510502
if use_gpu is True:
511-
nlp.add_pipe("clear_trf_data", last=True)
503+
nlp.add_pipe("doc_cleaner", last=True, config={"attrs": {"tensor": None}})
512504
nlp.add_pipe("postprocessor", config=normalize_options, last=True)
513505
if normalize_options["ents_to_keep"] and "ner" not in nlp.pipe_names:
514506
print(f"There is no NER pipeline for model {language_model}. Exiting...")
515507
exit(-1)
516-
return nlp, use_gpu
508+
return nlp
517509
nlp = spacy.blank("en")
518510
nlp.add_pipe("postprocessor", config=normalize_options, last=True)
519-
return nlp, False
511+
return nlp

0 commit comments

Comments
 (0)