Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,9 @@ def handle(self, *args, **options): # noqa: ARG002
resource_delete_actions(learning_resource)
else:
task = get_mit_edx_data.delay(
options["api_course_datafile"], options["api_program_datafile"]
options["api_course_datafile"],
options["api_program_datafile"],
_cooldown_force=True,
)
self.stdout.write(f"Started task {task} to get MIT edX course data")
self.stdout.write("Waiting on task...")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def handle(self, *args, **options): # noqa: ARG002
):
resource_delete_actions(learning_resource)
else:
task = get_mitxonline_data.delay()
task = get_mitxonline_data.delay(_cooldown_force=True)
self.stdout.write(f"Started task {task} to get MITx Online course data")
self.stdout.write("Waiting on task...")
start = now_in_utc()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,10 @@ def handle(self, *args, **options): # noqa: ARG002
):
resource_delete_actions(learning_resource)
else:
task = get_oll_data.delay(sheets_id=options["sheets_id"])
task = get_oll_data.delay(
sheets_id=options["sheets_id"],
_cooldown_force=True,
)
self.stdout.write(f"Started task {task} to get oll course data")
self.stdout.write("Waiting on task...")
start = now_in_utc()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def handle(self, *args, **options): # noqa: ARG002
):
resource_delete_actions(learning_resource)
else:
task = get_xpro_data.delay()
task = get_xpro_data.delay(_cooldown_force=True)
self.stdout.write(f"Started task {task} to get xpro course data")
self.stdout.write("Waiting on task...")
start = now_in_utc()
Expand Down
24 changes: 19 additions & 5 deletions learning_resources/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
from learning_resources_search.exceptions import RetryError
from main.celery import app
from main.constants import ISOFORMAT
from main.decorators import cooldown_task
from main.utils import chunks, clear_views_cache, now_in_utc

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -117,10 +118,16 @@ def get_micromasters_data():


