Skip to content

Commit df7672c

Browse files
committed
initial changes
1 parent 1253e7e commit df7672c

5 files changed

Lines changed: 56 additions & 0 deletions

File tree

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
#### New Features
3939

4040
- Added `artifact_repository` support to `udtf_configs` in `session.read.dbapi()`, enabling users to specify a custom artifact repository (e.g. PyPI) for packages used by the internal UDTF during distributed ingestion.
41+
- Added `get_wif_token` to `snowflake.snowpark.secrets` for workload identity federation tokens on the Snowflake server (not available in SPCS file-based secret environments).
4142

4243
#### Bug Fixes
4344

docs/source/snowpark/secrets.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,3 +22,4 @@ Snowpark Secrets
2222
get_secret_type
2323
get_username_password
2424
get_cloud_provider_token
25+
get_wif_token

src/snowflake/snowpark/secrets.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
"get_secret_type",
2121
"get_username_password",
2222
"get_cloud_provider_token",
23+
"get_wif_token",
2324
"UsernamePassword",
2425
"CloudProviderToken",
2526
]
@@ -61,6 +62,10 @@ def get_username_password(self, secret_name: str) -> UsernamePassword:
6162
def get_cloud_provider_token(self, secret_name: str) -> CloudProviderToken:
6263
pass
6364

65+
@abstractmethod
66+
def get_wif_token(self, secret_name: str, audience: str) -> str:
67+
pass
68+
6469

6570
class _SnowflakeSecretsServer(_SnowflakeSecrets):
6671
"""Secret instance for Snowflake server environment (using _snowflake module)."""
@@ -89,6 +94,9 @@ def get_cloud_provider_token(self, secret_name: str) -> CloudProviderToken:
8994
secret_object.token,
9095
)
9196

97+
def get_wif_token(self, secret_name: str, audience: str) -> str:
98+
return self._snowflake.get_wif_token(secret_name, audience)
99+
92100

93101
class _SnowflakeSecretsSPCS(_SnowflakeSecrets):
94102
"""Secret instance for SPCS container environment (file-based secrets)."""
@@ -173,6 +181,11 @@ def get_cloud_provider_token(self, secret_name: str) -> CloudProviderToken:
173181
"Cloud provider token secrets are not supported in SPCS container environments."
174182
)
175183

184+
def get_wif_token(self, secret_name: str, audience: str) -> str:
185+
raise NotImplementedError(
186+
"WIF token secrets are not supported in SPCS container environments."
187+
)
188+
176189

177190
def _is_spcs_environment() -> bool:
178191
return os.getenv(_SCLS_SPCS_SECRET_ENV_NAME, None) is not None
@@ -259,3 +272,29 @@ def get_cloud_provider_token(secret_name: str) -> CloudProviderToken:
259272
NotImplementedError: If running outside Snowflake server environment.
260273
"""
261274
return _get_secrets_instance().get_cloud_provider_token(secret_name)
275+
276+
277+
def get_wif_token(secret_name: str, audience: str) -> str:
278+
"""Get a workload identity federation (WIF) token from Snowflake.
279+
280+
Note:
281+
Requires a Snowflake environment with a WIF secret configured and an
282+
external access integration that allows the UDF or stored procedure to
283+
use that secret. The ``audience`` must match the token audience expected
284+
by the external system (for example, an OAuth token endpoint URL).
285+
286+
Args:
287+
secret_name: The secret reference name bound to the WIF secret.
288+
audience: The intended audience (``aud``) for the issued token.
289+
290+
Returns:
291+
The issued token as a string (typically a JWT).
292+
293+
Raises:
294+
NotImplementedError: If running outside the Snowflake server environment
295+
(including SPCS file-based secret environments, where WIF tokens cannot
296+
be minted).
297+
ValueError: If the secret does not exist or is not authorized (when
298+
applicable in supported environments).
299+
"""
300+
return _get_secrets_instance().get_wif_token(secret_name, audience)

tests/integ/test_secrets.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
get_secret_type,
1010
get_cloud_provider_token,
1111
get_oauth_access_token,
12+
get_wif_token,
1213
)
1314
from snowflake.snowpark.types import BooleanType, StringType
1415
from tests.utils import IS_NOT_ON_GITHUB, RUNNING_ON_JENKINS, IS_IN_STORED_PROC, Utils
@@ -169,3 +170,5 @@ def test_secrets_import_error():
169170
get_cloud_provider_token("c1")
170171
with pytest.raises(NotImplementedError):
171172
get_oauth_access_token("o1")
173+
with pytest.raises(NotImplementedError):
174+
get_wif_token("w1", "https://audience")

tests/unit/test_secrets.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
get_secret_type,
1313
get_username_password,
1414
get_cloud_provider_token,
15+
get_wif_token,
1516
UsernamePassword,
1617
CloudProviderToken,
1718
_SCLS_SPCS_SECRET_ENV_NAME,
@@ -31,6 +32,7 @@ def _build_fake_snowflake_module() -> object:
3132
get_secret_type=lambda secret_name: "PASSWORD",
3233
get_username_password=lambda secret_name: fake_username_password,
3334
get_cloud_provider_token=lambda secret_name: fake_cloud_token,
35+
get_wif_token=lambda secret_name, audience: f"wif:{secret_name}:{audience}",
3436
)
3537

3638

@@ -52,6 +54,11 @@ def test_secrets_mock_server_paths():
5254
assert cloud.secret_access_key == "SECRET_TEST"
5355
assert cloud.token == "STS_TOKEN_TEST"
5456

57+
assert (
58+
get_wif_token("w1", "https://example.com/aud")
59+
== "wif:w1:https://example.com/aud"
60+
)
61+
5562

5663
@pytest.fixture
5764
def scls_spcs_mock_env(tmp_path):
@@ -135,6 +142,9 @@ def test_secrets_mock_scls_spcs_error_cases(scls_spcs_mock_env):
135142
with pytest.raises(NotImplementedError):
136143
get_cloud_provider_token("any_secret")
137144

145+
with pytest.raises(NotImplementedError):
146+
get_wif_token("any_secret", "https://audience")
147+
138148
with pytest.raises(ValueError, match="Unknown secret type"):
139149
get_secret_type("unknown_secret")
140150

@@ -159,6 +169,8 @@ def test_secrets_import_error_paths():
159169
get_username_password("p1")
160170
with pytest.raises(NotImplementedError):
161171
get_cloud_provider_token("c1")
172+
with pytest.raises(NotImplementedError):
173+
get_wif_token("w1", "https://audience")
162174
finally:
163175
if original_env is not None:
164176
os.environ[_SCLS_SPCS_SECRET_ENV_NAME] = original_env

0 commit comments

Comments
 (0)