Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import com.databricks.jdbc.api.internal.IDatabricksConnectionContext;
import com.databricks.jdbc.common.util.DatabricksThreadContextHolder;
import com.databricks.jdbc.dbclient.IDatabricksHttpClient;
import com.databricks.jdbc.dbclient.impl.common.StatementId;
import com.databricks.jdbc.exception.DatabricksParsingException;
import com.databricks.jdbc.exception.DatabricksSQLException;
import com.databricks.jdbc.log.JdbcLogger;
Expand All @@ -26,7 +25,7 @@ class ChunkDownloadTask implements DatabricksCallableTask {
private final IDatabricksHttpClient httpClient;
private final ChunkDownloadCallback chunkDownloader;
private final IDatabricksConnectionContext connectionContext;
private final StatementId statementId;
private final String statementId;
private final ChunkLinkDownloadService linkDownloadService;
Throwable uncaughtException = null;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,15 @@
import com.databricks.jdbc.common.StatementType;
import com.databricks.jdbc.dbclient.impl.common.StatementId;

/* TODO : eliminate the use of thread local completely. Currently, we are limiting the usage of this for non-critical flows such as telemetry.*/
public class DatabricksThreadContextHolder {
private static final ThreadLocal<IDatabricksConnectionContext> localConnectionContext =
new ThreadLocal<>();
private static final ThreadLocal<StatementId> localStatementId = new ThreadLocal<>();
private static final ThreadLocal<String> localStatementId = new ThreadLocal<>();
private static final ThreadLocal<Long> localChunkId = new ThreadLocal<>();
private static final ThreadLocal<Integer> localRetryCount = new ThreadLocal<>();
private static final ThreadLocal<StatementType> localStatementType = new ThreadLocal<>();
private static final ThreadLocal<String> localSessionId = new ThreadLocal<>();

public static void setConnectionContext(IDatabricksConnectionContext context) {
localConnectionContext.set(context);
Expand All @@ -21,13 +23,28 @@ public static IDatabricksConnectionContext getConnectionContext() {
}

public static void setStatementId(StatementId statementId) {
if (statementId != null) {
localStatementId.set(
statementId.toSQLExecStatementId()); // This is because only GUID is relevant for tracking
}
}

public static void setStatementId(String statementId) {
localStatementId.set(statementId);
}

public static StatementId getStatementId() {
public static String getStatementId() {
return localStatementId.get();
}

public static void setSessionId(String sessionId) {
localSessionId.set(sessionId);
}

public static String getSessionId() {
return localSessionId.get();
}

public static void setStatementType(StatementType statementType) {
localStatementType.set(statementType);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ public ImmutableSessionInfo createSession(
LOGGER.error(errorMessage, e);
throw new DatabricksSQLException(errorMessage, e, DatabricksDriverErrorCode.SDK_CLIENT_ERROR);
}
DatabricksThreadContextHolder.setSessionId(createSessionResponse.getSessionId());
return ImmutableSessionInfo.builder()
.computeResource(warehouse)
.sessionId(createSessionResponse.getSessionId())
Expand All @@ -137,6 +138,7 @@ public ImmutableSessionInfo createSession(
@Override
public void deleteSession(ImmutableSessionInfo sessionInfo) throws DatabricksSQLException {
LOGGER.debug("public void deleteSession(String sessionId = {})", sessionInfo.sessionId());
DatabricksThreadContextHolder.setSessionId(sessionInfo.sessionId());
DeleteSessionRequest request =
new DeleteSessionRequest()
.setSessionId(sessionInfo.sessionId())
Expand Down Expand Up @@ -164,11 +166,14 @@ public DatabricksResultSet executeStatement(
IDatabricksStatementInternal parentStatement)
throws SQLException {
LOGGER.debug(
"public DatabricksResultSet executeStatement(String sql = {}, compute resource = {}, Map<Integer, ImmutableSqlParameter> parameters = {}, StatementType statementType = {}, IDatabricksSession session)",
"public DatabricksResultSet executeStatement(String sql = {}, compute resource = {}, Map<Integer, ImmutableSqlParameter> parameters = {}, StatementType statementType = {}, IDatabricksSession session = {}, parentStatement = {})",
sql,
computeResource.toString(),
parameters,
statementType);
statementType,
session,
parentStatement);
DatabricksThreadContextHolder.setSessionId(session.getSessionId());
long pollCount = 0;
long executionStartTime = Instant.now().toEpochMilli();
DatabricksThreadContextHolder.setStatementType(statementType);
Expand Down Expand Up @@ -207,6 +212,7 @@ public DatabricksResultSet executeStatement(
computeResource,
statementId);
StatementId typedStatementId = new StatementId(statementId);
DatabricksThreadContextHolder.setStatementId(typedStatementId);
if (parentStatement != null) {
parentStatement.setStatementId(typedStatementId);
}
Expand Down Expand Up @@ -279,9 +285,12 @@ public DatabricksResultSet executeStatementAsync(
IDatabricksStatementInternal parentStatement)
throws SQLException {
LOGGER.debug(
"public DatabricksResultSet executeStatementAsync(String sql = {}, compute resource = {}, Map<Integer, ImmutableSqlParameter> parameters, IDatabricksSession session)",
"public DatabricksResultSet executeStatementAsync(String sql = {}, compute resource = {}, Map<Integer, ImmutableSqlParameter> parameters, IDatabricksSession session = {}, IDatabricksStatementInternal parentStatement = {})",
sql,
computeResource.toString());
computeResource.toString(),
session,
parentStatement);
DatabricksThreadContextHolder.setSessionId(session.getSessionId());
ExecuteStatementRequest request =
getRequest(
StatementType.SQL,
Expand All @@ -307,6 +316,7 @@ public DatabricksResultSet executeStatementAsync(
handleFailedExecution(response, "", sql);
}
StatementId typedStatementId = new StatementId(statementId);
DatabricksThreadContextHolder.setStatementId(typedStatementId);
if (parentStatement != null) {
parentStatement.setStatementId(typedStatementId);
}
Expand All @@ -328,6 +338,8 @@ public DatabricksResultSet getStatementResult(
IDatabricksSession session,
IDatabricksStatementInternal parentStatement)
throws DatabricksSQLException {
DatabricksThreadContextHolder.setStatementId(typedStatementId);
DatabricksThreadContextHolder.setSessionId(session.getSessionId());
String statementId = typedStatementId.toSQLExecStatementId();
GetStatementRequest request = new GetStatementRequest().setStatementId(statementId);
String getStatusPath = String.format(STATEMENT_PATH_WITH_ID, statementId);
Expand All @@ -354,6 +366,7 @@ public DatabricksResultSet getStatementResult(
@Override
public void closeStatement(StatementId typedStatementId) throws DatabricksSQLException {
String statementId = typedStatementId.toSQLExecStatementId();
DatabricksThreadContextHolder.setStatementId(typedStatementId);
LOGGER.debug(String.format("public void closeStatement(String statementId = {})", statementId));
CloseStatementRequest request = new CloseStatementRequest().setStatementId(statementId);
String path = String.format(STATEMENT_PATH_WITH_ID, request.getStatementId());
Expand All @@ -371,6 +384,7 @@ public void closeStatement(StatementId typedStatementId) throws DatabricksSQLExc
@Override
public void cancelStatement(StatementId typedStatementId) throws DatabricksSQLException {
String statementId = typedStatementId.toSQLExecStatementId();
DatabricksThreadContextHolder.setStatementId(typedStatementId);
LOGGER.debug("public void cancelStatement(String statementId = {})", statementId);
CancelStatementRequest request = new CancelStatementRequest().setStatementId(statementId);
String path = String.format(CANCEL_STATEMENT_PATH_WITH_ID, request.getStatementId());
Expand All @@ -388,6 +402,7 @@ public void cancelStatement(StatementId typedStatementId) throws DatabricksSQLEx
@Override
public Collection<ExternalLink> getResultChunks(StatementId typedStatementId, long chunkIndex)
throws DatabricksSQLException {
DatabricksThreadContextHolder.setStatementId(typedStatementId);
String statementId = typedStatementId.toSQLExecStatementId();
LOGGER.debug(
"public Optional<ExternalLink> getResultChunk(String statementId = {}, long chunkIndex = {})",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import com.databricks.jdbc.api.internal.IDatabricksStatementInternal;
import com.databricks.jdbc.common.DatabricksClientConfiguratorManager;
import com.databricks.jdbc.common.StatementType;
import com.databricks.jdbc.common.util.DatabricksThreadContextHolder;
import com.databricks.jdbc.common.util.DriverUtil;
import com.databricks.jdbc.common.util.ProtocolFeatureUtil;
import com.databricks.jdbc.dbclient.impl.common.StatementId;
Expand Down Expand Up @@ -221,6 +222,7 @@ DatabricksResultSet execute(
checkResponseForErrors(response);

StatementId statementId = new StatementId(response.getOperationHandle().operationId);
DatabricksThreadContextHolder.setStatementId(statementId);
if (parentStatement != null) {
parentStatement.setStatementId(statementId);
}
Expand Down Expand Up @@ -322,6 +324,7 @@ DatabricksResultSet executeAsync(
}
}
StatementId statementId = new StatementId(response.getOperationHandle().operationId);
DatabricksThreadContextHolder.setStatementId(statementId);
if (parentStatement != null) {
parentStatement.setStatementId(statementId);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ public ImmutableSessionInfo createSession(
}

String sessionId = byteBufferToString(response.sessionHandle.getSessionId().guid);
DatabricksThreadContextHolder.setSessionId(sessionId);
LOGGER.debug("Session created with ID {}", sessionId);
return ImmutableSessionInfo.builder()
.sessionId(sessionId)
Expand All @@ -123,6 +124,7 @@ public void deleteSession(ImmutableSessionInfo sessionInfo) throws DatabricksSQL
LOGGER.debug(
String.format(
"public void deleteSession(Session session = {%s}))", sessionInfo.toString()));
DatabricksThreadContextHolder.setSessionId(sessionInfo.sessionId());
TCloseSessionReq closeSessionReq =
new TCloseSessionReq().setSessionHandle(sessionInfo.sessionHandle());
TCloseSessionResp response =
Expand Down Expand Up @@ -191,7 +193,7 @@ private TExecuteStatementReq getRequest(
IDatabricksStatementInternal parentStatement,
boolean runAsync)
throws SQLException {

DatabricksThreadContextHolder.setSessionId(session.getSessionId());
TSparkArrowTypes arrowNativeTypes = new TSparkArrowTypes().setTimestampAsArrow(true);

// Convert the parameters to a list of TSparkParameter objects.
Expand Down Expand Up @@ -247,6 +249,7 @@ public void closeStatement(StatementId statementId) throws DatabricksSQLExceptio
String.format(
"public void closeStatement(String statementId = {%s}) using Thrift client",
statementId));
DatabricksThreadContextHolder.setStatementId(statementId);
TCloseOperationReq request =
new TCloseOperationReq().setOperationHandle(getOperationHandle(statementId));
TCloseOperationResp resp = thriftAccessor.closeOperation(request);
Expand All @@ -259,6 +262,7 @@ public void cancelStatement(StatementId statementId) throws DatabricksSQLExcepti
String.format(
"public void cancelStatement(String statementId = {%s}) using Thrift client",
statementId));
DatabricksThreadContextHolder.setStatementId(statementId);
TCancelOperationReq request =
new TCancelOperationReq().setOperationHandle(getOperationHandle(statementId));
TCancelOperationResp resp = thriftAccessor.cancelOperation(request);
Expand All @@ -275,6 +279,8 @@ public DatabricksResultSet getStatementResult(
String.format(
"public DatabricksResultSet getStatementResult(String statementId = {%s}) using Thrift client",
statementId));
DatabricksThreadContextHolder.setStatementId(statementId);
DatabricksThreadContextHolder.setSessionId(session.getSessionId());
return thriftAccessor.getStatementResult(
getOperationHandle(statementId), parentStatement, session);
}
Expand All @@ -287,6 +293,7 @@ public Collection<ExternalLink> getResultChunks(StatementId statementId, long ch
"public Optional<ExternalLink> getResultChunk(String statementId = {%s}, long chunkIndex = {%s}) using Thrift client",
statementId, chunkIndex);
LOGGER.debug(context);
DatabricksThreadContextHolder.setStatementId(statementId);
TFetchResultsResp fetchResultsResp;
List<ExternalLink> externalLinks = new ArrayList<>();
AtomicInteger index = new AtomicInteger(0);
Expand Down Expand Up @@ -318,6 +325,7 @@ public DatabricksResultSet listCatalogs(IDatabricksSession session) throws SQLEx
String context =
String.format("Fetching catalogs using Thrift client. Session {%s}", session.toString());
LOGGER.debug(context);
DatabricksThreadContextHolder.setSessionId(session.getSessionId());
TGetCatalogsReq request =
new TGetCatalogsReq()
.setSessionHandle(Objects.requireNonNull(session.getSessionInfo()).sessionHandle());
Expand All @@ -337,6 +345,7 @@ public DatabricksResultSet listSchemas(
"Fetching schemas using Thrift client. Session {%s}, catalog {%s}, schemaNamePattern {%s}",
session.toString(), catalog, schemaNamePattern);
LOGGER.debug(context);
DatabricksThreadContextHolder.setSessionId(session.getSessionId());
TGetSchemasReq request =
new TGetSchemasReq()
.setSessionHandle(Objects.requireNonNull(session.getSessionInfo()).sessionHandle())
Expand Down Expand Up @@ -365,6 +374,7 @@ public DatabricksResultSet listTables(
"Fetching tables using Thrift client. Session {%s}, catalog {%s}, schemaNamePattern {%s}, tableNamePattern {%s}",
session.toString(), catalog, schemaNamePattern, tableNamePattern);
LOGGER.debug(context);
DatabricksThreadContextHolder.setSessionId(session.getSessionId());
TGetTablesReq request =
new TGetTablesReq()
.setSessionHandle(Objects.requireNonNull(session.getSessionInfo()).sessionHandle())
Expand All @@ -387,6 +397,7 @@ public DatabricksResultSet listTableTypes(IDatabricksSession session) {
LOGGER.debug(
String.format(
"Fetching table types using Thrift client. Session {%s}", session.toString()));
DatabricksThreadContextHolder.setSessionId(session.getSessionId());
return metadataResultSetBuilder.getTableTypesResult();
}

Expand All @@ -403,6 +414,7 @@ public DatabricksResultSet listColumns(
"Fetching columns using Thrift client. Session {%s}, catalog {%s}, schemaNamePattern {%s}, tableNamePattern {%s}, columnNamePattern {%s}",
session.toString(), catalog, schemaNamePattern, tableNamePattern, columnNamePattern);
LOGGER.debug(context);
DatabricksThreadContextHolder.setSessionId(session.getSessionId());
TGetColumnsReq request =
new TGetColumnsReq()
.setSessionHandle(Objects.requireNonNull(session.getSessionInfo()).sessionHandle())
Expand All @@ -429,6 +441,7 @@ public DatabricksResultSet listFunctions(
String.format(
"Fetching functions using Thrift client. Session {%s}, catalog {%s}, schemaNamePattern {%s}, functionNamePattern {%s}.",
session.toString(), catalog, schemaNamePattern, functionNamePattern);
DatabricksThreadContextHolder.setSessionId(session.getSessionId());
LOGGER.debug(context);
TGetFunctionsReq request =
new TGetFunctionsReq()
Expand All @@ -452,6 +465,7 @@ public DatabricksResultSet listPrimaryKeys(
"Fetching primary keys using Thrift client. session {%s}, catalog {%s}, schema {%s}, table {%s}",
session.toString(), catalog, schema, table);
LOGGER.debug(context);
DatabricksThreadContextHolder.setSessionId(session.getSessionId());
TGetPrimaryKeysReq request =
new TGetPrimaryKeysReq()
.setSessionHandle(Objects.requireNonNull(session.getSessionInfo()).sessionHandle())
Expand All @@ -474,6 +488,7 @@ public DatabricksResultSet listImportedKeys(
"Fetching imported keys using Thrift client for session {%s}, catalog {%s}, schema {%s}, table {%s}",
session.toString(), catalog, schema, table);
LOGGER.debug(context);
DatabricksThreadContextHolder.setSessionId(session.getSessionId());
// GetImportedKeys is implemented using GetCrossReferences
// When only foreign table name is provided, we get imported keys
TGetCrossReferenceReq request =
Expand Down
15 changes: 7 additions & 8 deletions src/main/java/com/databricks/jdbc/telemetry/TelemetryHelper.java
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import com.databricks.jdbc.common.util.DatabricksThreadContextHolder;
import com.databricks.jdbc.common.util.DriverUtil;
import com.databricks.jdbc.common.util.StringUtil;
import com.databricks.jdbc.dbclient.impl.common.StatementId;
import com.databricks.jdbc.exception.DatabricksParsingException;
import com.databricks.jdbc.log.JdbcLogger;
import com.databricks.jdbc.log.JdbcLoggerFactory;
Expand Down Expand Up @@ -57,7 +56,6 @@ public static boolean isTelemetryAllowedForConnection(IDatabricksConnectionConte
.isFeatureEnabled(TELEMETRY_FEATURE_FLAG_NAME);
}

// TODO : add an export even before connection context is built
public static void exportInitialTelemetryLog(IDatabricksConnectionContext connectionContext) {
if (connectionContext == null) {
return;
Expand Down Expand Up @@ -116,26 +114,27 @@ public static void exportLatencyLog(long executionTime) {
DatabricksThreadContextHolder.getConnectionContext(),
executionTime,
executionEvent,
DatabricksThreadContextHolder.getStatementId());
DatabricksThreadContextHolder.getStatementId(),
DatabricksThreadContextHolder.getSessionId());
}

@VisibleForTesting
static void exportLatencyLog(
IDatabricksConnectionContext connectionContext,
long latencyMilliseconds,
SqlExecutionEvent executionEvent,
StatementId statementId) {
String statementId,
String sessionId) {
// Though we already handle null connectionContext in the downstream implementation,
// we are adding this check for extra sanity
if (connectionContext != null) {
TelemetryEvent telemetryEvent =
new TelemetryEvent()
.setLatency(latencyMilliseconds)
.setSqlOperation(executionEvent)
.setDriverConnectionParameters(getDriverConnectionParameter(connectionContext));
if (statementId != null) {
telemetryEvent.setSqlStatementId(statementId.toString());
}
.setDriverConnectionParameters(getDriverConnectionParameter(connectionContext))
.setSqlStatementId(statementId)
.setSessionId(sessionId);

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can sessionId be null?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, in the 2 cases :

  1. if create session fails for a reason other than auth
  2. if it is a multiThreaded call

TelemetryFrontendLog telemetryFrontendLog =
new TelemetryFrontendLog()
.setFrontendLogEventId(getEventUUID())
Expand Down
1 change: 1 addition & 0 deletions src/test/java/com/databricks/jdbc/TestConstants.java
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ public class TestConstants {
public static final String TEST_FOREIGN_TABLE = "foreignTable";
public static final String TEST_FUNCTION_PATTERN = "functionPattern";
public static final String TEST_STRING = "test";
public static final String TEST_STRING_2 = "test2";
public static final String TEST_USER = "testUser";
public static final String TEST_PASSWORD = "testPassword";
public static final StatementId TEST_STATEMENT_ID = new StatementId("statement_id");
Expand Down
Loading
Loading