Skip to content

Commit 4a3fa90

Browse files
committed
Users can now manually configure step parameters passed at pipeline_init time
1 parent 1f4f314 commit 4a3fa90

1 file changed

Lines changed: 23 additions & 12 deletions

File tree

renard/pipeline/core.py

Lines changed: 23 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ def __init__(self):
8282

8383
def _pipeline_init_(
8484
self, lang: str, progress_reporter: ProgressReporter, **kwargs
85-
) -> Optional[Dict[Pipeline.PipelineParameter, Any]]:
85+
) -> Optional[Dict[str, Any]]:
8686
"""Set the step configuration that is common to the whole
8787
pipeline.
8888
@@ -576,16 +576,13 @@ def update(slider_value):
576576
class Pipeline:
577577
"""A flexible NLP pipeline"""
578578

579-
#: all the possible parameters of the whole pipeline, that are
580-
#: shared between steps
581-
PipelineParameter = Literal["lang", "progress_reporter", "character_ner_tag"]
582-
583579
def __init__(
584580
self,
585581
steps: List[PipelineStep],
586582
lang: str = "eng",
587583
progress_report: Optional[Literal["tqdm"]] = "tqdm",
588584
warn: bool = True,
585+
**step_additional_params,
589586
) -> None:
590587
"""
591588
:param steps: a ``tuple`` of :class:``PipelineStep``, that
@@ -595,16 +592,28 @@ def __init__(
595592
progress.
596593
:param lang: ISO 639-3 language code
597594
:param warn:
595+
596+
:param step_additional_params: additional parameters passed to
597+
:meth:`._pipeline_init_` when
598+
:meth:`_pipeline_init_steps_` is called. The following
599+
values are currently used:
600+
601+
- ``'character_ner_tag'``: the NER tag corresponding
602+
to characters (default: ``PER``)
598603
"""
599604
self.steps = steps
600605

601606
self.progress_report: Optional[Literal["tqdm"]] = progress_report
602607
self.progress_reporter = get_progress_reporter(progress_report)
603608

604609
self.lang = lang
605-
self.character_ner_tag = "PER"
606610
self.warn = warn
607611

612+
self.step_additional_params = step_additional_params
613+
self.step_additional_params["character_ner_tag"] = (
614+
self.step_additional_params.get("character_ner_tag", "PER")
615+
)
616+
608617
def _pipeline_init_steps_(self, ignored_steps: Optional[List[str]] = None):
609618
"""Initialise steps with global pipeline parameters.
610619
@@ -613,16 +622,18 @@ def _pipeline_init_steps_(self, ignored_steps: Optional[List[str]] = None):
613622
"""
614623
steps_progress_reporter = self.progress_reporter.get_subreporter()
615624
steps = self._non_ignored_steps(ignored_steps)
616-
pipeline_params = {
617-
"progress_reporter": steps_progress_reporter,
618-
"character_ner_tag": self.character_ner_tag,
619-
}
625+
pipeline_params = self.step_additional_params.copy()
620626
for step in steps:
621-
step_additional_params = step._pipeline_init_(self.lang, **pipeline_params)
627+
step_additional_params = step._pipeline_init_(
628+
self.lang, progress_reporter=steps_progress_reporter, **pipeline_params
629+
)
622630
if not step_additional_params is None:
623631
for key, value in step_additional_params.items():
624632
setattr(self, key, value)
625-
pipeline_params[key] = value
633+
# parameters set by the user have precedence over
634+
# step mandated parameters
635+
if not key in self.step_additional_params:
636+
pipeline_params[key] = value
626637

627638
def _non_ignored_steps(
628639
self, ignored_steps: Optional[List[str]]

0 commit comments

Comments
 (0)