Skip to content

Commit 0d584bc

Browse files
committed
Accept Pydantic models as alternatives to dicts in resource client methods
1 parent 723ec6e commit 0d584bc

File tree

8 files changed

+244
-113
lines changed

8 files changed

+244
-113
lines changed

src/apify_client/_resource_clients/actor.py

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -23,14 +23,11 @@
2323
RunOrigin,
2424
RunResponse,
2525
UpdateActorRequest,
26+
WebhookCreate,
2627
)
2728
from apify_client._resource_clients._resource_client import ResourceClient, ResourceClientAsync
28-
from apify_client._utils import (
29-
encode_key_value_store_record_value,
30-
encode_webhook_list_to_base64,
31-
response_to_dict,
32-
to_seconds,
33-
)
29+
from apify_client._types import WebhookRepresentationList
30+
from apify_client._utils import encode_key_value_store_record_value, response_to_dict, to_seconds
3431

3532
if TYPE_CHECKING:
3633
from datetime import timedelta
@@ -232,7 +229,7 @@ def start(
232229
run_timeout: timedelta | None = None,
233230
force_permission_level: ActorPermissionLevel | None = None,
234231
wait_for_finish: int | None = None,
235-
webhooks: list[dict] | None = None,
232+
webhooks: list[dict | WebhookCreate] | None = None,
236233
timeout: Timeout = 'long',
237234
) -> Run:
238235
"""Start the Actor and immediately return the Run object.
@@ -271,6 +268,10 @@ def start(
271268
"""
272269
run_input, content_type = encode_key_value_store_record_value(run_input, content_type)
273270

