diff --git a/src/main/java/com/databricks/jdbc/api/impl/DatabricksConnectionContext.java b/src/main/java/com/databricks/jdbc/api/impl/DatabricksConnectionContext.java index d9cad02e69..210d62f73b 100644 --- a/src/main/java/com/databricks/jdbc/api/impl/DatabricksConnectionContext.java +++ b/src/main/java/com/databricks/jdbc/api/impl/DatabricksConnectionContext.java @@ -969,6 +969,11 @@ public Integer getHttpConnectionRequestTimeout() { return null; } + @Override + public boolean enableShowCommandsForGetFunctions() { + return getParameter(DatabricksJdbcUrlParams.ENABLE_SHOW_COMMAND_FOR_GET_FUNCTIONS).equals("1"); + } + private static boolean nullOrEmptyString(String s) { return s == null || s.isEmpty(); } diff --git a/src/main/java/com/databricks/jdbc/api/impl/DatabricksDatabaseMetaData.java b/src/main/java/com/databricks/jdbc/api/impl/DatabricksDatabaseMetaData.java index f9d8d14bf2..a8e6b67458 100644 --- a/src/main/java/com/databricks/jdbc/api/impl/DatabricksDatabaseMetaData.java +++ b/src/main/java/com/databricks/jdbc/api/impl/DatabricksDatabaseMetaData.java @@ -9,6 +9,7 @@ import com.databricks.jdbc.api.internal.IDatabricksSession; import com.databricks.jdbc.common.*; import com.databricks.jdbc.common.util.DriverUtil; +import com.databricks.jdbc.common.util.WildcardUtil; import com.databricks.jdbc.dbclient.impl.common.MetadataResultSetBuilder; import com.databricks.jdbc.dbclient.impl.common.StatementId; import com.databricks.jdbc.exception.DatabricksSQLException; @@ -1527,18 +1528,21 @@ public ResultSet getClientInfoProperties() throws SQLException { public ResultSet getFunctions(String catalog, String schemaPattern, String functionNamePattern) throws SQLException { LOGGER.debug( - String.format( - "public ResultSet getFunctions(String catalog = {}, String schemaPattern = {}, String functionNamePattern = {})", - catalog, - schemaPattern, - functionNamePattern)); + "public ResultSet getFunctions(String catalog = {}, String schemaPattern = {}, String functionNamePattern = {})", + catalog, + schemaPattern, + functionNamePattern); throwExceptionIfConnectionIsClosed(); try { + if (WildcardUtil.isNullOrEmpty(functionNamePattern)) { + functionNamePattern = + "%"; // This is because functionName is a required parameter in thrift flow. + } return session .getDatabricksMetadataClient() .listFunctions(session, catalog, schemaPattern, functionNamePattern); } catch (Exception e) { - LOGGER.error(e, "Unable to fetch functions, returning empty result set"); + LOGGER.error(e, "Unable to fetch functions, returning empty result set {}", e); return metadataResultSetBuilder.getFunctionsResult(catalog, List.of()); } } diff --git a/src/main/java/com/databricks/jdbc/api/impl/arrow/InlineChunkProvider.java b/src/main/java/com/databricks/jdbc/api/impl/arrow/InlineChunkProvider.java index 39150c6da7..e22d974a40 100644 --- a/src/main/java/com/databricks/jdbc/api/impl/arrow/InlineChunkProvider.java +++ b/src/main/java/com/databricks/jdbc/api/impl/arrow/InlineChunkProvider.java @@ -45,11 +45,13 @@ public class InlineChunkProvider implements ChunkProvider { this.currentChunkIndex = -1; this.totalRows = 0; ByteArrayInputStream byteStream = initializeByteStream(resultsResp, session, parentStatement); - arrowResultChunk = - ArrowResultChunk.builder() - .withInputStream(byteStream, totalRows) - .withStatementId(parentStatement.getStatementId()) - .build(); + ArrowResultChunk.Builder builder = + ArrowResultChunk.builder().withInputStream(byteStream, totalRows); + + if (parentStatement != null) { + builder.withStatementId(parentStatement.getStatementId()); + } + arrowResultChunk = builder.build(); } /** diff --git a/src/main/java/com/databricks/jdbc/api/internal/IDatabricksConnectionContext.java b/src/main/java/com/databricks/jdbc/api/internal/IDatabricksConnectionContext.java index efcca61d7d..e8fb6c4811 100644 --- a/src/main/java/com/databricks/jdbc/api/internal/IDatabricksConnectionContext.java +++ b/src/main/java/com/databricks/jdbc/api/internal/IDatabricksConnectionContext.java @@ -361,6 +361,8 @@ public interface IDatabricksConnectionContext { /** Returns the HTTP connection request timeout in seconds */ Integer getHttpConnectionRequestTimeout(); + boolean enableShowCommandsForGetFunctions(); + /** Returns whether batched INSERT optimization is enabled */ boolean isBatchedInsertsEnabled(); } diff --git a/src/main/java/com/databricks/jdbc/common/DatabricksJdbcUrlParams.java b/src/main/java/com/databricks/jdbc/common/DatabricksJdbcUrlParams.java index 3150c8e4b8..764c2c9446 100644 --- a/src/main/java/com/databricks/jdbc/common/DatabricksJdbcUrlParams.java +++ b/src/main/java/com/databricks/jdbc/common/DatabricksJdbcUrlParams.java @@ -154,6 +154,8 @@ public enum DatabricksJdbcUrlParams { "HttpConnectionRequestTimeout", "HTTP connection request timeout in seconds"), CLOUD_FETCH_SPEED_THRESHOLD( "CloudFetchSpeedThreshold", "Minimum expected download speed in MB/s", "0.1"), + ENABLE_SHOW_COMMAND_FOR_GET_FUNCTIONS( + "EnableShowCommandForGetFunctions", "Use SQL command to fetch function list", "0"), ENABLE_BATCHED_INSERTS("EnableBatchedInserts", "Enable batched INSERT optimization", "0"), ENABLE_SQL_VALIDATION_FOR_IS_VALID( "EnableSQLValidationForIsValid", diff --git a/src/main/java/com/databricks/jdbc/dbclient/impl/sqlexec/CommandBuilder.java b/src/main/java/com/databricks/jdbc/dbclient/impl/sqlexec/CommandBuilder.java index a1ffba7db8..8645f1726d 100644 --- a/src/main/java/com/databricks/jdbc/dbclient/impl/sqlexec/CommandBuilder.java +++ b/src/main/java/com/databricks/jdbc/dbclient/impl/sqlexec/CommandBuilder.java @@ -6,10 +6,10 @@ import com.databricks.jdbc.api.internal.IDatabricksSession; import com.databricks.jdbc.common.util.WildcardUtil; -import com.databricks.jdbc.exception.DatabricksSQLFeatureNotSupportedException; +import com.databricks.jdbc.exception.DatabricksSQLException; +import com.databricks.jdbc.exception.DatabricksValidationException; import com.databricks.jdbc.log.JdbcLogger; import com.databricks.jdbc.log.JdbcLoggerFactory; -import java.sql.SQLException; import java.util.Collections; import java.util.HashMap; import java.util.Map; @@ -74,7 +74,7 @@ private String fetchCatalogSQL() { return SHOW_CATALOGS_SQL; } - private String fetchSchemaSQL() throws SQLException { + private String fetchSchemaSQL() { LOGGER.debug( "Building command for fetching schema. Catalog %s, SchemaPattern %s and session context %s", catalogName, schemaPattern, sessionContext); @@ -91,7 +91,7 @@ private String fetchSchemaSQL() throws SQLException { return showSchemasSQL; } - private String fetchTablesSQL() throws SQLException { + private String fetchTablesSQL() { LOGGER.debug( "Building command for fetching tables. Catalog %s, SchemaPattern %s, TablePattern %s and session context %s", catalogName, schemaPattern, tablePattern, sessionContext); @@ -111,7 +111,7 @@ private String fetchTablesSQL() throws SQLException { return showTablesSQL; } - private String fetchColumnsSQL() throws SQLException { + private String fetchColumnsSQL() throws DatabricksSQLException { String contextString = String.format( "Building command for fetching columns. Catalog %s, SchemaPattern %s, TablePattern %s, ColumnPattern %s and session context : %s", @@ -134,7 +134,7 @@ private String fetchColumnsSQL() throws SQLException { return showColumnsSQL; } - private String fetchFunctionsSQL() throws SQLException { + private String fetchFunctionsSQL() throws DatabricksSQLException { String contextString = String.format( "Building command for fetching functions. Catalog %s, SchemaPattern %s, FunctionPattern %s. With session context %s", @@ -156,7 +156,7 @@ private String fetchTableTypesSQL() { return SHOW_TABLE_TYPES_SQL; } - private String fetchPrimaryKeysSQL() throws SQLException { + private String fetchPrimaryKeysSQL() throws DatabricksSQLException { String contextString = String.format( "Building command for fetching primary keys. Catalog %s, Schema %s, Table %s. With session context: %s", @@ -170,7 +170,7 @@ private String fetchPrimaryKeysSQL() throws SQLException { return String.format(SHOW_PRIMARY_KEYS_SQL, catalogName, schemaName, tableName); } - private String fetchForeignKeysSQL() throws SQLException { + private String fetchForeignKeysSQL() throws DatabricksSQLException { String contextString = String.format( "Building command for fetching foreign keys. Catalog %s, Schema %s, Table %s. With session context: %s", @@ -184,7 +184,7 @@ private String fetchForeignKeysSQL() throws SQLException { return String.format(SHOW_FOREIGN_KEYS_SQL, catalogName, schemaName, tableName); } - public String getSQLString(CommandName command) throws SQLException { + public String getSQLString(CommandName command) throws DatabricksSQLException { switch (command) { case LIST_CATALOGS: return fetchCatalogSQL(); @@ -203,7 +203,7 @@ public String getSQLString(CommandName command) throws SQLException { case LIST_FOREIGN_KEYS: return fetchForeignKeysSQL(); } - throw new DatabricksSQLFeatureNotSupportedException( + throw new DatabricksValidationException( String.format("Invalid command issued %s. Context: %s", command, sessionContext)); } } diff --git a/src/main/java/com/databricks/jdbc/dbclient/impl/thrift/DatabricksThriftServiceClient.java b/src/main/java/com/databricks/jdbc/dbclient/impl/thrift/DatabricksThriftServiceClient.java index 55635d3841..9f2df55888 100644 --- a/src/main/java/com/databricks/jdbc/dbclient/impl/thrift/DatabricksThriftServiceClient.java +++ b/src/main/java/com/databricks/jdbc/dbclient/impl/thrift/DatabricksThriftServiceClient.java @@ -1,9 +1,11 @@ package com.databricks.jdbc.dbclient.impl.thrift; +import static com.databricks.jdbc.common.EnvironmentVariables.DEFAULT_STATEMENT_TIMEOUT_SECONDS; import static com.databricks.jdbc.common.EnvironmentVariables.JDBC_THRIFT_VERSION; import static com.databricks.jdbc.common.util.DatabricksThriftUtil.*; import static com.databricks.jdbc.common.util.DatabricksTypeUtil.DECIMAL; import static com.databricks.jdbc.common.util.DatabricksTypeUtil.getDecimalTypeString; +import static com.databricks.jdbc.dbclient.impl.sqlexec.CommandName.LIST_FUNCTIONS; import static com.databricks.jdbc.dbclient.impl.sqlexec.ResultConstants.TYPE_INFO_RESULT; import com.databricks.jdbc.api.impl.*; @@ -19,6 +21,7 @@ import com.databricks.jdbc.dbclient.IDatabricksMetadataClient; import com.databricks.jdbc.dbclient.impl.common.MetadataResultSetBuilder; import com.databricks.jdbc.dbclient.impl.common.StatementId; +import com.databricks.jdbc.dbclient.impl.sqlexec.CommandBuilder; import com.databricks.jdbc.exception.DatabricksHttpException; import com.databricks.jdbc.exception.DatabricksParsingException; import com.databricks.jdbc.exception.DatabricksSQLException; @@ -144,7 +147,7 @@ public DatabricksResultSet executeStatement( LOGGER.debug( String.format( "public DatabricksResultSet executeStatement(String sql = {%s}, Compute cluster = {%s}, Map parameters = {%s}, StatementType statementType = {%s}, IDatabricksSession session)", - sql, computeResource.toString(), parameters.toString(), statementType)); + sql, computeResource, parameters.toString(), statementType)); DatabricksThreadContextHolder.setStatementType(statementType); @@ -200,11 +203,14 @@ private TExecuteStatementReq getRequest( parameters.values().stream() .map(this::mapToSparkParameterListItem) .collect(Collectors.toList()); - + int timeout = DEFAULT_STATEMENT_TIMEOUT_SECONDS; + if (parentStatement != null && parentStatement.getStatement() != null) { + timeout = parentStatement.getStatement().getQueryTimeout(); + } TExecuteStatementReq request = new TExecuteStatementReq() .setStatement(sql) - .setQueryTimeout(parentStatement.getStatement().getQueryTimeout()) + .setQueryTimeout(timeout) .setSessionHandle(Objects.requireNonNull(session.getSessionInfo()).sessionHandle()) .setCanReadArrowResult(this.connectionContext.shouldEnableArrow()) .setUseArrowNativeTypes(arrowNativeTypes); @@ -228,7 +234,7 @@ private TExecuteStatementReq getRequest( request.setUseArrowNativeTypes(arrowNativeTypes); } - int maxRows = parentStatement.getMaxRows(); + int maxRows = (parentStatement == null) ? 0 : parentStatement.getMaxRows(); if (maxRows > 0) { // set request param only if user has set maxRows. // Similar // behavior @@ -436,13 +442,34 @@ public DatabricksResultSet listFunctions( String catalog, String schemaNamePattern, String functionNamePattern) - throws DatabricksSQLException { + throws SQLException { String context = String.format( "Fetching functions using Thrift client. Session {%s}, catalog {%s}, schemaNamePattern {%s}, functionNamePattern {%s}.", session.toString(), catalog, schemaNamePattern, functionNamePattern); DatabricksThreadContextHolder.setSessionId(session.getSessionId()); LOGGER.debug(context); + if (connectionContext.enableShowCommandsForGetFunctions()) { + String showFunctionsSqlCommand = + new CommandBuilder(catalog, session) + .setSchemaPattern(schemaNamePattern) + .setFunctionPattern(functionNamePattern) + .getSQLString(LIST_FUNCTIONS); + LOGGER.debug( + "Fetching functions using SQL Command {{}}. Session {{}}", + showFunctionsSqlCommand, + session.toString()); + try (DatabricksResultSet rs = + executeStatement( + showFunctionsSqlCommand, + session.getComputeResource(), + Collections.emptyMap(), + StatementType.METADATA, + session, + null)) { + return metadataResultSetBuilder.getFunctionsResult(rs, catalog); + } + } TGetFunctionsReq request = new TGetFunctionsReq() .setSessionHandle(Objects.requireNonNull(session.getSessionInfo()).sessionHandle()) diff --git a/src/test/java/com/databricks/jdbc/dbclient/impl/sqlexec/CommandBuilderTest.java b/src/test/java/com/databricks/jdbc/dbclient/impl/sqlexec/CommandBuilderTest.java index 4130df7543..0de1fe43ab 100644 --- a/src/test/java/com/databricks/jdbc/dbclient/impl/sqlexec/CommandBuilderTest.java +++ b/src/test/java/com/databricks/jdbc/dbclient/impl/sqlexec/CommandBuilderTest.java @@ -6,7 +6,7 @@ import com.databricks.jdbc.api.internal.IDatabricksSession; import com.databricks.jdbc.common.util.WildcardUtil; -import com.databricks.jdbc.exception.DatabricksSQLFeatureNotSupportedException; +import com.databricks.jdbc.exception.DatabricksValidationException; import java.sql.SQLException; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.DisplayName; @@ -200,7 +200,6 @@ void shouldThrowExceptionForUnsupportedCommand() { CommandName mockCommand = mock(CommandName.class); - assertThrows( - DatabricksSQLFeatureNotSupportedException.class, () -> builder.getSQLString(mockCommand)); + assertThrows(DatabricksValidationException.class, () -> builder.getSQLString(mockCommand)); } } diff --git a/src/test/java/com/databricks/jdbc/dbclient/impl/sqlexec/DatabricksMetadataSdkClientTest.java b/src/test/java/com/databricks/jdbc/dbclient/impl/sqlexec/DatabricksMetadataSdkClientTest.java index 647f71f93a..f71dab3bdd 100644 --- a/src/test/java/com/databricks/jdbc/dbclient/impl/sqlexec/DatabricksMetadataSdkClientTest.java +++ b/src/test/java/com/databricks/jdbc/dbclient/impl/sqlexec/DatabricksMetadataSdkClientTest.java @@ -742,7 +742,7 @@ void testListCrossReferences_notAvailable() throws Exception { @ParameterizedTest @MethodSource("listFunctionsTestParams") - void testTestFunctions( + void testGetFunctions( String sql, String catalog, String schema, String functionPattern, String description) throws SQLException { when(session.getComputeResource()).thenReturn(WAREHOUSE_COMPUTE); @@ -768,7 +768,6 @@ void testTestFunctions( when(mockedResultSet.getMetaData()).thenReturn(mockedMetaData); DatabricksResultSet actualResult = metadataClient.listFunctions(session, catalog, schema, functionPattern); - assertEquals( actualResult.getStatementStatus().getState(), StatementState.SUCCEEDED, description); assertEquals(actualResult.getStatementId(), GET_FUNCTIONS_STATEMENT_ID, description); diff --git a/src/test/java/com/databricks/jdbc/dbclient/impl/thrift/DatabricksThriftServiceClientTest.java b/src/test/java/com/databricks/jdbc/dbclient/impl/thrift/DatabricksThriftServiceClientTest.java index 54c7921be7..dfa5f545a7 100644 --- a/src/test/java/com/databricks/jdbc/dbclient/impl/thrift/DatabricksThriftServiceClientTest.java +++ b/src/test/java/com/databricks/jdbc/dbclient/impl/thrift/DatabricksThriftServiceClientTest.java @@ -27,6 +27,8 @@ import com.databricks.sdk.core.DatabricksConfig; import com.databricks.sdk.service.sql.StatementState; import java.math.BigDecimal; +import java.sql.ResultSet; +import java.sql.ResultSetMetaData; import java.sql.SQLException; import java.util.*; import java.util.stream.Stream; @@ -57,6 +59,7 @@ public class DatabricksThriftServiceClientTest { @Mock IDatabricksStatementInternal parentStatement; @Mock DatabricksStatement statement; @Mock DatabricksConfig databricksConfig; + @Mock ResultSetMetaData mockedMetaData; @Test void testCreateSession() throws DatabricksSQLException { @@ -580,6 +583,112 @@ void testListFunctions() throws SQLException { assertEquals(resultSet.getStatementStatus().getState(), StatementState.SUCCEEDED); } + @Test + void testListFunctionsWithSQLEnabled() throws SQLException { + DatabricksThriftServiceClient client = + new DatabricksThriftServiceClient(thriftAccessor, connectionContext); + when(connectionContext.enableShowCommandsForGetFunctions()).thenReturn(true); + when(connectionContext.shouldEnableArrow()).thenReturn(true); + when(session.getSessionInfo()).thenReturn(SESSION_INFO); + TSparkArrowTypes arrowNativeTypes = + new TSparkArrowTypes() + .setComplexTypesAsArrow(true) + .setIntervalTypesAsArrow(true) + .setNullTypeAsArrow(true) + .setDecimalAsArrow(true) + .setTimestampAsArrow(true); + TExecuteStatementReq executeStatementReq = + new TExecuteStatementReq() + .setStatement("SHOW FUNCTIONS IN CATALOG catalog1 SCHEMA LIKE 'testSchema' LIKE 'test'") + .setSessionHandle(SESSION_HANDLE) + .setCanReadArrowResult(true) + .setCanDecompressLZ4Result(true) + .setCanDownloadResult(true) + .setQueryTimeout(0) + .setParameters(Collections.emptyList()) + .setRunAsync(true) + .setUseArrowNativeTypes(arrowNativeTypes); + when(thriftAccessor.execute(executeStatementReq, null, session, StatementType.METADATA)) + .thenReturn(resultSet); + when(resultSet.getMetaData()).thenReturn(mockedMetaData); + when(mockedMetaData.getColumnCount()).thenReturn(6); + when(mockedMetaData.getColumnName(1)).thenReturn("functionName"); + when(mockedMetaData.getColumnName(2)).thenReturn("namespace"); + when(mockedMetaData.getColumnName(3)).thenReturn("catalogName"); + when(mockedMetaData.getColumnName(4)).thenReturn("remarks"); + when(mockedMetaData.getColumnName(5)).thenReturn("functionType"); + when(mockedMetaData.getColumnName(6)).thenReturn("specificName"); + when(resultSet.next()).thenReturn(true, false); + when(resultSet.getObject("functionName")).thenReturn("my_fn"); + when(resultSet.getObject("namespace")).thenReturn(TEST_SCHEMA); + when(resultSet.getObject("remarks")).thenReturn("remark"); + when(resultSet.getObject("functionType")).thenReturn(1); + when(resultSet.getObject("specificName")).thenReturn("my_fn"); + + ResultSet actualResultSet = + client.listFunctions(session, TEST_CATALOG, TEST_SCHEMA, TEST_STRING); + assertNotNull(actualResultSet); + assertTrue(actualResultSet.next()); + assertEquals(TEST_CATALOG, actualResultSet.getString("FUNCTION_CAT")); + assertEquals("my_fn", actualResultSet.getString("FUNCTION_NAME")); + } + + @Test + void testGetRequest_DefaultTimeoutAndNoRowLimit_WhenParentStatementNull() throws SQLException { + when(connectionContext.shouldEnableArrow()).thenReturn(true); + DatabricksThriftServiceClient client = + new DatabricksThriftServiceClient(thriftAccessor, connectionContext); + when(session.getSessionInfo()).thenReturn(SESSION_INFO); + + when(thriftAccessor.execute( + any(TExecuteStatementReq.class), eq(null), eq(session), eq(StatementType.SQL))) + .thenReturn(resultSet); + + client.executeStatement( + TEST_STRING, CLUSTER_COMPUTE, Collections.emptyMap(), StatementType.SQL, session, null); + + ArgumentCaptor requestCaptor = + ArgumentCaptor.forClass(TExecuteStatementReq.class); + verify(thriftAccessor) + .execute(requestCaptor.capture(), eq(null), eq(session), eq(StatementType.SQL)); + TExecuteStatementReq request = requestCaptor.getValue(); + assertEquals(0, request.getQueryTimeout()); + assertFalse(request.isSetResultRowLimit()); + } + + @Test + void testGetRequest_DefaultTimeout_WhenStatementNull() throws SQLException { + when(connectionContext.shouldEnableArrow()).thenReturn(true); + DatabricksThriftServiceClient client = + new DatabricksThriftServiceClient(thriftAccessor, connectionContext); + when(session.getSessionInfo()).thenReturn(SESSION_INFO); + when(parentStatement.getStatement()).thenReturn(null); + when(parentStatement.getMaxRows()).thenReturn(0); + + when(thriftAccessor.execute( + any(TExecuteStatementReq.class), + eq(parentStatement), + eq(session), + eq(StatementType.SQL))) + .thenReturn(resultSet); + + client.executeStatement( + TEST_STRING, + CLUSTER_COMPUTE, + Collections.emptyMap(), + StatementType.SQL, + session, + parentStatement); + + ArgumentCaptor requestCaptor = + ArgumentCaptor.forClass(TExecuteStatementReq.class); + verify(thriftAccessor) + .execute(requestCaptor.capture(), eq(parentStatement), eq(session), eq(StatementType.SQL)); + TExecuteStatementReq request = requestCaptor.getValue(); + assertEquals(0, request.getQueryTimeout()); + assertFalse(request.isSetResultRowLimit()); + } + @Test void testListPrimaryKeys() throws SQLException { DatabricksThriftServiceClient client =