Skip to content

Commit e0854a6

Browse files
google-genai-botcopybara-github
authored andcommitted
feat: Support ReinforcementTuning in GenAI SDK including ValidateReward API method.
PiperOrigin-RevId: 926249899
1 parent 53ea3f6 commit e0854a6

3 files changed

Lines changed: 1938 additions & 269 deletions

File tree

Lines changed: 187 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,187 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
#
15+
16+
17+
"""Tests for tunings.validate_reward()."""
18+
19+
from ... import types as genai_types
20+
from .. import pytest_helper
21+
22+
test_table: list[pytest_helper.TestTableItem] = [
23+
pytest_helper.TestTableItem(
24+
name="test_validate_reward_single_autorater",
25+
parameters=genai_types._ValidateRewardParameters(
26+
parent="projects/801452371447/locations/us-central1",
27+
sample_response=genai_types.Content(
28+
parts=[genai_types.Part(text="The answer is 42.")]
29+
),
30+
example=genai_types.ReinforcementTuningExample(
31+
contents=[
32+
genai_types.Content(
33+
parts=[
34+
genai_types.Part(text="What is the answer to life?")
35+
]
36+
)
37+
],
38+
),
39+
single_reward_config=genai_types.SingleReinforcementTuningRewardConfig(
40+
autorater_scorer=genai_types.ReinforcementTuningAutoraterScorer(
41+
autorater_config=genai_types.AutoraterConfig(
42+
autorater_model="test-model"
43+
)
44+
)
45+
),
46+
),
47+
exception_if_mldev=(
48+
"only supported in Gemini Enterprise Agent Platform mode"
49+
),
50+
),
51+
pytest_helper.TestTableItem(
52+
name="test_validate_reward_code_execution",
53+
parameters=genai_types._ValidateRewardParameters(
54+
parent="projects/801452371447/locations/us-central1",
55+
sample_response=genai_types.Content(
56+
parts=[genai_types.Part(text="print('hello')")]
57+
),
58+
example=genai_types.ReinforcementTuningExample(
59+
contents=[
60+
genai_types.Content(
61+
parts=[
62+
genai_types.Part(
63+
text="Write a hello world program."
64+
)
65+
]
66+
)
67+
],
68+
references={"reference": "hello"},
69+
),
70+
single_reward_config=genai_types.SingleReinforcementTuningRewardConfig(
71+
reward_name="codeExecReward",
72+
parse_response_config=genai_types.ReinforcementTuningParseResponseConfig(
73+
parse_type="IDENTITY",
74+
),
75+
code_execution_reward_scorer=genai_types.ReinforcementTuningCodeExecutionRewardScorer(
76+
python_code_snippet=(
77+
"reward = 1.0 if response == references['reference']"
78+
" else 0.0"
79+
),
80+
),
81+
),
82+
),
83+
exception_if_mldev=(
84+
"only supported in Gemini Enterprise Agent Platform mode"
85+
),
86+
),
87+
pytest_helper.TestTableItem(
88+
name="test_validate_reward_string_match",
89+
parameters=genai_types._ValidateRewardParameters(
90+
parent="projects/801452371447/locations/us-central1",
91+
sample_response=genai_types.Content(
92+
parts=[genai_types.Part(text="42")]
93+
),
94+
example=genai_types.ReinforcementTuningExample(
95+
contents=[
96+
genai_types.Content(
97+
parts=[genai_types.Part(text="What is 6 times 7?")]
98+
)
99+
],
100+
references={"answer": "42"},
101+
system_instruction=genai_types.Content(
102+
parts=[genai_types.Part(text="You are a math tutor.")]
103+
),
104+
),
105+
single_reward_config=genai_types.SingleReinforcementTuningRewardConfig(
106+
reward_name="stringMatchReward",
107+
parse_response_config=genai_types.ReinforcementTuningParseResponseConfig(
108+
parse_type="REGEX_EXTRACT",
109+
regex_extract_expression=r"(\d+)",
110+
),
111+
string_match_reward_scorer=genai_types.ReinforcementTuningStringMatchRewardScorer(
112+
correct_answer_reward=1.0,
113+
wrong_answer_reward=-1.0,
114+
string_match_expression=genai_types.ReinforcementTuningStringMatchRewardScorerStringMatchExpression(
115+
match_operation="EXACT_MATCH",
116+
expression="{{references.answer}}",
117+
),
118+
),
119+
),
120+
),
121+
exception_if_mldev=(
122+
"only supported in Gemini Enterprise Agent Platform mode"
123+
),
124+
),
125+
pytest_helper.TestTableItem(
126+
name="test_validate_reward_composite",
127+
parameters=genai_types._ValidateRewardParameters(
128+
parent="projects/801452371447/locations/us-central1",
129+
sample_response=genai_types.Content(
130+
parts=[genai_types.Part(text="The answer is 42.")]
131+
),
132+
example=genai_types.ReinforcementTuningExample(
133+
contents=[
134+
genai_types.Content(
135+
parts=[
136+
genai_types.Part(text="What is the answer to life?")
137+
]
138+
)
139+
],
140+
),
141+
composite_reward_config=genai_types.CompositeReinforcementTuningRewardConfig(
142+
weighted_reward_configs=[
143+
genai_types.CompositeReinforcementTuningRewardConfigWeightedRewardConfig(
144+
weight=0.7,
145+
reward_config=genai_types.SingleReinforcementTuningRewardConfig(
146+
reward_name="autoraterReward",
147+
autorater_scorer=genai_types.ReinforcementTuningAutoraterScorer(
148+
autorater_config=genai_types.AutoraterConfig(
149+
autorater_model="test-model"
150+
),
151+
autorater_prompt=(
152+
"Rate the response: {{response}}"
153+
),
154+
exact_match_scorer=genai_types.ReinforcementTuningAutoraterScorerExactMatchScorer(
155+
correct_answer_reward=1.0,
156+
wrong_answer_reward=0.0,
157+
expression="good",
158+
),
159+
),
160+
),
161+
),
162+
genai_types.CompositeReinforcementTuningRewardConfigWeightedRewardConfig(
163+
weight=0.3,
164+
reward_config=genai_types.SingleReinforcementTuningRewardConfig(
165+
reward_name="codeReward",
166+
code_execution_reward_scorer=genai_types.ReinforcementTuningCodeExecutionRewardScorer(
167+
python_code_snippet="reward = 1.0",
168+
),
169+
),
170+
),
171+
]
172+
),
173+
),
174+
exception_if_mldev=(
175+
"only supported in Gemini Enterprise Agent Platform mode"
176+
),
177+
),
178+
]
179+
180+
pytestmark = pytest_helper.setup(
181+
file=__file__,
182+
globals_for_file=globals(),
183+
test_method="tunings.validate_reward",
184+
test_table=test_table,
185+
)
186+
187+
pytest_plugins = ("pytest_asyncio",)

0 commit comments

Comments
 (0)