Skip to content

Commit 67097eb

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
feat: Add Create Skill method for Vertex AI Skill Registry SDK
PiperOrigin-RevId: 911930424
1 parent f5909b2 commit 67097eb

5 files changed

Lines changed: 923 additions & 0 deletions

File tree

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
# Copyright 2026 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
#
15+
"""Tests the skills.create() method against the Vertex AI endpoint using replays."""
16+
17+
import io
18+
import zipfile
19+
20+
from tests.unit.vertexai.genai.replays import pytest_helper
21+
from vertexai._genai import types
22+
23+
# MANDATORY: Initialize the replay test framework for this module
24+
pytestmark = pytest_helper.setup(
25+
file=__file__,
26+
globals_for_file=globals(),
27+
)
28+
29+
30+
def test_create_skill(client, tmp_path):
31+
client._api_client._http_options.base_url = (
32+
"https://us-central1-aiplatform.googleapis.com"
33+
)
34+
35+
# Create a dummy skill structure (SKILL.md is required by the spec)
36+
with open(tmp_path / "SKILL.md", "w") as f:
37+
f.write("# My Replay Skill\nThis is a test skill for replay tests.")
38+
39+
skill = client.skills.create(
40+
display_name="My Replay Skill",
41+
description="My Replay Skill Description",
42+
config=types.CreateSkillConfig(
43+
local_path=str(tmp_path), wait_for_completion=True
44+
),
45+
)
46+
47+
assert skill.name is not None
48+
assert skill.display_name == "My Replay Skill"
49+
assert skill.description == "My Replay Skill Description"
50+
51+
52+
def test_create_skill_with_prezipped_bytes(client):
53+
"""Tests the creation of a skill with pre-zipped bytes."""
54+
client._api_client._http_options.base_url = (
55+
"https://us-central1-aiplatform.googleapis.com"
56+
)
57+
58+
zip_buffer = io.BytesIO()
59+
zinfo = zipfile.ZipInfo("SKILL.md", date_time=(1980, 1, 1, 0, 0, 0))
60+
with zipfile.ZipFile(zip_buffer, "w") as zip_file:
61+
zip_file.writestr(zinfo, "# My Zipped Replay Skill\nThis is a test.")
62+
zipped_bytes = zip_buffer.getvalue()
63+
64+
skill = client.skills.create(
65+
display_name="My Zipped Replay Skill",
66+
description="My Zipped Replay Skill Description",
67+
config=types.CreateSkillConfig(
68+
zipped_filesystem=zipped_bytes, wait_for_completion=True
69+
),
70+
)
71+
72+
assert skill.name is not None
73+
assert skill.display_name == "My Zipped Replay Skill"
74+
assert skill.description == "My Zipped Replay Skill Description"

