Skip to content
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,9 @@
- `nvl2`
- `regr_valx`

#### Improvements

- Improve `DataFrameReader.dbapi`(PuPr) that dbapi will not retry on non-retryable error such as SQL syntax error on external data source query.

### Snowpark pandas API Updates

Expand Down

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what about ingestion (copy into), if we have SnowparkSQLException, should we retry?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think non-retryable exception is for error that we know that it will not work even if we retry.
I think SnowparkSQLException is a very general failure that not only exception like SQL syntax error could trigger it.
For example, a udtf could fail with SnowparkSQLException because of python error that is retryable.

Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,10 @@

from snowflake.snowpark._internal.data_source.datasource_typing import Connection
from snowflake.snowpark._internal.data_source.drivers.base_driver import BaseDriver
from snowflake.snowpark.exceptions import SnowparkDataframeReaderException
from snowflake.snowpark.exceptions import (
SnowparkDataframeReaderException,
_SnowparkDataSourceNonRetryableException,
)
from snowflake.snowpark.types import StructType
from snowflake.connector.options import pandas as pd
import logging
Expand Down Expand Up @@ -85,6 +88,11 @@ def read(self, partition: str) -> Iterator[List[Any]]:
batch = []
else:
raise ValueError("fetch size cannot be smaller than 0")
except Exception as exc:
if self.driver.non_retryable_error_checker(exc):
raise _SnowparkDataSourceNonRetryableException(exc)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as per our discussion, we do not need the private class and can just raise the existing reader exception here

else:
raise
finally:
try:
cursor.close()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,9 @@ def to_snow_type(self, schema: List[Any]) -> StructType:
f"{self.__class__.__name__} has not implemented to_snow_type function"
)

def non_retryable_error_checker(self, error: Exception) -> bool:
return False

