@@ -615,7 +615,7 @@ def test_text_generation_ga(self):
615615 target = prediction_service_client .PredictionServiceClient ,
616616 attribute = "predict" ,
617617 return_value = gca_predict_response ,
618- ):
618+ ) as mock_predict :
619619 response = model .predict (
620620 "What is the best recipe for banana bread? Recipe:" ,
621621 max_output_tokens = 128 ,
@@ -624,8 +624,33 @@ def test_text_generation_ga(self):
624624 top_k = 5 ,
625625 )
626626
627+ prediction_parameters = mock_predict .call_args [1 ]["parameters" ]
628+ assert prediction_parameters ["maxDecodeSteps" ] == 128
629+ assert prediction_parameters ["temperature" ] == 0
630+ assert prediction_parameters ["topP" ] == 1
631+ assert prediction_parameters ["topK" ] == 5
627632 assert response .text == _TEST_TEXT_GENERATION_PREDICTION ["content" ]
628633
634+ # Validating that unspecified parameters are not passed to the model
635+ # (except `max_output_tokens`).
636+ with mock .patch .object (
637+ target = prediction_service_client .PredictionServiceClient ,
638+ attribute = "predict" ,
639+ return_value = gca_predict_response ,
640+ ) as mock_predict :
641+ model .predict (
642+ "What is the best recipe for banana bread? Recipe:" ,
643+ )
644+
645+ prediction_parameters = mock_predict .call_args [1 ]["parameters" ]
646+ assert (
647+ prediction_parameters ["maxDecodeSteps" ]
648+ == language_models .TextGenerationModel ._DEFAULT_MAX_OUTPUT_TOKENS
649+ )
650+ assert "temperature" not in prediction_parameters
651+ assert "topP" not in prediction_parameters
652+ assert "topK" not in prediction_parameters
653+
629654 @pytest .mark .parametrize (
630655 "job_spec" ,
631656 [_TEST_PIPELINE_SPEC_JSON , _TEST_PIPELINE_JOB ],
0 commit comments