Skip to content

Commit f302d1f

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
feat: In run_query_job, rename gcs_bucket to gcs_uri and allow the case that user sets the filename for the output.
PiperOrigin-RevId: 893135854
1 parent 7a8f703 commit f302d1f

File tree

3 files changed

+147
-69
lines changed

3 files changed

+147
-69
lines changed

tests/unit/vertexai/genai/test_agent_engines.py

Lines changed: 102 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -2948,7 +2948,7 @@ def test_run_query_job_agent_engine(self, mock_uuid, get_mock, mock_storage_clie
29482948
name=_TEST_AGENT_ENGINE_RESOURCE_NAME,
29492949
config={
29502950
"query": _TEST_QUERY_PROMPT,
2951-
"gcs_bucket": "gs://my-input-bucket/",
2951+
"output_gcs_uri": "gs://my-input-bucket/",
29522952
},
29532953
)
29542954

@@ -2959,17 +2959,17 @@ def test_run_query_job_agent_engine(self, mock_uuid, get_mock, mock_storage_clie
29592959

29602960
assert result == _genai_types.RunQueryJobResult(
29612961
job_name="projects/123/locations/us-central1/reasoningEngines/456/operations/789",
2962-
input_gcs_uri="gs://my-input-bucket/input_b92b9b89-4585-4146-8ee5-22fe99802a8e.json",
2963-
output_gcs_uri="gs://my-input-bucket/output_b92b9b89-4585-4146-8ee5-22fe99802a8e.json",
2962+
input_gcs_uri="gs://my-input-bucket/b92b9b89-4585-4146-8ee5-22fe99802a8e_input.json",
2963+
output_gcs_uri="gs://my-input-bucket/b92b9b89-4585-4146-8ee5-22fe99802a8e_output.json",
29642964
)
29652965

29662966
request_mock.assert_called_with(
29672967
"post",
29682968
f"{_TEST_AGENT_ENGINE_RESOURCE_NAME}:asyncQuery",
29692969
{
29702970
"_url": {"name": _TEST_AGENT_ENGINE_RESOURCE_NAME},
2971-
"inputGcsUri": "gs://my-input-bucket/input_b92b9b89-4585-4146-8ee5-22fe99802a8e.json",
2972-
"outputGcsUri": "gs://my-input-bucket/output_b92b9b89-4585-4146-8ee5-22fe99802a8e.json",
2971+
"inputGcsUri": "gs://my-input-bucket/b92b9b89-4585-4146-8ee5-22fe99802a8e_input.json",
2972+
"outputGcsUri": "gs://my-input-bucket/b92b9b89-4585-4146-8ee5-22fe99802a8e_output.json",
29732973
},
29742974
None,
29752975
)
@@ -2980,38 +2980,18 @@ def test_run_query_job_agent_engine_missing_query(self):
29802980
):
29812981
self.client.agent_engines.run_query_job(
29822982
name=_TEST_AGENT_ENGINE_RESOURCE_NAME,
2983-
config={"gcs_bucket": "gs://my-input-bucket/"},
2983+
config={"output_gcs_uri": "gs://my-input-bucket/"},
29842984
)
29852985

2986-
def test_run_query_job_agent_engine_missing_bucket(self):
2986+
def test_run_query_job_agent_engine_missing_uri(self):
29872987
with pytest.raises(
2988-
ValueError, match="`gcs_bucket` is required in the config object."
2988+
ValueError, match="`output_gcs_uri` is required in the config object."
29892989
):
29902990
self.client.agent_engines.run_query_job(
29912991
name=_TEST_AGENT_ENGINE_RESOURCE_NAME,
29922992
config={"query": _TEST_QUERY_PROMPT},
29932993
)
29942994

