@@ -724,6 +724,53 @@ def test_tune_chat_model(
724724 ].runtime_config .parameter_values
725725 assert pipeline_arguments ["large_model_reference" ] == "chat-bison@001"
726726
727+ @pytest .mark .parametrize (
728+ "job_spec" ,
729+ [_TEST_PIPELINE_SPEC_JSON ],
730+ )
731+ @pytest .mark .parametrize (
732+ "mock_request_urlopen" ,
733+ ["https://us-central1-kfp.pkg.dev/proj/repo/pack/latest" ],
734+ indirect = True ,
735+ )
736+ def test_tune_code_chat_model (
737+ self ,
738+ mock_pipeline_service_create ,
739+ mock_pipeline_job_get ,
740+ mock_pipeline_bucket_exists ,
741+ job_spec ,
742+ mock_load_yaml_and_json ,
743+ mock_gcs_from_string ,
744+ mock_gcs_upload ,
745+ mock_request_urlopen ,
746+ mock_get_tuned_model ,
747+ ):
748+ """Tests tuning a code chat model."""
749+ aiplatform .init (project = _TEST_PROJECT , location = _TEST_LOCATION )
750+ with mock .patch .object (
751+ target = model_garden_service_client .ModelGardenServiceClient ,
752+ attribute = "get_publisher_model" ,
753+ return_value = gca_publisher_model .PublisherModel (
754+ _CODECHAT_BISON_PUBLISHER_MODEL_DICT
755+ ),
756+ ):
757+ model = preview_language_models .CodeChatModel .from_pretrained (
758+ "codechat-bison@001"
759+ )
760+
761+ # The tune_model call needs to be inside the PublisherModel mock
762+ # since it gets a new PublisherModel when tuning completes.
763+ model .tune_model (
764+ training_data = _TEST_TEXT_BISON_TRAINING_DF ,
765+ tuning_job_location = "europe-west4" ,
766+ tuned_model_location = "us-central1" ,
767+ )
768+ call_kwargs = mock_pipeline_service_create .call_args [1 ]
769+ pipeline_arguments = call_kwargs [
770+ "pipeline_job"
771+ ].runtime_config .parameter_values
772+ assert pipeline_arguments ["large_model_reference" ] == "codechat-bison@001"
773+
727774 @pytest .mark .usefixtures (
728775 "get_model_with_tuned_version_label_mock" ,
729776 "get_endpoint_with_models_mock" ,
0 commit comments