|
14 | 14 | import os |
15 | 15 | import time |
16 | 16 | import unittest |
| 17 | +from types import SimpleNamespace |
| 18 | +from unittest import mock |
17 | 19 |
|
18 | 20 | import numpy as np |
19 | 21 | import pyarrow as pa |
| 22 | +import requests |
20 | 23 | from odps import ODPS |
| 24 | +from odps.errors import ODPSError |
21 | 25 | from parameterized import parameterized |
22 | 26 | from torch import distributed as dist |
23 | 27 | from torch.utils.data import DataLoader |
24 | 28 |
|
| 29 | +from tzrec.datasets import odps_dataset |
25 | 30 | from tzrec.datasets.odps_dataset import OdpsDataset, OdpsWriter, _create_odps_account |
26 | 31 | from tzrec.features.feature import FgMode, create_features |
27 | 32 | from tzrec.protos import data_pb2, feature_pb2, sampler_pb2 |
@@ -689,5 +694,107 @@ def _writer_worker(rank, port): |
689 | 694 | self.assertEqual(reader.count, 1280) |
690 | 695 |
|
691 | 696 |
|
| 697 | +class _FailingStorageClient: |
| 698 | + """Stub storage-api client whose calls raise a connection-reset error.""" |
| 699 | + |
| 700 | + table = SimpleNamespace(full_table_name="test_project.test_table") |
| 701 | + _quota_name = "test_quota" |
| 702 | + _rest_endpoint = "http://service.odps.test" |
| 703 | + |
| 704 | + def _raise(self): |
| 705 | + raise requests.exceptions.ConnectionError( |
| 706 | + "Connection aborted.", |
| 707 | + ConnectionResetError(104, "Connection reset by peer"), |
| 708 | + ) |
| 709 | + |
| 710 | + def read_rows_arrow(self, read_req): |
| 711 | + self._raise() |
| 712 | + |
| 713 | + def get_read_session(self, sess_req): |
| 714 | + self._raise() |
| 715 | + |
| 716 | + |
| 717 | +class _RetryingStorageClient: |
| 718 | + """Stub whose read_rows_arrow raises ODPSError `fail_times` times, then succeeds.""" |
| 719 | + |
| 720 | + table = SimpleNamespace(full_table_name="test_project.test_table") |
| 721 | + _quota_name = "test_quota" |
| 722 | + _rest_endpoint = "http://service.odps.test" |
| 723 | + |
| 724 | + def __init__(self, fail_times): |
| 725 | + self._fail_times = fail_times |
| 726 | + self.calls = 0 |
| 727 | + self.sentinel = object() |
| 728 | + |
| 729 | + def read_rows_arrow(self, read_req): |
| 730 | + self.calls += 1 |
| 731 | + if self.calls <= self._fail_times: |
| 732 | + raise ODPSError("synthetic ODPS error", request_id="REQ-XYZ") |
| 733 | + return self.sentinel |
| 734 | + |
| 735 | + |
| 736 | +class OdpsStorageErrorLogTest(unittest.TestCase): |
| 737 | + def test_read_rows_arrow_logs_session_and_reraises(self): |
| 738 | + client = _FailingStorageClient() |
| 739 | + read_req = SimpleNamespace( |
| 740 | + session_id="sess-read-123", |
| 741 | + row_index=0, |
| 742 | + row_count=10, |
| 743 | + max_batch_rows=100, |
| 744 | + ) |
| 745 | + with mock.patch.object(odps_dataset, "logger") as m_logger: |
| 746 | + with self.assertRaises(requests.exceptions.ConnectionError): |
| 747 | + odps_dataset._read_rows_arrow_with_retry(client, read_req) |
| 748 | + m_logger.error.assert_called_once() |
| 749 | + logged = m_logger.error.call_args[0][0] |
| 750 | + self.assertIn("sess-read-123", logged) |
| 751 | + self.assertIn("test_project.test_table", logged) |
| 752 | + |
| 753 | + def test_get_read_session_logs_session_and_reraises(self): |
| 754 | + client = _FailingStorageClient() |
| 755 | + sess_req = SimpleNamespace(session_id="sess-scan-456") |
| 756 | + with mock.patch.object(odps_dataset, "logger") as m_logger: |
| 757 | + with self.assertRaises(requests.exceptions.ConnectionError): |
| 758 | + odps_dataset._get_session_record_count(client, sess_req) |
| 759 | + m_logger.error.assert_called_once() |
| 760 | + logged = m_logger.error.call_args[0][0] |
| 761 | + self.assertIn("sess-scan-456", logged) |
| 762 | + |
| 763 | + def test_read_rows_arrow_retries_odps_error_then_succeeds(self): |
| 764 | + client = _RetryingStorageClient(fail_times=2) |
| 765 | + read_req = SimpleNamespace( |
| 766 | + session_id="sess-retry-1", |
| 767 | + row_index=0, |
| 768 | + row_count=10, |
| 769 | + max_batch_rows=100, |
| 770 | + ) |
| 771 | + with mock.patch.object(odps_dataset.time, "sleep"): |
| 772 | + with mock.patch.object(odps_dataset, "logger") as m_logger: |
| 773 | + reader = odps_dataset._read_rows_arrow_with_retry(client, read_req) |
| 774 | + self.assertIs(reader, client.sentinel) |
| 775 | + self.assertEqual(client.calls, 3) # 2 failures + 1 success |
| 776 | + self.assertEqual(m_logger.warning.call_count, 2) |
| 777 | + m_logger.error.assert_not_called() |
| 778 | + |
| 779 | + def test_read_rows_arrow_odps_error_exhausts_retries_and_raises(self): |
| 780 | + client = _RetryingStorageClient(fail_times=99) |
| 781 | + read_req = SimpleNamespace( |
| 782 | + session_id="sess-retry-2", |
| 783 | + row_index=5, |
| 784 | + row_count=10, |
| 785 | + max_batch_rows=100, |
| 786 | + ) |
| 787 | + with mock.patch.object(odps_dataset.time, "sleep"): |
| 788 | + with mock.patch.object(odps_dataset, "logger") as m_logger: |
| 789 | + with self.assertRaises(ODPSError): |
| 790 | + odps_dataset._read_rows_arrow_with_retry(client, read_req) |
| 791 | + self.assertEqual(client.calls, 4) # initial + 3 retries |
| 792 | + self.assertEqual(m_logger.warning.call_count, 3) |
| 793 | + m_logger.error.assert_called_once() |
| 794 | + logged = m_logger.error.call_args[0][0] |
| 795 | + self.assertIn("after 3 retries", logged) |
| 796 | + self.assertIn("sess-retry-2", logged) |
| 797 | + |
| 798 | + |
692 | 799 | if __name__ == "__main__": |
693 | 800 | unittest.main() |
0 commit comments