Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -942,4 +942,14 @@ private Map<String, String> parseCustomHeaders(ImmutableMap<String, String> para
Collectors.toMap(
entry -> entry.getKey().substring(filterPrefix.length()), Map.Entry::getValue));
}

@Override
public boolean forceEnableTelemetry() {
return getParameter(DatabricksJdbcUrlParams.FORCE_ENABLE_TELEMETRY).equals("1");
}

@Override
public int getTelemetryFlushIntervalInMilliseconds() {
return Integer.parseInt(getParameter(DatabricksJdbcUrlParams.TELEMETRY_FLUSH_INTERVAL));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -320,4 +320,10 @@ public interface IDatabricksConnectionContext {

/** Returns the application name using JDBC Connection */
String getApplicationName();

/** Returns whether telemetry is enabled for all connections */
boolean forceEnableTelemetry();

/** Returns the flush interval in milliseconds for telemetry */
int getTelemetryFlushIntervalInMilliseconds();
}
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,8 @@ public enum DatabricksJdbcUrlParams {
TOKEN_CACHE_PASS_PHRASE("TokenCachePassPhrase", "Pass phrase to use for OAuth U2M Token Cache"),
ENABLE_TOKEN_CACHE("EnableTokenCache", "Enable caching OAuth tokens", "1"),
APPLICATION_NAME("ApplicationName", "Name of application using the driver", ""),
FORCE_ENABLE_TELEMETRY("ForceEnableTelemetry", "Force enable telemetry", "0"),
TELEMETRY_FLUSH_INTERVAL("TelemetryFlushInterval", "Flush interval in milliseconds", "5000"),
;

private final String paramName;
Expand Down
31 changes: 26 additions & 5 deletions src/main/java/com/databricks/jdbc/telemetry/TelemetryClient.java
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
import java.util.LinkedList;
import java.util.List;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;

public class TelemetryClient implements ITelemetryClient {

Expand All @@ -16,6 +18,7 @@ public class TelemetryClient implements ITelemetryClient {
private final int eventsBatchSize;
private final boolean isAuthEnabled;
private final ExecutorService executorService;
private final ScheduledExecutorService scheduledExecutorService;
private List<TelemetryFrontendLog> eventsBatch;

public TelemetryClient(
Expand All @@ -28,6 +31,9 @@ public TelemetryClient(
this.context = connectionContext;
this.databricksConfig = config;
this.executorService = executorService;
this.scheduledExecutorService =
java.util.concurrent.Executors.newSingleThreadScheduledExecutor();
schedulePeriodicFlush();
}

public TelemetryClient(
Expand All @@ -38,6 +44,16 @@ public TelemetryClient(
this.context = connectionContext;
this.databricksConfig = null;
this.executorService = executorService;
this.scheduledExecutorService =
java.util.concurrent.Executors.newSingleThreadScheduledExecutor();
schedulePeriodicFlush();
}

private void schedulePeriodicFlush() {
// Ensure minimum 1 second interval to avoid over-calling flush
int intervalMillis = Math.max(1000, context.getTelemetryFlushIntervalInMilliseconds());
scheduledExecutorService.scheduleAtFixedRate(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's reset the last flushedTime on flush, so that timer also gets reset

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense, addressed ✅

this::flush, intervalMillis, intervalMillis, TimeUnit.MILLISECONDS);
}

@Override
Expand All @@ -61,6 +77,7 @@ public void close() {
TelemetryHelper.exportChunkLatencyTelemetry(chunkDetails, statementId);
});
flush();
scheduledExecutorService.shutdown();
}

@Override
Expand All @@ -75,14 +92,18 @@ public void closeStatement(String statementId) {

private void flush() {
synchronized (this) {
List<TelemetryFrontendLog> logsToBeFlushed = eventsBatch;
executorService.submit(
new TelemetryPushTask(logsToBeFlushed, isAuthEnabled, context, databricksConfig));
eventsBatch = new LinkedList<>();
if (!eventsBatch.isEmpty()) {
List<TelemetryFrontendLog> logsToBeFlushed = eventsBatch;
executorService.submit(
new TelemetryPushTask(logsToBeFlushed, isAuthEnabled, context, databricksConfig));
eventsBatch = new LinkedList<>();
}
}
}

int getCurrentSize() {
return eventsBatch.size();
synchronized (this) {
return eventsBatch.size();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,9 @@ public static void updateClientAppName(String clientAppName) {
}

public static boolean isTelemetryAllowedForConnection(IDatabricksConnectionContext context) {
if (context.forceEnableTelemetry()) {
return true;
}
return context != null
&& context.isTelemetryEnabled()
&& DatabricksDriverFeatureFlagsContextFactory.getInstance(context)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -139,4 +139,50 @@ public void testExportEventDoesNotThrowErrorsInFailures() throws Exception {
() -> client.exportEvent(new TelemetryFrontendLog().setFrontendLogEventId("event2")));
}
}

@Test
public void testPeriodicFlushWithAuthenticatedClient() throws Exception {
try (MockedStatic<DatabricksHttpClientFactory> factoryMocked =
mockStatic(DatabricksHttpClientFactory.class)) {
DatabricksHttpClientFactory mockFactory = mock(DatabricksHttpClientFactory.class);
factoryMocked.when(DatabricksHttpClientFactory::getInstance).thenReturn(mockFactory);
when(mockFactory.getClient(any())).thenReturn(mockHttpClient);
when(mockHttpClient.execute(any())).thenReturn(mockHttpResponse);
when(mockHttpResponse.getStatusLine()).thenReturn(mockStatusLine);
when(mockStatusLine.getStatusCode()).thenReturn(200);
TelemetryResponse response = new TelemetryResponse().setNumSuccess(1L).setNumProtoSuccess(1L);
when(mockHttpResponse.getEntity())
.thenReturn(new StringEntity(new ObjectMapper().writeValueAsString(response)));

Map<String, String> headers = Map.of(HttpHeaders.AUTHORIZATION, "token");
when(databricksConfig.authenticate()).thenReturn(headers);

// JDBC URL with 2 seconds flush interval
String jdbcUrlWith2SecondsFlush =
"jdbc:databricks://adb-20.azuredatabricks.net:4423/default;transportMode=http;ssl=1;AuthMech=3;httpPath=/sql/1.0/warehouses/ghgjhgj;UserAgentEntry=MyApp;EnableTelemetry=1;TelemetryBatchSize=2;TelemetryFlushInterval=2000";

IDatabricksConnectionContext context =
DatabricksConnectionContext.parse(jdbcUrlWith2SecondsFlush, new Properties());
TelemetryClient client =
new TelemetryClient(context, MoreExecutors.newDirectExecutorService(), databricksConfig);

// Add a single event that won't trigger batch flush
client.exportEvent(new TelemetryFrontendLog().setFrontendLogEventId("event1"));
assertEquals(1, client.getCurrentSize());

// Wait for a short time to verify the periodic flush doesn't trigger immediately
Thread.sleep(100);
assertEquals(1, client.getCurrentSize());

// Wait for 2 seconds to trigger the periodic flush
Thread.sleep(2000);
assertEquals(0, client.getCurrentSize());

client.exportEvent(new TelemetryFrontendLog().setFrontendLogEventId("event2"));
assertEquals(1, client.getCurrentSize());
// Close the client to trigger final flush
client.close();
assertEquals(0, client.getCurrentSize());
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -131,15 +131,23 @@ public void testGetDatabricksConfigSafely_ReturnsNullOnError() {

@Test
public void testGetDatabricksConfigSafely_HandlesNullContext() {
DatabricksConfig result = TelemetryHelper.getDatabricksConfigSafely(null);
DatabricksConfig result = TelemetryHelper.getDatabricksConfigSafely(connectionContext);
assertNull(result, "Should return null when context is null");
}

@Test
public void testTelemetryNotAllowedUsecase() {
assertFalse(() -> isTelemetryAllowedForConnection(null));
assertFalse(() -> isTelemetryAllowedForConnection(connectionContext));
when(connectionContext.getComputeResource()).thenReturn(WAREHOUSE_COMPUTE);
enableFeatureFlagForTesting(connectionContext, Collections.emptyMap());
assertFalse(() -> isTelemetryAllowedForConnection(connectionContext));
}

@Test
public void testTelemetryAllowedWithForceTelemetryFlag() {
when(connectionContext.getComputeResource()).thenReturn(WAREHOUSE_COMPUTE);
when(connectionContext.forceEnableTelemetry()).thenReturn(true);
enableFeatureFlagForTesting(connectionContext, Collections.emptyMap());
assertTrue(() -> isTelemetryAllowedForConnection(connectionContext));
}
}
Loading