diff --git a/CHANGELOG.md b/CHANGELOG.md index 8a9dcf8d65..92c57851dc 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -104,6 +104,7 @@ #### Improvements +- Improved `DataFrameReader.dbapi`(PuPr) that dbapi will not retry on non-retryable error such as SQL syntax error on external data source query. - Removed unnecessary warnings about local package version mismatch when using `session.read.option('rowTag', ).xml()` or `xpath` functions. - Improved `DataFrameReader.dbapi` (PuPr) reading performance by setting the default `fetch_size` parameter value to 100000. - Improved error message for XSD validation failure when reading XML files using `session.read.option('rowValidationXSDPath', ).xml()`. diff --git a/src/snowflake/snowpark/_internal/data_source/datasource_reader.py b/src/snowflake/snowpark/_internal/data_source/datasource_reader.py index cb2ea1e01d..24f1eefb0d 100644 --- a/src/snowflake/snowpark/_internal/data_source/datasource_reader.py +++ b/src/snowflake/snowpark/_internal/data_source/datasource_reader.py @@ -85,6 +85,11 @@ def read(self, partition: str) -> Iterator[List[Any]]: batch = [] else: raise ValueError("fetch size cannot be smaller than 0") + except Exception as exc: + if self.driver.non_retryable_error_checker(exc): + raise SnowparkDataframeReaderException(message=str(exc)) + else: + raise finally: try: cursor.close() diff --git a/src/snowflake/snowpark/_internal/data_source/drivers/base_driver.py b/src/snowflake/snowpark/_internal/data_source/drivers/base_driver.py index e431bbc645..01d083cb10 100644 --- a/src/snowflake/snowpark/_internal/data_source/drivers/base_driver.py +++ b/src/snowflake/snowpark/_internal/data_source/drivers/base_driver.py @@ -55,6 +55,9 @@ def to_snow_type(self, schema: List[Any]) -> StructType: f"{self.__class__.__name__} has not implemented to_snow_type function" ) + def non_retryable_error_checker(self, error: Exception) -> bool: + return False + @staticmethod def prepare_connection( conn: "Connection", diff --git a/src/snowflake/snowpark/_internal/data_source/drivers/databricks_driver.py b/src/snowflake/snowpark/_internal/data_source/drivers/databricks_driver.py index 62d1e2bcce..8cc4457f7c 100644 --- a/src/snowflake/snowpark/_internal/data_source/drivers/databricks_driver.py +++ b/src/snowflake/snowpark/_internal/data_source/drivers/databricks_driver.py @@ -68,6 +68,18 @@ def to_snow_type(self, schema: List[Any]) -> StructType: all_columns.append(StructField(column_name, data_type, True)) return StructType(all_columns) + def non_retryable_error_checker(self, error: Exception) -> bool: + import databricks.sql + + if isinstance(error, databricks.sql.ServerOperationError): + syntax_error_codes = [ + "PARSE_SYNTAX_ERROR", # syntax error + ] + for error_code in syntax_error_codes: + if error_code in str(error): + return True + return False + def udtf_class_builder( self, fetch_size: int = 1000, diff --git a/src/snowflake/snowpark/_internal/data_source/drivers/oracledb_driver.py b/src/snowflake/snowpark/_internal/data_source/drivers/oracledb_driver.py index 745adec64f..d479fa705a 100644 --- a/src/snowflake/snowpark/_internal/data_source/drivers/oracledb_driver.py +++ b/src/snowflake/snowpark/_internal/data_source/drivers/oracledb_driver.py @@ -111,6 +111,30 @@ def prepare_connection( conn.outputtypehandler = output_type_handler return conn + def non_retryable_error_checker(self, error: Exception) -> bool: + import oracledb + + if isinstance(error, oracledb.DatabaseError): + syntax_error_codes = [ + "ORA-00900", # invalid SQL statement + "ORA-00901", # invalid CREATE command + "ORA-00904", # invalid identifier + "ORA-00905", # missing keyword + "ORA-00906", # missing left parenthesis + "ORA-00907", # missing right parenthesis + "ORA-00911", # invalid character + "ORA-00920", # invalid relational operator + "ORA-00921", # unexpected end of SQL command + "ORA-00923", # FROM keyword not found where expected + "ORA-00933", # SQL command not properly ended + "ORA-00936", # missing expression + "ORA-00942", # table or view does not exist + ] + for error_code in syntax_error_codes: + if error_code in str(error): + return True + return False + def udtf_class_builder( self, fetch_size: int = 1000, diff --git a/src/snowflake/snowpark/_internal/data_source/drivers/psycopg2_driver.py b/src/snowflake/snowpark/_internal/data_source/drivers/psycopg2_driver.py index ee9575f090..455c6a8641 100644 --- a/src/snowflake/snowpark/_internal/data_source/drivers/psycopg2_driver.py +++ b/src/snowflake/snowpark/_internal/data_source/drivers/psycopg2_driver.py @@ -211,6 +211,18 @@ def to_snow_type(self, schema: List[Any]) -> StructType: fields.append(StructField(name, data_type, True)) return StructType(fields) + def non_retryable_error_checker(self, error: Exception) -> bool: + import psycopg2 + + if isinstance(error, psycopg2.errors.SyntaxError): + syntax_error_codes = [ + "42601", # syntax error + ] + for error_code in syntax_error_codes: + if error_code == str(error.pgcode): + return True + return False + @staticmethod def to_result_snowpark_df( session: "Session", table_name, schema, _emit_ast: bool = True diff --git a/src/snowflake/snowpark/_internal/data_source/drivers/pymsql_driver.py b/src/snowflake/snowpark/_internal/data_source/drivers/pymsql_driver.py index 7428bce4af..7eae0f7b15 100644 --- a/src/snowflake/snowpark/_internal/data_source/drivers/pymsql_driver.py +++ b/src/snowflake/snowpark/_internal/data_source/drivers/pymsql_driver.py @@ -183,6 +183,18 @@ def to_snow_type(self, schema: List[Any]) -> StructType: fields.append(StructField(name, data_type, null_ok)) return StructType(fields) + def non_retryable_error_checker(self, error: Exception) -> bool: + import pymysql + + if isinstance(error, pymysql.err.ProgrammingError): + syntax_error_codes = [ + "1064", # syntax error + ] + for error_code in syntax_error_codes: + if error_code in str(error): + return True + return False + def udtf_class_builder( self, fetch_size: int = 1000, diff --git a/tests/integ/datasource/test_databricks.py b/tests/integ/datasource/test_databricks.py index 63c496f1c5..9c643322da 100644 --- a/tests/integ/datasource/test_databricks.py +++ b/tests/integ/datasource/test_databricks.py @@ -265,6 +265,18 @@ def test_unsupported_type(): assert schema == StructType([StructField("TEST_COL", StringType(), nullable=True)]) +def test_databricks_non_retryable_error(session): + with pytest.raises( + SnowparkDataframeReaderException, + match="PARSE_SYNTAX_ERROR", + ): + session.read.dbapi( + create_databricks_connection, + table=TEST_TABLE_NAME, + predicates=["invalid syntax"], + ) + + def test_session_init(session): with pytest.raises( SnowparkDataframeReaderException, diff --git a/tests/integ/datasource/test_mysql.py b/tests/integ/datasource/test_mysql.py index 842f386dcf..1e1a927ea3 100644 --- a/tests/integ/datasource/test_mysql.py +++ b/tests/integ/datasource/test_mysql.py @@ -305,6 +305,18 @@ def test_unsupported_type(): assert schema == StructType([StructField("TEST_COL", StringType(), nullable=True)]) +def test_mysql_non_retryable_error(session): + with pytest.raises( + SnowparkDataframeReaderException, + match="You have an error in your SQL syntax", + ): + session.read.dbapi( + create_connection_mysql, + table=TEST_TABLE_NAME, + predicates=["invalid syntax"], + ) + + def test_session_init(session): with pytest.raises( SnowparkDataframeReaderException, diff --git a/tests/integ/datasource/test_oracledb.py b/tests/integ/datasource/test_oracledb.py index ea8457b942..c1d5bba6d3 100644 --- a/tests/integ/datasource/test_oracledb.py +++ b/tests/integ/datasource/test_oracledb.py @@ -255,6 +255,18 @@ def test_unsupported_type(): assert schema == StructType([StructField("TEST_COL", StringType(), nullable=True)]) +def test_oracledb_non_retryable_error(session): + with pytest.raises( + SnowparkDataframeReaderException, + match="ORA-00920: invalid relational operator", + ): + session.read.dbapi( + create_connection_oracledb, + table=ORACLEDB_TABLE_NAME, + predicates=["invalid syntax"], + ).collect() + + def test_query_timeout_and_session_init(session): statement = """ BEGIN diff --git a/tests/integ/datasource/test_postgres.py b/tests/integ/datasource/test_postgres.py index 99088268e2..464ae7d1f0 100644 --- a/tests/integ/datasource/test_postgres.py +++ b/tests/integ/datasource/test_postgres.py @@ -507,3 +507,15 @@ def test_server_side_cursor(session): assert cursor.name is not None # Server-side cursor should have a name cursor.close() conn.close() + + +def test_postgres_non_retryable_error(session): + with pytest.raises( + SnowparkDataframeReaderException, + match="syntax error", + ): + session.read.dbapi( + create_postgres_connection, + table=POSTGRES_TABLE_NAME, + predicates=["invalid syntax"], + ).collect() diff --git a/tests/integ/test_data_source_api.py b/tests/integ/test_data_source_api.py index cc9e0ab3a0..7ec189ac96 100644 --- a/tests/integ/test_data_source_api.py +++ b/tests/integ/test_data_source_api.py @@ -210,6 +210,29 @@ def test_dbapi_retry(session, fetch_with_process): assert mock_task.call_count == _MAX_RETRY_TIME +@pytest.mark.parametrize("fetch_with_process", [True, False]) +def test_dbapi_non_retryable_error(session, fetch_with_process): + with mock.patch( + "snowflake.snowpark._internal.data_source.utils._task_fetch_data_from_source", + side_effect=SnowparkDataframeReaderException("mock error"), + ) as mock_task: + mock_task.__name__ = "_task_fetch_from_data_source" + parquet_queue = multiprocessing.Queue() if fetch_with_process else queue.Queue() + with pytest.raises(SnowparkDataframeReaderException, match="mock error"): + _task_fetch_data_from_source_with_retry( + worker=DataSourceReader( + PyodbcDriver, + sql_server_create_connection, + StructType([StructField("col1", IntegerType(), False)]), + DBMS_TYPE.SQL_SERVER_DB, + ), + partition="SELECT * FROM test_table", + partition_idx=0, + parquet_queue=parquet_queue, + ) + assert mock_task.call_count == 1 + + @pytest.mark.skipif( IS_WINDOWS, reason="sqlite3 file can not be shared across processes on windows",