@app.task
@cooldown_task(
wait_time=3600,
key_func=lambda *, api_course_datafile=None, api_program_datafile=None: (
f"course={api_course_datafile}:program={api_program_datafile}"
),
)
def get_mit_edx_data(
api_course_datafile: str | None = None,
api_program_datafile: str | None = None,
) -> int:
) -> int | None:
"""Task to sync MIT edX data with the database

Args:
Expand All @@ -130,7 +137,8 @@ def get_mit_edx_data(
Otherwise, the API is queried directly.

Returns:
int: The number of results that were fetched
int | None: The number of results fetched, or None if the call was
skipped due to rate limiting.
"""
courses = pipelines.mit_edx_courses_etl(api_course_datafile)
programs = pipelines.mit_edx_programs_etl(api_program_datafile)
Expand All @@ -139,7 +147,8 @@ def get_mit_edx_data(


@app.task
def get_mitxonline_data() -> int:
@cooldown_task(wait_time=900)
def get_mitxonline_data() -> int | None:
"""Execute the MITX Online ETL pipeline"""
courses = pipelines.mitxonline_courses_etl()
programs = pipelines.mitxonline_programs_etl()
Expand All @@ -148,7 +157,11 @@ def get_mitxonline_data() -> int:


@app.task
def get_oll_data(sheets_id=None):
@cooldown_task(
wait_time=900,
key_func=lambda *, sheets_id=None: f"sheets_id={sheets_id}",
)
def get_oll_data(sheets_id=None) -> int | None:
"""Execute the OLL ETL pipeline.

Args:
Expand Down Expand Up @@ -176,7 +189,8 @@ def get_sloan_data():


@app.task
def get_xpro_data():
@cooldown_task(wait_time=900)
def get_xpro_data() -> int | None:
"""Execute the xPro ETL pipeline"""
courses = pipelines.xpro_courses_etl()
programs = pipelines.xpro_programs_etl()
Expand Down
94 changes: 94 additions & 0 deletions main/decorators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
"""main decorators"""

import inspect
import logging
from collections.abc import Callable
from functools import wraps

from celery import current_task, states
from celery.exceptions import Reject
from django.core.cache import caches

log = logging.getLogger(__name__)

KEY_PREFIX = "cooldown"


def cooldown_task(
wait_time: int,
key: str | None = None,
key_func: Callable[..., str] | None = None,
):
"""
Drop calls made within `wait_time` seconds of the previous invocation.

The lock is acquired before the wrapped function runs and is not released
on exception — failures count against the cooldown to prevent retry
storms against upstream APIs. Uses an atomic ``cache.add`` so it is safe
across Celery workers.

Place this *below* ``@app.task`` so the cooldown runs on the worker, not
on the enqueuing process.

When a call is skipped from inside a Celery worker, the task's state is
explicitly set to ``REJECTED`` in the result backend and
``Reject(requeue=False)`` is raised (should be ignored by sentry).

To bypass the cooldown for a specific invocation (e.g., operator-forced
recovery), pass ``_cooldown_force=True`` as a kwarg through ``delay()``
or ``apply_async``. The wrapper consumes it before calling the wrapped
function and refreshes the lock so subsequent calls are still gated.
This is race-free relative to clearing the lock from outside, which has
a window between clear and enqueue where another worker can reacquire.

The wrapper also exposes ``clear_cooldown(*args, **kwargs)`` which
deletes the lock key. Useful for operational debugging from a shell;
prefer ``_cooldown_force=True`` from the enqueuing path.

Args:
wait_time: Lock duration in seconds.
key: Optional static cache key. Defaults to the wrapped function's
fully-qualified name.
key_func: Optional callable receiving the wrapped function's bound
arguments as keyword args; returns a string suffix appended to
the base key. Opt-in; use to scope the cooldown per
argument-set.
"""

def decorator(func):
base_key = f"{KEY_PREFIX}:{key or f'{func.__module__}.{func.__qualname__}'}"
sig = inspect.signature(func) if key_func else None

def _key_for(*args, **kwargs):
if key_func is None:
return base_key
bound = sig.bind(*args, **kwargs)
bound.apply_defaults()
return f"{base_key}:{key_func(**bound.arguments)}"

@wraps(func)
def wrapper(*args, **kwargs):
force = kwargs.pop("_cooldown_force", False)
lock_key = _key_for(*args, **kwargs)
if force:
log.info("Force-overriding cooldown for %s", lock_key)
caches["redis"].set(lock_key, "1", timeout=wait_time)
elif not caches["redis"].add(lock_key, "1", timeout=wait_time):
log.info("Skipping %s: cooldown active (%ss)", lock_key, wait_time)
if current_task and not current_task.request.called_directly:
current_task.update_state(
state=states.REJECTED,
meta={"reason": "cooldown", "key": lock_key},
)
reason = "cooldown active"
raise Reject(reason, requeue=False)
return None
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm going to sleep on it overnight but I'm a little concerned that letting the task succeed is going to cause confusion later when tasks appear to be successful but they aren't actually getting run.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we'll just leave it as-is and this should come up in the logs if anyone investigates.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I might have just come up with a way to make it more obvious in the celery monitoring dashboard - raise Reject so the status is REJECTED instead of SUCCEEDED - I think that should not raise a sentry error (which is what this is intended to avoid). And the mgmt commands should probably override the cooldown by default without needing to supply --force - what do you think of that idea?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I went ahead and pushed those changes in a new commit, can revert if you prefer the original approach.

return func(*args, **kwargs)

def clear_cooldown(*args, **kwargs):
caches["redis"].delete(_key_for(*args, **kwargs))

wrapper.clear_cooldown = clear_cooldown
return wrapper

return decorator
183 changes: 183 additions & 0 deletions main/decorators_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
"""Tests for main.decorators"""

import pytest
from celery.exceptions import Reject

from main.decorators import cooldown_task


@pytest.fixture
def mock_redis(mocker):
"""Patch the redis cache used by cooldown_task and return its mock."""
mock = mocker.Mock()
mocker.patch("main.decorators.caches", {"redis": mock})
return mock


def test_cooldown_task_runs_first_call_and_drops_subsequent(mock_redis, mocker):
"""First call runs; subsequent calls within the window are dropped."""
mock_redis.add.side_effect = [True, False, False]
inner = mocker.Mock(return_value="result")

@cooldown_task(wait_time=3600)
def my_task(*args, **kwargs):
return inner(*args, **kwargs)

assert my_task("a", b=1) == "result"
assert my_task() is None
assert my_task() is None
inner.assert_called_once_with("a", b=1)
for call in mock_redis.add.call_args_list:
assert call.kwargs["timeout"] == 3600


def test_cooldown_task_uses_custom_key(mock_redis):
"""When `key` is provided, the lock key uses it instead of the func name."""
mock_redis.add.return_value = True

@cooldown_task(wait_time=60, key="my-custom-key")
def my_task():
return 1

my_task()
assert mock_redis.add.call_args[0][0] == "cooldown:my-custom-key"


def test_cooldown_task_default_key_uses_func_name(mock_redis):
"""Default key is derived from the wrapped function's qualified name."""
mock_redis.add.return_value = True

@cooldown_task(wait_time=60)
def some_task():
return 1

some_task()
key = mock_redis.add.call_args[0][0]
assert key.startswith("cooldown:")
assert "some_task" in key


def test_cooldown_task_key_func_scopes_per_argument(mock_redis):
"""`key_func` produces a distinct lock per argument set."""
mock_redis.add.return_value = True

@cooldown_task(
wait_time=60,
key_func=lambda *, sheets_id=None: f"sheet:{sheets_id}",
)
def my_task(sheets_id=None):
return 1

my_task(sheets_id="A")
my_task(sheets_id="B")
keys = [call.args[0] for call in mock_redis.add.call_args_list]
assert keys[0].endswith(":sheet:A")
assert keys[1].endswith(":sheet:B")
assert keys[0] != keys[1]


def test_cooldown_task_lock_held_across_exception(mock_redis):
"""Failures count against the cooldown — lock is not released on exception."""
mock_redis.add.side_effect = [True, False]

@cooldown_task(wait_time=60)
def my_task():
msg = "boom"
raise RuntimeError(msg)

with pytest.raises(RuntimeError):
my_task()
assert my_task() is None


def test_cooldown_task_clear_cooldown_deletes_key(mock_redis):
"""`clear_cooldown` deletes the lock key."""

@cooldown_task(wait_time=60)
def my_task():
return 1

my_task.clear_cooldown()
assert mock_redis.delete.called
key = mock_redis.delete.call_args[0][0]
assert key.startswith("cooldown:")
assert "my_task" in key


def test_cooldown_task_clear_cooldown_respects_key_func(mock_redis):
"""`clear_cooldown` uses `key_func` so per-argument locks can be cleared."""

@cooldown_task(
wait_time=60,
key_func=lambda *, sheets_id=None: f"sheet:{sheets_id}",
)
def my_task(sheets_id=None):
return 1

my_task.clear_cooldown(sheets_id="A")
assert mock_redis.delete.call_args[0][0].endswith(":sheet:A")


def test_cooldown_task_force_bypasses_active_cooldown(mock_redis, mocker):
"""`_cooldown_force=True` runs the wrapped func even when the lock is held."""
inner = mocker.Mock(return_value="ran")

@cooldown_task(wait_time=60)
def my_task(**kwargs):
return inner(**kwargs)

result = my_task(_cooldown_force=True)
assert result == "ran"
inner.assert_called_once_with()
mock_redis.add.assert_not_called()
assert mock_redis.set.called
assert mock_redis.set.call_args.kwargs["timeout"] == 60


def test_cooldown_task_force_refreshes_cooldown_for_subsequent_calls(mock_redis):
"""After a forced run, an immediate normal call is still gated."""
mock_redis.add.return_value = False

@cooldown_task(wait_time=60)
def my_task(**kwargs):
return 1

my_task(_cooldown_force=True)
assert my_task() is None
mock_redis.add.assert_called_once()


def test_cooldown_task_raises_reject_inside_celery_worker(mock_redis, mocker):
"""
When a skip happens inside a real Celery task run, write REJECTED to the
result backend and raise Reject so the run is observable as REJECTED
rather than PENDING/SUCCESS.
"""
mock_redis.add.return_value = False
mock_task = mocker.patch("main.decorators.current_task")
mock_task.request.called_directly = False

@cooldown_task(wait_time=60)
def my_task():
return 1

with pytest.raises(Reject) as exc_info:
my_task()
assert exc_info.value.requeue is False
mock_task.update_state.assert_called_once()
assert mock_task.update_state.call_args.kwargs["state"] == "REJECTED"
assert mock_task.update_state.call_args.kwargs["meta"]["reason"] == "cooldown"


def test_cooldown_task_returns_none_when_called_directly(mock_redis, mocker):
"""Direct (non-worker) skipped calls return None — no Reject, no state write."""
mock_redis.add.return_value = False
mock_task = mocker.patch("main.decorators.current_task")
mock_task.request.called_directly = True

@cooldown_task(wait_time=60)
def my_task():
return 1

assert my_task() is None
mock_task.update_state.assert_not_called()
Loading