@@ -92,13 +92,15 @@ def _model_resource_name(self) -> str:
9292@dataclasses .dataclass
9393class _PredictionRequest :
9494 """A single-instance prediction request."""
95+
9596 instance : Dict [str , Any ]
9697 parameters : Optional [Dict [str , Any ]] = None
9798
9899
99100@dataclasses .dataclass
100101class _MultiInstancePredictionRequest :
101102 """A multi-instance prediction request."""
103+
102104 instances : List [Dict [str , Any ]]
103105 parameters : Optional [Dict [str , Any ]] = None
104106
@@ -573,6 +575,62 @@ def tune_model(
573575 return job
574576
575577
578+ @dataclasses .dataclass
579+ class CountTokensResponse :
580+ """The response from a count_tokens request.
581+ Attributes:
582+ total_tokens (int):
583+ The total number of tokens counted across all
584+ instances passed to the request.
585+ total_billable_characters (int):
586+ The total number of billable characters
587+ counted across all instances from the request.
588+ """
589+
590+ total_tokens : int
591+ total_billable_characters : int
592+ _count_tokens_response : Any
593+
594+
595+ class _CountTokensMixin (_LanguageModel ):
596+ """Mixin for models that support the CountTokens API"""
597+
598+ def count_tokens (
599+ self ,
600+ prompts : List [str ],
601+ ) -> CountTokensResponse :
602+ """Counts the tokens and billable characters for a given prompt.
603+
604+ Note: this does not make a request to the model, it only counts the tokens
605+ in the request.
606+
607+ Args:
608+ prompts (List[str]):
609+ Required. A list of prompts to ask the model. For example: ["What should I do today?", "How's it going?"]
610+
611+ Returns:
612+ A `CountTokensResponse` object that contains the number of tokens
613+ in the text and the number of billable characters.
614+ """
615+ instances = []
616+
617+ for prompt in prompts :
618+ instances .append ({"content" : prompt })
619+
620+ count_tokens_response = self ._endpoint ._prediction_client .select_version (
621+ "v1beta1"
622+ ).count_tokens (
623+ endpoint = self ._endpoint_name ,
624+ instances = instances ,
625+ )
626+
627+ return CountTokensResponse (
628+ total_tokens = count_tokens_response .total_tokens ,
629+ total_billable_characters = count_tokens_response .total_billable_characters ,
630+ _count_tokens_response = count_tokens_response ,
631+ )
632+
633+
576634@dataclasses .dataclass
577635class TuningEvaluationSpec :
578636 """Specification for model evaluation to perform during tuning.
@@ -587,6 +645,7 @@ class TuningEvaluationSpec:
587645 tensorboard: Vertex Tensorboard where to write the evaluation metrics.
588646 The Tensorboard must be in the same location as the tuning job.
589647 """
648+
590649 __module__ = "vertexai.language_models"
591650
592651 evaluation_data : str
@@ -605,6 +664,7 @@ class TextGenerationResponse:
605664 Learn more about the safety attributes here:
606665 https://cloud.google.com/vertex-ai/docs/generative-ai/learn/responsible-ai#safety_attribute_descriptions
607666 """
667+
608668 __module__ = "vertexai.language_models"
609669
610670 text : str
@@ -761,7 +821,9 @@ def predict_streaming(
761821 )
762822
763823 prediction_service_client = self ._endpoint ._prediction_client
764- for prediction_dict in _streaming_prediction .predict_stream_of_dicts_from_single_dict (
824+ for (
825+ prediction_dict
826+ ) in _streaming_prediction .predict_stream_of_dicts_from_single_dict (
765827 prediction_service_client = prediction_service_client ,
766828 endpoint_name = self ._endpoint_name ,
767829 instance = prediction_request .instance ,
@@ -955,6 +1017,7 @@ class _PreviewTextGenerationModel(
9551017 _PreviewTunableTextModelMixin ,
9561018 _PreviewModelWithBatchPredict ,
9571019 _evaluatable_language_models ._EvaluatableLanguageModel ,
1020+ _CountTokensMixin ,
9581021):
9591022 # Do not add docstring so that it's inherited from the base class.
9601023 __name__ = "TextGenerationModel"
@@ -1094,6 +1157,7 @@ class TextEmbeddingInput:
10941157 Specifies that the embeddings will be used for clustering.
10951158 title: Optional identifier of the text content.
10961159 """
1160+
10971161 __module__ = "vertexai.language_models"
10981162
10991163 text : str
@@ -1113,6 +1177,7 @@ class TextEmbeddingModel(_LanguageModel):
11131177 vector = embedding.values
11141178 print(len(vector))
11151179 """
1180+
11161181 __module__ = "vertexai.language_models"
11171182
11181183 _LAUNCH_STAGE = _model_garden_models ._SDK_GA_LAUNCH_STAGE
@@ -1173,7 +1238,8 @@ def _parse_text_embedding_response(
11731238 _prediction_response = prediction_response ,
11741239 )
11751240
1176- def get_embeddings (self ,
1241+ def get_embeddings (
1242+ self ,
11771243 texts : List [Union [str , TextEmbeddingInput ]],
11781244 * ,
11791245 auto_truncate : bool = True ,
@@ -1207,7 +1273,8 @@ def get_embeddings(self,
12071273
12081274 return results
12091275
1210- async def get_embeddings_async (self ,
1276+ async def get_embeddings_async (
1277+ self ,
12111278 texts : List [Union [str , TextEmbeddingInput ]],
12121279 * ,
12131280 auto_truncate : bool = True ,
@@ -1242,7 +1309,9 @@ async def get_embeddings_async(self,
12421309 return results
12431310
12441311
1245- class _PreviewTextEmbeddingModel (TextEmbeddingModel , _ModelWithBatchPredict ):
1312+ class _PreviewTextEmbeddingModel (
1313+ TextEmbeddingModel , _ModelWithBatchPredict , _CountTokensMixin
1314+ ):
12461315 __name__ = "TextEmbeddingModel"
12471316 __module__ = "vertexai.preview.language_models"
12481317
@@ -1252,6 +1321,7 @@ class _PreviewTextEmbeddingModel(TextEmbeddingModel, _ModelWithBatchPredict):
12521321@dataclasses .dataclass
12531322class TextEmbeddingStatistics :
12541323 """Text embedding statistics."""
1324+
12551325 __module__ = "vertexai.language_models"
12561326
12571327 token_count : int
@@ -1261,6 +1331,7 @@ class TextEmbeddingStatistics:
12611331@dataclasses .dataclass
12621332class TextEmbedding :
12631333 """Text embedding vector and statistics."""
1334+
12641335 __module__ = "vertexai.language_models"
12651336
12661337 values : List [float ]
@@ -1271,6 +1342,7 @@ class TextEmbedding:
12711342@dataclasses .dataclass
12721343class InputOutputTextPair :
12731344 """InputOutputTextPair represents a pair of input and output texts."""
1345+
12741346 __module__ = "vertexai.language_models"
12751347
12761348 input_text : str
@@ -1285,6 +1357,7 @@ class ChatMessage:
12851357 content: Content of the message.
12861358 author: Author of the message.
12871359 """
1360+
12881361 __module__ = "vertexai.language_models"
12891362
12901363 content : str
@@ -1362,6 +1435,7 @@ class ChatModel(_ChatModelBase, _TunableChatModelMixin):
13621435
13631436 chat.send_message("Do you know any cool events this weekend?")
13641437 """
1438+
13651439 __module__ = "vertexai.language_models"
13661440
13671441 _INSTANCE_SCHEMA_URI = "gs://google-cloud-aiplatform/schema/predict/instance/chat_generation_1.0.0.yaml"
@@ -1388,6 +1462,7 @@ class CodeChatModel(_ChatModelBase):
13881462
13891463 code_chat.send_message("Please help write a function to calculate the min of two numbers")
13901464 """
1465+
13911466 __module__ = "vertexai.language_models"
13921467
13931468 _INSTANCE_SCHEMA_URI = "gs://google-cloud-aiplatform/schema/predict/instance/codechat_generation_1.0.0.yaml"
@@ -1739,7 +1814,9 @@ def send_message_streaming(
17391814
17401815 full_response_text = ""
17411816
1742- for prediction_dict in _streaming_prediction .predict_stream_of_dicts_from_single_dict (
1817+ for (
1818+ prediction_dict
1819+ ) in _streaming_prediction .predict_stream_of_dicts_from_single_dict (
17431820 prediction_service_client = prediction_service_client ,
17441821 endpoint_name = self ._model ._endpoint_name ,
17451822 instance = prediction_request .instance ,
@@ -1770,6 +1847,7 @@ class ChatSession(_ChatSessionBase):
17701847
17711848 Within a chat session, the model keeps context and remembers the previous conversation.
17721849 """
1850+
17731851 __module__ = "vertexai.language_models"
17741852
17751853 def __init__ (
@@ -1802,6 +1880,7 @@ class CodeChatSession(_ChatSessionBase):
18021880
18031881 Within a code chat session, the model keeps context and remembers the previous converstion.
18041882 """
1883+
18051884 __module__ = "vertexai.language_models"
18061885
18071886 def __init__ (
@@ -1924,6 +2003,7 @@ class CodeGenerationModel(_LanguageModel):
19242003 prefix="def reverse_string(s):",
19252004 ))
19262005 """
2006+
19272007 __module__ = "vertexai.language_models"
19282008
19292009 _INSTANCE_SCHEMA_URI = "gs://google-cloud-aiplatform/schema/predict/instance/code_generation_1.0.0.yaml"
@@ -2074,7 +2154,9 @@ def predict_streaming(
20742154 )
20752155
20762156 prediction_service_client = self ._endpoint ._prediction_client
2077- for prediction_dict in _streaming_prediction .predict_stream_of_dicts_from_single_dict (
2157+ for (
2158+ prediction_dict
2159+ ) in _streaming_prediction .predict_stream_of_dicts_from_single_dict (
20782160 prediction_service_client = prediction_service_client ,
20792161 endpoint_name = self ._endpoint_name ,
20802162 instance = prediction_request .instance ,
0 commit comments