Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -942,4 +942,16 @@ private Map<String, String> parseCustomHeaders(ImmutableMap<String, String> para
Collectors.toMap(
entry -> entry.getKey().substring(filterPrefix.length()), Map.Entry::getValue));
}

@Override
public boolean forceEnableTelemetry() {
return getParameter(DatabricksJdbcUrlParams.FORCE_ENABLE_TELEMETRY).equals("1");
}

@Override
public int getTelemetryFlushIntervalInMilliseconds() {
// There is a minimum threshold of 1000ms for the flush interval
return Math.max(
1000, Integer.parseInt(getParameter(DatabricksJdbcUrlParams.TELEMETRY_FLUSH_INTERVAL)));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -320,4 +320,10 @@ public interface IDatabricksConnectionContext {

/** Returns the application name using JDBC Connection */
String getApplicationName();

/** Returns whether telemetry is enabled for all connections */
boolean forceEnableTelemetry();

/** Returns the flush interval in milliseconds for telemetry */
int getTelemetryFlushIntervalInMilliseconds();
}
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,8 @@ public enum DatabricksJdbcUrlParams {
TOKEN_CACHE_PASS_PHRASE("TokenCachePassPhrase", "Pass phrase to use for OAuth U2M Token Cache"),
ENABLE_TOKEN_CACHE("EnableTokenCache", "Enable caching OAuth tokens", "1"),
APPLICATION_NAME("ApplicationName", "Name of application using the driver", ""),
FORCE_ENABLE_TELEMETRY("ForceEnableTelemetry", "Force enable telemetry", "0"),
TELEMETRY_FLUSH_INTERVAL("TelemetryFlushInterval", "Flush interval in milliseconds", "5000"),
;

private final String paramName;
Expand Down
52 changes: 47 additions & 5 deletions src/main/java/com/databricks/jdbc/telemetry/TelemetryClient.java
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
import java.util.LinkedList;
import java.util.List;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.ScheduledFuture;
import java.util.concurrent.TimeUnit;

public class TelemetryClient implements ITelemetryClient {

Expand All @@ -16,7 +19,11 @@ public class TelemetryClient implements ITelemetryClient {
private final int eventsBatchSize;
private final boolean isAuthEnabled;
private final ExecutorService executorService;
private final ScheduledExecutorService scheduledExecutorService;
private List<TelemetryFrontendLog> eventsBatch;
private volatile long lastFlushedTime;
private ScheduledFuture<?> flushTask;
private final int flushIntervalMillis;

public TelemetryClient(
IDatabricksConnectionContext connectionContext,
Expand All @@ -28,6 +35,11 @@ public TelemetryClient(
this.context = connectionContext;
this.databricksConfig = config;
this.executorService = executorService;
this.scheduledExecutorService =
java.util.concurrent.Executors.newSingleThreadScheduledExecutor();
this.flushIntervalMillis = context.getTelemetryFlushIntervalInMilliseconds();
this.lastFlushedTime = System.currentTimeMillis();
schedulePeriodicFlush();
}

public TelemetryClient(
Expand All @@ -38,6 +50,27 @@ public TelemetryClient(
this.context = connectionContext;
this.databricksConfig = null;
this.executorService = executorService;
this.scheduledExecutorService =
java.util.concurrent.Executors.newSingleThreadScheduledExecutor();
this.flushIntervalMillis = context.getTelemetryFlushIntervalInMilliseconds();
this.lastFlushedTime = System.currentTimeMillis();
schedulePeriodicFlush();
}

private void schedulePeriodicFlush() {
if (flushTask != null) {
flushTask.cancel(false);
}
flushTask =
scheduledExecutorService.scheduleAtFixedRate(
this::periodicFlush, flushIntervalMillis, flushIntervalMillis, TimeUnit.MILLISECONDS);
}

private void periodicFlush() {
long now = System.currentTimeMillis();
if (now - lastFlushedTime >= flushIntervalMillis) {
flush();
}
}

@Override
Expand All @@ -61,6 +94,10 @@ public void close() {
TelemetryHelper.exportChunkLatencyTelemetry(chunkDetails, statementId);
});
flush();
if (flushTask != null) {
flushTask.cancel(false);
}
scheduledExecutorService.shutdown();
}

@Override
Expand All @@ -75,14 +112,19 @@ public void closeStatement(String statementId) {

private void flush() {
synchronized (this) {
List<TelemetryFrontendLog> logsToBeFlushed = eventsBatch;
executorService.submit(
new TelemetryPushTask(logsToBeFlushed, isAuthEnabled, context, databricksConfig));
eventsBatch = new LinkedList<>();
if (!eventsBatch.isEmpty()) {
List<TelemetryFrontendLog> logsToBeFlushed = eventsBatch;
executorService.submit(
new TelemetryPushTask(logsToBeFlushed, isAuthEnabled, context, databricksConfig));
eventsBatch = new LinkedList<>();
}
lastFlushedTime = System.currentTimeMillis();
}
}

int getCurrentSize() {
return eventsBatch.size();
synchronized (this) {
return eventsBatch.size();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,9 @@ public static void updateClientAppName(String clientAppName) {
}

public static boolean isTelemetryAllowedForConnection(IDatabricksConnectionContext context) {
if (context.forceEnableTelemetry()) {
return true;
}
return context != null
&& context.isTelemetryEnabled()
&& DatabricksDriverFeatureFlagsContextFactory.getInstance(context)
Expand Down
112 changes: 112 additions & 0 deletions src/test/java/com/databricks/jdbc/telemetry/TelemetryClientTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
import com.google.common.util.concurrent.MoreExecutors;
import java.util.Map;
import java.util.Properties;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.TimeUnit;
import org.apache.http.HttpHeaders;
import org.apache.http.StatusLine;
import org.apache.http.client.methods.CloseableHttpResponse;
Expand Down Expand Up @@ -139,4 +141,114 @@ public void testExportEventDoesNotThrowErrorsInFailures() throws Exception {
() -> client.exportEvent(new TelemetryFrontendLog().setFrontendLogEventId("event2")));
}
}

@Test
public void testPeriodicFlushWithAuthenticatedClient() throws Exception {
try (MockedStatic<DatabricksHttpClientFactory> factoryMocked =
mockStatic(DatabricksHttpClientFactory.class)) {
DatabricksHttpClientFactory mockFactory = mock(DatabricksHttpClientFactory.class);
factoryMocked.when(DatabricksHttpClientFactory::getInstance).thenReturn(mockFactory);
when(mockFactory.getClient(any())).thenReturn(mockHttpClient);
when(mockHttpClient.execute(any())).thenReturn(mockHttpResponse);
when(mockHttpResponse.getStatusLine()).thenReturn(mockStatusLine);
when(mockStatusLine.getStatusCode()).thenReturn(200);
TelemetryResponse response = new TelemetryResponse().setNumSuccess(1L).setNumProtoSuccess(1L);
when(mockHttpResponse.getEntity())
.thenReturn(new StringEntity(new ObjectMapper().writeValueAsString(response)));

Map<String, String> headers = Map.of(HttpHeaders.AUTHORIZATION, "token");
when(databricksConfig.authenticate()).thenReturn(headers);

// JDBC URL with 2 seconds flush interval
String jdbcUrlWith2SecondsFlush =
"jdbc:databricks://adb-20.azuredatabricks.net:4423/default;transportMode=http;ssl=1;AuthMech=3;httpPath=/sql/1.0/warehouses/ghgjhgj;UserAgentEntry=MyApp;EnableTelemetry=1;TelemetryBatchSize=2;TelemetryFlushInterval=2000";

IDatabricksConnectionContext context =
DatabricksConnectionContext.parse(jdbcUrlWith2SecondsFlush, new Properties());
TelemetryClient client =
new TelemetryClient(context, MoreExecutors.newDirectExecutorService(), databricksConfig);

// Add a single event that won't trigger batch flush
client.exportEvent(new TelemetryFrontendLog().setFrontendLogEventId("event1"));
assertEquals(1, client.getCurrentSize());

// Wait for a short time to verify the periodic flush doesn't trigger immediately
Thread.sleep(100);
assertEquals(1, client.getCurrentSize());

// Wait for 2 seconds to trigger the periodic flush
Thread.sleep(2000);
assertEquals(0, client.getCurrentSize());

client.exportEvent(new TelemetryFrontendLog().setFrontendLogEventId("event2"));
assertEquals(1, client.getCurrentSize());
// Close the client to trigger final flush
client.close();
assertEquals(0, client.getCurrentSize());
}
}

@Test
public void testTimerResetOnBatchSizeFlush() throws Exception {
TelemetryClient client = null;
ExecutorService executor = null;
try (MockedStatic<DatabricksHttpClientFactory> factoryMocked =
mockStatic(DatabricksHttpClientFactory.class)) {
DatabricksHttpClientFactory mockFactory = mock(DatabricksHttpClientFactory.class);
factoryMocked.when(DatabricksHttpClientFactory::getInstance).thenReturn(mockFactory);
when(mockFactory.getClient(any())).thenReturn(mockHttpClient);
when(mockHttpClient.execute(any())).thenReturn(mockHttpResponse);
when(mockHttpResponse.getStatusLine()).thenReturn(mockStatusLine);
when(mockStatusLine.getStatusCode()).thenReturn(200);
TelemetryResponse response = new TelemetryResponse().setNumSuccess(1L).setNumProtoSuccess(1L);
when(mockHttpResponse.getEntity())
.thenReturn(new StringEntity(new ObjectMapper().writeValueAsString(response)));

// Set up a client with 3 second flush interval and batch size of 2
String jdbcUrl =
"jdbc:databricks://adb-20.azuredatabricks.net:4423/default;transportMode=http;ssl=1;AuthMech=3;httpPath=/sql/1.0/warehouses/ghgjhgj;UserAgentEntry=MyApp;EnableTelemetry=1;TelemetryBatchSize=2;TelemetryFlushInterval=3000";
IDatabricksConnectionContext context =
DatabricksConnectionContext.parse(jdbcUrl, new Properties());
executor = MoreExecutors.newDirectExecutorService();
client = new TelemetryClient(context, executor);

// Add events to trigger batch size flush
client.exportEvent(new TelemetryFrontendLog().setFrontendLogEventId("event1"));
client.exportEvent(
new TelemetryFrontendLog()
.setFrontendLogEventId("event2")); // This should trigger flush due to batch size

// Wait 2 seconds (less than the flush interval)
Thread.sleep(2000);

// Add another event
client.exportEvent(new TelemetryFrontendLog().setFrontendLogEventId("event3"));

// Verify it's still in the batch (shouldn't have been flushed yet since timer was reset)
assertEquals(1, client.getCurrentSize());

// Wait another 2 seconds (still less than full interval from last flush)
Thread.sleep(2000);

// Verify it's still not flushed
assertEquals(1, client.getCurrentSize());

} finally {
// Clean up resources
if (client != null) {
client.close();
}
if (executor != null) {
executor.shutdown();
// Wait for any pending tasks to complete
if (!executor.awaitTermination(5, TimeUnit.SECONDS)) {
executor.shutdownNow();
}
}
// Verify mocks were properly used
verify(mockHttpClient, atLeastOnce()).execute(any());
verify(mockHttpResponse, atLeastOnce()).getStatusLine();
verify(mockStatusLine, atLeastOnce()).getStatusCode();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -131,15 +131,23 @@ public void testGetDatabricksConfigSafely_ReturnsNullOnError() {

@Test
public void testGetDatabricksConfigSafely_HandlesNullContext() {
DatabricksConfig result = TelemetryHelper.getDatabricksConfigSafely(null);
DatabricksConfig result = TelemetryHelper.getDatabricksConfigSafely(connectionContext);
assertNull(result, "Should return null when context is null");
}

@Test
public void testTelemetryNotAllowedUsecase() {
assertFalse(() -> isTelemetryAllowedForConnection(null));
assertFalse(() -> isTelemetryAllowedForConnection(connectionContext));
when(connectionContext.getComputeResource()).thenReturn(WAREHOUSE_COMPUTE);
enableFeatureFlagForTesting(connectionContext, Collections.emptyMap());
assertFalse(() -> isTelemetryAllowedForConnection(connectionContext));
}

@Test
public void testTelemetryAllowedWithForceTelemetryFlag() {
when(connectionContext.getComputeResource()).thenReturn(WAREHOUSE_COMPUTE);
when(connectionContext.forceEnableTelemetry()).thenReturn(true);
enableFeatureFlagForTesting(connectionContext, Collections.emptyMap());
assertTrue(() -> isTelemetryAllowedForConnection(connectionContext));
}
}
Loading