Skip to content

Commit e68c909

Browse files
committed
Accept Pydantic models as alternatives to dicts in resource client methods
1 parent 57a55d6 commit e68c909

File tree

8 files changed

+190
-112
lines changed

8 files changed

+190
-112
lines changed

src/apify_client/_internal_models.py

Lines changed: 47 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
1-
"""Internal Pydantic models that are not part of the public API and are therefore not generated."""
2-
31
from __future__ import annotations
42

5-
from pydantic import BaseModel, ConfigDict
3+
import json
4+
from base64 import b64encode
5+
from typing import Annotated
6+
7+
from pydantic import BaseModel, ConfigDict, Field, RootModel
68

7-
from apify_client._models import ActorJobStatus # noqa: TC001
9+
from apify_client._models import ActorJobStatus, WebhookCreate # noqa: TC001
810

911

1012
class ActorJob(BaseModel):
@@ -27,3 +29,44 @@ class ActorJobResponse(BaseModel):
2729
model_config = ConfigDict(extra='allow')
2830

2931
data: ActorJob
32+
33+
34+
class WebhookRepresentation(BaseModel):
35+
"""Representation of a webhook for base64-encoded API transmission.
36+
37+
Contains only the fields needed for the webhook payload sent via query parameters.
38+
"""
39+
40+
model_config = ConfigDict(populate_by_name=True, extra='ignore')
41+
42+
event_types: Annotated[list[str], Field(alias='eventTypes')]
43+
request_url: Annotated[str, Field(alias='requestUrl')]
44+
payload_template: Annotated[str | None, Field(alias='payloadTemplate')] = None
45+
headers_template: Annotated[str | None, Field(alias='headersTemplate')] = None
46+
47+
48+
class WebhookRepresentationList(RootModel[list[WebhookRepresentation]]):
49+
"""List of webhook representations with base64 encoding support."""
50+
51+
@classmethod
52+
def from_webhooks(cls, webhooks: list[WebhookCreate]) -> WebhookRepresentationList:
53+
"""Construct from a list of `WebhookCreate` models."""
54+
representations = list[WebhookRepresentation]()
55+
56+
for w in webhooks:
57+
webhook_dict = w.model_dump(mode='json', exclude_none=True)
58+
representations.append(WebhookRepresentation.model_validate(webhook_dict))
59+
60+
return cls(representations)
61+
62+
def to_base64(self) -> str | None:
63+
"""Encode this list of webhook representations to a base64 string.
64+
65+
Returns `None` if the list is empty, so that the query parameter is omitted.
66+
"""
67+
if not self.root:
68+
return None
69+
70+
data = [r.model_dump(by_alias=True, exclude_none=True) for r in self.root]
71+
json_string = json.dumps(data).encode(encoding='utf-8')
72+
return b64encode(json_string).decode(encoding='ascii')

src/apify_client/_resource_clients/actor.py

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from pydantic import TypeAdapter
66

77
from apify_client._docs import docs_group
8+
from apify_client._internal_models import WebhookRepresentationList
89
from apify_client._models import (
910
Actor,
1011
ActorPermissionLevel,
@@ -23,14 +24,10 @@
2324
RunOrigin,
2425
RunResponse,
2526
UpdateActorRequest,
27+
WebhookCreate,
2628
)
2729
from apify_client._resource_clients._resource_client import ResourceClient, ResourceClientAsync
28-
from apify_client._utils import (
29-
encode_key_value_store_record_value,
30-
encode_webhook_list_to_base64,
31-
response_to_dict,
32-
to_seconds,
33-
)
30+
from apify_client._utils import encode_key_value_store_record_value, response_to_dict, to_seconds
3431

3532
if TYPE_CHECKING:
3633
from datetime import timedelta
@@ -224,7 +221,7 @@ def start(
224221
timeout: timedelta | None = None,
225222
force_permission_level: ActorPermissionLevel | None = None,
226223
wait_for_finish: int | None = None,
227-
webhooks: list[dict] | None = None,
224+
webhooks: list[dict | WebhookCreate] | None = None,
228225
) -> Run:
229226
"""Start the Actor and immediately return the Run object.
230227
@@ -261,6 +258,10 @@ def start(
261258
"""
262259
run_input, content_type = encode_key_value_store_record_value(run_input, content_type)
263260

