diff --git a/NEXT_CHANGELOG.md b/NEXT_CHANGELOG.md index 89acbc938b..05bfa222ad 100644 --- a/NEXT_CHANGELOG.md +++ b/NEXT_CHANGELOG.md @@ -9,6 +9,7 @@ - Added a client property `enableVolumeOperations` to enable GET/PUT/REMOVE volume operations on a stream. For backward compatibility, allowedVolumeIngestionPaths can also be used for REMOVE operation. - Support for fetching schemas across all catalogs (when catalog is specified as null or a wildcard) in `DatabaseMetaData#getSchemas` API in SQL Execution mode. - **Configurable SQL validation in isValid()**: Added `EnableSQLValidationForIsValid` connection property to control whether `isValid()` method executes an actual SQL query for server-side validation. Default value is 0. +- Implement multi-row INSERT batching optimization for prepared statements to improve performance when executing large batches of INSERT operations. ### Updated - Databricks SDK dependency upgraded to latest version 0.60.0 diff --git a/src/main/java/com/databricks/jdbc/api/impl/DatabricksConnectionContext.java b/src/main/java/com/databricks/jdbc/api/impl/DatabricksConnectionContext.java index e2fe52c771..a699d0d1ef 100644 --- a/src/main/java/com/databricks/jdbc/api/impl/DatabricksConnectionContext.java +++ b/src/main/java/com/databricks/jdbc/api/impl/DatabricksConnectionContext.java @@ -1025,4 +1025,9 @@ public int getTelemetryFlushIntervalInMilliseconds() { return Math.max( 1000, Integer.parseInt(getParameter(DatabricksJdbcUrlParams.TELEMETRY_FLUSH_INTERVAL))); } + + @Override + public boolean isBatchedInsertsEnabled() { + return getParameter(DatabricksJdbcUrlParams.ENABLE_BATCHED_INSERTS).equals("1"); + } } diff --git a/src/main/java/com/databricks/jdbc/api/impl/DatabricksPreparedStatement.java b/src/main/java/com/databricks/jdbc/api/impl/DatabricksPreparedStatement.java index 93f8dc5556..1fb8158cdc 100644 --- a/src/main/java/com/databricks/jdbc/api/impl/DatabricksPreparedStatement.java +++ b/src/main/java/com/databricks/jdbc/api/impl/DatabricksPreparedStatement.java @@ -7,8 +7,10 @@ import static com.databricks.jdbc.common.util.SQLInterpolator.surroundPlaceholdersWithQuotes; import static com.databricks.jdbc.common.util.ValidationUtil.throwErrorIfNull; +import com.databricks.jdbc.common.DatabricksJdbcConstants; import com.databricks.jdbc.common.StatementType; import com.databricks.jdbc.common.util.DatabricksTypeUtil; +import com.databricks.jdbc.common.util.InsertStatementParser; import com.databricks.jdbc.exception.*; import com.databricks.jdbc.log.JdbcLogger; import com.databricks.jdbc.log.JdbcLoggerFactory; @@ -88,6 +90,138 @@ public int[] executeBatch() throws DatabricksBatchUpdateException { @Override public long[] executeLargeBatch() throws DatabricksBatchUpdateException { LOGGER.debug("public long executeLargeBatch()"); + + if (databricksBatchParameterMetaData.isEmpty()) { + return new long[0]; + } + + // Try to optimize INSERT statements with multi-row batching + if (canUseBatchedInsert()) { + return executeBatchedInsert(); + } else { + // Fall back to individual execution for non-INSERT or incompatible statements + return executeIndividualStatements(); + } + } + + /** + * Checks if the current batch can be optimized using multi-row INSERT. All statements must be + * compatible INSERT operations. + * + *

A batch is eligible for multi-row INSERT optimization when: + * + *

+ * + *

Compatible INSERT operations target the same table with the same columns in the same order. + * When compatible, multiple individual INSERTs like: + * + *

