Skip to content

Commit 3bc3cdb

Browse files
authored
New Google sdk (#1054)
1 parent 4b42dfa commit 3bc3cdb

6 files changed

Lines changed: 435 additions & 41 deletions

File tree

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
"""
2+
This file defines google SUTs that use Google's genai python SDK.
3+
"""
4+
5+
from typing import Optional
6+
7+
from google import genai
8+
from google.api_core.exceptions import (
9+
InternalServerError,
10+
ResourceExhausted,
11+
RetryError,
12+
TooManyRequests,
13+
)
14+
from google.genai.types import GenerateContentConfig, GenerateContentResponse, ThinkingConfig
15+
from pydantic import BaseModel
16+
17+
from modelgauge.general import APIException
18+
from modelgauge.prompt import TextPrompt
19+
from modelgauge.retry_decorator import retry
20+
from modelgauge.secret_values import InjectSecret
21+
from modelgauge.sut import REFUSAL_RESPONSE, PromptResponseSUT, SUTOptions, SUTResponse # usort: skip
22+
from modelgauge.suts.google_generativeai import (
23+
GOOGLE_REFUSAL_FINISH_REASONS,
24+
GoogleAiApiKey,
25+
) # Both SDKs use the same API key.
26+
from modelgauge.sut_capabilities import AcceptsTextPrompt
27+
from modelgauge.sut_decorator import modelgauge_sut
28+
from modelgauge.sut_registry import SUTS
29+
30+
31+
class GenAiRequest(BaseModel):
32+
model: str
33+
contents: str
34+
config: Optional[GenerateContentConfig] = None
35+
36+
37+
@modelgauge_sut(capabilities=[AcceptsTextPrompt])
38+
class GoogleGenAiSUT(PromptResponseSUT[GenAiRequest, GenerateContentResponse]):
39+
def __init__(self, uid: str, model_name: str, reasoning: bool, api_key: GoogleAiApiKey):
40+
super().__init__(uid)
41+
self.model_name = model_name
42+
self.client: Optional[genai.Client] = None
43+
self.reasoning = reasoning
44+
self.api_key = api_key.value
45+
46+
def _load_client(self) -> genai.Client:
47+
return genai.Client(api_key=self.api_key)
48+
49+
def translate_text_prompt(self, prompt: TextPrompt, options: SUTOptions) -> GenAiRequest:
50+
optional = {}
51+
if not self.reasoning:
52+
optional["thinking_config"] = ThinkingConfig(
53+
thinking_budget=0, # Turn off reasoning.
54+
)
55+
generation_config = GenerateContentConfig(
56+
stop_sequences=options.stop_sequences,
57+
max_output_tokens=options.max_tokens,
58+
temperature=options.temperature,
59+
top_p=options.top_p,
60+
top_k=options.top_k_per_token,
61+
presence_penalty=options.presence_penalty,
62+
frequency_penalty=options.frequency_penalty,
63+
**optional,
64+
)
65+
return GenAiRequest(
66+
model=self.model_name,
67+
contents=prompt.text,
68+
config=generation_config,
69+
)
70+
71+
@retry(transient_exceptions=[InternalServerError, ResourceExhausted, RetryError, TooManyRequests])
72+
def evaluate(self, request: GenAiRequest) -> GenerateContentResponse:
73+
if self.client is None:
74+
# Handle lazy init.
75+
self.client = self._load_client()
76+
return self.client.models.generate_content(**request.model_dump(exclude_none=True))
77+
78+
def translate_response(self, request: GenAiRequest, response: GenerateContentResponse) -> SUTResponse:
79+
assert (
80+
len(response.candidates) <= 1
81+
), f"Expected a single candidate in the response, got {len(response.candidates)}."
82+
if len(response.candidates) == 0:
83+
# This is apparently a refusal. At least, it's what happens consistently with a set of
84+
# prompts in the CSE, SRC, and SXC hazards
85+
response_text = REFUSAL_RESPONSE
86+
else:
87+
candidate = response.candidates[0]
88+
if candidate.finish_reason in GOOGLE_REFUSAL_FINISH_REASONS:
89+
response_text = REFUSAL_RESPONSE
90+
elif candidate.content is not None:
91+
response_text = candidate.content.parts[0].text
92+
else:
93+
raise APIException(
94+
f"Unexpected candidate in response from GoogleGenAiSUT {self.uid}: {candidate}. "
95+
f"The candidate does not have any content,"
96+
f" but it's finish reason {candidate.finish_reason} does not qualify as a refusal."
97+
)
98+
99+
return SUTResponse(text=response_text)
100+
101+
102+
models = ["gemini-2.5-flash-preview-05-20"]
103+
for model in models:
104+
SUTS.register(
105+
GoogleGenAiSUT,
106+
f"google-genai-{model}-no-reasoning",
107+
model,
108+
False,
109+
InjectSecret(GoogleAiApiKey),
110+
)

plugins/google/modelgauge/suts/google_genai_client.py renamed to plugins/google/modelgauge/suts/google_generativeai.py

Lines changed: 24 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,9 @@
1+
"""
2+
This file defines google SUTs that use Google's generativeai python SDK.
3+
This SDK is older and does not include the latest features e.g. reasoning configuration. generativeai will be deprecated in the future.
4+
The SUTs defined in this file should be migrated to use google's newer SDK `genai` as implemented in google_genai.py.
5+
"""
6+
17
from abc import abstractmethod
28
from typing import Dict, List, Optional
39

@@ -49,8 +55,8 @@ def description(cls) -> SecretDescription:
4955
)
5056