2995-
@mock.patch.object(agent_engines.AgentEngines, "_get")
2996-
def test_run_query_job_agent_engine_missing_cloud_run_job(self, get_mock):
2997-
get_mock.return_value = _genai_types.ReasoningEngine(
2998-
name=_TEST_AGENT_ENGINE_RESOURCE_NAME,
2999-
spec=_genai_types.ReasoningEngineSpec(
3000-
deployment_spec=_genai_types.ReasoningEngineSpecDeploymentSpec(env=[])
3001-
),
3002-
)
3003-
with pytest.raises(
3004-
ValueError,
3005-
match="Your ReasoningEngine does not support long running queries, please update your ReasoningEngine and try again.",
3006-
):
3007-
self.client.agent_engines.run_query_job(
3008-
name=_TEST_AGENT_ENGINE_RESOURCE_NAME,
3009-
config={
3010-
"query": _TEST_QUERY_PROMPT,
3011-
"gcs_bucket": "gs://my-input-bucket/",
3012-
},
3013-
)
3014-
30152995
@mock.patch("google.cloud.storage.Client")
30162996
@mock.patch.object(agent_engines.AgentEngines, "_get")
30172997
@mock.patch("uuid.uuid4")
@@ -3053,10 +3033,103 @@ def test_run_query_job_agent_engine_bucket_creation_forbidden(
30533033
name=_TEST_AGENT_ENGINE_RESOURCE_NAME,
30543034
config={
30553035
"query": _TEST_QUERY_PROMPT,
3056-
"gcs_bucket": "gs://my-input-bucket/",
3036+
"output_gcs_uri": "gs://my-input-bucket/",
30573037
},
30583038
)
30593039

