Skip to content

Commit dde26f1

Browse files
authored
python(feat): Add better report creation and wait_until_complete to SiftClient (#475)
1 parent c03f4c8 commit dde26f1

7 files changed

Lines changed: 590 additions & 102 deletions

File tree

python/lib/sift_client/_internal/low_level_wrappers/rules.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@
5353
from sift.rules.v1.rules_pb2_grpc import RuleServiceStub
5454

5555
from sift_client._internal.low_level_wrappers.base import DEFAULT_PAGE_SIZE, LowLevelClientBase
56-
from sift_client._internal.low_level_wrappers.reports import ReportsLowLevelClient
56+
from sift_client._internal.low_level_wrappers.jobs import JobsLowLevelClient
5757
from sift_client._internal.util.timestamp import to_pb_timestamp
5858
from sift_client._internal.util.util import count_non_none
5959
from sift_client.sift_types.rule import (
@@ -69,7 +69,7 @@
6969
from datetime import datetime
7070

7171
from sift_client.sift_types.channel import ChannelReference
72-
from sift_client.sift_types.report import Report
72+
from sift_client.sift_types.job import Job
7373

7474
# Configure logging
7575
logger = logging.getLogger(__name__)
@@ -587,8 +587,8 @@ async def evaluate_rules(
587587
report_name: str | None = None,
588588
tags: list[str | Tag] | None = None,
589589
organization_id: str | None = None,
590-
) -> tuple[int, Report | None, str | None]:
591-
"""Evaluate a rule.
590+
) -> tuple[int, str | None, Job | None]:
591+
"""Evaluate rules.
592592
593593
Args:
594594
run_id: The run ID to evaluate.
@@ -604,7 +604,7 @@ async def evaluate_rules(
604604
organization_id: The organization ID to evaluate.
605605
606606
Returns:
607-
The result of the rule execution.
607+
The annotation_count, report_id, and job for the pending report.
608608
"""
609609
if count_non_none(run_id, asset_ids) > 1:
610610
raise ValueError(
@@ -664,13 +664,13 @@ async def evaluate_rules(
664664
request
665665
)
666666
response = cast("EvaluateRulesResponse", response)
667-
created_annotation_count = response.created_annotation_count
668-
report_id = response.report_id
669667
job_id = response.job_id
670-
if report_id:
671-
report = await ReportsLowLevelClient(self._grpc_client).get_report(report_id=report_id)
672-
return created_annotation_count, report, job_id
673-
return created_annotation_count, None, job_id
668+
669+
if job_id:
670+
job = await JobsLowLevelClient(self._grpc_client).get_job(job_id=job_id)
671+
else:
672+
job = None
673+
return response.created_annotation_count, response.report_id, job
674674

675675
async def get_rule_version(self, rule_version_id: str) -> Rule:
676676
"""Get a rule at a specific version by rule_version_id.

python/lib/sift_client/_tests/resources/test_jobs.py

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
"""
99

1010
from datetime import datetime, timedelta, timezone
11+
from unittest.mock import AsyncMock, MagicMock, patch
1112

1213
import pytest
1314
from grpc.aio import AioRpcError
@@ -291,6 +292,106 @@ async def test_retry_finished_job_no_effect(self, jobs_api_async):
291292
with pytest.raises(AioRpcError, match="job cannot be retried"):
292293
await jobs_api_async.retry(job)
293294

295+
class TestWaitUntilComplete:
296+
"""Tests for the async wait_until_complete method."""
297+
298+
@pytest.mark.asyncio
299+
async def test_returns_immediately_when_job_already_complete(self, jobs_api_async):
300+
"""When get returns a completed job on first call, wait returns immediately."""
301+
job_id = "test-job-id"
302+
mock_job = MagicMock()
303+
mock_job.job_status = JobStatus.FINISHED
304+
305+
with patch(
306+
"sift_client.resources.jobs.JobsAPIAsync.get",
307+
new_callable=AsyncMock,
308+
return_value=mock_job,
309+
) as mock_get:
310+
result = await jobs_api_async.wait_until_complete(job=job_id)
311+
312+
assert result is mock_job
313+
assert result.job_status == JobStatus.FINISHED
314+
mock_get.assert_called_once_with(job_id)
315+
316+
@pytest.mark.asyncio
317+
async def test_returns_immediately_when_job_already_failed(self, jobs_api_async):
318+
"""When get returns a failed job on first call, wait returns immediately."""
319+
job_id = "test-job-id"
320+
mock_job = MagicMock()
321+
mock_job.job_status = JobStatus.FAILED
322+
323+
with patch(
324+
"sift_client.resources.jobs.JobsAPIAsync.get",
325+
new_callable=AsyncMock,
326+
return_value=mock_job,
327+
) as mock_get:
328+
result = await jobs_api_async.wait_until_complete(job=job_id)
329+
330+
assert result is mock_job
331+
assert result.job_status == JobStatus.FAILED
332+
mock_get.assert_called_once_with(job_id)
333+
334+
@pytest.mark.asyncio
335+
async def test_returns_immediately_when_job_already_cancelled(self, jobs_api_async):
336+
"""When get returns a cancelled job on first call, wait returns immediately."""
337+
job_id = "test-job-id"
338+
mock_job = MagicMock()
339+
mock_job.job_status = JobStatus.CANCELLED
340+
341+
with patch(
342+
"sift_client.resources.jobs.JobsAPIAsync.get",
343+
new_callable=AsyncMock,
344+
return_value=mock_job,
345+
) as mock_get:
346+
result = await jobs_api_async.wait_until_complete(job=job_id)
347+
348+
assert result is mock_job
349+
assert result.job_status == JobStatus.CANCELLED
350+
mock_get.assert_called_once_with(job_id)
351+
352+
@pytest.mark.asyncio
353+
async def test_polls_until_complete(self, jobs_api_async):
354+
"""When get returns running then finished, wait returns after second poll."""
355+
job_id = "test-job-id"
356+
running_job = MagicMock()
357+
running_job.job_status = JobStatus.RUNNING
358+
finished_job = MagicMock()
359+
finished_job.job_status = JobStatus.FINISHED
360+
361+
with patch(
362+
"sift_client.resources.jobs.JobsAPIAsync.get",
363+
new_callable=AsyncMock,
364+
side_effect=[running_job, finished_job],
365+
) as mock_get:
366+
result = await jobs_api_async.wait_until_complete(
367+
job=job_id,
368+
polling_interval_secs=0.01,
369+
timeout_secs=10.0,
370+
)
371+
372+
assert result is finished_job
373+
assert result.job_status == JobStatus.FINISHED
374+
assert mock_get.call_count == 2
375+
376+
@pytest.mark.asyncio
377+
async def test_raises_timeout_error_when_not_complete_in_time(self, jobs_api_async):
378+
"""When job never reaches a completed state, TimeoutError is raised."""
379+
job_id = "test-job-id"
380+
running_job = MagicMock()
381+
running_job.job_status = JobStatus.RUNNING
382+
383+
with patch(
384+
"sift_client.resources.jobs.JobsAPIAsync.get",
385+
new_callable=AsyncMock,
386+
return_value=running_job,
387+
):
388+
with pytest.raises(TimeoutError):
389+
await jobs_api_async.wait_until_complete(
390+
job=job_id,
391+
polling_interval_secs=0.05,
392+
timeout_secs=0.1,
393+
)
394+
294395
class TestJobProperties:
295396
"""Tests for job property methods."""
296397

0 commit comments

Comments
 (0)