diff --git a/csharp/src/Drivers/Databricks/Auth/OAuthClientCredentialsProvider.cs b/csharp/src/Drivers/Databricks/Auth/OAuthClientCredentialsProvider.cs new file mode 100644 index 0000000000..8fa797586e --- /dev/null +++ b/csharp/src/Drivers/Databricks/Auth/OAuthClientCredentialsProvider.cs @@ -0,0 +1,216 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +using System; +using System.Collections.Generic; +using System.Net.Http; +using System.Net.Http.Headers; +using System.Text.Json; +using System.Threading; +using System.Threading.Tasks; + +namespace Apache.Arrow.Adbc.Drivers.Databricks.Auth +{ + /// + /// Service for obtaining OAuth access tokens using the client credentials grant type. + /// + internal class OAuthClientCredentialsProvider : IDisposable + { + private readonly HttpClient _httpClient; + private readonly string _clientId; + private readonly string _clientSecret; + private readonly string _host; + private readonly string _tokenEndpoint; + private readonly int _timeoutMinutes; + private readonly SemaphoreSlim _tokenLock = new SemaphoreSlim(1, 1); + private TokenInfo? _cachedToken; + + private class TokenInfo + { + public string? AccessToken { get; set; } + public DateTime ExpiresAt { get; set; } + + // Add buffer time to refresh token before actual expiration + public bool NeedsRefresh => DateTime.UtcNow >= ExpiresAt.AddMinutes(-5); + } + + /// + /// Initializes a new instance of the class. + /// + /// The OAuth client ID. + /// The OAuth client secret. + /// The base URI of the Databricks workspace. + public OAuthClientCredentialsProvider( + string clientId, + string clientSecret, + string host, + int timeoutMinutes = 1) + { + _clientId = clientId ?? throw new ArgumentNullException(nameof(clientId)); + _clientSecret = clientSecret ?? throw new ArgumentNullException(nameof(clientSecret)); + _host = host ?? throw new ArgumentNullException(nameof(host)); + _timeoutMinutes = timeoutMinutes; + _tokenEndpoint = DetermineTokenEndpoint(); + + _httpClient = new HttpClient(); + _httpClient.Timeout = TimeSpan.FromMinutes(_timeoutMinutes); + } + + private string DetermineTokenEndpoint() + { + // For workspace URLs, the token endpoint is always /oidc/v1/token + return $"https://{_host}/oidc/v1/token"; + } + + private string? GetValidCachedToken() + { + return _cachedToken != null && !_cachedToken.NeedsRefresh && _cachedToken.AccessToken != null + ? _cachedToken.AccessToken + : null; + } + + + private async Task RefreshTokenInternalAsync(CancellationToken cancellationToken) + { + var request = CreateTokenRequest(); + + HttpResponseMessage response; + try + { + response = await _httpClient.SendAsync(request, cancellationToken); + response.EnsureSuccessStatusCode(); + } + catch (Exception ex) + { + throw new DatabricksException($"Failed to acquire OAuth access token: {ex.Message}", ex); + } + + string content = await response.Content.ReadAsStringAsync(); + + try + { + _cachedToken = ParseTokenResponse(content); + return _cachedToken.AccessToken!; + } + catch (JsonException ex) + { + throw new DatabricksException($"Failed to parse OAuth response: {ex.Message}", ex); + } + } + + private HttpRequestMessage CreateTokenRequest() + { + var requestContent = new FormUrlEncodedContent(new[] + { + new KeyValuePair("grant_type", "client_credentials"), + new KeyValuePair("scope", "all-apis") + }); + + var request = new HttpRequestMessage(HttpMethod.Post, _tokenEndpoint) + { + Content = requestContent + }; + + // Use Basic Auth with client ID and secret + var authHeader = Convert.ToBase64String( + System.Text.Encoding.ASCII.GetBytes($"{_clientId}:{_clientSecret}")); + request.Headers.Authorization = new AuthenticationHeaderValue("Basic", authHeader); + request.Headers.Accept.Add(new MediaTypeWithQualityHeaderValue("application/json")); + + return request; + } + + private TokenInfo ParseTokenResponse(string content) + { + using var jsonDoc = JsonDocument.Parse(content); + + if (!jsonDoc.RootElement.TryGetProperty("access_token", out var accessTokenElement)) + { + throw new DatabricksException("OAuth response did not contain an access_token"); + } + + string? accessToken = accessTokenElement.GetString(); + if (string.IsNullOrEmpty(accessToken)) + { + throw new DatabricksException("OAuth access_token was null or empty"); + } + + // Get expiration time from response + if (!jsonDoc.RootElement.TryGetProperty("expires_in", out var expiresInElement)) + { + throw new DatabricksException("OAuth response did not contain expires_in"); + } + + int expiresIn = expiresInElement.GetInt32(); + if (expiresIn <= 0) + { + throw new DatabricksException("OAuth expires_in value must be positive"); + } + + return new TokenInfo + { + AccessToken = accessToken!, + ExpiresAt = DateTime.UtcNow.AddSeconds(expiresIn) + }; + } + + private async Task GetAccessTokenAsync(CancellationToken cancellationToken = default) + { + await _tokenLock.WaitAsync(cancellationToken); + + try + { + // Double-check pattern in case another thread refreshed while we were waiting + if (GetValidCachedToken() is string refreshedToken) + { + return refreshedToken; + } + + return await RefreshTokenInternalAsync(cancellationToken); + } + finally + { + _tokenLock.Release(); + } + } + + + /// + /// Gets an OAuth access token using the client credentials grant type. + /// + /// A cancellation token to cancel the operation. + /// The access token. + public string GetAccessToken(CancellationToken cancellationToken = default) + { + // First try to get cached token without acquiring lock + if (GetValidCachedToken() is string cachedToken) + { + return cachedToken; + } + + return GetAccessTokenAsync(cancellationToken).GetAwaiter().GetResult(); + } + + + public void Dispose() + { + _tokenLock.Dispose(); + _httpClient.Dispose(); + } + + } +} diff --git a/csharp/test/Drivers/Databricks/Auth/OAuthClientCredentialsProviderTests.cs b/csharp/test/Drivers/Databricks/Auth/OAuthClientCredentialsProviderTests.cs new file mode 100644 index 0000000000..de9b01684d --- /dev/null +++ b/csharp/test/Drivers/Databricks/Auth/OAuthClientCredentialsProviderTests.cs @@ -0,0 +1,82 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +using System; +using System.Threading; +using System.Threading.Tasks; +using Apache.Arrow.Adbc.Drivers.Databricks.Auth; +using Xunit; +using Xunit.Abstractions; +using Apache.Arrow.Adbc.Tests.Drivers.Databricks; + +namespace Apache.Arrow.Adbc.Tests.Drivers.Databricks.Auth +{ + public class OAuthClientCredentialsProviderTests : TestBase, IDisposable + { + public OAuthClientCredentialsProviderTests(ITestOutputHelper? outputHelper) + : base(outputHelper, new DatabricksTestEnvironment.Factory()) + { + } + + private OAuthClientCredentialsProvider CreateService() + { + return new OAuthClientCredentialsProvider( + TestConfiguration.OAuthClientId, + TestConfiguration.OAuthClientSecret, + TestConfiguration.HostName, + timeoutMinutes: 1); + } + + [SkippableFact] + public void GetAccessToken_WithValidCredentials_ReturnsToken() + { + Skip.IfNot(!string.IsNullOrEmpty(TestConfiguration.OAuthClientId), "OAuth credentials not configured"); + + var service = CreateService(); + var token = service.GetAccessToken(); + + Assert.NotNull(token); + Assert.NotEmpty(token); + } + + [SkippableFact] + public void GetAccessToken_WithCancellation_ThrowsOperationCanceledException() + { + Skip.IfNot(!string.IsNullOrEmpty(TestConfiguration.OAuthClientId), "OAuth credentials not configured"); + + var service = CreateService(); + using var cts = new CancellationTokenSource(); + cts.Cancel(); + + var ex = Assert.ThrowsAny(() => + service.GetAccessToken(cts.Token)); + Assert.IsType(ex); + } + + [SkippableFact] + public void GetAccessToken_MultipleCalls_ReusesCachedToken() + { + Skip.IfNot(!string.IsNullOrEmpty(TestConfiguration.OAuthClientId), "OAuth credentials not configured"); + + var service = CreateService(); + var token1 = service.GetAccessToken(); + var token2 = service.GetAccessToken(); + + Assert.Equal(token1, token2); + } + } +} diff --git a/csharp/test/Drivers/Databricks/DatabricksTestConfiguration.cs b/csharp/test/Drivers/Databricks/DatabricksTestConfiguration.cs index fb221560b7..0f0366c200 100644 --- a/csharp/test/Drivers/Databricks/DatabricksTestConfiguration.cs +++ b/csharp/test/Drivers/Databricks/DatabricksTestConfiguration.cs @@ -15,12 +15,17 @@ * limitations under the License. */ +using System.Text.Json.Serialization; using Apache.Arrow.Adbc.Tests.Drivers.Apache.Spark; namespace Apache.Arrow.Adbc.Tests.Drivers.Databricks { public class DatabricksTestConfiguration : SparkTestConfiguration { + [JsonPropertyName("oauth_client_id"), JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)] + public string OAuthClientId { get; set; } = string.Empty; + [JsonPropertyName("oauth_client_secret"), JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)] + public string OAuthClientSecret { get; set; } = string.Empty; } }