|
3 | 3 | from __future__ import annotations |
4 | 4 |
|
5 | 5 | import asyncio |
| 6 | +import errno |
6 | 7 | import json |
| 8 | +import shutil |
7 | 9 | import tempfile |
8 | 10 | import uuid |
9 | 11 | from pathlib import Path |
@@ -755,6 +757,79 @@ def capture_transport(*, prompt, options): |
755 | 757 | # Cleanup removed the temp dir. |
756 | 758 | assert not Path(config_dir).exists() |
757 | 759 |
|
| 760 | + @pytest.mark.asyncio |
| 761 | + async def test_custom_transport_skips_materialization( |
| 762 | + self, |
| 763 | + cwd: Path, |
| 764 | + project_key: str, |
| 765 | + isolated_home: Path, |
| 766 | + track_resume_dirs: list[Path], |
| 767 | + ) -> None: |
| 768 | + """A pre-constructed custom transport never sees the materialized |
| 769 | + options, so loading the store and writing .credentials.json to a |
| 770 | + temp dir would be wasted (and leave the access token on disk for |
| 771 | + the session lifetime). connect() must skip materialization.""" |
| 772 | + |
| 773 | + class SpyStore(InMemorySessionStore): |
| 774 | + load_calls = 0 |
| 775 | + |
| 776 | + async def load(self, key): # type: ignore[override] |
| 777 | + SpyStore.load_calls += 1 |
| 778 | + return await super().load(key) |
| 779 | + |
| 780 | + store = SpyStore() |
| 781 | + await store.append( |
| 782 | + {"project_key": project_key, "session_id": SESSION_ID}, |
| 783 | + [{"type": "user", "uuid": "u1"}], |
| 784 | + ) |
| 785 | + |
| 786 | + opts = ClaudeAgentOptions(cwd=cwd, session_store=store, resume=SESSION_ID) |
| 787 | + client = ClaudeSDKClient(options=opts, transport=_make_mock_transport()) |
| 788 | + |
| 789 | + with patch( |
| 790 | + "claude_agent_sdk._internal.query.Query.initialize", |
| 791 | + new_callable=AsyncMock, |
| 792 | + ): |
| 793 | + await client.connect() |
| 794 | + assert SpyStore.load_calls == 0 |
| 795 | + assert not track_resume_dirs # mkdtemp never ran |
| 796 | + assert client._materialized is None |
| 797 | + await client.disconnect() |
| 798 | + |
| 799 | + @pytest.mark.asyncio |
| 800 | + async def test_query_custom_transport_skips_materialization( |
| 801 | + self, |
| 802 | + cwd: Path, |
| 803 | + project_key: str, |
| 804 | + isolated_home: Path, |
| 805 | + track_resume_dirs: list[Path], |
| 806 | + ) -> None: |
| 807 | + """Same gate for the one-shot ``query()`` path.""" |
| 808 | + |
| 809 | + class SpyStore(InMemorySessionStore): |
| 810 | + load_calls = 0 |
| 811 | + |
| 812 | + async def load(self, key): # type: ignore[override] |
| 813 | + SpyStore.load_calls += 1 |
| 814 | + return await super().load(key) |
| 815 | + |
| 816 | + store = SpyStore() |
| 817 | + await store.append( |
| 818 | + {"project_key": project_key, "session_id": SESSION_ID}, |
| 819 | + [{"type": "user", "uuid": "u1"}], |
| 820 | + ) |
| 821 | + |
| 822 | + opts = ClaudeAgentOptions(cwd=cwd, session_store=store, resume=SESSION_ID) |
| 823 | + custom = _make_mock_transport() |
| 824 | + custom.connect = AsyncMock(side_effect=OSError("spawn failed")) |
| 825 | + with pytest.raises(OSError, match="spawn failed"): |
| 826 | + async for _ in query(prompt="hi", options=opts, transport=custom): |
| 827 | + pass # pragma: no cover |
| 828 | + |
| 829 | + # Gate runs before transport.connect(); materialization never happened. |
| 830 | + assert SpyStore.load_calls == 0 |
| 831 | + assert not track_resume_dirs |
| 832 | + |
758 | 833 | @pytest.mark.asyncio |
759 | 834 | async def test_connect_no_materialization_passthrough( |
760 | 835 | self, cwd: Path, isolated_home: Path |
@@ -815,6 +890,84 @@ class TestSpawnFailureCleanup: |
815 | 890 | removed even when transport.connect() raises before any try/finally that |
816 | 891 | normally guards cleanup.""" |
817 | 892 |
|
| 893 | + @pytest.mark.asyncio |
| 894 | + async def test_cleanup_retries_on_transient_os_error( |
| 895 | + self, |
| 896 | + cwd: Path, |
| 897 | + project_key: str, |
| 898 | + isolated_home: Path, |
| 899 | + monkeypatch: pytest.MonkeyPatch, |
| 900 | + ) -> None: |
| 901 | + """Windows AV/indexer can briefly hold ``.credentials.json`` open; |
| 902 | + ``cleanup()`` must retry rmtree on EPERM/EBUSY so the access token |
| 903 | + doesn't leak in temp.""" |
| 904 | + store = InMemorySessionStore() |
| 905 | + await store.append( |
| 906 | + {"project_key": project_key, "session_id": SESSION_ID}, |
| 907 | + [{"type": "user", "uuid": "u1"}], |
| 908 | + ) |
| 909 | + opts = ClaudeAgentOptions(cwd=cwd, session_store=store, resume=SESSION_ID) |
| 910 | + m = await materialize_resume_session(opts) |
| 911 | + assert m is not None |
| 912 | + config_dir = m.config_dir |
| 913 | + |
| 914 | + calls: list[Any] = [] |
| 915 | + real_rmtree = shutil.rmtree |
| 916 | + |
| 917 | + def fake_rmtree(p: Any, **kw: Any) -> None: |
| 918 | + calls.append((p, kw)) |
| 919 | + if len(calls) <= 2 and not kw.get("ignore_errors"): |
| 920 | + raise PermissionError(errno.EPERM, "held by indexer") |
| 921 | + if Path(p).exists(): |
| 922 | + real_rmtree(p, **kw) |
| 923 | + |
| 924 | + monkeypatch.setattr(shutil, "rmtree", fake_rmtree) |
| 925 | + await m.cleanup() |
| 926 | + |
| 927 | + assert not config_dir.exists() |
| 928 | + assert len(calls) >= 3 # 2 failures + 1 success |
| 929 | + |
| 930 | + @pytest.mark.asyncio |
| 931 | + async def test_failure_path_retries_rmtree( |
| 932 | + self, |
| 933 | + cwd: Path, |
| 934 | + project_key: str, |
| 935 | + isolated_home: Path, |
| 936 | + monkeypatch: pytest.MonkeyPatch, |
| 937 | + track_resume_dirs: list[Path], |
| 938 | + ) -> None: |
| 939 | + """The except-BaseException cleanup path also retries on EPERM.""" |
| 940 | + |
| 941 | + class FailLateStore(InMemorySessionStore): |
| 942 | + async def list_subkeys(self, key): # type: ignore[override] |
| 943 | + raise OSError("boom") |
| 944 | + |
| 945 | + store = FailLateStore() |
| 946 | + await store.append( |
| 947 | + {"project_key": project_key, "session_id": SESSION_ID}, |
| 948 | + [{"type": "user", "uuid": "u1"}], |
| 949 | + ) |
| 950 | + |
| 951 | + calls: list[Any] = [] |
| 952 | + real_rmtree = shutil.rmtree |
| 953 | + |
| 954 | + def fake_rmtree(p: Any, **kw: Any) -> None: |
| 955 | + calls.append((p, kw)) |
| 956 | + if len(calls) <= 2 and not kw.get("ignore_errors"): |
| 957 | + raise PermissionError(errno.EPERM, "held by indexer") |
| 958 | + if Path(p).exists(): |
| 959 | + real_rmtree(p, **kw) |
| 960 | + |
| 961 | + monkeypatch.setattr(shutil, "rmtree", fake_rmtree) |
| 962 | + |
| 963 | + opts = ClaudeAgentOptions(cwd=cwd, session_store=store, resume=SESSION_ID) |
| 964 | + with pytest.raises(RuntimeError, match="boom"): |
| 965 | + await materialize_resume_session(opts) |
| 966 | + |
| 967 | + assert track_resume_dirs |
| 968 | + assert not track_resume_dirs[0].exists() |
| 969 | + assert len(calls) >= 3 |
| 970 | + |
818 | 971 | @pytest.mark.asyncio |
819 | 972 | async def test_client_connect_failure_removes_temp_dir( |
820 | 973 | self, |
|
0 commit comments