Skip to content

Commit 40f3e41

Browse files
Ark-kuncopybara-github
authored andcommitted
chore: De-hardcoded model parameter defaults
Model interface classes support different models that might have different defaults for their parameters. SDK should not hardcode these parameters by default, letting the user to either use the model's defaults or explicitly override them. There was a recent similar case where the tuning parameter defaults were different for different tuning methods. PiperOrigin-RevId: 555129237
1 parent dec8ffd commit 40f3e41

2 files changed

Lines changed: 91 additions & 89 deletions

File tree

tests/unit/aiplatform/test_language_models.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1126,7 +1126,6 @@ def test_code_generation(self):
11261126
# Validating the parameters
11271127
predict_temperature = 0.1
11281128
predict_max_output_tokens = 100
1129-
default_temperature = language_models.CodeGenerationModel._DEFAULT_TEMPERATURE
11301129
default_max_output_tokens = (
11311130
language_models.CodeGenerationModel._DEFAULT_MAX_OUTPUT_TOKENS
11321131
)
@@ -1149,7 +1148,7 @@ def test_code_generation(self):
11491148
prefix="Write a function that checks if a year is a leap year.",
11501149
)
11511150
prediction_parameters = mock_predict.call_args[1]["parameters"]
1152-
assert prediction_parameters["temperature"] == default_temperature
1151+
assert "temperature" not in prediction_parameters
11531152
assert prediction_parameters["maxOutputTokens"] == default_max_output_tokens
11541153

11551154
def test_code_completion(self):
@@ -1192,7 +1191,6 @@ def test_code_completion(self):
11921191
# Validating the parameters
11931192
predict_temperature = 0.1
11941193
predict_max_output_tokens = 100
1195-
default_temperature = language_models.CodeGenerationModel._DEFAULT_TEMPERATURE
11961194
default_max_output_tokens = (
11971195
language_models.CodeGenerationModel._DEFAULT_MAX_OUTPUT_TOKENS
11981196
)
@@ -1215,7 +1213,7 @@ def test_code_completion(self):
12151213
prefix="def reverse_string(s):",
12161214
)
12171215
prediction_parameters = mock_predict.call_args[1]["parameters"]
1218-
assert prediction_parameters["temperature"] == default_temperature
1216+
assert "temperature" not in prediction_parameters
12191217
assert prediction_parameters["maxOutputTokens"] == default_max_output_tokens
12201218

12211219
def test_text_embedding(self):

vertexai/language_models/_language_models.py

Lines changed: 89 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -237,28 +237,25 @@ class _TextGenerationModel(_LanguageModel):
237237

238238
_INSTANCE_SCHEMA_URI = "gs://google-cloud-aiplatform/schema/predict/instance/text_generation_1.0.0.yaml"
239239

240-
_DEFAULT_TEMPERATURE = 0.0
241240
_DEFAULT_MAX_OUTPUT_TOKENS = 128
242-
_DEFAULT_TOP_P = 0.95
243-
_DEFAULT_TOP_K = 40
244241

