|
21 | 21 | from concurrent import futures |
22 | 22 | from typing import Callable, Optional, Set, TYPE_CHECKING, Union, List |
23 | 23 |
|
| 24 | +from google import genai |
| 25 | +from google.cloud import aiplatform |
24 | 26 | from google.cloud.aiplatform import base |
25 | 27 | from google.cloud.aiplatform_v1beta1.types import ( |
26 | 28 | content as gapic_content_types, |
@@ -70,6 +72,107 @@ def _assemble_prompt( |
70 | 72 | ) |
71 | 73 |
|
72 | 74 |
|
| 75 | +def _generate_content_text_response_genai( |
| 76 | + model: str, client: genai.Client, prompt: str, max_retries: int = 3 |
| 77 | +) -> str: |
| 78 | + """Generates a text response from Gemini model from a text prompt with retries using genai module. |
| 79 | +
|
| 80 | + Args: |
| 81 | + model: The model name string. |
| 82 | + client: The genai client instance. |
| 83 | + prompt: The prompt to send to the model. |
| 84 | + max_retries: Maximum number of retries for response generation. |
| 85 | +
|
| 86 | + Returns: |
| 87 | + The text response from the model. |
| 88 | + Returns constants.RESPONSE_ERROR if there is an error after all retries. |
| 89 | + """ |
| 90 | + for retry_attempt in range(max_retries): |
| 91 | + try: |
| 92 | + response = client.models.generate_content( |
| 93 | + model=model, |
| 94 | + contents=prompt, |
| 95 | + ) |
| 96 | + # The new SDK raises exceptions on blocked content instead of returning |
| 97 | + # block_reason directly, so if it succeeds, we can return the text. |
| 98 | + if response.text: |
| 99 | + return response.text |
| 100 | + else: |
| 101 | + _LOGGER.warning( |
| 102 | + "The model response was empty or blocked.\n" |
| 103 | + f"Prompt: {prompt}.\n" |
| 104 | + f"Retry attempt: {retry_attempt + 1}/{max_retries}" |
| 105 | + ) |
| 106 | + except Exception as e: # pylint: disable=broad-except |
| 107 | + error_message = ( |
| 108 | + f"Failed to generate response candidates from GenAI model " |
| 109 | + f"{model}.\n" |
| 110 | + f"Error: {e}.\n" |
| 111 | + f"Prompt: {prompt}.\n" |
| 112 | + f"Retry attempt: {retry_attempt + 1}/{max_retries}" |
| 113 | + ) |
| 114 | + _LOGGER.warning(error_message) |
| 115 | + if retry_attempt < max_retries - 1: |
| 116 | + _LOGGER.info( |
| 117 | + f"Retrying response generation for prompt: {prompt}, attempt " |
| 118 | + f"{retry_attempt + 1}/{max_retries}..." |
| 119 | + ) |
| 120 | + |
| 121 | + final_error_message = ( |
| 122 | + f"Failed to generate response from GenAI model {model}.\n" f"Prompt: {prompt}." |
| 123 | + ) |
| 124 | + _LOGGER.warning(final_error_message) |
| 125 | + return constants.RESPONSE_ERROR |
| 126 | + |
| 127 | + |
| 128 | +def _generate_responses_from_genai_model( |
| 129 | + model: str, |
| 130 | + df: "pd.DataFrame", |
| 131 | + rubric_generation_prompt_template: Optional[str] = None, |
| 132 | +) -> List[str]: |
| 133 | + """Generates responses from Google GenAI SDK for the given evaluation dataset.""" |
| 134 | + _LOGGER.info( |
| 135 | + f"Generating a total of {df.shape[0]} " |
| 136 | + f"responses from Google GenAI model {model}." |
| 137 | + ) |
| 138 | + tasks = [] |
| 139 | + client = genai.Client( |
| 140 | + vertexai=True, |
| 141 | + project=aiplatform.initializer.global_config.project, |
| 142 | + location=aiplatform.initializer.global_config.location, |
| 143 | + ) |
| 144 | + |
| 145 | + with tqdm(total=len(df)) as pbar: |
| 146 | + with futures.ThreadPoolExecutor(max_workers=constants.MAX_WORKERS) as executor: |
| 147 | + for idx, row in df.iterrows(): |
| 148 | + if rubric_generation_prompt_template: |
| 149 | + input_columns = prompt_template_base.PromptTemplate( |
| 150 | + rubric_generation_prompt_template |
| 151 | + ).variables |
| 152 | + if multimodal_utils.is_multimodal_instance( |
| 153 | + row[list(input_columns)].to_dict() |
| 154 | + ): |
| 155 | + prompt = multimodal_utils._assemble_multi_modal_prompt( |
| 156 | + rubric_generation_prompt_template, row, idx, input_columns |
| 157 | + ) |
| 158 | + else: |
| 159 | + prompt = _assemble_prompt( |
| 160 | + row, rubric_generation_prompt_template |
| 161 | + ) |
| 162 | + else: |
| 163 | + prompt = row[constants.Dataset.PROMPT_COLUMN] |
| 164 | + task = executor.submit( |
| 165 | + _generate_content_text_response_genai, |
| 166 | + prompt=prompt, |
| 167 | + model=model, |
| 168 | + client=client, |
| 169 | + ) |
| 170 | + task.add_done_callback(lambda _: pbar.update(1)) |
| 171 | + tasks.append(task) |
| 172 | + responses = [future.result() for future in tasks] |
| 173 | + return responses |
| 174 | + |
| 175 | + |
73 | 176 | def _generate_content_text_response( |
74 | 177 | model: generative_models.GenerativeModel, prompt: str, max_attempts: int = 3 |
75 | 178 | ) -> str: |
|
0 commit comments