+   *   INSERT INTO users (id, name) VALUES (?, ?)  -- with parameters [1, "Alice"]
+   *   INSERT INTO users (id, name) VALUES (?, ?)  -- with parameters [2, "Bob"]
+   * 
+ * + * Are combined into a single multi-row INSERT: + * + *
+   *   INSERT INTO users (id, name) VALUES (?, ?), (?, ?)  -- with parameters [1, "Alice", 2, "Bob"]
+   * 
+ */ + private boolean canUseBatchedInsert() { + // Check if batched inserts are enabled via connection property + if (!connection.getConnectionContext().isBatchedInsertsEnabled()) { + return false; + } + + // Use strict exception-based parsing for better error handling + try { + InsertStatementParser.parseInsertStrict(sql); + return !databricksBatchParameterMetaData.isEmpty(); + } catch (Exception e) { + // Not a valid INSERT statement suitable for batching + return false; + } + } + + /** Executes the batch as a single multi-row INSERT statement. */ + private long[] executeBatchedInsert() throws DatabricksBatchUpdateException { + LOGGER.debug("Executing batched INSERT with {} rows", databricksBatchParameterMetaData.size()); + + try { + InsertStatementParser.InsertInfo insertInfo = InsertStatementParser.parseInsertStrict(sql); + + // Calculate how many rows we can fit in one chunk based on parameter limit + int parametersPerRow = insertInfo.getColumnCount(); + int maxRowsPerChunk = DatabricksJdbcConstants.MAX_QUERY_PARAMETERS / parametersPerRow; + + // Ensure we have at least 1 row per chunk + if (maxRowsPerChunk < 1) { + maxRowsPerChunk = 1; + } + + long[] allUpdateCounts = new long[databricksBatchParameterMetaData.size()]; + int processedRows = 0; + + // Process batches in chunks + for (int startIndex = 0; + startIndex < databricksBatchParameterMetaData.size(); + startIndex += maxRowsPerChunk) { + int endIndex = + Math.min(startIndex + maxRowsPerChunk, databricksBatchParameterMetaData.size()); + int chunkSize = endIndex - startIndex; + + LOGGER.debug("Processing chunk {}-{} ({} rows)", startIndex + 1, endIndex, chunkSize); + + // Generate multi-row SQL for this chunk + String multiRowSql = InsertStatementParser.generateMultiRowInsert(insertInfo, chunkSize); + + // Combine parameters for this chunk + Map chunkParams = new HashMap<>(); + int paramIndex = 1; + + for (int i = startIndex; i < endIndex; i++) { + DatabricksParameterMetaData batchParams = databricksBatchParameterMetaData.get(i); + Map rowParams = batchParams.getParameterBindings(); + for (int j = 1; j <= rowParams.size(); j++) { + if (rowParams.containsKey(j)) { + chunkParams.put(paramIndex++, rowParams.get(j)); + } + } + } + + // Execute this chunk + executeInternal(multiRowSql, chunkParams, StatementType.UPDATE, false); + + // Set update counts for this chunk (each row typically affects 1 row) + for (int i = startIndex; i < endIndex; i++) { + allUpdateCounts[i] = 1; + } + + processedRows += chunkSize; + } + + LOGGER.debug("Successfully processed {} rows in chunks", processedRows); + return allUpdateCounts; + + } catch (Exception e) { + LOGGER.error("Error executing batched INSERT: {}", e.getMessage(), e); + long[] failedCounts = new long[databricksBatchParameterMetaData.size()]; + for (int i = 0; i < failedCounts.length; i++) { + failedCounts[i] = Statement.EXECUTE_FAILED; + } + throw new DatabricksBatchUpdateException( + e.getMessage(), DatabricksDriverErrorCode.BATCH_EXECUTE_EXCEPTION, failedCounts); + } + } + + /** Executes batch statements individually (fallback method). */ + private long[] executeIndividualStatements() throws DatabricksBatchUpdateException { + LOGGER.debug( + "Executing batch individually with {} statements", databricksBatchParameterMetaData.size()); long[] largeUpdateCount = new long[databricksBatchParameterMetaData.size()]; for (int sqlQueryIndex = 0; diff --git a/src/main/java/com/databricks/jdbc/api/impl/DatabricksStatement.java b/src/main/java/com/databricks/jdbc/api/impl/DatabricksStatement.java index da5d82d5fc..091b6bd20d 100644 --- a/src/main/java/com/databricks/jdbc/api/impl/DatabricksStatement.java +++ b/src/main/java/com/databricks/jdbc/api/impl/DatabricksStatement.java @@ -692,6 +692,14 @@ static boolean isSelectQuery(String query) { return SELECT_PATTERN.matcher(trimmedQuery).find(); } + static boolean isInsertQuery(String query) { + if (query == null || query.trim().isEmpty()) { + return false; + } + String trimmedQuery = trimCommentsAndWhitespaces(query); + return INSERT_PATTERN.matcher(trimmedQuery).find(); + } + DatabricksResultSet executeInternal( String sql, Map params, diff --git a/src/main/java/com/databricks/jdbc/api/internal/IDatabricksConnectionContext.java b/src/main/java/com/databricks/jdbc/api/internal/IDatabricksConnectionContext.java index 1b22e1bdd4..98ab1fc797 100644 --- a/src/main/java/com/databricks/jdbc/api/internal/IDatabricksConnectionContext.java +++ b/src/main/java/com/databricks/jdbc/api/internal/IDatabricksConnectionContext.java @@ -353,4 +353,7 @@ public interface IDatabricksConnectionContext { /** Returns the HTTP connection request timeout in seconds */ Integer getHttpConnectionRequestTimeout(); + + /** Returns whether batched INSERT optimization is enabled */ + boolean isBatchedInsertsEnabled(); } diff --git a/src/main/java/com/databricks/jdbc/common/DatabricksJdbcConstants.java b/src/main/java/com/databricks/jdbc/common/DatabricksJdbcConstants.java index bbcacb1277..792768de6a 100644 --- a/src/main/java/com/databricks/jdbc/common/DatabricksJdbcConstants.java +++ b/src/main/java/com/databricks/jdbc/common/DatabricksJdbcConstants.java @@ -166,6 +166,12 @@ public enum FakeServiceType { Pattern.compile("^(\\s*\\()*\\s*REMOVE", Pattern.CASE_INSENSITIVE); public static final Pattern LIST_PATTERN = Pattern.compile("^(\\s*\\()*\\s*LIST", Pattern.CASE_INSENSITIVE); + public static final Pattern INSERT_PATTERN = + Pattern.compile("^(\\s*\\()*\\s*INSERT\\s+INTO", Pattern.CASE_INSENSITIVE); + + /** Maximum number of parameters allowed in a single Databricks query */ + public static final int MAX_QUERY_PARAMETERS = 256; + // Regex: match queries starting with "BEGIN" but not followed by "TRANSACTION" // (?i) -> case-insensitive // ^\s*BEGIN -> string starts with BEGIN (allow leading whitespace) diff --git a/src/main/java/com/databricks/jdbc/common/DatabricksJdbcUrlParams.java b/src/main/java/com/databricks/jdbc/common/DatabricksJdbcUrlParams.java index d0b04dee37..973f96a11d 100644 --- a/src/main/java/com/databricks/jdbc/common/DatabricksJdbcUrlParams.java +++ b/src/main/java/com/databricks/jdbc/common/DatabricksJdbcUrlParams.java @@ -154,6 +154,7 @@ public enum DatabricksJdbcUrlParams { "HttpConnectionRequestTimeout", "HTTP connection request timeout in seconds"), CLOUD_FETCH_SPEED_THRESHOLD( "CloudFetchSpeedThreshold", "Minimum expected download speed in MB/s", "0.1"), + ENABLE_BATCHED_INSERTS("EnableBatchedInserts", "Enable batched INSERT optimization", "0"), ENABLE_SQL_VALIDATION_FOR_IS_VALID( "EnableSQLValidationForIsValid", "Enable SQL query execution for connection validation in isValid() method", diff --git a/src/main/java/com/databricks/jdbc/common/util/InsertStatementParser.java b/src/main/java/com/databricks/jdbc/common/util/InsertStatementParser.java new file mode 100644 index 0000000000..050cc08020 --- /dev/null +++ b/src/main/java/com/databricks/jdbc/common/util/InsertStatementParser.java @@ -0,0 +1,215 @@ +package com.databricks.jdbc.common.util; + +import static com.databricks.jdbc.common.DatabricksJdbcConstants.INSERT_PATTERN; + +import com.databricks.jdbc.exception.DatabricksParsingException; +import com.databricks.jdbc.model.telemetry.enums.DatabricksDriverErrorCode; +import java.util.Arrays; +import java.util.List; +import java.util.Objects; +import java.util.regex.Matcher; +import java.util.regex.Pattern; +import java.util.stream.Collectors; + +/** + * Utility class for parsing INSERT statements to extract table and column information. Supports + * detecting compatible INSERT statements that can be combined into multi-row batches. + */ +public class InsertStatementParser { + + // Pattern to extract table and columns from INSERT INTO table (col1, col2, ...) VALUES format + private static final Pattern INSERT_DETAILS_PATTERN = + Pattern.compile( + "^\\s*INSERT\\s+INTO\\s+([\\w`\\.]+)\\s*\\(([^)]+)\\)\\s+VALUES\\s*\\(", + Pattern.CASE_INSENSITIVE | Pattern.DOTALL); + + /** Represents the parsed components of an INSERT statement. */ + public static class InsertInfo { + private final String tableName; + private final List columns; + private final String originalSql; + + public InsertInfo(String tableName, List columns, String originalSql) { + this.tableName = tableName; + this.columns = columns; + this.originalSql = originalSql; + } + + public String getTableName() { + return tableName; + } + + public List getColumns() { + return columns; + } + + public String getOriginalSql() { + return originalSql; + } + + public int getColumnCount() { + return columns.size(); + } + + /** + * Checks if this INSERT is compatible with another INSERT for batching. Two INSERTs are + * compatible if they target the same table with the same columns in the same order. + * + *

Compatible INSERT operations can be combined into multi-row INSERT statements for improved + * performance. For example, these two statements are compatible: + * + *

+     *   INSERT INTO users (id, name, email) VALUES (?, ?, ?)
+     *   INSERT INTO users (id, name, email) VALUES (?, ?, ?)
+     * 
+ * + * These can be batched into: + * + *
+     *   INSERT INTO users (id, name, email) VALUES (?, ?, ?), (?, ?, ?)
+     * 
+ */ + public boolean isCompatibleWith(InsertInfo other) { + return Objects.equals(this.tableName, other.tableName) + && Objects.equals(this.columns, other.columns); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + InsertInfo that = (InsertInfo) o; + return Objects.equals(tableName, that.tableName) && Objects.equals(columns, that.columns); + } + + @Override + public int hashCode() { + return Objects.hash(tableName, columns); + } + } + + /** + * Parses an INSERT statement to extract table and column information. + * + * @param sql the INSERT SQL statement to parse + * @return InsertInfo object containing parsed information, or null if not a valid INSERT + */ + public static InsertInfo parseInsert(String sql) { + try { + return parseInsertStrict(sql); + } catch (DatabricksParsingException e) { + return null; + } + } + + /** + * Parses an INSERT statement to extract table and column information with strict error handling. + * + * @param sql the INSERT SQL statement to parse + * @return InsertInfo object containing parsed information + * @throws DatabricksParsingException if the SQL is not a properly formatted INSERT statement + */ + public static InsertInfo parseInsertStrict(String sql) throws DatabricksParsingException { + if (sql == null || sql.trim().isEmpty()) { + throw new DatabricksParsingException( + "SQL statement cannot be null or empty", + DatabricksDriverErrorCode.INPUT_VALIDATION_ERROR); + } + + String trimmedSql = sql.trim(); + + // First check if it's an INSERT query using the shared pattern + if (!INSERT_PATTERN.matcher(trimmedSql).find()) { + throw new DatabricksParsingException( + "SQL statement is not an INSERT operation: " + trimmedSql, + DatabricksDriverErrorCode.INPUT_VALIDATION_ERROR); + } + + // Then extract detailed information using our specific pattern + Matcher matcher = INSERT_DETAILS_PATTERN.matcher(trimmedSql); + + if (!matcher.find()) { + throw new DatabricksParsingException( + "INSERT statement does not match the expected format 'INSERT INTO table (columns) VALUES': " + + trimmedSql, + DatabricksDriverErrorCode.INPUT_VALIDATION_ERROR); + } + + String tableName = matcher.group(1).trim(); + String columnsStr = matcher.group(2).trim(); + + // Parse column names, handling quoted identifiers and whitespace + List columns = parseColumns(columnsStr); + + if (columns.isEmpty()) { + throw new DatabricksParsingException( + "INSERT statement does not contain any valid column names: " + trimmedSql, + DatabricksDriverErrorCode.INPUT_VALIDATION_ERROR); + } + + return new InsertInfo(tableName, columns, trimmedSql); + } + + /** Parses a comma-separated list of column names, handling quoted identifiers. */ + private static List parseColumns(String columnsStr) { + return Arrays.stream(columnsStr.split(",")) + .map(String::trim) + .map(col -> col.replaceAll("^`|`$", "")) // Remove backticks if present + .filter(col -> !col.isEmpty()) + .collect(Collectors.toList()); + } + + /** + * Checks if the given SQL statement is a parametrized INSERT statement suitable for batching. + * + * @param sql the SQL statement to check + * @return true if it's a parametrized INSERT that can be batched, false otherwise + */ + public static boolean isParametrizedInsert(String sql) { + // Use the shared INSERT pattern for efficient detection + if (sql == null || !INSERT_PATTERN.matcher(sql.trim()).find()) { + return false; + } + return sql.contains("?"); + } + + /** + * Generates a multi-row INSERT statement from the template and number of rows. + * + * @param insertInfo the parsed INSERT information + * @param numberOfRows the number of rows to include in the batch + * @return the multi-row INSERT SQL statement + * @throws DatabricksParsingException if insertInfo is null or numberOfRows is invalid + */ + public static String generateMultiRowInsert(InsertInfo insertInfo, int numberOfRows) + throws DatabricksParsingException { + if (insertInfo == null) { + throw new DatabricksParsingException( + "InsertInfo cannot be null", DatabricksDriverErrorCode.INPUT_VALIDATION_ERROR); + } + if (numberOfRows <= 0) { + throw new DatabricksParsingException( + "Number of rows must be positive, got: " + numberOfRows, + DatabricksDriverErrorCode.INPUT_VALIDATION_ERROR); + } + + StringBuilder sql = new StringBuilder(); + sql.append("INSERT INTO ") + .append(insertInfo.getTableName()) + .append(" (") + .append(String.join(", ", insertInfo.getColumns())) + .append(") VALUES "); + + // Generate placeholders for each row + String valueClause = "(" + "?, ".repeat(insertInfo.getColumns().size() - 1) + "?)"; + + for (int i = 0; i < numberOfRows; i++) { + if (i > 0) { + sql.append(", "); + } + sql.append(valueClause); + } + + return sql.toString(); + } +} diff --git a/src/test/java/com/databricks/jdbc/api/impl/DatabricksPreparedStatementTest.java b/src/test/java/com/databricks/jdbc/api/impl/DatabricksPreparedStatementTest.java index 4cdf8650f4..1b3cad75bd 100644 --- a/src/test/java/com/databricks/jdbc/api/impl/DatabricksPreparedStatementTest.java +++ b/src/test/java/com/databricks/jdbc/api/impl/DatabricksPreparedStatementTest.java @@ -5,6 +5,7 @@ import static org.junit.jupiter.api.Assertions.*; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.lenient; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; @@ -43,6 +44,8 @@ public class DatabricksPreparedStatementTest { "INSERT INTO orders (user_id, shard, region_code, namespace) VALUES (?, ?, ?, ?)"; private static final String JDBC_URL = "jdbc:databricks://sample-host.18.azuredatabricks.net:4423/default;transportMode=http;ssl=1;AuthMech=3;httpPath=/sql/1.0/warehouses/99999999;"; + private static final String JDBC_URL_WITH_BATCHED_INSERTS = + "jdbc:databricks://sample-host.18.azuredatabricks.net:4423/default;transportMode=http;ssl=1;AuthMech=3;httpPath=/sql/1.0/warehouses/99999999;EnableBatchedInserts=1;"; private static final String JDBC_URL_WITH_MANY_PARAMETERS = "jdbc:databricks://sample-host.18.azuredatabricks.net:4423/default;transportMode=http;ssl=1;AuthMech=3;httpPath=/sql/1.0/warehouses/99999999;supportManyParameters=1;"; private static final String JDBC_CLUSTER_URL_WITH_MANY_PARAMETERS = @@ -215,7 +218,7 @@ public void testExecuteLargeUpdateStatement() throws Exception { @Test public void testExecuteBatchStatement() throws Exception { IDatabricksConnectionContext connectionContext = - DatabricksConnectionContext.parse(JDBC_URL, new Properties()); + DatabricksConnectionContext.parse(JDBC_URL_WITH_BATCHED_INSERTS, new Properties()); DatabricksConnection connection = new DatabricksConnection(connectionContext, client); DatabricksPreparedStatement statement = new DatabricksPreparedStatement(connection, BATCH_STATEMENT); @@ -227,15 +230,20 @@ public void testExecuteBatchStatement() throws Exception { statement.setString(4, "value"); statement.addBatch(); } + // Our implementation converts single INSERT to multi-row INSERT for batching + String expectedMultiRowSQL = + "INSERT INTO orders (user_id, shard, region_code, namespace) VALUES (?, ?, ?, ?), (?, ?, ?, ?), (?, ?, ?, ?), (?, ?, ?, ?)"; when(client.executeStatement( - eq(BATCH_STATEMENT), + eq(expectedMultiRowSQL), eq(new Warehouse(WAREHOUSE_ID)), any(HashMap.class), eq(StatementType.UPDATE), any(IDatabricksSession.class), eq(statement))) .thenReturn(resultSet); - when(resultSet.getUpdateCount()).thenReturn(1L); + lenient() + .when(resultSet.getUpdateCount()) + .thenReturn(4L); // Multi-row INSERT returns total rows affected int[] expectedCountsResult = {1, 1, 1, 1}; int[] updateCounts = statement.executeBatch(); @@ -267,7 +275,7 @@ public void testGetMetaData_NoResultSet_NonSelectQuery_ReturnNull() throws Excep @Test public void testExecuteBatchStatementThrowsError() throws Exception { IDatabricksConnectionContext connectionContext = - DatabricksConnectionContext.parse(JDBC_URL, new Properties()); + DatabricksConnectionContext.parse(JDBC_URL_WITH_BATCHED_INSERTS, new Properties()); DatabricksConnection connection = new DatabricksConnection(connectionContext, client); DatabricksPreparedStatement statement = new DatabricksPreparedStatement(connection, BATCH_STATEMENT); @@ -280,26 +288,24 @@ public void testExecuteBatchStatementThrowsError() throws Exception { statement.addBatch(); } - // First call succeeds, subsequent calls fail + // Our implementation batches all into one multi-row INSERT, so if it fails, all fail + String expectedMultiRowSQL = + "INSERT INTO orders (user_id, shard, region_code, namespace) VALUES (?, ?, ?, ?), (?, ?, ?, ?), (?, ?, ?, ?), (?, ?, ?, ?)"; when(client.executeStatement( - eq(BATCH_STATEMENT), + eq(expectedMultiRowSQL), eq(new Warehouse(WAREHOUSE_ID)), any(HashMap.class), eq(StatementType.UPDATE), any(IDatabricksSession.class), eq(statement))) - .thenReturn(resultSet) .thenThrow(new SQLException()); - when(resultSet.getUpdateCount()).thenReturn(1L); DatabricksBatchUpdateException exception = assertThrows(DatabricksBatchUpdateException.class, statement::executeBatch); int[] updateCounts = exception.getUpdateCounts(); assertEquals(4, updateCounts.length); - // First statement should succeed - assertEquals(1, updateCounts[0]); - // Remaining statements should fail - for (int i = 1; i < 4; i++) { + // All statements should fail since they're batched into one multi-row INSERT + for (int i = 0; i < 4; i++) { assertEquals(Statement.EXECUTE_FAILED, updateCounts[i]); } } @@ -307,7 +313,7 @@ public void testExecuteBatchStatementThrowsError() throws Exception { @Test public void testExecuteLargeBatchStatement() throws Exception { IDatabricksConnectionContext connectionContext = - DatabricksConnectionContext.parse(JDBC_URL, new Properties()); + DatabricksConnectionContext.parse(JDBC_URL_WITH_BATCHED_INSERTS, new Properties()); DatabricksConnection connection = new DatabricksConnection(connectionContext, client); DatabricksPreparedStatement statement = new DatabricksPreparedStatement(connection, BATCH_STATEMENT); @@ -319,15 +325,20 @@ public void testExecuteLargeBatchStatement() throws Exception { statement.setString(4, "value"); statement.addBatch(); } + // Our implementation converts single INSERT to multi-row INSERT for batching + String expectedMultiRowSQL = + "INSERT INTO orders (user_id, shard, region_code, namespace) VALUES (?, ?, ?, ?), (?, ?, ?, ?), (?, ?, ?, ?), (?, ?, ?, ?)"; when(client.executeStatement( - eq(BATCH_STATEMENT), + eq(expectedMultiRowSQL), eq(new Warehouse(WAREHOUSE_ID)), any(HashMap.class), eq(StatementType.UPDATE), any(IDatabricksSession.class), eq(statement))) .thenReturn(resultSet); - when(resultSet.getUpdateCount()).thenReturn(1L); + lenient() + .when(resultSet.getUpdateCount()) + .thenReturn(4L); // Multi-row INSERT returns total rows affected long[] expectedCountsResult = {1, 1, 1, 1}; long[] updateCounts = statement.executeLargeBatch(); @@ -340,7 +351,7 @@ public void testExecuteLargeBatchStatement() throws Exception { @Test public void testExecuteLargeBatchStatementThrowsError() throws Exception { IDatabricksConnectionContext connectionContext = - DatabricksConnectionContext.parse(JDBC_URL, new Properties()); + DatabricksConnectionContext.parse(JDBC_URL_WITH_BATCHED_INSERTS, new Properties()); DatabricksConnection connection = new DatabricksConnection(connectionContext, client); DatabricksPreparedStatement statement = new DatabricksPreparedStatement(connection, BATCH_STATEMENT); @@ -353,26 +364,24 @@ public void testExecuteLargeBatchStatementThrowsError() throws Exception { statement.addBatch(); } - // First call succeeds, subsequent calls fail + // Our implementation batches all into one multi-row INSERT, so if it fails, all fail + String expectedMultiRowSQL = + "INSERT INTO orders (user_id, shard, region_code, namespace) VALUES (?, ?, ?, ?), (?, ?, ?, ?), (?, ?, ?, ?), (?, ?, ?, ?)"; when(client.executeStatement( - eq(BATCH_STATEMENT), + eq(expectedMultiRowSQL), eq(new Warehouse(WAREHOUSE_ID)), any(HashMap.class), eq(StatementType.UPDATE), any(IDatabricksSession.class), eq(statement))) - .thenReturn(resultSet) .thenThrow(new SQLException()); - when(resultSet.getUpdateCount()).thenReturn(1L); DatabricksBatchUpdateException exception = assertThrows(DatabricksBatchUpdateException.class, statement::executeLargeBatch); long[] updateCounts = exception.getLargeUpdateCounts(); assertEquals(4, updateCounts.length); - // First statement should succeed - assertEquals(1, updateCounts[0]); - // Remaining statements should fail - for (int i = 1; i < 4; i++) { + // All statements should fail since they're batched into one multi-row INSERT + for (int i = 0; i < 4; i++) { assertEquals(Statement.EXECUTE_FAILED, updateCounts[i]); } } @@ -635,6 +644,172 @@ public void testSetCharacterStreamWithoutLength() throws DatabricksSQLException assertDoesNotThrow(() -> preparedStatement.setCharacterStream(1, characterStream)); } + @Test + public void testExecuteLargeBatchWithParameterChunking() throws Exception { + // Test scenario that would exceed the 256 parameter limit and verify chunking works + // 5 columns × 60 rows = 300 parameters (exceeds 256 limit) + // Should be split into chunks: 51 rows + 9 rows (51 = 255/5, leaving 1 parameter short for + // safety) + + String largeBatchStatement = + "INSERT INTO products (id, name, price, category, description) VALUES (?, ?, ?, ?, ?)"; + IDatabricksConnectionContext connectionContext = + DatabricksConnectionContext.parse(JDBC_URL_WITH_BATCHED_INSERTS, new Properties()); + DatabricksConnection connection = new DatabricksConnection(connectionContext, client); + DatabricksPreparedStatement statement = + new DatabricksPreparedStatement(connection, largeBatchStatement); + + // Add 60 batches (5 columns each = 300 total parameters) + int totalBatches = 60; + for (int i = 1; i <= totalBatches; i++) { + statement.setInt(1, i); // id + statement.setString(2, "Product " + i); // name + statement.setBigDecimal(3, new BigDecimal("19.99")); // price + statement.setString(4, "Category " + (i % 5)); // category + statement.setString(5, "Description for product " + i); // description + statement.addBatch(); + } + + // Mock client to verify chunking behavior + // With 5 columns, max rows per chunk = 256/5 = 51 rows + // So 60 total rows should be split into 2 chunks: 51 + 9 + when(client.executeStatement( + any(String.class), // SQL will vary based on chunk size + eq(new Warehouse(WAREHOUSE_ID)), + any(HashMap.class), + eq(StatementType.UPDATE), + any(IDatabricksSession.class), + eq(statement))) + .thenReturn(resultSet); + lenient() + .when(resultSet.getUpdateCount()) + .thenReturn(51L) // First chunk: 51 rows + .thenReturn(9L); // Second chunk: 9 rows + + long[] updateCounts = statement.executeLargeBatch(); + + // Verify results + assertEquals(totalBatches, updateCounts.length); + + // All update counts should be 1 (each row affects 1 row) + for (int i = 0; i < totalBatches; i++) { + assertEquals(1, updateCounts[i], "Update count for batch " + i + " should be 1"); + } + + assertFalse(statement.isClosed()); + statement.close(); + assertTrue(statement.isClosed()); + } + + @Test + public void testExecuteLargeBatchWithManyColumnsChunking() throws Exception { + // Test edge case with very wide table that forces 1 row per chunk + // 300 columns would result in 0 rows per chunk calculation, should default to 1 + + StringBuilder largeSqlBuilder = new StringBuilder("INSERT INTO wide_table ("); + StringBuilder valuesBuilder = new StringBuilder("("); + + // Create SQL with 300 columns + int columnCount = 300; + for (int i = 1; i <= columnCount; i++) { + if (i > 1) { + largeSqlBuilder.append(", "); + valuesBuilder.append(", "); + } + largeSqlBuilder.append("col").append(i); + valuesBuilder.append("?"); + } + largeSqlBuilder.append(") VALUES ").append(valuesBuilder).append(")"); + + String wideTableStatement = largeSqlBuilder.toString(); + IDatabricksConnectionContext connectionContext = + DatabricksConnectionContext.parse(JDBC_URL_WITH_BATCHED_INSERTS, new Properties()); + DatabricksConnection connection = new DatabricksConnection(connectionContext, client); + DatabricksPreparedStatement statement = + new DatabricksPreparedStatement(connection, wideTableStatement); + + // Add 3 batches - each should be executed separately due to parameter limit + int totalBatches = 3; + for (int batchNum = 1; batchNum <= totalBatches; batchNum++) { + // Set all 300 parameters for this batch + for (int col = 1; col <= columnCount; col++) { + statement.setString(col, "value_" + batchNum + "_" + col); + } + statement.addBatch(); + } + + // Mock client - each batch should be executed individually due to parameter limit + when(client.executeStatement( + any(String.class), + eq(new Warehouse(WAREHOUSE_ID)), + any(HashMap.class), + eq(StatementType.UPDATE), + any(IDatabricksSession.class), + eq(statement))) + .thenReturn(resultSet); + lenient().when(resultSet.getUpdateCount()).thenReturn(1L); // Each execution affects 1 row + + long[] updateCounts = statement.executeLargeBatch(); + + // Verify results + assertEquals(totalBatches, updateCounts.length); + for (int i = 0; i < totalBatches; i++) { + assertEquals(1, updateCounts[i], "Update count for batch " + i + " should be 1"); + } + + assertFalse(statement.isClosed()); + statement.close(); + assertTrue(statement.isClosed()); + } + + @Test + public void testExecuteLargeBatchParameterChunkingOptimization() throws Exception { + // Test that we're actually getting the chunking optimization vs individual execution + // Use a 2-column table with 200 rows = 400 parameters (exceeds 256 limit) + // Should be chunked into: 128 rows + 72 rows (128 = 256/2) + + String simpleStatement = "INSERT INTO users (id, name) VALUES (?, ?)"; + IDatabricksConnectionContext connectionContext = + DatabricksConnectionContext.parse(JDBC_URL_WITH_BATCHED_INSERTS, new Properties()); + DatabricksConnection connection = new DatabricksConnection(connectionContext, client); + DatabricksPreparedStatement statement = + new DatabricksPreparedStatement(connection, simpleStatement); + + // Add 200 batches (2 columns each = 400 total parameters) + int totalBatches = 200; + for (int i = 1; i <= totalBatches; i++) { + statement.setInt(1, i); + statement.setString(2, "User " + i); + statement.addBatch(); + } + + // Mock the client to capture the generated SQL + when(client.executeStatement( + any(String.class), + eq(new Warehouse(WAREHOUSE_ID)), + any(HashMap.class), + eq(StatementType.UPDATE), + any(IDatabricksSession.class), + eq(statement))) + .thenReturn(resultSet); + lenient() + .when(resultSet.getUpdateCount()) + .thenReturn(128L) + .thenReturn(72L); // Two chunks: 128 + 72 + + long[] updateCounts = statement.executeLargeBatch(); + + // Verify results + assertEquals(totalBatches, updateCounts.length); + for (int i = 0; i < totalBatches; i++) { + assertEquals(1, updateCounts[i], "Update count for batch " + i + " should be 1"); + } + + assertFalse(statement.isClosed()); + statement.close(); + assertTrue(statement.isClosed()); + } + @Test void testUnsupportedMethods() throws DatabricksSQLException { IDatabricksConnectionContext connectionContext = diff --git a/src/test/java/com/databricks/jdbc/api/impl/DatabricksStatementTest.java b/src/test/java/com/databricks/jdbc/api/impl/DatabricksStatementTest.java index 81ac51a899..f1a897e1cb 100644 --- a/src/test/java/com/databricks/jdbc/api/impl/DatabricksStatementTest.java +++ b/src/test/java/com/databricks/jdbc/api/impl/DatabricksStatementTest.java @@ -666,6 +666,35 @@ public void testIsSelectQuery() { assertFalse(DatabricksStatement.isSelectQuery(query)); } + @Test + public void testIsInsertQuery() { + // Test basic INSERT statements + assertTrue(DatabricksStatement.isInsertQuery("INSERT INTO users (id, name) VALUES (?, ?)")); + assertTrue(DatabricksStatement.isInsertQuery("insert into users (id, name) values (?, ?)")); + assertTrue( + DatabricksStatement.isInsertQuery( + " INSERT INTO users (id, name) VALUES (?, ?)")); + + // Test INSERT with comments + String queryWithComments = + "-- Comment\n/* Multi-line */ INSERT INTO users (id) VALUES (?); -- End"; + assertTrue(DatabricksStatement.isInsertQuery(queryWithComments)); + + // Test non-INSERT statements + assertFalse(DatabricksStatement.isInsertQuery("SELECT * FROM users")); + assertFalse(DatabricksStatement.isInsertQuery("UPDATE users SET name = ?")); + assertFalse(DatabricksStatement.isInsertQuery("DELETE FROM users")); + assertFalse(DatabricksStatement.isInsertQuery("CREATE TABLE users (id INT)")); + assertFalse(DatabricksStatement.isInsertQuery("")); + assertFalse(DatabricksStatement.isInsertQuery(null)); + + // Test INSERT with schema prefix + assertTrue(DatabricksStatement.isInsertQuery("INSERT INTO schema.users (id) VALUES (?)")); + + // Test with parentheses at the beginning + assertTrue(DatabricksStatement.isInsertQuery("(INSERT INTO users (id) VALUES (?))")); + } + private DatabricksConnection getTestConnection() throws DatabricksSQLException { IDatabricksConnectionContext connectionContext = DatabricksConnectionContext.parse(JDBC_URL, new Properties()); diff --git a/src/test/java/com/databricks/jdbc/common/util/InsertStatementParserTest.java b/src/test/java/com/databricks/jdbc/common/util/InsertStatementParserTest.java new file mode 100644 index 0000000000..fbacca0379 --- /dev/null +++ b/src/test/java/com/databricks/jdbc/common/util/InsertStatementParserTest.java @@ -0,0 +1,269 @@ +package com.databricks.jdbc.common.util; + +import static org.junit.jupiter.api.Assertions.*; + +import com.databricks.jdbc.common.util.InsertStatementParser.InsertInfo; +import java.util.Arrays; +import java.util.List; +import org.junit.jupiter.api.Test; + +class InsertStatementParserTest { + + @Test + void testParseBasicInsert() { + String sql = "INSERT INTO users (id, name, email) VALUES (?, ?, ?)"; + InsertInfo info = InsertStatementParser.parseInsert(sql); + + assertNotNull(info); + assertEquals("users", info.getTableName()); + assertEquals(Arrays.asList("id", "name", "email"), info.getColumns()); + assertEquals(sql, info.getOriginalSql()); + } + + @Test + void testParseInsertWithWhitespace() { + String sql = " INSERT INTO users ( id , name , email ) VALUES ( ?, ?, ? )"; + InsertInfo info = InsertStatementParser.parseInsert(sql); + + assertNotNull(info); + assertEquals("users", info.getTableName()); + assertEquals(Arrays.asList("id", "name", "email"), info.getColumns()); + } + + @Test + void testParseInsertWithBackticks() { + String sql = "INSERT INTO `my_table` (`id`, `user_name`, `email_address`) VALUES (?, ?, ?)"; + InsertInfo info = InsertStatementParser.parseInsert(sql); + + assertNotNull(info); + assertEquals("`my_table`", info.getTableName()); + assertEquals(Arrays.asList("id", "user_name", "email_address"), info.getColumns()); + } + + @Test + void testParseInsertWithSchemaPrefix() { + String sql = "INSERT INTO schema.users (id, name) VALUES (?, ?)"; + InsertInfo info = InsertStatementParser.parseInsert(sql); + + assertNotNull(info); + assertEquals("schema.users", info.getTableName()); + assertEquals(Arrays.asList("id", "name"), info.getColumns()); + } + + @Test + void testParseInsertCaseInsensitive() { + String sql = "insert into Users (ID, Name) values (?, ?)"; + InsertInfo info = InsertStatementParser.parseInsert(sql); + + assertNotNull(info); + assertEquals("Users", info.getTableName()); + assertEquals(Arrays.asList("ID", "Name"), info.getColumns()); + } + + @Test + void testParseInvalidSql() { + assertNull(InsertStatementParser.parseInsert("SELECT * FROM users")); + assertNull(InsertStatementParser.parseInsert("UPDATE users SET name = ?")); + assertNull(InsertStatementParser.parseInsert("DELETE FROM users")); + assertNull(InsertStatementParser.parseInsert(null)); + assertNull(InsertStatementParser.parseInsert("")); + assertNull(InsertStatementParser.parseInsert(" ")); + } + + @Test + void testParseInsertWithoutValues() { + String sql = "INSERT INTO users (id, name)"; + InsertInfo info = InsertStatementParser.parseInsert(sql); + assertNull(info); + } + + @Test + void testParseInsertWithoutColumns() { + String sql = "INSERT INTO users VALUES (?, ?)"; + InsertInfo info = InsertStatementParser.parseInsert(sql); + assertNull(info); + } + + @Test + void testIsParametrizedInsert() { + assertTrue( + InsertStatementParser.isParametrizedInsert("INSERT INTO users (id, name) VALUES (?, ?)")); + assertFalse( + InsertStatementParser.isParametrizedInsert( + "INSERT INTO users (id, name) VALUES (1, 'John')")); + assertFalse(InsertStatementParser.isParametrizedInsert("SELECT * FROM users")); + assertFalse(InsertStatementParser.isParametrizedInsert(null)); + } + + @Test + void testInsertInfoCompatibility() { + InsertInfo info1 = + InsertStatementParser.parseInsert("INSERT INTO users (id, name) VALUES (?, ?)"); + InsertInfo info2 = + InsertStatementParser.parseInsert("INSERT INTO users (id, name) VALUES (?, ?)"); + InsertInfo info3 = + InsertStatementParser.parseInsert("INSERT INTO users (id, email) VALUES (?, ?)"); + InsertInfo info4 = + InsertStatementParser.parseInsert("INSERT INTO orders (id, name) VALUES (?, ?)"); + + assertNotNull(info1); + assertNotNull(info2); + assertNotNull(info3); + assertNotNull(info4); + + assertTrue(info1.isCompatibleWith(info2)); + assertFalse(info1.isCompatibleWith(info3)); // Different columns + assertFalse(info1.isCompatibleWith(info4)); // Different table + } + + @Test + void testGenerateMultiRowInsert() throws Exception { + InsertInfo info = + InsertStatementParser.parseInsert("INSERT INTO users (id, name, email) VALUES (?, ?, ?)"); + assertNotNull(info); + + String multiRowSql = InsertStatementParser.generateMultiRowInsert(info, 3); + String expected = "INSERT INTO users (id, name, email) VALUES (?, ?, ?), (?, ?, ?), (?, ?, ?)"; + assertEquals(expected, multiRowSql); + } + + @Test + void testGenerateMultiRowInsertSingleRow() throws Exception { + InsertInfo info = + InsertStatementParser.parseInsert("INSERT INTO users (id, name) VALUES (?, ?)"); + assertNotNull(info); + + String multiRowSql = InsertStatementParser.generateMultiRowInsert(info, 1); + String expected = "INSERT INTO users (id, name) VALUES (?, ?)"; + assertEquals(expected, multiRowSql); + } + + @Test + void testGenerateMultiRowInsertInvalidInput() { + InsertInfo info = + InsertStatementParser.parseInsert("INSERT INTO users (id, name) VALUES (?, ?)"); + assertNotNull(info); + + // Test that exceptions are thrown for invalid inputs + assertThrows(Exception.class, () -> InsertStatementParser.generateMultiRowInsert(null, 3)); + assertThrows(Exception.class, () -> InsertStatementParser.generateMultiRowInsert(info, 0)); + assertThrows(Exception.class, () -> InsertStatementParser.generateMultiRowInsert(info, -1)); + } + + @Test + void testInsertInfoEqualsAndHashCode() { + InsertInfo info1 = + new InsertInfo( + "users", List.of("id", "name"), "INSERT INTO users (id, name) VALUES (?, ?)"); + InsertInfo info2 = + new InsertInfo( + "users", List.of("id", "name"), "INSERT INTO users (id, name) VALUES (?, ?)"); + InsertInfo info3 = + new InsertInfo( + "users", List.of("id", "email"), "INSERT INTO users (id, email) VALUES (?, ?)"); + + assertEquals(info1, info2); + assertNotEquals(info1, info3); + assertEquals(info1.hashCode(), info2.hashCode()); + assertNotEquals(info1.hashCode(), info3.hashCode()); + } + + @Test + void testGetColumnCount() { + InsertInfo info2Cols = + InsertStatementParser.parseInsert("INSERT INTO users (id, name) VALUES (?, ?)"); + assertNotNull(info2Cols); + assertEquals(2, info2Cols.getColumnCount()); + + InsertInfo info5Cols = + InsertStatementParser.parseInsert( + "INSERT INTO products (id, name, price, category, description) VALUES (?, ?, ?, ?, ?)"); + assertNotNull(info5Cols); + assertEquals(5, info5Cols.getColumnCount()); + + InsertInfo info1Col = InsertStatementParser.parseInsert("INSERT INTO simple (id) VALUES (?)"); + assertNotNull(info1Col); + assertEquals(1, info1Col.getColumnCount()); + } + + @Test + void testParameterLimitCalculations() { + // Test parameter limit calculations that would be used in chunking logic + + // 5 columns: 256/5 = 51 rows per chunk + InsertInfo info5Cols = + InsertStatementParser.parseInsert( + "INSERT INTO products (id, name, price, category, description) VALUES (?, ?, ?, ?, ?)"); + assertNotNull(info5Cols); + int maxRowsFor5Cols = 256 / info5Cols.getColumnCount(); + assertEquals(51, maxRowsFor5Cols); + + // 2 columns: 256/2 = 128 rows per chunk + InsertInfo info2Cols = + InsertStatementParser.parseInsert("INSERT INTO users (id, name) VALUES (?, ?)"); + assertNotNull(info2Cols); + int maxRowsFor2Cols = 256 / info2Cols.getColumnCount(); + assertEquals(128, maxRowsFor2Cols); + + // 10 columns: 256/10 = 25 rows per chunk + InsertInfo info10Cols = + InsertStatementParser.parseInsert( + "INSERT INTO wide_table (c1, c2, c3, c4, c5, c6, c7, c8, c9, c10) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)"); + assertNotNull(info10Cols); + int maxRowsFor10Cols = 256 / info10Cols.getColumnCount(); + assertEquals(25, maxRowsFor10Cols); + + // Edge case: 300 columns would result in 0 rows per chunk, should be handled as 1 + InsertInfo info300Cols = InsertStatementParser.parseInsert(generateLargeInsert(300)); + assertNotNull(info300Cols); + assertEquals(300, info300Cols.getColumnCount()); + int maxRowsFor300Cols = 256 / info300Cols.getColumnCount(); + assertEquals(0, maxRowsFor300Cols); // This would need to be handled as 1 in the actual code + } + + @Test + void testChunkingScenarios() { + // Test realistic chunking scenarios + + // Scenario 1: Large batch with 5 columns, 10000 rows + InsertInfo info5Cols = + InsertStatementParser.parseInsert( + "INSERT INTO products (id, name, price, category, description) VALUES (?, ?, ?, ?, ?)"); + assertNotNull(info5Cols); + assertEquals(5, info5Cols.getColumnCount()); + + int totalRows = 10000; + int maxRowsPerChunk = 256 / info5Cols.getColumnCount(); // 51 rows per chunk + int expectedChunks = (int) Math.ceil((double) totalRows / maxRowsPerChunk); // 197 chunks + assertEquals(51, maxRowsPerChunk); + assertEquals(197, expectedChunks); + + // Scenario 2: Batch that would exceed parameter limit in one go + InsertInfo info2Cols = + InsertStatementParser.parseInsert("INSERT INTO users (id, name) VALUES (?, ?)"); + assertNotNull(info2Cols); + assertEquals(2, info2Cols.getColumnCount()); + + int batchSize = 200; // Would be 400 parameters (200 * 2 columns), exceeding 256 limit + int maxRowsFor2Cols = 256 / info2Cols.getColumnCount(); // 128 rows per chunk + int neededChunks = (int) Math.ceil((double) batchSize / maxRowsFor2Cols); // 2 chunks + assertEquals(128, maxRowsFor2Cols); + assertEquals(2, neededChunks); + } + + private String generateLargeInsert(int columnCount) { + StringBuilder columns = new StringBuilder(); + StringBuilder values = new StringBuilder(); + + for (int i = 1; i <= columnCount; i++) { + if (i > 1) { + columns.append(", "); + values.append(", "); + } + columns.append("col").append(i); + values.append("?"); + } + + return "INSERT INTO large_table (" + columns + ") VALUES (" + values + ")"; + } +}