Skip to content

Commit a88751d

Browse files
authored
feat(csharp/src/Drivers/Databricks): Implement ClientCredentialsProvider (#2743)
First PR for Class to get token via oauth service for M2M authentication. Includes simple expiration and refresh logic. SDK refresh logic [here](https://github.com/databricks/databricks-sdk-java/blob/main/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/RefreshableTokenSource.java) Follow up is to integrate with the rest of the driver To test out: 1. Create a `databricks_test_config.json` ``` { "oauth_client_id": "...", "oauth_client_secret": "...", "host": "databricks....com" // workspace hostname } ``` 2. On macOS/Linux export DATABRICKS_TEST_CONFIG_FILE=/path/to/your/databricks_test_config.json On Windows PowerShell $env:DATABRICKS_TEST_CONFIG_FILE = "C:\path\to\your\databricks_test_config.json" 3. ```/csharp% dotnet test --filter "FullyQualifiedName~OAuthClientCredentialsServiceTests"```
1 parent 332e145 commit a88751d

3 files changed

Lines changed: 303 additions & 0 deletions

File tree

Lines changed: 216 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,216 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
using System;
19+
using System.Collections.Generic;
20+
using System.Net.Http;
21+
using System.Net.Http.Headers;
22+
using System.Text.Json;
23+
using System.Threading;
24+
using System.Threading.Tasks;
25+
26+
namespace Apache.Arrow.Adbc.Drivers.Databricks.Auth
27+
{
28+
/// <summary>
29+
/// Service for obtaining OAuth access tokens using the client credentials grant type.
30+
/// </summary>
31+
internal class OAuthClientCredentialsProvider : IDisposable
32+
{
33+
private readonly HttpClient _httpClient;
34+
private readonly string _clientId;
35+
private readonly string _clientSecret;
36+
private readonly string _host;
37+
private readonly string _tokenEndpoint;
38+
private readonly int _timeoutMinutes;
39+
private readonly SemaphoreSlim _tokenLock = new SemaphoreSlim(1, 1);
40+
private TokenInfo? _cachedToken;
41+
42+
private class TokenInfo
43+
{
44+
public string? AccessToken { get; set; }
45+
public DateTime ExpiresAt { get; set; }
46+
47+
// Add buffer time to refresh token before actual expiration
48+
public bool NeedsRefresh => DateTime.UtcNow >= ExpiresAt.AddMinutes(-5);
49+
}
50+
51+
/// <summary>
52+
/// Initializes a new instance of the <see cref="OAuthClientCredentialsService"/> class.
53+
/// </summary>
54+
/// <param name="clientId">The OAuth client ID.</param>
55+
/// <param name="clientSecret">The OAuth client secret.</param>
56+
/// <param name="baseUri">The base URI of the Databricks workspace.</param>
57+
public OAuthClientCredentialsProvider(
58+
string clientId,
59+
string clientSecret,
60+
string host,
61+
int timeoutMinutes = 1)
62+
{
63+
_clientId = clientId ?? throw new ArgumentNullException(nameof(clientId));
64+
_clientSecret = clientSecret ?? throw new ArgumentNullException(nameof(clientSecret));
65+
_host = host ?? throw new ArgumentNullException(nameof(host));
66+
_timeoutMinutes = timeoutMinutes;
67+
_tokenEndpoint = DetermineTokenEndpoint();
68+
69+
_httpClient = new HttpClient();
70+
_httpClient.Timeout = TimeSpan.FromMinutes(_timeoutMinutes);
71+
}
72+
73+
private string DetermineTokenEndpoint()
74+
{
75+
// For workspace URLs, the token endpoint is always /oidc/v1/token
76+
return $"https://{_host}/oidc/v1/token";
77+
}
78+
79+
private string? GetValidCachedToken()
80+
{
81+
return _cachedToken != null && !_cachedToken.NeedsRefresh && _cachedToken.AccessToken != null
82+
? _cachedToken.AccessToken
83+
: null;
84+
}
85+
86+
87+
private async Task<string> RefreshTokenInternalAsync(CancellationToken cancellationToken)
88+
{
89+
var request = CreateTokenRequest();
90+
91+
HttpResponseMessage response;
92+
try
93+
{
94+
response = await _httpClient.SendAsync(request, cancellationToken);
95+
response.EnsureSuccessStatusCode();
96+
}
97+
catch (Exception ex)
98+
{
99+
throw new DatabricksException($"Failed to acquire OAuth access token: {ex.Message}", ex);
100+
}
101+
102+
string content = await response.Content.ReadAsStringAsync();
103+
104+
try
105+
{
106+
_cachedToken = ParseTokenResponse(content);
107+
return _cachedToken.AccessToken!;
108+
}
109+
catch (JsonException ex)
110+
{
111+
throw new DatabricksException($"Failed to parse OAuth response: {ex.Message}", ex);
112+
}
113+
}
114+
115+
private HttpRequestMessage CreateTokenRequest()
116+
{
117+
var requestContent = new FormUrlEncodedContent(new[]
118+
{
119+
new KeyValuePair<string, string>("grant_type", "client_credentials"),
120+
new KeyValuePair<string, string>("scope", "all-apis")
121+
});
122+
123+
var request = new HttpRequestMessage(HttpMethod.Post, _tokenEndpoint)
124+
{
125+
Content = requestContent
126+
};
127+
128+
// Use Basic Auth with client ID and secret
129+
var authHeader = Convert.ToBase64String(
130+
System.Text.Encoding.ASCII.GetBytes($"{_clientId}:{_clientSecret}"));
131+
request.Headers.Authorization = new AuthenticationHeaderValue("Basic", authHeader);
132+
request.Headers.Accept.Add(new MediaTypeWithQualityHeaderValue("application/json"));
133+
134+
return request;
135+
}
136+
137+
private TokenInfo ParseTokenResponse(string content)
138+
{
139+
using var jsonDoc = JsonDocument.Parse(content);
140+
141+
if (!jsonDoc.RootElement.TryGetProperty("access_token", out var accessTokenElement))
142+
{
143+
throw new DatabricksException("OAuth response did not contain an access_token");
144+
}
145+
146+
string? accessToken = accessTokenElement.GetString();
147+
if (string.IsNullOrEmpty(accessToken))
148+
{
149+
throw new DatabricksException("OAuth access_token was null or empty");
150+
}
151+
152+
// Get expiration time from response
153+
if (!jsonDoc.RootElement.TryGetProperty("expires_in", out var expiresInElement))
154+
{
155+
throw new DatabricksException("OAuth response did not contain expires_in");
156+
}
157+
158+
int expiresIn = expiresInElement.GetInt32();
159+
if (expiresIn <= 0)
160+
{
161+
throw new DatabricksException("OAuth expires_in value must be positive");
162+
}
163+
164+
return new TokenInfo
165+
{
166+
AccessToken = accessToken!,
167+
ExpiresAt = DateTime.UtcNow.AddSeconds(expiresIn)
168+
};
169+
}
170+
171+
private async Task<string> GetAccessTokenAsync(CancellationToken cancellationToken = default)
172+
{
173+
await _tokenLock.WaitAsync(cancellationToken);
174+
175+
try
176+
{
177+
// Double-check pattern in case another thread refreshed while we were waiting
178+
if (GetValidCachedToken() is string refreshedToken)
179+
{
180+
return refreshedToken;
181+
}
182+
183+
return await RefreshTokenInternalAsync(cancellationToken);
184+
}
185+
finally
186+
{
187+
_tokenLock.Release();
188+
}
189+
}
190+
191+
192+
/// <summary>
193+
/// Gets an OAuth access token using the client credentials grant type.
194+
/// </summary>
195+
/// <param name="cancellationToken">A cancellation token to cancel the operation.</param>
196+
/// <returns>The access token.</returns>
197+
public string GetAccessToken(CancellationToken cancellationToken = default)
198+
{
199+
// First try to get cached token without acquiring lock
200+
if (GetValidCachedToken() is string cachedToken)
201+
{
202+
return cachedToken;
203+
}
204+
205+
return GetAccessTokenAsync(cancellationToken).GetAwaiter().GetResult();
206+
}
207+
208+
209+
public void Dispose()
210+
{
211+
_tokenLock.Dispose();
212+
_httpClient.Dispose();
213+
}
214+
215+
}
216+
}
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
using System;
19+
using System.Threading;
20+
using System.Threading.Tasks;
21+
using Apache.Arrow.Adbc.Drivers.Databricks.Auth;
22+
using Xunit;
23+
using Xunit.Abstractions;
24+
using Apache.Arrow.Adbc.Tests.Drivers.Databricks;
25+
26+
namespace Apache.Arrow.Adbc.Tests.Drivers.Databricks.Auth
27+
{
28+
public class OAuthClientCredentialsProviderTests : TestBase<DatabricksTestConfiguration, DatabricksTestEnvironment>, IDisposable
29+
{
30+
public OAuthClientCredentialsProviderTests(ITestOutputHelper? outputHelper)
31+
: base(outputHelper, new DatabricksTestEnvironment.Factory())
32+
{
33+
}
34+
35+
private OAuthClientCredentialsProvider CreateService()
36+
{
37+
return new OAuthClientCredentialsProvider(
38+
TestConfiguration.OAuthClientId,
39+
TestConfiguration.OAuthClientSecret,
40+
TestConfiguration.HostName,
41+
timeoutMinutes: 1);
42+
}
43+
44+
[SkippableFact]
45+
public void GetAccessToken_WithValidCredentials_ReturnsToken()
46+
{
47+
Skip.IfNot(!string.IsNullOrEmpty(TestConfiguration.OAuthClientId), "OAuth credentials not configured");
48+
49+
var service = CreateService();
50+
var token = service.GetAccessToken();
51+
52+
Assert.NotNull(token);
53+
Assert.NotEmpty(token);
54+
}
55+
56+
[SkippableFact]
57+
public void GetAccessToken_WithCancellation_ThrowsOperationCanceledException()
58+
{
59+
Skip.IfNot(!string.IsNullOrEmpty(TestConfiguration.OAuthClientId), "OAuth credentials not configured");
60+
61+
var service = CreateService();
62+
using var cts = new CancellationTokenSource();
63+
cts.Cancel();
64+
65+
var ex = Assert.ThrowsAny<OperationCanceledException>(() =>
66+
service.GetAccessToken(cts.Token));
67+
Assert.IsType<TaskCanceledException>(ex);
68+
}
69+
70+
[SkippableFact]
71+
public void GetAccessToken_MultipleCalls_ReusesCachedToken()
72+
{
73+
Skip.IfNot(!string.IsNullOrEmpty(TestConfiguration.OAuthClientId), "OAuth credentials not configured");
74+
75+
var service = CreateService();
76+
var token1 = service.GetAccessToken();
77+
var token2 = service.GetAccessToken();
78+
79+
Assert.Equal(token1, token2);
80+
}
81+
}
82+
}

csharp/test/Drivers/Databricks/DatabricksTestConfiguration.cs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,17 @@
1515
* limitations under the License.
1616
*/
1717

18+
using System.Text.Json.Serialization;
1819
using Apache.Arrow.Adbc.Tests.Drivers.Apache.Spark;
1920

2021
namespace Apache.Arrow.Adbc.Tests.Drivers.Databricks
2122
{
2223
public class DatabricksTestConfiguration : SparkTestConfiguration
2324
{
25+
[JsonPropertyName("oauth_client_id"), JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)]
26+
public string OAuthClientId { get; set; } = string.Empty;
2427

28+
[JsonPropertyName("oauth_client_secret"), JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)]
29+
public string OAuthClientSecret { get; set; } = string.Empty;
2530
}
2631
}

0 commit comments

Comments
 (0)