@@ -82,7 +82,7 @@ def __init__(self):
8282
8383 def _pipeline_init_ (
8484 self , lang : str , progress_reporter : ProgressReporter , ** kwargs
85- ) -> Optional [Dict [Pipeline . PipelineParameter , Any ]]:
85+ ) -> Optional [Dict [str , Any ]]:
8686 """Set the step configuration that is common to the whole
8787 pipeline.
8888
@@ -576,16 +576,13 @@ def update(slider_value):
576576class Pipeline :
577577 """A flexible NLP pipeline"""
578578
579- #: all the possible parameters of the whole pipeline, that are
580- #: shared between steps
581- PipelineParameter = Literal ["lang" , "progress_reporter" , "character_ner_tag" ]
582-
583579 def __init__ (
584580 self ,
585581 steps : List [PipelineStep ],
586582 lang : str = "eng" ,
587583 progress_report : Optional [Literal ["tqdm" ]] = "tqdm" ,
588584 warn : bool = True ,
585+ ** step_additional_params ,
589586 ) -> None :
590587 """
591588 :param steps: a ``tuple`` of :class:``PipelineStep``, that
@@ -595,16 +592,28 @@ def __init__(
595592 progress.
596593 :param lang: ISO 639-3 language code
597594 :param warn:
595+
596+ :param step_additional_params: additional parameters passed to
597+ :meth:`._pipeline_init_` when
598+ :meth:`_pipeline_init_steps_` is called. The following
599+ values are currently used:
600+
601+ - ``'character_ner_tag'``: the NER tag corresponding
602+ to characters (default: ``PER``)
598603 """
599604 self .steps = steps
600605
601606 self .progress_report : Optional [Literal ["tqdm" ]] = progress_report
602607 self .progress_reporter = get_progress_reporter (progress_report )
603608
604609 self .lang = lang
605- self .character_ner_tag = "PER"
606610 self .warn = warn
607611
612+ self .step_additional_params = step_additional_params
613+ self .step_additional_params ["character_ner_tag" ] = (
614+ self .step_additional_params .get ("character_ner_tag" , "PER" )
615+ )
616+
608617 def _pipeline_init_steps_ (self , ignored_steps : Optional [List [str ]] = None ):
609618 """Initialise steps with global pipeline parameters.
610619
@@ -613,16 +622,18 @@ def _pipeline_init_steps_(self, ignored_steps: Optional[List[str]] = None):
613622 """
614623 steps_progress_reporter = self .progress_reporter .get_subreporter ()
615624 steps = self ._non_ignored_steps (ignored_steps )
616- pipeline_params = {
617- "progress_reporter" : steps_progress_reporter ,
618- "character_ner_tag" : self .character_ner_tag ,
619- }
625+ pipeline_params = self .step_additional_params .copy ()
620626 for step in steps :
621- step_additional_params = step ._pipeline_init_ (self .lang , ** pipeline_params )
627+ step_additional_params = step ._pipeline_init_ (
628+ self .lang , progress_reporter = steps_progress_reporter , ** pipeline_params
629+ )
622630 if not step_additional_params is None :
623631 for key , value in step_additional_params .items ():
624632 setattr (self , key , value )
625- pipeline_params [key ] = value
633+ # parameters set by the user have precedence over
634+ # step mandated parameters
635+ if not key in self .step_additional_params :
636+ pipeline_params [key ] = value
626637
627638 def _non_ignored_steps (
628639 self , ignored_steps : Optional [List [str ]]
0 commit comments