@@ -55,12 +55,11 @@ def check_gpu_ram():
5555 total = torch .cuda .get_device_properties (device ).total_memory
5656 allocated_percent = (allocated / total ) * 100
5757
58- if allocated_percent > 20 : # This is is only a subset of RAM usage, but indicative of high usage
58+ if allocated_percent > 20 : # This is is only a subset of GPU RAM usage, but indicative of high usage
5959 torch .cuda .empty_cache ()
6060 torch .cuda .synchronize ()
6161
6262 gc .collect ()
63- print ("Attempting to free GPU memory..." )
6463
6564
6665def process_batch_texts (
@@ -71,16 +70,19 @@ def process_batch_texts(
7170 do_nlp ,
7271 keep_all ,
7372 using_gpu ,
73+ count ,
74+ progress ,
75+ progress_prefix ,
7476):
7577 nlp = load_language_model (language_model , normalize_options )
7678 results = []
7779 text_fetcher = TextFetcher (nlp , ** text_fetcher_args ) # Initialize text_fetcher with required params
7880 for tokens , _ in text_fetcher (batch_texts , do_nlp = do_nlp , keep_all = keep_all , progress = False ):
7981 if isinstance (tokens , PreparedDoc ):
8082 spacy_doc = make_spacy_doc (nlp , tokens )
81- if spacy_doc ._ .char_num > 100000 and using_gpu is True :
83+ if spacy_doc ._ .char_num > 10000 and using_gpu is True :
8284 split_doc = split_spacy_docs (nlp , spacy_doc )
83- doc = Doc .from_docs (list (nlp .pipe (split_doc , batch_size = 128 )))
85+ doc = Doc .from_docs (list (nlp .pipe (split_doc , batch_size = 64 )))
8486 doc ._ .metadata = spacy_doc ._ .metadata
8587 results .append (Tokens (doc , keep_all = keep_all ))
8688 else :
@@ -89,6 +91,11 @@ def process_batch_texts(
8991 results .append (Tokens (tokens , keep_all = keep_all ))
9092 else :
9193 results .append (tokens )
94+ if using_gpu :
95+ check_gpu_ram ()
96+ if progress :
97+ count += 1
98+ print (f"\r { progress_prefix } { count } texts processed..." , end = "" , flush = True )
9299 return results
93100
94101
@@ -209,7 +216,7 @@ def __init__(
209216 else :
210217 self .do_nlp = False
211218
212- def __process_batch (self , pool , batch , keep_all ):
219+ def __process_batch (self , pool , batch , keep_all , count , progress , progress_prefix ):
213220 for tokens in pool .apply (
214221 process_batch_texts ,
215222 (
@@ -220,6 +227,9 @@ def __process_batch(self, pool, batch, keep_all):
220227 self .do_nlp ,
221228 keep_all ,
222229 self .using_gpu ,
230+ count ,
231+ progress ,
232+ progress_prefix ,
223233 ),
224234 ):
225235 if self .ngram_config is not None :
@@ -242,21 +252,16 @@ def process_texts(
242252 print (f"\r { progress_prefix } { count } texts processed..." , end = "" , flush = True )
243253 for text in texts :
244254 current_batch .append (text )
245- if len (current_batch ) >= 20 :
255+ if len (current_batch ) >= 100 :
246256 with mp .Pool (1 ) as pool :
247- yield from self .__process_batch (pool , current_batch , keep_all )
257+ yield from self .__process_batch (pool , current_batch , keep_all , count , progress , progress_prefix )
248258 count += len (current_batch )
249- if progress :
250- print (f"\r { progress_prefix } { count } texts processed..." , end = "" , flush = True )
251259 current_batch = []
252260
253261 # Process the remaining texts
254262 if current_batch :
255263 with mp .Pool (1 ) as pool :
256- yield from self .__process_batch (pool , current_batch , keep_all )
257- count += len (current_batch )
258- if progress :
259- print (f"\r { progress_prefix } { count } texts processed..." , end = "" , flush = True )
264+ yield from self .__process_batch (pool , current_batch , keep_all , count , progress , progress_prefix )
260265
261266 def process_string (self , text : str , keep_all : bool = True ) -> Tokens :
262267 """Take a string and return a list of preprocessed tokens"""
0 commit comments