@@ -63,14 +63,13 @@ def check_gpu_ram():
6363
6464
6565def process_batch_texts (
66- text_fetcher_args , batch_texts , language_model , normalize_options , do_nlp , keep_all , using_gpu , progress_info
66+ queue , text_fetcher_args , batch_texts , language_model , normalize_options , do_nlp , keep_all , progress_info
6767):
68- nlp = load_language_model (language_model , normalize_options )
69- results = []
68+ nlp , using_gpu = load_language_model (language_model , normalize_options )
7069 text_fetcher = TextFetcher (nlp , ** text_fetcher_args ) # Initialize text_fetcher with required params
7170 previous_philo_id = None
7271 for tokens , _ in text_fetcher (batch_texts , do_nlp = do_nlp , keep_all = keep_all , progress = False ):
73- if isinstance (tokens , PreparedDoc ):
72+ if isinstance (tokens , PreparedDoc ) and using_gpu is True :
7473 spacy_doc = make_spacy_doc (nlp , tokens )
7574 if spacy_doc ._ .char_num > 10000 and using_gpu is True :
7675 split_doc = split_spacy_docs (nlp , spacy_doc )
@@ -101,8 +100,8 @@ def process_batch_texts(
101100 flush = True ,
102101 )
103102 previous_philo_id = current_doc_id
104- results . append (tokens )
105- return results
103+ queue . put (tokens )
104+ queue . put ( None )
106105
107106
108107def split_spacy_docs (nlp , doc : Doc ) -> list [Doc ]:
@@ -168,8 +167,6 @@ def __init__(
168167 pos_to_keep : list [str ] | bool = False ,
169168 ents_to_keep : list [str ] | bool = False ,
170169 post_processing_function : Callable | None = None ,
171- nlp_model : Language | None = None ,
172- using_gpu : bool = False ,
173170 ** _ , # this is meant to make the constructor accept invalid keywords
174171 ):
175172 self .normalize_options = {
@@ -188,14 +185,6 @@ def __init__(
188185 }
189186 self .language = language
190187 self .language_model = language_model
191- self .using_gpu = spacy .prefer_gpu ()
192- if workers is None :
193- cpu_count = os .cpu_count () or 2
194- self .workers = cpu_count - 1
195- else :
196- self .workers = workers
197- if self .using_gpu is True :
198- self .workers = 1
199188 ngrams = ngrams or 0
200189 if ngrams :
201190 self .ngram_config = {
@@ -206,6 +195,21 @@ def __init__(
206195 else :
207196 self .ngram_config = None
208197 self .post_func = post_processing_function
198+ if workers is None :
199+ cpu_count = os .cpu_count () or 2
200+ self .workers = cpu_count - 1
201+ else :
202+ self .workers = workers
203+ if self .normalize_options ["pos_to_keep" ] or self .normalize_options ["ents_to_keep" ] or lemmatizer == "spacy" :
204+ self .do_nlp = True
205+ else :
206+ self .do_nlp = False
207+ if self .do_nlp is False :
208+ using_gpu = False
209+ else :
210+ using_gpu = spacy .prefer_gpu ()
211+ if using_gpu is True :
212+ self .workers = 1
209213 self .text_fetcher_args = {
210214 "word_regex" : word_regex ,
211215 "sentence_boundaries" : sentence_boundaries ,
@@ -217,31 +221,36 @@ def __init__(
217221 "workers" : self .workers ,
218222 "ngram_config" : self .ngram_config ,
219223 }
220- if self .normalize_options ["pos_to_keep" ] or self .normalize_options ["ents_to_keep" ] or lemmatizer == "spacy" :
221- self .do_nlp = True
222- else :
223- self .do_nlp = False
224224
225- def __process_batch (self , pool , batch , keep_all , progress_info ):
226- for tokens in pool .apply (
227- process_batch_texts ,
228- (
225+ def __process_batch (self , batch , keep_all , progress_info ):
226+ queue = mp .Queue ()
227+ process = mp .Process (
228+ target = process_batch_texts ,
229+ args = (
230+ queue ,
229231 self .text_fetcher_args ,
230232 batch ,
231233 self .language_model ,
232234 self .normalize_options ,
233235 self .do_nlp ,
234236 keep_all ,
235- self .using_gpu ,
236237 progress_info ,
237238 ),
238- ):
239+ )
240+ process .start ()
241+
242+ while True :
243+ tokens = queue .get () # This blocks until data is available
244+ if tokens is None : # End signal
245+ break
239246 if self .ngram_config is not None :
240247 tokens = generate_ngrams (** self .ngram_config , tokens = tokens )
241248 if self .post_func is not None :
242249 tokens = self .post_func (tokens )
243250 yield tokens
244251
252+ process .join ()
253+
245254 def process_texts (
246255 self ,
247256 texts : Iterable [str ],
@@ -259,17 +268,15 @@ def process_texts(
259268 print (f"\r { progress_prefix } 0 text chunks of 0 documents processed..." , end = "" , flush = True )
260269 for text in texts :
261270 current_batch .append (text )
262- progress_info ["doc_count" ] += 1
263271 if len (current_batch ) >= 100 :
264- with mp .Pool (1 ) as pool :
265- yield from self .__process_batch (pool , current_batch , keep_all , progress_info )
266- progress_info ["count" ] += 1
272+ yield from self .__process_batch (current_batch , keep_all , progress_info )
273+ progress_info ["count" ] += 1
267274 current_batch = []
275+ progress_info ["doc_count" ] += 100
268276
269277 # Process the remaining texts
270278 if current_batch :
271- with mp .Pool (1 ) as pool :
272- yield from self .__process_batch (pool , current_batch , keep_all , progress_info )
279+ yield from self .__process_batch (current_batch , keep_all , progress_info )
273280
274281 def process_string (self , text : str , keep_all : bool = True ) -> Tokens :
275282 """Take a string and return a list of preprocessed tokens"""
0 commit comments