Skip to content

Commit e2bc17d

Browse files
[bugfix] ODPS reader: pin pyodps==0.12.5.1 (retry connection resets) + log session/table context on errors (#537)
Co-authored-by: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
1 parent 664fbb9 commit e2bc17d

4 files changed

Lines changed: 146 additions & 3 deletions

File tree

requirements/runtime.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ psutil
1515
pyfg @ https://tzrec.oss-accelerate.aliyuncs.com/third_party/pyfg-1.0.5-cp312-cp312-linux_x86_64.whl ; python_version=="3.12"
1616
pyfg @ https://tzrec.oss-accelerate.aliyuncs.com/third_party/pyfg-1.0.5-cp311-cp311-linux_x86_64.whl ; python_version=="3.11"
1717
pyfg @ https://tzrec.oss-accelerate.aliyuncs.com/third_party/pyfg-1.0.5-cp310-cp310-linux_x86_64.whl ; python_version=="3.10"
18-
pyodps>=0.12.4
18+
pyodps==0.12.5.1
1919
safetensors
2020
scikit-learn
2121
tensorboard

tzrec/datasets/odps_dataset.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -206,21 +206,52 @@ def _parse_table_path(
206206
return str_list[2], table_name, table_partitions, schema
207207

208208

209+
def _storage_debug_info(client: StorageApiArrowClient, **req_fields: Any) -> str:
210+
"""Format MaxCompute storage-api context for error logging."""
211+
parts = {
212+
"table": client.table.full_table_name,
213+
"quota": getattr(client, "_quota_name", None),
214+
"endpoint": getattr(client, "_rest_endpoint", None),
215+
**req_fields,
216+
}
217+
return ", ".join(f"{k}={v}" for k, v in parts.items())
218+
219+
209220
def _read_rows_arrow_with_retry(
210221
client: StorageApiArrowClient,
211222
read_req: ReadRowsRequest,
212223
) -> ArrowReader:
224+
def debug_info() -> str:
225+
return _storage_debug_info(
226+
client,
227+
session_id=read_req.session_id,
228+
row_index=read_req.row_index,
229+
row_count=read_req.row_count,
230+
max_batch_rows=read_req.max_batch_rows,
231+
)
232+
213233
max_retry_count = 3
214234
retry_cnt = 0
215235
while True:
216236
try:
217237
reader = client.read_rows_arrow(read_req)
218238
except ODPSError as e:
219239
if retry_cnt >= max_retry_count:
240+
logger.error(
241+
f"read_rows_arrow failed after {retry_cnt} retries "
242+
f"({debug_info()}): {e!r}"
243+
)
220244
raise e
221245
retry_cnt += 1
246+
logger.warning(
247+
f"read_rows_arrow retry {retry_cnt}/{max_retry_count} "
248+
f"({debug_info()}): {e!r}"
249+
)
222250
time.sleep(random.choice([5, 9, 12]))
223251
continue
252+
except Exception as e:
253+
logger.error(f"read_rows_arrow failed ({debug_info()}): {e!r}")
254+
raise
224255
break
225256
return reader
226257

@@ -231,7 +262,12 @@ def _get_session_record_count(
231262
) -> int:
232263
"""Get record count from a session, waiting until ready."""
233264
while True:
234-
scan_resp = client.get_read_session(sess_req)
265+
try:
266+
scan_resp = client.get_read_session(sess_req)
267+
except Exception as e:
268+
debug_info = _storage_debug_info(client, session_id=sess_req.session_id)
269+
logger.error(f"get_read_session failed ({debug_info}): {e!r}")
270+
raise
235271
if scan_resp.session_status == SessionStatus.INIT:
236272
time.sleep(1)
237273
continue

tzrec/datasets/odps_dataset_test.py

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,19 @@
1414
import os
1515
import time
1616
import unittest
17+
from types import SimpleNamespace
18+
from unittest import mock
1719

1820
import numpy as np
1921
import pyarrow as pa
22+
import requests
2023
from odps import ODPS
24+
from odps.errors import ODPSError
2125
from parameterized import parameterized
2226
from torch import distributed as dist
2327
from torch.utils.data import DataLoader
2428

29+
from tzrec.datasets import odps_dataset
2530
from tzrec.datasets.odps_dataset import OdpsDataset, OdpsWriter, _create_odps_account
2631
from tzrec.features.feature import FgMode, create_features
2732
from tzrec.protos import data_pb2, feature_pb2, sampler_pb2
@@ -689,5 +694,107 @@ def _writer_worker(rank, port):
689694
self.assertEqual(reader.count, 1280)
690695

691696

697+
class _FailingStorageClient:
698+
"""Stub storage-api client whose calls raise a connection-reset error."""
699+
700+
table = SimpleNamespace(full_table_name="test_project.test_table")
701+
_quota_name = "test_quota"
702+
_rest_endpoint = "http://service.odps.test"
703+
704+
def _raise(self):
705+
raise requests.exceptions.ConnectionError(
706+
"Connection aborted.",
707+
ConnectionResetError(104, "Connection reset by peer"),
708+
)
709+
710+
def read_rows_arrow(self, read_req):
711+
self._raise()
712+
713+
def get_read_session(self, sess_req):
714+
self._raise()
715+
716+
717+
class _RetryingStorageClient:
718+
"""Stub whose read_rows_arrow raises ODPSError `fail_times` times, then succeeds."""
719+
720+
table = SimpleNamespace(full_table_name="test_project.test_table")
721+
_quota_name = "test_quota"
722+
_rest_endpoint = "http://service.odps.test"
723+
724+
def __init__(self, fail_times):
725+
self._fail_times = fail_times
726+
self.calls = 0
727+
self.sentinel = object()
728+
729+
def read_rows_arrow(self, read_req):
730+
self.calls += 1
731+
if self.calls <= self._fail_times:
732+
raise ODPSError("synthetic ODPS error", request_id="REQ-XYZ")
733+
return self.sentinel
734+
735+
736+
class OdpsStorageErrorLogTest(unittest.TestCase):
737+
def test_read_rows_arrow_logs_session_and_reraises(self):
738+
client = _FailingStorageClient()
739+
read_req = SimpleNamespace(
740+
session_id="sess-read-123",
741+
row_index=0,
742+
row_count=10,
743+
max_batch_rows=100,
744+
)
745+
with mock.patch.object(odps_dataset, "logger") as m_logger:
746+
with self.assertRaises(requests.exceptions.ConnectionError):
747+
odps_dataset._read_rows_arrow_with_retry(client, read_req)
748+
m_logger.error.assert_called_once()
749+
logged = m_logger.error.call_args[0][0]
750+
self.assertIn("sess-read-123", logged)
751+
self.assertIn("test_project.test_table", logged)
752+
753+
def test_get_read_session_logs_session_and_reraises(self):
754+
client = _FailingStorageClient()
755+
sess_req = SimpleNamespace(session_id="sess-scan-456")
756+
with mock.patch.object(odps_dataset, "logger") as m_logger:
757+
with self.assertRaises(requests.exceptions.ConnectionError):
758+
odps_dataset._get_session_record_count(client, sess_req)
759+
m_logger.error.assert_called_once()
760+
logged = m_logger.error.call_args[0][0]
761+
self.assertIn("sess-scan-456", logged)
762+
763+
def test_read_rows_arrow_retries_odps_error_then_succeeds(self):
764+
client = _RetryingStorageClient(fail_times=2)
765+
read_req = SimpleNamespace(
766+
session_id="sess-retry-1",
767+
row_index=0,
768+
row_count=10,
769+
max_batch_rows=100,
770+
)
771+
with mock.patch.object(odps_dataset.time, "sleep"):
772+
with mock.patch.object(odps_dataset, "logger") as m_logger:
773+
reader = odps_dataset._read_rows_arrow_with_retry(client, read_req)
774+
self.assertIs(reader, client.sentinel)
775+
self.assertEqual(client.calls, 3) # 2 failures + 1 success
776+
self.assertEqual(m_logger.warning.call_count, 2)
777+
m_logger.error.assert_not_called()
778+
779+
def test_read_rows_arrow_odps_error_exhausts_retries_and_raises(self):
780+
client = _RetryingStorageClient(fail_times=99)
781+
read_req = SimpleNamespace(
782+
session_id="sess-retry-2",
783+
row_index=5,
784+
row_count=10,
785+
max_batch_rows=100,
786+
)
787+
with mock.patch.object(odps_dataset.time, "sleep"):
788+
with mock.patch.object(odps_dataset, "logger") as m_logger:
789+
with self.assertRaises(ODPSError):
790+
odps_dataset._read_rows_arrow_with_retry(client, read_req)
791+
self.assertEqual(client.calls, 4) # initial + 3 retries
792+
self.assertEqual(m_logger.warning.call_count, 3)
793+
m_logger.error.assert_called_once()
794+
logged = m_logger.error.call_args[0][0]
795+
self.assertIn("after 3 retries", logged)
796+
self.assertIn("sess-retry-2", logged)
797+
798+
692799
if __name__ == "__main__":
693800
unittest.main()

tzrec/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,4 @@
99
# See the License for the specific language governing permissions and
1010
# limitations under the License.
1111

12-
__version__ = "1.2.19"
12+
__version__ = "1.2.20"

0 commit comments

Comments
 (0)