5353 model as gca_model ,
5454)
5555
56- from vertexai .preview import language_models
56+ from vertexai .preview import (
57+ language_models as preview_language_models ,
58+ )
59+ from vertexai import language_models
5760from google .cloud .aiplatform_v1 import Execution as GapicExecution
5861from google .cloud .aiplatform .compat .types import (
5962 encryption_spec as gca_encryption_spec ,
@@ -456,7 +459,7 @@ def get_endpoint_mock():
456459@pytest .fixture
457460def mock_get_tuned_model (get_endpoint_mock ):
458461 with mock .patch .object (
459- language_models .TextGenerationModel , "get_tuned_model"
462+ preview_language_models .TextGenerationModel , "get_tuned_model"
460463 ) as mock_text_generation_model :
461464 mock_text_generation_model ._model_id = (
462465 test_constants .ModelConstants ._TEST_MODEL_RESOURCE_NAME
@@ -519,6 +522,50 @@ def teardown_method(self):
519522 initializer .global_pool .shutdown (wait = True )
520523
521524 def test_text_generation (self ):
525+ """Tests the text generation model."""
526+ aiplatform .init (
527+ project = _TEST_PROJECT ,
528+ location = _TEST_LOCATION ,
529+ )
530+ with mock .patch .object (
531+ target = model_garden_service_client .ModelGardenServiceClient ,
532+ attribute = "get_publisher_model" ,
533+ return_value = gca_publisher_model .PublisherModel (
534+ _TEXT_BISON_PUBLISHER_MODEL_DICT
535+ ),
536+ ) as mock_get_publisher_model :
537+ model = preview_language_models .TextGenerationModel .from_pretrained (
538+ "text-bison@001"
539+ )
540+
541+ mock_get_publisher_model .assert_called_once_with (
542+ name = "publishers/google/models/text-bison@001" , retry = base ._DEFAULT_RETRY
543+ )
544+
545+ assert (
546+ model ._model_resource_name
547+ == f"projects/{ _TEST_PROJECT } /locations/{ _TEST_LOCATION } /publishers/google/models/text-bison@001"
548+ )
549+
550+ gca_predict_response = gca_prediction_service .PredictResponse ()
551+ gca_predict_response .predictions .append (_TEST_TEXT_GENERATION_PREDICTION )
552+
553+ with mock .patch .object (
554+ target = prediction_service_client .PredictionServiceClient ,
555+ attribute = "predict" ,
556+ return_value = gca_predict_response ,
557+ ):
558+ response = model .predict (
559+ "What is the best recipe for banana bread? Recipe:" ,
560+ max_output_tokens = 128 ,
561+ temperature = 0 ,
562+ top_p = 1 ,
563+ top_k = 5 ,
564+ )
565+
566+ assert response .text == _TEST_TEXT_GENERATION_PREDICTION ["content" ]
567+
568+ def test_text_generation_ga (self ):
522569 """Tests the text generation model."""
523570 aiplatform .init (
524571 project = _TEST_PROJECT ,
@@ -596,7 +643,7 @@ def test_tune_model(
596643 _TEXT_BISON_PUBLISHER_MODEL_DICT
597644 ),
598645 ):
599- model = language_models .TextGenerationModel .from_pretrained (
646+ model = preview_language_models .TextGenerationModel .from_pretrained (
600647 "text-bison@001"
601648 )
602649
@@ -631,7 +678,7 @@ def test_get_tuned_model(
631678 _TEXT_BISON_PUBLISHER_MODEL_DICT
632679 ),
633680 ):
634- tuned_model = language_models .TextGenerationModel .get_tuned_model (
681+ tuned_model = preview_language_models .TextGenerationModel .get_tuned_model (
635682 test_constants .ModelConstants ._TEST_MODEL_RESOURCE_NAME
636683 )
637684
@@ -651,7 +698,7 @@ def get_tuned_model_raises_if_not_called_with_mg_model(self):
651698 )
652699
653700 with pytest .raises (ValueError ):
654- language_models .TextGenerationModel .get_tuned_model (
701+ preview_language_models .TextGenerationModel .get_tuned_model (
655702 test_constants .ModelConstants ._TEST_MODEL_RESOURCE_NAME
656703 )
657704
@@ -668,7 +715,7 @@ def test_chat(self):
668715 _CHAT_BISON_PUBLISHER_MODEL_DICT
669716 ),
670717 ) as mock_get_publisher_model :
671- model = language_models .ChatModel .from_pretrained ("chat-bison@001" )
718+ model = preview_language_models .ChatModel .from_pretrained ("chat-bison@001" )
672719
673720 mock_get_publisher_model .assert_called_once_with (
674721 name = "publishers/google/models/chat-bison@001" , retry = base ._DEFAULT_RETRY
@@ -681,11 +728,11 @@ def test_chat(self):
681728 My favorite movies are Lord of the Rings and Hobbit.
682729 """ ,
683730 examples = [
684- language_models .InputOutputTextPair (
731+ preview_language_models .InputOutputTextPair (
685732 input_text = "Who do you work for?" ,
686733 output_text = "I work for Ned." ,
687734 ),
688- language_models .InputOutputTextPair (
735+ preview_language_models .InputOutputTextPair (
689736 input_text = "What do I like?" ,
690737 output_text = "Ned likes watching movies." ,
691738 ),
@@ -786,7 +833,7 @@ def test_code_chat(self):
786833 _CODECHAT_BISON_PUBLISHER_MODEL_DICT
787834 ),
788835 ) as mock_get_publisher_model :
789- model = language_models .CodeChatModel .from_pretrained (
836+ model = preview_language_models .CodeChatModel .from_pretrained (
790837 "google/codechat-bison@001"
791838 )
792839
@@ -882,7 +929,7 @@ def test_code_generation(self):
882929 _CODE_GENERATION_BISON_PUBLISHER_MODEL_DICT
883930 ),
884931 ) as mock_get_publisher_model :
885- model = language_models .CodeGenerationModel .from_pretrained (
932+ model = preview_language_models .CodeGenerationModel .from_pretrained (
886933 "google/code-bison@001"
887934 )
888935
@@ -909,9 +956,11 @@ def test_code_generation(self):
909956 # Validating the parameters
910957 predict_temperature = 0.1
911958 predict_max_output_tokens = 100
912- default_temperature = language_models .CodeGenerationModel ._DEFAULT_TEMPERATURE
959+ default_temperature = (
960+ preview_language_models .CodeGenerationModel ._DEFAULT_TEMPERATURE
961+ )
913962 default_max_output_tokens = (
914- language_models .CodeGenerationModel ._DEFAULT_MAX_OUTPUT_TOKENS
963+ preview_language_models .CodeGenerationModel ._DEFAULT_MAX_OUTPUT_TOKENS
915964 )
916965
917966 with mock .patch .object (
@@ -948,7 +997,7 @@ def test_code_completion(self):
948997 _CODE_COMPLETION_BISON_PUBLISHER_MODEL_DICT
949998 ),
950999 ) as mock_get_publisher_model :
951- model = language_models .CodeGenerationModel .from_pretrained (
1000+ model = preview_language_models .CodeGenerationModel .from_pretrained (
9521001 "google/code-gecko@001"
9531002 )
9541003
@@ -975,9 +1024,11 @@ def test_code_completion(self):
9751024 # Validating the parameters
9761025 predict_temperature = 0.1
9771026 predict_max_output_tokens = 100
978- default_temperature = language_models .CodeGenerationModel ._DEFAULT_TEMPERATURE
1027+ default_temperature = (
1028+ preview_language_models .CodeGenerationModel ._DEFAULT_TEMPERATURE
1029+ )
9791030 default_max_output_tokens = (
980- language_models .CodeGenerationModel ._DEFAULT_MAX_OUTPUT_TOKENS
1031+ preview_language_models .CodeGenerationModel ._DEFAULT_MAX_OUTPUT_TOKENS
9811032 )
9821033
9831034 with mock .patch .object (
@@ -1002,6 +1053,43 @@ def test_code_completion(self):
10021053 assert prediction_parameters ["maxOutputTokens" ] == default_max_output_tokens
10031054
10041055 def test_text_embedding (self ):
1056+ """Tests the text embedding model."""
1057+ aiplatform .init (
1058+ project = _TEST_PROJECT ,
1059+ location = _TEST_LOCATION ,
1060+ )
1061+ with mock .patch .object (
1062+ target = model_garden_service_client .ModelGardenServiceClient ,
1063+ attribute = "get_publisher_model" ,
1064+ return_value = gca_publisher_model .PublisherModel (
1065+ _TEXT_EMBEDDING_GECKO_PUBLISHER_MODEL_DICT
1066+ ),
1067+ ) as mock_get_publisher_model :
1068+ model = preview_language_models .TextEmbeddingModel .from_pretrained (
1069+ "textembedding-gecko@001"
1070+ )
1071+
1072+ mock_get_publisher_model .assert_called_once_with (
1073+ name = "publishers/google/models/textembedding-gecko@001" ,
1074+ retry = base ._DEFAULT_RETRY ,
1075+ )
1076+
1077+ gca_predict_response = gca_prediction_service .PredictResponse ()
1078+ gca_predict_response .predictions .append (_TEST_TEXT_EMBEDDING_PREDICTION )
1079+
1080+ with mock .patch .object (
1081+ target = prediction_service_client .PredictionServiceClient ,
1082+ attribute = "predict" ,
1083+ return_value = gca_predict_response ,
1084+ ):
1085+ embeddings = model .get_embeddings (["What is life?" ])
1086+ assert embeddings
1087+ for embedding in embeddings :
1088+ vector = embedding .values
1089+ assert len (vector ) == _TEXT_EMBEDDING_VECTOR_LENGTH
1090+ assert vector == _TEST_TEXT_EMBEDDING_PREDICTION ["embeddings" ]["values" ]
1091+
1092+ def test_text_embedding_ga (self ):
10051093 """Tests the text embedding model."""
10061094 aiplatform .init (
10071095 project = _TEST_PROJECT ,
0 commit comments