diff --git a/CHANGELOG.md b/CHANGELOG.md index bbefc6fd34..2c33c12b7b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -83,6 +83,7 @@ - Fixed UDTF ingestion failure with `pyodbc` driver caused by unprocessed row data. - Fixed SQL Server query input failure due to incorrect select query generation. - Fixed UDTF ingestion not preserving column nullability in the output schema. + - Fixed an issue that caused the program to hang during multithreaded Parquet based ingestion when a data fetching error occurred. #### Improvements diff --git a/src/snowflake/snowpark/_internal/data_source/utils.py b/src/snowflake/snowpark/_internal/data_source/utils.py index f3fbebbd71..58dc17bcaa 100644 --- a/src/snowflake/snowpark/_internal/data_source/utils.py +++ b/src/snowflake/snowpark/_internal/data_source/utils.py @@ -308,6 +308,9 @@ def worker_process( ): """Worker process that fetches data from multiple partitions""" while True: + if stop_event and stop_event.is_set(): + # other worker has set the stop event signalling me to stop, exit gracefully + break try: # Get item from queue with timeout partition_idx, query = partition_queue.get(timeout=1.0) diff --git a/tests/integ/datasource/test_oracledb.py b/tests/integ/datasource/test_oracledb.py index 6b5178590c..ea8457b942 100644 --- a/tests/integ/datasource/test_oracledb.py +++ b/tests/integ/datasource/test_oracledb.py @@ -6,6 +6,7 @@ import math import sys from collections import namedtuple +from unittest.mock import patch import pytest @@ -334,3 +335,32 @@ def test_oracledb_driver_udtf_class_builder(): # Verify we got data with the right structure (2 columns) assert len(column_result_rows) > 0 assert len(column_result_rows[0]) == 2 # Two columns + + +def test_dbapi_no_hang_on_exit_when_worker_error(session): + """ + Test that the dbapi reader does not hang on exit when a worker raises an error + + Ideally the test should be put in test_data_source_api.py, + however, reproducing using SQLite is hard to achieve while pure mocking gets the test code too complex. + Hence, we use Oracledb here which can repro the issue reliably without the fix. + """ + with patch( + "snowflake.snowpark._internal.data_source.drivers.base_driver.BaseDriver.data_source_data_to_pandas_df" + ) as mock_data_source_data_to_pandas_df: + # Mock the data_source_data_to_pandas_df method to raise RuntimeError + mock_data_source_data_to_pandas_df.side_effect = RuntimeError( + "conversion error" + ) + + # Expect the dbapi call to raise a SnowparkDataframeReaderException due to the worker error + with pytest.raises(SnowparkDataframeReaderException, match="conversion error"): + session.read.dbapi( + create_connection_oracledb, + table=ORACLEDB_TABLE_NAME, + column="ID", + lower_bound=0, + upper_bound=100, + num_partitions=10, + max_workers=2, + )