Skip to content

Commit dd9129f

Browse files
SNOW-3338458: Snowpark python support for WIF secret in EA (#4202)
1 parent dff6633 commit dd9129f

6 files changed

Lines changed: 154 additions & 0 deletions

File tree

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,10 @@
44

55
### Snowpark Python API Updates
66

7+
#### New Features
8+
9+
- Added `get_wif_token` to `snowflake.snowpark.secrets` for workload identity federation tokens on the Snowflake server (not available in SPCS file-based secret environments).
10+
711
#### Documentation
812

913
- Clarified that the JDBC driver JAR referenced via `udtf_configs.imports` in `DataFrameReader.jdbc()` must be downloaded from the database vendor and uploaded to a Snowflake stage.

docs/source/snowpark/secrets.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,3 +22,4 @@ Snowpark Secrets
2222
get_secret_type
2323
get_username_password
2424
get_cloud_provider_token
25+
get_wif_token

src/snowflake/snowpark/secrets.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
"get_secret_type",
2121
"get_username_password",
2222
"get_cloud_provider_token",
23+
"get_wif_token",
2324
"UsernamePassword",
2425
"CloudProviderToken",
2526
]
@@ -61,6 +62,10 @@ def get_username_password(self, secret_name: str) -> UsernamePassword:
6162
def get_cloud_provider_token(self, secret_name: str) -> CloudProviderToken:
6263
pass
6364

65+
@abstractmethod
66+
def get_wif_token(self, secret_name: str, audience: str) -> str:
67+
pass # pragma: no cover
68+
6469

6570
class _SnowflakeSecretsServer(_SnowflakeSecrets):
6671
"""Secret instance for Snowflake server environment (using _snowflake module)."""
@@ -89,6 +94,9 @@ def get_cloud_provider_token(self, secret_name: str) -> CloudProviderToken:
8994
secret_object.token,
9095
)
9196

97+
def get_wif_token(self, secret_name: str, audience: str) -> str:
98+
return self._snowflake.get_wif_token(secret_name, audience)
99+
92100

93101
class _SnowflakeSecretsSPCS(_SnowflakeSecrets):
94102
"""Secret instance for SPCS container environment (file-based secrets)."""
@@ -173,6 +181,11 @@ def get_cloud_provider_token(self, secret_name: str) -> CloudProviderToken:
173181
"Cloud provider token secrets are not supported in SPCS container environments."
174182
)
175183

184+
def get_wif_token(self, secret_name: str, audience: str) -> str:
185+
raise NotImplementedError(
186+
"WIF token secrets are not supported in SPCS container environments."
187+
)
188+
176189