vertexai/_genai/_skills_utils.py

Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
# Copyright 2026 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
#
15+
"""Utility functions for Skills."""
16+
17+
import asyncio
18+
import base64
19+
import datetime
20+
import io
21+
import os
22+
import pathlib
23+
import time
24+
from typing import Any, Awaitable, Callable
25+
import zipfile
26+
27+
28+
def zip_directory(directory_path: pathlib.Path | str) -> bytes:
29+
"""Zips a directory into memory and returns the bytes.
30+
31+
Args:
32+
directory_path (pathlib.Path | str): Required. The local path to the
33+
directory.
34+
35+
Returns:
36+
bytes: The zipped directory content.
37+
"""
38+
directory_str = os.fspath(directory_path)
39+
if not os.path.isdir(directory_str):
40+
raise ValueError(f"Path is not a directory: {directory_str}")
41+
42+
zip_buffer = io.BytesIO()
43+
with zipfile.ZipFile(zip_buffer, "w", zipfile.ZIP_DEFLATED) as zip_file:
44+
for root, _, files in os.walk(directory_str):
45+
for file in files:
46+
file_path = os.path.join(root, file)
47+
arcname = os.path.relpath(file_path, directory_str)
48+
49+
# Read actual file data
50+
with open(file_path, "rb") as f:
51+
file_data = f.read()
52+
53+
# Use deterministic ZipInfo (mtime: 1980-01-01 00:00:00)
54+
zinfo = zipfile.ZipInfo(arcname, date_time=(1980, 1, 1, 0, 0, 0))
55+
zinfo.compress_type = zipfile.ZIP_DEFLATED
56+
zinfo.external_attr = 0o644 << 16 # Constant file permissions
57+
58+
zip_file.writestr(zinfo, file_data)
59+
return zip_buffer.getvalue()
60+
61+
62+
def get_zipped_filesystem_payload(directory_path: pathlib.Path | str) -> str:
63+
"""Zips a directory and base64-encodes the result to a UTF-8 string.
64+
65+
Args:
66+
directory_path (pathlib.Path | str): Required. The local path to the
67+
directory.
68+
69+
Returns:
70+
str: The base64-encoded zipped directory.
71+
"""
72+
zip_bytes = zip_directory(directory_path)
73+
return base64.b64encode(zip_bytes).decode("utf-8")
74+
75+
76+
def await_operation(
77+
*,
78+
operation_name: str,
79+
get_operation_fn: Callable[..., Any],
80+
poll_interval: datetime.timedelta | float = 10.0,
81+
timeout_seconds: float = 300.0,
82+
) -> Any:
83+
"""Waits for a long running operation to complete.
84+
85+
Args:
86+
operation_name (str): Required. The name of the operation.
87+
get_operation_fn (Callable): Required. Function to get the operation
88+
status.
89+
poll_interval (datetime.timedelta | float): The interval between polls.
90+
timeout_seconds (float): The maximum wait duration in seconds.
91+
92+
Returns:
93+
Any: The completed operation.
94+
"""
95+
if isinstance(poll_interval, datetime.timedelta):
96+
poll_seconds = poll_interval.total_seconds()
97+
else:
98+
poll_seconds = float(poll_interval)
99+
100+
start_time = time.time()
101+
operation = get_operation_fn(operation_name=operation_name)
102+
while not operation.done:
103+
if (time.time() - start_time) > timeout_seconds:
104+
raise TimeoutError(
105+
f"Operation {operation_name} did not complete within the timeout "
106+
f"of {timeout_seconds} seconds."
107+
)
108+
time.sleep(poll_seconds)
109+
operation = get_operation_fn(operation_name=operation.name)
110+
return operation
111+
112+
113+
async def await_operation_async(
114+
*,
115+
operation_name: str,
116+
get_operation_fn: Callable[..., Awaitable[Any]],
117+
poll_interval: datetime.timedelta | float = 10.0,
118+
timeout_seconds: float = 300.0,
119+
) -> Any:
120+
"""Waits for a long running operation to complete asynchronously.
121+
122+
Args:
123+
operation_name (str): Required. The name of the operation.
124+
get_operation_fn (Callable): Required. Async function to get the operation
125+
status.
126+
poll_interval (datetime.timedelta | float): The interval between polls.
127+
timeout_seconds (float): The maximum wait duration in seconds.
128+
129+
Returns:
130+
Any: The completed operation.
131+
"""
132+
if isinstance(poll_interval, datetime.timedelta):
133+
poll_seconds = poll_interval.total_seconds()
134+
else:
135+
poll_seconds = float(poll_interval)
136+
137+
start_time = time.time()
138+
operation = await get_operation_fn(operation_name=operation_name)
139+
while not operation.done:
140+
if (time.time() - start_time) > timeout_seconds:
141+
raise TimeoutError(
142+
f"Operation {operation_name} did not complete within the timeout "
143+
f"of {timeout_seconds} seconds."
144+
)
145+
await asyncio.sleep(poll_seconds)
146+
operation = await get_operation_fn(operation_name=operation.name)
147+
return operation

0 commit comments

Comments
 (0)