Skip to content

Commit 3f851cf

Browse files
authored
[PECOBLR-1309] Add validation in user input (in case of integers) (#1112)
## Description - This addresses the issue in #1103 (i.e., rowsFetchedPerBlock field) and also adds validation across other integer fields that require to be positive. ## Testing - Added unit tests ## Additional Notes to the Reviewer <!-- Share any additional context or insights that may help the reviewer understand the changes better. This could include challenges faced, limitations, or compromises made during the development process. Also, mention any areas of the code that you would like the reviewer to focus on specifically. -->
1 parent c7f0868 commit 3f851cf

21 files changed

Lines changed: 451 additions & 300 deletions

NEXT_CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
- Added the EnableTokenFederation url param to enable or disable Token federation feature. By default it is set to 1
77

88
### Updated
9+
- Added validation for positive integer configuration properties (RowsFetchedPerBlock, BatchInsertSize, etc.) to prevent hangs and errors when set to zero or negative values.
910
- Updated Circuit breaker to be triggered by 429 errors too.
1011

1112
### Fixed

src/main/java/com/databricks/jdbc/api/impl/DatabricksConnection.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,15 +87,15 @@ public IDatabricksSession getSession() {
8787
}
8888

8989
@Override
90-
public Statement createStatement() {
90+
public Statement createStatement() throws SQLException {
9191
LOGGER.debug("public Statement createStatement()");
9292
DatabricksStatement statement = new DatabricksStatement(this);
9393
statementSet.add(statement);
9494
return statement;
9595
}
9696

9797
@Override
98-
public PreparedStatement prepareStatement(String sql) {
98+
public PreparedStatement prepareStatement(String sql) throws SQLException {
9999
LOGGER.debug(
100100
String.format("public PreparedStatement prepareStatement(String sql = {%s})", sql));
101101
DatabricksPreparedStatement statement = new DatabricksPreparedStatement(this, sql);

src/main/java/com/databricks/jdbc/api/impl/DatabricksConnectionContext.java

Lines changed: 41 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -275,8 +275,10 @@ public String getPassThroughAccessToken() {
275275
}
276276

277277
@Override
278-
public int getAsyncExecPollInterval() {
279-
return Integer.parseInt(getParameter(DatabricksJdbcUrlParams.POLL_INTERVAL));
278+
public int getAsyncExecPollInterval() throws DatabricksValidationException {
279+
return ValidationUtil.validateAndParsePositiveInteger(
280+
getParameter(DatabricksJdbcUrlParams.POLL_INTERVAL),
281+
DatabricksJdbcUrlParams.POLL_INTERVAL.getParamName());
280282
}
281283

282284
@Override
@@ -405,13 +407,17 @@ public String getLogPathString() {
405407
}
406408

407409
@Override
408-
public int getLogFileSize() {
409-
return Integer.parseInt(getParameter(DatabricksJdbcUrlParams.LOG_FILE_SIZE));
410+
public int getLogFileSize() throws DatabricksValidationException {
411+
return ValidationUtil.validateAndParsePositiveInteger(
412+
getParameter(DatabricksJdbcUrlParams.LOG_FILE_SIZE),
413+
DatabricksJdbcUrlParams.LOG_FILE_SIZE.getParamName());
410414
}
411415

412416
@Override
413-
public int getLogFileCount() {
414-
return Integer.parseInt(getParameter(DatabricksJdbcUrlParams.LOG_FILE_COUNT));
417+
public int getLogFileCount() throws DatabricksValidationException {
418+
return ValidationUtil.validateAndParsePositiveInteger(
419+
getParameter(DatabricksJdbcUrlParams.LOG_FILE_COUNT),
420+
DatabricksJdbcUrlParams.LOG_FILE_COUNT.getParamName());
415421
}
416422

417423
@Override
@@ -476,8 +482,10 @@ public void setClientType(DatabricksClientType clientType) {
476482
}
477483

478484
@Override
479-
public int getCloudFetchThreadPoolSize() {
480-
return Integer.parseInt(getParameter(DatabricksJdbcUrlParams.CLOUD_FETCH_THREAD_POOL_SIZE));
485+
public int getCloudFetchThreadPoolSize() throws DatabricksValidationException {
486+
return ValidationUtil.validateAndParsePositiveInteger(
487+
getParameter(DatabricksJdbcUrlParams.CLOUD_FETCH_THREAD_POOL_SIZE),
488+
DatabricksJdbcUrlParams.CLOUD_FETCH_THREAD_POOL_SIZE.getParamName());
481489
}
482490

483491
@Override
@@ -833,8 +841,10 @@ public String getSSLKeyStoreProvider() {
833841
}
834842

835843
@Override
836-
public int getMaxBatchSize() {
837-
return Integer.parseInt(getParameter(DatabricksJdbcUrlParams.MAX_BATCH_SIZE));
844+
public int getMaxBatchSize() throws DatabricksValidationException {
845+
return ValidationUtil.validateAndParsePositiveInteger(
846+
getParameter(DatabricksJdbcUrlParams.MAX_BATCH_SIZE),
847+
DatabricksJdbcUrlParams.MAX_BATCH_SIZE.getParamName());
838848
}
839849

840850
@Override
@@ -844,12 +854,21 @@ public String getConnectionUuid() {
844854

845855
@Override
846856
public int getTelemetryBatchSize() {
847-
return Integer.parseInt(getParameter(DatabricksJdbcUrlParams.TELEMETRY_BATCH_SIZE));
857+
try {
858+
return ValidationUtil.validateAndParsePositiveInteger(
859+
getParameter(DatabricksJdbcUrlParams.TELEMETRY_BATCH_SIZE),
860+
DatabricksJdbcUrlParams.TELEMETRY_BATCH_SIZE.getParamName());
861+
} catch (DatabricksValidationException e) {
862+
// We don't want to throw any errors related to telemetry.
863+
return Integer.parseInt(DatabricksJdbcUrlParams.TELEMETRY_BATCH_SIZE.getDefaultValue());
864+
}
848865
}
849866

850867
@Override
851-
public int getBatchInsertSize() {
852-
return Integer.parseInt(getParameter(DatabricksJdbcUrlParams.BATCH_INSERT_SIZE));
868+
public int getBatchInsertSize() throws DatabricksValidationException {
869+
return ValidationUtil.validateAndParsePositiveInteger(
870+
getParameter(DatabricksJdbcUrlParams.BATCH_INSERT_SIZE),
871+
DatabricksJdbcUrlParams.BATCH_INSERT_SIZE.getParamName());
853872
}
854873

855874
@Override
@@ -928,8 +947,10 @@ public boolean isRequestTracingEnabled() {
928947
}
929948

930949
@Override
931-
public int getHttpConnectionPoolSize() {
932-
return Integer.parseInt(getParameter(DatabricksJdbcUrlParams.HTTP_CONNECTION_POOL_SIZE));
950+
public int getHttpConnectionPoolSize() throws DatabricksValidationException {
951+
return ValidationUtil.validateAndParsePositiveInteger(
952+
getParameter(DatabricksJdbcUrlParams.HTTP_CONNECTION_POOL_SIZE),
953+
DatabricksJdbcUrlParams.HTTP_CONNECTION_POOL_SIZE.getParamName());
933954
}
934955

935956
@Override
@@ -972,13 +993,14 @@ public String getAzureWorkspaceResourceId() {
972993

973994
@Override
974995
public int getRowsFetchedPerBlock() {
975-
int maxRows = DEFAULT_ROW_LIMIT_PER_BLOCK;
976996
try {
977-
maxRows = Integer.parseInt(getParameter(DatabricksJdbcUrlParams.ROWS_FETCHED_PER_BLOCK));
978-
} catch (NumberFormatException e) {
997+
return ValidationUtil.validateAndParsePositiveInteger(
998+
getParameter(DatabricksJdbcUrlParams.ROWS_FETCHED_PER_BLOCK),
999+
DatabricksJdbcUrlParams.ROWS_FETCHED_PER_BLOCK.getParamName());
1000+
} catch (DatabricksValidationException exception) {
9791001
LOGGER.warn("Invalid value for RowsFetchedPerBlock, using default value");
9801002
}
981-
return maxRows;
1003+
return DEFAULT_ROW_LIMIT_PER_BLOCK;
9821004
}
9831005

9841006
/** {@inheritDoc} */

src/main/java/com/databricks/jdbc/api/impl/DatabricksPreparedStatement.java

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,8 @@ public class DatabricksPreparedStatement extends DatabricksStatement implements
3737
private final boolean interpolateParameters;
3838
private final int CHUNK_SIZE = 8192;
3939

40-
public DatabricksPreparedStatement(DatabricksConnection connection, String sql) {
40+
public DatabricksPreparedStatement(DatabricksConnection connection, String sql)
41+
throws DatabricksValidationException {
4142
super(connection);
4243
this.sql = sql;
4344
this.interpolateParameters = connection.getConnectionContext().supportManyParameters();
@@ -49,7 +50,8 @@ public DatabricksPreparedStatement(DatabricksConnection connection, String sql)
4950
DatabricksConnection connection,
5051
String sql,
5152
boolean interpolateParameters,
52-
DatabricksParameterMetaData databricksParameterMetaData) {
53+
DatabricksParameterMetaData databricksParameterMetaData)
54+
throws DatabricksValidationException {
5355
super(connection);
5456
this.sql = sql;
5557
this.interpolateParameters = interpolateParameters;

src/main/java/com/databricks/jdbc/api/impl/DatabricksStatement.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ public class DatabricksStatement implements IDatabricksStatement, IDatabricksSta
4848
private final DatabricksBatchExecutor databricksBatchExecutor;
4949
private boolean noMoreResults = false; // JDBC end-of-results indicator
5050

51-
public DatabricksStatement(DatabricksConnection connection) {
51+
public DatabricksStatement(DatabricksConnection connection) throws DatabricksValidationException {
5252
this.connection = connection;
5353
this.resultSet = null;
5454
this.statementId = null;
@@ -58,7 +58,8 @@ public DatabricksStatement(DatabricksConnection connection) {
5858
new DatabricksBatchExecutor(this, connection.getConnectionContext().getMaxBatchSize());
5959
}
6060

61-
public DatabricksStatement(DatabricksConnection connection, StatementId statementId) {
61+
public DatabricksStatement(DatabricksConnection connection, StatementId statementId)
62+
throws DatabricksValidationException {
6263
this.connection = connection;
6364
this.statementId = statementId;
6465
this.resultSet = null;

src/main/java/com/databricks/jdbc/api/internal/IDatabricksConnectionContext.java

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import com.databricks.jdbc.common.*;
44
import com.databricks.jdbc.exception.DatabricksParsingException;
5+
import com.databricks.jdbc.exception.DatabricksValidationException;
56
import com.databricks.sdk.core.ProxyConfig;
67
import com.databricks.sdk.core.utils.Cloud;
78
import java.util.List;
@@ -78,9 +79,9 @@ public interface IDatabricksConnectionContext {
7879

7980
String getLogPathString();
8081

81-
int getLogFileSize();
82+
int getLogFileSize() throws DatabricksValidationException;
8283

83-
int getLogFileCount();
84+
int getLogFileCount() throws DatabricksValidationException;
8485

8586
/** Returns the userAgent string specific to client used to fetch results. */
8687
String getClientUserAgent();
@@ -145,7 +146,7 @@ public interface IDatabricksConnectionContext {
145146

146147
String getEndpointURL() throws DatabricksParsingException;
147148

148-
int getAsyncExecPollInterval();
149+
int getAsyncExecPollInterval() throws DatabricksValidationException;
149150

150151
Boolean shouldEnableArrow();
151152

@@ -156,7 +157,7 @@ public interface IDatabricksConnectionContext {
156157
Boolean getUseEmptyMetadata();
157158

158159
/** Returns the number of threads to be used for fetching data from cloud storage */
159-
int getCloudFetchThreadPoolSize();
160+
int getCloudFetchThreadPoolSize() throws DatabricksValidationException;
160161

161162
/** Returns the minimum expected download speed threshold in MB/s for CloudFetch operations */
162163
double getCloudFetchSpeedThreshold();
@@ -260,7 +261,7 @@ public interface IDatabricksConnectionContext {
260261
String getSSLTrustStoreProvider();
261262

262263
/** Returns the maximum number of commands that can be executed in a single batch. */
263-
int getMaxBatchSize();
264+
int getMaxBatchSize() throws DatabricksValidationException;
264265

265266
/** Checks if Telemetry is enabled */
266267
boolean isTelemetryEnabled();
@@ -269,7 +270,7 @@ public interface IDatabricksConnectionContext {
269270
int getTelemetryBatchSize();
270271

271272
/** Returns the maximum number of rows per batch insert execution */
272-
int getBatchInsertSize();
273+
int getBatchInsertSize() throws DatabricksValidationException;
273274

274275
/**
275276
* Returns a unique identifier for this connection context.
@@ -307,7 +308,7 @@ public interface IDatabricksConnectionContext {
307308
boolean isGeoSpatialSupportEnabled();
308309

309310
/** Returns the size for HTTP connection pool */
310-
int getHttpConnectionPoolSize();
311+
int getHttpConnectionPoolSize() throws DatabricksValidationException;
311312

312313
/** Returns the list of HTTP codes to retry for UC Volume Ingestion */
313314
List<Integer> getUCIngestionRetriableHttpCodes();

src/main/java/com/databricks/jdbc/common/DatabricksClientConfiguratorManager.java

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import com.databricks.jdbc.dbclient.impl.common.ClientConfigurator;
55
import com.databricks.jdbc.exception.DatabricksDriverException;
66
import com.databricks.jdbc.exception.DatabricksSSLException;
7+
import com.databricks.jdbc.exception.DatabricksValidationException;
78
import com.databricks.jdbc.log.JdbcLogger;
89
import com.databricks.jdbc.log.JdbcLoggerFactory;
910
import com.databricks.jdbc.model.telemetry.enums.DatabricksDriverErrorCode;
@@ -33,6 +34,13 @@ public ClientConfigurator getConfigurator(IDatabricksConnectionContext context)
3334
String.format("client configurator failed due to SSL error: %s", e.getMessage());
3435
LOGGER.error(e, message);
3536
throw new DatabricksDriverException(message, DatabricksDriverErrorCode.AUTH_ERROR);
37+
} catch (DatabricksValidationException e) {
38+
String message =
39+
String.format(
40+
"client configurator failed due to validation error: %s", e.getMessage());
41+
LOGGER.error(e, message);
42+
throw new DatabricksDriverException(
43+
message, DatabricksDriverErrorCode.INPUT_VALIDATION_ERROR);
3644
}
3745
});
3846
} catch (Exception ex) {

src/main/java/com/databricks/jdbc/common/util/ValidationUtil.java

Lines changed: 51 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44

55
import com.databricks.jdbc.common.DatabricksJdbcUrlParams;
66
import com.databricks.jdbc.exception.DatabricksHttpException;
7-
import com.databricks.jdbc.exception.DatabricksSQLException;
87
import com.databricks.jdbc.exception.DatabricksValidationException;
98
import com.databricks.jdbc.exception.DatabricksVendorCode;
109
import com.databricks.jdbc.log.JdbcLogger;
@@ -22,7 +21,7 @@ public class ValidationUtil {
2221
private static final JdbcLogger LOGGER = JdbcLoggerFactory.getLogger(ValidationUtil.class);
2322

2423
public static <T extends Number> void checkIfNonNegative(T number, String fieldName)
25-
throws DatabricksSQLException {
24+
throws DatabricksValidationException {
2625
if (number.longValue() < 0) {
2726
String errorMessage =
2827
String.format("Invalid input for %s, : %d", fieldName, number.longValue());
@@ -31,8 +30,51 @@ public static <T extends Number> void checkIfNonNegative(T number, String fieldN
3130
}
3231
}
3332

33+
/**
34+
* Validates that a number is positive (greater than 0).
35+
*
36+
* @param number the number to validate
37+
* @param fieldName the name of the field being validated
38+
* @throws DatabricksValidationException if the number is not positive
39+
*/
40+
public static <T extends Number> void checkIfPositive(T number, String fieldName)
41+
throws DatabricksValidationException {
42+
if (number.longValue() <= 0) {
43+
String errorMessage =
44+
String.format(
45+
"Invalid value for %s: %d. Value must be a positive integer (> 0).",
46+
fieldName, number.longValue());
47+
LOGGER.error(errorMessage);
48+
throw new DatabricksValidationException(errorMessage);
49+
}
50+
}
51+
52+
/**
53+
* Parses a string to an integer and validates that it is positive (greater than 0).
54+
*
55+
* @param value the string value to parse
56+
* @param fieldName the name of the field being validated
57+
* @return the parsed positive integer
58+
* @throws DatabricksValidationException if the value cannot be parsed or is not positive
59+
*/
60+
public static int validateAndParsePositiveInteger(String value, String fieldName)
61+
throws DatabricksValidationException {
62+
try {
63+
int parsedValue = Integer.parseInt(value);
64+
checkIfPositive(parsedValue, fieldName);
65+
return parsedValue;
66+
} catch (NumberFormatException e) {
67+
String errorMessage =
68+
String.format(
69+
"Invalid value for %s: '%s'. Value must be a valid positive integer.",
70+
fieldName, value);
71+
LOGGER.error(errorMessage);
72+
throw new DatabricksValidationException(errorMessage);
73+
}
74+
}
75+
3476
public static void throwErrorIfNull(Map<String, String> fields, String context)
35-
throws DatabricksSQLException {
77+
throws DatabricksValidationException {
3678
for (Map.Entry<String, String> field : fields.entrySet()) {
3779
if (field.getValue() == null) {
3880
String errorMessage =
@@ -44,7 +86,8 @@ public static void throwErrorIfNull(Map<String, String> fields, String context)
4486
}
4587
}
4688

47-
public static void throwErrorIfNull(String field, Object value) throws DatabricksSQLException {
89+
public static void throwErrorIfNull(String field, Object value)
90+
throws DatabricksValidationException {
4891
if (value != null) {
4992
return;
5093
}
@@ -116,10 +159,10 @@ public static boolean isValidJdbcUrl(String url) {
116159
* maintainability and extensibility.
117160
*
118161
* @param parameters Map of JDBC connection parameters to validate
119-
* @throws DatabricksSQLException if any validation fails
162+
* @throws DatabricksValidationException if any validation fails
120163
*/
121164
public static void validateInputProperties(Map<String, String> parameters)
122-
throws DatabricksSQLException {
165+
throws DatabricksValidationException {
123166
// Validate UID parameter
124167
validateUidParameter(parameters);
125168

@@ -131,10 +174,10 @@ public static void validateInputProperties(Map<String, String> parameters)
131174
* "token".
132175
*
133176
* @param parameters Map of JDBC connection parameters
134-
* @throws DatabricksSQLException if UID validation fails
177+
* @throws DatabricksValidationException if UID validation fails
135178
*/
136179
public static void validateUidParameter(Map<String, String> parameters)
137-
throws DatabricksSQLException {
180+
throws DatabricksValidationException {
138181
String uid = parameters.get(DatabricksJdbcUrlParams.UID.getParamName());
139182
// UID must either be omitted or set to "token"
140183
if (uid != null && !uid.equals(VALID_UID_VALUE)) {

0 commit comments

Comments
 (0)