Skip to content

Commit 11258ba

Browse files
committed
feat(bq jdbc): run getStatementType in parallel
1 parent cbb2a7a commit 11258ba

File tree

6 files changed

+118
-18
lines changed

6 files changed

+118
-18
lines changed

java-bigquery/google-cloud-bigquery-jdbc/src/main/java/com/google/cloud/bigquery/jdbc/BigQueryConnection.java

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,8 @@
6060
import java.util.Set;
6161
import java.util.concurrent.ConcurrentHashMap;
6262
import java.util.concurrent.Executor;
63+
import java.util.concurrent.ExecutorService;
64+
import java.util.concurrent.Executors;
6365
import java.util.concurrent.TimeUnit;
6466

6567
/**
@@ -138,6 +140,8 @@ public class BigQueryConnection extends BigQueryNoOpsConnection {
138140
Long connectionPoolSize;
139141
Long listenerPoolSize;
140142
String partnerToken;
143+
private int queryTaskThreadCount;
144+
private ExecutorService queryTaskExecutor;
141145

142146
BigQueryConnection(String url) throws IOException {
143147
this(url, DataSource.fromUrl(url));
@@ -238,6 +242,10 @@ public class BigQueryConnection extends BigQueryNoOpsConnection {
238242
this.filterTablesOnDefaultDataset = ds.getFilterTablesOnDefaultDataset();
239243
this.requestGoogleDriveScope = ds.getRequestGoogleDriveScope();
240244
this.metadataFetchThreadCount = ds.getMetadataFetchThreadCount();
245+
this.queryTaskThreadCount = ds.getQueryTaskThreadCount();
246+
this.queryTaskExecutor =
247+
Executors.newFixedThreadPool(
248+
this.queryTaskThreadCount, new BigQueryThreadFactory("BigQuery-query-task-"));
241249
this.requestReason = ds.getRequestReason();
242250
this.connectionPoolSize = ds.getConnectionPoolSize();
243251
this.listenerPoolSize = ds.getListenerPoolSize();
@@ -596,6 +604,10 @@ int getMetadataFetchThreadCount() {
596604
return this.metadataFetchThreadCount;
597605
}
598606

607+
public ExecutorService getQueryTaskExecutor() {
608+
return this.queryTaskExecutor;
609+
}
610+
599611
boolean isEnableWriteAPI() {
600612
return enableWriteAPI;
601613
}
@@ -836,6 +848,10 @@ public void close() throws SQLException {
836848
statement.close();
837849
}
838850
this.openStatements.clear();
851+
852+
if (this.queryTaskExecutor != null) {
853+
this.queryTaskExecutor.shutdown();
854+
}
839855
} catch (ConcurrentModificationException ex) {
840856
throw new BigQueryJdbcException(ex);
841857
} catch (InterruptedException e) {

java-bigquery/google-cloud-bigquery-jdbc/src/main/java/com/google/cloud/bigquery/jdbc/BigQueryJdbcUrlUtility.java

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,8 @@ protected boolean removeEldestEntry(Map.Entry<String, Map<String, String>> eldes
142142
Pattern.CASE_INSENSITIVE);
143143
static final String METADATA_FETCH_THREAD_COUNT_PROPERTY_NAME = "MetaDataFetchThreadCount";
144144
static final int DEFAULT_METADATA_FETCH_THREAD_COUNT_VALUE = 32;
145+
static final String QUERY_TASK_THREAD_COUNT_PROPERTY_NAME = "QueryTaskThreadCount";
146+
static final int DEFAULT_QUERY_TASK_THREAD_COUNT_VALUE = 16;
145147
static final String RETRY_TIMEOUT_IN_SECS_PROPERTY_NAME = "Timeout";
146148
static final long DEFAULT_RETRY_TIMEOUT_IN_SECS_VALUE = 0L;
147149
static final String JOB_TIMEOUT_PROPERTY_NAME = "JobTimeout";
@@ -535,6 +537,12 @@ protected boolean removeEldestEntry(Map.Entry<String, Map<String, String>> eldes
535537
"The number of threads used to call a DatabaseMetaData method.")
536538
.setDefaultValue(String.valueOf(DEFAULT_METADATA_FETCH_THREAD_COUNT_VALUE))
537539
.build(),
540+
BigQueryConnectionProperty.newBuilder()
541+
.setName(QUERY_TASK_THREAD_COUNT_PROPERTY_NAME)
542+
.setDescription(
543+
"The number of background threads used for executing queries parallel tasks.")
544+
.setDefaultValue(String.valueOf(DEFAULT_QUERY_TASK_THREAD_COUNT_VALUE))
545+
.build(),
538546
BigQueryConnectionProperty.newBuilder()
539547
.setName(ENABLE_WRITE_API_PROPERTY_NAME)
540548
.setDescription(

java-bigquery/google-cloud-bigquery-jdbc/src/main/java/com/google/cloud/bigquery/jdbc/BigQueryStatement.java

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -72,8 +72,8 @@
7272
import java.util.Random;
7373
import java.util.UUID;
7474
import java.util.concurrent.BlockingQueue;
75-
import java.util.concurrent.ExecutorService;
76-
import java.util.concurrent.Executors;
75+
import java.util.concurrent.ExecutionException;
76+
import java.util.concurrent.Future;
7777
import java.util.concurrent.LinkedBlockingDeque;
7878
import java.util.concurrent.ThreadFactory;
7979
import java.util.logging.Level;
@@ -88,9 +88,6 @@
8888
public class BigQueryStatement extends BigQueryNoOpsStatement {
8989

9090
// TODO (obada): Update this after benchmarking
91-
private static final int MAX_PROCESS_QUERY_THREADS_CNT = 50;
92-
protected static ExecutorService queryTaskExecutor =
93-
Executors.newFixedThreadPool(MAX_PROCESS_QUERY_THREADS_CNT);
9491
private final BigQueryJdbcCustomLogger LOG = new BigQueryJdbcCustomLogger(this.toString());
9592
private static final String DEFAULT_DATASET_NAME = "_google_jdbc";
9693
private static final String DEFAULT_TABLE_NAME = "temp_table_";
@@ -594,15 +591,20 @@ void runQuery(String query, QueryJobConfiguration jobConfiguration)
594591

595592
try {
596593
resetStatementFields();
594+
595+
final QueryJobConfiguration finalJobConfiguration = jobConfiguration;
596+
Future<StatementType> statementTypeFuture =
597+
connection.getQueryTaskExecutor().submit(() -> getStatementType(finalJobConfiguration));
598+
597599
ExecuteResult executeResult = executeJob(jobConfiguration);
598-
StatementType statementType =
599-
executeResult.job == null
600-
? getStatementType(jobConfiguration)
601-
: ((QueryStatistics) executeResult.job.getStatistics()).getStatementType();
600+
601+
StatementType statementType = statementTypeFuture.get();
602602
SqlType queryType = getQueryType(jobConfiguration, statementType);
603603
handleQueryResult(query, executeResult.tableResult, queryType);
604604
} catch (InterruptedException ex) {
605605
throw new BigQueryJdbcRuntimeException(ex);
606+
} catch (ExecutionException e) {
607+
throw new BigQueryJdbcException(e.getCause());
606608
} catch (BigQueryException ex) {
607609
if (ex.getMessage().contains("Syntax error")) {
608610
throw new BigQueryJdbcSqlSyntaxErrorException(ex);
@@ -829,7 +831,8 @@ Thread populateArrowBufferedQueue(
829831
com.google.api.gax.rpc.ServerStream<ReadRowsResponse> stream =
830832
bqReadClient.readRowsCallable().call(readRowsRequest);
831833
for (ReadRowsResponse response : stream) {
832-
if (Thread.currentThread().isInterrupted() || queryTaskExecutor.isShutdown()) {
834+
if (Thread.currentThread().isInterrupted()
835+
|| connection.getQueryTaskExecutor().isShutdown()) {
833836
break;
834837
}
835838

@@ -1042,7 +1045,8 @@ Thread runNextPageTaskAsync(
10421045
try {
10431046
while (currentPageToken != null) {
10441047
// do not process further pages and shutdown
1045-
if (Thread.currentThread().isInterrupted() || queryTaskExecutor.isShutdown()) {
1048+
if (Thread.currentThread().isInterrupted()
1049+
|| connection.getQueryTaskExecutor().isShutdown()) {
10461050
LOG.warning(
10471051
"%s Interrupted @ runNextPageTaskAsync", Thread.currentThread().getName());
10481052
break;
@@ -1073,7 +1077,8 @@ Thread runNextPageTaskAsync(
10731077
// completes
10741078
Uninterruptibles.putUninterruptibly(rpcResponseQueue, Tuple.of(null, false));
10751079
}
1076-
// We cannot do queryTaskExecutor.shutdownNow() here as populate buffer method may not
1080+
// We cannot do connection.getQueryTaskExecutor().shutdownNow() here as populate buffer
1081+
// method may not
10771082
// have finished processing the records and even that will be interrupted
10781083
};
10791084

@@ -1117,7 +1122,7 @@ Thread parseAndPopulateRpcDataAsync(
11171122
}
11181123

11191124
if (Thread.currentThread().isInterrupted()
1120-
|| queryTaskExecutor.isShutdown()
1125+
|| connection.getQueryTaskExecutor().isShutdown()
11211126
|| fieldValueLists == null) {
11221127
// do not process further pages and shutdown (outerloop)
11231128
break;
@@ -1127,7 +1132,8 @@ Thread parseAndPopulateRpcDataAsync(
11271132
long results = 0;
11281133
for (FieldValueList fieldValueList : fieldValueLists) {
11291134

1130-
if (Thread.currentThread().isInterrupted() || queryTaskExecutor.isShutdown()) {
1135+
if (Thread.currentThread().isInterrupted()
1136+
|| connection.getQueryTaskExecutor().isShutdown()) {
11311137
// do not process further pages and shutdown (inner loop)
11321138
break;
11331139
}

java-bigquery/google-cloud-bigquery-jdbc/src/main/java/com/google/cloud/bigquery/jdbc/DataSource.java

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ public class DataSource implements javax.sql.DataSource {
8484
private Boolean filterTablesOnDefaultDataset;
8585
private Integer requestGoogleDriveScope;
8686
private Integer metadataFetchThreadCount;
87+
private Integer queryTaskThreadCount;
8788
private String sslTrustStorePath;
8889
private String sslTrustStorePassword;
8990
private Map<String, String> labels;
@@ -240,6 +241,9 @@ public class DataSource implements javax.sql.DataSource {
240241
.put(
241242
BigQueryJdbcUrlUtility.METADATA_FETCH_THREAD_COUNT_PROPERTY_NAME,
242243
(ds, val) -> ds.setMetadataFetchThreadCount(Integer.parseInt(val)))
244+
.put(
245+
BigQueryJdbcUrlUtility.QUERY_TASK_THREAD_COUNT_PROPERTY_NAME,
246+
(ds, val) -> ds.setQueryTaskThreadCount(Integer.parseInt(val)))
243247
.put(
244248
BigQueryJdbcUrlUtility.SSL_TRUST_STORE_PROPERTY_NAME,
245249
DataSource::setSSLTrustStorePath)
@@ -546,6 +550,11 @@ private Properties createProperties() {
546550
BigQueryJdbcUrlUtility.METADATA_FETCH_THREAD_COUNT_PROPERTY_NAME,
547551
String.valueOf(this.metadataFetchThreadCount));
548552
}
553+
if (this.queryTaskThreadCount != null) {
554+
connectionProperties.setProperty(
555+
BigQueryJdbcUrlUtility.QUERY_TASK_THREAD_COUNT_PROPERTY_NAME,
556+
String.valueOf(this.queryTaskThreadCount));
557+
}
549558
if (this.sslTrustStorePath != null) {
550559
connectionProperties.setProperty(
551560
BigQueryJdbcUrlUtility.SSL_TRUST_STORE_PROPERTY_NAME,
@@ -1024,6 +1033,16 @@ public void setMetadataFetchThreadCount(Integer metadataFetchThreadCount) {
10241033
this.metadataFetchThreadCount = metadataFetchThreadCount;
10251034
}
10261035

1036+
public Integer getQueryTaskThreadCount() {
1037+
return queryTaskThreadCount != null
1038+
? queryTaskThreadCount
1039+
: BigQueryJdbcUrlUtility.DEFAULT_QUERY_TASK_THREAD_COUNT_VALUE;
1040+
}
1041+
1042+
public void setQueryTaskThreadCount(Integer queryTaskThreadCount) {
1043+
this.queryTaskThreadCount = queryTaskThreadCount;
1044+
}
1045+
10271046
public String getSSLTrustStorePath() {
10281047
return sslTrustStorePath;
10291048
}

java-bigquery/google-cloud-bigquery-jdbc/src/test/java/com/google/cloud/bigquery/jdbc/BigQueryConnectionTest.java

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -363,6 +363,38 @@ public void testMetaDataFetchThreadCountProperty() throws SQLException, IOExcept
363363
}
364364
}
365365

366+
@Test
367+
public void testQueryTaskThreadCountProperty() throws SQLException, IOException {
368+
// Test Case 1: Should use the default value when the property is not provided.
369+
String urlDefault =
370+
"jdbc:bigquery://https://www.googleapis.com/bigquery/v2:443;"
371+
+ "OAuthType=2;ProjectId=MyBigQueryProject;"
372+
+ "OAuthAccessToken=redactedToken;OAuthClientId=redactedToken;"
373+
+ "OAuthClientSecret=redactedToken;";
374+
try (BigQueryConnection connectionDefault = new BigQueryConnection(urlDefault)) {
375+
assertEquals(
376+
4,
377+
((java.util.concurrent.ThreadPoolExecutor) connectionDefault.getQueryTaskExecutor())
378+
.getCorePoolSize(),
379+
"Should use the default value of 4 when the property is not provided");
380+
}
381+
382+
// Test Case 2: Should use the custom value when a valid integer is provided.
383+
String urlCustom =
384+
"jdbc:bigquery://https://www.googleapis.com/bigquery/v2:443;"
385+
+ "OAuthType=2;ProjectId=MyBigQueryProject;"
386+
+ "OAuthAccessToken=redactedToken;OAuthClientId=redactedToken;"
387+
+ "OAuthClientSecret=redactedToken;"
388+
+ "QueryTaskThreadCount=16;";
389+
try (BigQueryConnection connectionCustom = new BigQueryConnection(urlCustom)) {
390+
assertEquals(
391+
16,
392+
((java.util.concurrent.ThreadPoolExecutor) connectionCustom.getQueryTaskExecutor())
393+
.getCorePoolSize(),
394+
"Should use the custom value when a valid integer is provided");
395+
}
396+
}
397+
366398
@Test
367399
public void testBigQueryReadClientKeepAliveSettings() throws SQLException, IOException {
368400
String url =

java-bigquery/google-cloud-bigquery-jdbc/src/test/java/com/google/cloud/bigquery/jdbc/BigQueryStatementTest.java

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -60,11 +60,14 @@
6060
import java.util.Map;
6161
import java.util.UUID;
6262
import java.util.concurrent.BlockingQueue;
63+
import java.util.concurrent.ExecutorService;
64+
import java.util.concurrent.Executors;
6365
import org.apache.arrow.memory.RootAllocator;
6466
import org.apache.arrow.vector.BitVector;
6567
import org.apache.arrow.vector.FieldVector;
6668
import org.apache.arrow.vector.IntVector;
6769
import org.apache.arrow.vector.VectorSchemaRoot;
70+
import org.junit.jupiter.api.AfterEach;
6871
import org.junit.jupiter.api.BeforeEach;
6972
import org.junit.jupiter.api.Disabled;
7073
import org.junit.jupiter.api.Test;
@@ -84,6 +87,8 @@ public class BigQueryStatementTest {
8487

8588
private BigQueryStatement bigQueryStatement;
8689

90+
private ExecutorService queryTaskExecutor;
91+
8792
private final String query = "select * from test";
8893

8994
private final String jobIdVal = UUID.randomUUID().toString();
@@ -126,13 +131,16 @@ private Job getJobMock(
126131

127132
@BeforeEach
128133
public void setUp() throws IOException, SQLException {
134+
queryTaskExecutor = Executors.newFixedThreadPool(1);
129135
bigQueryConnection = mock(BigQueryConnection.class);
130136
rpcFactoryMock = mock(BigQueryRpcFactory.class);
131137
bigquery = mock(BigQuery.class);
132138
bigQueryConnection.bigQuery = bigquery;
133139
storageReadClient = mock(BigQueryReadClient.class);
134140
jobId = JobId.newBuilder().setJob(jobIdVal).build();
135141

142+
doReturn(queryTaskExecutor).when(bigQueryConnection).getQueryTaskExecutor();
143+
136144
doReturn(bigquery).when(bigQueryConnection).getBigQuery();
137145
doReturn(10L).when(bigQueryConnection).getJobTimeoutInSeconds();
138146
doReturn(10L).when(bigQueryConnection).getMaxBytesBilled();
@@ -148,7 +156,13 @@ public void setUp() throws IOException, SQLException {
148156
.setSerializedSchema(serializeSchema(vectorSchemaRoot.getSchema()))
149157
.build();
150158
// bigQueryConnection.addOpenStatements(bigQueryStatement);
159+
}
151160

161+
@AfterEach
162+
public void tearDown() {
163+
if (queryTaskExecutor != null) {
164+
queryTaskExecutor.shutdown();
165+
}
152166
}
153167

154168
private VectorSchemaRoot getTestVectorSchemaRoot() {
@@ -303,8 +317,13 @@ public void setQueryTimeoutTest() throws Exception {
303317
ArgumentCaptor<JobInfo> captor = ArgumentCaptor.forClass(JobInfo.class);
304318

305319
bigQueryStatementSpy.runQuery(query, jobConfiguration);
306-
verify(bigquery).create(captor.capture());
307-
QueryJobConfiguration jobConfig = captor.getValue().getConfiguration();
320+
verify(bigquery, Mockito.times(2)).create(captor.capture());
321+
QueryJobConfiguration jobConfig =
322+
captor.getAllValues().stream()
323+
.map(jobInfo -> (QueryJobConfiguration) jobInfo.getConfiguration())
324+
.filter(config -> config.dryRun() == null || !config.dryRun())
325+
.findFirst()
326+
.get();
308327
assertEquals(3000L, jobConfig.getJobTimeoutMs().longValue());
309328
}
310329

@@ -401,10 +420,10 @@ public void testJoblessQuery() throws SQLException, InterruptedException {
401420

402421
jobfulStatementSpy.executeQuery("SELECT 1");
403422

404-
verify(bigquery).create(any(JobInfo.class));
423+
verify(bigquery, Mockito.times(2)).create(any(JobInfo.class));
405424
assertTrue(
406425
jobfulCaptor.getAllValues().stream()
407-
.noneMatch(
426+
.anyMatch(
408427
jobInfo ->
409428
Boolean.TRUE.equals(
410429
((QueryJobConfiguration) jobInfo.getConfiguration()).dryRun())));

0 commit comments

Comments
 (0)