@@ -502,6 +502,14 @@ def __normalize_token(self, orig_token: Token | PreprocessorToken) -> str:
502502 return token
503503
504504
505+ @Language .component ("clear_trf_data" )
506+ def clear_trf_data (doc ):
507+ """Clear the cache of a doc to free GPU memory"""
508+ if hasattr (doc ._ , "trf_data" ):
509+ doc ._ .trf_data = None
510+ return doc
511+
512+
505513def load_language_model (language , normalize_options : dict [str , Any ]) -> Language :
506514 """Load language model based on name"""
507515 nlp = None
@@ -521,17 +529,17 @@ def load_language_model(language, normalize_options: dict[str, Any]) -> Language
521529 normalize_options ["ents_to_keep" ],
522530 )
523531 ):
524- diabled_pipelines = ["tokenizer" , "textcat" ]
532+ disabled_pipelines = ["tokenizer" , "textcat" ]
525533 if not normalize_options ["pos_to_keep" ]:
526- diabled_pipelines .append ("tagger" )
534+ disabled_pipelines .append ("tagger" )
527535 if not normalize_options ["ents_to_keep" ]:
528- diabled_pipelines .append ("ner" )
536+ disabled_pipelines .append ("ner" )
529537 model_loaded = ""
530538 set_gpu_allocator ("pytorch" )
531- prefer_gpu ()
539+ use_gpu = prefer_gpu ()
532540 for model in possible_models :
533541 try :
534- nlp = spacy .load (model , exclude = diabled_pipelines )
542+ nlp = spacy .load (model , exclude = disabled_pipelines )
535543 print ("Using Spacy model" , model )
536544 except OSError :
537545 pass
@@ -541,6 +549,8 @@ def load_language_model(language, normalize_options: dict[str, Any]) -> Language
541549 if nlp is None :
542550 print (f"No Spacy model installed for the { language } language. Stopping..." )
543551 exit (- 1 )
552+ if use_gpu is True :
553+ nlp .add_pipe ("clear_trf_data" , last = True )
544554 nlp .add_pipe ("postprocessor" , config = normalize_options , last = True )
545555 if normalize_options ["ents_to_keep" ] and "ner" not in nlp .pipe_names :
546556 print (f"There is no NER pipeline for model { model_loaded } . Exiting..." )
0 commit comments