From 07ad7b6db8408c18cc56415edb3d74d51fad75eb Mon Sep 17 00:00:00 2001 From: toddmeng-db Date: Thu, 24 Apr 2025 17:06:59 -0700 Subject: [PATCH 01/11] OAuthClientCredentialsService --- .../Auth/OAuthClientCredentialsService.cs | 145 ++++++++++++++++++ 1 file changed, 145 insertions(+) create mode 100644 csharp/src/Drivers/Databricks/Auth/OAuthClientCredentialsService.cs diff --git a/csharp/src/Drivers/Databricks/Auth/OAuthClientCredentialsService.cs b/csharp/src/Drivers/Databricks/Auth/OAuthClientCredentialsService.cs new file mode 100644 index 0000000000..66a2f6b63e --- /dev/null +++ b/csharp/src/Drivers/Databricks/Auth/OAuthClientCredentialsService.cs @@ -0,0 +1,145 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +using System; +using System.Collections.Generic; +using System.Net.Http; +using System.Net.Http.Headers; +using System.Text.Json; +using System.Threading; +using System.Threading.Tasks; + +namespace Apache.Arrow.Adbc.Drivers.Databricks.Auth +{ + /// + /// Service for obtaining OAuth access tokens using the client credentials grant type. + /// + internal class OAuthClientCredentialsService + { + private readonly HttpClient _httpClient; + private readonly string _clientId; + private readonly string _clientSecret; + private readonly Uri _baseUri; + private readonly string? _tenantId; + private readonly string _scope; + private readonly string _tokenEndpoint; + + /// + /// Initializes a new instance of the class. + /// + /// The HTTP client to use for token requests. + /// The OAuth client ID. + /// The OAuth client secret. + /// The base URI of the Databricks workspace. + /// The Azure AD tenant ID. Required for Azure Databricks. + /// The OAuth scope to request. Default is "all-apis". + public OAuthClientCredentialsService( + HttpClient httpClient, + string clientId, + string clientSecret, + Uri baseUri, + string? tenantId = null, + string scope = "all-apis") + { + _httpClient = httpClient ?? throw new ArgumentNullException(nameof(httpClient)); + _clientId = clientId ?? throw new ArgumentNullException(nameof(clientId)); + _clientSecret = clientSecret ?? throw new ArgumentNullException(nameof(clientSecret)); + _baseUri = baseUri ?? throw new ArgumentNullException(nameof(baseUri)); + _tenantId = tenantId; + _scope = scope ?? "all-apis"; + _tokenEndpoint = DetermineTokenEndpoint(); + } + + private string DetermineTokenEndpoint() + { + string host = _baseUri.Host.ToLowerInvariant(); + if (host.Contains("azuredatabricks.net")) + { + if (string.IsNullOrEmpty(_tenantId)) + { + throw new ArgumentException("Azure Databricks requires a tenantId to determine the token endpoint."); + } + + return $"https://login.microsoftonline.com/{_tenantId}/oauth2/v2.0/token"; + } + else + { + // Applies to AWS and GCP (if using Databricks-hosted IdP) + return "https://accounts.cloud.databricks.com/oidc/token"; + } + } + + /// + /// Gets an OAuth access token using the client credentials grant type. + /// + /// A cancellation token to cancel the operation. + /// The access token. + /// Thrown when the token request fails or the response is invalid. + public async Task GetAccessTokenAsync(CancellationToken cancellationToken = default) + { + var requestContent = new FormUrlEncodedContent(new[] + { + new KeyValuePair("grant_type", "client_credentials"), + new KeyValuePair("client_id", _clientId), + new KeyValuePair("client_secret", _clientSecret), + new KeyValuePair("scope", _scope) + }); + + var request = new HttpRequestMessage(HttpMethod.Post, _tokenEndpoint) + { + Content = requestContent + }; + + request.Headers.Accept.Add(new MediaTypeWithQualityHeaderValue("application/json")); + + HttpResponseMessage response; + try + { + response = await _httpClient.SendAsync(request, cancellationToken); + response.EnsureSuccessStatusCode(); + } + catch (HttpRequestException ex) + { + throw new DatabricksException($"Failed to acquire OAuth access token: {ex.Message}", ex); + } + + string content = await response.Content.ReadAsStringAsync(); + + try + { + using var jsonDoc = JsonDocument.Parse(content); + + if (!jsonDoc.RootElement.TryGetProperty("access_token", out var accessTokenElement)) + { + throw new DatabricksException("OAuth response did not contain an access_token"); + } + + string? accessToken = accessTokenElement.GetString(); + if (string.IsNullOrEmpty(accessToken)) + { + throw new DatabricksException("OAuth access_token was null or empty"); + } + + return accessToken!; + } + catch (JsonException ex) + { + throw new DatabricksException($"Failed to parse OAuth response: {ex.Message}", ex); + } + } + } +} \ No newline at end of file From 136af431d0f90c310d30e3bd93004333000e9e74 Mon Sep 17 00:00:00 2001 From: toddmeng-db Date: Thu, 24 Apr 2025 22:29:46 -0700 Subject: [PATCH 02/11] OAuthClientCredentialsService --- .../Auth/OAuthClientCredentialsService.cs | 22 ++++++++++++++----- 1 file changed, 16 insertions(+), 6 deletions(-) diff --git a/csharp/src/Drivers/Databricks/Auth/OAuthClientCredentialsService.cs b/csharp/src/Drivers/Databricks/Auth/OAuthClientCredentialsService.cs index 66a2f6b63e..a42a11f2e1 100644 --- a/csharp/src/Drivers/Databricks/Auth/OAuthClientCredentialsService.cs +++ b/csharp/src/Drivers/Databricks/Auth/OAuthClientCredentialsService.cs @@ -30,40 +30,50 @@ namespace Apache.Arrow.Adbc.Drivers.Databricks.Auth /// internal class OAuthClientCredentialsService { - private readonly HttpClient _httpClient; + private readonly Lazy _httpClient; private readonly string _clientId; private readonly string _clientSecret; private readonly Uri _baseUri; private readonly string? _tenantId; private readonly string _scope; private readonly string _tokenEndpoint; + private readonly int _timeoutMinutes; /// /// Initializes a new instance of the class. /// - /// The HTTP client to use for token requests. /// The OAuth client ID. /// The OAuth client secret. /// The base URI of the Databricks workspace. /// The Azure AD tenant ID. Required for Azure Databricks. /// The OAuth scope to request. Default is "all-apis". + /// The timeout in minutes for HTTP requests. Default is 5 minutes. public OAuthClientCredentialsService( - HttpClient httpClient, string clientId, string clientSecret, Uri baseUri, string? tenantId = null, - string scope = "all-apis") + string scope = "all-apis", + int timeoutMinutes = 5) { - _httpClient = httpClient ?? throw new ArgumentNullException(nameof(httpClient)); _clientId = clientId ?? throw new ArgumentNullException(nameof(clientId)); _clientSecret = clientSecret ?? throw new ArgumentNullException(nameof(clientSecret)); _baseUri = baseUri ?? throw new ArgumentNullException(nameof(baseUri)); _tenantId = tenantId; _scope = scope ?? "all-apis"; + _timeoutMinutes = timeoutMinutes; _tokenEndpoint = DetermineTokenEndpoint(); + + _httpClient = new Lazy(() => + { + var client = new HttpClient(); + client.Timeout = TimeSpan.FromMinutes(_timeoutMinutes); + return client; + }); } + private HttpClient HttpClient => _httpClient.Value; + private string DetermineTokenEndpoint() { string host = _baseUri.Host.ToLowerInvariant(); @@ -109,7 +119,7 @@ public async Task GetAccessTokenAsync(CancellationToken cancellationToke HttpResponseMessage response; try { - response = await _httpClient.SendAsync(request, cancellationToken); + response = await HttpClient.SendAsync(request, cancellationToken); response.EnsureSuccessStatusCode(); } catch (HttpRequestException ex) From 405bc54d6cd8335ba4909e40003f430b638af570 Mon Sep 17 00:00:00 2001 From: toddmeng-db Date: Fri, 25 Apr 2025 12:36:45 -0700 Subject: [PATCH 03/11] token refresh + test files --- .../Auth/OAuthClientCredentialsService.cs | 192 ++++++++++++------ .../OAuthClientCredentialsServiceTests.cs | 60 ++++++ 2 files changed, 194 insertions(+), 58 deletions(-) create mode 100644 csharp/test/Drivers/Databricks/Auth/OAuthClientCredentialsServiceTests.cs diff --git a/csharp/src/Drivers/Databricks/Auth/OAuthClientCredentialsService.cs b/csharp/src/Drivers/Databricks/Auth/OAuthClientCredentialsService.cs index a42a11f2e1..bb05c0d52e 100644 --- a/csharp/src/Drivers/Databricks/Auth/OAuthClientCredentialsService.cs +++ b/csharp/src/Drivers/Databricks/Auth/OAuthClientCredentialsService.cs @@ -28,16 +28,27 @@ namespace Apache.Arrow.Adbc.Drivers.Databricks.Auth /// /// Service for obtaining OAuth access tokens using the client credentials grant type. /// - internal class OAuthClientCredentialsService + internal class OAuthClientCredentialsService : IDisposable { private readonly Lazy _httpClient; private readonly string _clientId; private readonly string _clientSecret; private readonly Uri _baseUri; - private readonly string? _tenantId; - private readonly string _scope; private readonly string _tokenEndpoint; private readonly int _timeoutMinutes; + private readonly SemaphoreSlim _tokenLock = new SemaphoreSlim(1, 1); + private TokenInfo? _cachedToken; + + private class TokenInfo + { + public string? AccessToken { get; set; } + public DateTime ExpiresAt { get; set; } + + public bool IsExpired => DateTime.UtcNow >= ExpiresAt; + + // Add buffer time to refresh token before actual expiration + public bool NeedsRefresh => DateTime.UtcNow >= ExpiresAt.AddMinutes(-5); + } /// /// Initializes a new instance of the class. @@ -45,68 +56,80 @@ internal class OAuthClientCredentialsService /// The OAuth client ID. /// The OAuth client secret. /// The base URI of the Databricks workspace. - /// The Azure AD tenant ID. Required for Azure Databricks. - /// The OAuth scope to request. Default is "all-apis". - /// The timeout in minutes for HTTP requests. Default is 5 minutes. public OAuthClientCredentialsService( string clientId, string clientSecret, Uri baseUri, - string? tenantId = null, - string scope = "all-apis", - int timeoutMinutes = 5) + int timeoutMinutes = 1, + HttpClient? httpClient = null) { _clientId = clientId ?? throw new ArgumentNullException(nameof(clientId)); _clientSecret = clientSecret ?? throw new ArgumentNullException(nameof(clientSecret)); _baseUri = baseUri ?? throw new ArgumentNullException(nameof(baseUri)); - _tenantId = tenantId; - _scope = scope ?? "all-apis"; _timeoutMinutes = timeoutMinutes; _tokenEndpoint = DetermineTokenEndpoint(); - _httpClient = new Lazy(() => - { - var client = new HttpClient(); - client.Timeout = TimeSpan.FromMinutes(_timeoutMinutes); - return client; - }); + _httpClient = httpClient != null + ? new Lazy(() => httpClient) + : new Lazy(() => + { + var client = new HttpClient(); + client.Timeout = TimeSpan.FromMinutes(_timeoutMinutes); + return client; + }); } private HttpClient HttpClient => _httpClient.Value; private string DetermineTokenEndpoint() { - string host = _baseUri.Host.ToLowerInvariant(); - if (host.Contains("azuredatabricks.net")) + // For workspace URLs, the token endpoint is always /oidc/v1/token + // TODO: Might be different for Azure AAD SPs + return $"{_baseUri.Scheme}://{_baseUri.Host}/oidc/v1/token"; + } + + private string? GetValidCachedToken() + { + return _cachedToken != null && !_cachedToken.NeedsRefresh && _cachedToken.AccessToken != null + ? _cachedToken.AccessToken + : null; + } + + + private async Task RefreshTokenInternalAsync(CancellationToken cancellationToken) + { + var request = CreateTokenRequest(); + + HttpResponseMessage response; + try { - if (string.IsNullOrEmpty(_tenantId)) - { - throw new ArgumentException("Azure Databricks requires a tenantId to determine the token endpoint."); - } + response = await HttpClient.SendAsync(request, cancellationToken); + response.EnsureSuccessStatusCode(); + } + catch (HttpRequestException ex) + { + throw new DatabricksException($"Failed to acquire OAuth access token: {ex.Message}", ex); + } - return $"https://login.microsoftonline.com/{_tenantId}/oauth2/v2.0/token"; + string content = await response.Content.ReadAsStringAsync(); + + try + { + _cachedToken = ParseTokenResponse(content); + return _cachedToken.AccessToken!; } - else + catch (JsonException ex) { - // Applies to AWS and GCP (if using Databricks-hosted IdP) - return "https://accounts.cloud.databricks.com/oidc/token"; + throw new DatabricksException($"Failed to parse OAuth response: {ex.Message}", ex); } } - /// - /// Gets an OAuth access token using the client credentials grant type. - /// - /// A cancellation token to cancel the operation. - /// The access token. - /// Thrown when the token request fails or the response is invalid. - public async Task GetAccessTokenAsync(CancellationToken cancellationToken = default) + private HttpRequestMessage CreateTokenRequest() { var requestContent = new FormUrlEncodedContent(new[] { new KeyValuePair("grant_type", "client_credentials"), - new KeyValuePair("client_id", _clientId), - new KeyValuePair("client_secret", _clientSecret), - new KeyValuePair("scope", _scope) + new KeyValuePair("scope", "sql") }); var request = new HttpRequestMessage(HttpMethod.Post, _tokenEndpoint) @@ -114,42 +137,95 @@ public async Task GetAccessTokenAsync(CancellationToken cancellationToke Content = requestContent }; + // Use Basic Auth with client ID and secret + var authHeader = Convert.ToBase64String( + System.Text.Encoding.ASCII.GetBytes($"{_clientId}:{_clientSecret}")); + request.Headers.Authorization = new AuthenticationHeaderValue("Basic", authHeader); request.Headers.Accept.Add(new MediaTypeWithQualityHeaderValue("application/json")); - HttpResponseMessage response; - try + return request; + } + + private TokenInfo ParseTokenResponse(string content) + { + using var jsonDoc = JsonDocument.Parse(content); + + if (!jsonDoc.RootElement.TryGetProperty("access_token", out var accessTokenElement)) { - response = await HttpClient.SendAsync(request, cancellationToken); - response.EnsureSuccessStatusCode(); + throw new DatabricksException("OAuth response did not contain an access_token"); } - catch (HttpRequestException ex) + + string? accessToken = accessTokenElement.GetString(); + if (string.IsNullOrEmpty(accessToken)) { - throw new DatabricksException($"Failed to acquire OAuth access token: {ex.Message}", ex); + throw new DatabricksException("OAuth access_token was null or empty"); } - string content = await response.Content.ReadAsStringAsync(); - - try + // Get expiration time from response + if (!jsonDoc.RootElement.TryGetProperty("expires_in", out var expiresInElement)) { - using var jsonDoc = JsonDocument.Parse(content); - - if (!jsonDoc.RootElement.TryGetProperty("access_token", out var accessTokenElement)) - { - throw new DatabricksException("OAuth response did not contain an access_token"); - } + throw new DatabricksException("OAuth response did not contain expires_in"); + } + + int expiresIn = expiresInElement.GetInt32(); + if (expiresIn <= 0) + { + throw new DatabricksException("OAuth expires_in value must be positive"); + } + + return new TokenInfo + { + AccessToken = accessToken!, + ExpiresAt = DateTime.UtcNow.AddSeconds(expiresIn) + }; + } + + /// + /// Gets an OAuth access token using the client credentials grant type. + /// + /// A cancellation token to cancel the operation. + /// The access token. + /// Thrown when the token request fails or the response is invalid. + public async Task GetAccessTokenAsync(CancellationToken cancellationToken = default) + { + // First try to get cached token without acquiring lock + if (GetValidCachedToken() is string cachedToken) + { + return cachedToken; + } + + // If token needs refresh, acquire lock with timeout + var lockTimeout = TimeSpan.FromSeconds(30); // Reasonable timeout for lock acquisition + if (!await _tokenLock.WaitAsync(lockTimeout, cancellationToken).ConfigureAwait(false)) + { + throw new TimeoutException("Timeout waiting for token refresh lock"); + } - string? accessToken = accessTokenElement.GetString(); - if (string.IsNullOrEmpty(accessToken)) + try + { + // Double-check pattern in case another thread refreshed while we were waiting + if (GetValidCachedToken() is string refreshedToken) { - throw new DatabricksException("OAuth access_token was null or empty"); + return refreshedToken; } - return accessToken!; + return await RefreshTokenInternalAsync(cancellationToken).ConfigureAwait(false); } - catch (JsonException ex) + finally { - throw new DatabricksException($"Failed to parse OAuth response: {ex.Message}", ex); + _tokenLock.Release(); } } + + + public void Dispose() + { + _tokenLock.Dispose(); + if (_httpClient.IsValueCreated) + { + HttpClient.Dispose(); + } + } + } } \ No newline at end of file diff --git a/csharp/test/Drivers/Databricks/Auth/OAuthClientCredentialsServiceTests.cs b/csharp/test/Drivers/Databricks/Auth/OAuthClientCredentialsServiceTests.cs new file mode 100644 index 0000000000..3de5c7274c --- /dev/null +++ b/csharp/test/Drivers/Databricks/Auth/OAuthClientCredentialsServiceTests.cs @@ -0,0 +1,60 @@ +using System; +using System.Threading; +using System.Threading.Tasks; +using Apache.Arrow.Adbc.Drivers.Databricks.Auth; +using Apache.Arrow.Adbc.Tests; +using Apache.Arrow.Adbc.Tests.Drivers.Apache.Spark; +using Xunit; +using Xunit.Abstractions; + +namespace Apache.Arrow.Adbc.Tests.Drivers.Databricks.Auth +{ + public class OAuthClientCredentialsServiceTests : TestBase, IDisposable + { + private readonly OAuthClientCredentialsService _service; + + public OAuthClientCredentialsServiceTests(ITestOutputHelper? outputHelper) + : base(outputHelper, new SparkTestEnvironment.Factory()) + { + _service = new OAuthClientCredentialsService( + TestConfiguration.Username, + TestConfiguration.Password, + new Uri(TestConfiguration.Uri), + timeoutMinutes: 1); + } + + [Fact] + public async Task GetAccessToken_WithValidCredentials_ReturnsToken() + { + var token = await _service.GetAccessTokenAsync(CancellationToken.None); + + Assert.NotNull(token); + Assert.NotEmpty(token); + } + + [Fact] + public async Task GetAccessToken_WithCancellation_ThrowsOperationCanceledException() + { + using var cts = new CancellationTokenSource(); + cts.Cancel(); + + await Assert.ThrowsAsync(() => + _service.GetAccessTokenAsync(cts.Token)); + } + + [Fact] + public async Task GetAccessToken_MultipleCalls_ReusesCachedToken() + { + var token1 = await _service.GetAccessTokenAsync(CancellationToken.None); + var token2 = await _service.GetAccessTokenAsync(CancellationToken.None); + + Assert.Equal(token1, token2); + } + + void IDisposable.Dispose() + { + _service.Dispose(); + base.Dispose(); + } + } +} \ No newline at end of file From 3d2c2b76a3b6c9d901751ea3a0dbf4da1e04fa46 Mon Sep 17 00:00:00 2001 From: toddmeng-db Date: Fri, 25 Apr 2025 15:53:37 -0700 Subject: [PATCH 04/11] Test fixes --- .../OAuthClientCredentialsServiceTests.cs | 19 ++++++++++++------- .../Databricks/DatabricksTestConfiguration.cs | 5 +++++ 2 files changed, 17 insertions(+), 7 deletions(-) diff --git a/csharp/test/Drivers/Databricks/Auth/OAuthClientCredentialsServiceTests.cs b/csharp/test/Drivers/Databricks/Auth/OAuthClientCredentialsServiceTests.cs index 3de5c7274c..8697633d6d 100644 --- a/csharp/test/Drivers/Databricks/Auth/OAuthClientCredentialsServiceTests.cs +++ b/csharp/test/Drivers/Databricks/Auth/OAuthClientCredentialsServiceTests.cs @@ -2,23 +2,22 @@ using System.Threading; using System.Threading.Tasks; using Apache.Arrow.Adbc.Drivers.Databricks.Auth; -using Apache.Arrow.Adbc.Tests; -using Apache.Arrow.Adbc.Tests.Drivers.Apache.Spark; +using Apache.Arrow.Adbc.Tests.Drivers.Databricks; using Xunit; using Xunit.Abstractions; namespace Apache.Arrow.Adbc.Tests.Drivers.Databricks.Auth { - public class OAuthClientCredentialsServiceTests : TestBase, IDisposable + public class OAuthClientCredentialsServiceTests : TestBase, IDisposable { private readonly OAuthClientCredentialsService _service; public OAuthClientCredentialsServiceTests(ITestOutputHelper? outputHelper) - : base(outputHelper, new SparkTestEnvironment.Factory()) + : base(outputHelper, new DatabricksTestEnvironment.Factory()) { _service = new OAuthClientCredentialsService( - TestConfiguration.Username, - TestConfiguration.Password, + TestConfiguration.OAuthClientId, + TestConfiguration.OAuthClientSecret, new Uri(TestConfiguration.Uri), timeoutMinutes: 1); } @@ -26,6 +25,8 @@ public OAuthClientCredentialsServiceTests(ITestOutputHelper? outputHelper) [Fact] public async Task GetAccessToken_WithValidCredentials_ReturnsToken() { + Skip.IfNot(!string.IsNullOrEmpty(TestConfiguration.OAuthClientId), "OAuth credentials not configured"); + var token = await _service.GetAccessTokenAsync(CancellationToken.None); Assert.NotNull(token); @@ -35,16 +36,20 @@ public async Task GetAccessToken_WithValidCredentials_ReturnsToken() [Fact] public async Task GetAccessToken_WithCancellation_ThrowsOperationCanceledException() { + Skip.IfNot(!string.IsNullOrEmpty(TestConfiguration.OAuthClientId), "OAuth credentials not configured"); + using var cts = new CancellationTokenSource(); cts.Cancel(); - await Assert.ThrowsAsync(() => + await Assert.ThrowsAsync(() => _service.GetAccessTokenAsync(cts.Token)); } [Fact] public async Task GetAccessToken_MultipleCalls_ReusesCachedToken() { + Skip.IfNot(!string.IsNullOrEmpty(TestConfiguration.OAuthClientId), "OAuth credentials not configured"); + var token1 = await _service.GetAccessTokenAsync(CancellationToken.None); var token2 = await _service.GetAccessTokenAsync(CancellationToken.None); diff --git a/csharp/test/Drivers/Databricks/DatabricksTestConfiguration.cs b/csharp/test/Drivers/Databricks/DatabricksTestConfiguration.cs index fb221560b7..0f0366c200 100644 --- a/csharp/test/Drivers/Databricks/DatabricksTestConfiguration.cs +++ b/csharp/test/Drivers/Databricks/DatabricksTestConfiguration.cs @@ -15,12 +15,17 @@ * limitations under the License. */ +using System.Text.Json.Serialization; using Apache.Arrow.Adbc.Tests.Drivers.Apache.Spark; namespace Apache.Arrow.Adbc.Tests.Drivers.Databricks { public class DatabricksTestConfiguration : SparkTestConfiguration { + [JsonPropertyName("oauth_client_id"), JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)] + public string OAuthClientId { get; set; } = string.Empty; + [JsonPropertyName("oauth_client_secret"), JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)] + public string OAuthClientSecret { get; set; } = string.Empty; } } From d4edad77c009c2fb8d87533d603ead91965444c2 Mon Sep 17 00:00:00 2001 From: toddmeng-db Date: Fri, 25 Apr 2025 15:57:13 -0700 Subject: [PATCH 05/11] lint fixes --- .../Databricks/Auth/OAuthClientCredentialsService.cs | 10 +++++----- .../Auth/OAuthClientCredentialsServiceTests.cs | 2 +- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/csharp/src/Drivers/Databricks/Auth/OAuthClientCredentialsService.cs b/csharp/src/Drivers/Databricks/Auth/OAuthClientCredentialsService.cs index bb05c0d52e..a17bd5551a 100644 --- a/csharp/src/Drivers/Databricks/Auth/OAuthClientCredentialsService.cs +++ b/csharp/src/Drivers/Databricks/Auth/OAuthClientCredentialsService.cs @@ -43,9 +43,9 @@ private class TokenInfo { public string? AccessToken { get; set; } public DateTime ExpiresAt { get; set; } - + public bool IsExpired => DateTime.UtcNow >= ExpiresAt; - + // Add buffer time to refresh token before actual expiration public bool NeedsRefresh => DateTime.UtcNow >= ExpiresAt.AddMinutes(-5); } @@ -112,7 +112,7 @@ private async Task RefreshTokenInternalAsync(CancellationToken cancellat } string content = await response.Content.ReadAsStringAsync(); - + try { _cachedToken = ParseTokenResponse(content); @@ -149,7 +149,7 @@ private HttpRequestMessage CreateTokenRequest() private TokenInfo ParseTokenResponse(string content) { using var jsonDoc = JsonDocument.Parse(content); - + if (!jsonDoc.RootElement.TryGetProperty("access_token", out var accessTokenElement)) { throw new DatabricksException("OAuth response did not contain an access_token"); @@ -228,4 +228,4 @@ public void Dispose() } } -} \ No newline at end of file +} diff --git a/csharp/test/Drivers/Databricks/Auth/OAuthClientCredentialsServiceTests.cs b/csharp/test/Drivers/Databricks/Auth/OAuthClientCredentialsServiceTests.cs index 8697633d6d..47dd4c9824 100644 --- a/csharp/test/Drivers/Databricks/Auth/OAuthClientCredentialsServiceTests.cs +++ b/csharp/test/Drivers/Databricks/Auth/OAuthClientCredentialsServiceTests.cs @@ -62,4 +62,4 @@ void IDisposable.Dispose() base.Dispose(); } } -} \ No newline at end of file +} From 217c58ae056d1be5c0a547d04709aa8f8d0ade55 Mon Sep 17 00:00:00 2001 From: toddmeng-db Date: Fri, 25 Apr 2025 16:29:26 -0700 Subject: [PATCH 06/11] lint fix --- .../OAuthClientCredentialsServiceTests.cs | 20 +++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/csharp/test/Drivers/Databricks/Auth/OAuthClientCredentialsServiceTests.cs b/csharp/test/Drivers/Databricks/Auth/OAuthClientCredentialsServiceTests.cs index 47dd4c9824..464e247ce2 100644 --- a/csharp/test/Drivers/Databricks/Auth/OAuthClientCredentialsServiceTests.cs +++ b/csharp/test/Drivers/Databricks/Auth/OAuthClientCredentialsServiceTests.cs @@ -1,8 +1,24 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + using System; using System.Threading; using System.Threading.Tasks; using Apache.Arrow.Adbc.Drivers.Databricks.Auth; -using Apache.Arrow.Adbc.Tests.Drivers.Databricks; using Xunit; using Xunit.Abstractions; @@ -62,4 +78,4 @@ void IDisposable.Dispose() base.Dispose(); } } -} +} From efe06c25ec2c8b0ee329f148a303cea81b28cbaa Mon Sep 17 00:00:00 2001 From: toddmeng-db Date: Sun, 27 Apr 2025 21:56:26 -0700 Subject: [PATCH 07/11] feedback fixes --- ...e.cs => OAuthClientCredentialsProvider.cs} | 48 ++++++------------- ...=> OAuthClientCredentialsProviderTests.cs} | 10 ++-- 2 files changed, 20 insertions(+), 38 deletions(-) rename csharp/src/Drivers/Databricks/Auth/{OAuthClientCredentialsService.cs => OAuthClientCredentialsProvider.cs} (82%) rename csharp/test/Drivers/Databricks/Auth/{OAuthClientCredentialsServiceTests.cs => OAuthClientCredentialsProviderTests.cs} (87%) diff --git a/csharp/src/Drivers/Databricks/Auth/OAuthClientCredentialsService.cs b/csharp/src/Drivers/Databricks/Auth/OAuthClientCredentialsProvider.cs similarity index 82% rename from csharp/src/Drivers/Databricks/Auth/OAuthClientCredentialsService.cs rename to csharp/src/Drivers/Databricks/Auth/OAuthClientCredentialsProvider.cs index a17bd5551a..aed1ed1f15 100644 --- a/csharp/src/Drivers/Databricks/Auth/OAuthClientCredentialsService.cs +++ b/csharp/src/Drivers/Databricks/Auth/OAuthClientCredentialsProvider.cs @@ -28,12 +28,12 @@ namespace Apache.Arrow.Adbc.Drivers.Databricks.Auth /// /// Service for obtaining OAuth access tokens using the client credentials grant type. /// - internal class OAuthClientCredentialsService : IDisposable + internal class OAuthClientCredentialsProvider : IDisposable { - private readonly Lazy _httpClient; + private readonly HttpClient _httpClient; private readonly string _clientId; private readonly string _clientSecret; - private readonly Uri _baseUri; + private readonly string _host; private readonly string _tokenEndpoint; private readonly int _timeoutMinutes; private readonly SemaphoreSlim _tokenLock = new SemaphoreSlim(1, 1); @@ -56,36 +56,27 @@ private class TokenInfo /// The OAuth client ID. /// The OAuth client secret. /// The base URI of the Databricks workspace. - public OAuthClientCredentialsService( + public OAuthClientCredentialsProvider( string clientId, string clientSecret, - Uri baseUri, - int timeoutMinutes = 1, - HttpClient? httpClient = null) + string host, + int timeoutMinutes = 1) { _clientId = clientId ?? throw new ArgumentNullException(nameof(clientId)); _clientSecret = clientSecret ?? throw new ArgumentNullException(nameof(clientSecret)); - _baseUri = baseUri ?? throw new ArgumentNullException(nameof(baseUri)); + _host = host ?? throw new ArgumentNullException(nameof(host)); _timeoutMinutes = timeoutMinutes; _tokenEndpoint = DetermineTokenEndpoint(); - _httpClient = httpClient != null - ? new Lazy(() => httpClient) - : new Lazy(() => - { - var client = new HttpClient(); - client.Timeout = TimeSpan.FromMinutes(_timeoutMinutes); - return client; - }); + _httpClient = new HttpClient(); + _httpClient.Timeout = TimeSpan.FromMinutes(_timeoutMinutes); } - private HttpClient HttpClient => _httpClient.Value; - private string DetermineTokenEndpoint() { // For workspace URLs, the token endpoint is always /oidc/v1/token // TODO: Might be different for Azure AAD SPs - return $"{_baseUri.Scheme}://{_baseUri.Host}/oidc/v1/token"; + return $"https://{_host}/oidc/v1/token"; } private string? GetValidCachedToken() @@ -103,7 +94,7 @@ private async Task RefreshTokenInternalAsync(CancellationToken cancellat HttpResponseMessage response; try { - response = await HttpClient.SendAsync(request, cancellationToken); + response = await _httpClient.SendAsync(request, cancellationToken); response.EnsureSuccessStatusCode(); } catch (HttpRequestException ex) @@ -129,7 +120,7 @@ private HttpRequestMessage CreateTokenRequest() var requestContent = new FormUrlEncodedContent(new[] { new KeyValuePair("grant_type", "client_credentials"), - new KeyValuePair("scope", "sql") + new KeyValuePair("scope", "all-apis") }); var request = new HttpRequestMessage(HttpMethod.Post, _tokenEndpoint) @@ -194,12 +185,7 @@ public async Task GetAccessTokenAsync(CancellationToken cancellationToke return cachedToken; } - // If token needs refresh, acquire lock with timeout - var lockTimeout = TimeSpan.FromSeconds(30); // Reasonable timeout for lock acquisition - if (!await _tokenLock.WaitAsync(lockTimeout, cancellationToken).ConfigureAwait(false)) - { - throw new TimeoutException("Timeout waiting for token refresh lock"); - } + await _tokenLock.WaitAsync(cancellationToken); try { @@ -209,7 +195,7 @@ public async Task GetAccessTokenAsync(CancellationToken cancellationToke return refreshedToken; } - return await RefreshTokenInternalAsync(cancellationToken).ConfigureAwait(false); + return await RefreshTokenInternalAsync(cancellationToken); } finally { @@ -217,14 +203,10 @@ public async Task GetAccessTokenAsync(CancellationToken cancellationToke } } - public void Dispose() { _tokenLock.Dispose(); - if (_httpClient.IsValueCreated) - { - HttpClient.Dispose(); - } + _httpClient.Dispose(); } } diff --git a/csharp/test/Drivers/Databricks/Auth/OAuthClientCredentialsServiceTests.cs b/csharp/test/Drivers/Databricks/Auth/OAuthClientCredentialsProviderTests.cs similarity index 87% rename from csharp/test/Drivers/Databricks/Auth/OAuthClientCredentialsServiceTests.cs rename to csharp/test/Drivers/Databricks/Auth/OAuthClientCredentialsProviderTests.cs index 464e247ce2..11d6044b50 100644 --- a/csharp/test/Drivers/Databricks/Auth/OAuthClientCredentialsServiceTests.cs +++ b/csharp/test/Drivers/Databricks/Auth/OAuthClientCredentialsProviderTests.cs @@ -24,17 +24,17 @@ namespace Apache.Arrow.Adbc.Tests.Drivers.Databricks.Auth { - public class OAuthClientCredentialsServiceTests : TestBase, IDisposable + public class OAuthClientCredentialsProviderTests : TestBase, IDisposable { - private readonly OAuthClientCredentialsService _service; + private readonly OAuthClientCredentialsProvider _service; - public OAuthClientCredentialsServiceTests(ITestOutputHelper? outputHelper) + public OAuthClientCredentialsProviderTests(ITestOutputHelper? outputHelper) : base(outputHelper, new DatabricksTestEnvironment.Factory()) { - _service = new OAuthClientCredentialsService( + _service = new OAuthClientCredentialsProvider( TestConfiguration.OAuthClientId, TestConfiguration.OAuthClientSecret, - new Uri(TestConfiguration.Uri), + TestConfiguration.HostName, timeoutMinutes: 1); } From d7a67bd1829ad0d0c57df0590a15de8a691aa7f1 Mon Sep 17 00:00:00 2001 From: toddmeng-db Date: Mon, 28 Apr 2025 15:21:29 -0700 Subject: [PATCH 08/11] non-async behavior --- .../Auth/OAuthClientCredentialsProvider.cs | 32 +++++++++++-------- .../OAuthClientCredentialsProviderTests.cs | 17 +++++----- 2 files changed, 28 insertions(+), 21 deletions(-) diff --git a/csharp/src/Drivers/Databricks/Auth/OAuthClientCredentialsProvider.cs b/csharp/src/Drivers/Databricks/Auth/OAuthClientCredentialsProvider.cs index aed1ed1f15..de1729efd9 100644 --- a/csharp/src/Drivers/Databricks/Auth/OAuthClientCredentialsProvider.cs +++ b/csharp/src/Drivers/Databricks/Auth/OAuthClientCredentialsProvider.cs @@ -171,20 +171,8 @@ private TokenInfo ParseTokenResponse(string content) }; } - /// - /// Gets an OAuth access token using the client credentials grant type. - /// - /// A cancellation token to cancel the operation. - /// The access token. - /// Thrown when the token request fails or the response is invalid. - public async Task GetAccessTokenAsync(CancellationToken cancellationToken = default) + private async Task GetAccessTokenAsync(CancellationToken cancellationToken = default) { - // First try to get cached token without acquiring lock - if (GetValidCachedToken() is string cachedToken) - { - return cachedToken; - } - await _tokenLock.WaitAsync(cancellationToken); try @@ -203,6 +191,24 @@ public async Task GetAccessTokenAsync(CancellationToken cancellationToke } } + + /// + /// Gets an OAuth access token using the client credentials grant type. + /// + /// A cancellation token to cancel the operation. + /// The access token. + public string GetAccessToken(CancellationToken cancellationToken = default) + { + // First try to get cached token without acquiring lock + if (GetValidCachedToken() is string cachedToken) + { + return cachedToken; + } + + return GetAccessTokenAsync(cancellationToken).GetAwaiter().GetResult(); + } + + public void Dispose() { _tokenLock.Dispose(); diff --git a/csharp/test/Drivers/Databricks/Auth/OAuthClientCredentialsProviderTests.cs b/csharp/test/Drivers/Databricks/Auth/OAuthClientCredentialsProviderTests.cs index 11d6044b50..08017fa3d2 100644 --- a/csharp/test/Drivers/Databricks/Auth/OAuthClientCredentialsProviderTests.cs +++ b/csharp/test/Drivers/Databricks/Auth/OAuthClientCredentialsProviderTests.cs @@ -21,6 +21,7 @@ using Apache.Arrow.Adbc.Drivers.Databricks.Auth; using Xunit; using Xunit.Abstractions; +using Apache.Arrow.Adbc.Tests.Drivers.Databricks; namespace Apache.Arrow.Adbc.Tests.Drivers.Databricks.Auth { @@ -39,35 +40,35 @@ public OAuthClientCredentialsProviderTests(ITestOutputHelper? outputHelper) } [Fact] - public async Task GetAccessToken_WithValidCredentials_ReturnsToken() + public void GetAccessToken_WithValidCredentials_ReturnsToken() { Skip.IfNot(!string.IsNullOrEmpty(TestConfiguration.OAuthClientId), "OAuth credentials not configured"); - var token = await _service.GetAccessTokenAsync(CancellationToken.None); + var token = _service.GetAccessToken(); Assert.NotNull(token); Assert.NotEmpty(token); } [Fact] - public async Task GetAccessToken_WithCancellation_ThrowsOperationCanceledException() + public void GetAccessToken_WithCancellation_ThrowsOperationCanceledException() { Skip.IfNot(!string.IsNullOrEmpty(TestConfiguration.OAuthClientId), "OAuth credentials not configured"); using var cts = new CancellationTokenSource(); cts.Cancel(); - await Assert.ThrowsAsync(() => - _service.GetAccessTokenAsync(cts.Token)); + Assert.Throws(() => + _service.GetAccessToken(cts.Token)); } [Fact] - public async Task GetAccessToken_MultipleCalls_ReusesCachedToken() + public void GetAccessToken_MultipleCalls_ReusesCachedToken() { Skip.IfNot(!string.IsNullOrEmpty(TestConfiguration.OAuthClientId), "OAuth credentials not configured"); - var token1 = await _service.GetAccessTokenAsync(CancellationToken.None); - var token2 = await _service.GetAccessTokenAsync(CancellationToken.None); + var token1 = _service.GetAccessToken(); + var token2 = _service.GetAccessToken(); Assert.Equal(token1, token2); } From 6a899a97595d88b757326c138a12726f41e4f594 Mon Sep 17 00:00:00 2001 From: toddmeng-db Date: Mon, 28 Apr 2025 17:28:49 -0700 Subject: [PATCH 09/11] add azure ad service principal support test bug fix --- .../Auth/OAuthClientCredentialsProvider.cs | 21 +++++++++--- .../OAuthClientCredentialsProviderTests.cs | 34 +++++++++---------- .../Databricks/DatabricksTestConfiguration.cs | 3 ++ 3 files changed, 37 insertions(+), 21 deletions(-) diff --git a/csharp/src/Drivers/Databricks/Auth/OAuthClientCredentialsProvider.cs b/csharp/src/Drivers/Databricks/Auth/OAuthClientCredentialsProvider.cs index de1729efd9..ac791507a2 100644 --- a/csharp/src/Drivers/Databricks/Auth/OAuthClientCredentialsProvider.cs +++ b/csharp/src/Drivers/Databricks/Auth/OAuthClientCredentialsProvider.cs @@ -36,6 +36,8 @@ internal class OAuthClientCredentialsProvider : IDisposable private readonly string _host; private readonly string _tokenEndpoint; private readonly int _timeoutMinutes; + private readonly string? _tenantId; + private readonly string? _accountId; private readonly SemaphoreSlim _tokenLock = new SemaphoreSlim(1, 1); private TokenInfo? _cachedToken; @@ -55,16 +57,23 @@ private class TokenInfo /// /// The OAuth client ID. /// The OAuth client secret. - /// The base URI of the Databricks workspace. + /// The host of the Databricks workspace. + /// The Azure AD tenant ID (optional). + /// The Databricks account ID (optional). + /// The timeout in minutes for HTTP requests. public OAuthClientCredentialsProvider( string clientId, string clientSecret, string host, + string? tenantId = null, + string? accountId = null, int timeoutMinutes = 1) { _clientId = clientId ?? throw new ArgumentNullException(nameof(clientId)); _clientSecret = clientSecret ?? throw new ArgumentNullException(nameof(clientSecret)); _host = host ?? throw new ArgumentNullException(nameof(host)); + _tenantId = tenantId; + _accountId = accountId; _timeoutMinutes = timeoutMinutes; _tokenEndpoint = DetermineTokenEndpoint(); @@ -74,8 +83,12 @@ public OAuthClientCredentialsProvider( private string DetermineTokenEndpoint() { - // For workspace URLs, the token endpoint is always /oidc/v1/token - // TODO: Might be different for Azure AAD SPs + if (!string.IsNullOrEmpty(_tenantId)) + { + // Use the tenant-specific Azure OIDC token endpoint + return $"https://login.microsoftonline.com/{_tenantId}/oauth2/v2.0/token"; + } + return $"https://{_host}/oidc/v1/token"; } @@ -120,7 +133,7 @@ private HttpRequestMessage CreateTokenRequest() var requestContent = new FormUrlEncodedContent(new[] { new KeyValuePair("grant_type", "client_credentials"), - new KeyValuePair("scope", "all-apis") + new KeyValuePair("scope", !string.IsNullOrEmpty(_tenantId) ? "https://databricks.azure.net/.default" : "all-apis") }); var request = new HttpRequestMessage(HttpMethod.Post, _tokenEndpoint) diff --git a/csharp/test/Drivers/Databricks/Auth/OAuthClientCredentialsProviderTests.cs b/csharp/test/Drivers/Databricks/Auth/OAuthClientCredentialsProviderTests.cs index 08017fa3d2..de9b01684d 100644 --- a/csharp/test/Drivers/Databricks/Auth/OAuthClientCredentialsProviderTests.cs +++ b/csharp/test/Drivers/Databricks/Auth/OAuthClientCredentialsProviderTests.cs @@ -27,56 +27,56 @@ namespace Apache.Arrow.Adbc.Tests.Drivers.Databricks.Auth { public class OAuthClientCredentialsProviderTests : TestBase, IDisposable { - private readonly OAuthClientCredentialsProvider _service; - public OAuthClientCredentialsProviderTests(ITestOutputHelper? outputHelper) : base(outputHelper, new DatabricksTestEnvironment.Factory()) { - _service = new OAuthClientCredentialsProvider( + } + + private OAuthClientCredentialsProvider CreateService() + { + return new OAuthClientCredentialsProvider( TestConfiguration.OAuthClientId, TestConfiguration.OAuthClientSecret, TestConfiguration.HostName, timeoutMinutes: 1); } - [Fact] + [SkippableFact] public void GetAccessToken_WithValidCredentials_ReturnsToken() { Skip.IfNot(!string.IsNullOrEmpty(TestConfiguration.OAuthClientId), "OAuth credentials not configured"); - var token = _service.GetAccessToken(); + var service = CreateService(); + var token = service.GetAccessToken(); Assert.NotNull(token); Assert.NotEmpty(token); } - [Fact] + [SkippableFact] public void GetAccessToken_WithCancellation_ThrowsOperationCanceledException() { Skip.IfNot(!string.IsNullOrEmpty(TestConfiguration.OAuthClientId), "OAuth credentials not configured"); + var service = CreateService(); using var cts = new CancellationTokenSource(); cts.Cancel(); - Assert.Throws(() => - _service.GetAccessToken(cts.Token)); + var ex = Assert.ThrowsAny(() => + service.GetAccessToken(cts.Token)); + Assert.IsType(ex); } - [Fact] + [SkippableFact] public void GetAccessToken_MultipleCalls_ReusesCachedToken() { Skip.IfNot(!string.IsNullOrEmpty(TestConfiguration.OAuthClientId), "OAuth credentials not configured"); - var token1 = _service.GetAccessToken(); - var token2 = _service.GetAccessToken(); + var service = CreateService(); + var token1 = service.GetAccessToken(); + var token2 = service.GetAccessToken(); Assert.Equal(token1, token2); } - - void IDisposable.Dispose() - { - _service.Dispose(); - base.Dispose(); - } } } diff --git a/csharp/test/Drivers/Databricks/DatabricksTestConfiguration.cs b/csharp/test/Drivers/Databricks/DatabricksTestConfiguration.cs index 0f0366c200..84d99c51f7 100644 --- a/csharp/test/Drivers/Databricks/DatabricksTestConfiguration.cs +++ b/csharp/test/Drivers/Databricks/DatabricksTestConfiguration.cs @@ -27,5 +27,8 @@ public class DatabricksTestConfiguration : SparkTestConfiguration [JsonPropertyName("oauth_client_secret"), JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)] public string OAuthClientSecret { get; set; } = string.Empty; + + [JsonPropertyName("azure_tenant_id"), JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)] + public string AzureTenantId { get; set; } = string.Empty; } } From 01cb38edfa89cfc1546187212a3209496a7ca430 Mon Sep 17 00:00:00 2001 From: toddmeng-db Date: Tue, 29 Apr 2025 10:17:59 -0700 Subject: [PATCH 10/11] Revert "add azure ad service principal support" This reverts commit 6a899a97595d88b757326c138a12726f41e4f594. --- .../Auth/OAuthClientCredentialsProvider.cs | 21 +++--------- .../OAuthClientCredentialsProviderTests.cs | 34 +++++++++---------- .../Databricks/DatabricksTestConfiguration.cs | 3 -- 3 files changed, 21 insertions(+), 37 deletions(-) diff --git a/csharp/src/Drivers/Databricks/Auth/OAuthClientCredentialsProvider.cs b/csharp/src/Drivers/Databricks/Auth/OAuthClientCredentialsProvider.cs index ac791507a2..de1729efd9 100644 --- a/csharp/src/Drivers/Databricks/Auth/OAuthClientCredentialsProvider.cs +++ b/csharp/src/Drivers/Databricks/Auth/OAuthClientCredentialsProvider.cs @@ -36,8 +36,6 @@ internal class OAuthClientCredentialsProvider : IDisposable private readonly string _host; private readonly string _tokenEndpoint; private readonly int _timeoutMinutes; - private readonly string? _tenantId; - private readonly string? _accountId; private readonly SemaphoreSlim _tokenLock = new SemaphoreSlim(1, 1); private TokenInfo? _cachedToken; @@ -57,23 +55,16 @@ private class TokenInfo /// /// The OAuth client ID. /// The OAuth client secret. - /// The host of the Databricks workspace. - /// The Azure AD tenant ID (optional). - /// The Databricks account ID (optional). - /// The timeout in minutes for HTTP requests. + /// The base URI of the Databricks workspace. public OAuthClientCredentialsProvider( string clientId, string clientSecret, string host, - string? tenantId = null, - string? accountId = null, int timeoutMinutes = 1) { _clientId = clientId ?? throw new ArgumentNullException(nameof(clientId)); _clientSecret = clientSecret ?? throw new ArgumentNullException(nameof(clientSecret)); _host = host ?? throw new ArgumentNullException(nameof(host)); - _tenantId = tenantId; - _accountId = accountId; _timeoutMinutes = timeoutMinutes; _tokenEndpoint = DetermineTokenEndpoint(); @@ -83,12 +74,8 @@ public OAuthClientCredentialsProvider( private string DetermineTokenEndpoint() { - if (!string.IsNullOrEmpty(_tenantId)) - { - // Use the tenant-specific Azure OIDC token endpoint - return $"https://login.microsoftonline.com/{_tenantId}/oauth2/v2.0/token"; - } - + // For workspace URLs, the token endpoint is always /oidc/v1/token + // TODO: Might be different for Azure AAD SPs return $"https://{_host}/oidc/v1/token"; } @@ -133,7 +120,7 @@ private HttpRequestMessage CreateTokenRequest() var requestContent = new FormUrlEncodedContent(new[] { new KeyValuePair("grant_type", "client_credentials"), - new KeyValuePair("scope", !string.IsNullOrEmpty(_tenantId) ? "https://databricks.azure.net/.default" : "all-apis") + new KeyValuePair("scope", "all-apis") }); var request = new HttpRequestMessage(HttpMethod.Post, _tokenEndpoint) diff --git a/csharp/test/Drivers/Databricks/Auth/OAuthClientCredentialsProviderTests.cs b/csharp/test/Drivers/Databricks/Auth/OAuthClientCredentialsProviderTests.cs index de9b01684d..08017fa3d2 100644 --- a/csharp/test/Drivers/Databricks/Auth/OAuthClientCredentialsProviderTests.cs +++ b/csharp/test/Drivers/Databricks/Auth/OAuthClientCredentialsProviderTests.cs @@ -27,56 +27,56 @@ namespace Apache.Arrow.Adbc.Tests.Drivers.Databricks.Auth { public class OAuthClientCredentialsProviderTests : TestBase, IDisposable { + private readonly OAuthClientCredentialsProvider _service; + public OAuthClientCredentialsProviderTests(ITestOutputHelper? outputHelper) : base(outputHelper, new DatabricksTestEnvironment.Factory()) { - } - - private OAuthClientCredentialsProvider CreateService() - { - return new OAuthClientCredentialsProvider( + _service = new OAuthClientCredentialsProvider( TestConfiguration.OAuthClientId, TestConfiguration.OAuthClientSecret, TestConfiguration.HostName, timeoutMinutes: 1); } - [SkippableFact] + [Fact] public void GetAccessToken_WithValidCredentials_ReturnsToken() { Skip.IfNot(!string.IsNullOrEmpty(TestConfiguration.OAuthClientId), "OAuth credentials not configured"); - var service = CreateService(); - var token = service.GetAccessToken(); + var token = _service.GetAccessToken(); Assert.NotNull(token); Assert.NotEmpty(token); } - [SkippableFact] + [Fact] public void GetAccessToken_WithCancellation_ThrowsOperationCanceledException() { Skip.IfNot(!string.IsNullOrEmpty(TestConfiguration.OAuthClientId), "OAuth credentials not configured"); - var service = CreateService(); using var cts = new CancellationTokenSource(); cts.Cancel(); - var ex = Assert.ThrowsAny(() => - service.GetAccessToken(cts.Token)); - Assert.IsType(ex); + Assert.Throws(() => + _service.GetAccessToken(cts.Token)); } - [SkippableFact] + [Fact] public void GetAccessToken_MultipleCalls_ReusesCachedToken() { Skip.IfNot(!string.IsNullOrEmpty(TestConfiguration.OAuthClientId), "OAuth credentials not configured"); - var service = CreateService(); - var token1 = service.GetAccessToken(); - var token2 = service.GetAccessToken(); + var token1 = _service.GetAccessToken(); + var token2 = _service.GetAccessToken(); Assert.Equal(token1, token2); } + + void IDisposable.Dispose() + { + _service.Dispose(); + base.Dispose(); + } } } diff --git a/csharp/test/Drivers/Databricks/DatabricksTestConfiguration.cs b/csharp/test/Drivers/Databricks/DatabricksTestConfiguration.cs index 84d99c51f7..0f0366c200 100644 --- a/csharp/test/Drivers/Databricks/DatabricksTestConfiguration.cs +++ b/csharp/test/Drivers/Databricks/DatabricksTestConfiguration.cs @@ -27,8 +27,5 @@ public class DatabricksTestConfiguration : SparkTestConfiguration [JsonPropertyName("oauth_client_secret"), JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)] public string OAuthClientSecret { get; set; } = string.Empty; - - [JsonPropertyName("azure_tenant_id"), JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)] - public string AzureTenantId { get; set; } = string.Empty; } } From ae6bb29e809f7945e1849f847ce76ddf103732bb Mon Sep 17 00:00:00 2001 From: toddmeng-db Date: Tue, 29 Apr 2025 10:19:32 -0700 Subject: [PATCH 11/11] fix revert --- .../Auth/OAuthClientCredentialsProvider.cs | 5 +-- .../OAuthClientCredentialsProviderTests.cs | 34 +++++++++---------- 2 files changed, 18 insertions(+), 21 deletions(-) diff --git a/csharp/src/Drivers/Databricks/Auth/OAuthClientCredentialsProvider.cs b/csharp/src/Drivers/Databricks/Auth/OAuthClientCredentialsProvider.cs index de1729efd9..8fa797586e 100644 --- a/csharp/src/Drivers/Databricks/Auth/OAuthClientCredentialsProvider.cs +++ b/csharp/src/Drivers/Databricks/Auth/OAuthClientCredentialsProvider.cs @@ -44,8 +44,6 @@ private class TokenInfo public string? AccessToken { get; set; } public DateTime ExpiresAt { get; set; } - public bool IsExpired => DateTime.UtcNow >= ExpiresAt; - // Add buffer time to refresh token before actual expiration public bool NeedsRefresh => DateTime.UtcNow >= ExpiresAt.AddMinutes(-5); } @@ -75,7 +73,6 @@ public OAuthClientCredentialsProvider( private string DetermineTokenEndpoint() { // For workspace URLs, the token endpoint is always /oidc/v1/token - // TODO: Might be different for Azure AAD SPs return $"https://{_host}/oidc/v1/token"; } @@ -97,7 +94,7 @@ private async Task RefreshTokenInternalAsync(CancellationToken cancellat response = await _httpClient.SendAsync(request, cancellationToken); response.EnsureSuccessStatusCode(); } - catch (HttpRequestException ex) + catch (Exception ex) { throw new DatabricksException($"Failed to acquire OAuth access token: {ex.Message}", ex); } diff --git a/csharp/test/Drivers/Databricks/Auth/OAuthClientCredentialsProviderTests.cs b/csharp/test/Drivers/Databricks/Auth/OAuthClientCredentialsProviderTests.cs index 08017fa3d2..de9b01684d 100644 --- a/csharp/test/Drivers/Databricks/Auth/OAuthClientCredentialsProviderTests.cs +++ b/csharp/test/Drivers/Databricks/Auth/OAuthClientCredentialsProviderTests.cs @@ -27,56 +27,56 @@ namespace Apache.Arrow.Adbc.Tests.Drivers.Databricks.Auth { public class OAuthClientCredentialsProviderTests : TestBase, IDisposable { - private readonly OAuthClientCredentialsProvider _service; - public OAuthClientCredentialsProviderTests(ITestOutputHelper? outputHelper) : base(outputHelper, new DatabricksTestEnvironment.Factory()) { - _service = new OAuthClientCredentialsProvider( + } + + private OAuthClientCredentialsProvider CreateService() + { + return new OAuthClientCredentialsProvider( TestConfiguration.OAuthClientId, TestConfiguration.OAuthClientSecret, TestConfiguration.HostName, timeoutMinutes: 1); } - [Fact] + [SkippableFact] public void GetAccessToken_WithValidCredentials_ReturnsToken() { Skip.IfNot(!string.IsNullOrEmpty(TestConfiguration.OAuthClientId), "OAuth credentials not configured"); - var token = _service.GetAccessToken(); + var service = CreateService(); + var token = service.GetAccessToken(); Assert.NotNull(token); Assert.NotEmpty(token); } - [Fact] + [SkippableFact] public void GetAccessToken_WithCancellation_ThrowsOperationCanceledException() { Skip.IfNot(!string.IsNullOrEmpty(TestConfiguration.OAuthClientId), "OAuth credentials not configured"); + var service = CreateService(); using var cts = new CancellationTokenSource(); cts.Cancel(); - Assert.Throws(() => - _service.GetAccessToken(cts.Token)); + var ex = Assert.ThrowsAny(() => + service.GetAccessToken(cts.Token)); + Assert.IsType(ex); } - [Fact] + [SkippableFact] public void GetAccessToken_MultipleCalls_ReusesCachedToken() { Skip.IfNot(!string.IsNullOrEmpty(TestConfiguration.OAuthClientId), "OAuth credentials not configured"); - var token1 = _service.GetAccessToken(); - var token2 = _service.GetAccessToken(); + var service = CreateService(); + var token1 = service.GetAccessToken(); + var token2 = service.GetAccessToken(); Assert.Equal(token1, token2); } - - void IDisposable.Dispose() - { - _service.Dispose(); - base.Dispose(); - } } }