@@ -237,28 +237,25 @@ class _TextGenerationModel(_LanguageModel):
237237
238238 _INSTANCE_SCHEMA_URI = "gs://google-cloud-aiplatform/schema/predict/instance/text_generation_1.0.0.yaml"
239239
240- _DEFAULT_TEMPERATURE = 0.0
241240 _DEFAULT_MAX_OUTPUT_TOKENS = 128
242- _DEFAULT_TOP_P = 0.95
243- _DEFAULT_TOP_K = 40
244241
245242 def predict (
246243 self ,
247244 prompt : str ,
248245 * ,
249- max_output_tokens : int = _DEFAULT_MAX_OUTPUT_TOKENS ,
250- temperature : float = _DEFAULT_TEMPERATURE ,
251- top_k : int = _DEFAULT_TOP_K ,
252- top_p : float = _DEFAULT_TOP_P ,
246+ max_output_tokens : Optional [ int ] = _DEFAULT_MAX_OUTPUT_TOKENS ,
247+ temperature : Optional [ float ] = None ,
248+ top_k : Optional [ int ] = None ,
249+ top_p : Optional [ float ] = None ,
253250 ) -> "TextGenerationResponse" :
254251 """Gets model response for a single prompt.
255252
256253 Args:
257254 prompt: Question to ask the model.
258- max_output_tokens: Max length of the output text in tokens.
259- temperature: Controls the randomness of predictions. Range: [0, 1].
260- top_k: The number of highest probability vocabulary tokens to keep for top-k-filtering.
261- top_p: The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Range: [0, 1].
255+ max_output_tokens: Max length of the output text in tokens. Range: [1, 1024].
256+ temperature: Controls the randomness of predictions. Range: [0, 1]. Default: 0.
257+ top_k: The number of highest probability vocabulary tokens to keep for top-k-filtering. Range: [1, 40]. Default: 40.
258+ top_p: The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Range: [0, 1]. Default: 0.95.
262259
263260 Returns:
264261 A `TextGenerationResponse` object that contains the text produced by the model.
@@ -275,19 +272,19 @@ def predict(
275272 def _batch_predict (
276273 self ,
277274 prompts : List [str ],
278- max_output_tokens : int = _DEFAULT_MAX_OUTPUT_TOKENS ,
279- temperature : float = _DEFAULT_TEMPERATURE ,
280- top_k : int = _DEFAULT_TOP_K ,
281- top_p : float = _DEFAULT_TOP_P ,
275+ max_output_tokens : Optional [ int ] = _DEFAULT_MAX_OUTPUT_TOKENS ,
276+ temperature : Optional [ float ] = None ,
277+ top_k : Optional [ int ] = None ,
278+ top_p : Optional [ float ] = None ,
282279 ) -> List ["TextGenerationResponse" ]:
283280 """Gets model response for a single prompt.
284281
285282 Args:
286283 prompts: Questions to ask the model.
287- max_output_tokens: Max length of the output text in tokens.
288- temperature: Controls the randomness of predictions. Range: [0, 1].
289- top_k: The number of highest probability vocabulary tokens to keep for top-k-filtering.
290- top_p: The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Range: [0, 1].
284+ max_output_tokens: Max length of the output text in tokens. Range: [1, 1024].
285+ temperature: Controls the randomness of predictions. Range: [0, 1]. Default: 0.
286+ top_k: The number of highest probability vocabulary tokens to keep for top-k-filtering. Range: [1, 40]. Default: 40.
287+ top_p: The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Range: [0, 1]. Default: 0.95.
291288
292289 Returns:
293290 A list of `TextGenerationResponse` objects that contain the texts produced by the model.
@@ -458,17 +455,17 @@ class _ChatModel(_TextGenerationModel):
458455 def start_chat (
459456 self ,
460457 max_output_tokens : int = _TextGenerationModel ._DEFAULT_MAX_OUTPUT_TOKENS ,
461- temperature : float = _TextGenerationModel . _DEFAULT_TEMPERATURE ,
462- top_k : int = _TextGenerationModel . _DEFAULT_TOP_K ,
463- top_p : float = _TextGenerationModel . _DEFAULT_TOP_P ,
458+ temperature : Optional [ float ] = None ,
459+ top_k : Optional [ int ] = None ,
460+ top_p : Optional [ float ] = None ,
464461 ) -> "_ChatSession" :
465462 """Starts a chat session with the model.
466463
467464 Args:
468- max_output_tokens: Max length of the output text in tokens.
469- temperature: Controls the randomness of predictions. Range: [0, 1].
470- top_k: The number of highest probability vocabulary tokens to keep for top-k-filtering.
471- top_p: The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Range: [0, 1].
465+ max_output_tokens: Max length of the output text in tokens. Range: [1, 1024].
466+ temperature: Controls the randomness of predictions. Range: [0, 1]. Default: 0.
467+ top_k: The number of highest probability vocabulary tokens to keep for top-k-filtering. Range: [1, 40]. Default: 40.
468+ top_p: The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Range: [0, 1]. Default: 0.95.
472469
473470 Returns:
474471 A `ChatSession` object.
@@ -492,9 +489,9 @@ def __init__(
492489 self ,
493490 model : _ChatModel ,
494491 max_output_tokens : int = _TextGenerationModel ._DEFAULT_MAX_OUTPUT_TOKENS ,
495- temperature : float = _TextGenerationModel . _DEFAULT_TEMPERATURE ,
496- top_k : int = _TextGenerationModel . _DEFAULT_TOP_K ,
497- top_p : float = _TextGenerationModel . _DEFAULT_TOP_P ,
492+ temperature : Optional [ float ] = None ,
493+ top_k : Optional [ int ] = None ,
494+ top_p : Optional [ float ] = None ,
498495 ):
499496 self ._model = model
500497 self ._history = []
@@ -517,13 +514,13 @@ def send_message(
517514
518515 Args:
519516 message: Message to send to the model
520- max_output_tokens: Max length of the output text in tokens.
517+ max_output_tokens: Max length of the output text in tokens. Range: [1, 1024].
521518 Uses the value specified when calling `ChatModel.start_chat` by default.
522- temperature: Controls the randomness of predictions. Range: [0, 1].
519+ temperature: Controls the randomness of predictions. Range: [0, 1]. Default: 0.
523520 Uses the value specified when calling `ChatModel.start_chat` by default.
524- top_k: The number of highest probability vocabulary tokens to keep for top-k-filtering.
521+ top_k: The number of highest probability vocabulary tokens to keep for top-k-filtering. Range: [1, 40]. Default: 40.
525522 Uses the value specified when calling `ChatModel.start_chat` by default.
526- top_p: The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Range: [0, 1].
523+ top_p: The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Range: [0, 1]. Default: 0.95.
527524 Uses the value specified when calling `ChatModel.start_chat` by default.
528525
529526 Returns:
@@ -633,10 +630,10 @@ def start_chat(
633630 * ,
634631 context : Optional [str ] = None ,
635632 examples : Optional [List [InputOutputTextPair ]] = None ,
636- max_output_tokens : int = _TextGenerationModel ._DEFAULT_MAX_OUTPUT_TOKENS ,
637- temperature : float = _TextGenerationModel . _DEFAULT_TEMPERATURE ,
638- top_k : int = _TextGenerationModel . _DEFAULT_TOP_K ,
639- top_p : float = _TextGenerationModel . _DEFAULT_TOP_P ,
633+ max_output_tokens : Optional [ int ] = _TextGenerationModel ._DEFAULT_MAX_OUTPUT_TOKENS ,
634+ temperature : Optional [ float ] = None ,
635+ top_k : Optional [ int ] = None ,
636+ top_p : Optional [ float ] = None ,
640637 message_history : Optional [List [ChatMessage ]] = None ,
641638 ) -> "ChatSession" :
642639 """Starts a chat session with the model.
@@ -646,10 +643,10 @@ def start_chat(
646643 For example, you can use context to specify words the model can or cannot use, topics to focus on or avoid, or the response format or style
647644 examples: List of structured messages to the model to learn how to respond to the conversation.
648645 A list of `InputOutputTextPair` objects.
649- max_output_tokens: Max length of the output text in tokens.
650- temperature: Controls the randomness of predictions. Range: [0, 1].
651- top_k: The number of highest probability vocabulary tokens to keep for top-k-filtering. Range: [1, 40]
652- top_p: The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Range: [0, 1].
646+ max_output_tokens: Max length of the output text in tokens. Range: [1, 1024].
647+ temperature: Controls the randomness of predictions. Range: [0, 1]. Default: 0.
648+ top_k: The number of highest probability vocabulary tokens to keep for top-k-filtering. Range: [1, 40]. Default: 40.
649+ top_p: The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Range: [0, 1]. Default: 0.95.
653650 message_history: A list of previously sent and received messages.
654651
655652 Returns:
@@ -717,19 +714,18 @@ class CodeChatModel(_ChatModelBase):
717714 _LAUNCH_STAGE = _model_garden_models ._SDK_GA_LAUNCH_STAGE
718715
719716 _DEFAULT_MAX_OUTPUT_TOKENS = 128
720- _DEFAULT_TEMPERATURE = 0.5
721717
722718 def start_chat (
723719 self ,
724720 * ,
725- max_output_tokens : int = _DEFAULT_MAX_OUTPUT_TOKENS ,
726- temperature : float = _DEFAULT_TEMPERATURE ,
721+ max_output_tokens : Optional [ int ] = _DEFAULT_MAX_OUTPUT_TOKENS ,
722+ temperature : Optional [ float ] = None ,
727723 message_history : Optional [List [ChatMessage ]] = None ,
728724 ) -> "CodeChatSession" :
729725 """Starts a chat session with the code chat model.
730726
731727 Args:
732- max_output_tokens: Max length of the output text in tokens.
728+ max_output_tokens: Max length of the output text in tokens. Range: [1, 1000].
733729 temperature: Controls the randomness of predictions. Range: [0, 1].
734730
735731 Returns:
@@ -754,11 +750,10 @@ def __init__(
754750 model : _ChatModelBase ,
755751 context : Optional [str ] = None ,
756752 examples : Optional [List [InputOutputTextPair ]] = None ,
757- max_output_tokens : int = _TextGenerationModel ._DEFAULT_MAX_OUTPUT_TOKENS ,
758- temperature : float = _TextGenerationModel ._DEFAULT_TEMPERATURE ,
759- top_k : int = _TextGenerationModel ._DEFAULT_TOP_K ,
760- top_p : float = _TextGenerationModel ._DEFAULT_TOP_P ,
761- is_code_chat_session : bool = False ,
753+ max_output_tokens : Optional [int ] = _TextGenerationModel ._DEFAULT_MAX_OUTPUT_TOKENS ,
754+ temperature : Optional [float ] = None ,
755+ top_k : Optional [int ] = None ,
756+ top_p : Optional [float ] = None ,
762757 message_history : Optional [List [ChatMessage ]] = None ,
763758 ):
764759 self ._model = model
@@ -768,7 +763,6 @@ def __init__(
768763 self ._temperature = temperature
769764 self ._top_k = top_k
770765 self ._top_p = top_p
771- self ._is_code_chat_session = is_code_chat_session
772766 self ._message_history : List [ChatMessage ] = message_history or []
773767
774768 @property
@@ -789,30 +783,36 @@ def send_message(
789783
790784 Args:
791785 message: Message to send to the model
792- max_output_tokens: Max length of the output text in tokens.
786+ max_output_tokens: Max length of the output text in tokens. Range: [1, 1024].
793787 Uses the value specified when calling `ChatModel.start_chat` by default.
794- temperature: Controls the randomness of predictions. Range: [0, 1].
788+ temperature: Controls the randomness of predictions. Range: [0, 1]. Default: 0.
795789 Uses the value specified when calling `ChatModel.start_chat` by default.
796- top_k: The number of highest probability vocabulary tokens to keep for top-k-filtering.
790+ top_k: The number of highest probability vocabulary tokens to keep for top-k-filtering. Range: [1, 40]. Default: 40.
797791 Uses the value specified when calling `ChatModel.start_chat` by default.
798- top_p: The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Range: [0, 1].
792+ top_p: The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Range: [0, 1]. Default: 0.95.
799793 Uses the value specified when calling `ChatModel.start_chat` by default.
800794
801795 Returns:
802796 A `TextGenerationResponse` object that contains the text produced by the model.
803797 """
804- prediction_parameters = {
805- "temperature" : temperature
806- if temperature is not None
807- else self ._temperature ,
808- "maxDecodeSteps" : max_output_tokens
809- if max_output_tokens is not None
810- else self ._max_output_tokens ,
811- }
798+ prediction_parameters = {}
799+
800+ max_output_tokens = max_output_tokens or self ._max_output_tokens
801+ if max_output_tokens :
802+ prediction_parameters ["maxDecodeSteps" ] = max_output_tokens
812803
813- if not self ._is_code_chat_session :
814- prediction_parameters ["topP" ] = top_p if top_p is not None else self ._top_p
815- prediction_parameters ["topK" ] = top_k if top_k is not None else self ._top_k
804+ if temperature is None :
805+ temperature = self ._temperature
806+ if temperature is not None :
807+ prediction_parameters ["temperature" ] = temperature
808+
809+ top_p = top_p or self ._top_p
810+ if top_p :
811+ prediction_parameters ["topP" ] = top_p
812+
813+ top_k = top_k or self ._top_k
814+ if top_k :
815+ prediction_parameters ["topK" ] = top_k
816816
817817 message_structs = []
818818 for past_message in self ._message_history :
@@ -830,9 +830,9 @@ def send_message(
830830 )
831831
832832 prediction_instance = {"messages" : message_structs }
833- if not self . _is_code_chat_session and self ._context :
833+ if self ._context :
834834 prediction_instance ["context" ] = self ._context
835- if not self . _is_code_chat_session and self ._examples :
835+ if self ._examples :
836836 prediction_instance ["examples" ] = [
837837 {
838838 "input" : {"content" : example .input_text },
@@ -885,10 +885,10 @@ def __init__(
885885 model : ChatModel ,
886886 context : Optional [str ] = None ,
887887 examples : Optional [List [InputOutputTextPair ]] = None ,
888- max_output_tokens : int = _TextGenerationModel ._DEFAULT_MAX_OUTPUT_TOKENS ,
889- temperature : float = _TextGenerationModel . _DEFAULT_TEMPERATURE ,
890- top_k : int = _TextGenerationModel . _DEFAULT_TOP_K ,
891- top_p : float = _TextGenerationModel . _DEFAULT_TOP_P ,
888+ max_output_tokens : Optional [ int ] = _TextGenerationModel ._DEFAULT_MAX_OUTPUT_TOKENS ,
889+ temperature : Optional [ float ] = None ,
890+ top_k : Optional [ int ] = None ,
891+ top_p : Optional [ float ] = None ,
892892 message_history : Optional [List [ChatMessage ]] = None ,
893893 ):
894894 super ().__init__ (
@@ -913,14 +913,13 @@ def __init__(
913913 self ,
914914 model : CodeChatModel ,
915915 max_output_tokens : int = CodeChatModel ._DEFAULT_MAX_OUTPUT_TOKENS ,
916- temperature : float = CodeChatModel . _DEFAULT_TEMPERATURE ,
916+ temperature : Optional [ float ] = None ,
917917 message_history : Optional [List [ChatMessage ]] = None ,
918918 ):
919919 super ().__init__ (
920920 model = model ,
921921 max_output_tokens = max_output_tokens ,
922922 temperature = temperature ,
923- is_code_chat_session = True ,
924923 message_history = message_history ,
925924 )
926925
@@ -935,7 +934,7 @@ def send_message(
935934
936935 Args:
937936 message: Message to send to the model
938- max_output_tokens: Max length of the output text in tokens.
937+ max_output_tokens: Max length of the output text in tokens. Range: [1, 1000].
939938 Uses the value specified when calling `CodeChatModel.start_chat` by default.
940939 temperature: Controls the randomness of predictions. Range: [0, 1].
941940 Uses the value specified when calling `CodeChatModel.start_chat` by default.
@@ -970,33 +969,38 @@ class CodeGenerationModel(_LanguageModel):
970969 _INSTANCE_SCHEMA_URI = "gs://google-cloud-aiplatform/schema/predict/instance/code_generation_1.0.0.yaml"
971970
972971 _LAUNCH_STAGE = _model_garden_models ._SDK_GA_LAUNCH_STAGE
973- _DEFAULT_TEMPERATURE = 0.0
974972 _DEFAULT_MAX_OUTPUT_TOKENS = 128
975973
976974 def predict (
977975 self ,
978976 prefix : str ,
979- suffix : Optional [str ] = "" ,
977+ suffix : Optional [str ] = None ,
980978 * ,
981- max_output_tokens : int = _DEFAULT_MAX_OUTPUT_TOKENS ,
982- temperature : float = _DEFAULT_TEMPERATURE ,
979+ max_output_tokens : Optional [ int ] = _DEFAULT_MAX_OUTPUT_TOKENS ,
980+ temperature : Optional [ float ] = None ,
983981 ) -> "TextGenerationResponse" :
984982 """Gets model response for a single prompt.
985983
986984 Args:
987985 prefix: Code before the current point.
988986 suffix: Code after the current point.
989- max_output_tokens: Max length of the output text in tokens.
987+ max_output_tokens: Max length of the output text in tokens. Range: [1, 1000].
990988 temperature: Controls the randomness of predictions. Range: [0, 1].
991989
992990 Returns:
993991 A `TextGenerationResponse` object that contains the text produced by the model.
994992 """
995- instance = {"prefix" : prefix , "suffix" : suffix }
996- prediction_parameters = {
997- "temperature" : temperature ,
998- "maxOutputTokens" : max_output_tokens ,
999- }
993+ instance = {"prefix" : prefix }
994+ if suffix :
995+ instance ["suffix" ] = suffix
996+
997+ prediction_parameters = {}
998+
999+ if temperature is not None :
1000+ prediction_parameters ["temperature" ] = temperature
1001+
1002+ if max_output_tokens :
1003+ prediction_parameters ["maxOutputTokens" ] = max_output_tokens
10001004
10011005 prediction_response = self ._endpoint .predict (
10021006 instances = [instance ],
0 commit comments