261+
validated_webhooks = (
262+
[WebhookCreate.model_validate(w) if isinstance(w, dict) else w for w in webhooks] if webhooks else []
263+
)
264+
264265
request_params = self._build_params(
265266
build=build,
266267
maxItems=max_items,
@@ -270,7 +271,7 @@ def start(
270271
timeout=to_seconds(timeout, as_int=True),
271272
waitForFinish=wait_for_finish,
272273
forcePermissionLevel=force_permission_level.value if force_permission_level is not None else None,
273-
webhooks=encode_webhook_list_to_base64(webhooks) if webhooks is not None else None,
274+
webhooks=WebhookRepresentationList.from_webhooks(validated_webhooks).to_base64(),
274275
)
275276

276277
response = self._http_client.call(
@@ -295,7 +296,7 @@ def call(
295296
restart_on_error: bool | None = None,
296297
memory_mbytes: int | None = None,
297298
timeout: timedelta | None = None,
298-
webhooks: list[dict] | None = None,
299+
webhooks: list[dict | WebhookCreate] | None = None,
299300
force_permission_level: ActorPermissionLevel | None = None,
300301
wait_duration: timedelta | None = None,
301302
logger: Logger | None | Literal['default'] = 'default',
@@ -689,7 +690,7 @@ async def start(
689690
timeout: timedelta | None = None,
690691
force_permission_level: ActorPermissionLevel | None = None,
691692
wait_for_finish: int | None = None,
692-
webhooks: list[dict] | None = None,
693+
webhooks: list[dict | WebhookCreate] | None = None,
693694
) -> Run:
694695
"""Start the Actor and immediately return the Run object.
695696
@@ -726,6 +727,10 @@ async def start(
726727
"""
727728
run_input, content_type = encode_key_value_store_record_value(run_input, content_type)
728729

730+
validated_webhooks = (
731+
[WebhookCreate.model_validate(w) if isinstance(w, dict) else w for w in webhooks] if webhooks else []
732+
)
733+
729734
request_params = self._build_params(
730735
build=build,
731736
maxItems=max_items,
@@ -735,7 +740,7 @@ async def start(
735740
timeout=to_seconds(timeout, as_int=True),
736741
waitForFinish=wait_for_finish,
737742
forcePermissionLevel=force_permission_level.value if force_permission_level is not None else None,
738-
webhooks=encode_webhook_list_to_base64(webhooks) if webhooks is not None else None,
743+
webhooks=WebhookRepresentationList.from_webhooks(validated_webhooks).to_base64(),
739744
)
740745

741746
response = await self._http_client.call(
@@ -760,7 +765,7 @@ async def call(
760765
restart_on_error: bool | None = None,
761766
memory_mbytes: int | None = None,
762767
timeout: timedelta | None = None,
763-
webhooks: list[dict] | None = None,
768+
webhooks: list[dict | WebhookCreate] | None = None,
764769
force_permission_level: ActorPermissionLevel | None = None,
765770
wait_duration: timedelta | None = None,
766771
logger: Logger | None | Literal['default'] = 'default',

src/apify_client/_resource_clients/request_queue.py

Lines changed: 48 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ def list_and_lock_head(self, *, lock_duration: timedelta, limit: int | None = No
166166
result = response_to_dict(response)
167167
return HeadAndLockResponse.model_validate(result).data
168168

169-
def add_request(self, request: dict, *, forefront: bool | None = None) -> RequestRegistration:
169+
def add_request(self, request: dict | RequestDraft, *, forefront: bool | None = None) -> RequestRegistration:
170170
"""Add a request to the queue.
171171
172172
https://docs.apify.com/api/v2#/reference/request-queues/request-collection/add-request
@@ -178,12 +178,15 @@ def add_request(self, request: dict, *, forefront: bool | None = None) -> Reques
178178
Returns:
179179
The added request.
180180
"""
181+
if isinstance(request, dict):
182+
request = RequestDraft.model_validate(request)
183+
181184
request_params = self._build_params(forefront=forefront, clientKey=self.client_key)
182185

183186
response = self._http_client.call(
184187
url=self._build_url('requests'),
185188
method='POST',
186-
json=request,
189+
json=request.model_dump(by_alias=True, exclude_none=True),
187190
params=request_params,
188191
timeout=FAST_OPERATION_TIMEOUT,
189192
)
@@ -217,7 +220,7 @@ def get_request(self, request_id: str) -> Request | None:
217220

218221
return None
219222

220-
def update_request(self, request: dict, *, forefront: bool | None = None) -> RequestRegistration:
223+
def update_request(self, request: dict | Request, *, forefront: bool | None = None) -> RequestRegistration:
221224
"""Update a request in the queue.
222225
223226
https://docs.apify.com/api/v2#/reference/request-queues/request/update-request
@@ -229,14 +232,15 @@ def update_request(self, request: dict, *, forefront: bool | None = None) -> Req
229232
Returns:
230233
The updated request.
231234
"""
232-
request_id = request['id']
235+
if isinstance(request, dict):
236+
request = Request.model_validate(request)
233237

234238
request_params = self._build_params(forefront=forefront, clientKey=self.client_key)
235239

236240
response = self._http_client.call(
237-
url=self._build_url(f'requests/{request_id}'),
241+
url=self._build_url(f'requests/{request.id}'),
238242
method='PUT',
239-
json=request,
243+
json=request.model_dump(by_alias=True, exclude_none=True),
240244
params=request_params,
241245
timeout=STANDARD_OPERATION_TIMEOUT,
242246
)
@@ -315,7 +319,7 @@ def delete_request_lock(self, request_id: str, *, forefront: bool | None = None)
315319

316320
def batch_add_requests(
317321
self,
318-
requests: list[dict],
322+
requests: list[dict | RequestDraft],
319323
*,
320324
forefront: bool = False,
321325
max_parallel: int = 1,
@@ -348,14 +352,19 @@ def batch_add_requests(
348352
if max_parallel != 1:
349353
raise NotImplementedError('max_parallel is only supported in async client')
350354

355+
requests_as_dicts = [
356+
(RequestDraft.model_validate(r) if isinstance(r, dict) else r).model_dump(by_alias=True, exclude_none=True)
357+
for r in requests
358+
]
359+
351360
request_params = self._build_params(clientKey=self.client_key, forefront=forefront)
352361

353362
# Compute the payload size limit to ensure it doesn't exceed the maximum allowed size.
354363
payload_size_limit_bytes = _MAX_PAYLOAD_SIZE_BYTES - math.ceil(_MAX_PAYLOAD_SIZE_BYTES * _SAFETY_BUFFER_PERCENT)
355364

356365
# Split the requests into batches, constrained by the max payload size and max requests per batch.
357366
batches = constrained_batches(
358-
requests,
367+
requests_as_dicts,
359368
max_size=payload_size_limit_bytes,
360369
max_count=_RQ_MAX_REQUESTS_PER_BATCH,
361370
)
@@ -394,21 +403,26 @@ def batch_add_requests(
394403
)
395404
).data
396405

397-
def batch_delete_requests(self, requests: list[dict]) -> BatchDeleteResult:
406+
def batch_delete_requests(self, requests: list[dict | RequestDraft]) -> BatchDeleteResult:
398407
"""Delete given requests from the queue.
399408
400409
https://docs.apify.com/api/v2#/reference/request-queues/batch-request-operations/delete-requests
401410
402411
Args:
403412
requests: List of the requests to delete.
404413
"""
414+
requests_as_dicts = [
415+
(RequestDraft.model_validate(r) if isinstance(r, dict) else r).model_dump(by_alias=True, exclude_none=True)
416+
for r in requests
417+
]
418+
405419
request_params = self._build_params(clientKey=self.client_key)
406420

407421
response = self._http_client.call(
408422
url=self._build_url('requests/batch'),
409423
method='DELETE',
410424
params=request_params,
411-
json=requests,
425+
json=requests_as_dicts,
412426
timeout=FAST_OPERATION_TIMEOUT,
413427
)
414428

@@ -580,7 +594,7 @@ async def list_and_lock_head(self, *, lock_duration: timedelta, limit: int | Non
580594
result = response_to_dict(response)
581595
return HeadAndLockResponse.model_validate(result).data
582596

583-
async def add_request(self, request: dict, *, forefront: bool | None = None) -> RequestRegistration:
597+
async def add_request(self, request: dict | RequestDraft, *, forefront: bool | None = None) -> RequestRegistration:
584598
"""Add a request to the queue.
585599
586600
https://docs.apify.com/api/v2#/reference/request-queues/request-collection/add-request
@@ -592,12 +606,15 @@ async def add_request(self, request: dict, *, forefront: bool | None = None) ->
592606
Returns:
593607
The added request.
594608
"""
609+
if isinstance(request, dict):
610+
request = RequestDraft.model_validate(request)
611+
595612
request_params = self._build_params(forefront=forefront, clientKey=self.client_key)
596613

597614
response = await self._http_client.call(
598615
url=self._build_url('requests'),
599616
method='POST',
600-
json=request,
617+
json=request.model_dump(by_alias=True, exclude_none=True),
601618
params=request_params,
602619
timeout=FAST_OPERATION_TIMEOUT,
603620
)
@@ -629,7 +646,7 @@ async def get_request(self, request_id: str) -> Request | None:
629646
catch_not_found_or_throw(exc)
630647
return None
631648

632-
async def update_request(self, request: dict, *, forefront: bool | None = None) -> RequestRegistration:
649+
async def update_request(self, request: dict | Request, *, forefront: bool | None = None) -> RequestRegistration:
633650
"""Update a request in the queue.
634651
635652
https://docs.apify.com/api/v2#/reference/request-queues/request/update-request
@@ -641,14 +658,15 @@ async def update_request(self, request: dict, *, forefront: bool | None = None)
641658
Returns:
642659
The updated request.
643660
"""
644-
request_id = request['id']
661+
if isinstance(request, dict):
662+
request = Request.model_validate(request)
645663

646664
request_params = self._build_params(forefront=forefront, clientKey=self.client_key)
647665

648666
response = await self._http_client.call(
649-
url=self._build_url(f'requests/{request_id}'),
667+
url=self._build_url(f'requests/{request.id}'),
650668
method='PUT',
651-
json=request,
669+
json=request.model_dump(by_alias=True, exclude_none=True),
652670
params=request_params,
653671
timeout=STANDARD_OPERATION_TIMEOUT,
654672
)
@@ -777,7 +795,7 @@ async def _batch_add_requests_worker(
777795

778796
async def batch_add_requests(
779797
self,
780-
requests: list[dict],
798+
requests: list[dict | RequestDraft],
781799
*,
782800
forefront: bool = False,
783801
max_parallel: int = 5,
@@ -807,6 +825,11 @@ async def batch_add_requests(
807825
if min_delay_between_unprocessed_requests_retries:
808826
logger.warning('`min_delay_between_unprocessed_requests_retries` is deprecated and not used anymore.')
809827

828+
requests_as_dicts = [
829+
(RequestDraft.model_validate(r) if isinstance(r, dict) else r).model_dump(by_alias=True, exclude_none=True)
830+
for r in requests
831+
]
832+
810833
asyncio_queue: asyncio.Queue[Iterable[dict]] = asyncio.Queue()
811834
request_params = self._build_params(clientKey=self.client_key, forefront=forefront)
812835

@@ -815,7 +838,7 @@ async def batch_add_requests(
815838

816839
# Split the requests into batches, constrained by the max payload size and max requests per batch.
817840
batches = constrained_batches(
818-
requests,
841+
requests_as_dicts,
819842
max_size=payload_size_limit_bytes,
820843
max_count=_RQ_MAX_REQUESTS_PER_BATCH,
821844
)
@@ -858,21 +881,26 @@ async def batch_add_requests(
858881
)
859882
).data
860883

861-
async def batch_delete_requests(self, requests: list[dict]) -> BatchDeleteResult:
884+
async def batch_delete_requests(self, requests: list[dict | RequestDraft]) -> BatchDeleteResult:
862885
"""Delete given requests from the queue.
863886
864887
https://docs.apify.com/api/v2#/reference/request-queues/batch-request-operations/delete-requests
865888
866889
Args:
867890
requests: List of the requests to delete.
868891
"""
892+
requests_as_dicts = [
893+
(RequestDraft.model_validate(r) if isinstance(r, dict) else r).model_dump(by_alias=True, exclude_none=True)
894+
for r in requests
895+
]
896+
869897
request_params = self._build_params(clientKey=self.client_key)
870898

871899
response = await self._http_client.call(
872900
url=self._build_url('requests/batch'),
873901
method='DELETE',
874902
params=request_params,
875-
json=requests,
903+
json=requests_as_dicts,
876904
timeout=FAST_OPERATION_TIMEOUT,
877905
)
878906
result = response_to_dict(response)

0 commit comments

Comments
 (0)