Skip to content

Commit 5316f4b

Browse files
committed
update charge_lock
1 parent d919005 commit 5316f4b

File tree

9 files changed

+248
-156
lines changed

9 files changed

+248
-156
lines changed

src/apify/_actor.py

Lines changed: 31 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
EventSystemInfoData,
2727
)
2828

29-
from apify._charging import ChargeResult, ChargingManager, ChargingManagerImplementation
29+
from apify._charging import DEFAULT_DATASET_ITEM_EVENT, ChargeResult, ChargingManager, ChargingManagerImplementation
3030
from apify._configuration import Configuration
3131
from apify._consts import EVENT_LISTENERS_TIMEOUT
3232
from apify._crypto import decrypt_input_secrets, load_private_key
@@ -613,53 +613,63 @@ async def open_request_queue(
613613
storage_client=self._storage_client.get_suitable_storage_client(force_cloud=force_cloud),
614614
)
615615

616-
@overload
617-
async def push_data(self, data: dict | list[dict]) -> None: ...
618-
@overload
619-
async def push_data(self, data: dict | list[dict], charged_event_name: str) -> ChargeResult: ...
620616
@_ensure_context
621-
async def push_data(self, data: dict | list[dict], charged_event_name: str | None = None) -> ChargeResult | None:
617+
async def push_data(self, data: dict | list[dict], charged_event_name: str | None = None) -> ChargeResult:
622618
"""Store an object or a list of objects to the default dataset of the current Actor run.
623619
624620
Args:
625621
data: The data to push to the default dataset.
626622
charged_event_name: If provided and if the Actor uses the pay-per-event pricing model,
627623
the method will attempt to charge for the event for each pushed item.
628624
"""
629-
if not data:
630-
return None
631-
632-
data = data if isinstance(data, list) else [data]
633-
634625
if charged_event_name and charged_event_name.startswith('apify-'):
635626
raise ValueError(f'Cannot charge for synthetic event "{charged_event_name}" manually')
636627

637628
charging_manager = self.get_charging_manager()
638629

630+
if not data:
631+
charged_event_name = charged_event_name or DEFAULT_DATASET_ITEM_EVENT
632+
charge_limit_reached = charging_manager.is_event_charge_limit_reached(charged_event_name)
633+
634+
return ChargeResult(
635+
event_charge_limit_reached=charge_limit_reached,
636+
charged_count=0,
637+
chargeable_within_limit=charging_manager.compute_chargeable(),
638+
)
639+
640+
data = data if isinstance(data, list) else [data]
641+
642+
dataset = await self.open_dataset()
643+
639644
# Acquire the charge lock to prevent race conditions between concurrent
640645
# push_data calls. We need to hold the lock for the entire push_data + charge sequence.
641-
async with charging_manager.charge_lock:
642-
# No explicit charging requested; synthetic events are handled within dataset.push_data.
646+
async with charging_manager.charge_lock():
647+
# Synthetic events are handled within dataset.push_data, only get data for `ChargeResult`.
643648
if charged_event_name is None:
644-
dataset = await self.open_dataset()
649+
before = charging_manager.get_charged_event_count(DEFAULT_DATASET_ITEM_EVENT)
645650
await dataset.push_data(data)
646-
return None
651+
after = charging_manager.get_charged_event_count(DEFAULT_DATASET_ITEM_EVENT)
652+
return ChargeResult(
653+
event_charge_limit_reached=charging_manager.is_event_charge_limit_reached(
654+
DEFAULT_DATASET_ITEM_EVENT
655+
),
656+
charged_count=after - before,
657+
chargeable_within_limit=charging_manager.compute_chargeable(),
658+
)
647659

