diff --git a/csharp/src/Drivers/Apache/Spark/SparkHttpConnection.cs b/csharp/src/Drivers/Apache/Spark/SparkHttpConnection.cs index 9cabd1ac41..75a81228b8 100644 --- a/csharp/src/Drivers/Apache/Spark/SparkHttpConnection.cs +++ b/csharp/src/Drivers/Apache/Spark/SparkHttpConnection.cs @@ -35,7 +35,6 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Spark { internal class SparkHttpConnection : SparkConnection { - private static readonly string s_userAgent = $"{DriverName.Replace(" ", "")}/{ProductVersionDefault}"; private const string BasicAuthenticationScheme = "Basic"; private const string BearerAuthenticationScheme = "Bearer"; @@ -164,7 +163,7 @@ protected override TTransport CreateTransport() HttpClient httpClient = new(httpClientHandler); httpClient.BaseAddress = baseAddress; httpClient.DefaultRequestHeaders.Authorization = authenticationHeaderValue; - httpClient.DefaultRequestHeaders.UserAgent.ParseAdd(s_userAgent); + httpClient.DefaultRequestHeaders.UserAgent.ParseAdd(GetUserAgent()); httpClient.DefaultRequestHeaders.AcceptEncoding.Clear(); httpClient.DefaultRequestHeaders.AcceptEncoding.Add(new StringWithQualityHeaderValue("identity")); httpClient.DefaultRequestHeaders.ExpectContinue = false; @@ -252,5 +251,36 @@ protected internal override Task GetRowSetAsync(TGetPrimaryKeysResp res internal override SparkServerType ServerType => SparkServerType.Http; protected override int ColumnMapIndexOffset => 1; + + private string GetUserAgent() + { + // Build the base user agent string with Thrift version + string thriftVersion = GetThriftVersion(); + string thriftComponent = string.IsNullOrEmpty(thriftVersion) ? "Thrift" : $"Thrift/{thriftVersion}"; + string baseUserAgent = $"{DriverName.Replace(" ", "")}/{ProductVersionDefault} {thriftComponent}"; + + // Check if a client has provided a user-agent entry + if (Properties.TryGetValue(SparkParameters.UserAgentEntry, out string? userAgentEntry) && !string.IsNullOrWhiteSpace(userAgentEntry)) + { + return $"{baseUserAgent} {userAgentEntry}"; + } + + return baseUserAgent; + } + + private string GetThriftVersion() + { + try + { + var thriftAssembly = typeof(TProtocol).Assembly; + var version = thriftAssembly.GetName().Version; + return version != null ? $"{version.Major}.{version.Minor}.{version.Build}" : ""; + } + catch + { + // Return empty string if there's any issue retrieving the assembly version + return ""; + } + } } } diff --git a/csharp/src/Drivers/Apache/Spark/SparkParameters.cs b/csharp/src/Drivers/Apache/Spark/SparkParameters.cs index 66d329814f..99275d7d0e 100644 --- a/csharp/src/Drivers/Apache/Spark/SparkParameters.cs +++ b/csharp/src/Drivers/Apache/Spark/SparkParameters.cs @@ -33,6 +33,7 @@ public class SparkParameters public const string Type = "adbc.spark.type"; public const string DataTypeConv = "adbc.spark.data_type_conv"; public const string ConnectTimeoutMilliseconds = "adbc.spark.connect_timeout_ms"; + public const string UserAgentEntry = "adbc.spark.user_agent_entry"; } public static class SparkAuthTypeConstants diff --git a/csharp/test/Drivers/Apache/Spark/SparkHttpConnectionUserAgentTest.cs b/csharp/test/Drivers/Apache/Spark/SparkHttpConnectionUserAgentTest.cs new file mode 100644 index 0000000000..84871ad78f --- /dev/null +++ b/csharp/test/Drivers/Apache/Spark/SparkHttpConnectionUserAgentTest.cs @@ -0,0 +1,147 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +using System; +using System.Collections.Generic; +using System.Net.Http; +using System.Reflection; +using Apache.Arrow.Adbc.Drivers.Apache.Spark; +using Xunit; + +namespace Apache.Arrow.Adbc.Tests.Drivers.Apache.Spark +{ + /// + /// Tests for the SparkHttpConnection user agent functionality. + /// + public class SparkHttpConnectionUserAgentTest + { + [Fact] + public void UserAgentEntry_WhenNotProvided_UsesBaseUserAgentWithThrift() + { + // Arrange + var properties = new Dictionary + { + [SparkParameters.Type] = SparkServerTypeConstants.Http, + [SparkParameters.HostName] = "valid.server.com", + [SparkParameters.Path] = "/path", + [SparkParameters.AuthType] = SparkAuthTypeConstants.None + }; + + // Act + string userAgent = GetUserAgentFromConnection(properties); + + // Assert + Assert.Matches(@"ADBCSparkDriver/[\d\.]+ Thrift(/[\d\.]+)?", userAgent); + } + + [Fact] + public void UserAgentEntry_WhenProvided_AppendsToBaseUserAgent() + { + // Arrange + var properties = new Dictionary + { + [SparkParameters.Type] = SparkServerTypeConstants.Http, + [SparkParameters.HostName] = "valid.server.com", + [SparkParameters.Path] = "/path", + [SparkParameters.AuthType] = SparkAuthTypeConstants.None, + [SparkParameters.UserAgentEntry] = "PowerBI" + }; + + // Act + string userAgent = GetUserAgentFromConnection(properties); + + // Assert + Assert.Matches(@"ADBCSparkDriver/[\d\.]+ Thrift(/[\d\.]+)? PowerBI", userAgent); + } + + [Fact] + public void UserAgentEntry_WhenEmpty_UsesBaseUserAgent() + { + // Arrange + var properties = new Dictionary + { + [SparkParameters.Type] = SparkServerTypeConstants.Http, + [SparkParameters.HostName] = "valid.server.com", + [SparkParameters.Path] = "/path", + [SparkParameters.AuthType] = SparkAuthTypeConstants.None, + [SparkParameters.UserAgentEntry] = "" + }; + + // Act + string userAgent = GetUserAgentFromConnection(properties); + + // Assert + Assert.Matches(@"ADBCSparkDriver/[\d\.]+ Thrift(/[\d\.]+)?", userAgent); + } + + [Fact] + public void UserAgentEntry_WhenWhitespace_UsesBaseUserAgent() + { + // Arrange + var properties = new Dictionary + { + [SparkParameters.Type] = SparkServerTypeConstants.Http, + [SparkParameters.HostName] = "valid.server.com", + [SparkParameters.Path] = "/path", + [SparkParameters.AuthType] = SparkAuthTypeConstants.None, + [SparkParameters.UserAgentEntry] = " " + }; + + // Act + string userAgent = GetUserAgentFromConnection(properties); + + // Assert + Assert.Matches(@"ADBCSparkDriver/[\d\.]+ Thrift(/[\d\.]+)?", userAgent); + } + + [Fact] + public void UserAgent_IncludesThriftComponent() + { + // Arrange + var properties = new Dictionary + { + [SparkParameters.Type] = SparkServerTypeConstants.Http, + [SparkParameters.HostName] = "valid.server.com", + [SparkParameters.Path] = "/path", + [SparkParameters.AuthType] = SparkAuthTypeConstants.None + }; + + // Act + string userAgent = GetUserAgentFromConnection(properties); + + // Assert + Assert.Contains("Thrift", userAgent); + } + + private string GetUserAgentFromConnection(Dictionary properties) + { + // Create the connection + var connection = new SparkHttpConnection(properties); + + // Use reflection to access the private GetUserAgent method + var getUserAgentMethod = typeof(SparkHttpConnection).GetMethod("GetUserAgent", BindingFlags.NonPublic | BindingFlags.Instance); + + if (getUserAgentMethod == null) + { + throw new InvalidOperationException("GetUserAgent method not found"); + } + + // Invoke the method + return (string)getUserAgentMethod.Invoke(connection, null)!; + } + } +}