From 52ffd88862d85983bb13aa1a2f36d12165c755d5 Mon Sep 17 00:00:00 2001 From: Yuyang Wang Date: Thu, 4 Sep 2025 11:21:27 -0700 Subject: [PATCH 1/4] placeholder --- .../snowpark/_internal/data_source/utils.py | 46 +++++++++++++------ 1 file changed, 31 insertions(+), 15 deletions(-) diff --git a/src/snowflake/snowpark/_internal/data_source/utils.py b/src/snowflake/snowpark/_internal/data_source/utils.py index f3fbebbd71..ce6ec17dac 100644 --- a/src/snowflake/snowpark/_internal/data_source/utils.py +++ b/src/snowflake/snowpark/_internal/data_source/utils.py @@ -195,7 +195,9 @@ def _task_fetch_data_from_source_with_retry( partition_idx: int, parquet_queue: Union[mp.Queue, queue.Queue], stop_event: threading.Event = None, -): +) -> Dict: + start = time.perf_counter() + logger.debug(f"Partition {partition_idx} fetch start") _retry_run( _task_fetch_data_from_source, worker, @@ -204,6 +206,14 @@ def _task_fetch_data_from_source_with_retry( parquet_queue, stop_event, ) + end = time.perf_counter() + logger.debug(f"Partition {partition_idx} fetch finished") + telemetry_dict = { + "duration": end - start, + "partition_query": partition, + "partition_idx": partition_idx, + } + return telemetry_dict def _upload_and_copy_into_table( @@ -307,21 +317,23 @@ def worker_process( stop_event: threading.Event = None, ): """Worker process that fetches data from multiple partitions""" + telemetry_set = set() while True: try: # Get item from queue with timeout partition_idx, query = partition_queue.get(timeout=1.0) - _task_fetch_data_from_source_with_retry( + telemetry_dict = _task_fetch_data_from_source_with_retry( reader, query, partition_idx, parquet_queue, stop_event, ) + telemetry_set.add(telemetry_dict) except queue.Empty: # indicate whether a process is exit gracefully - process_or_thread_error_indicator.put(os.getpid()) + process_or_thread_error_indicator.put((os.getpid(), telemetry_set)) # No more work available, exit gracefully break except Exception as e: @@ -357,14 +369,17 @@ def process_completed_futures(thread_futures) -> float: def _drain_process_status_queue( process_or_thread_error_indicator: Union[mp.Queue, queue.Queue], -) -> Set: - result = set() +) -> Tuple[Set, Set]: + process_id = set() + telemetry = set() while True: try: - result.add(process_or_thread_error_indicator.get(block=False)) + result_indicator = process_or_thread_error_indicator.get(block=False) + process_id.add(result_indicator[0]) + telemetry = telemetry.union(result_indicator[1]) except queue.Empty: break - return result + return process_id, telemetry def process_parquet_queue_with_threads( @@ -410,6 +425,7 @@ def process_parquet_queue_with_threads( completed_partitions = set() gracefully_exited_processes = set() + telemetries = set() # process parquet_queue may produce more data than the threads can handle, # so we use semaphore to limit the number of threads backpressure_semaphore = BoundedSemaphore(value=_MAX_WORKER_SCALE * max_workers) @@ -470,13 +486,13 @@ def process_parquet_queue_with_threads( # Check if any processes have failed for i, process in enumerate(workers): if not process.is_alive(): + ids, telemetry = _drain_process_status_queue( + process_or_thread_error_indicator + ) gracefully_exited_processes = ( - gracefully_exited_processes.union( - _drain_process_status_queue( - process_or_thread_error_indicator - ) - ) + gracefully_exited_processes.union(ids) ) + telemetries = telemetries.union(telemetry) if process.pid not in gracefully_exited_processes: raise SnowparkDataframeReaderException( f"Partition {i} data fetching process failed with exit code {process.exitcode} or failed silently" @@ -501,9 +517,9 @@ def process_parquet_queue_with_threads( for process in workers: process.join() # empty parquet queue to get all signals after each process ends - gracefully_exited_processes = gracefully_exited_processes.union( - _drain_process_status_queue(process_or_thread_error_indicator) - ) + ids, telemetry = _drain_process_status_queue(process_or_thread_error_indicator) + gracefully_exited_processes = gracefully_exited_processes.union(ids) + telemetries = telemetries.union(telemetry) # check if any process fails for idx, process in enumerate(workers): From 44b600007d8787cdb390e2a7955ca6b9904fc2a4 Mon Sep 17 00:00:00 2001 From: Yuyang Wang Date: Wed, 10 Sep 2025 16:16:53 -0700 Subject: [PATCH 2/4] add more telemetry and test --- .../snowpark/_internal/data_source/utils.py | 60 +++++++++++++++---- src/snowflake/snowpark/dataframe_reader.py | 8 +++ tests/integ/test_data_source_api.py | 23 ++++++- 3 files changed, 76 insertions(+), 15 deletions(-) diff --git a/src/snowflake/snowpark/_internal/data_source/utils.py b/src/snowflake/snowpark/_internal/data_source/utils.py index ce6ec17dac..88204b20b5 100644 --- a/src/snowflake/snowpark/_internal/data_source/utils.py +++ b/src/snowflake/snowpark/_internal/data_source/utils.py @@ -1,6 +1,7 @@ # # Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved. # +import json import math import os import queue @@ -209,9 +210,11 @@ def _task_fetch_data_from_source_with_retry( end = time.perf_counter() logger.debug(f"Partition {partition_idx} fetch finished") telemetry_dict = { + "partition_idx": partition_idx, + "thread_id": threading.get_ident(), + "process_id": os.getpid(), "duration": end - start, "partition_query": partition, - "partition_idx": partition_idx, } return telemetry_dict @@ -262,7 +265,9 @@ def _upload_and_copy_into_table_with_retry( snowflake_table_name: Optional[str] = None, on_error: Optional[str] = "abort_statement", statements_params: Optional[Dict[str, str]] = None, -): +) -> Dict: + start = time.perf_counter() + logger.debug(f"Parquet file {parquet_id} upload and copy into table start") try: _retry_run( _upload_and_copy_into_table, @@ -279,6 +284,14 @@ def _upload_and_copy_into_table_with_retry( # proactively close the buffer to release memory parquet_buffer.close() backpressure_semaphore.release() + end = time.perf_counter() + logger.debug(f"Parquet file {parquet_id} upload and copy into table finished") + telemetry_dict = { + "parquet_id": parquet_id, + "thread_id": threading.get_ident(), + "duration": end - start, + } + return telemetry_dict def _retry_run(func: Callable, *args, **kwargs) -> Any: @@ -330,7 +343,7 @@ def worker_process( parquet_queue, stop_event, ) - telemetry_set.add(telemetry_dict) + telemetry_set.add(json.dumps(telemetry_dict)) except queue.Empty: # indicate whether a process is exit gracefully process_or_thread_error_indicator.put((os.getpid(), telemetry_set)) @@ -342,13 +355,14 @@ def worker_process( break -def process_completed_futures(thread_futures) -> float: +def process_completed_futures(thread_futures) -> Tuple[float, Set]: """Process completed futures with simplified error handling.""" + telemetries = set() for parquet_id, future in list(thread_futures): # Iterate over a copy of the set if future.done(): thread_futures.discard((parquet_id, future)) try: - future.result() + telemetries.add(json.dumps(future.result())) logger.debug( f"Thread future for parquet {parquet_id} completed successfully." ) @@ -364,7 +378,7 @@ def process_completed_futures(thread_futures) -> float: ) thread_futures.clear() # Clear the set since all are cancelled raise - return time.perf_counter() + return time.perf_counter(), telemetries def _drain_process_status_queue( @@ -394,7 +408,7 @@ def process_parquet_queue_with_threads( statements_params: Optional[Dict[str, str]] = None, on_error: str = "abort_statement", fetch_with_process: bool = False, -) -> Tuple[float, float, float]: +) -> Tuple[float, float, float, Set, Set]: """ Process parquet data from a multiprocessing queue using a thread pool. @@ -425,7 +439,8 @@ def process_parquet_queue_with_threads( completed_partitions = set() gracefully_exited_processes = set() - telemetries = set() + fetch_to_local_workers_telemetries = set() + upload_to_sf_worker_telemetries = set() # process parquet_queue may produce more data than the threads can handle, # so we use semaphore to limit the number of threads backpressure_semaphore = BoundedSemaphore(value=_MAX_WORKER_SCALE * max_workers) @@ -436,8 +451,13 @@ def process_parquet_queue_with_threads( thread_futures = set() # stores tuples of (parquet_id, thread_future) while len(completed_partitions) < total_partitions or thread_futures: # Process any completed futures and handle errors - upload_to_sf_end_time = process_completed_futures(thread_futures) - + ( + upload_to_sf_end_time, + upload_to_sf_worker_telemetry, + ) = process_completed_futures(thread_futures) + upload_to_sf_worker_telemetries = upload_to_sf_worker_telemetries.union( + upload_to_sf_worker_telemetry + ) try: backpressure_semaphore.acquire() parquet_id, parquet_buffer = parquet_queue.get(block=False) @@ -492,7 +512,9 @@ def process_parquet_queue_with_threads( gracefully_exited_processes = ( gracefully_exited_processes.union(ids) ) - telemetries = telemetries.union(telemetry) + fetch_to_local_workers_telemetries = ( + fetch_to_local_workers_telemetries.union(telemetry) + ) if process.pid not in gracefully_exited_processes: raise SnowparkDataframeReaderException( f"Partition {i} data fetching process failed with exit code {process.exitcode} or failed silently" @@ -519,7 +541,9 @@ def process_parquet_queue_with_threads( # empty parquet queue to get all signals after each process ends ids, telemetry = _drain_process_status_queue(process_or_thread_error_indicator) gracefully_exited_processes = gracefully_exited_processes.union(ids) - telemetries = telemetries.union(telemetry) + fetch_to_local_workers_telemetries = fetch_to_local_workers_telemetries.union( + telemetry + ) # check if any process fails for idx, process in enumerate(workers): @@ -538,10 +562,20 @@ def process_parquet_queue_with_threads( raise SnowparkDataframeReaderException( f"Partition {idx} data fetching thread failed with error: {e}" ) + _, telemetry = _drain_process_status_queue(process_or_thread_error_indicator) + fetch_to_local_workers_telemetries = fetch_to_local_workers_telemetries.union( + telemetry + ) logger.debug(f"fetch to local end at {fetch_to_local_end_time}") logger.debug(f"upload and copy into end at {upload_to_sf_end_time}") logger.debug( f"upload and copy into total time: {upload_to_sf_end_time - upload_to_sf_start_time}" ) - return fetch_to_local_end_time, upload_to_sf_start_time, upload_to_sf_end_time + return ( + fetch_to_local_end_time, + upload_to_sf_start_time, + upload_to_sf_end_time, + fetch_to_local_workers_telemetries, + upload_to_sf_worker_telemetries, + ) diff --git a/src/snowflake/snowpark/dataframe_reader.py b/src/snowflake/snowpark/dataframe_reader.py index 5533ebb246..26ca31a01d 100644 --- a/src/snowflake/snowpark/dataframe_reader.py +++ b/src/snowflake/snowpark/dataframe_reader.py @@ -1959,6 +1959,8 @@ def create_oracledb_connection(): fetch_to_local_end_time, upload_to_sf_start_time, upload_to_sf_end_time, + fetch_to_local_workers_telemetry, + upload_to_sf_workers_telemetry, ) = process_parquet_queue_with_threads( session=self._session, parquet_queue=parquet_queue, @@ -2023,6 +2025,12 @@ def create_oracledb_connection(): self, snowflake_table_name, struct_schema, _emit_ast=_emit_ast ) telemetry_json_string["schema"] = res_df.schema.simple_string() + telemetry_json_string["fetch_to_local_workers_telemetries"] = [ + json.loads(telemetry) for telemetry in fetch_to_local_workers_telemetry + ] + telemetry_json_string["upload_to_sf_workers_telemetries"] = [ + json.loads(telemetry) for telemetry in upload_to_sf_workers_telemetry + ] self._session._conn._telemetry_client.send_data_source_perf_telemetry( telemetry_json_string ) diff --git a/tests/integ/test_data_source_api.py b/tests/integ/test_data_source_api.py index 69403dabed..7af07ae6b9 100644 --- a/tests/integ/test_data_source_api.py +++ b/tests/integ/test_data_source_api.py @@ -110,8 +110,8 @@ "input_type, input_value", [ ("table", SQL_SERVER_TABLE_NAME), - ("query", f"SELECT * FROM {SQL_SERVER_TABLE_NAME}"), - ("query", f"(SELECT * FROM {SQL_SERVER_TABLE_NAME})"), + # ("query", f"SELECT * FROM {SQL_SERVER_TABLE_NAME}"), + # ("query", f"(SELECT * FROM {SQL_SERVER_TABLE_NAME})"), ], ) @pytest.mark.parametrize("fetch_with_process", [True, False]) @@ -406,6 +406,25 @@ def test_telemetry(session, fetch_with_process): assert "upload_and_copy_into_sf_table_duration" in telemetry_json assert "end_to_end_duration" in telemetry_json + assert "fetch_to_local_workers_telemetries" in telemetry_json + assert "upload_to_sf_workers_telemetries" in telemetry_json + + # fetch_to_local_workers_telemetries + for entry in telemetry_json["fetch_to_local_workers_telemetries"]: + assert "partition_idx" in entry + assert "thread_id" in entry + assert "process_id" in entry + assert "duration" in entry + assert "partition_query" in entry + + # upload_to_sf_workers_telemetries + for entry in telemetry_json["upload_to_sf_workers_telemetries"]: + assert "parquet_id" in entry + assert "thread_id" in entry + assert "duration" in entry + + print(telemetry_json) + @pytest.mark.parametrize("fetch_with_process", [True, False]) def test_telemetry_tracking(caplog, session, fetch_with_process): From b0ada84e700cedd16adfbf4cb0e08e5bead065dc Mon Sep 17 00:00:00 2001 From: Yuyang Wang Date: Thu, 25 Sep 2025 16:47:54 -0700 Subject: [PATCH 3/4] remove comment and print --- tests/integ/test_data_source_api.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/tests/integ/test_data_source_api.py b/tests/integ/test_data_source_api.py index 7af07ae6b9..ebdc067e75 100644 --- a/tests/integ/test_data_source_api.py +++ b/tests/integ/test_data_source_api.py @@ -110,8 +110,8 @@ "input_type, input_value", [ ("table", SQL_SERVER_TABLE_NAME), - # ("query", f"SELECT * FROM {SQL_SERVER_TABLE_NAME}"), - # ("query", f"(SELECT * FROM {SQL_SERVER_TABLE_NAME})"), + ("query", f"SELECT * FROM {SQL_SERVER_TABLE_NAME}"), + ("query", f"(SELECT * FROM {SQL_SERVER_TABLE_NAME})"), ], ) @pytest.mark.parametrize("fetch_with_process", [True, False]) @@ -423,8 +423,6 @@ def test_telemetry(session, fetch_with_process): assert "thread_id" in entry assert "duration" in entry - print(telemetry_json) - @pytest.mark.parametrize("fetch_with_process", [True, False]) def test_telemetry_tracking(caplog, session, fetch_with_process): From 91b5ad0b3a03eefc9a2c7269dfd0a045b8b6bea2 Mon Sep 17 00:00:00 2001 From: Yuyang Wang Date: Fri, 26 Sep 2025 12:05:13 -0700 Subject: [PATCH 4/4] revert sending telemetry as it could cause memory issue --- .../snowpark/_internal/data_source/utils.py | 94 ++++++------------- src/snowflake/snowpark/dataframe_reader.py | 8 -- tests/integ/test_data_source_api.py | 25 ++--- 3 files changed, 36 insertions(+), 91 deletions(-) diff --git a/src/snowflake/snowpark/_internal/data_source/utils.py b/src/snowflake/snowpark/_internal/data_source/utils.py index 88204b20b5..626529abfa 100644 --- a/src/snowflake/snowpark/_internal/data_source/utils.py +++ b/src/snowflake/snowpark/_internal/data_source/utils.py @@ -1,7 +1,6 @@ # # Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved. # -import json import math import os import queue @@ -196,7 +195,7 @@ def _task_fetch_data_from_source_with_retry( partition_idx: int, parquet_queue: Union[mp.Queue, queue.Queue], stop_event: threading.Event = None, -) -> Dict: +): start = time.perf_counter() logger.debug(f"Partition {partition_idx} fetch start") _retry_run( @@ -208,15 +207,9 @@ def _task_fetch_data_from_source_with_retry( stop_event, ) end = time.perf_counter() - logger.debug(f"Partition {partition_idx} fetch finished") - telemetry_dict = { - "partition_idx": partition_idx, - "thread_id": threading.get_ident(), - "process_id": os.getpid(), - "duration": end - start, - "partition_query": partition, - } - return telemetry_dict + logger.debug( + f"Partition {partition_idx} fetch finished, used {end - start} seconds" + ) def _upload_and_copy_into_table( @@ -265,7 +258,7 @@ def _upload_and_copy_into_table_with_retry( snowflake_table_name: Optional[str] = None, on_error: Optional[str] = "abort_statement", statements_params: Optional[Dict[str, str]] = None, -) -> Dict: +): start = time.perf_counter() logger.debug(f"Parquet file {parquet_id} upload and copy into table start") try: @@ -285,13 +278,9 @@ def _upload_and_copy_into_table_with_retry( parquet_buffer.close() backpressure_semaphore.release() end = time.perf_counter() - logger.debug(f"Parquet file {parquet_id} upload and copy into table finished") - telemetry_dict = { - "parquet_id": parquet_id, - "thread_id": threading.get_ident(), - "duration": end - start, - } - return telemetry_dict + logger.debug( + f"Parquet file {parquet_id} upload and copy into table finished, used {end - start} seconds" + ) def _retry_run(func: Callable, *args, **kwargs) -> Any: @@ -330,23 +319,21 @@ def worker_process( stop_event: threading.Event = None, ): """Worker process that fetches data from multiple partitions""" - telemetry_set = set() while True: try: # Get item from queue with timeout partition_idx, query = partition_queue.get(timeout=1.0) - telemetry_dict = _task_fetch_data_from_source_with_retry( + _task_fetch_data_from_source_with_retry( reader, query, partition_idx, parquet_queue, stop_event, ) - telemetry_set.add(json.dumps(telemetry_dict)) except queue.Empty: # indicate whether a process is exit gracefully - process_or_thread_error_indicator.put((os.getpid(), telemetry_set)) + process_or_thread_error_indicator.put(os.getpid()) # No more work available, exit gracefully break except Exception as e: @@ -355,14 +342,13 @@ def worker_process( break -def process_completed_futures(thread_futures) -> Tuple[float, Set]: +def process_completed_futures(thread_futures) -> float: """Process completed futures with simplified error handling.""" - telemetries = set() for parquet_id, future in list(thread_futures): # Iterate over a copy of the set if future.done(): thread_futures.discard((parquet_id, future)) try: - telemetries.add(json.dumps(future.result())) + future.result() logger.debug( f"Thread future for parquet {parquet_id} completed successfully." ) @@ -378,22 +364,19 @@ def process_completed_futures(thread_futures) -> Tuple[float, Set]: ) thread_futures.clear() # Clear the set since all are cancelled raise - return time.perf_counter(), telemetries + return time.perf_counter() def _drain_process_status_queue( process_or_thread_error_indicator: Union[mp.Queue, queue.Queue], -) -> Tuple[Set, Set]: - process_id = set() - telemetry = set() +) -> Set: + result = set() while True: try: - result_indicator = process_or_thread_error_indicator.get(block=False) - process_id.add(result_indicator[0]) - telemetry = telemetry.union(result_indicator[1]) + result.add(process_or_thread_error_indicator.get(block=False)) except queue.Empty: break - return process_id, telemetry + return result def process_parquet_queue_with_threads( @@ -408,7 +391,7 @@ def process_parquet_queue_with_threads( statements_params: Optional[Dict[str, str]] = None, on_error: str = "abort_statement", fetch_with_process: bool = False, -) -> Tuple[float, float, float, Set, Set]: +) -> Tuple[float, float, float]: """ Process parquet data from a multiprocessing queue using a thread pool. @@ -439,8 +422,6 @@ def process_parquet_queue_with_threads( completed_partitions = set() gracefully_exited_processes = set() - fetch_to_local_workers_telemetries = set() - upload_to_sf_worker_telemetries = set() # process parquet_queue may produce more data than the threads can handle, # so we use semaphore to limit the number of threads backpressure_semaphore = BoundedSemaphore(value=_MAX_WORKER_SCALE * max_workers) @@ -451,13 +432,8 @@ def process_parquet_queue_with_threads( thread_futures = set() # stores tuples of (parquet_id, thread_future) while len(completed_partitions) < total_partitions or thread_futures: # Process any completed futures and handle errors - ( - upload_to_sf_end_time, - upload_to_sf_worker_telemetry, - ) = process_completed_futures(thread_futures) - upload_to_sf_worker_telemetries = upload_to_sf_worker_telemetries.union( - upload_to_sf_worker_telemetry - ) + upload_to_sf_end_time = process_completed_futures(thread_futures) + try: backpressure_semaphore.acquire() parquet_id, parquet_buffer = parquet_queue.get(block=False) @@ -506,14 +482,12 @@ def process_parquet_queue_with_threads( # Check if any processes have failed for i, process in enumerate(workers): if not process.is_alive(): - ids, telemetry = _drain_process_status_queue( - process_or_thread_error_indicator - ) gracefully_exited_processes = ( - gracefully_exited_processes.union(ids) - ) - fetch_to_local_workers_telemetries = ( - fetch_to_local_workers_telemetries.union(telemetry) + gracefully_exited_processes.union( + _drain_process_status_queue( + process_or_thread_error_indicator + ) + ) ) if process.pid not in gracefully_exited_processes: raise SnowparkDataframeReaderException( @@ -539,10 +513,8 @@ def process_parquet_queue_with_threads( for process in workers: process.join() # empty parquet queue to get all signals after each process ends - ids, telemetry = _drain_process_status_queue(process_or_thread_error_indicator) - gracefully_exited_processes = gracefully_exited_processes.union(ids) - fetch_to_local_workers_telemetries = fetch_to_local_workers_telemetries.union( - telemetry + gracefully_exited_processes = gracefully_exited_processes.union( + _drain_process_status_queue(process_or_thread_error_indicator) ) # check if any process fails @@ -562,20 +534,10 @@ def process_parquet_queue_with_threads( raise SnowparkDataframeReaderException( f"Partition {idx} data fetching thread failed with error: {e}" ) - _, telemetry = _drain_process_status_queue(process_or_thread_error_indicator) - fetch_to_local_workers_telemetries = fetch_to_local_workers_telemetries.union( - telemetry - ) logger.debug(f"fetch to local end at {fetch_to_local_end_time}") logger.debug(f"upload and copy into end at {upload_to_sf_end_time}") logger.debug( f"upload and copy into total time: {upload_to_sf_end_time - upload_to_sf_start_time}" ) - return ( - fetch_to_local_end_time, - upload_to_sf_start_time, - upload_to_sf_end_time, - fetch_to_local_workers_telemetries, - upload_to_sf_worker_telemetries, - ) + return fetch_to_local_end_time, upload_to_sf_start_time, upload_to_sf_end_time diff --git a/src/snowflake/snowpark/dataframe_reader.py b/src/snowflake/snowpark/dataframe_reader.py index 26ca31a01d..5533ebb246 100644 --- a/src/snowflake/snowpark/dataframe_reader.py +++ b/src/snowflake/snowpark/dataframe_reader.py @@ -1959,8 +1959,6 @@ def create_oracledb_connection(): fetch_to_local_end_time, upload_to_sf_start_time, upload_to_sf_end_time, - fetch_to_local_workers_telemetry, - upload_to_sf_workers_telemetry, ) = process_parquet_queue_with_threads( session=self._session, parquet_queue=parquet_queue, @@ -2025,12 +2023,6 @@ def create_oracledb_connection(): self, snowflake_table_name, struct_schema, _emit_ast=_emit_ast ) telemetry_json_string["schema"] = res_df.schema.simple_string() - telemetry_json_string["fetch_to_local_workers_telemetries"] = [ - json.loads(telemetry) for telemetry in fetch_to_local_workers_telemetry - ] - telemetry_json_string["upload_to_sf_workers_telemetries"] = [ - json.loads(telemetry) for telemetry in upload_to_sf_workers_telemetry - ] self._session._conn._telemetry_client.send_data_source_perf_telemetry( telemetry_json_string ) diff --git a/tests/integ/test_data_source_api.py b/tests/integ/test_data_source_api.py index ebdc067e75..ad736ab2aa 100644 --- a/tests/integ/test_data_source_api.py +++ b/tests/integ/test_data_source_api.py @@ -387,7 +387,7 @@ def test_partition_unsupported_type(session): @pytest.mark.parametrize("fetch_with_process", [True, False]) -def test_telemetry(session, fetch_with_process): +def test_telemetry(session, fetch_with_process, caplog): with patch( "snowflake.snowpark._internal.telemetry.TelemetryClient.send_data_source_perf_telemetry" ) as mock_telemetry: @@ -406,22 +406,13 @@ def test_telemetry(session, fetch_with_process): assert "upload_and_copy_into_sf_table_duration" in telemetry_json assert "end_to_end_duration" in telemetry_json - assert "fetch_to_local_workers_telemetries" in telemetry_json - assert "upload_to_sf_workers_telemetries" in telemetry_json - - # fetch_to_local_workers_telemetries - for entry in telemetry_json["fetch_to_local_workers_telemetries"]: - assert "partition_idx" in entry - assert "thread_id" in entry - assert "process_id" in entry - assert "duration" in entry - assert "partition_query" in entry - - # upload_to_sf_workers_telemetries - for entry in telemetry_json["upload_to_sf_workers_telemetries"]: - assert "parquet_id" in entry - assert "thread_id" in entry - assert "duration" in entry + assert "upload and copy into table start" in caplog.text + if not fetch_with_process: + assert "fetch start" in caplog.text + + assert "upload and copy into table finished, used" in caplog.text + if not fetch_with_process: + assert "fetch finished, used" in caplog.text @pytest.mark.parametrize("fetch_with_process", [True, False])