Skip to content

Commit 93d0a43

Browse files
committed
Add async run logs and status polling
* Add AsyncRunClient support for get_logs, inspect, wait_for_condition, and watch_statuses. * Share status wait condition helpers while keeping sync and async polling control flow explicit.
1 parent 68273e9 commit 93d0a43

2 files changed

Lines changed: 184 additions & 19 deletions

File tree

cli/polyaxon/_client/run.py

Lines changed: 76 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -999,23 +999,25 @@ def get_statuses(
999999
except (ApiException, HTTPError) as e:
10001000
raise PolyaxonClientException("Api error: %s" % e) from e
10011001

1002-
def _wait_for_condition(self, statuses: Optional[List[str]] = None):
1003-
statuses = to_list(statuses, check_none=True)
1002+
def _should_stop_waiting(self, last_status, statuses) -> bool:
1003+
if statuses:
1004+
return last_status in statuses
1005+
return LifeCycle.is_done(last_status)
10041006

1005-
def condition():
1006-
if statuses:
1007-
return last_status in statuses
1008-
return LifeCycle.is_done(last_status)
1007+
def _is_retryable_status_error(self, error: ApiException) -> bool:
1008+
return error.status in {500, 502, 503, 504}
10091009

1010+
def _wait_for_condition(self, statuses: Optional[List[str]] = None):
1011+
statuses = to_list(statuses, check_none=True)
10101012
last_status = None
1011-
while not condition():
1013+
while not self._should_stop_waiting(last_status, statuses):
10121014
if last_status:
10131015
time.sleep(settings.CLIENT_CONFIG.watch_interval)
10141016
try:
10151017
last_status, _conditions = self.get_statuses(last_status)
10161018
yield last_status, _conditions
10171019
except ApiException as e:
1018-
if e.status in {500, 502, 503, 504}:
1020+
if self._is_retryable_status_error(e):
10191021
yield last_status, []
10201022
else:
10211023
raise e
@@ -3435,25 +3437,82 @@ async def get_statuses(
34353437
except ApiException as e:
34363438
raise PolyaxonClientException("Api error: %s" % e) from e
34373439

3438-
@async_client_handler(check_no_op=True, check_offline=True)
3439-
async def wait_for_condition(self, *args, **kwargs):
3440-
self._raise_sync_only("wait_for_condition")
3440+
async def _wait_for_condition(self, statuses: Optional[List[str]] = None):
3441+
statuses = to_list(statuses, check_none=True)
3442+
last_status = None
3443+
while not self._should_stop_waiting(last_status, statuses):
3444+
if last_status:
3445+
await asyncio.sleep(settings.CLIENT_CONFIG.watch_interval)
3446+
try:
3447+
last_status, conditions = await self.get_statuses(last_status)
3448+
yield last_status, conditions
3449+
except ApiException as e:
3450+
if self._is_retryable_status_error(e):
3451+
yield last_status, []
3452+
else:
3453+
raise e
34413454

34423455
@async_client_handler(check_no_op=True, check_offline=True)
3443-
async def watch_statuses(self, *args, **kwargs):
3444-
self._raise_sync_only("watch_statuses")
3456+
async def wait_for_condition(
3457+
self,
3458+
statuses: Optional[List[str]] = None,
3459+
print_status: bool = False,
3460+
live_update: Any = None,
3461+
):
3462+
async for status, _conditions in self._wait_for_condition(statuses):
3463+
self._run_data.status = status # type: ignore
3464+
if print_status:
3465+
print("Last received status: {}\n".format(status))
3466+
if live_update:
3467+
latest_status = Printer.add_status_color(
3468+
{"status": status}, status_key="status"
3469+
)
3470+
live_update.update(status="{}\n".format(latest_status["status"]))
3471+
3472+
async def watch_statuses(self, statuses: Optional[List[str]] = None):
3473+
if self._no_op or self._is_offline:
3474+
return
3475+
async for status, conditions in self._wait_for_condition(statuses):
3476+
self._run_data.status = status # type: ignore
3477+
yield status, conditions
34453478

34463479
@async_client_handler(check_no_op=True, check_offline=True)
3447-
async def get_logs(self, *args, **kwargs):
3448-
self._raise_sync_only("get_logs")
3480+
async def get_logs(self, last_file=None, last_time=None) -> "V1Logs":
3481+
if not self.settings:
3482+
await self.refresh_data()
3483+
await self._use_agent_host()
3484+
params = get_logs_params(
3485+
last_file=last_file,
3486+
last_time=last_time,
3487+
connection=self.artifacts_store,
3488+
kind=self.run_data.kind,
3489+
)
3490+
return await self.client.runs_v1.get_run_logs(
3491+
self.namespace,
3492+
self.owner,
3493+
self.project,
3494+
self.run_uuid,
3495+
**params,
3496+
)
34493497

34503498
@async_client_handler(check_no_op=True, check_offline=True)
34513499
async def watch_logs(self, *args, **kwargs):
34523500
self._raise_sync_only("watch_logs")
34533501

34543502
@async_client_handler(check_no_op=True, check_offline=True)
3455-
async def inspect(self, *args, **kwargs):
3456-
self._raise_sync_only("inspect")
3503+
async def inspect(self):
3504+
if not self.settings:
3505+
await self.refresh_data()
3506+
await self._use_agent_host()
3507+
params = get_streams_params(connection=self.artifacts_store, status=self.status)
3508+
return await self.client.runs_v1.inspect_run(
3509+
self.namespace,
3510+
self.owner,
3511+
self.project,
3512+
self.run_uuid,
3513+
self.run_data.kind,
3514+
**params,
3515+
)
34573516

34583517
@async_client_handler(check_no_op=True, check_offline=True)
34593518
async def shell(self, *args, **kwargs):

cli/tests/test_client/test_async_run_client.py

Lines changed: 108 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from mock import mock
33
import pytest
44

5+
from polyaxon import settings
56
from polyaxon._client.run import AsyncRunClient, RunClient
67
from polyaxon._schemas.lifecycle import (
78
V1ProjectVersionKind,
@@ -162,7 +163,10 @@ def test_async_run_client_rejects_sync_client():
162163
def test_async_run_client_method_surface_is_async():
163164
for method in ASYNC_METHODS:
164165
assert method in AsyncRunClient.__dict__
165-
assert inspect.iscoroutinefunction(getattr(AsyncRunClient, method))
166+
if method == "watch_statuses":
167+
assert inspect.isasyncgenfunction(getattr(AsyncRunClient, method))
168+
else:
169+
assert inspect.iscoroutinefunction(getattr(AsyncRunClient, method))
166170

167171

168172
def test_async_run_client_local_helpers_stay_shared():
@@ -329,6 +333,109 @@ async def test_status_methods_await_api_without_async_req():
329333
assert "async_req" not in sdk_client.runs_v1.create_run_status.call_args[1]
330334

331335

336+
@pytest.mark.asyncio
337+
async def test_wait_for_condition_updates_status():
338+
patch_settings()
339+
sdk_client = AsyncPolyaxonClientMock()
340+
sdk_client.runs_v1.get_run_statuses = AsyncMock(
341+
return_value=mock.Mock(
342+
status=V1Statuses.RUNNING,
343+
status_conditions=[
344+
V1StatusCondition.model_construct(type=V1Statuses.RUNNING)
345+
],
346+
)
347+
)
348+
client = make_client(sdk_client)
349+
350+
await client.wait_for_condition(statuses=[V1Statuses.RUNNING])
351+
352+
assert client.status == V1Statuses.RUNNING
353+
sdk_client.runs_v1.get_run_statuses.assert_called_once_with(
354+
OWNER,
355+
PROJECT,
356+
RUN_UUID,
357+
)
358+
359+
360+
@pytest.mark.asyncio
361+
async def test_watch_statuses_yields_until_done_status():
362+
patch_settings()
363+
settings.CLIENT_CONFIG.watch_interval = 0
364+
sdk_client = AsyncPolyaxonClientMock()
365+
running_condition = V1StatusCondition.model_construct(type=V1Statuses.RUNNING)
366+
succeeded_condition = V1StatusCondition.model_construct(type=V1Statuses.SUCCEEDED)
367+
sdk_client.runs_v1.get_run_statuses = AsyncMock(
368+
side_effect=[
369+
mock.Mock(
370+
status=V1Statuses.RUNNING,
371+
status_conditions=[running_condition],
372+
),
373+
mock.Mock(
374+
status=V1Statuses.SUCCEEDED,
375+
status_conditions=[running_condition, succeeded_condition],
376+
),
377+
]
378+
)
379+
client = make_client(sdk_client)
380+
381+
statuses = []
382+
async for status, _conditions in client.watch_statuses():
383+
statuses.append(status)
384+
385+
assert statuses == [V1Statuses.RUNNING, V1Statuses.SUCCEEDED]
386+
assert client.status == V1Statuses.SUCCEEDED
387+
assert sdk_client.runs_v1.get_run_statuses.call_count == 2
388+
389+
390+
@pytest.mark.asyncio
391+
async def test_get_logs_awaits_refresh_and_api():
392+
patch_settings()
393+
sdk_client = AsyncPolyaxonClientMock()
394+
response = mock.Mock()
395+
sdk_client.runs_v1.get_run = AsyncMock(return_value=make_run())
396+
sdk_client.runs_v1.get_run_logs = AsyncMock(return_value=response)
397+
client = make_client(sdk_client)
398+
client._run_data.settings = None
399+
400+
result = await client.get_logs(last_file="last.log", last_time="123")
401+
402+
assert result is response
403+
sdk_client.runs_v1.get_run.assert_called_once_with(OWNER, PROJECT, RUN_UUID)
404+
assert sdk_client.runs_v1.get_run_logs.call_args[0] == (
405+
"test-namespace",
406+
OWNER,
407+
PROJECT,
408+
RUN_UUID,
409+
)
410+
assert "async_req" not in sdk_client.runs_v1.get_run_logs.call_args[1]
411+
412+
413+
@pytest.mark.asyncio
414+
async def test_inspect_awaits_refresh_and_api():
415+
patch_settings()
416+
sdk_client = AsyncPolyaxonClientMock()
417+
response = {"pods": {}}
418+
sdk_client.runs_v1.get_run = AsyncMock(
419+
return_value=make_run(status=V1Statuses.RUNNING)
420+
)
421+
sdk_client.runs_v1.inspect_run = AsyncMock(return_value=response)
422+
client = make_client(sdk_client)
423+
client._run_data.settings = None
424+
425+
result = await client.inspect()
426+
427+
assert result is response
428+
sdk_client.runs_v1.get_run.assert_called_once_with(OWNER, PROJECT, RUN_UUID)
429+
assert sdk_client.runs_v1.inspect_run.call_args[0] == (
430+
"test-namespace",
431+
OWNER,
432+
PROJECT,
433+
RUN_UUID,
434+
None,
435+
)
436+
assert "async_req" not in sdk_client.runs_v1.inspect_run.call_args[1]
437+
438+
332439
@pytest.mark.asyncio
333440
async def test_events_metrics_and_lineage_methods_await_api():
334441
patch_settings()
@@ -504,7 +611,6 @@ async def test_promote_methods_use_async_project_client_with_injected_client():
504611
@pytest.mark.parametrize(
505612
"method,args",
506613
[
507-
("get_logs", ()),
508614
("upload_artifact", ("file.txt",)),
509615
("download_artifacts", ()),
510616
("persist_run", ("/tmp/run",)),

0 commit comments

Comments
 (0)