Skip to content

Commit ff47513

Browse files
Ark-kuncopybara-github
authored andcommitted
chore: LLM - Switched to a more robust way to get the tuned model resource name from the pipeline job
PiperOrigin-RevId: 553633219
1 parent be01f31 commit ff47513

2 files changed

Lines changed: 13 additions & 11 deletions

File tree

tests/unit/aiplatform/test_language_models.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -360,12 +360,14 @@ def make_pipeline_job(state):
360360
task_details=[
361361
gca_pipeline_job.PipelineTaskDetail(
362362
task_id=456,
363-
task_name="upload-llm-model",
363+
task_name="tune-large-model-20230724214903",
364364
execution=GapicExecution(
365-
name="test-execution-name",
366-
display_name="evaluation_metrics",
365+
name="projects/123/locations/europe-west4/metadataStores/default/executions/...",
366+
display_name="tune-large-model-20230724214903",
367+
schema_title="system.Run",
367368
metadata={
368-
"output:model_resource_name": "projects/123/locations/us-central1/models/456"
369+
"output:model_resource_name": "projects/123/locations/us-central1/models/456",
370+
"output:endpoint_resource_name": "projects/123/locations/us-central1/endpoints/456",
369371
},
370372
),
371373
),

vertexai/language_models/_language_models.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1037,19 +1037,19 @@ def result(self) -> "_LanguageModel":
10371037
if self._model:
10381038
return self._model
10391039
self._job.wait()
1040-
upload_model_tasks = [
1041-
task_info
1042-
for task_info in self._job.gca_resource.job_detail.task_details
1043-
if task_info.task_name == "upload-llm-model"
1040+
root_pipeline_tasks = [
1041+
task_detail
1042+
for task_detail in self._job.gca_resource.job_detail.task_details
1043+
if task_detail.execution.schema_title == "system.Run"
10441044
]
1045-
if len(upload_model_tasks) != 1:
1045+
if len(root_pipeline_tasks) != 1:
10461046
raise RuntimeError(
10471047
f"Failed to get the model name from the tuning pipeline: {self._job.name}"
10481048
)
1049-
upload_model_task = upload_model_tasks[0]
1049+
root_pipeline_task = root_pipeline_tasks[0]
10501050

10511051
# Trying to get model name from output parameter
1052-
vertex_model_name = upload_model_task.execution.metadata[
1052+
vertex_model_name = root_pipeline_task.execution.metadata[
10531053
"output:model_resource_name"
10541054
].strip()
10551055
_LOGGER.info(f"Tuning has completed. Created Vertex Model: {vertex_model_name}")

0 commit comments

Comments
 (0)