diff --git a/src/main/java/com/databricks/jdbc/api/impl/arrow/ChunkDownloadTask.java b/src/main/java/com/databricks/jdbc/api/impl/arrow/ChunkDownloadTask.java index 741c587826..df6e810e96 100644 --- a/src/main/java/com/databricks/jdbc/api/impl/arrow/ChunkDownloadTask.java +++ b/src/main/java/com/databricks/jdbc/api/impl/arrow/ChunkDownloadTask.java @@ -5,7 +5,6 @@ import com.databricks.jdbc.api.internal.IDatabricksConnectionContext; import com.databricks.jdbc.common.util.DatabricksThreadContextHolder; import com.databricks.jdbc.dbclient.IDatabricksHttpClient; -import com.databricks.jdbc.dbclient.impl.common.StatementId; import com.databricks.jdbc.exception.DatabricksParsingException; import com.databricks.jdbc.exception.DatabricksSQLException; import com.databricks.jdbc.log.JdbcLogger; @@ -26,7 +25,7 @@ class ChunkDownloadTask implements DatabricksCallableTask { private final IDatabricksHttpClient httpClient; private final ChunkDownloadCallback chunkDownloader; private final IDatabricksConnectionContext connectionContext; - private final StatementId statementId; + private final String statementId; private final ChunkLinkDownloadService linkDownloadService; Throwable uncaughtException = null; diff --git a/src/main/java/com/databricks/jdbc/common/util/DatabricksThreadContextHolder.java b/src/main/java/com/databricks/jdbc/common/util/DatabricksThreadContextHolder.java index 3ad3bb5e86..8066739370 100644 --- a/src/main/java/com/databricks/jdbc/common/util/DatabricksThreadContextHolder.java +++ b/src/main/java/com/databricks/jdbc/common/util/DatabricksThreadContextHolder.java @@ -4,13 +4,15 @@ import com.databricks.jdbc.common.StatementType; import com.databricks.jdbc.dbclient.impl.common.StatementId; +/* TODO : eliminate the use of thread local completely. Currently, we are limiting the usage of this for non-critical flows such as telemetry.*/ public class DatabricksThreadContextHolder { private static final ThreadLocal localConnectionContext = new ThreadLocal<>(); - private static final ThreadLocal localStatementId = new ThreadLocal<>(); + private static final ThreadLocal localStatementId = new ThreadLocal<>(); private static final ThreadLocal localChunkId = new ThreadLocal<>(); private static final ThreadLocal localRetryCount = new ThreadLocal<>(); private static final ThreadLocal localStatementType = new ThreadLocal<>(); + private static final ThreadLocal localSessionId = new ThreadLocal<>(); public static void setConnectionContext(IDatabricksConnectionContext context) { localConnectionContext.set(context); @@ -21,13 +23,28 @@ public static IDatabricksConnectionContext getConnectionContext() { } public static void setStatementId(StatementId statementId) { + if (statementId != null) { + localStatementId.set( + statementId.toSQLExecStatementId()); // This is because only GUID is relevant for tracking + } + } + + public static void setStatementId(String statementId) { localStatementId.set(statementId); } - public static StatementId getStatementId() { + public static String getStatementId() { return localStatementId.get(); } + public static void setSessionId(String sessionId) { + localSessionId.set(sessionId); + } + + public static String getSessionId() { + return localSessionId.get(); + } + public static void setStatementType(StatementType statementType) { localStatementType.set(statementType); } diff --git a/src/main/java/com/databricks/jdbc/dbclient/impl/sqlexec/DatabricksSdkClient.java b/src/main/java/com/databricks/jdbc/dbclient/impl/sqlexec/DatabricksSdkClient.java index 210d19861e..3af22949aa 100644 --- a/src/main/java/com/databricks/jdbc/dbclient/impl/sqlexec/DatabricksSdkClient.java +++ b/src/main/java/com/databricks/jdbc/dbclient/impl/sqlexec/DatabricksSdkClient.java @@ -128,6 +128,7 @@ public ImmutableSessionInfo createSession( LOGGER.error(errorMessage, e); throw new DatabricksSQLException(errorMessage, e, DatabricksDriverErrorCode.SDK_CLIENT_ERROR); } + DatabricksThreadContextHolder.setSessionId(createSessionResponse.getSessionId()); return ImmutableSessionInfo.builder() .computeResource(warehouse) .sessionId(createSessionResponse.getSessionId()) @@ -137,6 +138,7 @@ public ImmutableSessionInfo createSession( @Override public void deleteSession(ImmutableSessionInfo sessionInfo) throws DatabricksSQLException { LOGGER.debug("public void deleteSession(String sessionId = {})", sessionInfo.sessionId()); + DatabricksThreadContextHolder.setSessionId(sessionInfo.sessionId()); DeleteSessionRequest request = new DeleteSessionRequest() .setSessionId(sessionInfo.sessionId()) @@ -164,11 +166,14 @@ public DatabricksResultSet executeStatement( IDatabricksStatementInternal parentStatement) throws SQLException { LOGGER.debug( - "public DatabricksResultSet executeStatement(String sql = {}, compute resource = {}, Map parameters = {}, StatementType statementType = {}, IDatabricksSession session)", + "public DatabricksResultSet executeStatement(String sql = {}, compute resource = {}, Map parameters = {}, StatementType statementType = {}, IDatabricksSession session = {}, parentStatement = {})", sql, computeResource.toString(), parameters, - statementType); + statementType, + session, + parentStatement); + DatabricksThreadContextHolder.setSessionId(session.getSessionId()); long pollCount = 0; long executionStartTime = Instant.now().toEpochMilli(); DatabricksThreadContextHolder.setStatementType(statementType); @@ -207,6 +212,7 @@ public DatabricksResultSet executeStatement( computeResource, statementId); StatementId typedStatementId = new StatementId(statementId); + DatabricksThreadContextHolder.setStatementId(typedStatementId); if (parentStatement != null) { parentStatement.setStatementId(typedStatementId); } @@ -279,9 +285,12 @@ public DatabricksResultSet executeStatementAsync( IDatabricksStatementInternal parentStatement) throws SQLException { LOGGER.debug( - "public DatabricksResultSet executeStatementAsync(String sql = {}, compute resource = {}, Map parameters, IDatabricksSession session)", + "public DatabricksResultSet executeStatementAsync(String sql = {}, compute resource = {}, Map parameters, IDatabricksSession session = {}, IDatabricksStatementInternal parentStatement = {})", sql, - computeResource.toString()); + computeResource.toString(), + session, + parentStatement); + DatabricksThreadContextHolder.setSessionId(session.getSessionId()); ExecuteStatementRequest request = getRequest( StatementType.SQL, @@ -307,6 +316,7 @@ public DatabricksResultSet executeStatementAsync( handleFailedExecution(response, "", sql); } StatementId typedStatementId = new StatementId(statementId); + DatabricksThreadContextHolder.setStatementId(typedStatementId); if (parentStatement != null) { parentStatement.setStatementId(typedStatementId); } @@ -328,6 +338,8 @@ public DatabricksResultSet getStatementResult( IDatabricksSession session, IDatabricksStatementInternal parentStatement) throws DatabricksSQLException { + DatabricksThreadContextHolder.setStatementId(typedStatementId); + DatabricksThreadContextHolder.setSessionId(session.getSessionId()); String statementId = typedStatementId.toSQLExecStatementId(); GetStatementRequest request = new GetStatementRequest().setStatementId(statementId); String getStatusPath = String.format(STATEMENT_PATH_WITH_ID, statementId); @@ -354,6 +366,7 @@ public DatabricksResultSet getStatementResult( @Override public void closeStatement(StatementId typedStatementId) throws DatabricksSQLException { String statementId = typedStatementId.toSQLExecStatementId(); + DatabricksThreadContextHolder.setStatementId(typedStatementId); LOGGER.debug(String.format("public void closeStatement(String statementId = {})", statementId)); CloseStatementRequest request = new CloseStatementRequest().setStatementId(statementId); String path = String.format(STATEMENT_PATH_WITH_ID, request.getStatementId()); @@ -371,6 +384,7 @@ public void closeStatement(StatementId typedStatementId) throws DatabricksSQLExc @Override public void cancelStatement(StatementId typedStatementId) throws DatabricksSQLException { String statementId = typedStatementId.toSQLExecStatementId(); + DatabricksThreadContextHolder.setStatementId(typedStatementId); LOGGER.debug("public void cancelStatement(String statementId = {})", statementId); CancelStatementRequest request = new CancelStatementRequest().setStatementId(statementId); String path = String.format(CANCEL_STATEMENT_PATH_WITH_ID, request.getStatementId()); @@ -388,6 +402,7 @@ public void cancelStatement(StatementId typedStatementId) throws DatabricksSQLEx @Override public Collection getResultChunks(StatementId typedStatementId, long chunkIndex) throws DatabricksSQLException { + DatabricksThreadContextHolder.setStatementId(typedStatementId); String statementId = typedStatementId.toSQLExecStatementId(); LOGGER.debug( "public Optional getResultChunk(String statementId = {}, long chunkIndex = {})", diff --git a/src/main/java/com/databricks/jdbc/dbclient/impl/thrift/DatabricksThriftAccessor.java b/src/main/java/com/databricks/jdbc/dbclient/impl/thrift/DatabricksThriftAccessor.java index 59fdc86608..8459b58a5a 100644 --- a/src/main/java/com/databricks/jdbc/dbclient/impl/thrift/DatabricksThriftAccessor.java +++ b/src/main/java/com/databricks/jdbc/dbclient/impl/thrift/DatabricksThriftAccessor.java @@ -9,6 +9,7 @@ import com.databricks.jdbc.api.internal.IDatabricksStatementInternal; import com.databricks.jdbc.common.DatabricksClientConfiguratorManager; import com.databricks.jdbc.common.StatementType; +import com.databricks.jdbc.common.util.DatabricksThreadContextHolder; import com.databricks.jdbc.common.util.DriverUtil; import com.databricks.jdbc.common.util.ProtocolFeatureUtil; import com.databricks.jdbc.dbclient.impl.common.StatementId; @@ -221,6 +222,7 @@ DatabricksResultSet execute( checkResponseForErrors(response); StatementId statementId = new StatementId(response.getOperationHandle().operationId); + DatabricksThreadContextHolder.setStatementId(statementId); if (parentStatement != null) { parentStatement.setStatementId(statementId); } @@ -322,6 +324,7 @@ DatabricksResultSet executeAsync( } } StatementId statementId = new StatementId(response.getOperationHandle().operationId); + DatabricksThreadContextHolder.setStatementId(statementId); if (parentStatement != null) { parentStatement.setStatementId(statementId); } diff --git a/src/main/java/com/databricks/jdbc/dbclient/impl/thrift/DatabricksThriftServiceClient.java b/src/main/java/com/databricks/jdbc/dbclient/impl/thrift/DatabricksThriftServiceClient.java index 76ce3cd3bd..493e8c84f1 100644 --- a/src/main/java/com/databricks/jdbc/dbclient/impl/thrift/DatabricksThriftServiceClient.java +++ b/src/main/java/com/databricks/jdbc/dbclient/impl/thrift/DatabricksThriftServiceClient.java @@ -110,6 +110,7 @@ public ImmutableSessionInfo createSession( } String sessionId = byteBufferToString(response.sessionHandle.getSessionId().guid); + DatabricksThreadContextHolder.setSessionId(sessionId); LOGGER.debug("Session created with ID {}", sessionId); return ImmutableSessionInfo.builder() .sessionId(sessionId) @@ -123,6 +124,7 @@ public void deleteSession(ImmutableSessionInfo sessionInfo) throws DatabricksSQL LOGGER.debug( String.format( "public void deleteSession(Session session = {%s}))", sessionInfo.toString())); + DatabricksThreadContextHolder.setSessionId(sessionInfo.sessionId()); TCloseSessionReq closeSessionReq = new TCloseSessionReq().setSessionHandle(sessionInfo.sessionHandle()); TCloseSessionResp response = @@ -191,7 +193,7 @@ private TExecuteStatementReq getRequest( IDatabricksStatementInternal parentStatement, boolean runAsync) throws SQLException { - + DatabricksThreadContextHolder.setSessionId(session.getSessionId()); TSparkArrowTypes arrowNativeTypes = new TSparkArrowTypes().setTimestampAsArrow(true); // Convert the parameters to a list of TSparkParameter objects. @@ -247,6 +249,7 @@ public void closeStatement(StatementId statementId) throws DatabricksSQLExceptio String.format( "public void closeStatement(String statementId = {%s}) using Thrift client", statementId)); + DatabricksThreadContextHolder.setStatementId(statementId); TCloseOperationReq request = new TCloseOperationReq().setOperationHandle(getOperationHandle(statementId)); TCloseOperationResp resp = thriftAccessor.closeOperation(request); @@ -259,6 +262,7 @@ public void cancelStatement(StatementId statementId) throws DatabricksSQLExcepti String.format( "public void cancelStatement(String statementId = {%s}) using Thrift client", statementId)); + DatabricksThreadContextHolder.setStatementId(statementId); TCancelOperationReq request = new TCancelOperationReq().setOperationHandle(getOperationHandle(statementId)); TCancelOperationResp resp = thriftAccessor.cancelOperation(request); @@ -275,6 +279,8 @@ public DatabricksResultSet getStatementResult( String.format( "public DatabricksResultSet getStatementResult(String statementId = {%s}) using Thrift client", statementId)); + DatabricksThreadContextHolder.setStatementId(statementId); + DatabricksThreadContextHolder.setSessionId(session.getSessionId()); return thriftAccessor.getStatementResult( getOperationHandle(statementId), parentStatement, session); } @@ -287,6 +293,7 @@ public Collection getResultChunks(StatementId statementId, long ch "public Optional getResultChunk(String statementId = {%s}, long chunkIndex = {%s}) using Thrift client", statementId, chunkIndex); LOGGER.debug(context); + DatabricksThreadContextHolder.setStatementId(statementId); TFetchResultsResp fetchResultsResp; List externalLinks = new ArrayList<>(); AtomicInteger index = new AtomicInteger(0); @@ -318,6 +325,7 @@ public DatabricksResultSet listCatalogs(IDatabricksSession session) throws SQLEx String context = String.format("Fetching catalogs using Thrift client. Session {%s}", session.toString()); LOGGER.debug(context); + DatabricksThreadContextHolder.setSessionId(session.getSessionId()); TGetCatalogsReq request = new TGetCatalogsReq() .setSessionHandle(Objects.requireNonNull(session.getSessionInfo()).sessionHandle()); @@ -337,6 +345,7 @@ public DatabricksResultSet listSchemas( "Fetching schemas using Thrift client. Session {%s}, catalog {%s}, schemaNamePattern {%s}", session.toString(), catalog, schemaNamePattern); LOGGER.debug(context); + DatabricksThreadContextHolder.setSessionId(session.getSessionId()); TGetSchemasReq request = new TGetSchemasReq() .setSessionHandle(Objects.requireNonNull(session.getSessionInfo()).sessionHandle()) @@ -365,6 +374,7 @@ public DatabricksResultSet listTables( "Fetching tables using Thrift client. Session {%s}, catalog {%s}, schemaNamePattern {%s}, tableNamePattern {%s}", session.toString(), catalog, schemaNamePattern, tableNamePattern); LOGGER.debug(context); + DatabricksThreadContextHolder.setSessionId(session.getSessionId()); TGetTablesReq request = new TGetTablesReq() .setSessionHandle(Objects.requireNonNull(session.getSessionInfo()).sessionHandle()) @@ -387,6 +397,7 @@ public DatabricksResultSet listTableTypes(IDatabricksSession session) { LOGGER.debug( String.format( "Fetching table types using Thrift client. Session {%s}", session.toString())); + DatabricksThreadContextHolder.setSessionId(session.getSessionId()); return metadataResultSetBuilder.getTableTypesResult(); } @@ -403,6 +414,7 @@ public DatabricksResultSet listColumns( "Fetching columns using Thrift client. Session {%s}, catalog {%s}, schemaNamePattern {%s}, tableNamePattern {%s}, columnNamePattern {%s}", session.toString(), catalog, schemaNamePattern, tableNamePattern, columnNamePattern); LOGGER.debug(context); + DatabricksThreadContextHolder.setSessionId(session.getSessionId()); TGetColumnsReq request = new TGetColumnsReq() .setSessionHandle(Objects.requireNonNull(session.getSessionInfo()).sessionHandle()) @@ -429,6 +441,7 @@ public DatabricksResultSet listFunctions( String.format( "Fetching functions using Thrift client. Session {%s}, catalog {%s}, schemaNamePattern {%s}, functionNamePattern {%s}.", session.toString(), catalog, schemaNamePattern, functionNamePattern); + DatabricksThreadContextHolder.setSessionId(session.getSessionId()); LOGGER.debug(context); TGetFunctionsReq request = new TGetFunctionsReq() @@ -452,6 +465,7 @@ public DatabricksResultSet listPrimaryKeys( "Fetching primary keys using Thrift client. session {%s}, catalog {%s}, schema {%s}, table {%s}", session.toString(), catalog, schema, table); LOGGER.debug(context); + DatabricksThreadContextHolder.setSessionId(session.getSessionId()); TGetPrimaryKeysReq request = new TGetPrimaryKeysReq() .setSessionHandle(Objects.requireNonNull(session.getSessionInfo()).sessionHandle()) @@ -474,6 +488,7 @@ public DatabricksResultSet listImportedKeys( "Fetching imported keys using Thrift client for session {%s}, catalog {%s}, schema {%s}, table {%s}", session.toString(), catalog, schema, table); LOGGER.debug(context); + DatabricksThreadContextHolder.setSessionId(session.getSessionId()); // GetImportedKeys is implemented using GetCrossReferences // When only foreign table name is provided, we get imported keys TGetCrossReferenceReq request = diff --git a/src/main/java/com/databricks/jdbc/telemetry/TelemetryHelper.java b/src/main/java/com/databricks/jdbc/telemetry/TelemetryHelper.java index c0870b8d6a..3aa856c143 100644 --- a/src/main/java/com/databricks/jdbc/telemetry/TelemetryHelper.java +++ b/src/main/java/com/databricks/jdbc/telemetry/TelemetryHelper.java @@ -6,7 +6,6 @@ import com.databricks.jdbc.common.util.DatabricksThreadContextHolder; import com.databricks.jdbc.common.util.DriverUtil; import com.databricks.jdbc.common.util.StringUtil; -import com.databricks.jdbc.dbclient.impl.common.StatementId; import com.databricks.jdbc.exception.DatabricksParsingException; import com.databricks.jdbc.log.JdbcLogger; import com.databricks.jdbc.log.JdbcLoggerFactory; @@ -57,7 +56,6 @@ public static boolean isTelemetryAllowedForConnection(IDatabricksConnectionConte .isFeatureEnabled(TELEMETRY_FEATURE_FLAG_NAME); } - // TODO : add an export even before connection context is built public static void exportInitialTelemetryLog(IDatabricksConnectionContext connectionContext) { if (connectionContext == null) { return; @@ -116,7 +114,8 @@ public static void exportLatencyLog(long executionTime) { DatabricksThreadContextHolder.getConnectionContext(), executionTime, executionEvent, - DatabricksThreadContextHolder.getStatementId()); + DatabricksThreadContextHolder.getStatementId(), + DatabricksThreadContextHolder.getSessionId()); } @VisibleForTesting @@ -124,7 +123,8 @@ static void exportLatencyLog( IDatabricksConnectionContext connectionContext, long latencyMilliseconds, SqlExecutionEvent executionEvent, - StatementId statementId) { + String statementId, + String sessionId) { // Though we already handle null connectionContext in the downstream implementation, // we are adding this check for extra sanity if (connectionContext != null) { @@ -132,10 +132,9 @@ static void exportLatencyLog( new TelemetryEvent() .setLatency(latencyMilliseconds) .setSqlOperation(executionEvent) - .setDriverConnectionParameters(getDriverConnectionParameter(connectionContext)); - if (statementId != null) { - telemetryEvent.setSqlStatementId(statementId.toString()); - } + .setDriverConnectionParameters(getDriverConnectionParameter(connectionContext)) + .setSqlStatementId(statementId) + .setSessionId(sessionId); TelemetryFrontendLog telemetryFrontendLog = new TelemetryFrontendLog() .setFrontendLogEventId(getEventUUID()) diff --git a/src/test/java/com/databricks/jdbc/TestConstants.java b/src/test/java/com/databricks/jdbc/TestConstants.java index 1417f0b3d7..8cd9bf97d4 100644 --- a/src/test/java/com/databricks/jdbc/TestConstants.java +++ b/src/test/java/com/databricks/jdbc/TestConstants.java @@ -32,6 +32,7 @@ public class TestConstants { public static final String TEST_FOREIGN_TABLE = "foreignTable"; public static final String TEST_FUNCTION_PATTERN = "functionPattern"; public static final String TEST_STRING = "test"; + public static final String TEST_STRING_2 = "test2"; public static final String TEST_USER = "testUser"; public static final String TEST_PASSWORD = "testPassword"; public static final StatementId TEST_STATEMENT_ID = new StatementId("statement_id"); diff --git a/src/test/java/com/databricks/jdbc/telemetry/TelemetryHelperTest.java b/src/test/java/com/databricks/jdbc/telemetry/TelemetryHelperTest.java index 6c694e9755..e45b2ea70f 100644 --- a/src/test/java/com/databricks/jdbc/telemetry/TelemetryHelperTest.java +++ b/src/test/java/com/databricks/jdbc/telemetry/TelemetryHelperTest.java @@ -1,7 +1,6 @@ package com.databricks.jdbc.telemetry; -import static com.databricks.jdbc.TestConstants.TEST_STRING; -import static com.databricks.jdbc.TestConstants.WAREHOUSE_COMPUTE; +import static com.databricks.jdbc.TestConstants.*; import static com.databricks.jdbc.common.safe.FeatureFlagTestUtil.enableFeatureFlagForTesting; import static com.databricks.jdbc.telemetry.TelemetryHelper.isTelemetryAllowedForConnection; import static org.junit.jupiter.api.Assertions.*; @@ -63,11 +62,24 @@ void testHostFetchThrowsErrorInTelemetryLog() throws DatabricksParsingException @Test void testLatencyTelemetryLogDoesNotThrowError() { + TelemetryHelper telemetryHelper = new TelemetryHelper(); // Increasing coverage for class + when(connectionContext.getConnectionUuid()).thenReturn(TEST_STRING_2); + when(connectionContext.getClientType()).thenReturn(DatabricksClientType.SEA); + SqlExecutionEvent event = new SqlExecutionEvent().setDriverStatementType(StatementType.QUERY); + assertDoesNotThrow( + () -> + telemetryHelper.exportLatencyLog( + connectionContext, 150, event, TEST_STRING, SESSION_ID)); + } + + @Test + void testLatencyTelemetryLogDoesNotThrowErrorWithNullStatementId() { TelemetryHelper telemetryHelper = new TelemetryHelper(); // Increasing coverage for class when(connectionContext.getConnectionUuid()).thenReturn(TEST_STRING); when(connectionContext.getClientType()).thenReturn(DatabricksClientType.SEA); SqlExecutionEvent event = new SqlExecutionEvent().setDriverStatementType(StatementType.QUERY); - assertDoesNotThrow(() -> telemetryHelper.exportLatencyLog(connectionContext, 150, event, null)); + assertDoesNotThrow( + () -> telemetryHelper.exportLatencyLog(connectionContext, 150, event, null, SESSION_ID)); } @Test