Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 32 additions & 2 deletions csharp/src/Drivers/Apache/Spark/SparkHttpConnection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Spark
{
internal class SparkHttpConnection : SparkConnection
{
private static readonly string s_userAgent = $"{DriverName.Replace(" ", "")}/{ProductVersionDefault}";
private const string BasicAuthenticationScheme = "Basic";
private const string BearerAuthenticationScheme = "Bearer";

Expand Down Expand Up @@ -164,7 +163,7 @@ protected override TTransport CreateTransport()
HttpClient httpClient = new(httpClientHandler);
httpClient.BaseAddress = baseAddress;
httpClient.DefaultRequestHeaders.Authorization = authenticationHeaderValue;
httpClient.DefaultRequestHeaders.UserAgent.ParseAdd(s_userAgent);
httpClient.DefaultRequestHeaders.UserAgent.ParseAdd(GetUserAgent());
httpClient.DefaultRequestHeaders.AcceptEncoding.Clear();
httpClient.DefaultRequestHeaders.AcceptEncoding.Add(new StringWithQualityHeaderValue("identity"));
httpClient.DefaultRequestHeaders.ExpectContinue = false;
Expand Down Expand Up @@ -252,5 +251,36 @@ protected internal override Task<TRowSet> GetRowSetAsync(TGetPrimaryKeysResp res
internal override SparkServerType ServerType => SparkServerType.Http;

protected override int ColumnMapIndexOffset => 1;

private string GetUserAgent()
{
// Build the base user agent string with Thrift version
string thriftVersion = GetThriftVersion();
string thriftComponent = string.IsNullOrEmpty(thriftVersion) ? "Thrift" : $"Thrift/{thriftVersion}";
string baseUserAgent = $"{DriverName.Replace(" ", "")}/{ProductVersionDefault} {thriftComponent}";

// Check if a client has provided a user-agent entry
if (Properties.TryGetValue(SparkParameters.UserAgentEntry, out string? userAgentEntry) && !string.IsNullOrWhiteSpace(userAgentEntry))
{
return $"{baseUserAgent} {userAgentEntry}";
}

return baseUserAgent;
}

private string GetThriftVersion()
{
try
{
var thriftAssembly = typeof(TProtocol).Assembly;
var version = thriftAssembly.GetName().Version;
return version != null ? $"{version.Major}.{version.Minor}.{version.Build}" : "";
}
catch
{
// Return empty string if there's any issue retrieving the assembly version
return "";
}
}
}
}
1 change: 1 addition & 0 deletions csharp/src/Drivers/Apache/Spark/SparkParameters.cs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ public class SparkParameters
public const string Type = "adbc.spark.type";
public const string DataTypeConv = "adbc.spark.data_type_conv";
public const string ConnectTimeoutMilliseconds = "adbc.spark.connect_timeout_ms";
public const string UserAgentEntry = "adbc.spark.user_agent_entry";
}

public static class SparkAuthTypeConstants
Expand Down
147 changes: 147 additions & 0 deletions csharp/test/Drivers/Apache/Spark/SparkHttpConnectionUserAgentTest.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

using System;
using System.Collections.Generic;
using System.Net.Http;
using System.Reflection;
using Apache.Arrow.Adbc.Drivers.Apache.Spark;
using Xunit;

namespace Apache.Arrow.Adbc.Tests.Drivers.Apache.Spark
{
/// <summary>
/// Tests for the SparkHttpConnection user agent functionality.
/// </summary>
public class SparkHttpConnectionUserAgentTest
{
[Fact]
public void UserAgentEntry_WhenNotProvided_UsesBaseUserAgentWithThrift()
{
// Arrange
var properties = new Dictionary<string, string>
{
[SparkParameters.Type] = SparkServerTypeConstants.Http,
[SparkParameters.HostName] = "valid.server.com",
[SparkParameters.Path] = "/path",
[SparkParameters.AuthType] = SparkAuthTypeConstants.None
};

// Act
string userAgent = GetUserAgentFromConnection(properties);

// Assert
Assert.Matches(@"ADBCSparkDriver/[\d\.]+ Thrift(/[\d\.]+)?", userAgent);
}

[Fact]
public void UserAgentEntry_WhenProvided_AppendsToBaseUserAgent()
{
// Arrange
var properties = new Dictionary<string, string>
{
[SparkParameters.Type] = SparkServerTypeConstants.Http,
[SparkParameters.HostName] = "valid.server.com",
[SparkParameters.Path] = "/path",
[SparkParameters.AuthType] = SparkAuthTypeConstants.None,
[SparkParameters.UserAgentEntry] = "PowerBI"
};

// Act
string userAgent = GetUserAgentFromConnection(properties);

// Assert
Assert.Matches(@"ADBCSparkDriver/[\d\.]+ Thrift(/[\d\.]+)? PowerBI", userAgent);
}

[Fact]
public void UserAgentEntry_WhenEmpty_UsesBaseUserAgent()
{
// Arrange
var properties = new Dictionary<string, string>
{
[SparkParameters.Type] = SparkServerTypeConstants.Http,
[SparkParameters.HostName] = "valid.server.com",
[SparkParameters.Path] = "/path",
[SparkParameters.AuthType] = SparkAuthTypeConstants.None,
[SparkParameters.UserAgentEntry] = ""
};

// Act
string userAgent = GetUserAgentFromConnection(properties);

// Assert
Assert.Matches(@"ADBCSparkDriver/[\d\.]+ Thrift(/[\d\.]+)?", userAgent);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are also trailing spaces somewhere around here.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think they're fixed now? Thanks for giving me the headsup

}

[Fact]
public void UserAgentEntry_WhenWhitespace_UsesBaseUserAgent()
{
// Arrange
var properties = new Dictionary<string, string>
{
[SparkParameters.Type] = SparkServerTypeConstants.Http,
[SparkParameters.HostName] = "valid.server.com",
[SparkParameters.Path] = "/path",
[SparkParameters.AuthType] = SparkAuthTypeConstants.None,
[SparkParameters.UserAgentEntry] = " "
};

// Act
string userAgent = GetUserAgentFromConnection(properties);

// Assert
Assert.Matches(@"ADBCSparkDriver/[\d\.]+ Thrift(/[\d\.]+)?", userAgent);
}

[Fact]
public void UserAgent_IncludesThriftComponent()
{
// Arrange
var properties = new Dictionary<string, string>
{
[SparkParameters.Type] = SparkServerTypeConstants.Http,
[SparkParameters.HostName] = "valid.server.com",
[SparkParameters.Path] = "/path",
[SparkParameters.AuthType] = SparkAuthTypeConstants.None
};

// Act
string userAgent = GetUserAgentFromConnection(properties);

// Assert
Assert.Contains("Thrift", userAgent);
}

private string GetUserAgentFromConnection(Dictionary<string, string> properties)
{
// Create the connection
var connection = new SparkHttpConnection(properties);

// Use reflection to access the private GetUserAgent method
var getUserAgentMethod = typeof(SparkHttpConnection).GetMethod("GetUserAgent", BindingFlags.NonPublic | BindingFlags.Instance);

if (getUserAgentMethod == null)
{
throw new InvalidOperationException("GetUserAgent method not found");
}

// Invoke the method
return (string)getUserAgentMethod.Invoke(connection, null)!;
}
}
}
Loading