@staticmethod
def prepare_connection(
conn: "Connection",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,18 @@ def to_snow_type(self, schema: List[Any]) -> StructType:
all_columns.append(StructField(column_name, data_type, True))
return StructType(all_columns)

def non_retryable_error_checker(self, error: Exception) -> bool:
import databricks.sql

if isinstance(error, databricks.sql.ServerOperationError):
syntax_error_codes = [
"PARSE_SYNTAX_ERROR", # syntax error
]
for error_code in syntax_error_codes:
if error_code in str(error):
return True
return False

def udtf_class_builder(
self,
fetch_size: int = 1000,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,30 @@ def prepare_connection(
conn.outputtypehandler = output_type_handler
return conn

def non_retryable_error_checker(self, error: Exception) -> bool:
import oracledb

if isinstance(error, oracledb.DatabaseError):
syntax_error_codes = [
"ORA-00900", # invalid SQL statement
"ORA-00901", # invalid CREATE command
"ORA-00904", # invalid identifier
"ORA-00905", # missing keyword
"ORA-00906", # missing left parenthesis
"ORA-00907", # missing right parenthesis
"ORA-00911", # invalid character
"ORA-00920", # invalid relational operator
"ORA-00921", # unexpected end of SQL command
"ORA-00923", # FROM keyword not found where expected
"ORA-00933", # SQL command not properly ended
"ORA-00936", # missing expression
"ORA-00942", # table or view does not exist
]
for error_code in syntax_error_codes:
if error_code in str(error):
return True
return False

def udtf_class_builder(
self,
fetch_size: int = 1000,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,18 @@ def to_snow_type(self, schema: List[Any]) -> StructType:
fields.append(StructField(name, data_type, True))
return StructType(fields)

def non_retryable_error_checker(self, error: Exception) -> bool:
import psycopg2

if isinstance(error, psycopg2.errors.SyntaxError):
syntax_error_codes = [
"42601", # syntax error
]
for error_code in syntax_error_codes:
if error_code == str(error.pgcode):
return True
return False

@staticmethod
def to_result_snowpark_df(
session: "Session", table_name, schema, _emit_ast: bool = True
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,18 @@ def to_snow_type(self, schema: List[Any]) -> StructType:
fields.append(StructField(name, data_type, null_ok))
return StructType(fields)

def non_retryable_error_checker(self, error: Exception) -> bool:
import pymysql

if isinstance(error, pymysql.err.ProgrammingError):
syntax_error_codes = [
"1064", # syntax error
]
for error_code in syntax_error_codes:
if error_code in str(error):
return True
return False

def udtf_class_builder(
self,
fetch_size: int = 1000,
Expand Down
9 changes: 7 additions & 2 deletions src/snowflake/snowpark/_internal/data_source/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,10 @@
)
import snowflake
from snowflake.snowpark._internal.data_source import DataSourceReader
from snowflake.snowpark.exceptions import SnowparkDataframeReaderException

from snowflake.snowpark.exceptions import (
SnowparkDataframeReaderException,
_SnowparkDataSourceNonRetryableException,
)

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -282,6 +284,9 @@ def _retry_run(func: Callable, *args, **kwargs) -> Any:
except SnowparkDataframeReaderException:
# SnowparkDataframeReaderException is a non-retryable exception
raise
except _SnowparkDataSourceNonRetryableException:
# SnowparkDataSourceNonRetryableException is a non-retryable exception
raise
except Exception as e:
last_error = e
error_trace = traceback.format_exc()
Expand Down
9 changes: 8 additions & 1 deletion src/snowflake/snowpark/dataframe_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@
from snowflake.snowpark.exceptions import (
SnowparkSessionException,
SnowparkDataframeReaderException,
_SnowparkDataSourceNonRetryableException,
)
from snowflake.snowpark.functions import sql_expr, col, concat, lit, to_file
from snowflake.snowpark.mock._connection import MockServerConnection
Expand Down Expand Up @@ -2003,7 +2004,13 @@ def create_oracledb_connection():
f"Cancelled a remaining data fetching future {future} due to error in another thread."
)

if isinstance(exc, SnowparkDataframeReaderException):
if isinstance(
exc,
(
SnowparkDataframeReaderException,
_SnowparkDataSourceNonRetryableException,
),
):
raise exc

raise SnowparkDataframeReaderException(
Expand Down
21 changes: 21 additions & 0 deletions src/snowflake/snowpark/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,3 +283,24 @@ class SnowparkInvalidObjectNameException(SnowparkGeneralException):
"""

pass


class _SnowparkDataSourceNonRetryableException(SnowparkGeneralException):
"""Exception for data source non-retryable error."""

def __init__(
self,
error: Exception,
) -> None:
self.error: Exception = error

self._pretty_msg = f"{self.__class__.__name__}({self.error})"

def __repr__(self):
return f"{self.__class__.__name__}({self.error})"

def __str__(self):
return self._pretty_msg

def __reduce__(self):
return (self.__class__, (self.error,))
13 changes: 13 additions & 0 deletions tests/integ/datasource/test_databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
)
from snowflake.snowpark.exceptions import (
SnowparkDataframeReaderException,
_SnowparkDataSourceNonRetryableException,
SnowparkSQLException,
)
from snowflake.snowpark.types import (
Expand Down Expand Up @@ -265,6 +266,18 @@ def test_unsupported_type():
assert schema == StructType([StructField("TEST_COL", StringType(), nullable=True)])


def test_databricks_non_retryable_error(session):
with pytest.raises(
_SnowparkDataSourceNonRetryableException,
match="PARSE_SYNTAX_ERROR",
):
session.read.dbapi(
create_databricks_connection,
table=TEST_TABLE_NAME,
predicates=["invalid syntax"],
)


def test_session_init(session):
with pytest.raises(
SnowparkDataframeReaderException,
Expand Down
13 changes: 13 additions & 0 deletions tests/integ/datasource/test_mysql.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
PymysqlTypeCode,
)
from snowflake.snowpark._internal.data_source.utils import DBMS_TYPE
from snowflake.snowpark.exceptions import _SnowparkDataSourceNonRetryableException
from snowflake.snowpark.types import StructType, StructField, StringType
from snowflake.snowpark.exceptions import (
SnowparkDataframeReaderException,
Expand Down Expand Up @@ -305,6 +306,18 @@ def test_unsupported_type():
assert schema == StructType([StructField("TEST_COL", StringType(), nullable=True)])


def test_mysql_non_retryable_error(session):
with pytest.raises(
_SnowparkDataSourceNonRetryableException,
match="You have an error in your SQL syntax",
):
session.read.dbapi(
create_connection_mysql,
table=TEST_TABLE_NAME,
predicates=["invalid syntax"],
)


def test_session_init(session):
with pytest.raises(
SnowparkDataframeReaderException,
Expand Down
13 changes: 13 additions & 0 deletions tests/integ/datasource/test_oracledb.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
)
from snowflake.snowpark.types import StructType, StructField, StringType
from snowflake.snowpark.exceptions import (
_SnowparkDataSourceNonRetryableException,
SnowparkDataframeReaderException,
SnowparkSQLException,
)
Expand Down Expand Up @@ -254,6 +255,18 @@ def test_unsupported_type():
assert schema == StructType([StructField("TEST_COL", StringType(), nullable=True)])


def test_oracledb_non_retryable_error(session):
with pytest.raises(
_SnowparkDataSourceNonRetryableException,
match="ORA-00920: invalid relational operator",
):
session.read.dbapi(
create_connection_oracledb,
table=ORACLEDB_TABLE_NAME,
predicates=["invalid syntax"],
).collect()

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Curious about why some tests use .collect() while others don't, my impression is that dbapi() is not lazily evaluated?



def test_query_timeout_and_session_init(session):
statement = """
BEGIN
Expand Down
13 changes: 13 additions & 0 deletions tests/integ/datasource/test_postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from snowflake.snowpark._internal.data_source.utils import DBMS_TYPE
from snowflake.snowpark.exceptions import (
SnowparkDataframeReaderException,
_SnowparkDataSourceNonRetryableException,
SnowparkSQLException,
)
from snowflake.snowpark.types import (
Expand Down Expand Up @@ -507,3 +508,15 @@ def test_server_side_cursor(session):
assert cursor.name is not None # Server-side cursor should have a name
cursor.close()
conn.close()


def test_postgres_non_retryable_error(session):
with pytest.raises(
_SnowparkDataSourceNonRetryableException,
match="syntax error",
):
session.read.dbapi(
create_postgres_connection,
table=POSTGRES_TABLE_NAME,
predicates=["invalid syntax"],
).collect()
30 changes: 29 additions & 1 deletion tests/integ/test_data_source_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,10 @@
random_name_for_temp_object,
)
from snowflake.snowpark.dataframe_reader import _MAX_RETRY_TIME
from snowflake.snowpark.exceptions import SnowparkDataframeReaderException
from snowflake.snowpark.exceptions import (
SnowparkDataframeReaderException,
_SnowparkDataSourceNonRetryableException,
)
from snowflake.snowpark.types import (
StructType,
StructField,
Expand Down Expand Up @@ -210,6 +213,31 @@ def test_dbapi_retry(session, fetch_with_process):
assert mock_task.call_count == _MAX_RETRY_TIME


@pytest.mark.parametrize("fetch_with_process", [True, False])
def test_dbapi_non_retryable_error(session, fetch_with_process):
with mock.patch(
"snowflake.snowpark._internal.data_source.utils._task_fetch_data_from_source",
side_effect=_SnowparkDataSourceNonRetryableException(Exception("mock error")),
) as mock_task:
mock_task.__name__ = "_task_fetch_from_data_source"
parquet_queue = multiprocessing.Queue() if fetch_with_process else queue.Queue()
with pytest.raises(
_SnowparkDataSourceNonRetryableException, match="mock error"
):
_task_fetch_data_from_source_with_retry(
worker=DataSourceReader(
PyodbcDriver,
sql_server_create_connection,
StructType([StructField("col1", IntegerType(), False)]),
DBMS_TYPE.SQL_SERVER_DB,
),
partition="SELECT * FROM test_table",
partition_idx=0,
parquet_queue=parquet_queue,
)
assert mock_task.call_count == 1


@pytest.mark.skipif(
IS_WINDOWS,
reason="sqlite3 file can not be shared across processes on windows",
Expand Down
Loading