-
Notifications
You must be signed in to change notification settings - Fork 28
Prep for annotator logprobs #1449
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
2b0c1f0
ed7ccf7
91da439
83b5837
14e585e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,85 @@ | ||
| from typing import Optional, List, Sequence | ||
|
|
||
| from pydantic import BaseModel, model_validator | ||
|
|
||
|
|
||
| class ModelOptions(BaseModel): | ||
| """ | ||
| An exhaustive set of options that could potentially be desired by a model. | ||
|
|
||
| Not all SUTs and annotators respect all options. | ||
| """ | ||
|
|
||
| max_tokens: Optional[int] = None | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This no longer defaults to 100. The CLI sets the default instead. |
||
| """Maximum number of tokens to generate (per completion)""" | ||
|
|
||
| max_total_output_tokens: Optional[int] = None | ||
| """Maximum number of tokens for all generated SUT outputs, including reasoning.""" | ||
|
|
||
| temperature: Optional[float] = None | ||
| """Temperature parameter that governs diversity""" | ||
|
|
||
| top_k_per_token: Optional[int] = None | ||
| """Take this many highest probability candidates per token in the completion""" | ||
|
|
||
| stop_sequences: Optional[List[str]] = None | ||
| """Stop generating once we hit one of these strings.""" | ||
|
|
||
| top_p: Optional[float] = None | ||
| """Same from tokens that occupy this probability mass (nucleus sampling)""" | ||
|
|
||
| presence_penalty: Optional[float] = None | ||
| """Penalize repetition (OpenAI & Writer only)""" | ||
|
|
||
| frequency_penalty: Optional[float] = None | ||
| """Penalize repetition (OpenAI & Writer only)""" | ||
|
|
||
| random: Optional[str] = None | ||
| """Used to control randomness. Expect different responses for the same | ||
| request but with different values for `random`.""" | ||
|
|
||
| # Must specify SUTCapabilities for these | ||
| top_logprobs: Optional[int] = None | ||
| """If present, will request the log probabilities for this | ||
| many of the top tokens at each token position.""" | ||
|
|
||
| @model_validator(mode="after") | ||
| def check_max_total_output_tokens(self): | ||
| if ( | ||
| self.max_total_output_tokens is not None | ||
| and self.max_tokens is not None | ||
| and self.max_total_output_tokens < self.max_tokens | ||
| ): | ||
| raise ValueError( | ||
| f"Invalid ModelOptions. max_total_output_tokens ({self.max_total_output_tokens}) must be >= max_tokens ({self.max_tokens})." | ||
| ) | ||
| return self | ||
|
|
||
| @staticmethod | ||
| def create_from_arguments(max_tokens=None, temp=None, top_p=None, top_k=None, top_logprobs=None): | ||
| options = ModelOptions() | ||
| if max_tokens is not None: | ||
| options.max_tokens = max_tokens | ||
| if temp is not None: | ||
| options.temperature = temp | ||
| if top_p is not None: | ||
| options.top_p = top_p | ||
| if top_k is not None: | ||
| options.top_k_per_token = top_k | ||
| if top_logprobs is not None: | ||
| options.top_logprobs = top_logprobs | ||
|
|
||
| return options | ||
|
|
||
|
|
||
| class TokenProbability(BaseModel): | ||
| """Probability assigned to a given token.""" | ||
|
|
||
| token: str | ||
| logprob: float | ||
|
|
||
|
|
||
| class TopTokens(BaseModel): | ||
| """List of most likely tokens and their probabilities.""" | ||
|
|
||
| top_tokens: Sequence[TokenProbability] | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -14,7 +14,7 @@ | |
| from modelgauge.pipeline import Pipeline | ||
| from modelgauge.prompt_pipeline import PromptSink, PromptSource, PromptSutAssigner, PromptSutWorkers | ||
| from modelgauge.ready import ReadyResponses, Readyable | ||
| from modelgauge.sut import SUTOptions | ||
| from modelgauge.model_options import ModelOptions | ||
|
|
||
| logger = get_logger(__name__) | ||
|
|
||
|
|
@@ -26,14 +26,12 @@ def __init__( | |
| input_dataset, | ||
| output_dir, | ||
| cache_dir=None, | ||
| sut_options=SUTOptions(), | ||
| tag=None, | ||
| ): | ||
| self.num_workers = num_workers | ||
| self.input_dataset = input_dataset | ||
| self.root_dir = output_dir | ||
| self.cache_dir = cache_dir | ||
| self.sut_options = sut_options | ||
| self.tag = tag | ||
| self.pipeline_segments = [] | ||
| self.start_time = datetime.datetime.now() | ||
|
|
@@ -120,7 +118,9 @@ def _write_metadata(self): | |
|
|
||
|
|
||
| class PromptRunner(PipelineRunner): | ||
| def __init__(self, suts, **kwargs): | ||
| def __init__(self, suts, sut_options=ModelOptions(), **kwargs): | ||
| self.sut_options = sut_options | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I moved this out of the |
||
| logger.info(f"Using SUT options: {self.sut_options}") | ||
| self.suts = suts | ||
| self.sut_worker = None # Convenience pointer. | ||
| super().__init__(**kwargs) | ||
|
|
@@ -278,6 +278,7 @@ def build_runner( | |
| sut_uid_col=None, | ||
| sut_response_col=None, | ||
| jailbreak=False, | ||
| sut_options=None, | ||
| **kwargs, | ||
| ): | ||
| if jailbreak and not (annotators and suts): | ||
|
|
@@ -304,9 +305,11 @@ def build_runner( | |
| ) | ||
| # Build runner | ||
| if suts and annotators: | ||
| pipeline_runner = PromptPlusAnnotatorRunner(suts=suts, annotators=annotators, input_dataset=dataset, **kwargs) | ||
| pipeline_runner = PromptPlusAnnotatorRunner( | ||
| suts=suts, annotators=annotators, input_dataset=dataset, sut_options=sut_options, **kwargs | ||
| ) | ||
| elif suts: | ||
| pipeline_runner = PromptRunner(suts=suts, input_dataset=dataset, **kwargs) | ||
| pipeline_runner = PromptRunner(suts=suts, input_dataset=dataset, sut_options=sut_options, **kwargs) | ||
| elif annotators: | ||
| pipeline_runner = AnnotatorRunner(annotators=annotators, input_dataset=dataset, **kwargs) | ||
| else: | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Everything in this file was moved from
sut.py