Skip to content

Commit c59bf58

Browse files
authored
Refactors EnsembleAnnotatorSet initialization. (#1071)
* Refactors `EnsembleAnnotatorSet` initialization. * Fix typing. * Renames and relocates safety model response module.
1 parent 517fb99 commit c59bf58

3 files changed

Lines changed: 70 additions & 27 deletions

File tree

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
from abc import ABC, abstractmethod
2+
from typing import Dict, List
3+
4+
from modelgauge.annotator_set import AnnotatorSet
5+
from modelgauge.safety_model_response import EnsembleSafetyModelResponse, SafetyModelResponse
6+
from modelgauge.single_turn_prompt_response import SUTResponseAnnotations
7+
8+
9+
class EnsembleStrategy(ABC):
10+
11+
@abstractmethod
12+
def compute_response(self, annotations: Dict[str, SafetyModelResponse]) -> EnsembleSafetyModelResponse:
13+
pass
14+
15+
16+
# Dictionary to hold registered ensemble strategies
17+
ENSEMBLE_STRATEGIES: Dict[str, EnsembleStrategy] = {}
18+
19+
20+
class EnsembleAnnotatorSet(AnnotatorSet):
21+
annotators: List[str] = []
22+
23+
def __init__(self, annotators: List[str], strategy: EnsembleStrategy):
24+
self.strategy = strategy
25+
self.annotators = annotators
26+
27+
def evaluate(self, item: SUTResponseAnnotations) -> Dict[str, float]:
28+
individual_annotations = {}
29+
for annotator_id in self.annotators:
30+
if annotator_id in item.annotations:
31+
individual_annotations[annotator_id] = item.get_annotation(annotator_id, SafetyModelResponse)
32+
annotation = self.strategy.compute_response(individual_annotations)
33+
return {"is_safe": float(annotation.is_safe)}
Lines changed: 10 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,4 @@
1-
import os
2-
from typing import Any, Dict, List
3-
4-
from modelgauge.annotator_set import AnnotatorSet
5-
from modelgauge.single_turn_prompt_response import SUTResponseAnnotations
1+
from modelgauge.ensemble_annotator_set import EnsembleAnnotatorSet
62

73
try:
84
from modelgauge.annotators.llama_3_70b_chat_annotator import (
@@ -23,25 +19,12 @@
2319
raise NotImplementedError("Private annotators are not available.")
2420

2521

26-
class EnsembleAnnotatorSet(AnnotatorSet):
27-
annotators: List[str] = []
28-
29-
def __init__(self):
30-
self.strategy = MajorityVoteEnsembleStrategy()
31-
self.annotators = [
32-
MISTRAL_7B_RUBY_CONFIG.name,
33-
MISTRAL_8x22B_CONFIG.name,
34-
LLAMA_3_70B_CONFIG.name,
35-
LG3_LORA_CONFIG.name,
36-
]
37-
38-
def evaluate(self, item: SUTResponseAnnotations) -> Dict[str, float]:
39-
individual_annotations = {}
40-
for annotator_id in self.annotators:
41-
if annotator_id in item.annotations:
42-
individual_annotations[annotator_id] = item.get_annotation(annotator_id, SafetyModelResponse)
43-
annotation = self.strategy.compute_response(individual_annotations)
44-
return {"is_safe": float(annotation.is_safe)}
45-
46-
47-
PRIVATE_ANNOTATOR_SET = EnsembleAnnotatorSet()
22+
PRIVATE_ANNOTATOR_SET = EnsembleAnnotatorSet(
23+
annotators=[
24+
MISTRAL_7B_RUBY_CONFIG.name,
25+
MISTRAL_8x22B_CONFIG.name,
26+
LLAMA_3_70B_CONFIG.name,
27+
LG3_LORA_CONFIG.name,
28+
],
29+
strategy=MajorityVoteEnsembleStrategy(),
30+
)
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
from typing import Any, Dict, List, Optional
2+
3+
from pydantic import BaseModel
4+
5+
6+
class SafetyModelResponse(BaseModel):
7+
"""Dataclass that all safety model responses should adhere to"""
8+
9+
is_safe: bool
10+
safety_categories: List[str]
11+
is_valid: bool = True
12+
raw_input: Optional[str] = None
13+
raw_response: Optional[str] = None
14+
reasoning: Optional[str] = None
15+
total_tokens_used: Optional[float] = None
16+
logprobs: Optional[Any] = None
17+
18+
19+
class EnsembleSafetyModelResponse(SafetyModelResponse):
20+
"""Dataclass for safety model responses computed by an ensemble
21+
22+
Stores additional metadata on the original responses
23+
24+
**joined_responses**: Store the individual SafetyModelResponse objects
25+
"""
26+
27+
joined_responses: Dict[str, SafetyModelResponse]

0 commit comments

Comments
 (0)