diff --git a/src/modelgauge/annotator.py b/src/modelgauge/annotator.py index 7abc4cf84..5173baeb7 100644 --- a/src/modelgauge/annotator.py +++ b/src/modelgauge/annotator.py @@ -44,3 +44,9 @@ def annotate(self, annotation_request): def translate_response(self, request, response) -> SafetyAnnotation: """Convert the raw response into the standardized SafetyAnnotation.""" pass + + def process(self, prompt: TextPrompt | ChatPrompt, response: SUTResponse) -> SafetyAnnotation: + """End-to-end processing of a single prompt+response pair.""" + annotator_request = self.translate_prompt(prompt, response) + annotator_response = self.annotate(annotator_request) + return self.translate_response(annotator_request, annotator_response) diff --git a/src/modelgauge/ensemble_annotator.py b/src/modelgauge/ensemble_annotator.py index 43874dc08..91756268b 100644 --- a/src/modelgauge/ensemble_annotator.py +++ b/src/modelgauge/ensemble_annotator.py @@ -15,11 +15,19 @@ class EnsembleAnnotator(Annotator): def __init__(self, uid, annotators: list[str], ensemble_strategy: str, secrets: RawSecrets | None = None): super().__init__(uid) - self.annotators = self._make_annotators(annotators, secrets) + self.annotator_uids: list[str] = annotators + self._secrets: RawSecrets | None = secrets + self._annotators: dict[str, Annotator] = {} if ensemble_strategy not in ENSEMBLE_STRATEGIES: raise ValueError(f"Ensemble strategy {ensemble_strategy} not recognized.") self.ensemble_strategy = ENSEMBLE_STRATEGIES[ensemble_strategy] + @property + def annotators(self) -> dict[str, Annotator]: + if len(self._annotators) != len(self.annotator_uids): + self._annotators = self._make_annotators(self.annotator_uids, self._secrets) + return self._annotators + def _make_annotators(self, annotator_uids: list[str], secrets: RawSecrets | None) -> dict[str, Annotator]: if secrets is None: secrets = load_secrets_from_config() @@ -53,3 +61,10 @@ def translate_response(self, request: dict[str, Any], response: dict[str, Any]): joined_responses=annotations, metadata=ensemble_annotation.metadata, # TODO: Merge metadata here instead of in strategy ) + + +def get_annotator_component_ids(annotator: Annotator) -> list[str]: + if isinstance(annotator, EnsembleAnnotator): + return annotator.annotator_uids + else: + return [annotator.uid]