2323from google .cloud .aiplatform import base
2424from google .cloud .aiplatform import initializer as aiplatform_initializer
2525from google .cloud .aiplatform import utils as aiplatform_utils
26+ from google .cloud .aiplatform .compat import types as aiplatform_types
2627from google .cloud .aiplatform .utils import gcs_utils
2728from vertexai ._model_garden import _model_garden_models
2829from vertexai .language_models import (
@@ -148,18 +149,24 @@ def tune_model(
148149 self ,
149150 training_data : Union [str , "pandas.core.frame.DataFrame" ],
150151 * ,
151- train_steps : int = 1000 ,
152+ train_steps : Optional [ int ] = None ,
152153 learning_rate : Optional [float ] = None ,
153154 learning_rate_multiplier : Optional [float ] = None ,
154155 tuning_job_location : Optional [str ] = None ,
155156 tuned_model_location : Optional [str ] = None ,
156157 model_display_name : Optional [str ] = None ,
157158 tuning_evaluation_spec : Optional ["TuningEvaluationSpec" ] = None ,
158159 default_context : Optional [str ] = None ,
159- ):
160+ ) -> "_LanguageModelTuningJob" :
160161 """Tunes a model based on training data.
161162
162- This method launches a model tuning job that can take some time.
163+ This method launches and returns an asynchronous model tuning job.
164+ Usage:
165+ ```
166+ tuning_job = model.tune_model(...)
167+ ... do some other work
168+ tuned_model = tuning_job.get_tuned_model() # Blocks until tuning is complete
169+ ```
163170
164171 Args:
165172 training_data: A Pandas DataFrame or a URI pointing to data in JSON lines format.
@@ -303,16 +310,68 @@ def _tune_model(
303310 base_model = self ,
304311 job = pipeline_job ,
305312 )
306- self ._job = job
307- tuned_model = job .result ()
308- # The UXR study attendees preferred to tune model in place
309- self ._endpoint = tuned_model ._endpoint
310- self ._endpoint_name = tuned_model ._endpoint_name
313+ return job
311314
312315
313316class _TunableTextModelMixin (_TunableModelMixin ):
314317 """Text model that can be tuned."""
315318
319+ def tune_model (
320+ self ,
321+ training_data : Union [str , "pandas.core.frame.DataFrame" ],
322+ * ,
323+ train_steps : Optional [int ] = None ,
324+ learning_rate_multiplier : Optional [float ] = None ,
325+ tuning_job_location : Optional [str ] = None ,
326+ tuned_model_location : Optional [str ] = None ,
327+ model_display_name : Optional [str ] = None ,
328+ tuning_evaluation_spec : Optional ["TuningEvaluationSpec" ] = None ,
329+ ) -> "_LanguageModelTuningJob" :
330+ """Tunes a model based on training data.
331+
332+ This method launches and returns an asynchronous model tuning job.
333+ Usage:
334+ ```
335+ tuning_job = model.tune_model(...)
336+ ... do some other work
337+ tuned_model = tuning_job.get_tuned_model() # Blocks until tuning is complete
338+
339+ Args:
340+ training_data: A Pandas DataFrame or a URI pointing to data in JSON lines format.
341+ The dataset schema is model-specific.
342+ See https://cloud.google.com/vertex-ai/docs/generative-ai/models/tune-models#dataset_format
343+ train_steps: Number of training batches to tune on (batch size is 8 samples).
344+ learning_rate_multiplier: Learning rate multiplier to use in tuning.
345+ tuning_job_location: GCP location where the tuning job should be run.
346+ Only "europe-west4" and "us-central1" locations are supported for now.
347+ tuned_model_location: GCP location where the tuned model should be deployed. Only "us-central1" is supported for now.
348+ model_display_name: Custom display name for the tuned model.
349+ tuning_evaluation_spec: Specification for the model evaluation during tuning.
350+
351+ Returns:
352+ A `LanguageModelTuningJob` object that represents the tuning job.
353+ Calling `job.result()` blocks until the tuning is complete and returns a `LanguageModel` object.
354+
355+ Raises:
356+ ValueError: If the "tuning_job_location" value is not supported
357+ ValueError: If the "tuned_model_location" value is not supported
358+ RuntimeError: If the model does not support tuning
359+ """
360+ # Note: Chat models do not support default_context
361+ return super ().tune_model (
362+ training_data = training_data ,
363+ train_steps = train_steps ,
364+ learning_rate_multiplier = learning_rate_multiplier ,
365+ tuning_job_location = tuning_job_location ,
366+ tuned_model_location = tuned_model_location ,
367+ model_display_name = model_display_name ,
368+ tuning_evaluation_spec = tuning_evaluation_spec ,
369+ )
370+
371+
372+ class _PreviewTunableTextModelMixin (_TunableModelMixin ):
373+ """Text model that can be tuned."""
374+
316375 def tune_model (
317376 self ,
318377 training_data : Union [str , "pandas.core.frame.DataFrame" ],
@@ -324,10 +383,20 @@ def tune_model(
324383 tuned_model_location : Optional [str ] = None ,
325384 model_display_name : Optional [str ] = None ,
326385 tuning_evaluation_spec : Optional ["TuningEvaluationSpec" ] = None ,
327- ):
386+ ) -> "_LanguageModelTuningJob" :
328387 """Tunes a model based on training data.
329388
330- This method launches a model tuning job that can take some time.
389+ This method launches a model tuning job, waits for completion,
390+ updates the model in-place. This method returns job object for forward
391+ compatibility.
392+ In the future (GA), this method will become asynchronous and will stop
393+ updating the model in-place.
394+
395+ Usage:
396+ ```
397+ tuning_job = model.tune_model(...) # Blocks until tuning is complete
398+ tuned_model = tuning_job.get_tuned_model() # Blocks until tuning is complete
399+ ```
331400
332401 Args:
333402 training_data: A Pandas DataFrame or a URI pointing to data in JSON lines format.
@@ -353,7 +422,7 @@ def tune_model(
353422 RuntimeError: If the model does not support tuning
354423 """
355424 # Note: Chat models do not support default_context
356- return super ().tune_model (
425+ job = super ().tune_model (
357426 training_data = training_data ,
358427 train_steps = train_steps ,
359428 learning_rate = learning_rate ,
@@ -363,11 +432,74 @@ def tune_model(
363432 model_display_name = model_display_name ,
364433 tuning_evaluation_spec = tuning_evaluation_spec ,
365434 )
435+ tuned_model = job .get_tuned_model ()
436+ self ._endpoint = tuned_model ._endpoint
437+ self ._endpoint_name = tuned_model ._endpoint_name
438+ return job
366439
367440
368441class _TunableChatModelMixin (_TunableModelMixin ):
369442 """Chat model that can be tuned."""
370443
444+ def tune_model (
445+ self ,
446+ training_data : Union [str , "pandas.core.frame.DataFrame" ],
447+ * ,
448+ train_steps : Optional [int ] = None ,
449+ learning_rate_multiplier : Optional [float ] = None ,
450+ tuning_job_location : Optional [str ] = None ,
451+ tuned_model_location : Optional [str ] = None ,
452+ model_display_name : Optional [str ] = None ,
453+ default_context : Optional [str ] = None ,
454+ ) -> "_LanguageModelTuningJob" :
455+ """Tunes a model based on training data.
456+
457+ This method launches and returns an asynchronous model tuning job.
458+ Usage:
459+ ```
460+ tuning_job = model.tune_model(...)
461+ ... do some other work
462+ tuned_model = tuning_job.get_tuned_model() # Blocks until tuning is complete
463+ ```
464+
465+ Args:
466+ training_data: A Pandas DataFrame or a URI pointing to data in JSON lines format.
467+ The dataset schema is model-specific.
468+ See https://cloud.google.com/vertex-ai/docs/generative-ai/models/tune-models#dataset_format
469+ train_steps: Number of training batches to tune on (batch size is 8 samples).
470+ learning_rate: Deprecated. Use learning_rate_multiplier instead.
471+ Learning rate to use in tuning.
472+ learning_rate_multiplier: Learning rate multiplier to use in tuning.
473+ tuning_job_location: GCP location where the tuning job should be run.
474+ Only "europe-west4" and "us-central1" locations are supported for now.
475+ tuned_model_location: GCP location where the tuned model should be deployed. Only "us-central1" is supported for now.
476+ model_display_name: Custom display name for the tuned model.
477+ default_context: The context to use for all training samples by default.
478+
479+ Returns:
480+ A `LanguageModelTuningJob` object that represents the tuning job.
481+ Calling `job.result()` blocks until the tuning is complete and returns a `LanguageModel` object.
482+
483+ Raises:
484+ ValueError: If the "tuning_job_location" value is not supported
485+ ValueError: If the "tuned_model_location" value is not supported
486+ RuntimeError: If the model does not support tuning
487+ """
488+ # Note: Chat models do not support tuning_evaluation_spec
489+ return super ().tune_model (
490+ training_data = training_data ,
491+ train_steps = train_steps ,
492+ learning_rate_multiplier = learning_rate_multiplier ,
493+ tuning_job_location = tuning_job_location ,
494+ tuned_model_location = tuned_model_location ,
495+ model_display_name = model_display_name ,
496+ default_context = default_context ,
497+ )
498+
499+
500+ class _PreviewTunableChatModelMixin (_TunableModelMixin ):
501+ """Chat model that can be tuned."""
502+
371503 def tune_model (
372504 self ,
373505 training_data : Union [str , "pandas.core.frame.DataFrame" ],
@@ -379,10 +511,20 @@ def tune_model(
379511 tuned_model_location : Optional [str ] = None ,
380512 model_display_name : Optional [str ] = None ,
381513 default_context : Optional [str ] = None ,
382- ):
514+ ) -> "_LanguageModelTuningJob" :
383515 """Tunes a model based on training data.
384516
385- This method launches a model tuning job that can take some time.
517+ This method launches a model tuning job, waits for completion,
518+ updates the model in-place. This method returns job object for forward
519+ compatibility.
520+ In the future (GA), this method will become asynchronous and will stop
521+ updating the model in-place.
522+
523+ Usage:
524+ ```
525+ tuning_job = model.tune_model(...) # Blocks until tuning is complete
526+ tuned_model = tuning_job.get_tuned_model() # Blocks until tuning is complete
527+ ```
386528
387529 Args:
388530 training_data: A Pandas DataFrame or a URI pointing to data in JSON lines format.
@@ -408,7 +550,7 @@ def tune_model(
408550 RuntimeError: If the model does not support tuning
409551 """
410552 # Note: Chat models do not support tuning_evaluation_spec
411- return super ().tune_model (
553+ job = super ().tune_model (
412554 training_data = training_data ,
413555 train_steps = train_steps ,
414556 learning_rate = learning_rate ,
@@ -418,6 +560,10 @@ def tune_model(
418560 model_display_name = model_display_name ,
419561 default_context = default_context ,
420562 )
563+ tuned_model = job .get_tuned_model ()
564+ self ._endpoint = tuned_model ._endpoint
565+ self ._endpoint_name = tuned_model ._endpoint_name
566+ return job
421567
422568
423569@dataclasses .dataclass
@@ -746,7 +892,7 @@ class TextGenerationModel(_TextGenerationModel, _ModelWithBatchPredict):
746892
747893class _PreviewTextGenerationModel (
748894 _TextGenerationModel ,
749- _TunableTextModelMixin ,
895+ _PreviewTunableTextModelMixin ,
750896 _PreviewModelWithBatchPredict ,
751897 _evaluatable_language_models ._EvaluatableLanguageModel ,
752898):
@@ -1076,7 +1222,7 @@ class ChatModel(_ChatModelBase):
10761222 _INSTANCE_SCHEMA_URI = "gs://google-cloud-aiplatform/schema/predict/instance/chat_generation_1.0.0.yaml"
10771223
10781224
1079- class _PreviewChatModel (ChatModel , _TunableChatModelMixin ):
1225+ class _PreviewChatModel (ChatModel , _PreviewTunableChatModelMixin ):
10801226 _LAUNCH_STAGE = _model_garden_models ._SDK_PUBLIC_PREVIEW_LAUNCH_STAGE
10811227
10821228
@@ -1650,11 +1796,12 @@ def __init__(
16501796 base_model : _LanguageModel ,
16511797 job : aiplatform .PipelineJob ,
16521798 ):
1799+ """Internal constructor. Do not call directly."""
16531800 self ._base_model = base_model
16541801 self ._job = job
16551802 self ._model : Optional [_LanguageModel ] = None
16561803
1657- def result (self ) -> "_LanguageModel" :
1804+ def get_tuned_model (self ) -> "_LanguageModel" :
16581805 """Blocks until the tuning is complete and returns a `LanguageModel` object."""
16591806 if self ._model :
16601807 return self ._model
@@ -1681,11 +1828,12 @@ def result(self) -> "_LanguageModel":
16811828 return self ._model
16821829
16831830 @property
1684- def status (self ):
1685- """Job status"""
1831+ def _status (self ) -> Optional [ aiplatform_types . pipeline_state . PipelineState ] :
1832+ """Job status. """
16861833 return self ._job .state
16871834
1688- def cancel (self ):
1835+ def _cancel (self ):
1836+ """Cancels the job."""
16891837 self ._job .cancel ()
16901838
16911839
0 commit comments