From 01a9490e76c546b312b13262136fe13a46e55bc8 Mon Sep 17 00:00:00 2001 From: Eric Wang Date: Thu, 10 Apr 2025 09:33:17 -0700 Subject: [PATCH 1/7] feat(csharp/drivers): enhance GetColumns with BASE_TYPE_NAME column --- .../Apache/Hive2/HiveServer2Connection.cs | 11 +- .../Apache/Hive2/HiveServer2HttpConnection.cs | 2 +- .../Apache/Hive2/HiveServer2Statement.cs | 117 ++++++++++ .../Drivers/Apache/Hive2/SqlTypeNameParser.cs | 6 +- .../Drivers/Apache/Impala/ImpalaConnection.cs | 2 +- .../Drivers/Apache/Spark/SparkConnection.cs | 2 +- .../test/Drivers/Databricks/StatementTests.cs | 216 ++++++++++++++++++ 7 files changed, 344 insertions(+), 12 deletions(-) diff --git a/csharp/src/Drivers/Apache/Hive2/HiveServer2Connection.cs b/csharp/src/Drivers/Apache/Hive2/HiveServer2Connection.cs index 2159642842..1bb986529d 100644 --- a/csharp/src/Drivers/Apache/Hive2/HiveServer2Connection.cs +++ b/csharp/src/Drivers/Apache/Hive2/HiveServer2Connection.cs @@ -768,7 +768,7 @@ protected static Uri GetBaseAddress(string? uri, string? hostName, string? path, return baseAddress; } - protected IReadOnlyDictionary GetColumnIndexMap(List columns) => columns + internal IReadOnlyDictionary GetColumnIndexMap(List columns) => columns .Select(t => new { Index = t.Position - ColumnMapIndexOffset, t.ColumnName }) .ToDictionary(t => t.ColumnName, t => t.Index); @@ -1242,12 +1242,7 @@ private static StructArray GetColumnSchema(TableInfo tableInfo) nullBitmapBuffer.Build()); } - protected abstract void SetPrecisionScaleAndTypeName( - short colType, - string typeName, - TableInfo? tableInfo, - int columnSize, - int decimalDigits); + public abstract void SetPrecisionScaleAndTypeName(short columnType, string typeName, TableInfo? tableInfo, int columnSize, int decimalDigits); public override Schema GetTableSchema(string? catalog, string? dbSchema, string? tableName) { @@ -1364,7 +1359,7 @@ private static IArrowType GetArrowType(int columnTypeId, string typeName, bool i } } - protected async Task FetchResultsAsync(TOperationHandle operationHandle, long batchSize = BatchSizeDefault, CancellationToken cancellationToken = default) + public async Task FetchResultsAsync(TOperationHandle operationHandle, long batchSize = BatchSizeDefault, CancellationToken cancellationToken = default) { await PollForResponseAsync(operationHandle, Client, PollTimeMillisecondsDefault, cancellationToken); diff --git a/csharp/src/Drivers/Apache/Hive2/HiveServer2HttpConnection.cs b/csharp/src/Drivers/Apache/Hive2/HiveServer2HttpConnection.cs index aeee4b998b..6ae19e6c8f 100644 --- a/csharp/src/Drivers/Apache/Hive2/HiveServer2HttpConnection.cs +++ b/csharp/src/Drivers/Apache/Hive2/HiveServer2HttpConnection.cs @@ -224,7 +224,7 @@ protected override TOpenSessionReq CreateSessionRequest() return req; } - protected override void SetPrecisionScaleAndTypeName( + public override void SetPrecisionScaleAndTypeName( short colType, string typeName, TableInfo? tableInfo, diff --git a/csharp/src/Drivers/Apache/Hive2/HiveServer2Statement.cs b/csharp/src/Drivers/Apache/Hive2/HiveServer2Statement.cs index f43bad07a0..0bf64c20ea 100644 --- a/csharp/src/Drivers/Apache/Hive2/HiveServer2Statement.cs +++ b/csharp/src/Drivers/Apache/Hive2/HiveServer2Statement.cs @@ -21,6 +21,7 @@ using System.Threading; using System.Threading.Tasks; using Apache.Arrow.Ipc; +using Apache.Arrow.Types; using Apache.Hive.Service.Rpc.Thrift; using Thrift.Transport; @@ -426,12 +427,128 @@ private async Task GetQueryResult(TSparkDirectResults? directResult int columnCount = HiveServer2Reader.GetColumnCount(rowSet); int rowCount = HiveServer2Reader.GetRowCount(rowSet, columnCount); IReadOnlyList data = HiveServer2Reader.GetArrowArrayData(rowSet, columnCount, schema, Connection.DataTypeConversion); + + // Enhance column schema results if this is a GetColumns query + if (SqlQuery?.ToLowerInvariant() == GetColumnsCommandName) + { + return EnhanceGetColumnsResult(schema, data, rowCount, resultSetMetadata, rowSet); + } + return new QueryResult(rowCount, new HiveServer2Connection.HiveInfoArrowStream(schema, data)); } await HiveServer2Connection.PollForResponseAsync(OperationHandle!, Connection.Client, PollTimeMilliseconds, cancellationToken); schema = await GetResultSetSchemaAsync(OperationHandle!, Connection.Client, cancellationToken); + + // For GetColumns operation, we need to fetch the results and enhance them + if (SqlQuery?.ToLowerInvariant() == GetColumnsCommandName) + { + // Fetch the results manually to enhance them + TRowSet rowSet = await Connection.FetchResultsAsync(OperationHandle!, BatchSize, cancellationToken); + int columnCount = HiveServer2Reader.GetColumnCount(rowSet); + int rowCount = HiveServer2Reader.GetRowCount(rowSet, columnCount); + + // Get metadata again to ensure we have the latest + TGetResultSetMetadataResp metadata = await HiveServer2Connection.GetResultSetMetadataAsync(OperationHandle!, Connection.Client, cancellationToken); + + // Get the arrays from the row set + IReadOnlyList data = HiveServer2Reader.GetArrowArrayData(rowSet, columnCount, schema, Connection.DataTypeConversion); + + return EnhanceGetColumnsResult(schema, data, rowCount, metadata, rowSet); + } + return new QueryResult(-1, Connection.NewReader(this, schema)); } + + private QueryResult EnhanceGetColumnsResult(Schema originalSchema, IReadOnlyList originalData, + int rowCount, TGetResultSetMetadataResp metadata, TRowSet rowSet) + { + // Create a column map using Connection's GetColumnIndexMap method + var columnMap = Connection.GetColumnIndexMap(metadata.Schema.Columns); + + // Get column indices - we know these columns always exist + int typeNameIndex = columnMap["TYPE_NAME"]; + int dataTypeIndex = columnMap["DATA_TYPE"]; + int columnSizeIndex = columnMap["COLUMN_SIZE"]; + int decimalDigitsIndex = columnMap["DECIMAL_DIGITS"]; + + // Extract the existing arrays + StringArray typeNames = (StringArray)originalData[typeNameIndex]; + Int32Array originalColumnSizes = (Int32Array)originalData[columnSizeIndex]; + Int32Array originalDecimalDigits = (Int32Array)originalData[decimalDigitsIndex]; + + // Create enhanced schema with BASE_TYPE_NAME column + var enhancedFields = originalSchema.FieldsList.ToList(); + enhancedFields.Add(new Field("BASE_TYPE_NAME", StringType.Default, true)); + Schema enhancedSchema = new Schema(enhancedFields, originalSchema.Metadata); + + // Pre-allocate arrays to store our values + int length = typeNames.Length; + List baseTypeNames = new List(length); + List columnSizeValues = new List(length); + List decimalDigitsValues = new List(length); + + // Process each row + for (int i = 0; i < length; i++) + { + string? typeName = typeNames.GetString(i); + short colType = (short)rowSet.Columns[dataTypeIndex].I32Val.Values.Values[i]; + int columnSize = originalColumnSizes.GetValue(i).GetValueOrDefault(); + int decimalDigits = originalDecimalDigits.GetValue(i).GetValueOrDefault(); + + // Create a TableInfo for this row + var tableInfo = new HiveServer2Connection.TableInfo(string.Empty); + + // Process all types through SetPrecisionScaleAndTypeName + Connection.SetPrecisionScaleAndTypeName(colType, typeName ?? string.Empty, tableInfo, columnSize, decimalDigits); + + // Get base type name + string baseTypeName; + if (tableInfo.BaseTypeName.Count > 0) + { + string? baseTypeNameValue = tableInfo.BaseTypeName[0]; + baseTypeName = baseTypeNameValue ?? string.Empty; + } + else + { + baseTypeName = typeName ?? string.Empty; + } + baseTypeNames.Add(baseTypeName); + + // Get precision/scale values + if (tableInfo.Precision.Count > 0) + { + int? precisionValue = tableInfo.Precision[0]; + columnSizeValues.Add(precisionValue.GetValueOrDefault(columnSize)); + } + else + { + columnSizeValues.Add(columnSize); + } + + if (tableInfo.Scale.Count > 0) + { + int? scaleValue = tableInfo.Scale[0]; + decimalDigitsValues.Add(scaleValue.GetValueOrDefault(decimalDigits)); + } + else + { + decimalDigitsValues.Add(decimalDigits); + } + } + + // Create the Arrow arrays directly from our data arrays + StringArray baseTypeNameArray = new StringArray.Builder().AppendRange(baseTypeNames).Build(); + Int32Array columnSizeArray = new Int32Array.Builder().AppendRange(columnSizeValues).Build(); + Int32Array decimalDigitsArray = new Int32Array.Builder().AppendRange(decimalDigitsValues).Build(); + + // Create enhanced data with modified columns + var enhancedData = new List(originalData); + enhancedData[columnSizeIndex] = columnSizeArray; + enhancedData[decimalDigitsIndex] = decimalDigitsArray; + enhancedData.Add(baseTypeNameArray); + + return new QueryResult(rowCount, new HiveServer2Connection.HiveInfoArrowStream(enhancedSchema, enhancedData)); + } } } diff --git a/csharp/src/Drivers/Apache/Hive2/SqlTypeNameParser.cs b/csharp/src/Drivers/Apache/Hive2/SqlTypeNameParser.cs index bd7d285e21..c994080455 100644 --- a/csharp/src/Drivers/Apache/Hive2/SqlTypeNameParser.cs +++ b/csharp/src/Drivers/Apache/Hive2/SqlTypeNameParser.cs @@ -89,7 +89,11 @@ internal abstract class SqlTypeNameParser : ISqlTypeNameParser where T : SqlT // Note: the INTERVAL sql type does not have an associated column type id. private static readonly HashSet s_parsers = new HashSet(s_parserMap.Values - .Concat([SqlIntervalTypeParser.Default, SqlSimpleTypeParser.Default("VOID")])); + .Concat([ + SqlIntervalTypeParser.Default, + SqlSimpleTypeParser.Default("VOID"), + SqlSimpleTypeParser.Default("VARIANT"), + ])); /// /// Gets the base SQL type name without decoration or sub clauses diff --git a/csharp/src/Drivers/Apache/Impala/ImpalaConnection.cs b/csharp/src/Drivers/Apache/Impala/ImpalaConnection.cs index 3d51af5218..a63ae3cd79 100644 --- a/csharp/src/Drivers/Apache/Impala/ImpalaConnection.cs +++ b/csharp/src/Drivers/Apache/Impala/ImpalaConnection.cs @@ -86,7 +86,7 @@ protected override Task GetRowSetAsync(TGetSchemasResp response, Cancel protected internal override Task GetRowSetAsync(TGetPrimaryKeysResp response, CancellationToken cancellationToken = default) => FetchResultsAsync(response.OperationHandle, cancellationToken: cancellationToken); - protected override void SetPrecisionScaleAndTypeName( + public override void SetPrecisionScaleAndTypeName( short colType, string typeName, TableInfo? tableInfo, diff --git a/csharp/src/Drivers/Apache/Spark/SparkConnection.cs b/csharp/src/Drivers/Apache/Spark/SparkConnection.cs index 8b3014cc8d..8682f9c26b 100644 --- a/csharp/src/Drivers/Apache/Spark/SparkConnection.cs +++ b/csharp/src/Drivers/Apache/Spark/SparkConnection.cs @@ -63,7 +63,7 @@ public override AdbcStatement CreateStatement() protected internal override int PositionRequiredOffset => 1; - protected override void SetPrecisionScaleAndTypeName( + public override void SetPrecisionScaleAndTypeName( short colType, string typeName, TableInfo? tableInfo, diff --git a/csharp/test/Drivers/Databricks/StatementTests.cs b/csharp/test/Drivers/Databricks/StatementTests.cs index 8f06c63c8d..bf139bbf1e 100644 --- a/csharp/test/Drivers/Databricks/StatementTests.cs +++ b/csharp/test/Drivers/Databricks/StatementTests.cs @@ -18,10 +18,14 @@ using System; using System.Collections.Generic; using System.Threading.Tasks; +using Apache.Arrow.Adbc.Drivers.Apache; using Apache.Arrow.Adbc.Drivers.Databricks; using Apache.Arrow.Adbc.Tests.Drivers.Apache.Common; +using Apache.Arrow.Adbc.Tests.Xunit; +using Apache.Arrow.Types; using Xunit; using Xunit.Abstractions; +using System.Linq; namespace Apache.Arrow.Adbc.Tests.Drivers.Databricks { @@ -114,6 +118,218 @@ public async Task CanGetCrossReferenceFromChildTableDatabricks() await base.CanGetCrossReferenceFromChildTable(TestConfiguration.Metadata.Catalog, TestConfiguration.Metadata.Schema); } + [SkippableFact] + public async Task CanGetColumnsWithBaseTypeName() + { + var statement = Connection.CreateStatement(); + statement.SetOption(ApacheParameters.IsMetadataCommand, "true"); + statement.SetOption(ApacheParameters.CatalogName, TestConfiguration.Metadata.Catalog); + statement.SetOption(ApacheParameters.SchemaName, TestConfiguration.Metadata.Schema); + statement.SetOption(ApacheParameters.TableName, TestConfiguration.Metadata.Table); + statement.SqlQuery = "GetColumns"; + + QueryResult queryResult = await statement.ExecuteQueryAsync(); + Assert.NotNull(queryResult.Stream); + + // We should have 24 columns now (the original 23 + BASE_TYPE_NAME) + Assert.Equal(24, queryResult.Stream.Schema.FieldsList.Count); + + // Verify the BASE_TYPE_NAME column is present + bool hasBaseTypeNameColumn = false; + int baseTypeNameIndex = -1; + int typeNameIndex = -1; + + for (int i = 0; i < queryResult.Stream.Schema.FieldsList.Count; i++) + { + if (queryResult.Stream.Schema.FieldsList[i].Name.Equals("BASE_TYPE_NAME", StringComparison.OrdinalIgnoreCase)) + { + hasBaseTypeNameColumn = true; + baseTypeNameIndex = i; + } + else if (queryResult.Stream.Schema.FieldsList[i].Name.Equals("TYPE_NAME", StringComparison.OrdinalIgnoreCase)) + { + typeNameIndex = i; + } + } + + Assert.True(hasBaseTypeNameColumn, "BASE_TYPE_NAME column not found in GetColumns result"); + Assert.True(typeNameIndex >= 0, "TYPE_NAME column not found in GetColumns result"); + + // Read batches and verify BASE_TYPE_NAME values + int actualBatchLength = 0; + + // Track if we've seen specific complex types + bool foundDecimal = false; + bool foundInterval = false; + bool foundMap = false; + bool foundArray = false; + bool foundStruct = false; + + Dictionary typeNameToBaseTypeName = new Dictionary(); + + // For tracking decimal precision and scale + int columnSizeIndex = -1; + int decimalDigitsIndex = -1; + Dictionary decimalTypeInfo = new Dictionary(); + + // Find COLUMN_SIZE and DECIMAL_DIGITS columns + for (int i = 0; i < queryResult.Stream.Schema.FieldsList.Count; i++) + { + if (queryResult.Stream.Schema.FieldsList[i].Name.Equals("COLUMN_SIZE", StringComparison.OrdinalIgnoreCase)) + { + columnSizeIndex = i; + } + else if (queryResult.Stream.Schema.FieldsList[i].Name.Equals("DECIMAL_DIGITS", StringComparison.OrdinalIgnoreCase)) + { + decimalDigitsIndex = i; + } + + if (columnSizeIndex >= 0 && decimalDigitsIndex >= 0) + break; + } + + while (queryResult.Stream != null) + { + RecordBatch? batch = await queryResult.Stream.ReadNextRecordBatchAsync(); + if (batch == null) + { + break; + } + + actualBatchLength += batch.Length; + + // Verify relationships between TYPE_NAME and BASE_TYPE_NAME for each row + for (int i = 0; i < batch.Length; i++) + { + string? typeName = ((StringArray)batch.Column(typeNameIndex)).GetString(i); + string? baseTypeName = ((StringArray)batch.Column(baseTypeNameIndex)).GetString(i); + + // Store for later analysis + if (!string.IsNullOrEmpty(typeName) && !string.IsNullOrEmpty(baseTypeName)) + { + typeNameToBaseTypeName[typeName] = baseTypeName; + + // Collect precision and scale for DECIMAL types + if (typeName.StartsWith("DECIMAL(") && columnSizeIndex >= 0 && decimalDigitsIndex >= 0) + { + int? precision = ((Int32Array)batch.Column(columnSizeIndex)).GetValue(i); + int? scale = ((Int32Array)batch.Column(decimalDigitsIndex)).GetValue(i); + + if (precision.HasValue && scale.HasValue) + { + decimalTypeInfo[typeName] = (precision.Value, (short)scale.Value); + } + } + + // Track if we've found specific complex types + if (typeName.StartsWith("DECIMAL(")) + foundDecimal = true; + else if (typeName.StartsWith("INTERVAL")) + foundInterval = true; + else if (typeName.StartsWith("MAP<")) + foundMap = true; + else if (typeName.StartsWith("ARRAY<")) + foundArray = true; + else if (typeName.StartsWith("STRUCT<")) + foundStruct = true; + } + + // BASE_TYPE_NAME should not be null if TYPE_NAME is not null + if (!string.IsNullOrEmpty(typeName)) + { + Assert.NotNull(baseTypeName); + + // BASE_TYPE_NAME should be contained within TYPE_NAME or equal to it + // But we might have cases like "ARRAY" where baseTypeName would be "ARRAY" + if (!typeName.Contains("<") && !typeName.Contains("(") && !typeName.Contains(" ")) + { + // Simple types should match exactly, with special handling for INT vs INTEGER + bool isEquivalentType = + typeName == baseTypeName || + ((typeName == "INT" && baseTypeName == "INTEGER")) || + ((typeName == "TIMESTAMP_NTZ" || typeName == "TIMESTAMP_LTZ") && baseTypeName == "TIMESTAMP"); + + Assert.True(isEquivalentType, + $"TypeName '{typeName}' should be equivalent to BaseTypeName '{baseTypeName}'"); + } + else + { + // Complex types should have BASE_TYPE_NAME as a prefix (without parameters) + Assert.True(typeName.StartsWith(baseTypeName), + $"TypeName '{typeName}' should start with BaseTypeName '{baseTypeName}'"); + + // The BASE_TYPE_NAME should not contain angle brackets or parentheses + Assert.DoesNotContain("(", baseTypeName); + Assert.DoesNotContain("<", baseTypeName); + } + + OutputHelper?.WriteLine($"TYPE_NAME: {typeName}, BASE_TYPE_NAME: {baseTypeName}"); + } + } + } + + // Specific tests for complex types - if we found them in the results + if (foundDecimal) + { + string decimalTypeName = typeNameToBaseTypeName.Keys.First(k => k.StartsWith("DECIMAL(")); + string decimalBaseTypeName = typeNameToBaseTypeName[decimalTypeName]; + Assert.Equal("DECIMAL", decimalBaseTypeName); + OutputHelper?.WriteLine($"Verified DECIMAL: {decimalTypeName} -> {decimalBaseTypeName}"); + + // Extract precision and scale from the type name (e.g., "DECIMAL(38,10)" -> precision=38, scale=10) + string typePart = decimalTypeName.Substring(decimalTypeName.IndexOf('(') + 1); + typePart = typePart.Remove(typePart.Length - 1); // Remove closing parenthesis + string[] parts = typePart.Split(','); + + int expectedPrecision = int.Parse(parts[0]); + int expectedScale = parts.Length > 1 ? int.Parse(parts[1]) : 0; + + // Verify that the precision and scale from the data match what's in the type name + Assert.True(decimalTypeInfo.ContainsKey(decimalTypeName), + $"Could not find precision and scale information for {decimalTypeName}"); + + var (actualPrecision, actualScale) = decimalTypeInfo[decimalTypeName]; + Assert.Equal(expectedPrecision, actualPrecision); + Assert.Equal(expectedScale, actualScale); + + OutputHelper?.WriteLine($"Verified DECIMAL precision/scale: {decimalTypeName} -> precision={actualPrecision}, scale={actualScale}"); + } + + if (foundInterval) + { + string intervalTypeName = typeNameToBaseTypeName.Keys.First(k => k.StartsWith("INTERVAL")); + string intervalBaseTypeName = typeNameToBaseTypeName[intervalTypeName]; + Assert.Equal("INTERVAL", intervalBaseTypeName); + OutputHelper?.WriteLine($"Verified INTERVAL: {intervalTypeName} -> {intervalBaseTypeName}"); + } + + if (foundMap) + { + string mapTypeName = typeNameToBaseTypeName.Keys.First(k => k.StartsWith("MAP<")); + string mapBaseTypeName = typeNameToBaseTypeName[mapTypeName]; + Assert.Equal("MAP", mapBaseTypeName); + OutputHelper?.WriteLine($"Verified MAP: {mapTypeName} -> {mapBaseTypeName}"); + } + + if (foundArray) + { + string arrayTypeName = typeNameToBaseTypeName.Keys.First(k => k.StartsWith("ARRAY<")); + string arrayBaseTypeName = typeNameToBaseTypeName[arrayTypeName]; + Assert.Equal("ARRAY", arrayBaseTypeName); + OutputHelper?.WriteLine($"Verified ARRAY: {arrayTypeName} -> {arrayBaseTypeName}"); + } + + if (foundStruct) + { + string structTypeName = typeNameToBaseTypeName.Keys.First(k => k.StartsWith("STRUCT<")); + string structBaseTypeName = typeNameToBaseTypeName[structTypeName]; + Assert.Equal("STRUCT", structBaseTypeName); + OutputHelper?.WriteLine($"Verified STRUCT: {structTypeName} -> {structBaseTypeName}"); + } + + Assert.Equal(TestConfiguration.Metadata.ExpectedColumnCount, actualBatchLength); + } + protected override void PrepareCreateTableWithPrimaryKeys(out string sqlUpdate, out string tableNameParent, out string fullTableNameParent, out IReadOnlyList primaryKeys) { CreateNewTableName(out tableNameParent, out fullTableNameParent); From de5157040945a38170d243462e4a58e4fb475791 Mon Sep 17 00:00:00 2001 From: Eric Wang Date: Thu, 10 Apr 2025 11:07:36 -0700 Subject: [PATCH 2/7] lint --- .../Apache/Hive2/HiveServer2Statement.cs | 42 +++++------ .../Drivers/Apache/Hive2/SqlTypeNameParser.cs | 2 +- .../test/Drivers/Databricks/StatementTests.cs | 74 +++++++++---------- 3 files changed, 59 insertions(+), 59 deletions(-) diff --git a/csharp/src/Drivers/Apache/Hive2/HiveServer2Statement.cs b/csharp/src/Drivers/Apache/Hive2/HiveServer2Statement.cs index 0bf64c20ea..0b806ec8bc 100644 --- a/csharp/src/Drivers/Apache/Hive2/HiveServer2Statement.cs +++ b/csharp/src/Drivers/Apache/Hive2/HiveServer2Statement.cs @@ -427,19 +427,19 @@ private async Task GetQueryResult(TSparkDirectResults? directResult int columnCount = HiveServer2Reader.GetColumnCount(rowSet); int rowCount = HiveServer2Reader.GetRowCount(rowSet, columnCount); IReadOnlyList data = HiveServer2Reader.GetArrowArrayData(rowSet, columnCount, schema, Connection.DataTypeConversion); - + // Enhance column schema results if this is a GetColumns query if (SqlQuery?.ToLowerInvariant() == GetColumnsCommandName) { return EnhanceGetColumnsResult(schema, data, rowCount, resultSetMetadata, rowSet); } - + return new QueryResult(rowCount, new HiveServer2Connection.HiveInfoArrowStream(schema, data)); } await HiveServer2Connection.PollForResponseAsync(OperationHandle!, Connection.Client, PollTimeMilliseconds, cancellationToken); schema = await GetResultSetSchemaAsync(OperationHandle!, Connection.Client, cancellationToken); - + // For GetColumns operation, we need to fetch the results and enhance them if (SqlQuery?.ToLowerInvariant() == GetColumnsCommandName) { @@ -447,47 +447,47 @@ private async Task GetQueryResult(TSparkDirectResults? directResult TRowSet rowSet = await Connection.FetchResultsAsync(OperationHandle!, BatchSize, cancellationToken); int columnCount = HiveServer2Reader.GetColumnCount(rowSet); int rowCount = HiveServer2Reader.GetRowCount(rowSet, columnCount); - + // Get metadata again to ensure we have the latest TGetResultSetMetadataResp metadata = await HiveServer2Connection.GetResultSetMetadataAsync(OperationHandle!, Connection.Client, cancellationToken); - + // Get the arrays from the row set IReadOnlyList data = HiveServer2Reader.GetArrowArrayData(rowSet, columnCount, schema, Connection.DataTypeConversion); - + return EnhanceGetColumnsResult(schema, data, rowCount, metadata, rowSet); } - + return new QueryResult(-1, Connection.NewReader(this, schema)); } - private QueryResult EnhanceGetColumnsResult(Schema originalSchema, IReadOnlyList originalData, + private QueryResult EnhanceGetColumnsResult(Schema originalSchema, IReadOnlyList originalData, int rowCount, TGetResultSetMetadataResp metadata, TRowSet rowSet) { // Create a column map using Connection's GetColumnIndexMap method var columnMap = Connection.GetColumnIndexMap(metadata.Schema.Columns); - + // Get column indices - we know these columns always exist int typeNameIndex = columnMap["TYPE_NAME"]; int dataTypeIndex = columnMap["DATA_TYPE"]; int columnSizeIndex = columnMap["COLUMN_SIZE"]; int decimalDigitsIndex = columnMap["DECIMAL_DIGITS"]; - + // Extract the existing arrays StringArray typeNames = (StringArray)originalData[typeNameIndex]; Int32Array originalColumnSizes = (Int32Array)originalData[columnSizeIndex]; Int32Array originalDecimalDigits = (Int32Array)originalData[decimalDigitsIndex]; - + // Create enhanced schema with BASE_TYPE_NAME column var enhancedFields = originalSchema.FieldsList.ToList(); enhancedFields.Add(new Field("BASE_TYPE_NAME", StringType.Default, true)); Schema enhancedSchema = new Schema(enhancedFields, originalSchema.Metadata); - + // Pre-allocate arrays to store our values int length = typeNames.Length; List baseTypeNames = new List(length); List columnSizeValues = new List(length); List decimalDigitsValues = new List(length); - + // Process each row for (int i = 0; i < length; i++) { @@ -495,13 +495,13 @@ private QueryResult EnhanceGetColumnsResult(Schema originalSchema, IReadOnlyList short colType = (short)rowSet.Columns[dataTypeIndex].I32Val.Values.Values[i]; int columnSize = originalColumnSizes.GetValue(i).GetValueOrDefault(); int decimalDigits = originalDecimalDigits.GetValue(i).GetValueOrDefault(); - + // Create a TableInfo for this row var tableInfo = new HiveServer2Connection.TableInfo(string.Empty); - + // Process all types through SetPrecisionScaleAndTypeName Connection.SetPrecisionScaleAndTypeName(colType, typeName ?? string.Empty, tableInfo, columnSize, decimalDigits); - + // Get base type name string baseTypeName; if (tableInfo.BaseTypeName.Count > 0) @@ -514,7 +514,7 @@ private QueryResult EnhanceGetColumnsResult(Schema originalSchema, IReadOnlyList baseTypeName = typeName ?? string.Empty; } baseTypeNames.Add(baseTypeName); - + // Get precision/scale values if (tableInfo.Precision.Count > 0) { @@ -525,7 +525,7 @@ private QueryResult EnhanceGetColumnsResult(Schema originalSchema, IReadOnlyList { columnSizeValues.Add(columnSize); } - + if (tableInfo.Scale.Count > 0) { int? scaleValue = tableInfo.Scale[0]; @@ -536,18 +536,18 @@ private QueryResult EnhanceGetColumnsResult(Schema originalSchema, IReadOnlyList decimalDigitsValues.Add(decimalDigits); } } - + // Create the Arrow arrays directly from our data arrays StringArray baseTypeNameArray = new StringArray.Builder().AppendRange(baseTypeNames).Build(); Int32Array columnSizeArray = new Int32Array.Builder().AppendRange(columnSizeValues).Build(); Int32Array decimalDigitsArray = new Int32Array.Builder().AppendRange(decimalDigitsValues).Build(); - + // Create enhanced data with modified columns var enhancedData = new List(originalData); enhancedData[columnSizeIndex] = columnSizeArray; enhancedData[decimalDigitsIndex] = decimalDigitsArray; enhancedData.Add(baseTypeNameArray); - + return new QueryResult(rowCount, new HiveServer2Connection.HiveInfoArrowStream(enhancedSchema, enhancedData)); } } diff --git a/csharp/src/Drivers/Apache/Hive2/SqlTypeNameParser.cs b/csharp/src/Drivers/Apache/Hive2/SqlTypeNameParser.cs index c994080455..4445334d01 100644 --- a/csharp/src/Drivers/Apache/Hive2/SqlTypeNameParser.cs +++ b/csharp/src/Drivers/Apache/Hive2/SqlTypeNameParser.cs @@ -90,7 +90,7 @@ internal abstract class SqlTypeNameParser : ISqlTypeNameParser where T : SqlT // Note: the INTERVAL sql type does not have an associated column type id. private static readonly HashSet s_parsers = new HashSet(s_parserMap.Values .Concat([ - SqlIntervalTypeParser.Default, + SqlIntervalTypeParser.Default, SqlSimpleTypeParser.Default("VOID"), SqlSimpleTypeParser.Default("VARIANT"), ])); diff --git a/csharp/test/Drivers/Databricks/StatementTests.cs b/csharp/test/Drivers/Databricks/StatementTests.cs index bf139bbf1e..81f0807306 100644 --- a/csharp/test/Drivers/Databricks/StatementTests.cs +++ b/csharp/test/Drivers/Databricks/StatementTests.cs @@ -133,12 +133,12 @@ public async Task CanGetColumnsWithBaseTypeName() // We should have 24 columns now (the original 23 + BASE_TYPE_NAME) Assert.Equal(24, queryResult.Stream.Schema.FieldsList.Count); - + // Verify the BASE_TYPE_NAME column is present bool hasBaseTypeNameColumn = false; int baseTypeNameIndex = -1; int typeNameIndex = -1; - + for (int i = 0; i < queryResult.Stream.Schema.FieldsList.Count; i++) { if (queryResult.Stream.Schema.FieldsList[i].Name.Equals("BASE_TYPE_NAME", StringComparison.OrdinalIgnoreCase)) @@ -151,27 +151,27 @@ public async Task CanGetColumnsWithBaseTypeName() typeNameIndex = i; } } - + Assert.True(hasBaseTypeNameColumn, "BASE_TYPE_NAME column not found in GetColumns result"); Assert.True(typeNameIndex >= 0, "TYPE_NAME column not found in GetColumns result"); - + // Read batches and verify BASE_TYPE_NAME values int actualBatchLength = 0; - + // Track if we've seen specific complex types bool foundDecimal = false; bool foundInterval = false; bool foundMap = false; bool foundArray = false; bool foundStruct = false; - + Dictionary typeNameToBaseTypeName = new Dictionary(); - + // For tracking decimal precision and scale int columnSizeIndex = -1; int decimalDigitsIndex = -1; Dictionary decimalTypeInfo = new Dictionary(); - + // Find COLUMN_SIZE and DECIMAL_DIGITS columns for (int i = 0; i < queryResult.Stream.Schema.FieldsList.Count; i++) { @@ -183,11 +183,11 @@ public async Task CanGetColumnsWithBaseTypeName() { decimalDigitsIndex = i; } - + if (columnSizeIndex >= 0 && decimalDigitsIndex >= 0) break; } - + while (queryResult.Stream != null) { RecordBatch? batch = await queryResult.Stream.ReadNextRecordBatchAsync(); @@ -195,32 +195,32 @@ public async Task CanGetColumnsWithBaseTypeName() { break; } - + actualBatchLength += batch.Length; - + // Verify relationships between TYPE_NAME and BASE_TYPE_NAME for each row for (int i = 0; i < batch.Length; i++) { string? typeName = ((StringArray)batch.Column(typeNameIndex)).GetString(i); string? baseTypeName = ((StringArray)batch.Column(baseTypeNameIndex)).GetString(i); - + // Store for later analysis if (!string.IsNullOrEmpty(typeName) && !string.IsNullOrEmpty(baseTypeName)) { typeNameToBaseTypeName[typeName] = baseTypeName; - + // Collect precision and scale for DECIMAL types if (typeName.StartsWith("DECIMAL(") && columnSizeIndex >= 0 && decimalDigitsIndex >= 0) { int? precision = ((Int32Array)batch.Column(columnSizeIndex)).GetValue(i); int? scale = ((Int32Array)batch.Column(decimalDigitsIndex)).GetValue(i); - + if (precision.HasValue && scale.HasValue) { decimalTypeInfo[typeName] = (precision.Value, (short)scale.Value); } } - + // Track if we've found specific complex types if (typeName.StartsWith("DECIMAL(")) foundDecimal = true; @@ -233,41 +233,41 @@ public async Task CanGetColumnsWithBaseTypeName() else if (typeName.StartsWith("STRUCT<")) foundStruct = true; } - + // BASE_TYPE_NAME should not be null if TYPE_NAME is not null if (!string.IsNullOrEmpty(typeName)) { Assert.NotNull(baseTypeName); - + // BASE_TYPE_NAME should be contained within TYPE_NAME or equal to it // But we might have cases like "ARRAY" where baseTypeName would be "ARRAY" if (!typeName.Contains("<") && !typeName.Contains("(") && !typeName.Contains(" ")) { // Simple types should match exactly, with special handling for INT vs INTEGER - bool isEquivalentType = - typeName == baseTypeName || + bool isEquivalentType = + typeName == baseTypeName || ((typeName == "INT" && baseTypeName == "INTEGER")) || ((typeName == "TIMESTAMP_NTZ" || typeName == "TIMESTAMP_LTZ") && baseTypeName == "TIMESTAMP"); - - Assert.True(isEquivalentType, + + Assert.True(isEquivalentType, $"TypeName '{typeName}' should be equivalent to BaseTypeName '{baseTypeName}'"); } else { // Complex types should have BASE_TYPE_NAME as a prefix (without parameters) - Assert.True(typeName.StartsWith(baseTypeName), + Assert.True(typeName.StartsWith(baseTypeName), $"TypeName '{typeName}' should start with BaseTypeName '{baseTypeName}'"); - + // The BASE_TYPE_NAME should not contain angle brackets or parentheses Assert.DoesNotContain("(", baseTypeName); Assert.DoesNotContain("<", baseTypeName); } - + OutputHelper?.WriteLine($"TYPE_NAME: {typeName}, BASE_TYPE_NAME: {baseTypeName}"); } } } - + // Specific tests for complex types - if we found them in the results if (foundDecimal) { @@ -275,26 +275,26 @@ public async Task CanGetColumnsWithBaseTypeName() string decimalBaseTypeName = typeNameToBaseTypeName[decimalTypeName]; Assert.Equal("DECIMAL", decimalBaseTypeName); OutputHelper?.WriteLine($"Verified DECIMAL: {decimalTypeName} -> {decimalBaseTypeName}"); - + // Extract precision and scale from the type name (e.g., "DECIMAL(38,10)" -> precision=38, scale=10) string typePart = decimalTypeName.Substring(decimalTypeName.IndexOf('(') + 1); typePart = typePart.Remove(typePart.Length - 1); // Remove closing parenthesis string[] parts = typePart.Split(','); - + int expectedPrecision = int.Parse(parts[0]); int expectedScale = parts.Length > 1 ? int.Parse(parts[1]) : 0; - + // Verify that the precision and scale from the data match what's in the type name - Assert.True(decimalTypeInfo.ContainsKey(decimalTypeName), + Assert.True(decimalTypeInfo.ContainsKey(decimalTypeName), $"Could not find precision and scale information for {decimalTypeName}"); - + var (actualPrecision, actualScale) = decimalTypeInfo[decimalTypeName]; Assert.Equal(expectedPrecision, actualPrecision); Assert.Equal(expectedScale, actualScale); - + OutputHelper?.WriteLine($"Verified DECIMAL precision/scale: {decimalTypeName} -> precision={actualPrecision}, scale={actualScale}"); } - + if (foundInterval) { string intervalTypeName = typeNameToBaseTypeName.Keys.First(k => k.StartsWith("INTERVAL")); @@ -302,7 +302,7 @@ public async Task CanGetColumnsWithBaseTypeName() Assert.Equal("INTERVAL", intervalBaseTypeName); OutputHelper?.WriteLine($"Verified INTERVAL: {intervalTypeName} -> {intervalBaseTypeName}"); } - + if (foundMap) { string mapTypeName = typeNameToBaseTypeName.Keys.First(k => k.StartsWith("MAP<")); @@ -310,7 +310,7 @@ public async Task CanGetColumnsWithBaseTypeName() Assert.Equal("MAP", mapBaseTypeName); OutputHelper?.WriteLine($"Verified MAP: {mapTypeName} -> {mapBaseTypeName}"); } - + if (foundArray) { string arrayTypeName = typeNameToBaseTypeName.Keys.First(k => k.StartsWith("ARRAY<")); @@ -318,7 +318,7 @@ public async Task CanGetColumnsWithBaseTypeName() Assert.Equal("ARRAY", arrayBaseTypeName); OutputHelper?.WriteLine($"Verified ARRAY: {arrayTypeName} -> {arrayBaseTypeName}"); } - + if (foundStruct) { string structTypeName = typeNameToBaseTypeName.Keys.First(k => k.StartsWith("STRUCT<")); @@ -326,7 +326,7 @@ public async Task CanGetColumnsWithBaseTypeName() Assert.Equal("STRUCT", structBaseTypeName); OutputHelper?.WriteLine($"Verified STRUCT: {structTypeName} -> {structBaseTypeName}"); } - + Assert.Equal(TestConfiguration.Metadata.ExpectedColumnCount, actualBatchLength); } From e9d5e7eceb229e120115309bb9ae6d997c6dbf09 Mon Sep 17 00:00:00 2001 From: Eric Wang Date: Thu, 10 Apr 2025 16:06:58 -0700 Subject: [PATCH 3/7] address comments --- .../Apache/Hive2/HiveServer2Statement.cs | 56 ++++++++++--------- 1 file changed, 31 insertions(+), 25 deletions(-) diff --git a/csharp/src/Drivers/Apache/Hive2/HiveServer2Statement.cs b/csharp/src/Drivers/Apache/Hive2/HiveServer2Statement.cs index 0b806ec8bc..8a4094ae34 100644 --- a/csharp/src/Drivers/Apache/Hive2/HiveServer2Statement.cs +++ b/csharp/src/Drivers/Apache/Hive2/HiveServer2Statement.cs @@ -407,7 +407,36 @@ private async Task GetColumnsAsync(CancellationToken cancellationTo cancellationToken); OperationHandle = resp.OperationHandle; - return await GetQueryResult(resp.DirectResults, cancellationToken); + // For GetColumns, we need to enhance the result with BASE_TYPE_NAME + if (Connection.AreResultsAvailableDirectly() && resp.DirectResults?.ResultSet?.Results != null) + { + TGetResultSetMetadataResp resultSetMetadata = resp.DirectResults.ResultSetMetadata; + Schema schema = Connection.SchemaParser.GetArrowSchema(resultSetMetadata.Schema, Connection.DataTypeConversion); + TRowSet rowSet = resp.DirectResults.ResultSet.Results; + int columnCount = HiveServer2Reader.GetColumnCount(rowSet); + int rowCount = HiveServer2Reader.GetRowCount(rowSet, columnCount); + IReadOnlyList data = HiveServer2Reader.GetArrowArrayData(rowSet, columnCount, schema, Connection.DataTypeConversion); + + return EnhanceGetColumnsResult(schema, data, rowCount, resultSetMetadata, rowSet); + } + else + { + await HiveServer2Connection.PollForResponseAsync(OperationHandle!, Connection.Client, PollTimeMilliseconds, cancellationToken); + Schema schema = await GetResultSetSchemaAsync(OperationHandle!, Connection.Client, cancellationToken); + + // Fetch the results manually to enhance them + TRowSet rowSet = await Connection.FetchResultsAsync(OperationHandle!, BatchSize, cancellationToken); + int columnCount = HiveServer2Reader.GetColumnCount(rowSet); + int rowCount = HiveServer2Reader.GetRowCount(rowSet, columnCount); + + // Get metadata again to ensure we have the latest + TGetResultSetMetadataResp metadata = await HiveServer2Connection.GetResultSetMetadataAsync(OperationHandle!, Connection.Client, cancellationToken); + + // Get the arrays from the row set + IReadOnlyList data = HiveServer2Reader.GetArrowArrayData(rowSet, columnCount, schema, Connection.DataTypeConversion); + + return EnhanceGetColumnsResult(schema, data, rowCount, metadata, rowSet); + } } private async Task GetResultSetSchemaAsync(TOperationHandle operationHandle, TCLIService.IAsync client, CancellationToken cancellationToken = default) @@ -428,39 +457,16 @@ private async Task GetQueryResult(TSparkDirectResults? directResult int rowCount = HiveServer2Reader.GetRowCount(rowSet, columnCount); IReadOnlyList data = HiveServer2Reader.GetArrowArrayData(rowSet, columnCount, schema, Connection.DataTypeConversion); - // Enhance column schema results if this is a GetColumns query - if (SqlQuery?.ToLowerInvariant() == GetColumnsCommandName) - { - return EnhanceGetColumnsResult(schema, data, rowCount, resultSetMetadata, rowSet); - } - return new QueryResult(rowCount, new HiveServer2Connection.HiveInfoArrowStream(schema, data)); } await HiveServer2Connection.PollForResponseAsync(OperationHandle!, Connection.Client, PollTimeMilliseconds, cancellationToken); schema = await GetResultSetSchemaAsync(OperationHandle!, Connection.Client, cancellationToken); - // For GetColumns operation, we need to fetch the results and enhance them - if (SqlQuery?.ToLowerInvariant() == GetColumnsCommandName) - { - // Fetch the results manually to enhance them - TRowSet rowSet = await Connection.FetchResultsAsync(OperationHandle!, BatchSize, cancellationToken); - int columnCount = HiveServer2Reader.GetColumnCount(rowSet); - int rowCount = HiveServer2Reader.GetRowCount(rowSet, columnCount); - - // Get metadata again to ensure we have the latest - TGetResultSetMetadataResp metadata = await HiveServer2Connection.GetResultSetMetadataAsync(OperationHandle!, Connection.Client, cancellationToken); - - // Get the arrays from the row set - IReadOnlyList data = HiveServer2Reader.GetArrowArrayData(rowSet, columnCount, schema, Connection.DataTypeConversion); - - return EnhanceGetColumnsResult(schema, data, rowCount, metadata, rowSet); - } - return new QueryResult(-1, Connection.NewReader(this, schema)); } - private QueryResult EnhanceGetColumnsResult(Schema originalSchema, IReadOnlyList originalData, + protected internal QueryResult EnhanceGetColumnsResult(Schema originalSchema, IReadOnlyList originalData, int rowCount, TGetResultSetMetadataResp metadata, TRowSet rowSet) { // Create a column map using Connection's GetColumnIndexMap method From eae907a470fc8a74cd86de4797d8844227154385 Mon Sep 17 00:00:00 2001 From: Eric Wang Date: Thu, 10 Apr 2025 16:12:24 -0700 Subject: [PATCH 4/7] address comments --- csharp/src/Drivers/Apache/Hive2/HiveServer2Connection.cs | 4 ++-- csharp/src/Drivers/Apache/Hive2/HiveServer2HttpConnection.cs | 2 +- csharp/src/Drivers/Apache/Impala/ImpalaConnection.cs | 2 +- csharp/src/Drivers/Apache/Spark/SparkConnection.cs | 2 +- 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/csharp/src/Drivers/Apache/Hive2/HiveServer2Connection.cs b/csharp/src/Drivers/Apache/Hive2/HiveServer2Connection.cs index 1bb986529d..9a17eb6c1b 100644 --- a/csharp/src/Drivers/Apache/Hive2/HiveServer2Connection.cs +++ b/csharp/src/Drivers/Apache/Hive2/HiveServer2Connection.cs @@ -1242,7 +1242,7 @@ private static StructArray GetColumnSchema(TableInfo tableInfo) nullBitmapBuffer.Build()); } - public abstract void SetPrecisionScaleAndTypeName(short columnType, string typeName, TableInfo? tableInfo, int columnSize, int decimalDigits); + internal abstract void SetPrecisionScaleAndTypeName(short columnType, string typeName, TableInfo? tableInfo, int columnSize, int decimalDigits); public override Schema GetTableSchema(string? catalog, string? dbSchema, string? tableName) { @@ -1359,7 +1359,7 @@ private static IArrowType GetArrowType(int columnTypeId, string typeName, bool i } } - public async Task FetchResultsAsync(TOperationHandle operationHandle, long batchSize = BatchSizeDefault, CancellationToken cancellationToken = default) + internal async Task FetchResultsAsync(TOperationHandle operationHandle, long batchSize = BatchSizeDefault, CancellationToken cancellationToken = default) { await PollForResponseAsync(operationHandle, Client, PollTimeMillisecondsDefault, cancellationToken); diff --git a/csharp/src/Drivers/Apache/Hive2/HiveServer2HttpConnection.cs b/csharp/src/Drivers/Apache/Hive2/HiveServer2HttpConnection.cs index 6ae19e6c8f..1c882b78c6 100644 --- a/csharp/src/Drivers/Apache/Hive2/HiveServer2HttpConnection.cs +++ b/csharp/src/Drivers/Apache/Hive2/HiveServer2HttpConnection.cs @@ -224,7 +224,7 @@ protected override TOpenSessionReq CreateSessionRequest() return req; } - public override void SetPrecisionScaleAndTypeName( + internal override void SetPrecisionScaleAndTypeName( short colType, string typeName, TableInfo? tableInfo, diff --git a/csharp/src/Drivers/Apache/Impala/ImpalaConnection.cs b/csharp/src/Drivers/Apache/Impala/ImpalaConnection.cs index a63ae3cd79..c5d8e45c8f 100644 --- a/csharp/src/Drivers/Apache/Impala/ImpalaConnection.cs +++ b/csharp/src/Drivers/Apache/Impala/ImpalaConnection.cs @@ -86,7 +86,7 @@ protected override Task GetRowSetAsync(TGetSchemasResp response, Cancel protected internal override Task GetRowSetAsync(TGetPrimaryKeysResp response, CancellationToken cancellationToken = default) => FetchResultsAsync(response.OperationHandle, cancellationToken: cancellationToken); - public override void SetPrecisionScaleAndTypeName( + internal override void SetPrecisionScaleAndTypeName( short colType, string typeName, TableInfo? tableInfo, diff --git a/csharp/src/Drivers/Apache/Spark/SparkConnection.cs b/csharp/src/Drivers/Apache/Spark/SparkConnection.cs index 8682f9c26b..925073e351 100644 --- a/csharp/src/Drivers/Apache/Spark/SparkConnection.cs +++ b/csharp/src/Drivers/Apache/Spark/SparkConnection.cs @@ -63,7 +63,7 @@ public override AdbcStatement CreateStatement() protected internal override int PositionRequiredOffset => 1; - public override void SetPrecisionScaleAndTypeName( + internal override void SetPrecisionScaleAndTypeName( short colType, string typeName, TableInfo? tableInfo, From 0c3eb3d77cda5d9c46f070d7275985be955761b0 Mon Sep 17 00:00:00 2001 From: Eric Wang Date: Thu, 10 Apr 2025 16:18:25 -0700 Subject: [PATCH 5/7] lint --- csharp/src/Drivers/Apache/Hive2/HiveServer2Statement.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csharp/src/Drivers/Apache/Hive2/HiveServer2Statement.cs b/csharp/src/Drivers/Apache/Hive2/HiveServer2Statement.cs index 8a4094ae34..983a92385d 100644 --- a/csharp/src/Drivers/Apache/Hive2/HiveServer2Statement.cs +++ b/csharp/src/Drivers/Apache/Hive2/HiveServer2Statement.cs @@ -423,7 +423,7 @@ private async Task GetColumnsAsync(CancellationToken cancellationTo { await HiveServer2Connection.PollForResponseAsync(OperationHandle!, Connection.Client, PollTimeMilliseconds, cancellationToken); Schema schema = await GetResultSetSchemaAsync(OperationHandle!, Connection.Client, cancellationToken); - + // Fetch the results manually to enhance them TRowSet rowSet = await Connection.FetchResultsAsync(OperationHandle!, BatchSize, cancellationToken); int columnCount = HiveServer2Reader.GetColumnCount(rowSet); From cb139a59e3509df06311883505b238f22e3ebebb Mon Sep 17 00:00:00 2001 From: Eric Wang Date: Mon, 21 Apr 2025 11:07:46 -0700 Subject: [PATCH 6/7] address comments --- .../Apache/Hive2/HiveServer2Statement.cs | 43 ++++++++++--------- 1 file changed, 23 insertions(+), 20 deletions(-) diff --git a/csharp/src/Drivers/Apache/Hive2/HiveServer2Statement.cs b/csharp/src/Drivers/Apache/Hive2/HiveServer2Statement.cs index 983a92385d..f025021d7e 100644 --- a/csharp/src/Drivers/Apache/Hive2/HiveServer2Statement.cs +++ b/csharp/src/Drivers/Apache/Hive2/HiveServer2Statement.cs @@ -407,36 +407,39 @@ private async Task GetColumnsAsync(CancellationToken cancellationTo cancellationToken); OperationHandle = resp.OperationHandle; + // Common variables declared upfront + TGetResultSetMetadataResp metadata; + Schema schema; + TRowSet rowSet; + // For GetColumns, we need to enhance the result with BASE_TYPE_NAME if (Connection.AreResultsAvailableDirectly() && resp.DirectResults?.ResultSet?.Results != null) { - TGetResultSetMetadataResp resultSetMetadata = resp.DirectResults.ResultSetMetadata; - Schema schema = Connection.SchemaParser.GetArrowSchema(resultSetMetadata.Schema, Connection.DataTypeConversion); - TRowSet rowSet = resp.DirectResults.ResultSet.Results; - int columnCount = HiveServer2Reader.GetColumnCount(rowSet); - int rowCount = HiveServer2Reader.GetRowCount(rowSet, columnCount); - IReadOnlyList data = HiveServer2Reader.GetArrowArrayData(rowSet, columnCount, schema, Connection.DataTypeConversion); - - return EnhanceGetColumnsResult(schema, data, rowCount, resultSetMetadata, rowSet); + // Get data from direct results + metadata = resp.DirectResults.ResultSetMetadata; + schema = Connection.SchemaParser.GetArrowSchema(metadata.Schema, Connection.DataTypeConversion); + rowSet = resp.DirectResults.ResultSet.Results; } else { + // Poll and fetch results await HiveServer2Connection.PollForResponseAsync(OperationHandle!, Connection.Client, PollTimeMilliseconds, cancellationToken); - Schema schema = await GetResultSetSchemaAsync(OperationHandle!, Connection.Client, cancellationToken); + + // Get metadata + metadata = await HiveServer2Connection.GetResultSetMetadataAsync(OperationHandle!, Connection.Client, cancellationToken); + schema = Connection.SchemaParser.GetArrowSchema(metadata.Schema, Connection.DataTypeConversion); - // Fetch the results manually to enhance them - TRowSet rowSet = await Connection.FetchResultsAsync(OperationHandle!, BatchSize, cancellationToken); - int columnCount = HiveServer2Reader.GetColumnCount(rowSet); - int rowCount = HiveServer2Reader.GetRowCount(rowSet, columnCount); - - // Get metadata again to ensure we have the latest - TGetResultSetMetadataResp metadata = await HiveServer2Connection.GetResultSetMetadataAsync(OperationHandle!, Connection.Client, cancellationToken); + // Fetch the results + rowSet = await Connection.FetchResultsAsync(OperationHandle!, BatchSize, cancellationToken); + } - // Get the arrays from the row set - IReadOnlyList data = HiveServer2Reader.GetArrowArrayData(rowSet, columnCount, schema, Connection.DataTypeConversion); + // Common processing for both paths + int columnCount = HiveServer2Reader.GetColumnCount(rowSet); + int rowCount = HiveServer2Reader.GetRowCount(rowSet, columnCount); + IReadOnlyList data = HiveServer2Reader.GetArrowArrayData(rowSet, columnCount, schema, Connection.DataTypeConversion); - return EnhanceGetColumnsResult(schema, data, rowCount, metadata, rowSet); - } + // Return the enhanced result with added BASE_TYPE_NAME column + return EnhanceGetColumnsResult(schema, data, rowCount, metadata, rowSet); } private async Task GetResultSetSchemaAsync(TOperationHandle operationHandle, TCLIService.IAsync client, CancellationToken cancellationToken = default) From a3b14338d80f264bf5fc67086a2bd469382d3d3c Mon Sep 17 00:00:00 2001 From: Eric Wang Date: Mon, 21 Apr 2025 11:21:11 -0700 Subject: [PATCH 7/7] lint --- csharp/src/Drivers/Apache/Hive2/HiveServer2Statement.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csharp/src/Drivers/Apache/Hive2/HiveServer2Statement.cs b/csharp/src/Drivers/Apache/Hive2/HiveServer2Statement.cs index f025021d7e..6079253336 100644 --- a/csharp/src/Drivers/Apache/Hive2/HiveServer2Statement.cs +++ b/csharp/src/Drivers/Apache/Hive2/HiveServer2Statement.cs @@ -424,7 +424,7 @@ private async Task GetColumnsAsync(CancellationToken cancellationTo { // Poll and fetch results await HiveServer2Connection.PollForResponseAsync(OperationHandle!, Connection.Client, PollTimeMilliseconds, cancellationToken); - + // Get metadata metadata = await HiveServer2Connection.GetResultSetMetadataAsync(OperationHandle!, Connection.Client, cancellationToken); schema = Connection.SchemaParser.GetArrowSchema(metadata.Schema, Connection.DataTypeConversion);