5157

52-
class GoogleGenAiConfig(BaseModel):
53-
"""Generation config for Google Gen AI requests.
58+
class GenerativeAiConfig(BaseModel):
59+
"""Generation config for Google Generative AI requests.
5460
5561
Based on https://ai.google.dev/api/generate-content#v1beta.GenerationConfig
5662
"""
@@ -64,13 +70,13 @@ class GoogleGenAiConfig(BaseModel):
6470
frequency_penalty: Optional[float] = None
6571

6672

67-
class GoogleGenAiRequest(BaseModel):
73+
class GenerativeAiRequest(BaseModel):
6874
contents: str
69-
generation_config: GoogleGenAiConfig
75+
generation_config: GenerativeAiConfig
7076
safety_settings: Optional[Dict[HarmCategory, HarmBlockThreshold]] = None
7177

7278

73-
class GoogleGenAiResponse(BaseModel):
79+
class GenerativeAiResponse(BaseModel):
7480
class Candidate(BaseModel):
7581
content: Optional[Dict] = None
7682
finish_reason: int
@@ -79,7 +85,7 @@ class Candidate(BaseModel):
7985
usage_metadata: Dict
8086

8187

82-
class GoogleGenAiBaseSUT(PromptResponseSUT[GoogleGenAiRequest, GoogleGenAiResponse]):
88+
class GoogleGenerativeAiBaseSUT(PromptResponseSUT[GenerativeAiRequest, GenerativeAiResponse]):
8389
def __init__(self, uid: str, model_name: str, api_key: GoogleAiApiKey):
8490
super().__init__(uid)
8591
self.model_name = model_name
@@ -101,8 +107,8 @@ def safety_settings(self) -> Optional[Dict[HarmCategory, HarmBlockThreshold]]:
101107
def _load_client(self) -> genai.GenerativeModel:
102108
return genai.GenerativeModel(self.model_name)
103109

104-
def translate_text_prompt(self, prompt: TextPrompt, options: SUTOptions) -> GoogleGenAiRequest:
105-
generation_config = GoogleGenAiConfig(
110+
def translate_text_prompt(self, prompt: TextPrompt, options: SUTOptions) -> GenerativeAiRequest:
111+
generation_config = GenerativeAiConfig(
106112
stop_sequences=options.stop_sequences,
107113
max_output_tokens=options.max_tokens,
108114
temperature=options.temperature,
@@ -111,20 +117,20 @@ def translate_text_prompt(self, prompt: TextPrompt, options: SUTOptions) -> Goog
111117
presence_penalty=options.presence_penalty,
112118
frequency_penalty=options.frequency_penalty,
113119
)
114-
return GoogleGenAiRequest(
120+
return GenerativeAiRequest(
115121
contents=prompt.text, generation_config=generation_config, safety_settings=self.safety_settings
116122
)
117123

118124
@retry(transient_exceptions=[InternalServerError, ResourceExhausted, RetryError, TooManyRequests])
119-
def evaluate(self, request: GoogleGenAiRequest) -> GoogleGenAiResponse:
125+
def evaluate(self, request: GenerativeAiRequest) -> GenerativeAiResponse:
120126
if self.model is None:
121127
# Handle lazy init.
122128
self.model = self._load_client()
123129
response = self.model.generate_content(**request.model_dump(exclude_none=True))
124130
# Convert to pydantic model
125-
return GoogleGenAiResponse(**response.to_dict())
131+
return GenerativeAiResponse(**response.to_dict())
126132

127-
def translate_response(self, request: GoogleGenAiRequest, response: GoogleGenAiResponse) -> SUTResponse:
133+
def translate_response(self, request: GenerativeAiRequest, response: GenerativeAiResponse) -> SUTResponse:
128134
assert (
129135
len(response.candidates) <= 1
130136
), f"Expected a single candidate in the response, got {len(response.candidates)}."
@@ -140,7 +146,7 @@ def translate_response(self, request: GoogleGenAiRequest, response: GoogleGenAiR
140146
response_text = candidate.content["parts"][0]["text"]
141147
else:
142148
raise APIException(
143-
f"Unexpected candidate in response from GoogleGenAiSUT {self.uid}: {candidate}. "
149+
f"Unexpected candidate in response from GoogleGenerativeAiSUT {self.uid}: {candidate}. "
144150
f"The candidate does not have any content,"
145151
f" but it's finish reason {candidate.finish_reason} does not qualify as a refusal."
146152
)
@@ -149,7 +155,7 @@ def translate_response(self, request: GoogleGenAiRequest, response: GoogleGenAiR
149155

150156

151157
@modelgauge_sut(capabilities=[AcceptsTextPrompt])
152-
class GoogleGenAiDefaultSUT(GoogleGenAiBaseSUT):
158+
class GoogleGenerativeAiDefaultSUT(GoogleGenerativeAiBaseSUT):
153159
"""SUT for Google Generative AI model with the model's default safety settings.
154160
As of 11/20/2024: The default settings are:
155161
"Block most (for gemini-1.5-pro-002 and gemini-1.5-flash-002 only) or Block some (in all other models)
@@ -168,7 +174,7 @@ def safety_settings(self) -> Optional[Dict[HarmCategory, HarmBlockThreshold]]:
168174