3040+
@mock.patch("google.cloud.storage.Client")
3041+
@mock.patch.object(agent_engines.AgentEngines, "_get")
3042+
@mock.patch("uuid.uuid4")
3043+
def test_run_query_job_agent_engine_file_uri(
3044+
self, mock_uuid, get_mock, mock_storage_client
3045+
):
3046+
with mock.patch.object(
3047+
self.client.agent_engines._api_client, "request"
3048+
) as request_mock:
3049+
request_mock.return_value = genai_types.HttpResponse(
3050+
body='{"name": "projects/123/locations/us-central1/reasoningEngines/456/operations/789"}'
3051+
)
3052+
3053+
mock_bucket = mock.Mock()
3054+
mock_bucket.exists.return_value = True
3055+
mock_blob = mock.Mock()
3056+
mock_bucket.blob.return_value = mock_blob
3057+
mock_storage_client.return_value.bucket.return_value = mock_bucket
3058+
3059+
get_mock.return_value = _genai_types.ReasoningEngine(
3060+
name=_TEST_AGENT_ENGINE_RESOURCE_NAME,
3061+
spec=_genai_types.ReasoningEngineSpec(
3062+
deployment_spec=_genai_types.ReasoningEngineSpecDeploymentSpec(
3063+
env=[_genai_types.EnvVar(name="input_gcs_uri", value="")]
3064+
)
3065+
),
3066+
)
3067+
3068+
result = self.client.agent_engines.run_query_job(
3069+
name=_TEST_AGENT_ENGINE_RESOURCE_NAME,
3070+
config={
3071+
"query": _TEST_QUERY_PROMPT,
3072+
"output_gcs_uri": "gs://my-input-bucket/path/output.json",
3073+
},
3074+
)
3075+
3076+
mock_blob.upload_from_string.assert_called_once_with(_TEST_QUERY_PROMPT)
3077+
mock_bucket.blob.assert_called_with("path/output_input.json")
3078+
3079+
assert result == _genai_types.RunQueryJobResult(
3080+
job_name="projects/123/locations/us-central1/reasoningEngines/456/operations/789",
3081+
input_gcs_uri="gs://my-input-bucket/path/output_input.json",
3082+
output_gcs_uri="gs://my-input-bucket/path/output.json",
3083+
)
3084+
3085+
@mock.patch("google.cloud.storage.Client")
3086+
@mock.patch.object(agent_engines.AgentEngines, "_get")
3087+
@mock.patch("uuid.uuid4")
3088+
def test_run_query_job_agent_engine_directory_no_slash(
3089+
self, mock_uuid, get_mock, mock_storage_client
3090+
):
3091+
with mock.patch.object(
3092+
self.client.agent_engines._api_client, "request"
3093+
) as request_mock:
3094+
request_mock.return_value = genai_types.HttpResponse(
3095+
body='{"name": "projects/123/locations/us-central1/reasoningEngines/456/operations/789"}'
3096+
)
3097+
3098+
mock_bucket = mock.Mock()
3099+
mock_bucket.exists.return_value = True
3100+
mock_blob = mock.Mock()
3101+
mock_bucket.blob.return_value = mock_blob
3102+
mock_storage_client.return_value.bucket.return_value = mock_bucket
3103+
3104+
mock_uuid.return_value.hex = "b92b9b89-4585-4146-8ee5-22fe99802a8e"
3105+
3106+
get_mock.return_value = _genai_types.ReasoningEngine(
3107+
name=_TEST_AGENT_ENGINE_RESOURCE_NAME,
3108+
spec=_genai_types.ReasoningEngineSpec(
3109+
deployment_spec=_genai_types.ReasoningEngineSpecDeploymentSpec(
3110+
env=[_genai_types.EnvVar(name="input_gcs_uri", value="")]
3111+
)
3112+
),
3113+
)
3114+
3115+
result = self.client.agent_engines.run_query_job(
3116+
name=_TEST_AGENT_ENGINE_RESOURCE_NAME,
3117+
config={
3118+
"query": _TEST_QUERY_PROMPT,
3119+
"output_gcs_uri": "gs://my-input-bucket/path",
3120+
},
3121+
)
3122+
3123+
mock_bucket.blob.assert_called_with(
3124+
"path/b92b9b89-4585-4146-8ee5-22fe99802a8e_input.json"
3125+
)
3126+
3127+
assert result == _genai_types.RunQueryJobResult(
3128+
job_name="projects/123/locations/us-central1/reasoningEngines/456/operations/789",
3129+
input_gcs_uri="gs://my-input-bucket/path/b92b9b89-4585-4146-8ee5-22fe99802a8e_input.json",
3130+
output_gcs_uri="gs://my-input-bucket/path/b92b9b89-4585-4146-8ee5-22fe99802a8e_output.json",
3131+
)
3132+
30603133
def test_query_agent_engine_async(self):
30613134
agent = self.client.agent_engines._register_api_methods(
30623135
agent_engine=_genai_types.AgentEngine(

vertexai/_genai/agent_engines.py

Lines changed: 30 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1077,7 +1077,7 @@ def run_query_job(
10771077
the default configuration will be used. This can be used to specify
10781078
the following fields:
10791079
- query: The query to send to the agent engine.
1080-
- gcs_bucket: The GCS bucket path to use for the query.
1080+
- output_gcs_uri: The GCS URI to use for the output.
10811081
"""
10821082
from google.cloud import storage # type: ignore[attr-defined]
10831083
from google.api_core import exceptions
@@ -1090,41 +1090,40 @@ def run_query_job(
10901090

10911091
if not config.query:
10921092
raise ValueError("`query` is required in the config object.")
1093-
if not config.gcs_bucket:
1094-
raise ValueError("`gcs_bucket` is required in the config object.")
1095-
1096-
api_resource = self._get(name=name)
1097-
1098-
is_supported = False
1099-
if (
1100-
api_resource.spec
1101-
and api_resource.spec.deployment_spec
1102-
and api_resource.spec.deployment_spec.env
1103-
):
1104-
for env in api_resource.spec.deployment_spec.env:
1105-
if env.name in [
1106-
"INPUT_GCS_URI",
1107-
"OUTPUT_GCS_URI",
1108-
"input_gcs_uri",
1109-
"output_gcs_uri",
1110-
]:
1111-
is_supported = True
1112-
break
1113-
1114-
if not is_supported:
1115-
raise ValueError(
1116-
"Your ReasoningEngine does not support long running queries, "
1117-
"please update your ReasoningEngine and try again."
1118-
)
1119-
1120-
gcs_bucket = config.gcs_bucket.rstrip("/")
1093+
if not config.output_gcs_uri:
1094+
raise ValueError("`output_gcs_uri` is required in the config object.")
1095+
1096+
output_gcs_uri = config.output_gcs_uri
1097+
is_file = False
1098+
last_part = ""
1099+
if not output_gcs_uri.endswith("/"):
1100+
last_part = output_gcs_uri.split("/")[-1]
1101+
if "." in last_part:
1102+
is_file = True
1103+
1104+
if is_file:
1105+
path_parts = output_gcs_uri.split("/")
1106+
file_name = path_parts[-1]
1107+
base_uri = "/".join(path_parts[:-1])
1108+
name_parts = file_name.rsplit(".", 1)
1109+
if len(name_parts) == 2:
1110+
name_part, ext = name_parts[0], "." + name_parts[1]
1111+
else:
1112+
name_part = name_parts[0]
1113+
ext = ""
1114+
input_gcs_uri = f"{base_uri}/{name_part}_input{ext}"
1115+
else:
1116+
job_uuid = uuid.uuid4().hex
1117+
gcs_path = output_gcs_uri.rstrip("/")
1118+
input_gcs_uri = f"{gcs_path}/{job_uuid}_input.json"
1119+
output_gcs_uri = f"{gcs_path}/{job_uuid}_output.json"
11211120

11221121
storage_client = storage.Client(
11231122
project=self._api_client.project, credentials=self._api_client._credentials
11241123
)
11251124

11261125
# Handle creating the bucket if it does not exist
1127-
bucket_name = gcs_bucket.replace("gs://", "").split("/")[0]
1126+
bucket_name = config.output_gcs_uri.replace("gs://", "").split("/")[0]
11281127
bucket = storage_client.bucket(bucket_name)
11291128

11301129
try:
@@ -1144,15 +1143,10 @@ def run_query_job(
11441143
"The service account may lack 'storage.buckets.create' permission."
11451144
) from e
11461145

1147-
job_uuid = uuid.uuid4().hex
1148-
input_blob_name = f"input_{job_uuid}.json"
1149-
input_gcs_uri = f"{gcs_bucket}/{input_blob_name}"
1146+
input_blob_name = input_gcs_uri.replace(f"gs://{bucket_name}/", "")
11501147
blob = bucket.blob(input_blob_name)
11511148
blob.upload_from_string(config.query)
11521149

1153-
output_blob_name = f"output_{job_uuid}.json"
1154-
output_gcs_uri = f"{gcs_bucket}/{output_blob_name}"
1155-
11561150
new_config = types._RunQueryJobAgentEngineConfig(
11571151
input_gcs_uri=input_gcs_uri,
11581152
output_gcs_uri=output_gcs_uri,

vertexai/_genai/types/common.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15736,8 +15736,14 @@ class RunQueryJobAgentEngineConfig(_common.BaseModel):
1573615736
query: Optional[str] = Field(
1573715737
default=None, description="""The query to send to the agent engine."""
1573815738
)
15739-
gcs_bucket: Optional[str] = Field(
15740-
default=None, description="""The GCS bucket to use for the query."""
15739+
output_gcs_uri: Optional[str] = Field(
15740+
default=None,
15741+
description="""The GCS URI to use for the output.
15742+
If it is a file, the system use this file to store the response.
15743+
If it represents a directory, the system automatically generate a file
15744+
for the response.
15745+
In both cases, the input query will be stored in the same directory under
15746+
the same file name prefix as the output file.""",
1574115747
)
1574215748

1574315749

@@ -15750,8 +15756,13 @@ class RunQueryJobAgentEngineConfigDict(TypedDict, total=False):
1575015756
query: Optional[str]
1575115757
"""The query to send to the agent engine."""
1575215758

15753-
gcs_bucket: Optional[str]
15754-
"""The GCS bucket to use for the query."""
15759+
output_gcs_uri: Optional[str]
15760+
"""The GCS URI to use for the output.
15761+
If it is a file, the system use this file to store the response.
15762+
If it represents a directory, the system automatically generate a file
15763+
for the response.
15764+
In both cases, the input query will be stored in the same directory under
15765+
the same file name prefix as the output file."""
1575515766

1575615767

1575715768
RunQueryJobAgentEngineConfigOrDict = Union[

0 commit comments

Comments
 (0)