Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@

#### Improvements

- Improved `DataFrameReader.dbapi`(PuPr) that dbapi will not retry on non-retryable error such as SQL syntax error on external data source query.
- Removed unnecessary warnings about local package version mismatch when using `session.read.option('rowTag', <tag_name>).xml(<stage_file_path>)` or `xpath` functions.
- Improved `DataFrameReader.dbapi` (PuPr) reading performance by setting the default `fetch_size` parameter value to 100000.
- Improved error message for XSD validation failure when reading XML files using `session.read.option('rowValidationXSDPath', <xsd_path>).xml(<stage_file_path>)`.
Expand Down

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what about ingestion (copy into), if we have SnowparkSQLException, should we retry?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think non-retryable exception is for error that we know that it will not work even if we retry.
I think SnowparkSQLException is a very general failure that not only exception like SQL syntax error could trigger it.
For example, a udtf could fail with SnowparkSQLException because of python error that is retryable.

Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,11 @@ def read(self, partition: str) -> Iterator[List[Any]]:
batch = []
else:
raise ValueError("fetch size cannot be smaller than 0")
except Exception as exc:
if self.driver.non_retryable_error_checker(exc):
raise SnowparkDataframeReaderException(message=str(exc))
else:
raise
finally:
try:
cursor.close()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,9 @@ def to_snow_type(self, schema: List[Any]) -> StructType:
f"{self.__class__.__name__} has not implemented to_snow_type function"
)

def non_retryable_error_checker(self, error: Exception) -> bool:
return False

