diff --git a/csharp/src/Drivers/Apache/Hive2/HiveServer2Connection.cs b/csharp/src/Drivers/Apache/Hive2/HiveServer2Connection.cs index 2159642842..9a17eb6c1b 100644 --- a/csharp/src/Drivers/Apache/Hive2/HiveServer2Connection.cs +++ b/csharp/src/Drivers/Apache/Hive2/HiveServer2Connection.cs @@ -768,7 +768,7 @@ protected static Uri GetBaseAddress(string? uri, string? hostName, string? path, return baseAddress; } - protected IReadOnlyDictionary GetColumnIndexMap(List columns) => columns + internal IReadOnlyDictionary GetColumnIndexMap(List columns) => columns .Select(t => new { Index = t.Position - ColumnMapIndexOffset, t.ColumnName }) .ToDictionary(t => t.ColumnName, t => t.Index); @@ -1242,12 +1242,7 @@ private static StructArray GetColumnSchema(TableInfo tableInfo) nullBitmapBuffer.Build()); } - protected abstract void SetPrecisionScaleAndTypeName( - short colType, - string typeName, - TableInfo? tableInfo, - int columnSize, - int decimalDigits); + internal abstract void SetPrecisionScaleAndTypeName(short columnType, string typeName, TableInfo? tableInfo, int columnSize, int decimalDigits); public override Schema GetTableSchema(string? catalog, string? dbSchema, string? tableName) { @@ -1364,7 +1359,7 @@ private static IArrowType GetArrowType(int columnTypeId, string typeName, bool i } } - protected async Task FetchResultsAsync(TOperationHandle operationHandle, long batchSize = BatchSizeDefault, CancellationToken cancellationToken = default) + internal async Task FetchResultsAsync(TOperationHandle operationHandle, long batchSize = BatchSizeDefault, CancellationToken cancellationToken = default) { await PollForResponseAsync(operationHandle, Client, PollTimeMillisecondsDefault, cancellationToken); diff --git a/csharp/src/Drivers/Apache/Hive2/HiveServer2HttpConnection.cs b/csharp/src/Drivers/Apache/Hive2/HiveServer2HttpConnection.cs index aeee4b998b..1c882b78c6 100644 --- a/csharp/src/Drivers/Apache/Hive2/HiveServer2HttpConnection.cs +++ b/csharp/src/Drivers/Apache/Hive2/HiveServer2HttpConnection.cs @@ -224,7 +224,7 @@ protected override TOpenSessionReq CreateSessionRequest() return req; } - protected override void SetPrecisionScaleAndTypeName( + internal override void SetPrecisionScaleAndTypeName( short colType, string typeName, TableInfo? tableInfo, diff --git a/csharp/src/Drivers/Apache/Hive2/HiveServer2Statement.cs b/csharp/src/Drivers/Apache/Hive2/HiveServer2Statement.cs index f43bad07a0..6079253336 100644 --- a/csharp/src/Drivers/Apache/Hive2/HiveServer2Statement.cs +++ b/csharp/src/Drivers/Apache/Hive2/HiveServer2Statement.cs @@ -21,6 +21,7 @@ using System.Threading; using System.Threading.Tasks; using Apache.Arrow.Ipc; +using Apache.Arrow.Types; using Apache.Hive.Service.Rpc.Thrift; using Thrift.Transport; @@ -406,7 +407,39 @@ private async Task GetColumnsAsync(CancellationToken cancellationTo cancellationToken); OperationHandle = resp.OperationHandle; - return await GetQueryResult(resp.DirectResults, cancellationToken); + // Common variables declared upfront + TGetResultSetMetadataResp metadata; + Schema schema; + TRowSet rowSet; + + // For GetColumns, we need to enhance the result with BASE_TYPE_NAME + if (Connection.AreResultsAvailableDirectly() && resp.DirectResults?.ResultSet?.Results != null) + { + // Get data from direct results + metadata = resp.DirectResults.ResultSetMetadata; + schema = Connection.SchemaParser.GetArrowSchema(metadata.Schema, Connection.DataTypeConversion); + rowSet = resp.DirectResults.ResultSet.Results; + } + else + { + // Poll and fetch results + await HiveServer2Connection.PollForResponseAsync(OperationHandle!, Connection.Client, PollTimeMilliseconds, cancellationToken); + + // Get metadata + metadata = await HiveServer2Connection.GetResultSetMetadataAsync(OperationHandle!, Connection.Client, cancellationToken); + schema = Connection.SchemaParser.GetArrowSchema(metadata.Schema, Connection.DataTypeConversion); + + // Fetch the results + rowSet = await Connection.FetchResultsAsync(OperationHandle!, BatchSize, cancellationToken); + } + + // Common processing for both paths + int columnCount = HiveServer2Reader.GetColumnCount(rowSet); + int rowCount = HiveServer2Reader.GetRowCount(rowSet, columnCount); + IReadOnlyList data = HiveServer2Reader.GetArrowArrayData(rowSet, columnCount, schema, Connection.DataTypeConversion); + + // Return the enhanced result with added BASE_TYPE_NAME column + return EnhanceGetColumnsResult(schema, data, rowCount, metadata, rowSet); } private async Task GetResultSetSchemaAsync(TOperationHandle operationHandle, TCLIService.IAsync client, CancellationToken cancellationToken = default) @@ -426,12 +459,105 @@ private async Task GetQueryResult(TSparkDirectResults? directResult int columnCount = HiveServer2Reader.GetColumnCount(rowSet); int rowCount = HiveServer2Reader.GetRowCount(rowSet, columnCount); IReadOnlyList data = HiveServer2Reader.GetArrowArrayData(rowSet, columnCount, schema, Connection.DataTypeConversion); + return new QueryResult(rowCount, new HiveServer2Connection.HiveInfoArrowStream(schema, data)); } await HiveServer2Connection.PollForResponseAsync(OperationHandle!, Connection.Client, PollTimeMilliseconds, cancellationToken); schema = await GetResultSetSchemaAsync(OperationHandle!, Connection.Client, cancellationToken); + return new QueryResult(-1, Connection.NewReader(this, schema)); } + + protected internal QueryResult EnhanceGetColumnsResult(Schema originalSchema, IReadOnlyList originalData, + int rowCount, TGetResultSetMetadataResp metadata, TRowSet rowSet) + { + // Create a column map using Connection's GetColumnIndexMap method + var columnMap = Connection.GetColumnIndexMap(metadata.Schema.Columns); + + // Get column indices - we know these columns always exist + int typeNameIndex = columnMap["TYPE_NAME"]; + int dataTypeIndex = columnMap["DATA_TYPE"]; + int columnSizeIndex = columnMap["COLUMN_SIZE"]; + int decimalDigitsIndex = columnMap["DECIMAL_DIGITS"]; + + // Extract the existing arrays + StringArray typeNames = (StringArray)originalData[typeNameIndex]; + Int32Array originalColumnSizes = (Int32Array)originalData[columnSizeIndex]; + Int32Array originalDecimalDigits = (Int32Array)originalData[decimalDigitsIndex]; + + // Create enhanced schema with BASE_TYPE_NAME column + var enhancedFields = originalSchema.FieldsList.ToList(); + enhancedFields.Add(new Field("BASE_TYPE_NAME", StringType.Default, true)); + Schema enhancedSchema = new Schema(enhancedFields, originalSchema.Metadata); + + // Pre-allocate arrays to store our values + int length = typeNames.Length; + List baseTypeNames = new List(length); + List columnSizeValues = new List(length); + List decimalDigitsValues = new List(length); + + // Process each row + for (int i = 0; i < length; i++) + { + string? typeName = typeNames.GetString(i); + short colType = (short)rowSet.Columns[dataTypeIndex].I32Val.Values.Values[i]; + int columnSize = originalColumnSizes.GetValue(i).GetValueOrDefault(); + int decimalDigits = originalDecimalDigits.GetValue(i).GetValueOrDefault(); + + // Create a TableInfo for this row + var tableInfo = new HiveServer2Connection.TableInfo(string.Empty); + + // Process all types through SetPrecisionScaleAndTypeName + Connection.SetPrecisionScaleAndTypeName(colType, typeName ?? string.Empty, tableInfo, columnSize, decimalDigits); + + // Get base type name + string baseTypeName; + if (tableInfo.BaseTypeName.Count > 0) + { + string? baseTypeNameValue = tableInfo.BaseTypeName[0]; + baseTypeName = baseTypeNameValue ?? string.Empty; + } + else + { + baseTypeName = typeName ?? string.Empty; + } + baseTypeNames.Add(baseTypeName); + + // Get precision/scale values + if (tableInfo.Precision.Count > 0) + { + int? precisionValue = tableInfo.Precision[0]; + columnSizeValues.Add(precisionValue.GetValueOrDefault(columnSize)); + } + else + { + columnSizeValues.Add(columnSize); + } + + if (tableInfo.Scale.Count > 0) + { + int? scaleValue = tableInfo.Scale[0]; + decimalDigitsValues.Add(scaleValue.GetValueOrDefault(decimalDigits)); + } + else + { + decimalDigitsValues.Add(decimalDigits); + } + } + + // Create the Arrow arrays directly from our data arrays + StringArray baseTypeNameArray = new StringArray.Builder().AppendRange(baseTypeNames).Build(); + Int32Array columnSizeArray = new Int32Array.Builder().AppendRange(columnSizeValues).Build(); + Int32Array decimalDigitsArray = new Int32Array.Builder().AppendRange(decimalDigitsValues).Build(); + + // Create enhanced data with modified columns + var enhancedData = new List(originalData); + enhancedData[columnSizeIndex] = columnSizeArray; + enhancedData[decimalDigitsIndex] = decimalDigitsArray; + enhancedData.Add(baseTypeNameArray); + + return new QueryResult(rowCount, new HiveServer2Connection.HiveInfoArrowStream(enhancedSchema, enhancedData)); + } } } diff --git a/csharp/src/Drivers/Apache/Hive2/SqlTypeNameParser.cs b/csharp/src/Drivers/Apache/Hive2/SqlTypeNameParser.cs index bd7d285e21..4445334d01 100644 --- a/csharp/src/Drivers/Apache/Hive2/SqlTypeNameParser.cs +++ b/csharp/src/Drivers/Apache/Hive2/SqlTypeNameParser.cs @@ -89,7 +89,11 @@ internal abstract class SqlTypeNameParser : ISqlTypeNameParser where T : SqlT // Note: the INTERVAL sql type does not have an associated column type id. private static readonly HashSet s_parsers = new HashSet(s_parserMap.Values - .Concat([SqlIntervalTypeParser.Default, SqlSimpleTypeParser.Default("VOID")])); + .Concat([ + SqlIntervalTypeParser.Default, + SqlSimpleTypeParser.Default("VOID"), + SqlSimpleTypeParser.Default("VARIANT"), + ])); /// /// Gets the base SQL type name without decoration or sub clauses diff --git a/csharp/src/Drivers/Apache/Impala/ImpalaConnection.cs b/csharp/src/Drivers/Apache/Impala/ImpalaConnection.cs index 3d51af5218..c5d8e45c8f 100644 --- a/csharp/src/Drivers/Apache/Impala/ImpalaConnection.cs +++ b/csharp/src/Drivers/Apache/Impala/ImpalaConnection.cs @@ -86,7 +86,7 @@ protected override Task GetRowSetAsync(TGetSchemasResp response, Cancel protected internal override Task GetRowSetAsync(TGetPrimaryKeysResp response, CancellationToken cancellationToken = default) => FetchResultsAsync(response.OperationHandle, cancellationToken: cancellationToken); - protected override void SetPrecisionScaleAndTypeName( + internal override void SetPrecisionScaleAndTypeName( short colType, string typeName, TableInfo? tableInfo, diff --git a/csharp/src/Drivers/Apache/Spark/SparkConnection.cs b/csharp/src/Drivers/Apache/Spark/SparkConnection.cs index 8b3014cc8d..925073e351 100644 --- a/csharp/src/Drivers/Apache/Spark/SparkConnection.cs +++ b/csharp/src/Drivers/Apache/Spark/SparkConnection.cs @@ -63,7 +63,7 @@ public override AdbcStatement CreateStatement() protected internal override int PositionRequiredOffset => 1; - protected override void SetPrecisionScaleAndTypeName( + internal override void SetPrecisionScaleAndTypeName( short colType, string typeName, TableInfo? tableInfo, diff --git a/csharp/test/Drivers/Databricks/StatementTests.cs b/csharp/test/Drivers/Databricks/StatementTests.cs index 8f06c63c8d..81f0807306 100644 --- a/csharp/test/Drivers/Databricks/StatementTests.cs +++ b/csharp/test/Drivers/Databricks/StatementTests.cs @@ -18,10 +18,14 @@ using System; using System.Collections.Generic; using System.Threading.Tasks; +using Apache.Arrow.Adbc.Drivers.Apache; using Apache.Arrow.Adbc.Drivers.Databricks; using Apache.Arrow.Adbc.Tests.Drivers.Apache.Common; +using Apache.Arrow.Adbc.Tests.Xunit; +using Apache.Arrow.Types; using Xunit; using Xunit.Abstractions; +using System.Linq; namespace Apache.Arrow.Adbc.Tests.Drivers.Databricks { @@ -114,6 +118,218 @@ public async Task CanGetCrossReferenceFromChildTableDatabricks() await base.CanGetCrossReferenceFromChildTable(TestConfiguration.Metadata.Catalog, TestConfiguration.Metadata.Schema); } + [SkippableFact] + public async Task CanGetColumnsWithBaseTypeName() + { + var statement = Connection.CreateStatement(); + statement.SetOption(ApacheParameters.IsMetadataCommand, "true"); + statement.SetOption(ApacheParameters.CatalogName, TestConfiguration.Metadata.Catalog); + statement.SetOption(ApacheParameters.SchemaName, TestConfiguration.Metadata.Schema); + statement.SetOption(ApacheParameters.TableName, TestConfiguration.Metadata.Table); + statement.SqlQuery = "GetColumns"; + + QueryResult queryResult = await statement.ExecuteQueryAsync(); + Assert.NotNull(queryResult.Stream); + + // We should have 24 columns now (the original 23 + BASE_TYPE_NAME) + Assert.Equal(24, queryResult.Stream.Schema.FieldsList.Count); + + // Verify the BASE_TYPE_NAME column is present + bool hasBaseTypeNameColumn = false; + int baseTypeNameIndex = -1; + int typeNameIndex = -1; + + for (int i = 0; i < queryResult.Stream.Schema.FieldsList.Count; i++) + { + if (queryResult.Stream.Schema.FieldsList[i].Name.Equals("BASE_TYPE_NAME", StringComparison.OrdinalIgnoreCase)) + { + hasBaseTypeNameColumn = true; + baseTypeNameIndex = i; + } + else if (queryResult.Stream.Schema.FieldsList[i].Name.Equals("TYPE_NAME", StringComparison.OrdinalIgnoreCase)) + { + typeNameIndex = i; + } + } + + Assert.True(hasBaseTypeNameColumn, "BASE_TYPE_NAME column not found in GetColumns result"); + Assert.True(typeNameIndex >= 0, "TYPE_NAME column not found in GetColumns result"); + + // Read batches and verify BASE_TYPE_NAME values + int actualBatchLength = 0; + + // Track if we've seen specific complex types + bool foundDecimal = false; + bool foundInterval = false; + bool foundMap = false; + bool foundArray = false; + bool foundStruct = false; + + Dictionary typeNameToBaseTypeName = new Dictionary(); + + // For tracking decimal precision and scale + int columnSizeIndex = -1; + int decimalDigitsIndex = -1; + Dictionary decimalTypeInfo = new Dictionary(); + + // Find COLUMN_SIZE and DECIMAL_DIGITS columns + for (int i = 0; i < queryResult.Stream.Schema.FieldsList.Count; i++) + { + if (queryResult.Stream.Schema.FieldsList[i].Name.Equals("COLUMN_SIZE", StringComparison.OrdinalIgnoreCase)) + { + columnSizeIndex = i; + } + else if (queryResult.Stream.Schema.FieldsList[i].Name.Equals("DECIMAL_DIGITS", StringComparison.OrdinalIgnoreCase)) + { + decimalDigitsIndex = i; + } + + if (columnSizeIndex >= 0 && decimalDigitsIndex >= 0) + break; + } + + while (queryResult.Stream != null) + { + RecordBatch? batch = await queryResult.Stream.ReadNextRecordBatchAsync(); + if (batch == null) + { + break; + } + + actualBatchLength += batch.Length; + + // Verify relationships between TYPE_NAME and BASE_TYPE_NAME for each row + for (int i = 0; i < batch.Length; i++) + { + string? typeName = ((StringArray)batch.Column(typeNameIndex)).GetString(i); + string? baseTypeName = ((StringArray)batch.Column(baseTypeNameIndex)).GetString(i); + + // Store for later analysis + if (!string.IsNullOrEmpty(typeName) && !string.IsNullOrEmpty(baseTypeName)) + { + typeNameToBaseTypeName[typeName] = baseTypeName; + + // Collect precision and scale for DECIMAL types + if (typeName.StartsWith("DECIMAL(") && columnSizeIndex >= 0 && decimalDigitsIndex >= 0) + { + int? precision = ((Int32Array)batch.Column(columnSizeIndex)).GetValue(i); + int? scale = ((Int32Array)batch.Column(decimalDigitsIndex)).GetValue(i); + + if (precision.HasValue && scale.HasValue) + { + decimalTypeInfo[typeName] = (precision.Value, (short)scale.Value); + } + } + + // Track if we've found specific complex types + if (typeName.StartsWith("DECIMAL(")) + foundDecimal = true; + else if (typeName.StartsWith("INTERVAL")) + foundInterval = true; + else if (typeName.StartsWith("MAP<")) + foundMap = true; + else if (typeName.StartsWith("ARRAY<")) + foundArray = true; + else if (typeName.StartsWith("STRUCT<")) + foundStruct = true; + } + + // BASE_TYPE_NAME should not be null if TYPE_NAME is not null + if (!string.IsNullOrEmpty(typeName)) + { + Assert.NotNull(baseTypeName); + + // BASE_TYPE_NAME should be contained within TYPE_NAME or equal to it + // But we might have cases like "ARRAY" where baseTypeName would be "ARRAY" + if (!typeName.Contains("<") && !typeName.Contains("(") && !typeName.Contains(" ")) + { + // Simple types should match exactly, with special handling for INT vs INTEGER + bool isEquivalentType = + typeName == baseTypeName || + ((typeName == "INT" && baseTypeName == "INTEGER")) || + ((typeName == "TIMESTAMP_NTZ" || typeName == "TIMESTAMP_LTZ") && baseTypeName == "TIMESTAMP"); + + Assert.True(isEquivalentType, + $"TypeName '{typeName}' should be equivalent to BaseTypeName '{baseTypeName}'"); + } + else + { + // Complex types should have BASE_TYPE_NAME as a prefix (without parameters) + Assert.True(typeName.StartsWith(baseTypeName), + $"TypeName '{typeName}' should start with BaseTypeName '{baseTypeName}'"); + + // The BASE_TYPE_NAME should not contain angle brackets or parentheses + Assert.DoesNotContain("(", baseTypeName); + Assert.DoesNotContain("<", baseTypeName); + } + + OutputHelper?.WriteLine($"TYPE_NAME: {typeName}, BASE_TYPE_NAME: {baseTypeName}"); + } + } + } + + // Specific tests for complex types - if we found them in the results + if (foundDecimal) + { + string decimalTypeName = typeNameToBaseTypeName.Keys.First(k => k.StartsWith("DECIMAL(")); + string decimalBaseTypeName = typeNameToBaseTypeName[decimalTypeName]; + Assert.Equal("DECIMAL", decimalBaseTypeName); + OutputHelper?.WriteLine($"Verified DECIMAL: {decimalTypeName} -> {decimalBaseTypeName}"); + + // Extract precision and scale from the type name (e.g., "DECIMAL(38,10)" -> precision=38, scale=10) + string typePart = decimalTypeName.Substring(decimalTypeName.IndexOf('(') + 1); + typePart = typePart.Remove(typePart.Length - 1); // Remove closing parenthesis + string[] parts = typePart.Split(','); + + int expectedPrecision = int.Parse(parts[0]); + int expectedScale = parts.Length > 1 ? int.Parse(parts[1]) : 0; + + // Verify that the precision and scale from the data match what's in the type name + Assert.True(decimalTypeInfo.ContainsKey(decimalTypeName), + $"Could not find precision and scale information for {decimalTypeName}"); + + var (actualPrecision, actualScale) = decimalTypeInfo[decimalTypeName]; + Assert.Equal(expectedPrecision, actualPrecision); + Assert.Equal(expectedScale, actualScale); + + OutputHelper?.WriteLine($"Verified DECIMAL precision/scale: {decimalTypeName} -> precision={actualPrecision}, scale={actualScale}"); + } + + if (foundInterval) + { + string intervalTypeName = typeNameToBaseTypeName.Keys.First(k => k.StartsWith("INTERVAL")); + string intervalBaseTypeName = typeNameToBaseTypeName[intervalTypeName]; + Assert.Equal("INTERVAL", intervalBaseTypeName); + OutputHelper?.WriteLine($"Verified INTERVAL: {intervalTypeName} -> {intervalBaseTypeName}"); + } + + if (foundMap) + { + string mapTypeName = typeNameToBaseTypeName.Keys.First(k => k.StartsWith("MAP<")); + string mapBaseTypeName = typeNameToBaseTypeName[mapTypeName]; + Assert.Equal("MAP", mapBaseTypeName); + OutputHelper?.WriteLine($"Verified MAP: {mapTypeName} -> {mapBaseTypeName}"); + } + + if (foundArray) + { + string arrayTypeName = typeNameToBaseTypeName.Keys.First(k => k.StartsWith("ARRAY<")); + string arrayBaseTypeName = typeNameToBaseTypeName[arrayTypeName]; + Assert.Equal("ARRAY", arrayBaseTypeName); + OutputHelper?.WriteLine($"Verified ARRAY: {arrayTypeName} -> {arrayBaseTypeName}"); + } + + if (foundStruct) + { + string structTypeName = typeNameToBaseTypeName.Keys.First(k => k.StartsWith("STRUCT<")); + string structBaseTypeName = typeNameToBaseTypeName[structTypeName]; + Assert.Equal("STRUCT", structBaseTypeName); + OutputHelper?.WriteLine($"Verified STRUCT: {structTypeName} -> {structBaseTypeName}"); + } + + Assert.Equal(TestConfiguration.Metadata.ExpectedColumnCount, actualBatchLength); + } + protected override void PrepareCreateTableWithPrimaryKeys(out string sqlUpdate, out string tableNameParent, out string fullTableNameParent, out IReadOnlyList primaryKeys) { CreateNewTableName(out tableNameParent, out fullTableNameParent);