177190
def _is_spcs_environment() -> bool:
178191
return os.getenv(_SCLS_SPCS_SECRET_ENV_NAME, None) is not None
@@ -259,3 +272,29 @@ def get_cloud_provider_token(secret_name: str) -> CloudProviderToken:
259272
NotImplementedError: If running outside Snowflake server environment.
260273
"""
261274
return _get_secrets_instance().get_cloud_provider_token(secret_name)
275+
276+
277+
def get_wif_token(secret_name: str, audience: str) -> str:
278+
"""Get a workload identity federation (WIF) token from Snowflake.
279+
280+
Note:
281+
Requires a Snowflake environment with a WIF secret configured and an
282+
external access integration that allows the UDF or stored procedure to
283+
use that secret. The ``audience`` must match the token audience expected
284+
by the external system (for example, an OAuth token endpoint URL).
285+
286+
Args:
287+
secret_name: The secret reference name bound to the WIF secret.
288+
audience: The intended audience (``aud``) for the issued token.
289+
290+
Returns:
291+
The issued token as a string (typically a JWT).
292+
293+
Raises:
294+
NotImplementedError: If running outside the Snowflake server environment
295+
(including SPCS file-based secret environments, where WIF tokens cannot
296+
be minted).
297+
ValueError: If the secret does not exist or is not authorized (when
298+
applicable in supported environments).
299+
"""
300+
return _get_secrets_instance().get_wif_token(secret_name, audience)

tests/integ/conftest.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,9 @@ def set_up_external_access_integration_resources(
7474
integration1,
7575
integration2,
7676
integration3,
77+
key4,
78+
integration4,
79+
wif_audience,
7780
):
7881
try:
7982
# IMPORTANT SETUP NOTES: the test role needs to be granted the creation privilege
@@ -128,6 +131,12 @@ def set_up_external_access_integration_resources(
128131
).collect()
129132
session.sql(
130133
f"""
134+
CREATE SECRET IF NOT EXISTS {key4}
135+
TYPE = WORKLOAD_IDENTITY_FEDERATION;
136+
"""
137+
).collect()
138+
session.sql(
139+
f"""
131140
CREATE IF NOT EXISTS EXTERNAL ACCESS INTEGRATION {integration1}
132141
ALLOWED_NETWORK_RULES = ({rule1})
133142
ALLOWED_AUTHENTICATION_SECRETS = ({key1})
@@ -148,6 +157,14 @@ def set_up_external_access_integration_resources(
148157
ALLOWED_NETWORK_RULES = ({rule3})
149158
ALLOWED_AUTHENTICATION_SECRETS = ({key3})
150159
ENABLED = true;
160+
"""
161+
).collect()
162+
session.sql(
163+
f"""
164+
CREATE EXTERNAL ACCESS INTEGRATION IF NOT EXISTS {integration4}
165+
ALLOWED_NETWORK_RULES = ({rule1})
166+
ALLOWED_AUTHENTICATION_SECRETS = ({key4})
167+
ENABLED = true;
151168
"""
152169
).collect()
153170
CONNECTION_PARAMETERS["external_access_rule1"] = rule1
@@ -156,9 +173,12 @@ def set_up_external_access_integration_resources(
156173
CONNECTION_PARAMETERS["external_access_key1"] = key1
157174
CONNECTION_PARAMETERS["external_access_key2"] = key2
158175
CONNECTION_PARAMETERS["external_access_key3"] = key3
176+
CONNECTION_PARAMETERS["external_access_key4"] = key4
159177
CONNECTION_PARAMETERS["external_access_integration1"] = integration1
160178
CONNECTION_PARAMETERS["external_access_integration2"] = integration2
161179
CONNECTION_PARAMETERS["external_access_integration3"] = integration3
180+
CONNECTION_PARAMETERS["external_access_integration4"] = integration4
181+
CONNECTION_PARAMETERS["wif_audience"] = wif_audience
162182
except SnowparkSQLException:
163183
# GCP currently does not support external access integration
164184
# we can remove the exception once the integration is available on GCP
@@ -184,9 +204,12 @@ def clean_up_external_access_integration_resources():
184204
CONNECTION_PARAMETERS.pop("external_access_key1", None)
185205
CONNECTION_PARAMETERS.pop("external_access_key2", None)
186206
CONNECTION_PARAMETERS.pop("external_access_key3", None)
207+
CONNECTION_PARAMETERS.pop("external_access_key4", None)
187208
CONNECTION_PARAMETERS.pop("external_access_integration1", None)
188209
CONNECTION_PARAMETERS.pop("external_access_integration2", None)
189210
CONNECTION_PARAMETERS.pop("external_access_integration3", None)
211+
CONNECTION_PARAMETERS.pop("external_access_integration4", None)
212+
CONNECTION_PARAMETERS.pop("wif_audience", None)
190213

191214

192215
def set_up_dataframe_processor_parameters(
@@ -315,9 +338,12 @@ def session(
315338
key1 = "snowpark_python_test_key1"
316339
key2 = "snowpark_python_test_key2"
317340
key3 = "snowpark_python_test_key3"
341+
key4 = "snowpark_python_test_key4"
318342
integration1 = "snowpark_python_test_integration1"
319343
integration2 = "snowpark_python_test_integration2"
320344
integration3 = "snowpark_python_test_integration3"
345+
integration4 = "snowpark_python_test_integration4"
346+
wif_audience = "https://replace-with-your-wif-audience"
321347

322348
session = (
323349
Session.builder.configs(db_parameters)
@@ -351,6 +377,9 @@ def session(
351377
integration1,
352378
integration2,
353379
integration3,
380+
key4,
381+
integration4,
382+
wif_audience,
354383
)
355384

356385
if validate_ast:
@@ -387,9 +416,12 @@ def profiler_session(
387416
key1 = "snowpark_python_profiler_test_key1"
388417
key2 = "snowpark_python_profiler_test_key2"
389418
key3 = "snowpark_python_profiler_test_key3"
419+
key4 = "snowpark_python_profiler_test_key4"
390420
integration1 = "snowpark_python_profiler_test_integration1"
391421
integration2 = "snowpark_python_profiler_test_integration2"
392422
integration3 = "snowpark_python_profiler_test_integration3"
423+
integration4 = "snowpark_python_profiler_test_integration4"
424+
wif_audience = "https://replace-with-your-wif-audience"
393425
session = (
394426
Session.builder.configs(db_parameters)
395427
.config("local_testing", local_testing_mode)
@@ -409,6 +441,9 @@ def profiler_session(
409441
integration1,
410442
integration2,
411443
integration3,
444+
key4,
445+
integration4,
446+
wif_audience,
412447
)
413448
set_up_test_session_parameters(session, local_testing_mode)
414449
try:

tests/integ/test_secrets.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
get_secret_type,
1010
get_cloud_provider_token,
1111
get_oauth_access_token,
12+
get_wif_token,
1213
)
1314
from snowflake.snowpark.types import BooleanType, StringType
1415
from tests.utils import IS_NOT_ON_GITHUB, RUNNING_ON_JENKINS, IS_IN_STORED_PROC, Utils
@@ -152,6 +153,66 @@ def get_secret():
152153
)
153154

154155

156+
@pytest.mark.skipif(
157+
IS_NOT_ON_GITHUB or not RUNNING_ON_JENKINS,
158+
reason="Secret API is only supported on Snowflake server environment",
159+
)
160+
def test_get_wif_token_udf(session, db_parameters):
161+
try:
162+
wif_audience = db_parameters["wif_audience"]
163+
164+
def get_wif():
165+
return get_wif_token("cred", wif_audience)
166+
167+
get_wif_udf = session.udf.register(
168+
get_wif,
169+
return_type=StringType(),
170+
packages=["snowflake-snowpark-python"],
171+
external_access_integrations=[
172+
db_parameters["external_access_integration4"]
173+
],
174+
secrets={"cred": f"{db_parameters['external_access_key4']}"},
175+
)
176+
df = session.create_dataframe([[1], [2]]).to_df("x")
177+
rows = df.select(get_wif_udf()).collect()
178+
for row in rows:
179+
token = row[0]
180+
assert (
181+
isinstance(token, str) and len(token.split(".")) == 3
182+
), f"expected JWT-shaped token (header.payload.signature), got {token!r}"
183+
except KeyError:
184+
pytest.skip("External Access Integration is not supported on the deployment.")
185+
186+
187+
@pytest.mark.skipif(
188+
IS_NOT_ON_GITHUB or not RUNNING_ON_JENKINS,
189+
reason="Secret API is only supported on Snowflake server environment",
190+
)
191+
def test_get_wif_token_sproc(session, db_parameters):
192+
try:
193+
wif_audience = db_parameters["wif_audience"]
194+
195+
def get_wif_in_sproc(session_):
196+
return get_wif_token("cred", wif_audience)
197+
198+
get_wif_sp = session.sproc.register(
199+
get_wif_in_sproc,
200+
return_type=StringType(),
201+
packages=["snowflake-snowpark-python"],
202+
external_access_integrations=[
203+
db_parameters["external_access_integration4"]
204+
],
205+
secrets={"cred": f"{db_parameters['external_access_key4']}"},
206+
anonymous=True,
207+
)
208+
token = get_wif_sp()
209+
assert (
210+
isinstance(token, str) and len(token.split(".")) == 3
211+
), f"expected JWT-shaped token (header.payload.signature), got {token!r}"
212+
except KeyError:
213+
pytest.skip("External Access Integration is not supported on the deployment.")
214+
215+
155216
@pytest.mark.skipif(
156217
IS_IN_STORED_PROC,
157218
reason="Run only outside Snowflake server to validate NotImplementedError",
@@ -169,3 +230,5 @@ def test_secrets_import_error():
169230
get_cloud_provider_token("c1")
170231
with pytest.raises(NotImplementedError):
171232
get_oauth_access_token("o1")
233+
with pytest.raises(NotImplementedError):
234+
get_wif_token("w1", "https://audience")

tests/unit/test_secrets.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
get_secret_type,
1313
get_username_password,
1414
get_cloud_provider_token,
15+
get_wif_token,
1516
UsernamePassword,
1617
CloudProviderToken,
1718
_SCLS_SPCS_SECRET_ENV_NAME,
@@ -31,6 +32,7 @@ def _build_fake_snowflake_module() -> object:
3132
get_secret_type=lambda secret_name: "PASSWORD",
3233
get_username_password=lambda secret_name: fake_username_password,
3334
get_cloud_provider_token=lambda secret_name: fake_cloud_token,
35+
get_wif_token=lambda secret_name, audience: f"wif:{secret_name}:{audience}",
3436
)
3537

3638

@@ -52,6 +54,11 @@ def test_secrets_mock_server_paths():
5254
assert cloud.secret_access_key == "SECRET_TEST"
5355
assert cloud.token == "STS_TOKEN_TEST"
5456

57+
assert (
58+
get_wif_token("w1", "https://example.com/aud")
59+
== "wif:w1:https://example.com/aud"
60+
)
61+
5562

5663
@pytest.fixture
5764
def scls_spcs_mock_env(tmp_path):
@@ -135,6 +142,9 @@ def test_secrets_mock_scls_spcs_error_cases(scls_spcs_mock_env):
135142
with pytest.raises(NotImplementedError):
136143
get_cloud_provider_token("any_secret")
137144

145+
with pytest.raises(NotImplementedError):
146+
get_wif_token("any_secret", "https://audience")
147+
138148
with pytest.raises(ValueError, match="Unknown secret type"):
139149
get_secret_type("unknown_secret")
140150

@@ -159,6 +169,8 @@ def test_secrets_import_error_paths():
159169
get_username_password("p1")
160170
with pytest.raises(NotImplementedError):
161171
get_cloud_provider_token("c1")
172+
with pytest.raises(NotImplementedError):
173+
get_wif_token("w1", "https://audience")
162174
finally:
163175
if original_env is not None:
164176
os.environ[_SCLS_SPCS_SECRET_ENV_NAME] = original_env

0 commit comments

Comments
 (0)