Skip to content

Commit 1971711

Browse files
Add init_runtime_native_unified credential provider with backwards compatibility (#1380)
## What are we changing Introduces a new unified runtime auth provider that returns (host, account_id, workspace_id, inner) instead of (host, inner). Falls back to the existing providers when the new import is unavailable or returns None. ## How is this tested? * New unit tests * Run existing runtime integration tests. NOTE: The new interface cannot be tested with integration tests until Runtime adds support, but the SDK needs support before Runtime. So it is a chicken and egg problem. Existing integration tests validate that we are not breaking existing users by falling back to the previous auth system.
1 parent b16b649 commit 1971711

3 files changed

Lines changed: 108 additions & 0 deletions

File tree

databricks/sdk/credentials_provider.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,8 +160,21 @@ def runtime_native_auth(cfg: "Config") -> Optional[CredentialsProvider]:
160160
# runtime and no config variables are set.
161161
from databricks.sdk.runtime import (init_runtime_legacy_auth,
162162
init_runtime_native_auth,
163+
init_runtime_native_unified,
163164
init_runtime_repl_auth)
164165

166+
# Try the unified provider first (returns host, account_id, workspace_id, inner).
167+
if init_runtime_native_unified is not None:
168+
host, account_id, workspace_id, inner = init_runtime_native_unified()
169+
if host is not None:
170+
cfg.host = host
171+
cfg.account_id = account_id
172+
cfg.workspace_id = workspace_id
173+
logger.debug("[init_runtime_native_unified] runtime native auth configured")
174+
return inner
175+
logger.debug("[init_runtime_native_unified] no host detected")
176+
177+
# Fall back to legacy providers (return host, inner).
165178
for init in [
166179
init_runtime_native_auth,
167180
init_runtime_repl_auth,

databricks/sdk/runtime/__init__.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,13 @@
2323
]
2424

2525
# DO NOT MOVE THE TRY-CATCH BLOCK BELOW AND DO NOT ADD THINGS BEFORE IT! WILL MAKE TEST FAIL.
26+
try:
27+
from dbruntime.sdk_credential_provider import init_runtime_native_unified
28+
29+
logger.debug("runtime SDK credential provider (unified) available")
30+
except ImportError:
31+
init_runtime_native_unified = None
32+
2633
try:
2734
# We don't want to expose additional entity to user namespace, so
2835
# a workaround here for exposing required information in notebook environment
@@ -34,6 +41,7 @@
3441
init_runtime_native_auth = None
3542

3643
globals()["init_runtime_native_auth"] = init_runtime_native_auth
44+
globals()["init_runtime_native_unified"] = init_runtime_native_unified
3745

3846

3947
def init_runtime_repl_auth():

tests/test_notebook_oauth.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ def fake_init_runtime_repl_auth():
4343
pass
4444

4545
fake_runtime.init_runtime_native_auth = fake_init_runtime_native_auth
46+
fake_runtime.init_runtime_native_unified = None
4647
fake_runtime.init_runtime_legacy_auth = fake_init_runtime_legacy_auth
4748
fake_runtime.init_runtime_repl_auth = fake_init_runtime_repl_auth
4849

@@ -187,3 +188,89 @@ def test_workspace_client_integration(
187188
assert w.config.scopes == expected_scopes
188189
headers = w.config.authenticate()
189190
assert headers["Authorization"] == "Bearer exchanged-oauth-token"
191+
192+
193+
@pytest.fixture
194+
def mock_runtime_native_unified():
195+
"""Mock the runtime module with init_runtime_native_unified returning 4-tuple."""
196+
fake_runtime = types.ModuleType("databricks.sdk.runtime")
197+
198+
def fake_init_runtime_native_unified():
199+
def inner():
200+
return {"Authorization": "Bearer unified-token"}
201+
202+
return "https://unified.cloud.databricks.com", "acc-123", "ws-456", inner
203+
204+
fake_runtime.init_runtime_native_unified = fake_init_runtime_native_unified
205+
fake_runtime.init_runtime_native_auth = None
206+
fake_runtime.init_runtime_legacy_auth = None
207+
fake_runtime.init_runtime_repl_auth = None
208+
209+
sys.modules["databricks.sdk.runtime"] = fake_runtime
210+
yield
211+
212+
213+
@pytest.fixture
214+
def mock_runtime_native_unified_returns_none():
215+
"""Mock the runtime module with init_runtime_native_unified returning None host."""
216+
fake_runtime = types.ModuleType("databricks.sdk.runtime")
217+
218+
def fake_init_runtime_native_unified():
219+
return None, None, None, None
220+
221+
def fake_init_runtime_native_auth():
222+
def inner():
223+
return {"Authorization": "Bearer fallback-token"}
224+
225+
return "https://fallback.cloud.databricks.com", inner
226+
227+
fake_runtime.init_runtime_native_unified = fake_init_runtime_native_unified
228+
fake_runtime.init_runtime_native_auth = fake_init_runtime_native_auth
229+
fake_runtime.init_runtime_legacy_auth = None
230+
fake_runtime.init_runtime_repl_auth = None
231+
232+
sys.modules["databricks.sdk.runtime"] = fake_runtime
233+
yield
234+
235+
236+
def test_runtime_unified_auth_sets_host_and_ids(mock_runtime_env, mock_runtime_native_unified):
237+
"""Test that init_runtime_native_unified sets host, account_id, and workspace_id on Config."""
238+
cfg = Config(host="https://unified.cloud.databricks.com")
239+
240+
headers = cfg.authenticate()
241+
assert headers["Authorization"] == "Bearer unified-token"
242+
assert cfg.host == "https://unified.cloud.databricks.com"
243+
assert cfg.account_id == "acc-123"
244+
assert cfg.workspace_id == "ws-456"
245+
246+
247+
def test_runtime_unified_auth_fallback_when_none(mock_runtime_env, mock_runtime_native_unified_returns_none):
248+
"""Test fallback to init_runtime_native_auth when unified returns None."""
249+
cfg = Config(host="https://fallback.cloud.databricks.com")
250+
251+
headers = cfg.authenticate()
252+
assert headers["Authorization"] == "Bearer fallback-token"
253+
assert cfg.host == "https://fallback.cloud.databricks.com"
254+
255+
256+
def test_runtime_unified_auth_fallback_when_not_available(mock_runtime_env, mock_runtime_native_auth):
257+
"""Test fallback to init_runtime_native_auth when unified is None (import failed)."""
258+
cfg = Config(host="https://test.cloud.databricks.com")
259+
260+
headers = cfg.authenticate()
261+
assert headers["Authorization"] == "Bearer test-notebook-pat-token"
262+
assert cfg.host == "https://test.cloud.databricks.com"
263+
264+
265+
def test_runtime_unified_auth_priority_over_native(mock_runtime_env, mock_runtime_native_unified):
266+
"""Test that unified provider is used over native auth in DefaultCredentials chain."""
267+
cfg = Config(host="https://unified.cloud.databricks.com")
268+
269+
default_creds = DefaultCredentials()
270+
creds_provider = default_creds(cfg)
271+
272+
headers = creds_provider()
273+
assert headers["Authorization"] == "Bearer unified-token"
274+
assert default_creds.auth_type() == "runtime"
275+
assert cfg.account_id == "acc-123"
276+
assert cfg.workspace_id == "ws-456"

0 commit comments

Comments
 (0)