Skip to content

Commit 68273e9

Browse files
committed
Close owned agent clients on exit
* Track AgentClient ownership for public and internal Polyaxon clients. * Close owned clients from sync and async agent exit hooks to avoid leaking async SDK sessions.
1 parent 2d2a556 commit 68273e9

7 files changed

Lines changed: 214 additions & 17 deletions

File tree

cli/polyaxon/_runner/agent/async_agent.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,10 @@ async def __aenter__(self):
6161
return await self._enter()
6262

6363
async def __aexit__(self, exc_type, exc_val, exc_tb):
64-
await self._exit()
64+
try:
65+
await self._exit()
66+
finally:
67+
await self.client.aclose()
6568

6669
async def refresh_executor(self):
6770
if (

cli/polyaxon/_runner/agent/base_agent.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from polyaxon._auxiliaries import V1PolyaxonInitContainer, V1PolyaxonSidecarContainer
99
from polyaxon._connections import V1Connection
1010
from polyaxon._constants.globals import DEFAULT
11-
from polyaxon._runner.agent.client import AgentClient
11+
from polyaxon._runner.agent.client import AgentClient, AsyncAgentClient
1212
from polyaxon._runner.executor import BaseExecutor
1313
from polyaxon._schemas.checks import ChecksConfig
1414
from polyaxon._schemas.lifecycle import LiveState, V1Statuses
@@ -44,9 +44,8 @@ def __init__(
4444
self._graceful_shutdown = False
4545
self._last_data_collected_at = last_hour
4646
self._last_reconciled_at = last_hour
47-
self.client = AgentClient(
48-
owner=owner, agent_uuid=agent_uuid, is_async=self.IS_ASYNC
49-
)
47+
agent_client_cls = AsyncAgentClient if self.IS_ASYNC else AgentClient
48+
self.client = agent_client_cls(owner=owner, agent_uuid=agent_uuid)
5049
self.executor = self.EXECUTOR()
5150
self.content = settings.AGENT_CONFIG.to_json()
5251

cli/polyaxon/_runner/agent/client.py

Lines changed: 41 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,39 +3,57 @@
33

44
from polyaxon._schemas.lifecycle import V1StatusCondition, V1Statuses
55
from polyaxon.client import PolyaxonClient, V1Agent, V1AgentStateResponse
6+
from polyaxon.exceptions import PolyaxonClientException
67
from polyaxon.logger import logger
78

89

9-
class AgentClient:
10+
class _AgentClientBase:
11+
_IS_ASYNC = False
12+
1013
def __init__(
1114
self,
1215
owner: Optional[str] = None,
1316
agent_uuid: Optional[str] = None,
1417
client: Optional[PolyaxonClient] = None,
1518
internal_client: Optional[PolyaxonClient] = None,
16-
is_async: bool = False,
1719
):
1820
self.owner = owner
1921
self.agent_uuid = agent_uuid
20-
self.is_async = is_async
22+
self._validate_client_mode(client, "client")
23+
self._validate_client_mode(internal_client, "internal_client")
2124
self._client = client
2225
self._internal_client = internal_client
26+
self._created_client = None
27+
self._created_internal_client = None
28+
29+
def _validate_client_mode(self, client: Optional[PolyaxonClient], name: str):
30+
if client is None:
31+
return
32+
is_async = getattr(client, "is_async", None)
33+
if isinstance(is_async, bool) and is_async != self._IS_ASYNC:
34+
raise PolyaxonClientException(
35+
"Injected `{}` transport mode does not match AgentClient mode.".format(
36+
name
37+
)
38+
)
2339

2440
@property
2541
def client(self):
2642
if self._client:
2743
return self._client
28-
self._client = PolyaxonClient(is_async=self.is_async)
44+
self._client = PolyaxonClient(is_async=self._IS_ASYNC)
45+
self._created_client = self._client
2946
return self._client
3047

3148
@property
3249
def internal_client(self):
3350
if self._internal_client:
3451
return self._internal_client
3552
self._internal_client = PolyaxonClient(
36-
is_async=self.is_async,
53+
is_async=self._IS_ASYNC,
3754
is_internal=True,
3855
)
56+
self._created_internal_client = self._internal_client
3957
return self._internal_client
4058

4159
@property
@@ -181,3 +199,21 @@ def log_run_status(
181199
uuid=run_uuid,
182200
body={"condition": status_condition},
183201
)
202+
203+
204+
class AgentClient(_AgentClientBase):
205+
def close(self):
206+
if self._created_client is not None:
207+
self._created_client.close()
208+
if self._created_internal_client is not None:
209+
self._created_internal_client.close()
210+
211+
212+
class AsyncAgentClient(_AgentClientBase):
213+
_IS_ASYNC = True
214+
215+
async def aclose(self):
216+
if self._created_client is not None:
217+
await self._created_client.aclose()
218+
if self._created_internal_client is not None:
219+
await self._created_internal_client.aclose()

cli/polyaxon/_runner/agent/sync_agent.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,10 @@ def __enter__(self):
6060
return self._enter()
6161

6262
def __exit__(self, exc_type, exc_val, exc_tb):
63-
self._exit()
63+
try:
64+
self._exit()
65+
finally:
66+
self.client.close()
6467

6568
def refresh_executor(self):
6669
if (

cli/tests/test_k8s/test_async_agent.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
from polyaxon._k8s.agent.async_agent import AsyncAgent
55
from polyaxon._k8s.executor.async_executor import AsyncExecutor
6-
from polyaxon._runner.agent.client import AgentClient
6+
from polyaxon._runner.agent.client import AsyncAgentClient
77
from polyaxon._utils.test_utils import AsyncMock, patch_settings
88

99

@@ -18,7 +18,7 @@ async def test_init_agent_component(register):
1818
agent = AsyncAgent(owner="foo", agent_uuid="uuid")
1919
assert agent.max_interval == 6
2020
assert isinstance(agent.executor, AsyncExecutor)
21-
assert isinstance(agent.client, AgentClient)
21+
assert isinstance(agent.client, AsyncAgentClient)
2222
assert register.call_count == 0
2323

2424

@@ -43,7 +43,7 @@ async def test_init_agent(
4343
agent.executor.manager.get_version.return_value = {}
4444
assert agent.max_interval == 6
4545
assert agent.executor is not None
46-
assert isinstance(agent.client, AgentClient)
46+
assert isinstance(agent.client, AsyncAgentClient)
4747
assert get_agent.call_count == 0
4848
assert get_agent_state.call_count == 0
4949
assert create_agent_status.call_count == 0
@@ -53,9 +53,24 @@ async def test_init_agent(
5353
await agent._enter()
5454
assert agent.max_interval == 6
5555
assert agent.executor is not None
56-
assert isinstance(agent.client, AgentClient)
56+
assert isinstance(agent.client, AsyncAgentClient)
5757
assert get_agent.call_count == 1
5858
assert get_agent_state.call_count == 0
5959
assert create_agent_status.call_count == 1
6060
assert sync_agent.call_count == 1
6161
assert agent.executor.manager.get_version.call_count == 1
62+
63+
64+
@pytest.mark.agent_mark
65+
@pytest.mark.asyncio
66+
async def test_async_agent_aexit_closes_client_in_finally():
67+
patch_settings()
68+
agent = AsyncAgent(owner="foo", agent_uuid="uuid")
69+
agent.client = MagicMock()
70+
agent.client.aclose = AsyncMock()
71+
agent._exit = AsyncMock(side_effect=RuntimeError("exit failed"))
72+
73+
with pytest.raises(RuntimeError):
74+
await agent.__aexit__(None, None, None)
75+
76+
agent.client.aclose.assert_called_once()
Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
from mock import MagicMock, patch
2+
import pytest
3+
4+
from polyaxon._runner.agent.client import AgentClient, AsyncAgentClient
5+
from polyaxon.exceptions import PolyaxonClientException
6+
7+
8+
pytestmark = pytest.mark.agent_mark
9+
10+
11+
class ClientMock:
12+
def __init__(self, is_async: bool):
13+
self.is_async = is_async
14+
self.agents_v1 = MagicMock()
15+
self.runs_v1 = MagicMock()
16+
self.close_calls = 0
17+
self.aclose_calls = 0
18+
19+
def close(self):
20+
self.close_calls += 1
21+
22+
async def aclose(self):
23+
self.aclose_calls += 1
24+
25+
26+
@patch("polyaxon._runner.agent.client.PolyaxonClient")
27+
def test_agent_client_close_closes_owned_clients(client_cls):
28+
public_client = ClientMock(is_async=False)
29+
internal_client = ClientMock(is_async=False)
30+
client_cls.side_effect = [public_client, internal_client]
31+
client = AgentClient(owner="foo", agent_uuid="uuid")
32+
33+
assert client.client is public_client
34+
assert client.internal_client is internal_client
35+
36+
client.close()
37+
38+
assert public_client.close_calls == 1
39+
assert internal_client.close_calls == 1
40+
41+
42+
@pytest.mark.asyncio
43+
@patch("polyaxon._runner.agent.client.PolyaxonClient")
44+
async def test_agent_client_aclose_closes_owned_clients(client_cls):
45+
public_client = ClientMock(is_async=True)
46+
internal_client = ClientMock(is_async=True)
47+
client_cls.side_effect = [public_client, internal_client]
48+
client = AsyncAgentClient(owner="foo", agent_uuid="uuid")
49+
50+
assert client.client is public_client
51+
assert client.internal_client is internal_client
52+
53+
await client.aclose()
54+
55+
assert public_client.aclose_calls == 1
56+
assert internal_client.aclose_calls == 1
57+
58+
59+
def test_agent_client_close_does_not_close_injected_clients():
60+
public_client = ClientMock(is_async=False)
61+
internal_client = ClientMock(is_async=False)
62+
client = AgentClient(
63+
owner="foo",
64+
agent_uuid="uuid",
65+
client=public_client,
66+
internal_client=internal_client,
67+
)
68+
69+
client.close()
70+
71+
assert public_client.close_calls == 0
72+
assert internal_client.close_calls == 0
73+
74+
75+
@pytest.mark.asyncio
76+
async def test_agent_client_aclose_does_not_close_injected_clients():
77+
public_client = ClientMock(is_async=True)
78+
internal_client = ClientMock(is_async=True)
79+
client = AsyncAgentClient(
80+
owner="foo",
81+
agent_uuid="uuid",
82+
client=public_client,
83+
internal_client=internal_client,
84+
)
85+
86+
await client.aclose()
87+
88+
assert public_client.aclose_calls == 0
89+
assert internal_client.aclose_calls == 0
90+
91+
92+
def test_async_agent_client_does_not_expose_sync_close():
93+
client = AsyncAgentClient(owner="foo", agent_uuid="uuid")
94+
95+
assert not hasattr(client, "close")
96+
97+
98+
def test_agent_client_does_not_expose_async_close():
99+
client = AgentClient(owner="foo", agent_uuid="uuid")
100+
101+
assert not hasattr(client, "aclose")
102+
103+
104+
def test_agent_client_rejects_mode_mismatch_public_client():
105+
with pytest.raises(PolyaxonClientException):
106+
AsyncAgentClient(
107+
owner="foo",
108+
agent_uuid="uuid",
109+
client=ClientMock(is_async=False),
110+
)
111+
112+
113+
def test_agent_client_rejects_mode_mismatch_internal_client():
114+
with pytest.raises(PolyaxonClientException):
115+
AgentClient(
116+
owner="foo",
117+
agent_uuid="uuid",
118+
internal_client=ClientMock(is_async=True),
119+
)
120+
121+
122+
def test_agent_client_allows_loose_mocks_without_mode_attribute():
123+
client = AgentClient(
124+
owner="foo",
125+
agent_uuid="uuid",
126+
client=MagicMock(),
127+
internal_client=MagicMock(),
128+
)
129+
130+
assert client.client is not None
131+
assert client.internal_client is not None

cli/tests/test_runner/test_base_agent.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import pytest
33

44
from polyaxon._constants.globals import DEFAULT
5-
from polyaxon._runner.agent.client import AgentClient
5+
from polyaxon._runner.agent.client import AgentClient, AsyncAgentClient
66
from polyaxon._runner.agent.sync_agent import BaseSyncAgent
77
from polyaxon._utils.test_utils import BaseTestCase
88

@@ -148,8 +148,18 @@ def test_agent_client_uses_internal_client_for_collect_agent_data(self):
148148

149149
@patch("polyaxon._runner.agent.client.PolyaxonClient")
150150
def test_agent_client_creates_internal_client_with_internal_mode(self, client_cls):
151-
client = AgentClient(owner="foo", agent_uuid="uuid", is_async=True)
151+
client = AsyncAgentClient(owner="foo", agent_uuid="uuid")
152152

153153
_ = client.internal_client
154154

155155
client_cls.assert_called_once_with(is_async=True, is_internal=True)
156+
157+
def test_sync_agent_exit_closes_client_in_finally(self):
158+
agent = DummyAgent(owner="foo", agent_uuid="uuid")
159+
agent.client = MagicMock()
160+
agent._exit = MagicMock(side_effect=RuntimeError("exit failed"))
161+
162+
with pytest.raises(RuntimeError):
163+
agent.__exit__(None, None, None)
164+
165+
agent.client.close.assert_called_once()

0 commit comments

Comments
 (0)