648-
pushed_items_count = self.get_charging_manager().compute_push_data_limit(
660+
pushed_items_count = charging_manager.compute_push_data_limit(
649661
items_count=len(data),
650662
event_name=charged_event_name,
651663
is_default_dataset=True,
652664
)
653665

654-
dataset = await self.open_dataset()
655-
656666
if pushed_items_count < len(data):
657667
await dataset.push_data(data[:pushed_items_count])
658668
elif pushed_items_count > 0:
659669
await dataset.push_data(data)
660670

661671
# Only charge explicit events; synthetic events will be processed within the client.
662-
return await self.get_charging_manager().charge(
672+
return await charging_manager.charge(
663673
event_name=charged_event_name,
664674
count=pushed_items_count,
665675
)
@@ -723,10 +733,9 @@ async def charge(self, event_name: str, count: int = 1) -> ChargeResult:
723733
event_name: Name of the event to be charged for.
724734
count: Number of events to charge for.
725735
"""
726-
# Acquire lock to prevent race conditions with concurrent charge/push_data calls.
736+
# charging_manager.charge() acquires charge_lock internally.
727737
charging_manager = self.get_charging_manager()
728-
async with charging_manager.charge_lock:
729-
return await charging_manager.charge(event_name, count)
738+
return await charging_manager.charge(event_name, count)
730739

731740
@overload
732741
def on(

src/apify/_charging.py

Lines changed: 84 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from __future__ import annotations
22

3-
import asyncio
43
import math
54
from contextvars import ContextVar
65
from dataclasses import dataclass
@@ -18,7 +17,7 @@
1817
PricePerDatasetItemActorPricingInfo,
1918
PricingModel,
2019
)
21-
from apify._utils import docs_group, ensure_context
20+
from apify._utils import ReentrantLock, docs_group, ensure_context
2221
from apify.log import logger
2322
from apify.storages import Dataset
2423

@@ -53,7 +52,7 @@ class ChargingManager(Protocol):
5352
- Apify platform documentation: https://docs.apify.com/platform/actors/publishing/monetize
5453
"""
5554

56-
charge_lock: asyncio.Lock
55+
charge_lock: ReentrantLock
5756
"""Lock to synchronize charge operations. Prevents race conditions between `charge` and `push_data` calls."""
5857

5958
async def charge(self, event_name: str, count: int = 1) -> ChargeResult:
@@ -114,6 +113,12 @@ def compute_push_data_limit(
114113
Max number of items that can be pushed within the budget.
115114
"""
116115

116+
def is_event_charge_limit_reached(self, event_name: str) -> bool:
117+
"""Return True if the remaining budget is insufficient to charge even a single event of the given type."""
118+
119+
def compute_chargeable(self) -> dict[str, int | None]:
120+
"""Compute the maximum number of events of each type that can be charged within the current budget."""
121+
117122

118123
@docs_group('Charging')
119124
@dataclass(frozen=True)
@@ -170,7 +175,7 @@ def __init__(self, configuration: Configuration, client: ApifyClientAsync) -> No
170175
self._not_ppe_warning_printed = False
171176
self.active = False
172177

173-
self.charge_lock = asyncio.Lock()
178+
self.charge_lock = ReentrantLock()
174179

175180
async def __aenter__(self) -> None:
176181
"""Initialize the charging manager - this is called by the `Actor` class and shouldn't be invoked manually."""
@@ -244,13 +249,6 @@ async def __aexit__(
244249

245250
@_ensure_context
246251
async def charge(self, event_name: str, count: int = 1) -> ChargeResult:
247-
def calculate_chargeable() -> dict[str, int | None]:
248-
"""Calculate the maximum number of events of each type that can be charged within the current budget."""
249-
return {
250-
event_name: self.calculate_max_event_charge_count_within_limit(event_name)
251-
for event_name in self._pricing_info
252-
}
253-
254252
# For runs that do not use the pay-per-event pricing model, just print a warning and return
255253
if self._pricing_model != 'PAY_PER_EVENT':
256254
if not self._not_ppe_warning_printed:
@@ -262,79 +260,81 @@ def calculate_chargeable() -> dict[str, int | None]:
262260
return ChargeResult(
263261
event_charge_limit_reached=False,
264262
charged_count=0,
265-
chargeable_within_limit=calculate_chargeable(),
263+
chargeable_within_limit=self.compute_chargeable(),
266264
)
267265

268-
# START OF CRITICAL SECTION - no awaits here
269-
270-
# Determine the maximum amount of events that can be charged within the budget
271-
max_chargeable = self.calculate_max_event_charge_count_within_limit(event_name)
272-
charged_count = min(count, max_chargeable if max_chargeable is not None else count)
273-
274-
if charged_count == 0:
266+
if count <= 0:
275267
return ChargeResult(
276-
event_charge_limit_reached=True,
268+
event_charge_limit_reached=self.is_event_charge_limit_reached(event_name),
277269
charged_count=0,
278-
chargeable_within_limit=calculate_chargeable(),
270+
chargeable_within_limit=self.compute_chargeable(),
279271
)
280272

281-
pricing_info = self._pricing_info.get(
282-
event_name,
283-
PricingInfoItem(
284-
# Use a nonzero price for local development so that the maximum budget can be reached.
285-
price=Decimal() if self._is_at_home else Decimal(1),
286-
title=f"Unknown event '{event_name}'",
287-
),
288-
)
273+
async with self.charge_lock():
274+
# Determine the maximum amount of events that can be charged within the budget
275+
max_chargeable = self.calculate_max_event_charge_count_within_limit(event_name)
276+
charged_count = min(count, max_chargeable if max_chargeable is not None else count)
289277

290-
# Update the charging state
291-
self._charging_state.setdefault(event_name, ChargingStateItem(0, Decimal()))
292-
self._charging_state[event_name].charge_count += charged_count
293-
self._charging_state[event_name].total_charged_amount += charged_count * pricing_info.price
294-
295-
# END OF CRITICAL SECTION
278+
if charged_count == 0:
279+
return ChargeResult(
280+
event_charge_limit_reached=True,
281+
charged_count=0,
282+
chargeable_within_limit=self.compute_chargeable(),
283+
)
296284

297-
# If running on the platform, call the charge endpoint
298-
if self._is_at_home:
299-
if self._actor_run_id is None:
300-
raise RuntimeError('Actor run ID not configured')
301-
302-
if event_name.startswith('apify-'):
303-
# Synthetic events (e.g. apify-default-dataset-item) are tracked internally only,
304-
# the platform handles them automatically based on dataset writes.
305-
pass
306-
elif event_name in self._pricing_info:
307-
await self._client.run(self._actor_run_id).charge(event_name, charged_count)
308-
else:
309-
logger.warning(f"Attempting to charge for an unknown event '{event_name}'")
310-
311-
# Log the charged operation (if enabled)
312-
if self._charging_log_dataset:
313-
await self._charging_log_dataset.push_data(
314-
{
315-
'event_name': event_name,
316-
'event_title': pricing_info.title,
317-
'event_price_usd': round(pricing_info.price, 3),
318-
'charged_count': charged_count,
319-
'timestamp': datetime.now(timezone.utc).isoformat(),
320-
}
285+
pricing_info = self._pricing_info.get(
286+
event_name,
287+
PricingInfoItem(
288+
# Use a nonzero price for local development so that the maximum budget can be reached.
289+
price=Decimal() if self._is_at_home else Decimal(1),
290+
title=f"Unknown event '{event_name}'",
291+
),
321292
)
322293

323-
# If it is not possible to charge the full amount, log that fact
324-
if charged_count < count:
325-
subject = 'instance' if count == 1 else 'instances'
326-
logger.info(
327-
f"Charging {count} {subject} of '{event_name}' event would exceed max_total_charge_usd "
328-
f'- only {charged_count} events were charged'
329-
)
294+
# Update the charging state
295+
self._charging_state.setdefault(event_name, ChargingStateItem(0, Decimal()))
296+
self._charging_state[event_name].charge_count += charged_count
297+
self._charging_state[event_name].total_charged_amount += charged_count * pricing_info.price
298+
299+
# If running on the platform, call the charge endpoint
300+
if self._is_at_home:
301+
if self._actor_run_id is None:
302+
raise RuntimeError('Actor run ID not configured')
303+
304+
if event_name.startswith('apify-'):
305+
# Synthetic events (e.g. apify-default-dataset-item) are tracked internally only,
306+
# the platform handles them automatically based on dataset writes.
307+
pass
308+
elif event_name in self._pricing_info:
309+
await self._client.run(self._actor_run_id).charge(event_name, charged_count)
310+
else:
311+
logger.warning(f"Attempting to charge for an unknown event '{event_name}'")
312+
313+
# Log the charged operation (if enabled)
314+
if self._charging_log_dataset:
315+
await self._charging_log_dataset.push_data(
316+
{
317+
'event_name': event_name,
318+
'event_title': pricing_info.title,
319+
'event_price_usd': round(pricing_info.price, 3),
320+
'charged_count': charged_count,
321+
'timestamp': datetime.now(timezone.utc).isoformat(),
322+
}
323+
)
330324

331-
max_charge_count = self.calculate_max_event_charge_count_within_limit(event_name)
325+
# If it is not possible to charge the full amount, log that fact
326+
if charged_count < count:
327+
subject = 'instance' if count == 1 else 'instances'
328+
logger.info(
329+
f"Charging {count} {subject} of '{event_name}' event would exceed max_total_charge_usd "
330+
f'- only {charged_count} events were charged'
331+
)
332332

333-
return ChargeResult(
334-
event_charge_limit_reached=max_charge_count is not None and max_charge_count <= 0,
335-
charged_count=charged_count,
336-
chargeable_within_limit=calculate_chargeable(),
337-
)
333+
return ChargeResult(
334+
event_charge_limit_reached=self.is_event_charge_limit_reached(event_name),
335+
charged_count=charged_count,
336+
chargeable_within_limit=self.compute_chargeable(),
337+
)
338338

339339
@_ensure_context
340340
def calculate_total_charged_amount(self) -> Decimal:
@@ -394,6 +394,18 @@ def compute_push_data_limit(
394394
max_count = max(0, math.floor(result)) if result.is_finite() else items_count
395395
return min(items_count, max_count)
396396

397+
@_ensure_context
398+
def is_event_charge_limit_reached(self, event_name: str) -> bool:
399+
max_charge_count = self.calculate_max_event_charge_count_within_limit(event_name)
400+
return max_charge_count is not None and max_charge_count <= 0
401+
402+
@_ensure_context
403+
def compute_chargeable(self) -> dict[str, int | None]:
404+
return {
405+
event_name: self.calculate_max_event_charge_count_within_limit(event_name)
406+
for event_name in self._pricing_info
407+
}
408+
397409
async def _fetch_pricing_info(self) -> _FetchedPricingInfoDict:
398410
"""Fetch pricing information from environment variables or API."""
399411
# Check if pricing info is available via environment variables

src/apify/_utils.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,18 @@
11
from __future__ import annotations
22

3+
import asyncio
34
import builtins
45
import inspect
56
import sys
67
from collections.abc import Callable
8+
from contextlib import asynccontextmanager
79
from enum import Enum
810
from functools import wraps
911
from importlib import metadata
10-
from typing import Any, Literal, TypeVar, cast
12+
from typing import TYPE_CHECKING, Any, Literal, TypeVar, cast
13+
14+
if TYPE_CHECKING:
15+
from collections.abc import AsyncIterator
1116

1217
T = TypeVar('T', bound=Callable[..., Any])
1318

@@ -123,3 +128,25 @@ def maybe_extract_enum_member_value(maybe_enum_member: Any) -> Any:
123128
if isinstance(maybe_enum_member, Enum):
124129
return maybe_enum_member.value
125130
return maybe_enum_member
131+
132+
133+
class ReentrantLock:
134+
"""A reentrant lock implementation for asyncio using asyncio.Lock."""
135+
136+
def __init__(self) -> None:
137+
self._lock = asyncio.Lock()
138+
self._owner: asyncio.Task | None = None
139+
140+
@asynccontextmanager
141+
async def __call__(self) -> AsyncIterator[None]:
142+
"""Acquire the lock if it's not already owned by the current task, otherwise proceed without acquiring."""
143+
me = asyncio.current_task()
144+
if self._owner is me:
145+
yield
146+
return
147+
async with self._lock:
148+
self._owner = me
149+
try:
150+
yield
151+
finally:
152+
self._owner = None

src/apify/storage_clients/_apify/_dataset_client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ async def payloads_generator(items: list[Any]) -> AsyncIterator[str]:
142142
for index, item in enumerate(items):
143143
yield await self._check_and_serialize(item, index)
144144

145-
async with self._lock, self._charge_lock():
145+
async with self._charge_lock(), self._lock:
146146
items = data if isinstance(data, list) else [data]
147147
limit = self._compute_limit_for_push(len(items))
148148
items = items[:limit]

0 commit comments

Comments
 (0)