|
1 | | -import os |
2 | | -from typing import Any, Dict, List |
3 | | - |
4 | | -from modelgauge.annotator_set import AnnotatorSet |
5 | | -from modelgauge.single_turn_prompt_response import SUTResponseAnnotations |
| 1 | +from modelgauge.ensemble_annotator_set import EnsembleAnnotatorSet |
6 | 2 |
|
7 | 3 | try: |
8 | 4 | from modelgauge.annotators.llama_3_70b_chat_annotator import ( |
|
23 | 19 | raise NotImplementedError("Private annotators are not available.") |
24 | 20 |
|
25 | 21 |
|
26 | | -class EnsembleAnnotatorSet(AnnotatorSet): |
27 | | - annotators: List[str] = [] |
28 | | - |
29 | | - def __init__(self): |
30 | | - self.strategy = MajorityVoteEnsembleStrategy() |
31 | | - self.annotators = [ |
32 | | - MISTRAL_7B_RUBY_CONFIG.name, |
33 | | - MISTRAL_8x22B_CONFIG.name, |
34 | | - LLAMA_3_70B_CONFIG.name, |
35 | | - LG3_LORA_CONFIG.name, |
36 | | - ] |
37 | | - |
38 | | - def evaluate(self, item: SUTResponseAnnotations) -> Dict[str, float]: |
39 | | - individual_annotations = {} |
40 | | - for annotator_id in self.annotators: |
41 | | - if annotator_id in item.annotations: |
42 | | - individual_annotations[annotator_id] = item.get_annotation(annotator_id, SafetyModelResponse) |
43 | | - annotation = self.strategy.compute_response(individual_annotations) |
44 | | - return {"is_safe": float(annotation.is_safe)} |
45 | | - |
46 | | - |
47 | | -PRIVATE_ANNOTATOR_SET = EnsembleAnnotatorSet() |
| 22 | +PRIVATE_ANNOTATOR_SET = EnsembleAnnotatorSet( |
| 23 | + annotators=[ |
| 24 | + MISTRAL_7B_RUBY_CONFIG.name, |
| 25 | + MISTRAL_8x22B_CONFIG.name, |
| 26 | + LLAMA_3_70B_CONFIG.name, |
| 27 | + LG3_LORA_CONFIG.name, |
| 28 | + ], |
| 29 | + strategy=MajorityVoteEnsembleStrategy(), |
| 30 | +) |
0 commit comments