Skip to content

Commit 9b30a83

Browse files
committed
more fixes regarding GPU RAM issues
1 parent 7daaa46 commit 9b30a83

File tree

1 file changed

+18
-13
lines changed

1 file changed

+18
-13
lines changed

text_preprocessing/preprocessor.py

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -55,12 +55,11 @@ def check_gpu_ram():
5555
total = torch.cuda.get_device_properties(device).total_memory
5656
allocated_percent = (allocated / total) * 100
5757

58-
if allocated_percent > 20: # This is is only a subset of RAM usage, but indicative of high usage
58+
if allocated_percent > 20: # This is is only a subset of GPU RAM usage, but indicative of high usage
5959
torch.cuda.empty_cache()
6060
torch.cuda.synchronize()
6161

6262
gc.collect()
63-
print("Attempting to free GPU memory...")
6463

6564

6665
def process_batch_texts(
@@ -71,16 +70,19 @@ def process_batch_texts(
7170
do_nlp,
7271
keep_all,
7372
using_gpu,
73+
count,
74+
progress,
75+
progress_prefix,
7476
):
7577
nlp = load_language_model(language_model, normalize_options)
7678
results = []
7779
text_fetcher = TextFetcher(nlp, **text_fetcher_args) # Initialize text_fetcher with required params
7880
for tokens, _ in text_fetcher(batch_texts, do_nlp=do_nlp, keep_all=keep_all, progress=False):
7981
if isinstance(tokens, PreparedDoc):
8082
spacy_doc = make_spacy_doc(nlp, tokens)
81-
if spacy_doc._.char_num > 100000 and using_gpu is True:
83+
if spacy_doc._.char_num > 10000 and using_gpu is True:
8284
split_doc = split_spacy_docs(nlp, spacy_doc)
83-
doc = Doc.from_docs(list(nlp.pipe(split_doc, batch_size=128)))
85+
doc = Doc.from_docs(list(nlp.pipe(split_doc, batch_size=64)))
8486
doc._.metadata = spacy_doc._.metadata
8587
results.append(Tokens(doc, keep_all=keep_all))
8688
else:
@@ -89,6 +91,11 @@ def process_batch_texts(
8991
results.append(Tokens(tokens, keep_all=keep_all))
9092
else:
9193
results.append(tokens)
94+
if using_gpu:
95+
check_gpu_ram()
96+
if progress:
97+
count += 1
98+
print(f"\r{progress_prefix} {count} texts processed...", end="", flush=True)
9299
return results
93100

94101

@@ -209,7 +216,7 @@ def __init__(
209216
else:
210217
self.do_nlp = False
211218

212-
def __process_batch(self, pool, batch, keep_all):
219+
def __process_batch(self, pool, batch, keep_all, count, progress, progress_prefix):
213220
for tokens in pool.apply(
214221
process_batch_texts,
215222
(
@@ -220,6 +227,9 @@ def __process_batch(self, pool, batch, keep_all):
220227
self.do_nlp,
221228
keep_all,
222229
self.using_gpu,
230+
count,
231+
progress,
232+
progress_prefix,
223233
),
224234
):
225235
if self.ngram_config is not None:
@@ -242,21 +252,16 @@ def process_texts(
242252
print(f"\r{progress_prefix} {count} texts processed...", end="", flush=True)
243253
for text in texts:
244254
current_batch.append(text)
245-
if len(current_batch) >= 20:
255+
if len(current_batch) >= 100:
246256
with mp.Pool(1) as pool:
247-
yield from self.__process_batch(pool, current_batch, keep_all)
257+
yield from self.__process_batch(pool, current_batch, keep_all, count, progress, progress_prefix)
248258
count += len(current_batch)
249-
if progress:
250-
print(f"\r{progress_prefix} {count} texts processed...", end="", flush=True)
251259
current_batch = []
252260

253261
# Process the remaining texts
254262
if current_batch:
255263
with mp.Pool(1) as pool:
256-
yield from self.__process_batch(pool, current_batch, keep_all)
257-
count += len(current_batch)
258-
if progress:
259-
print(f"\r{progress_prefix} {count} texts processed...", end="", flush=True)
264+
yield from self.__process_batch(pool, current_batch, keep_all, count, progress, progress_prefix)
260265

261266
def process_string(self, text: str, keep_all: bool = True) -> Tokens:
262267
"""Take a string and return a list of preprocessed tokens"""

0 commit comments

Comments
 (0)