Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -969,6 +969,11 @@ public Integer getHttpConnectionRequestTimeout() {
return null;
}

@Override
public boolean enableShowCommandsForGetFunctions() {
return getParameter(DatabricksJdbcUrlParams.ENABLE_SHOW_COMMAND_FOR_GET_FUNCTIONS).equals("1");
}

private static boolean nullOrEmptyString(String s) {
return s == null || s.isEmpty();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import com.databricks.jdbc.api.internal.IDatabricksSession;
import com.databricks.jdbc.common.*;
import com.databricks.jdbc.common.util.DriverUtil;
import com.databricks.jdbc.common.util.WildcardUtil;
import com.databricks.jdbc.dbclient.impl.common.MetadataResultSetBuilder;
import com.databricks.jdbc.dbclient.impl.common.StatementId;
import com.databricks.jdbc.exception.DatabricksSQLException;
Expand Down Expand Up @@ -1527,18 +1528,21 @@ public ResultSet getClientInfoProperties() throws SQLException {
public ResultSet getFunctions(String catalog, String schemaPattern, String functionNamePattern)
throws SQLException {
LOGGER.debug(
String.format(
"public ResultSet getFunctions(String catalog = {}, String schemaPattern = {}, String functionNamePattern = {})",
catalog,
schemaPattern,
functionNamePattern));
"public ResultSet getFunctions(String catalog = {}, String schemaPattern = {}, String functionNamePattern = {})",
catalog,
schemaPattern,
functionNamePattern);
throwExceptionIfConnectionIsClosed();
try {
if (WildcardUtil.isNullOrEmpty(functionNamePattern)) {
functionNamePattern =
"%"; // This is because functionName is a required parameter in thrift flow.
}
return session
.getDatabricksMetadataClient()
.listFunctions(session, catalog, schemaPattern, functionNamePattern);
} catch (Exception e) {
LOGGER.error(e, "Unable to fetch functions, returning empty result set");
LOGGER.error(e, "Unable to fetch functions, returning empty result set {}", e);
return metadataResultSetBuilder.getFunctionsResult(catalog, List.of());
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,13 @@ public class InlineChunkProvider implements ChunkProvider {
this.currentChunkIndex = -1;
this.totalRows = 0;
ByteArrayInputStream byteStream = initializeByteStream(resultsResp, session, parentStatement);
arrowResultChunk =
ArrowResultChunk.builder()
.withInputStream(byteStream, totalRows)
.withStatementId(parentStatement.getStatementId())
.build();
ArrowResultChunk.Builder builder =
ArrowResultChunk.builder().withInputStream(byteStream, totalRows);

if (parentStatement != null) {
builder.withStatementId(parentStatement.getStatementId());
}
arrowResultChunk = builder.build();
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,8 @@ public interface IDatabricksConnectionContext {
/** Returns the HTTP connection request timeout in seconds */
Integer getHttpConnectionRequestTimeout();

boolean enableShowCommandsForGetFunctions();

/** Returns whether batched INSERT optimization is enabled */
boolean isBatchedInsertsEnabled();
}
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,8 @@ public enum DatabricksJdbcUrlParams {
"HttpConnectionRequestTimeout", "HTTP connection request timeout in seconds"),
CLOUD_FETCH_SPEED_THRESHOLD(
"CloudFetchSpeedThreshold", "Minimum expected download speed in MB/s", "0.1"),
ENABLE_SHOW_COMMAND_FOR_GET_FUNCTIONS(
"EnableShowCommandForGetFunctions", "Use SQL command to fetch function list", "0"),
ENABLE_BATCHED_INSERTS("EnableBatchedInserts", "Enable batched INSERT optimization", "0"),
ENABLE_SQL_VALIDATION_FOR_IS_VALID(
"EnableSQLValidationForIsValid",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@

import com.databricks.jdbc.api.internal.IDatabricksSession;
import com.databricks.jdbc.common.util.WildcardUtil;
import com.databricks.jdbc.exception.DatabricksSQLFeatureNotSupportedException;
import com.databricks.jdbc.exception.DatabricksSQLException;
import com.databricks.jdbc.exception.DatabricksValidationException;
import com.databricks.jdbc.log.JdbcLogger;
import com.databricks.jdbc.log.JdbcLoggerFactory;
import java.sql.SQLException;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
Expand Down Expand Up @@ -74,7 +74,7 @@ private String fetchCatalogSQL() {
return SHOW_CATALOGS_SQL;
}

private String fetchSchemaSQL() throws SQLException {
private String fetchSchemaSQL() {
LOGGER.debug(
"Building command for fetching schema. Catalog %s, SchemaPattern %s and session context %s",
catalogName, schemaPattern, sessionContext);
Expand All @@ -91,7 +91,7 @@ private String fetchSchemaSQL() throws SQLException {
return showSchemasSQL;
}

private String fetchTablesSQL() throws SQLException {
private String fetchTablesSQL() {
LOGGER.debug(
"Building command for fetching tables. Catalog %s, SchemaPattern %s, TablePattern %s and session context %s",
catalogName, schemaPattern, tablePattern, sessionContext);
Expand All @@ -111,7 +111,7 @@ private String fetchTablesSQL() throws SQLException {
return showTablesSQL;
}

private String fetchColumnsSQL() throws SQLException {
private String fetchColumnsSQL() throws DatabricksSQLException {
String contextString =
String.format(
"Building command for fetching columns. Catalog %s, SchemaPattern %s, TablePattern %s, ColumnPattern %s and session context : %s",
Expand All @@ -134,7 +134,7 @@ private String fetchColumnsSQL() throws SQLException {
return showColumnsSQL;
}

private String fetchFunctionsSQL() throws SQLException {
private String fetchFunctionsSQL() throws DatabricksSQLException {
String contextString =
String.format(
"Building command for fetching functions. Catalog %s, SchemaPattern %s, FunctionPattern %s. With session context %s",
Expand All @@ -156,7 +156,7 @@ private String fetchTableTypesSQL() {
return SHOW_TABLE_TYPES_SQL;
}

private String fetchPrimaryKeysSQL() throws SQLException {
private String fetchPrimaryKeysSQL() throws DatabricksSQLException {
String contextString =
String.format(
"Building command for fetching primary keys. Catalog %s, Schema %s, Table %s. With session context: %s",
Expand All @@ -170,7 +170,7 @@ private String fetchPrimaryKeysSQL() throws SQLException {
return String.format(SHOW_PRIMARY_KEYS_SQL, catalogName, schemaName, tableName);
}

private String fetchForeignKeysSQL() throws SQLException {
private String fetchForeignKeysSQL() throws DatabricksSQLException {
String contextString =
String.format(
"Building command for fetching foreign keys. Catalog %s, Schema %s, Table %s. With session context: %s",
Expand All @@ -184,7 +184,7 @@ private String fetchForeignKeysSQL() throws SQLException {
return String.format(SHOW_FOREIGN_KEYS_SQL, catalogName, schemaName, tableName);
}

public String getSQLString(CommandName command) throws SQLException {
public String getSQLString(CommandName command) throws DatabricksSQLException {
switch (command) {
case LIST_CATALOGS:
return fetchCatalogSQL();
Expand All @@ -203,7 +203,7 @@ public String getSQLString(CommandName command) throws SQLException {
case LIST_FOREIGN_KEYS:
return fetchForeignKeysSQL();
}
throw new DatabricksSQLFeatureNotSupportedException(
throw new DatabricksValidationException(
String.format("Invalid command issued %s. Context: %s", command, sessionContext));
}
}
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
package com.databricks.jdbc.dbclient.impl.thrift;

import static com.databricks.jdbc.common.EnvironmentVariables.DEFAULT_STATEMENT_TIMEOUT_SECONDS;
import static com.databricks.jdbc.common.EnvironmentVariables.JDBC_THRIFT_VERSION;
import static com.databricks.jdbc.common.util.DatabricksThriftUtil.*;
import static com.databricks.jdbc.common.util.DatabricksTypeUtil.DECIMAL;
import static com.databricks.jdbc.common.util.DatabricksTypeUtil.getDecimalTypeString;
import static com.databricks.jdbc.dbclient.impl.sqlexec.CommandName.LIST_FUNCTIONS;
import static com.databricks.jdbc.dbclient.impl.sqlexec.ResultConstants.TYPE_INFO_RESULT;

import com.databricks.jdbc.api.impl.*;
Expand All @@ -19,6 +21,7 @@
import com.databricks.jdbc.dbclient.IDatabricksMetadataClient;
import com.databricks.jdbc.dbclient.impl.common.MetadataResultSetBuilder;
import com.databricks.jdbc.dbclient.impl.common.StatementId;
import com.databricks.jdbc.dbclient.impl.sqlexec.CommandBuilder;
import com.databricks.jdbc.exception.DatabricksHttpException;
import com.databricks.jdbc.exception.DatabricksParsingException;
import com.databricks.jdbc.exception.DatabricksSQLException;
Expand Down Expand Up @@ -144,7 +147,7 @@ public DatabricksResultSet executeStatement(
LOGGER.debug(
String.format(
"public DatabricksResultSet executeStatement(String sql = {%s}, Compute cluster = {%s}, Map<Integer, ImmutableSqlParameter> parameters = {%s}, StatementType statementType = {%s}, IDatabricksSession session)",
sql, computeResource.toString(), parameters.toString(), statementType));
sql, computeResource, parameters.toString(), statementType));

DatabricksThreadContextHolder.setStatementType(statementType);

Expand Down Expand Up @@ -200,11 +203,14 @@ private TExecuteStatementReq getRequest(
parameters.values().stream()
.map(this::mapToSparkParameterListItem)
.collect(Collectors.toList());

int timeout = DEFAULT_STATEMENT_TIMEOUT_SECONDS;
if (parentStatement != null && parentStatement.getStatement() != null) {
timeout = parentStatement.getStatement().getQueryTimeout();
}
TExecuteStatementReq request =
new TExecuteStatementReq()
.setStatement(sql)
.setQueryTimeout(parentStatement.getStatement().getQueryTimeout())
.setQueryTimeout(timeout)
.setSessionHandle(Objects.requireNonNull(session.getSessionInfo()).sessionHandle())
.setCanReadArrowResult(this.connectionContext.shouldEnableArrow())
.setUseArrowNativeTypes(arrowNativeTypes);
Expand All @@ -228,7 +234,7 @@ private TExecuteStatementReq getRequest(
request.setUseArrowNativeTypes(arrowNativeTypes);
}

int maxRows = parentStatement.getMaxRows();
int maxRows = (parentStatement == null) ? 0 : parentStatement.getMaxRows();
if (maxRows > 0) { // set request param only if user has set maxRows.
// Similar
// behavior
Expand Down Expand Up @@ -436,13 +442,34 @@ public DatabricksResultSet listFunctions(
String catalog,
String schemaNamePattern,
String functionNamePattern)
throws DatabricksSQLException {
throws SQLException {
String context =
String.format(
"Fetching functions using Thrift client. Session {%s}, catalog {%s}, schemaNamePattern {%s}, functionNamePattern {%s}.",
session.toString(), catalog, schemaNamePattern, functionNamePattern);
DatabricksThreadContextHolder.setSessionId(session.getSessionId());
LOGGER.debug(context);
if (connectionContext.enableShowCommandsForGetFunctions()) {
String showFunctionsSqlCommand =
new CommandBuilder(catalog, session)
.setSchemaPattern(schemaNamePattern)
.setFunctionPattern(functionNamePattern)
.getSQLString(LIST_FUNCTIONS);
LOGGER.debug(
"Fetching functions using SQL Command {{}}. Session {{}}",
showFunctionsSqlCommand,
session.toString());
try (DatabricksResultSet rs =
executeStatement(
showFunctionsSqlCommand,
session.getComputeResource(),
Collections.emptyMap(),
StatementType.METADATA,
session,
null)) {
return metadataResultSetBuilder.getFunctionsResult(rs, catalog);
}
}
TGetFunctionsReq request =
new TGetFunctionsReq()
.setSessionHandle(Objects.requireNonNull(session.getSessionInfo()).sessionHandle())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import com.databricks.jdbc.api.internal.IDatabricksSession;
import com.databricks.jdbc.common.util.WildcardUtil;
import com.databricks.jdbc.exception.DatabricksSQLFeatureNotSupportedException;
import com.databricks.jdbc.exception.DatabricksValidationException;
import java.sql.SQLException;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.DisplayName;
Expand Down Expand Up @@ -200,7 +200,6 @@ void shouldThrowExceptionForUnsupportedCommand() {

CommandName mockCommand = mock(CommandName.class);

assertThrows(
DatabricksSQLFeatureNotSupportedException.class, () -> builder.getSQLString(mockCommand));
assertThrows(DatabricksValidationException.class, () -> builder.getSQLString(mockCommand));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -742,7 +742,7 @@ void testListCrossReferences_notAvailable() throws Exception {

@ParameterizedTest
@MethodSource("listFunctionsTestParams")
void testTestFunctions(
void testGetFunctions(
String sql, String catalog, String schema, String functionPattern, String description)
throws SQLException {
when(session.getComputeResource()).thenReturn(WAREHOUSE_COMPUTE);
Expand All @@ -768,7 +768,6 @@ void testTestFunctions(
when(mockedResultSet.getMetaData()).thenReturn(mockedMetaData);
DatabricksResultSet actualResult =
metadataClient.listFunctions(session, catalog, schema, functionPattern);

assertEquals(
actualResult.getStatementStatus().getState(), StatementState.SUCCEEDED, description);
assertEquals(actualResult.getStatementId(), GET_FUNCTIONS_STATEMENT_ID, description);
Expand Down
Loading
Loading