1515
1616"""Base class for working with Model Garden models."""
1717
18- import abc
1918import dataclasses
2019from typing import Dict , Optional , Type
2120
2524from google .cloud .aiplatform import models as aiplatform_models
2625from google .cloud .aiplatform import _publisher_models
2726
27+ from google .cloud .aiplatform .compat .types import (
28+ publisher_model as gca_publisher_model ,
29+ )
2830
2931_SUPPORTED_PUBLISHERS = ["google" ]
3032
3133_SHORT_MODEL_ID_TO_TUNING_PIPELINE_MAP = {
3234 "text-bison" : "https://us-kfp.pkg.dev/vertex-ai/large-language-model-pipelines/tune-large-model/sdk-1-25"
3335}
3436
37+ _SDK_PUBLIC_PREVIEW_LAUNCH_STAGE = frozenset (
38+ [
39+ gca_publisher_model .PublisherModel .LaunchStage .PUBLIC_PREVIEW ,
40+ gca_publisher_model .PublisherModel .LaunchStage .GA ,
41+ ]
42+ )
43+ _SDK_GA_LAUNCH_STAGE = frozenset ([gca_publisher_model .PublisherModel .LaunchStage .GA ])
44+
3545_LOGGER = base .Logger (__name__ )
3646
3747
3848@dataclasses .dataclass
3949class _ModelInfo :
4050 endpoint_name : str
4151 interface_class : Type ["_ModelGardenModel" ]
52+ publisher_model_resource : _publisher_models ._PublisherModel
4253 tuning_pipeline_uri : Optional [str ] = None
4354 tuning_model_id : Optional [str ] = None
4455
@@ -114,6 +125,7 @@ def _get_model_info(
114125 return _ModelInfo (
115126 endpoint_name = endpoint_name ,
116127 interface_class = interface_class ,
128+ publisher_model_resource = publisher_model_res ,
117129 tuning_pipeline_uri = tuning_pipeline_uri ,
118130 tuning_model_id = tuning_model_id ,
119131 )
@@ -122,6 +134,28 @@ def _get_model_info(
122134class _ModelGardenModel :
123135 """Base class for shared methods and properties across Model Garden models."""
124136
137+ _LAUNCH_STAGE : gca_publisher_model .PublisherModel .LaunchStage = (
138+ _SDK_PUBLIC_PREVIEW_LAUNCH_STAGE
139+ )
140+
141+ def _validate_launch_stage (
142+ self ,
143+ publisher_model_resource : gca_publisher_model .PublisherModel ,
144+ ) -> None :
145+ """Validates the model class _LAUNCH_STAGE matches the PublisherModel resource's launch stage.
146+
147+ Args:
148+ publisher_model_resource (gca_publisher_model.PublisherModel
149+ The GAPIC PublisherModel resource for this model.
150+ """
151+
152+ publisher_launch_stage = publisher_model_resource .launch_stage
153+
154+ if publisher_launch_stage not in self ._LAUNCH_STAGE :
155+ raise ValueError (
156+ f"The model you are trying to instantiate does not support the launch stage: { publisher_launch_stage .name } "
157+ )
158+
125159 # Subclasses override this attribute to specify their instance schema
126160 _INSTANCE_SCHEMA_URI : Optional [str ] = None
127161
@@ -174,6 +208,8 @@ def from_pretrained(cls, model_name: str) -> "_ModelGardenModel":
174208 f"{ model_name } is of type { model_info .interface_class .__name__ } not of type { cls .__name__ } "
175209 )
176210
211+ cls ._validate_launch_stage (cls , model_info .publisher_model_resource )
212+
177213 return model_info .interface_class (
178214 model_id = model_name ,
179215 endpoint_name = model_info .endpoint_name ,
0 commit comments