169175

170176
@modelgauge_sut(capabilities=[AcceptsTextPrompt])
171-
class GoogleGeminiDisabledSafetySettingsSUT(GoogleGenAiBaseSUT):
177+
class GoogleGeminiDisabledSafetySettingsSUT(GoogleGenerativeAiBaseSUT):
172178
"""SUT for Google Gemini model that removes that harm block threshold for all Gemini-specific harm categories."""
173179

174180
@property
@@ -182,7 +188,7 @@ def safety_settings(self) -> Optional[Dict[HarmCategory, HarmBlockThreshold]]:
182188

183189

184190
@modelgauge_sut(capabilities=[AcceptsTextPrompt])
185-
class GoogleGenAiSafetyOnSUT(GoogleGenAiBaseSUT):
191+
class GoogleGenerativeAiSafetyOnSUT(GoogleGenerativeAiBaseSUT):
186192
"""SUT for Google Generative AI model with the explicit safety settings turned on (ie BLOCK_LOW_AND_ABOVE).
187193
188194
Finish reasons related to safety are treated as refusal responses."""
@@ -207,8 +213,8 @@ def safety_settings(self) -> Optional[Dict[HarmCategory, HarmBlockThreshold]]:
207213
"gemini-2.5-pro-preview-05-06",
208214
]
209215
for model in gemini_models:
210-
SUTS.register(GoogleGenAiDefaultSUT, model, model, InjectSecret(GoogleAiApiKey))
216+
SUTS.register(GoogleGenerativeAiDefaultSUT, model, model, InjectSecret(GoogleAiApiKey))
211217
SUTS.register(
212218
GoogleGeminiDisabledSafetySettingsSUT, f"{model}-safety_block_none", model, InjectSecret(GoogleAiApiKey)
213219
)
214-
SUTS.register(GoogleGenAiSafetyOnSUT, f"{model}-safety_block_most", model, InjectSecret(GoogleAiApiKey))
220+
SUTS.register(GoogleGenerativeAiSafetyOnSUT, f"{model}-safety_block_most", model, InjectSecret(GoogleAiApiKey))

plugins/google/pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ packages = [{include = "modelgauge"}]
99
[tool.poetry.dependencies]
1010
python = "^3.10"
1111
google-generativeai = "^0.8.0"
12+
google-genai = "^1.17.0"
1213

1314

1415
[build-system]

0 commit comments

Comments
 (0)