271+
validated_webhooks = (
272+
[WebhookCreate.model_validate(w) if isinstance(w, dict) else w for w in webhooks] if webhooks else []
273+
)
274+
274275
request_params = self._build_params(
275276
build=build,
276277
maxItems=max_items,
@@ -280,7 +281,7 @@ def start(
280281
timeout=to_seconds(run_timeout, as_int=True),
281282
waitForFinish=wait_for_finish,
282283
forcePermissionLevel=force_permission_level.value if force_permission_level is not None else None,
283-
webhooks=encode_webhook_list_to_base64(webhooks) if webhooks is not None else None,
284+
webhooks=WebhookRepresentationList.from_webhooks(validated_webhooks).to_base64(),
284285
)
285286

286287
response = self._http_client.call(
@@ -306,7 +307,7 @@ def call(
306307
restart_on_error: bool | None = None,
307308
memory_mbytes: int | None = None,
308309
run_timeout: timedelta | None = None,
309-
webhooks: list[dict] | None = None,
310+
webhooks: list[dict | WebhookCreate] | None = None,
310311
force_permission_level: ActorPermissionLevel | None = None,
311312
wait_duration: timedelta | None = None,
312313
logger: Logger | None | Literal['default'] = 'default',
@@ -728,7 +729,7 @@ async def start(
728729
run_timeout: timedelta | None = None,
729730
force_permission_level: ActorPermissionLevel | None = None,
730731
wait_for_finish: int | None = None,
731-
webhooks: list[dict] | None = None,
732+
webhooks: list[dict | WebhookCreate] | None = None,
732733
timeout: Timeout = 'long',
733734
) -> Run:
734735
"""Start the Actor and immediately return the Run object.
@@ -767,6 +768,10 @@ async def start(
767768
"""
768769
run_input, content_type = encode_key_value_store_record_value(run_input, content_type)
769770

771+
validated_webhooks = (
772+
[WebhookCreate.model_validate(w) if isinstance(w, dict) else w for w in webhooks] if webhooks else []
773+
)
774+
770775
request_params = self._build_params(
771776
build=build,
772777
maxItems=max_items,
@@ -776,7 +781,7 @@ async def start(
776781
timeout=to_seconds(run_timeout, as_int=True),
777782
waitForFinish=wait_for_finish,
778783
forcePermissionLevel=force_permission_level.value if force_permission_level is not None else None,
779-
webhooks=encode_webhook_list_to_base64(webhooks) if webhooks is not None else None,
784+
webhooks=WebhookRepresentationList.from_webhooks(validated_webhooks).to_base64(),
780785
)
781786

782787
response = await self._http_client.call(
@@ -802,7 +807,7 @@ async def call(
802807
restart_on_error: bool | None = None,
803808
memory_mbytes: int | None = None,
804809
run_timeout: timedelta | None = None,
805-
webhooks: list[dict] | None = None,
810+
webhooks: list[dict | WebhookCreate] | None = None,
806811
force_permission_level: ActorPermissionLevel | None = None,
807812
wait_duration: timedelta | None = None,
808813
logger: Logger | None | Literal['default'] = 'default',

src/apify_client/_resource_clients/request_queue.py

Lines changed: 53 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
UnlockRequestsResult,
3636
)
3737
from apify_client._resource_clients._resource_client import ResourceClient, ResourceClientAsync
38+
from apify_client._types import RequestDeleteInput, RequestInput
3839
from apify_client._utils import catch_not_found_or_throw, response_to_dict, to_seconds
3940
from apify_client.errors import ApifyApiError
4041

@@ -189,7 +190,7 @@ def list_and_lock_head(
189190

190191
def add_request(
191192
self,
192-
request: dict,
193+
request: dict | RequestInput,
193194
*,
194195
forefront: bool | None = None,
195196
timeout: Timeout = 'short',
@@ -206,12 +207,15 @@ def add_request(
206207
Returns:
207208
The added request.
208209
"""
210+
if isinstance(request, dict):
211+
request = RequestInput.model_validate(request)
212+
209213
request_params = self._build_params(forefront=forefront, clientKey=self.client_key)
210214

211215
response = self._http_client.call(
212216
url=self._build_url('requests'),
213217
method='POST',
214-
json=request,
218+
json=request.model_dump(by_alias=True, exclude_none=True),
215219
params=request_params,
216220
timeout=timeout,
217221
)
@@ -248,7 +252,7 @@ def get_request(self, request_id: str, *, timeout: Timeout = 'short') -> Request
248252

249253
def update_request(
250254
self,
251-
request: dict,
255+
request: dict | Request,
252256
*,
253257
forefront: bool | None = None,
254258
timeout: Timeout = 'medium',
@@ -265,14 +269,15 @@ def update_request(
265269
Returns:
266270
The updated request.
267271
"""
268-
request_id = request['id']
272+
if isinstance(request, dict):
273+
request = Request.model_validate(request)
269274

270275
request_params = self._build_params(forefront=forefront, clientKey=self.client_key)
271276

272277
response = self._http_client.call(
273-
url=self._build_url(f'requests/{request_id}'),
278+
url=self._build_url(f'requests/{request.id}'),
274279
method='PUT',
275-
json=request,
280+
json=request.model_dump(by_alias=True, exclude_none=True),
276281
params=request_params,
277282
timeout=timeout,
278283
)
@@ -361,7 +366,7 @@ def delete_request_lock(
361366

362367
def batch_add_requests(
363368
self,
364-
requests: list[dict],
369+
requests: list[dict | RequestInput],
365370
*,
366371
forefront: bool = False,
367372
max_parallel: int = 1,
@@ -396,14 +401,19 @@ def batch_add_requests(
396401
if max_parallel != 1:
397402
raise NotImplementedError('max_parallel is only supported in async client')
398403

404+
requests_as_dicts = [
405+
(RequestInput.model_validate(r) if isinstance(r, dict) else r).model_dump(by_alias=True, exclude_none=True)
406+
for r in requests
407+
]
408+
399409
request_params = self._build_params(clientKey=self.client_key, forefront=forefront)
400410

401411
# Compute the payload size limit to ensure it doesn't exceed the maximum allowed size.
402412
payload_size_limit_bytes = _MAX_PAYLOAD_SIZE_BYTES - math.ceil(_MAX_PAYLOAD_SIZE_BYTES * _SAFETY_BUFFER_PERCENT)
403413

404414
# Split the requests into batches, constrained by the max payload size and max requests per batch.
405415
batches = constrained_batches(
406-
requests,
416+
requests_as_dicts,
407417
max_size=payload_size_limit_bytes,
408418
max_count=_RQ_MAX_REQUESTS_PER_BATCH,
409419
)
@@ -444,7 +454,7 @@ def batch_add_requests(
444454

445455
def batch_delete_requests(
446456
self,
447-
requests: list[dict],
457+
requests: list[dict | RequestDeleteInput],
448458
*,
449459
timeout: Timeout = 'short',
450460
) -> BatchDeleteResult:
@@ -456,13 +466,20 @@ def batch_delete_requests(
456466
requests: List of the requests to delete.
457467
timeout: Timeout for the API HTTP request.
458468
"""
469+
requests_as_dicts = [
470+
(RequestDeleteInput.model_validate(r) if isinstance(r, dict) else r).model_dump(
471+
by_alias=True, exclude_none=True
472+
)
473+
for r in requests
474+
]
475+
459476
request_params = self._build_params(clientKey=self.client_key)
460477

461478
response = self._http_client.call(
462479
url=self._build_url('requests/batch'),
463480
method='DELETE',
464481
params=request_params,
465-
json=requests,
482+
json=requests_as_dicts,
466483
timeout=timeout,
467484
)
468485

@@ -658,7 +675,7 @@ async def list_and_lock_head(
658675

659676
async def add_request(
660677
self,
661-
request: dict,
678+
request: dict | RequestInput,
662679
*,
663680
forefront: bool | None = None,
664681
timeout: Timeout = 'short',
@@ -675,12 +692,15 @@ async def add_request(
675692
Returns:
676693
The added request.
677694
"""
695+
if isinstance(request, dict):
696+
request = RequestInput.model_validate(request)
697+
678698
request_params = self._build_params(forefront=forefront, clientKey=self.client_key)
679699

680700
response = await self._http_client.call(
681701
url=self._build_url('requests'),
682702
method='POST',
683-
json=request,
703+
json=request.model_dump(by_alias=True, exclude_none=True),
684704
params=request_params,
685705
timeout=timeout,
686706
)
@@ -715,7 +735,7 @@ async def get_request(self, request_id: str, *, timeout: Timeout = 'short') -> R
715735

716736
async def update_request(
717737
self,
718-
request: dict,
738+
request: dict | Request,
719739
*,
720740
forefront: bool | None = None,
721741
timeout: Timeout = 'medium',
@@ -732,14 +752,15 @@ async def update_request(
732752
Returns:
733753
The updated request.
734754
"""
735-
request_id = request['id']
755+
if isinstance(request, dict):
756+
request = Request.model_validate(request)
736757

737758
request_params = self._build_params(forefront=forefront, clientKey=self.client_key)
738759

739760
response = await self._http_client.call(
740-
url=self._build_url(f'requests/{request_id}'),
761+
url=self._build_url(f'requests/{request.id}'),
741762
method='PUT',
742-
json=request,
763+
json=request.model_dump(by_alias=True, exclude_none=True),
743764
params=request_params,
744765
timeout=timeout,
745766
)
@@ -874,7 +895,7 @@ async def _batch_add_requests_worker(
874895

875896
async def batch_add_requests(
876897
self,
877-
requests: list[dict],
898+
requests: list[dict | RequestInput],
878899
*,
879900
forefront: bool = False,
880901
max_parallel: int = 5,
@@ -906,6 +927,11 @@ async def batch_add_requests(
906927
if min_delay_between_unprocessed_requests_retries:
907928
logger.warning('`min_delay_between_unprocessed_requests_retries` is deprecated and not used anymore.')
908929

930+
requests_as_dicts = [
931+
(RequestInput.model_validate(r) if isinstance(r, dict) else r).model_dump(by_alias=True, exclude_none=True)
932+
for r in requests
933+
]
934+
909935
asyncio_queue: asyncio.Queue[Iterable[dict]] = asyncio.Queue()
910936
request_params = self._build_params(clientKey=self.client_key, forefront=forefront)
911937

@@ -914,7 +940,7 @@ async def batch_add_requests(
914940

915941
# Split the requests into batches, constrained by the max payload size and max requests per batch.
916942
batches = constrained_batches(
917-
requests,
943+
requests_as_dicts,
918944
max_size=payload_size_limit_bytes,
919945
max_count=_RQ_MAX_REQUESTS_PER_BATCH,
920946
)
@@ -959,7 +985,7 @@ async def batch_add_requests(
959985

960986
async def batch_delete_requests(
961987
self,
962-
requests: list[dict],
988+
requests: list[dict | RequestDeleteInput],
963989
*,
964990
timeout: Timeout = 'short',
965991
) -> BatchDeleteResult:
@@ -971,13 +997,20 @@ async def batch_delete_requests(
971997
requests: List of the requests to delete.
972998
timeout: Timeout for the API HTTP request.
973999
"""
1000+
requests_as_dicts = [
1001+
(RequestDeleteInput.model_validate(r) if isinstance(r, dict) else r).model_dump(
1002+
by_alias=True, exclude_none=True
1003+
)
1004+
for r in requests
1005+
]
1006+
9741007
request_params = self._build_params(clientKey=self.client_key)
9751008

9761009
response = await self._http_client.call(
9771010
url=self._build_url('requests/batch'),
9781011
method='DELETE',
9791012
params=request_params,
980-
json=requests,
1013+
json=requests_as_dicts,
9811014
timeout=timeout,
9821015
)
9831016
result = response_to_dict(response)

src/apify_client/_resource_clients/run.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,7 @@
1313
from apify_client._resource_clients._resource_client import ResourceClient, ResourceClientAsync
1414
from apify_client._status_message_watcher import StatusMessageWatcher, StatusMessageWatcherAsync
1515
from apify_client._streamed_log import StreamedLog, StreamedLogAsync
16-
from apify_client._utils import (
17-
encode_key_value_store_record_value,
18-
response_to_dict,
19-
to_safe_id,
20-
to_seconds,
21-
)
16+
from apify_client._utils import encode_key_value_store_record_value, response_to_dict, to_safe_id, to_seconds
2217

2318
if TYPE_CHECKING:
2419
import logging

0 commit comments

Comments
 (0)