@staticmethod
def prepare_connection(
conn: "Connection",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,18 @@ def to_snow_type(self, schema: List[Any]) -> StructType:
all_columns.append(StructField(column_name, data_type, True))
return StructType(all_columns)

def non_retryable_error_checker(self, error: Exception) -> bool:
import databricks.sql

if isinstance(error, databricks.sql.ServerOperationError):
syntax_error_codes = [
"PARSE_SYNTAX_ERROR", # syntax error
]
for error_code in syntax_error_codes:
if error_code in str(error):
return True
return False

def udtf_class_builder(
self,
fetch_size: int = 1000,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,30 @@ def prepare_connection(
conn.outputtypehandler = output_type_handler
return conn

def non_retryable_error_checker(self, error: Exception) -> bool:
import oracledb

if isinstance(error, oracledb.DatabaseError):
syntax_error_codes = [
"ORA-00900", # invalid SQL statement
"ORA-00901", # invalid CREATE command
"ORA-00904", # invalid identifier
"ORA-00905", # missing keyword
"ORA-00906", # missing left parenthesis
"ORA-00907", # missing right parenthesis
"ORA-00911", # invalid character
"ORA-00920", # invalid relational operator
"ORA-00921", # unexpected end of SQL command
"ORA-00923", # FROM keyword not found where expected
"ORA-00933", # SQL command not properly ended
"ORA-00936", # missing expression
"ORA-00942", # table or view does not exist
]
for error_code in syntax_error_codes:
if error_code in str(error):
return True
return False

def udtf_class_builder(
self,
fetch_size: int = 1000,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,18 @@ def to_snow_type(self, schema: List[Any]) -> StructType:
fields.append(StructField(name, data_type, True))
return StructType(fields)

def non_retryable_error_checker(self, error: Exception) -> bool:
import psycopg2

if isinstance(error, psycopg2.errors.SyntaxError):
syntax_error_codes = [
"42601", # syntax error
]
for error_code in syntax_error_codes:
if error_code == str(error.pgcode):
return True
return False

@staticmethod
def to_result_snowpark_df(
session: "Session", table_name, schema, _emit_ast: bool = True
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,18 @@ def to_snow_type(self, schema: List[Any]) -> StructType:
fields.append(StructField(name, data_type, null_ok))
return StructType(fields)

def non_retryable_error_checker(self, error: Exception) -> bool:
import pymysql

if isinstance(error, pymysql.err.ProgrammingError):
syntax_error_codes = [
"1064", # syntax error
]
for error_code in syntax_error_codes:
if error_code in str(error):
return True
return False

def udtf_class_builder(
self,
fetch_size: int = 1000,
Expand Down
12 changes: 12 additions & 0 deletions tests/integ/datasource/test_databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,18 @@ def test_unsupported_type():
assert schema == StructType([StructField("TEST_COL", StringType(), nullable=True)])


def test_databricks_non_retryable_error(session):
with pytest.raises(
SnowparkDataframeReaderException,
match="PARSE_SYNTAX_ERROR",
):
session.read.dbapi(
create_databricks_connection,
table=TEST_TABLE_NAME,
predicates=["invalid syntax"],
)


def test_session_init(session):
with pytest.raises(
SnowparkDataframeReaderException,
Expand Down
12 changes: 12 additions & 0 deletions tests/integ/datasource/test_mysql.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,18 @@ def test_unsupported_type():
assert schema == StructType([StructField("TEST_COL", StringType(), nullable=True)])


def test_mysql_non_retryable_error(session):
with pytest.raises(
SnowparkDataframeReaderException,
match="You have an error in your SQL syntax",
):
session.read.dbapi(
create_connection_mysql,
table=TEST_TABLE_NAME,
predicates=["invalid syntax"],
)


def test_session_init(session):
with pytest.raises(
SnowparkDataframeReaderException,
Expand Down
12 changes: 12 additions & 0 deletions tests/integ/datasource/test_oracledb.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,18 @@ def test_unsupported_type():
assert schema == StructType([StructField("TEST_COL", StringType(), nullable=True)])


def test_oracledb_non_retryable_error(session):
with pytest.raises(
SnowparkDataframeReaderException,
match="ORA-00920: invalid relational operator",
):
session.read.dbapi(
create_connection_oracledb,
table=ORACLEDB_TABLE_NAME,
predicates=["invalid syntax"],
).collect()

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Curious about why some tests use .collect() while others don't, my impression is that dbapi() is not lazily evaluated?



def test_query_timeout_and_session_init(session):
statement = """
BEGIN
Expand Down
12 changes: 12 additions & 0 deletions tests/integ/datasource/test_postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -507,3 +507,15 @@ def test_server_side_cursor(session):
assert cursor.name is not None # Server-side cursor should have a name
cursor.close()
conn.close()


def test_postgres_non_retryable_error(session):
with pytest.raises(
SnowparkDataframeReaderException,
match="syntax error",
):
session.read.dbapi(
create_postgres_connection,
table=POSTGRES_TABLE_NAME,
predicates=["invalid syntax"],
).collect()
23 changes: 23 additions & 0 deletions tests/integ/test_data_source_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,29 @@ def test_dbapi_retry(session, fetch_with_process):
assert mock_task.call_count == _MAX_RETRY_TIME


@pytest.mark.parametrize("fetch_with_process", [True, False])
def test_dbapi_non_retryable_error(session, fetch_with_process):
with mock.patch(
"snowflake.snowpark._internal.data_source.utils._task_fetch_data_from_source",
side_effect=SnowparkDataframeReaderException("mock error"),
) as mock_task:
mock_task.__name__ = "_task_fetch_from_data_source"
parquet_queue = multiprocessing.Queue() if fetch_with_process else queue.Queue()
with pytest.raises(SnowparkDataframeReaderException, match="mock error"):
_task_fetch_data_from_source_with_retry(
worker=DataSourceReader(
PyodbcDriver,
sql_server_create_connection,
StructType([StructField("col1", IntegerType(), False)]),
DBMS_TYPE.SQL_SERVER_DB,
),
partition="SELECT * FROM test_table",
partition_idx=0,
parquet_queue=parquet_queue,
)
assert mock_task.call_count == 1


@pytest.mark.skipif(
IS_WINDOWS,
reason="sqlite3 file can not be shared across processes on windows",
Expand Down