245242
def predict(
246243
self,
247244
prompt: str,
248245
*,
249-
max_output_tokens: int = _DEFAULT_MAX_OUTPUT_TOKENS,
250-
temperature: float = _DEFAULT_TEMPERATURE,
251-
top_k: int = _DEFAULT_TOP_K,
252-
top_p: float = _DEFAULT_TOP_P,
246+
max_output_tokens: Optional[int] = _DEFAULT_MAX_OUTPUT_TOKENS,
247+
temperature: Optional[float] = None,
248+
top_k: Optional[int] = None,
249+
top_p: Optional[float] = None,
253250
) -> "TextGenerationResponse":
254251
"""Gets model response for a single prompt.
255252
256253
Args:
257254
prompt: Question to ask the model.
258-
max_output_tokens: Max length of the output text in tokens.
259-
temperature: Controls the randomness of predictions. Range: [0, 1].
260-
top_k: The number of highest probability vocabulary tokens to keep for top-k-filtering.
261-
top_p: The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Range: [0, 1].
255+
max_output_tokens: Max length of the output text in tokens. Range: [1, 1024].
256+
temperature: Controls the randomness of predictions. Range: [0, 1]. Default: 0.
257+
top_k: The number of highest probability vocabulary tokens to keep for top-k-filtering. Range: [1, 40]. Default: 40.
258+
top_p: The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Range: [0, 1]. Default: 0.95.
262259
263260
Returns:
264261
A `TextGenerationResponse` object that contains the text produced by the model.
@@ -275,19 +272,19 @@ def predict(
275272
def _batch_predict(
276273
self,
277274
prompts: List[str],
278-
max_output_tokens: int = _DEFAULT_MAX_OUTPUT_TOKENS,
279-
temperature: float = _DEFAULT_TEMPERATURE,
280-
top_k: int = _DEFAULT_TOP_K,
281-
top_p: float = _DEFAULT_TOP_P,
275+
max_output_tokens: Optional[int] = _DEFAULT_MAX_OUTPUT_TOKENS,
276+
temperature: Optional[float] = None,
277+
top_k: Optional[int] = None,
278+
top_p: Optional[float] = None,
282279
) -> List["TextGenerationResponse"]:
283280
"""Gets model response for a single prompt.
284281
285282
Args:
286283
prompts: Questions to ask the model.
287-
max_output_tokens: Max length of the output text in tokens.
288-
temperature: Controls the randomness of predictions. Range: [0, 1].
289-
top_k: The number of highest probability vocabulary tokens to keep for top-k-filtering.
290-
top_p: The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Range: [0, 1].
284+
max_output_tokens: Max length of the output text in tokens. Range: [1, 1024].
285+
temperature: Controls the randomness of predictions. Range: [0, 1]. Default: 0.
286+
top_k: The number of highest probability vocabulary tokens to keep for top-k-filtering. Range: [1, 40]. Default: 40.
287+
top_p: The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Range: [0, 1]. Default: 0.95.
291288
292289
Returns:
293290
A list of `TextGenerationResponse` objects that contain the texts produced by the model.
@@ -458,17 +455,17 @@ class _ChatModel(_TextGenerationModel):
458455
def start_chat(
459456
self,
460457
max_output_tokens: int = _TextGenerationModel._DEFAULT_MAX_OUTPUT_TOKENS,
461-
temperature: float = _TextGenerationModel._DEFAULT_TEMPERATURE,
462-
top_k: int = _TextGenerationModel._DEFAULT_TOP_K,
463-
top_p: float = _TextGenerationModel._DEFAULT_TOP_P,
458+
temperature: Optional[float] = None,
459+
top_k: Optional[int] = None,
460+
top_p: Optional[float] = None,
464461
) -> "_ChatSession":
465462
"""Starts a chat session with the model.
466463
467464
Args:
468-
max_output_tokens: Max length of the output text in tokens.
469-
temperature: Controls the randomness of predictions. Range: [0, 1].
470-
top_k: The number of highest probability vocabulary tokens to keep for top-k-filtering.
471-
top_p: The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Range: [0, 1].
465+
max_output_tokens: Max length of the output text in tokens. Range: [1, 1024].
466+
temperature: Controls the randomness of predictions. Range: [0, 1]. Default: 0.
467+
top_k: The number of highest probability vocabulary tokens to keep for top-k-filtering. Range: [1, 40]. Default: 40.
468+
top_p: The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Range: [0, 1]. Default: 0.95.
472469
473470
Returns:
474471
A `ChatSession` object.
@@ -492,9 +489,9 @@ def __init__(
492489
self,
493490
model: _ChatModel,
494491
max_output_tokens: int = _TextGenerationModel._DEFAULT_MAX_OUTPUT_TOKENS,
495-
temperature: float = _TextGenerationModel._DEFAULT_TEMPERATURE,
496-
top_k: int = _TextGenerationModel._DEFAULT_TOP_K,
497-
top_p: float = _TextGenerationModel._DEFAULT_TOP_P,
492+
temperature: Optional[float] = None,
493+
top_k: Optional[int] = None,
494+
top_p: Optional[float] = None,
498495
):
499496
self._model = model
500497
self._history = []
@@ -517,13 +514,13 @@ def send_message(
517514
518515
Args:
519516
message: Message to send to the model
520-
max_output_tokens: Max length of the output text in tokens.
517+
max_output_tokens: Max length of the output text in tokens. Range: [1, 1024].
521518
Uses the value specified when calling `ChatModel.start_chat` by default.
522-
temperature: Controls the randomness of predictions. Range: [0, 1].
519+
temperature: Controls the randomness of predictions. Range: [0, 1]. Default: 0.
523520
Uses the value specified when calling `ChatModel.start_chat` by default.
524-
top_k: The number of highest probability vocabulary tokens to keep for top-k-filtering.
521+
top_k: The number of highest probability vocabulary tokens to keep for top-k-filtering. Range: [1, 40]. Default: 40.
525522
Uses the value specified when calling `ChatModel.start_chat` by default.
526-
top_p: The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Range: [0, 1].
523+
top_p: The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Range: [0, 1]. Default: 0.95.
527524
Uses the value specified when calling `ChatModel.start_chat` by default.
528525
529526
Returns:
@@ -633,10 +630,10 @@ def start_chat(
633630
*,
634631
context: Optional[str] = None,
635632
examples: Optional[List[InputOutputTextPair]] = None,
636-
max_output_tokens: int = _TextGenerationModel._DEFAULT_MAX_OUTPUT_TOKENS,
637-
temperature: float = _TextGenerationModel._DEFAULT_TEMPERATURE,
638-
top_k: int = _TextGenerationModel._DEFAULT_TOP_K,
639-
top_p: float = _TextGenerationModel._DEFAULT_TOP_P,
633+
max_output_tokens: Optional[int] = _TextGenerationModel._DEFAULT_MAX_OUTPUT_TOKENS,
634+
temperature: Optional[float] = None,
635+
top_k: Optional[int] = None,
636+
top_p: Optional[float] = None,
640637
message_history: Optional[List[ChatMessage]] = None,
641638
) -> "ChatSession":
642639
"""Starts a chat session with the model.
@@ -646,10 +643,10 @@ def start_chat(
646643
For example, you can use context to specify words the model can or cannot use, topics to focus on or avoid, or the response format or style
647644
examples: List of structured messages to the model to learn how to respond to the conversation.
648645
A list of `InputOutputTextPair` objects.
649-
max_output_tokens: Max length of the output text in tokens.
650-
temperature: Controls the randomness of predictions. Range: [0, 1].
651-
top_k: The number of highest probability vocabulary tokens to keep for top-k-filtering. Range: [1, 40]
652-
top_p: The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Range: [0, 1].
646+
max_output_tokens: Max length of the output text in tokens. Range: [1, 1024].
647+
temperature: Controls the randomness of predictions. Range: [0, 1]. Default: 0.
648+
top_k: The number of highest probability vocabulary tokens to keep for top-k-filtering. Range: [1, 40]. Default: 40.
649+
top_p: The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Range: [0, 1]. Default: 0.95.
653650
message_history: A list of previously sent and received messages.
654651
655652
Returns:
@@ -717,19 +714,18 @@ class CodeChatModel(_ChatModelBase):
717714
_LAUNCH_STAGE = _model_garden_models._SDK_GA_LAUNCH_STAGE
718715

719716
_DEFAULT_MAX_OUTPUT_TOKENS = 128
720-
_DEFAULT_TEMPERATURE = 0.5
721717

722718
def start_chat(
723719
self,
724720
*,
725-
max_output_tokens: int = _DEFAULT_MAX_OUTPUT_TOKENS,
726-
temperature: float = _DEFAULT_TEMPERATURE,
721+
max_output_tokens: Optional[int] = _DEFAULT_MAX_OUTPUT_TOKENS,
722+
temperature: Optional[float] = None,
727723
message_history: Optional[List[ChatMessage]] = None,
728724
) -> "CodeChatSession":
729725
"""Starts a chat session with the code chat model.
730726
731727
Args:
732-
max_output_tokens: Max length of the output text in tokens.
728+
max_output_tokens: Max length of the output text in tokens. Range: [1, 1000].
733729
temperature: Controls the randomness of predictions. Range: [0, 1].
734730
735731
Returns:
@@ -754,11 +750,10 @@ def __init__(
754750
model: _ChatModelBase,
755751
context: Optional[str] = None,
756752
examples: Optional[List[InputOutputTextPair]] = None,
757-
max_output_tokens: int = _TextGenerationModel._DEFAULT_MAX_OUTPUT_TOKENS,
758-
temperature: float = _TextGenerationModel._DEFAULT_TEMPERATURE,
759-
top_k: int = _TextGenerationModel._DEFAULT_TOP_K,
760-
top_p: float = _TextGenerationModel._DEFAULT_TOP_P,
761-
is_code_chat_session: bool = False,
753+
max_output_tokens: Optional[int] = _TextGenerationModel._DEFAULT_MAX_OUTPUT_TOKENS,
754+
temperature: Optional[float] = None,
755+
top_k: Optional[int] = None,
756+
top_p: Optional[float] = None,
762757
message_history: Optional[List[ChatMessage]] = None,
763758
):
764759
self._model = model
@@ -768,7 +763,6 @@ def __init__(
768763
self._temperature = temperature
769764
self._top_k = top_k
770765
self._top_p = top_p
771-
self._is_code_chat_session = is_code_chat_session
772766
self._message_history: List[ChatMessage] = message_history or []
773767

774768
@property
@@ -789,30 +783,36 @@ def send_message(
789783
790784
Args:
791785
message: Message to send to the model
792-
max_output_tokens: Max length of the output text in tokens.
786+
max_output_tokens: Max length of the output text in tokens. Range: [1, 1024].
793787
Uses the value specified when calling `ChatModel.start_chat` by default.
794-
temperature: Controls the randomness of predictions. Range: [0, 1].
788+
temperature: Controls the randomness of predictions. Range: [0, 1]. Default: 0.
795789
Uses the value specified when calling `ChatModel.start_chat` by default.
796-
top_k: The number of highest probability vocabulary tokens to keep for top-k-filtering.
790+
top_k: The number of highest probability vocabulary tokens to keep for top-k-filtering. Range: [1, 40]. Default: 40.
797791
Uses the value specified when calling `ChatModel.start_chat` by default.
798-
top_p: The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Range: [0, 1].
792+
top_p: The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Range: [0, 1]. Default: 0.95.
799793
Uses the value specified when calling `ChatModel.start_chat` by default.
800794
801795
Returns:
802796
A `TextGenerationResponse` object that contains the text produced by the model.
803797
"""
804-
prediction_parameters = {
805-
"temperature": temperature
806-
if temperature is not None
807-
else self._temperature,
808-
"maxDecodeSteps": max_output_tokens
809-
if max_output_tokens is not None
810-
else self._max_output_tokens,
811-
}
798+
prediction_parameters = {}
799+
800+
max_output_tokens = max_output_tokens or self._max_output_tokens
801+
if max_output_tokens:
802+
prediction_parameters["maxDecodeSteps"] = max_output_tokens
812803

813-
if not self._is_code_chat_session:
814-
prediction_parameters["topP"] = top_p if top_p is not None else self._top_p
815-
prediction_parameters["topK"] = top_k if top_k is not None else self._top_k
804+
if temperature is None:
805+
temperature = self._temperature
806+
if temperature is not None:
807+
prediction_parameters["temperature"] = temperature
808+
809+
top_p = top_p or self._top_p
810+
if top_p:
811+
prediction_parameters["topP"] = top_p
812+
813+
top_k = top_k or self._top_k
814+
if top_k:
815+
prediction_parameters["topK"] = top_k
816816

817817
message_structs = []
818818
for past_message in self._message_history:
@@ -830,9 +830,9 @@ def send_message(
830830
)
831831

832832
prediction_instance = {"messages": message_structs}
833-
if not self._is_code_chat_session and self._context:
833+
if self._context:
834834
prediction_instance["context"] = self._context
835-
if not self._is_code_chat_session and self._examples:
835+
if self._examples:
836836
prediction_instance["examples"] = [
837837
{
838838
"input": {"content": example.input_text},
@@ -885,10 +885,10 @@ def __init__(
885885
model: ChatModel,
886886
context: Optional[str] = None,
887887
examples: Optional[List[InputOutputTextPair]] = None,
888-
max_output_tokens: int = _TextGenerationModel._DEFAULT_MAX_OUTPUT_TOKENS,
889-
temperature: float = _TextGenerationModel._DEFAULT_TEMPERATURE,
890-
top_k: int = _TextGenerationModel._DEFAULT_TOP_K,
891-
top_p: float = _TextGenerationModel._DEFAULT_TOP_P,
888+
max_output_tokens: Optional[int] = _TextGenerationModel._DEFAULT_MAX_OUTPUT_TOKENS,
889+
temperature: Optional[float] = None,
890+
top_k: Optional[int] = None,
891+
top_p: Optional[float] = None,
892892
message_history: Optional[List[ChatMessage]] = None,
893893
):
894894
super().__init__(
@@ -913,14 +913,13 @@ def __init__(
913913
self,
914914
model: CodeChatModel,
915915
max_output_tokens: int = CodeChatModel._DEFAULT_MAX_OUTPUT_TOKENS,
916-
temperature: float = CodeChatModel._DEFAULT_TEMPERATURE,
916+
temperature: Optional[float] = None,
917917
message_history: Optional[List[ChatMessage]] = None,
918918
):
919919
super().__init__(
920920
model=model,
921921
max_output_tokens=max_output_tokens,
922922
temperature=temperature,
923-
is_code_chat_session=True,
924923
message_history=message_history,
925924
)
926925

@@ -935,7 +934,7 @@ def send_message(
935934
936935
Args:
937936
message: Message to send to the model
938-
max_output_tokens: Max length of the output text in tokens.
937+
max_output_tokens: Max length of the output text in tokens. Range: [1, 1000].
939938
Uses the value specified when calling `CodeChatModel.start_chat` by default.
940939
temperature: Controls the randomness of predictions. Range: [0, 1].
941940
Uses the value specified when calling `CodeChatModel.start_chat` by default.
@@ -970,33 +969,38 @@ class CodeGenerationModel(_LanguageModel):
970969
_INSTANCE_SCHEMA_URI = "gs://google-cloud-aiplatform/schema/predict/instance/code_generation_1.0.0.yaml"
971970

972971
_LAUNCH_STAGE = _model_garden_models._SDK_GA_LAUNCH_STAGE
973-
_DEFAULT_TEMPERATURE = 0.0
974972
_DEFAULT_MAX_OUTPUT_TOKENS = 128
975973

976974
def predict(
977975
self,
978976
prefix: str,
979-
suffix: Optional[str] = "",
977+
suffix: Optional[str] = None,
980978
*,
981-
max_output_tokens: int = _DEFAULT_MAX_OUTPUT_TOKENS,
982-
temperature: float = _DEFAULT_TEMPERATURE,
979+
max_output_tokens: Optional[int] = _DEFAULT_MAX_OUTPUT_TOKENS,
980+
temperature: Optional[float] = None,
983981
) -> "TextGenerationResponse":
984982
"""Gets model response for a single prompt.
985983
986984
Args:
987985
prefix: Code before the current point.
988986
suffix: Code after the current point.
989-
max_output_tokens: Max length of the output text in tokens.
987+
max_output_tokens: Max length of the output text in tokens. Range: [1, 1000].
990988
temperature: Controls the randomness of predictions. Range: [0, 1].
991989
992990
Returns:
993991
A `TextGenerationResponse` object that contains the text produced by the model.
994992
"""
995-
instance = {"prefix": prefix, "suffix": suffix}
996-
prediction_parameters = {
997-
"temperature": temperature,
998-
"maxOutputTokens": max_output_tokens,
999-
}
993+
instance = {"prefix": prefix}
994+
if suffix:
995+
instance["suffix"] = suffix
996+
997+
prediction_parameters = {}
998+
999+
if temperature is not None:
1000+
prediction_parameters["temperature"] = temperature
1001+
1002+
if max_output_tokens:
1003+
prediction_parameters["maxOutputTokens"] = max_output_tokens
10001004

10011005
prediction_response = self._endpoint.predict(
10021006
instances=[instance],

0 commit comments

Comments
 (0)