@@ -63,50 +63,45 @@ def check_gpu_ram():
6363
6464
6565def process_batch_texts (
66- text_fetcher_args ,
67- batch_texts ,
68- language_model ,
69- normalize_options ,
70- do_nlp ,
71- keep_all ,
72- using_gpu ,
73- count ,
74- progress ,
75- progress_prefix ,
66+ text_fetcher_args , batch_texts , language_model , normalize_options , do_nlp , keep_all , using_gpu , progress_info
7667):
7768 nlp = load_language_model (language_model , normalize_options )
7869 results = []
7970 text_fetcher = TextFetcher (nlp , ** text_fetcher_args ) # Initialize text_fetcher with required params
8071 previous_philo_id = None
81- doc_count = 0
8272 for tokens , _ in text_fetcher (batch_texts , do_nlp = do_nlp , keep_all = keep_all , progress = False ):
8373 if isinstance (tokens , PreparedDoc ):
8474 spacy_doc = make_spacy_doc (nlp , tokens )
8575 if spacy_doc ._ .char_num > 10000 and using_gpu is True :
8676 split_doc = split_spacy_docs (nlp , spacy_doc )
8777 doc = Doc .from_docs (list (nlp .pipe (split_doc , batch_size = 64 )))
8878 doc ._ .metadata = spacy_doc ._ .metadata
89- results . append ( Tokens (doc , keep_all = keep_all ) )
79+ tokens = Tokens (doc , keep_all = keep_all )
9080 else :
91- results . append ( Tokens (nlp (spacy_doc ), keep_all = keep_all ) )
81+ tokens = Tokens (nlp (spacy_doc ), keep_all = keep_all )
9282 elif isinstance (tokens , Doc ):
93- results .append (Tokens (tokens , keep_all = keep_all ))
94- else :
95- results .append (tokens )
83+ tokens = Tokens (tokens , keep_all = keep_all )
9684 if using_gpu :
9785 check_gpu_ram ()
98- current_doc_id = results [ - 1 ] .metadata .get ("philo_id" ).split ()[0 ]
86+ current_doc_id = tokens .metadata .get ("philo_id" ).split ()[0 ]
9987 if previous_philo_id != current_doc_id :
100- doc_count += 1
101- if progress :
102- count += 1
88+ progress_info [ " doc_count" ] += 1
89+ if progress_info [ " progress" ] is True :
90+ progress_info [ " count" ] += 1
10391 if text_fetcher_args ["text_object_type" ] == "doc" :
104- print (f"\r { progress_prefix } { count } texts processed..." , end = "" , flush = True )
92+ print (
93+ f"\r { progress_info ['progress_prefix' ]} { progress_info ['count' ]} texts processed..." ,
94+ end = "" ,
95+ flush = True ,
96+ )
10597 else :
10698 print (
107- f"\r { progress_prefix } { count } text chunks of { doc_count } documents processed..." , end = "" , flush = True
99+ f"\r { progress_info ['progress_prefix' ]} { progress_info ['count' ]} text chunks of { progress_info ['doc_count' ]} documents processed..." ,
100+ end = "" ,
101+ flush = True ,
108102 )
109103 previous_philo_id = current_doc_id
104+ results .append (tokens )
110105 return results
111106
112107
@@ -227,7 +222,7 @@ def __init__(
227222 else :
228223 self .do_nlp = False
229224
230- def __process_batch (self , pool , batch , keep_all , count , progress , progress_prefix ):
225+ def __process_batch (self , pool , batch , keep_all , progress_info ):
231226 for tokens in pool .apply (
232227 process_batch_texts ,
233228 (
@@ -238,9 +233,7 @@ def __process_batch(self, pool, batch, keep_all, count, progress, progress_prefi
238233 self .do_nlp ,
239234 keep_all ,
240235 self .using_gpu ,
241- count ,
242- progress ,
243- progress_prefix ,
236+ progress_info ,
244237 ),
245238 ):
246239 if self .ngram_config is not None :
@@ -257,22 +250,25 @@ def process_texts(
257250 progress_prefix = "Processing texts..." ,
258251 ) -> Iterable [Tokens ]:
259252 """Process all documents. Returns an iterator of documents"""
260-
261- count = 0
253+ progress_info = {"count" : 0 , "doc_count" : 0 , "progress" : progress , "progress_prefix" : progress_prefix }
262254 current_batch = []
263- print (f"\r { progress_prefix } { count } texts processed..." , end = "" , flush = True )
255+ if progress is True :
256+ if self .text_fetcher_args ["text_object_type" ] == "doc" :
257+ print (f"\r { progress_prefix } 0 documents processed..." , end = "" , flush = True )
258+ else :
259+ print (f"\r { progress_prefix } 0 text chunks of 0 documents processed..." , end = "" , flush = True )
264260 for text in texts :
265261 current_batch .append (text )
262+ progress_info ["doc_count" ] += 1
266263 if len (current_batch ) >= 100 :
267264 with mp .Pool (1 ) as pool :
268- yield from self .__process_batch (pool , current_batch , keep_all , count , progress , progress_prefix )
269- count += len (current_batch )
265+ yield from self .__process_batch (pool , current_batch , keep_all , progress_info )
270266 current_batch = []
271267
272268 # Process the remaining texts
273269 if current_batch :
274270 with mp .Pool (1 ) as pool :
275- yield from self .__process_batch (pool , current_batch , keep_all , count , progress , progress_prefix )
271+ yield from self .__process_batch (pool , current_batch , keep_all , progress_info )
276272
277273 def process_string (self , text : str , keep_all : bool = True ) -> Tokens :
278274 """Take a string and return a list of preprocessed tokens"""
0 commit comments