diff --git a/chancy/__init__.py b/chancy/__init__.py index fbcb748..0eab11a 100644 --- a/chancy/__init__.py +++ b/chancy/__init__.py @@ -7,9 +7,10 @@ "Limit", "Reference", "job", + "ConcurrencyRule", ) from chancy.app import Chancy from chancy.queue import Queue from chancy.worker import Worker -from chancy.job import Limit, Job, QueuedJob, Reference, job +from chancy.job import Limit, Job, QueuedJob, Reference, job, ConcurrencyRule diff --git a/chancy/app.py b/chancy/app.py index 681eb41..19ace02 100644 --- a/chancy/app.py +++ b/chancy/app.py @@ -674,14 +674,27 @@ async def push_many_ex( :param jobs: The jobs to push onto the queue. :return: A list of references to the jobs in the queue. """ - references = [] - for job in jobs: - await cursor.execute( - self._push_job_sql(), - self._get_job_params(job), + # Insert concurrency configurations + concurrency_params = self._concurrency_params_iterator(jobs) + if concurrency_params: + await cursor.executemany( + self._push_concurrency_config_sql(), + concurrency_params, ) + + # Insert jobs + await cursor.executemany( + self._push_job_sql(), + self._job_params_iterator(jobs), + returning=True, + ) + references = [] + while True: record = await cursor.fetchone() - references.append(Reference(record["id"])) + if record: + references.append(Reference(record["id"])) + elif not cursor.nextset(): + break if self.notifications: for queue in set( @@ -693,7 +706,7 @@ async def push_many_ex( return references def sync_push_many_ex( - self, cursor: Cursor, jobs: list[Job] + self, cursor: Cursor, jobs: list[Job | IsAJob[..., Any]] ) -> list[Reference]: """ Synchronously push multiple jobs onto the queue using a specific cursor. @@ -710,16 +723,31 @@ def sync_push_many_ex( :param jobs: The jobs to push onto the queue. :return: A list of references to the jobs in the queue. """ - references = [] - for job in jobs: - cursor.execute( - self._push_job_sql(), - self._get_job_params(job), + # Insert concurrency configurations + concurrency_params = self._concurrency_params_iterator(jobs) + if concurrency_params: + cursor.executemany( + self._push_concurrency_config_sql(), + concurrency_params, ) - record = cursor.fetchone() - references.append(Reference(record["id"])) - for queue in set(job.queue for job in jobs): + # Insert jobs + cursor.executemany( + self._push_job_sql(), + self._job_params_iterator(jobs), + returning=True, + ) + references = [] + while True: + record = cursor.fetchone() + if record: + references.append(Reference(record["id"])) + elif not cursor.nextset(): + break + + for queue in set( + job.queue if isinstance(job, Job) else job.job.queue for job in jobs + ): self.sync_notify(cursor, "queue.pushed", {"q": queue}) return references @@ -1345,7 +1373,8 @@ def _push_job_sql(self): priority, max_attempts, scheduled_at, - unique_key + unique_key, + concurrency_key ) VALUES ( %(id)s, @@ -1357,7 +1386,8 @@ def _push_job_sql(self): %(priority)s, %(max_attempts)s, %(scheduled_at)s, - %(unique_key)s + %(unique_key)s, + %(concurrency_key)s ) ON CONFLICT (unique_key) WHERE @@ -1464,17 +1494,43 @@ def _declare_sql(self, upsert: bool): action=action, ) + def _push_concurrency_config_sql(self): + return sql.SQL( + """ + INSERT INTO {concurrency_configs} + (concurrency_key, concurrency_max, updated_at) + VALUES (%s, %s, NOW()) + ON CONFLICT (concurrency_key) DO UPDATE SET + concurrency_max = EXCLUDED.concurrency_max, + updated_at = NOW() + """ + ).format( + concurrency_configs=sql.Identifier( + f"{self.prefix}concurrency_configs" + ) + ) + @staticmethod - def _get_job_params(job: Job | IsAJob[..., Any]) -> dict: + def _get_concurrency_params(job: Job) -> tuple: + """ + Get the parameters for storing concurrency configuration. + + :param job: The job containing concurrency configuration. + :return: A tuple of parameters for the concurrency config. + """ + return ( + job.evaluate_concurrency_key(), # prefixed concurrency key + job.concurrency_rule.max if job.concurrency_rule else None, + ) + + @staticmethod + def _get_job_params(job: Job) -> dict: """ Get the parameters for a job to be inserted into the database. :param job: The job to get parameters for. :return: A dictionary of parameters for the job. """ - if callable(job): - job = job.job - return { "id": chancy_uuid(), "queue": job.queue, @@ -1486,11 +1542,38 @@ def _get_job_params(job: Job | IsAJob[..., Any]) -> dict: "max_attempts": job.max_attempts, "scheduled_at": job.scheduled_at, "unique_key": job.unique_key, + "concurrency_key": job.evaluate_concurrency_key(), } + def _concurrency_params_iterator( + self, jobs: list[Job | IsAJob[..., Any]] + ) -> Iterator[tuple] | None: + """ + Collect and deduplicate concurrency configurations from jobs. + Create an iterator over the unique concurrency parameters. + """ + concurrency_configs = {} + + for job in jobs: + if callable(job): + job = job.job + if job.concurrency_rule: + params = self._get_concurrency_params(job) + concurrency_configs[params[0]] = params + + if concurrency_configs: + yield from concurrency_configs.values() + + def _job_params_iterator(self, jobs: list[Job | IsAJob[..., Any]]): + """Create iterator for job parameters.""" + for job in jobs: + if callable(job): + job = job.job + yield self._get_job_params(job) + -from chancy.plugins.pruner import Pruner # noqa: E402 -from chancy.plugins.recovery import Recovery # noqa: E402 from chancy.plugins.leadership import Leadership # noqa: E402 from chancy.plugins.metrics import Metrics # noqa: E402 +from chancy.plugins.pruner import Pruner # noqa: E402 +from chancy.plugins.recovery import Recovery # noqa: E402 from chancy.plugins.workflow import WorkflowPlugin # noqa: E402 diff --git a/chancy/job.py b/chancy/job.py index 8d34819..7e09160 100644 --- a/chancy/job.py +++ b/chancy/job.py @@ -97,8 +97,69 @@ def serialize(self) -> dict: return {"t": self.type_.value, "v": self.value} +@dataclasses.dataclass +class ConcurrencyRule: + """ + A concurrency that can be applied to a job and partition id. + """ + + #: The maximum number of jobs with the same concurrency key that can run + #: simultaneously across all workers. + max: int + #: The concurrency key specification for this job. Can be a field name string + #: or a callable that computes the key from job arguments. + key: str | Callable | None = None + + @classmethod + def deserialize(cls, data: dict) -> "ConcurrencyRule": + return cls(max=data["v"], key=data["k"]) + + def serialize(self) -> dict: + # Would be great to access global logger here to warn about callable not being serializable + return { + "k": self.key if not callable(self.key) else None, + "v": self.max, + } + + def compute_key(self, **job_kwargs) -> str | None: + """ + Compute the concurrency key from concurrency rule and jobs kwargs. + + This function takes the job's kwargs to compute the actual concurrency key + that will be used for concurrency limiting. + + :return: The computed concurrency key string, or None if no concurrency + rule is configured. + """ + if self.key is None: + return None + + try: + if callable(self.key): + key = self.key(**job_kwargs) + if key is None: + raise ValueError( + "Concurrency key function evaluated to None" + ) + elif isinstance(self.key, str): + # For string field names, look up the value in kwargs + key = job_kwargs.get(self.key) + if key is None: + raise ValueError( + f"Concurrency key '{self.key}' not found in job kwargs or evaluated to None" + ) + else: + raise TypeError( + f"Invalid concurrency key type '{type(self.key)}'." + ) + except Exception as e: + raise ValueError("Failed to evaluate concurrency key") from e + + return key + + @dataclasses.dataclass(frozen=True, kw_only=True) -class Job: +class BaseJob: """ A job is an immutable, stateless unit of work that can be pushed onto a Chancy queue and executed elsewhere. @@ -146,30 +207,72 @@ def hello_world(): """ return cls(func=importable_name(func), **kwargs) - def with_priority(self, priority: int) -> "Job": + def with_priority(self, priority: int) -> "BaseJob": return dataclasses.replace(self, priority=priority) - def with_max_attempts(self, max_attempts: int) -> "Job": + def with_max_attempts(self, max_attempts: int) -> "BaseJob": return dataclasses.replace(self, max_attempts=max_attempts) - def with_scheduled_at(self, scheduled_at: datetime) -> "Job": + def with_scheduled_at(self, scheduled_at: datetime) -> "BaseJob": return dataclasses.replace(self, scheduled_at=scheduled_at) - def with_limits(self, limits: list[Limit]) -> "Job": + def with_limits(self, limits: list[Limit]) -> "BaseJob": return dataclasses.replace(self, limits=limits) - def with_kwargs(self, **kwargs) -> "Job": + def with_kwargs(self, **kwargs) -> "BaseJob": return dataclasses.replace(self, kwargs=kwargs) - def with_unique_key(self, unique_key: str) -> "Job": + def with_unique_key(self, unique_key: str) -> "BaseJob": return dataclasses.replace(self, unique_key=unique_key) - def with_queue(self, queue: str) -> "Job": + def with_queue(self, queue: str) -> "BaseJob": return dataclasses.replace(self, queue=queue) - def with_meta(self, meta: dict[str, Any]) -> "Job": + def with_meta(self, meta: dict[str, Any]) -> "BaseJob": return dataclasses.replace(self, meta=meta) + +@dataclasses.dataclass(frozen=True, kw_only=True) +class Job(BaseJob): + #: The concurrency rule for this job. This determines how many instances of + #: this job can run concurrently across all workers. + concurrency_rule: ConcurrencyRule | None = None + + def with_concurrency( + self, + concurrency_rule: ConcurrencyRule, + ) -> "Job": + """ + Add concurrency constraints to this job. + + :param concurrency_rule: The concurrency rule for this job. This determines how many instances of + this job can run concurrently across all workers. + :return: A new Job instance with concurrency constraints. + """ + return dataclasses.replace( + self, + concurrency_rule=concurrency_rule, + ) + + def evaluate_concurrency_key(self) -> str | None: + """ + Evaluate the concurrency key from concurrency rule and jobs kwargs. + + This function takes a job's concurrency_rule specification and the job's + kwargs to compute the actual concurrency key that will be used + for concurrency limiting. The key is prefixed with the function name. + + :return: The computed concurrency key string (prefixed with func_name), + or None if no concurrency rule is configured. + """ + if self.concurrency_rule is None: + return None + + computed_key = self.concurrency_rule.compute_key(**(self.kwargs or {})) + if computed_key is not None: + return f"{self.func}:{computed_key}" + return self.func + def pack(self) -> dict: """ Pack the job into a dictionary that can be serialized and used to @@ -185,6 +288,9 @@ def pack(self) -> dict: "u": self.unique_key, "q": self.queue, "m": self.meta, + "c": self.concurrency_rule.serialize() + if self.concurrency_rule + else None, } @classmethod @@ -202,11 +308,14 @@ def unpack(cls, data: dict) -> "Job": unique_key=data["u"], queue=data["q"], meta=data["m"], + concurrency_rule=ConcurrencyRule.deserialize(data["c"]) + if data["c"] + else None, ) @dataclasses.dataclass(frozen=True, kw_only=True) -class QueuedJob(Job): +class QueuedJob(BaseJob): """ A job instance is a job that has been pushed onto a queue and now has stateful information associated with it, such as the number of attempts @@ -234,6 +343,10 @@ class State(enum.Enum): state: State = State.PENDING #: A list of errors that occurred during the execution of this job. errors: list[ErrorT] = dataclasses.field(default_factory=list) + #: The computed concurrency key for this specific job instance. This is + #: derived from the job's concurrency_rule specification and job arguments. + #: It is computed on job push. + computed_concurrency_key: str | None = None @classmethod def unpack(cls, data: dict) -> "QueuedJob": @@ -254,6 +367,7 @@ def unpack(cls, data: dict) -> "QueuedJob": errors=data["errors"], limits=[Limit.deserialize(limit) for limit in data["limits"]], meta=data["meta"], + computed_concurrency_key=data.get("concurrency_key"), ) diff --git a/chancy/migrations/v7.py b/chancy/migrations/v7.py new file mode 100644 index 0000000..6db6e40 --- /dev/null +++ b/chancy/migrations/v7.py @@ -0,0 +1,87 @@ +from psycopg import sql + +from chancy.migrate import Migration + + +class AddConcurrencySupport(Migration): + """ + Add support for job-level concurrency constraints. + + This migration adds: + 1. concurrency_key column to jobs table for storing computed concurrency keys (prefixed with func_name) + 2. concurrency_configs table using prefixed concurrency_key as primary key + 3. Optimized indexes for concurrency-aware job selection + """ + + async def up(self, migrator, cursor): + # Add concurrency_key column to jobs table + await cursor.execute( + sql.SQL( + """ + ALTER TABLE {jobs} + ADD COLUMN concurrency_key TEXT + """ + ).format(jobs=sql.Identifier(f"{migrator.prefix}jobs")) + ) + + # Create concurrency configurations table + await cursor.execute( + sql.SQL( + """ + CREATE TABLE {concurrency_configs} ( + concurrency_key TEXT PRIMARY KEY, + concurrency_max INTEGER NOT NULL, + created_at TIMESTAMPTZ DEFAULT NOW(), + updated_at TIMESTAMPTZ DEFAULT NOW() + ) + """ + ).format( + concurrency_configs=sql.Identifier( + f"{migrator.prefix}concurrency_configs" + ) + ) + ) + + # Create partial index for efficient concurrency lookups on running jobs + await cursor.execute( + sql.SQL( + """ + CREATE INDEX {index_name} ON {jobs} (concurrency_key) + WHERE state = 'running' AND concurrency_key IS NOT NULL + """ + ).format( + index_name=sql.Identifier( + f"{migrator.prefix}jobs_concurrency_key_running_idx" + ), + jobs=sql.Identifier(f"{migrator.prefix}jobs"), + ) + ) + + async def down(self, migrator, cursor): + # Drop the concurrency index + await cursor.execute( + sql.SQL("DROP INDEX IF EXISTS {index_name}").format( + index_name=sql.Identifier( + f"{migrator.prefix}jobs_concurrency_key_running_idx" + ) + ) + ) + + # Drop concurrency configurations table + await cursor.execute( + sql.SQL("DROP TABLE IF EXISTS {concurrency_configs}").format( + concurrency_configs=sql.Identifier( + f"{migrator.prefix}concurrency_configs" + ) + ) + ) + + # Remove concurrency_key column from jobs table + await cursor.execute( + sql.SQL( + """ + ALTER TABLE {jobs} + DROP COLUMN IF EXISTS concurrency_key + """ + ).format(jobs=sql.Identifier(f"{migrator.prefix}jobs")) + ) diff --git a/chancy/plugins/pruner.py b/chancy/plugins/pruner.py index 79b22e9..8430f43 100644 --- a/chancy/plugins/pruner.py +++ b/chancy/plugins/pruner.py @@ -1,14 +1,32 @@ +from abc import ABCMeta + from psycopg import AsyncCursor, sql from psycopg.rows import DictRow, dict_row from chancy.app import Chancy -from chancy.worker import Worker from chancy.plugin import Plugin -from chancy.rule import SQLAble, JobRules +from chancy.rule import ConcurrencyRules, JobRules, SQLAble from chancy.utils import timed_block +from chancy.worker import Worker + + +class PrunerMeta(ABCMeta): + """Metaclass to handle deprecated Rules attribute.""" + def __getattr__(cls, name): + if name == "Rules": + import warnings -class Pruner(Plugin): + warnings.warn( + "Pruner.Rules is deprecated. Use Pruner.JobRules instead.", + DeprecationWarning, + stacklevel=2, + ) + return cls.JobRules + raise AttributeError(f"'{cls.__name__}' has no attribute '{name}'") + + +class Pruner(Plugin, metaclass=PrunerMeta): """ A plugin that prunes stale data from the database. @@ -25,16 +43,17 @@ class Pruner(Plugin): async with Chancy(..., plugins=[ Leadership(), Pruner( - Pruner.Rules.Queue() == "default" & (Pruner.Rules.Age() > 60) + job_rule=Pruner.JobRules.Queue() == "default" & (Pruner.JobRules.Age() > 60), + concurrency_rule=Pruner.ConcurrencyRules.Age() > 60*60*24*3 ) ]) as chancy: ... The pruner will never prune jobs that haven't been run yet or are currently - running. When the pruner runs, it will also call the - :py:meth:`chancy.plugin.Plugin.cleanup` method on any plugins that - implement it, allowing them to clean up any data that is no longer - needed such as completed workflows. + running. It also cleans up stale concurrency configuration records. When + the pruner runs, it will also call the :py:meth:`chancy.plugin.Plugin.cleanup` + method on any plugins that implement it, allowing them to clean up any data + that is no longer needed such as completed workflows. Rules ----- @@ -46,14 +65,14 @@ class Pruner(Plugin): .. code-block:: python - Pruner(Pruner.Rules.Age() > 60) + Pruner(job_rule=Pruner.JobRules.Age() > 60) Or to prune jobs that are older than 60 seconds and are in the "default" queue: .. code-block:: python - Pruner(Pruner.Rules.Queue() == "default" & (Pruner.Rules.Age() > 60)) + Pruner(job_rule=Pruner.JobRules.Queue() == "default" & (Pruner.JobRules.Age() > 60)) Or to prune jobs that are older than 60 seconds and are in the "default" queue, or instantly deleted if the job is `update_cache`: @@ -61,13 +80,29 @@ class Pruner(Plugin): .. code-block:: python Pruner( - (Pruner.Rules.Queue() == "default" & (Pruner.Rules.Age() > 60)) | - Pruner.Rules.Job() == "update_cache" + job_rule=(Pruner.JobRules.Queue() == "default" & (Pruner.JobRules.Age() > 60)) | + Pruner.JobRules.Job() == "update_cache" ) + To customize concurrency rule cleanup: + + .. code-block:: python + + # Clean rules older than 3 days + Pruner(concurrency_rule=Pruner.ConcurrencyRules.Age() > 60*60*24*3) + + # Clean orphaned rules and those older than 12 hours + Pruner(concurrency_rule=( + Pruner.ConcurrencyRules.Orphaned() | + Pruner.ConcurrencyRules.Age() > 60*60*12 + )) + + # Disable concurrency rule cleanup + Pruner(concurrency_rule=None) + By default, the pruner will run every 60 seconds and will remove up to - 10,000 jobs in a single run that have been completed for more than 60 - seconds. + 10,000 jobs in a single run that have been completed for more than 1 day. + It will also clean up concurrency rules older than 3 days or that are orphaned. .. tip:: @@ -75,24 +110,54 @@ class Pruner(Plugin): multiple rules, you may need to create additional indexes to improve performance on busy queues. - :param rule: The rule that the pruner will use to match jobs. + .. deprecated:: + The ``rule`` parameter is deprecated. Use ``job_rule`` instead. + + :param rule: [DEPRECATED] The rule for pruning jobs. Use ``job_rule`` instead. + :param job_rule: The rule that the pruner will use to match jobs for pruning. + :param concurrency_rule: The rule for pruning concurrency rules. + Defaults to rules older than 7 days or orphaned. + Set to None to disable. :param maximum_to_prune: The maximum number of jobs to prune in a single run of the pruner. :param poll_interval: The interval in seconds between each run of the pruner. """ - Rules = JobRules + JobRules = JobRules + ConcurrencyRules = ConcurrencyRules def __init__( self, - rule: SQLAble = Rules.Age() > 60 * 60 * 24, + rule: SQLAble | None = None, # Deprecated *, + job_rule: SQLAble | None = JobRules.Age() > 60 * 60 * 24, + concurrency_rule: SQLAble | None = ( + (ConcurrencyRules.Age() > 60 * 60 * 24 * 7) # 7 days + | ConcurrencyRules.Orphaned() + ), maximum_to_prune: int = 10000, poll_interval: int = 60, ): super().__init__() - self.rule = rule + + # Handle backward compatibility with deprecation warning + if rule is not None and job_rule is not None: + job_rule = None # For backward compatibility + + if rule is not None: + import warnings + + warnings.warn( + "The 'rule' parameter is deprecated and will be removed in a future version. " + "Use 'job_rule' instead.", + DeprecationWarning, + stacklevel=2, + ) + job_rule = rule + + self.job_rule = job_rule + self.concurrency_rule = concurrency_rule self.maximum_to_prune = maximum_to_prune self.poll_interval = poll_interval @@ -110,11 +175,18 @@ async def run(self, worker: Worker, chancy: Chancy): async with chancy.pool.connection() as conn: async with conn.cursor(row_factory=dict_row) as cursor: with timed_block() as chancy_time: - rows_removed = await self.prune(chancy, cursor) + # Prune jobs + job_rows_removed = await self.prune_jobs(chancy, cursor) + + # Prune concurrency configs + concurrency_rule_rows_removed = ( + await self.prune_concurrency_rules(chancy, cursor) + ) + chancy.log.info( - f"Pruner removed {rows_removed} row(s) from the" - f" database. Took {chancy_time.elapsed:.2f}" - f" seconds." + f"Pruner removed {job_rows_removed} job(s) and " + f"{concurrency_rule_rows_removed} concurrency config(s). " + f"Took {chancy_time.elapsed:.2f} seconds." ) await chancy.notify( @@ -122,7 +194,8 @@ async def run(self, worker: Worker, chancy: Chancy): "pruner.removed", { "elapsed": chancy_time.elapsed, - "rows_removed": rows_removed, + "job_rows_removed": job_rows_removed, + "concurrency_rule_rows_removed": concurrency_rule_rows_removed, }, ) @@ -136,7 +209,9 @@ async def run(self, worker: Worker, chancy: Chancy): f" row(s) from the database." ) - async def prune(self, chancy: Chancy, cursor: AsyncCursor[DictRow]) -> int: + async def prune_jobs( + self, chancy: Chancy, cursor: AsyncCursor[DictRow] + ) -> int: """ Prune stale records from the database. @@ -144,6 +219,9 @@ async def prune(self, chancy: Chancy, cursor: AsyncCursor[DictRow]) -> int: :param cursor: The database cursor to use for the operation. :return: The number of rows removed from the database """ + if self.job_rule is None: + return 0 + job_query = sql.SQL( """ WITH jobs_to_prune AS ( @@ -159,9 +237,41 @@ async def prune(self, chancy: Chancy, cursor: AsyncCursor[DictRow]) -> int: """ ).format( table=sql.Identifier(f"{chancy.prefix}jobs"), - rule=self.rule.to_sql(), + rule=self.job_rule.to_sql(), maximum_to_prune=sql.Literal(self.maximum_to_prune), ) await cursor.execute(job_query) return cursor.rowcount + + async def prune_concurrency_rules( + self, chancy: "Chancy", cursor: AsyncCursor[DictRow] + ) -> int: + """ + Prune stale concurrency rule records from the database. + + :param chancy: The Chancy application. + :param cursor: The database cursor to use for the operation. + :return: The number of rows removed from the database + """ + if self.concurrency_rule is None: + return 0 + + rule_sql = self.concurrency_rule.to_sql( + {"chancy_prefix": chancy.prefix} + ) + + config_query = sql.SQL( + """ + DELETE FROM {concurrency_configs} + WHERE ({rule}) + """ + ).format( + concurrency_configs=sql.Identifier( + f"{chancy.prefix}concurrency_configs" + ), + rule=rule_sql, + ) + + await cursor.execute(config_query) + return cursor.rowcount diff --git a/chancy/queue.py b/chancy/queue.py index 3b1993b..d1c128f 100644 --- a/chancy/queue.py +++ b/chancy/queue.py @@ -98,6 +98,14 @@ class State(enum.Enum): #: continuously available, as it can reduce latency between jobs. However, #: it can also increase load on the database and should be used with care. eager_polling: bool = False + #: The multiplier used to calculate the scan limit when fetching jobs. + #: The scan limit is calculated as `min(batch_size * scan_factor, scan_limit_upper_bound)`. + #: Higher values reduce the chance of starvation when many jobs are blocked + #: by concurrency limits, but increase query cost. + scan_factor: int = 20 + #: The maximum number of jobs to scan when fetching work, regardless of + #: the scan_factor calculation. + scan_limit_upper_bound: int = 1000 @classmethod def unpack(cls, data: dict) -> "Queue": @@ -116,6 +124,8 @@ def unpack(cls, data: dict) -> "Queue": rate_limit_window=data.get("rate_limit_window"), resume_at=data.get("resume_at"), eager_polling=data.get("eager_polling", False), + scan_factor=data.get("scan_factor", 20), + scan_limit_upper_bound=data.get("scan_limit_upper_bound", 1000), ) def pack(self) -> dict: @@ -135,4 +145,6 @@ def pack(self) -> dict: "rate_limit_window": self.rate_limit_window, "resume_at": self.resume_at, "eager_polling": self.eager_polling, + "scan_factor": self.scan_factor, + "scan_limit_upper_bound": self.scan_limit_upper_bound, } diff --git a/chancy/rule.py b/chancy/rule.py index a00c7d9..44a30e3 100644 --- a/chancy/rule.py +++ b/chancy/rule.py @@ -4,11 +4,12 @@ """ from typing import Any + from psycopg import sql class SQLAble: - def to_sql(self) -> sql.Composable: + def to_sql(self, context: dict = {}) -> sql.Composable: raise NotImplementedError @@ -46,7 +47,7 @@ def contains(self, value: str) -> "Condition": """ return Condition(self.to_sql(), "ILIKE", f"%{value}%") - def to_sql(self) -> sql.Composable: + def to_sql(self, context: dict = {}) -> sql.Composable: return sql.Identifier(self.field) @@ -62,7 +63,7 @@ def __or__(self, other: "Condition") -> "OrCondition": def __and__(self, other: "Condition") -> "AndCondition": return AndCondition(self, other) - def to_sql(self) -> sql.Composable: + def to_sql(self, context: dict = {}) -> sql.Composable: return sql.SQL("{field} {op} {value}").format( field=self.field, op=sql.SQL(self.op), @@ -81,9 +82,9 @@ def __or__(self, other: SQLAble) -> "OrCondition": def __and__(self, other: SQLAble) -> "AndCondition": return AndCondition(self, other) - def to_sql(self) -> sql.Composable: + def to_sql(self, context: dict = {}) -> sql.Composable: return sql.SQL("({left}) OR ({right})").format( - left=self.left.to_sql(), right=self.right.to_sql() + left=self.left.to_sql(context), right=self.right.to_sql(context) ) @@ -98,9 +99,9 @@ def __or__(self, other: Condition) -> OrCondition: def __and__(self, other: Condition) -> "AndCondition": return AndCondition(self, other) - def to_sql(self) -> sql.Composable: + def to_sql(self, context: dict = {}) -> sql.Composable: return sql.SQL("({left}) AND ({right})").format( - left=self.left.to_sql(), right=self.right.to_sql() + left=self.left.to_sql(context), right=self.right.to_sql(context) ) @@ -113,7 +114,7 @@ class Age(Rule): def __init__(self): super().__init__("age") - def to_sql(self) -> sql.Composable: + def to_sql(self, context: dict = {}) -> sql.Composable: return sql.SQL("EXTRACT(EPOCH FROM (NOW() - created_at))") class Queue(Rule): @@ -139,3 +140,40 @@ def __init__(self): class ID(Rule): def __init__(self): super().__init__("id") + + +class ConcurrencyRules: + """ + A collection of rules that can be used to filter the concurrency_rules table. + """ + + class Age(Rule): + """Age since last update (updated_at)""" + + def __init__(self): + super().__init__("age") + + def to_sql(self, context: dict = {}) -> sql.Composable: + return sql.SQL("EXTRACT(EPOCH FROM (NOW() - updated_at))") + + class Key(Rule): + """Concurrency key pattern matching""" + + def __init__(self): + super().__init__("concurrency_key") + + class Orphaned(Rule): + """Configs with no corresponding jobs""" + + def __init__(self): + super().__init__("orphaned") + + def to_sql(self, context: dict = {}) -> sql.Composable: + return sql.SQL( + "NOT EXISTS (SELECT 1 FROM {jobs_table} j WHERE j.concurrency_key = {concurrency_configs}.concurrency_key)" + ).format( + jobs_table=sql.Identifier(f"{context['chancy_prefix']}jobs"), + concurrency_configs=sql.Identifier( + f"{context['chancy_prefix']}concurrency_configs" + ), + ) diff --git a/chancy/worker.py b/chancy/worker.py index 5634250..1f95b37 100644 --- a/chancy/worker.py +++ b/chancy/worker.py @@ -640,6 +640,104 @@ async def queue_update(self, update: QueuedJob): """ await self.outgoing.put(update) + def _fetch_jobs_sql(self, prefix: str) -> sql.SQL: + """ + Build the complete fetch_jobs query. + + Args: + prefix: The table prefix to use (e.g., "chancy_") + + Returns: + sql.SQL query template with parameter placeholders + """ + return sql.SQL( + """ + WITH candidate_jobs AS ( + -- Get a reasonable sample of pending jobs for this queue + SELECT j.id, j.priority, j.concurrency_key + FROM {jobs} j + WHERE j.queue = %(queue)s + AND j.state IN ('pending', 'retrying') + AND j.attempts < j.max_attempts + AND (j.scheduled_at IS NULL OR j.scheduled_at <= NOW()) + ORDER BY j.priority DESC, j.id ASC + LIMIT %(scan_limit)s + ), + lockable_configs AS ( + -- Lock only the specific configs we need + SELECT cc.concurrency_key, cc.concurrency_max + FROM {concurrency_configs} cc + WHERE cc.concurrency_key IN (SELECT DISTINCT cj.concurrency_key + FROM candidate_jobs cj + WHERE cj.concurrency_key IS NOT NULL) + FOR UPDATE SKIP LOCKED + ), + current_usage AS ( + -- Count running jobs only for the locked concurrency keys + SELECT j.concurrency_key, COUNT(*) as running_count + FROM {jobs} j + WHERE j.state = 'running' + AND j.concurrency_key IN (SELECT lc.concurrency_key FROM lockable_configs lc) + GROUP BY j.concurrency_key + ), + available_slots AS ( + -- Calculate available slots for locked configs + SELECT + lc.concurrency_key, + lc.concurrency_max, + COALESCE(cu.running_count, 0) as running_count, + GREATEST(0, lc.concurrency_max - COALESCE(cu.running_count, 0)) as slots_available + FROM lockable_configs lc + LEFT JOIN current_usage cu ON cu.concurrency_key = lc.concurrency_key + ), + ranked_jobs AS ( + SELECT + cj.id, + cj.priority, + cj.concurrency_key, + asl.slots_available, + CASE + WHEN cj.concurrency_key IS NULL THEN 1 + ELSE ROW_NUMBER() OVER ( + PARTITION BY cj.concurrency_key + ORDER BY cj.priority DESC, cj.id ASC + ) + END as job_rank + FROM candidate_jobs cj + LEFT JOIN available_slots asl ON cj.concurrency_key = asl.concurrency_key + WHERE + -- Include non-constrained jobs + cj.concurrency_key IS NULL + OR + -- Include constrained jobs only if their config was lockable and has slots + (cj.concurrency_key IS NOT NULL AND asl.slots_available > 0) + ), + eligible_jobs AS ( + SELECT rj.id, rj.priority + FROM ranked_jobs rj + WHERE + -- Non-constrained jobs are always eligible + rj.concurrency_key IS NULL + OR + -- Constrained jobs must be within available slots + (rj.concurrency_key IS NOT NULL AND rj.job_rank <= rj.slots_available) + ORDER BY rj.priority DESC, rj.id ASC + LIMIT %(maximum_jobs_to_fetch)s + FOR UPDATE SKIP LOCKED + ) + UPDATE {jobs} SET + started_at = NOW(), + state = 'running', + taken_by = %(worker_id)s + FROM eligible_jobs ej + WHERE {jobs}.id = ej.id + RETURNING {jobs}.* + """ + ).format( + jobs=sql.Identifier(f"{prefix}jobs"), + concurrency_configs=sql.Identifier(f"{prefix}concurrency_configs"), + ) + async def fetch_jobs( self, queue: Queue, @@ -661,7 +759,6 @@ async def fetch_jobs( :param conn: The database connection to use. :param up_to: The maximum number of jobs to fetch. """ - jobs_table = sql.Identifier(f"{self.chancy.prefix}jobs") rate_limits_table = sql.Identifier( f"{self.chancy.prefix}queue_rate_limits" ) @@ -710,47 +807,17 @@ async def fetch_jobs( # Adjust up_to based on remaining rate limit up_to = min(up_to, queue.rate_limit - current_count) + # Use the centralized query builder method + query = self._fetch_jobs_sql(self.chancy.prefix) await cursor.execute( - sql.SQL( - """ - WITH selected_jobs AS ( - SELECT - id - FROM - {jobs} - WHERE - queue = %(queue)s - AND - (state = 'pending' OR state = 'retrying') - AND - attempts < max_attempts - AND - (scheduled_at IS NULL OR scheduled_at <= NOW()) - ORDER BY - priority DESC, - id ASC - LIMIT - %(maximum_jobs_to_fetch)s - FOR UPDATE OF {jobs} SKIP LOCKED - ) - UPDATE - {jobs} - SET - started_at = NOW(), - state = 'running', - taken_by = %(worker_id)s - FROM - selected_jobs - WHERE - {jobs}.id = selected_jobs.id - RETURNING {jobs}.* - """ - ).format( - jobs=jobs_table, - ), + query, { "queue": queue.name, "maximum_jobs_to_fetch": up_to, + "scan_limit": min( + up_to * queue.scan_factor, + queue.scan_limit_upper_bound, + ), "worker_id": self.worker_id, }, ) diff --git a/docs/howto/jobs.rst b/docs/howto/jobs.rst index bc0f899..16d3088 100644 --- a/docs/howto/jobs.rst +++ b/docs/howto/jobs.rst @@ -157,3 +157,124 @@ Prevent duplicate job execution by assigning a unique key: Unique jobs ensure only one job with the same ``unique_key`` is queued or running at a time, but any number can be completed or failed. + +Concurrency +----------------------- + +Control the number of jobs with the same concurrency key that can run +simultaneously across all workers and queues using with_concurrency(): + +.. code-block:: python + + from chancy import job, ConcurrencyRule + + @job() + def process_user_data(*, user_id: str, action: str): + print(f"Processing {action} for user {user_id}") + + async with Chancy("postgresql://localhost/postgres") as chancy: + # Limit to 1 concurrent job per user_id + job_with_limit = process_user_data.job.with_concurrency( + ConcurrencyRule( + max=1, + key="user_id" + ) + ) + await chancy.push(job_with_limit.with_kwargs(user_id="123", action="upload")) + +The ``key`` parameter determines how jobs are grouped for concurrency limits: + +**Field-based keys**: Use a parameter name to group by that field's value: + +.. code-block:: python + + # Limit by user_id - max 1 job per user + job.with_concurrency(ConcurrencyRule(max=1, key="user_id")) + +**Callable keys**: Use a function to compute complex grouping keys: + +.. code-block:: python + + # Limit by user + action combination + job.with_concurrency( + ConcurrencyRule( + max=2, + key=lambda user_id, action, **kw: f"{user_id}:{action}" + ) + ) + +**Function-level limits**: Omit the key to limit all jobs of this type: + +.. code-block:: python + + # Limit total concurrent jobs of this type to 5 + job.with_concurrency(ConcurrencyRule(max=5)) + +.. note:: + + Concurrency constraints are enforced globally across all workers in your + cluster. Jobs exceeding the limit will wait in the queue until a slot + becomes available. + +More about concurrency +~~~~~~~~~~~~~~~~~~~~~~ + +**Concurrency is checked at fetch time:** Chancy enforces concurrency limits +when workers fetch jobs by counting running jobs directly rather than +maintaining counters or leases. This pragmatic design keeps things simple +and robust: + +- **Atomic:** The check and claim happen in a single query, eliminating race + conditions between workers. +- **Easy recovery:** If a worker crashes, the recovery plugin marks the job as + pending again and the slot is automatically freed. No counters to decrement, + no leases to release. +- **Minimal overhead:** When few jobs use concurrency limits, performance + impact is negligible. Jobs without a concurrency key pass straight through, + and the database efficiently skips empty intermediate results. + +Any drawbacks of this approach (such as scan window limitations) can be +mitigated by leveraging Chancy's queue architecture to separate workloads. + +**How jobs are fetched:** Workers scan pending jobs by priority within a +configurable window. Jobs that cannot run due to concurrency limits are +skipped. If many jobs in the scan window are blocked, eligible jobs outside +the window won't be considered until the next fetch cycle. + +**When to use a dedicated queue:** If you have concurrency keys with strict +low limits but high volume, consider placing them in a dedicated queue. For +example, if you limit per-user processing to 1 concurrent job but have +thousands of users submitting work simultaneously, these jobs can fill the +scan window and delay other work. + +This separation ensures: + +- Unconstrained jobs aren't starved by concurrency-blocked jobs +- Each queue can be tuned separately + +**Tuning the scan window:** The scan limit determines how many pending jobs +the worker examines when looking for work. It is calculated as:: + + scan_limit = min(batch_size * scan_factor, scan_limit_upper_bound) + +Where: + +- **batch_size**: Number of jobs the worker wants to fetch (based on queue concurrency) +- **scan_factor**: Multiplier applied to batch_size (default: 20) +- **scan_limit_upper_bound**: Maximum scan limit regardless of batch_size (default: 1000) + +For example, with defaults and a worker fetching 10 jobs:: + + scan_limit = min(10 * 20, 1000) = 200 jobs scanned + +You can tune these parameters per-queue: + +.. code-block:: python + + Queue( + "user-processing", + scan_factor=50, # default: 20 + scan_limit_upper_bound=5000, # default: 1000 + ) + +Higher values reduce the chance of starvation but increase query cost. diff --git a/tests/plugins/test_pruner.py b/tests/plugins/test_pruner.py index 9d016cd..c76bb61 100644 --- a/tests/plugins/test_pruner.py +++ b/tests/plugins/test_pruner.py @@ -4,8 +4,9 @@ from psycopg.rows import dict_row from chancy import Chancy, Queue, Worker, job -from chancy.plugins.pruner import Pruner +from chancy.job import ConcurrencyRule from chancy.plugins.leadership import ImmediateLeadership +from chancy.plugins.pruner import Pruner @job() @@ -13,6 +14,39 @@ def job_to_run(): pass +@job(concurrency_rule=ConcurrencyRule(2, "user_id")) +def job_with_concurrency(user_id: int): + pass + + +async def _add_old_concurrency_rule( + chancy: Chancy, key: str, max_concurrency: int +): + """Helper to add an old concurrency rule directly to the database.""" + async with chancy.pool.connection() as conn: + async with conn.cursor() as cursor: + await cursor.execute( + f""" + INSERT INTO {chancy.prefix}concurrency_configs + (concurrency_key, concurrency_max, created_at, updated_at) + VALUES (%s, %s, NOW() - INTERVAL '8 days', NOW() - INTERVAL '8 days') + """, + (key, max_concurrency), + ) + await conn.commit() + + +async def _count_concurrency_rules(chancy: Chancy) -> int: + """Helper to count concurrency rules in the database.""" + async with chancy.pool.connection() as conn: + async with conn.cursor() as cursor: + await cursor.execute( + f"SELECT COUNT(*) FROM {chancy.prefix}concurrency_configs" + ) + result = await cursor.fetchone() + return result[0] if result else 0 + + @pytest.mark.parametrize( "chancy", [ @@ -26,11 +60,13 @@ def job_to_run(): indirect=True, ) @pytest.mark.asyncio -async def test_pruner_functionality(chancy: Chancy, worker: Worker): +async def test_job_rule_pruning(chancy: Chancy, worker: Worker): """ This test manually calls the prune method to avoid timing issues. """ - p = Pruner(Pruner.Rules.Queue() == "test_queue") + p = Pruner( + job_rule=Pruner.JobRules.Queue() == "test_queue", concurrency_rule=None + ) await chancy.declare(Queue("test_queue")) ref = await chancy.push(job_to_run.job.with_queue("test_queue")) @@ -39,13 +75,15 @@ async def test_pruner_functionality(chancy: Chancy, worker: Worker): async with chancy.pool.connection() as conn: async with conn.cursor(row_factory=dict_row) as cursor: - await p.prune(chancy, cursor) + await p.prune_jobs(chancy, cursor) pruned_job = await chancy.get_job(ref) assert pruned_job is None, "Job should be pruned" p = Pruner( - (Pruner.Rules.Queue() == "test_queue") & (Pruner.Rules.Age() > 10) + job_rule=(Pruner.JobRules.Queue() == "test_queue") + & (Pruner.JobRules.Age() > 10), + concurrency_rule=None, ) ref = await chancy.push(job_to_run.job.with_queue("test_queue")) initial_job = await chancy.wait_for_job(ref) @@ -53,7 +91,7 @@ async def test_pruner_functionality(chancy: Chancy, worker: Worker): async with chancy.pool.connection() as conn: async with conn.cursor(row_factory=dict_row) as cursor: - await p.prune(chancy, cursor) + await p.prune_jobs(chancy, cursor) not_pruned_job = await chancy.get_job(ref) assert not_pruned_job is not None, "Job should not be pruned yet" @@ -62,7 +100,169 @@ async def test_pruner_functionality(chancy: Chancy, worker: Worker): async with chancy.pool.connection() as conn: async with conn.cursor(row_factory=dict_row) as cursor: - await p.prune(chancy, cursor) + await p.prune_jobs(chancy, cursor) pruned_job = await chancy.get_job(ref) assert pruned_job is None, "Job should be pruned" + + +@pytest.mark.parametrize( + "chancy", + [ + { + "plugins": [ + ImmediateLeadership(), + ], + "no_default_plugins": True, + } + ], + indirect=True, +) +@pytest.mark.asyncio +async def test_concurrency_rule_pruning_by_age(chancy: Chancy, worker: Worker): + """Test pruning concurrency rules older than a certain age.""" + # Create a pruner that cleans rules older than 3 days + p = Pruner( + job_rule=None, + concurrency_rule=Pruner.ConcurrencyRules.Age() + > 60 * 60 * 24 * 3, # 3 days + ) + + # Add an old concurrency rule (8 days old) + await _add_old_concurrency_rule( + chancy, "test.job_with_concurrency:user_123", 2 + ) + + # Verify concurrency rule exists + initial_count = await _count_concurrency_rules(chancy) + assert initial_count == 1, "Concurrency rule should exist before pruning" + + # Run concurrency rule pruning + async with chancy.pool.connection() as conn: + async with conn.cursor(row_factory=dict_row) as cursor: + rows_removed = await p.prune_concurrency_rules(chancy, cursor) + + # Verify concurrency rule was pruned + assert rows_removed == 1, "Should have removed 1 concurrency rule" + final_count = await _count_concurrency_rules(chancy) + assert final_count == 0, "Concurrency rule should be pruned" + + +@pytest.mark.parametrize( + "chancy", + [ + { + "plugins": [ + ImmediateLeadership(), + ], + "no_default_plugins": True, + } + ], + indirect=True, +) +@pytest.mark.asyncio +async def test_concurrency_rule_pruning_orphaned( + chancy: Chancy, worker: Worker +): + """Test pruning orphaned concurrency rules (no corresponding jobs).""" + # Create a pruner that only cleans orphaned concurrency rules + p = Pruner( + job_rule=None, + concurrency_rule=Pruner.ConcurrencyRules.Orphaned(), + ) + + # Add a concurrency rule without any corresponding jobs + await _add_old_concurrency_rule(chancy, "test.orphaned_job:user_456", 3) + + # Also add a config that will have a corresponding job + await chancy.declare(Queue("test_queue")) + await chancy.push( + job_with_concurrency.job.with_queue("test_queue").with_kwargs( + user_id=789 + ) + ) + + # Verify both concurrency rules exist + initial_count = await _count_concurrency_rules(chancy) + assert initial_count == 2, "Should have 2 concurrency rules before pruning" + + # Run concurrency rule pruning + async with chancy.pool.connection() as conn: + async with conn.cursor(row_factory=dict_row) as cursor: + rows_removed = await p.prune_concurrency_rules(chancy, cursor) + + # Verify only orphaned concurrency rule was pruned + assert rows_removed == 1, "Should have removed 1 orphaned concurrency rule" + final_count = await _count_concurrency_rules(chancy) + assert final_count == 1, ( + "Should have 1 concurrency rule remaining (non-orphaned)" + ) + + +@pytest.mark.parametrize( + "chancy", + [ + { + "plugins": [ + ImmediateLeadership(), + ], + "no_default_plugins": True, + } + ], + indirect=True, +) +@pytest.mark.asyncio +async def test_concurrency_rule_pruning_combined_rules( + chancy: Chancy, worker: Worker +): + """Test pruning with combined rules (age OR orphaned).""" + # Create a pruner with combined rules: old OR orphaned + p = Pruner( + job_rule=None, + concurrency_rule=( + (Pruner.ConcurrencyRules.Age() > 60 * 60 * 24 * 3) # 3 days + | Pruner.ConcurrencyRules.Orphaned() + ), + ) + + # Add an old config (8 days) + await _add_old_concurrency_rule(chancy, "test.old_job:user_111", 1) + + # Add an orphaned config (recent but no jobs) + async with chancy.pool.connection() as conn: + async with conn.cursor() as cursor: + await cursor.execute( + f""" + INSERT INTO {chancy.prefix}concurrency_configs + (concurrency_key, concurrency_max, created_at, updated_at) + VALUES (%s, %s, NOW(), NOW()) + """, + ("test.orphaned_recent:user_222", 2), + ) + await conn.commit() + + # Add a fresh concurrency rule with corresponding job + await chancy.declare(Queue("test_queue")) + await chancy.push( + job_with_concurrency.job.with_queue("test_queue").with_kwargs( + user_id=333 + ) + ) + + # Verify all concurrency rules exist + initial_count = await _count_concurrency_rules(chancy) + assert initial_count == 3, "Should have 3 concurrency rules before pruning" + + # Run concurrency rule pruning + async with chancy.pool.connection() as conn: + async with conn.cursor(row_factory=dict_row) as cursor: + rows_removed = await p.prune_concurrency_rules(chancy, cursor) + + # Verify old and orphaned concurrency rules were pruned, but fresh one with job remains + assert rows_removed == 2, ( + "Should have removed 2 concurrency rules (old + orphaned)" + ) + final_count = await _count_concurrency_rules(chancy) + assert final_count == 1, ( + "Should have 1 concurrency rule remaining (fresh with job)" + ) diff --git a/tests/test_concurrency.py b/tests/test_concurrency.py new file mode 100644 index 0000000..0202327 --- /dev/null +++ b/tests/test_concurrency.py @@ -0,0 +1,346 @@ +import asyncio +import time + +import pytest + +from chancy import Chancy, Job, QueuedJob, Worker, job +from chancy.job import ConcurrencyRule + + +@job() +def simple_job(): + """A simple job for testing""" + pass + + +@job() +def slow_job(user_id: str, action: str = "default", duration: float = 0.5): + """A job that takes some time to complete.""" + time.sleep(duration) + + +async def _count_running_jobs_for_key( + chancy: Chancy, concurrency_key: str +) -> int: + """Count running jobs for a specific concurrency key.""" + async with chancy.pool.connection() as conn: + async with conn.cursor() as cursor: + await cursor.execute( + f""" + SELECT COUNT(*) FROM {chancy.prefix}jobs + WHERE concurrency_key = %s AND state = 'running' + """, + (concurrency_key,), + ) + result = await cursor.fetchone() + return result[0] if result else 0 + + +async def _sample_running_counts( + chancy: Chancy, + concurrency_key: str, + samples: int = 20, + interval: float = 0.25, +) -> list[int]: + """Sample running job counts over time.""" + counts = [] + for _ in range(samples): + count = await _count_running_jobs_for_key(chancy, concurrency_key) + counts.append(count) + await asyncio.sleep(interval) + return counts + + +async def _push_many_collect(chancy: Chancy, jobs: list) -> list: + """Push many jobs and collect all references from the async generator.""" + refs = [] + async for batch_refs in chancy.push_many(jobs): + refs.extend(batch_refs) + return refs + + +class TestConcurrencyKeyEvaluation: + """Test concurrency key evaluation logic""" + + def test_no_concurrency_key(self): + """Test job without concurrency constraints""" + job = Job.from_func(simple_job) + result = job.evaluate_concurrency_key() + assert result is None + + def test_max_concurrent_only(self): + """Test job with only max_concurrent (no key specified)""" + job = Job.from_func(simple_job).with_concurrency(ConcurrencyRule(max=3)) + result = job.evaluate_concurrency_key() + assert result == "test_concurrency.simple_job" + + def test_simple_field_key(self): + """Test simple field-based concurrency key""" + job = ( + Job.from_func(simple_job) + .with_concurrency(ConcurrencyRule(max=1, key="user_id")) + .with_kwargs(user_id="123", action="upload") + ) + result = job.evaluate_concurrency_key() + assert result == "test_concurrency.simple_job:123" + + def test_callable_key(self): + """Test callable concurrency key""" + + def key_func(user_id: str, action: str, **kw) -> str: + return f"{user_id}:{action}" + + job = ( + Job.from_func(simple_job) + .with_concurrency(ConcurrencyRule(max=1, key=key_func)) + .with_kwargs(user_id="123", action="upload") + ) + result = job.evaluate_concurrency_key() + assert result == "test_concurrency.simple_job:123:upload" + + def test_missing_field_raises_error(self): + """Test that missing field raises an error""" + job = ( + Job.from_func(simple_job) + .with_concurrency(ConcurrencyRule(max=1, key="missing_field")) + .with_kwargs(user_id="123") + ) + with pytest.raises( + ValueError, match="Failed to evaluate concurrency key" + ): + job.evaluate_concurrency_key() + + def test_callable_exception_raises_error(self): + """Test that callable exceptions are properly raised""" + + def failing_key(**kwargs): + raise ValueError("Test error") + + job = ( + Job.from_func(simple_job) + .with_concurrency(ConcurrencyRule(max=1, key=failing_key)) + .with_kwargs(user_id="123") + ) + with pytest.raises( + ValueError, match="Failed to evaluate concurrency key" + ): + job.evaluate_concurrency_key() + + def test_none_values_from_callable(self): + """Test that None values from callables raise errors""" + + def none_key(**kwargs): + return None + + job = ( + Job.from_func(simple_job) + .with_concurrency(ConcurrencyRule(max=1, key=none_key)) + .with_kwargs(user_id="123") + ) + with pytest.raises( + ValueError, match="Failed to evaluate concurrency key" + ): + job.evaluate_concurrency_key() + + +class TestJobWithConcurrency: + """Test Job class concurrency methods""" + + def test_with_concurrency_method(self): + """Test the with_concurrency fluent method""" + # Test simple string key + job_with_concurrency = simple_job.job.with_concurrency( + ConcurrencyRule(max=3, key="user_id") + ) + assert job_with_concurrency.concurrency_rule.key == "user_id" + assert job_with_concurrency.concurrency_rule.max == 3 + + # Original job should be unchanged (immutable) + assert simple_job.job.concurrency_rule is None + + def test_with_concurrency_callable_key(self): + """Test with_concurrency with callable key""" + + def key_func(user_id: str, action: str, **kw) -> str: + return f"{user_id}:{action}" + + job_with_concurrency = simple_job.job.with_concurrency( + ConcurrencyRule(max=5, key=key_func) + ) + assert job_with_concurrency.concurrency_rule.key == key_func + assert job_with_concurrency.concurrency_rule.max == 5 + + +@pytest.mark.asyncio +class TestConcurrencyIntegration: + """Integration tests for concurrency constraints""" + + async def test_concurrency_config_storage(self, chancy: Chancy): + """Test that concurrency configurations are stored in the database""" + + # Push a job with concurrency constraints + job_with_concurrency = slow_job.job.with_concurrency( + ConcurrencyRule(max=3, key="user_id") + ).with_kwargs(user_id="user_123", action="test") + await chancy.push(job_with_concurrency) + + # Check that concurrency config was stored + async with chancy.pool.connection() as conn: + async with conn.cursor() as cursor: + await cursor.execute( + f"SELECT * FROM {chancy.prefix}concurrency_configs WHERE concurrency_key = %s", + ("test_concurrency.slow_job:user_123",), + ) + result = await cursor.fetchone() + + assert result is not None + assert ( + result[0] == "test_concurrency.slow_job:user_123" + ) # concurrency_key (prefixed) + assert result[1] == 3 # concurrency_max + + async def test_basic_concurrency_limiting( + self, chancy: Chancy, worker: Worker + ): + """Test basic concurrency limiting verifies limit is enforced""" + concurrency_key = "test_concurrency.slow_job:user_123" + + # Create a job with concurrency limit of 2 per user + job_with_concurrency = slow_job.job.with_concurrency( + ConcurrencyRule(max=2, key="user_id") + ) + + # Push 5 jobs for the same user - should only run 2 at a time + jobs = [ + job_with_concurrency.with_kwargs( + user_id="user_123", action=f"action_{i}", duration=0.5 + ) + for i in range(5) + ] + refs = await _push_many_collect(chancy, jobs) + + # Sample running counts while jobs execute + await asyncio.sleep(0.2) + running_counts = await _sample_running_counts( + chancy, concurrency_key, samples=15, interval=0.1 + ) + + # Wait for all jobs to complete + completed_jobs = await chancy.wait_for_jobs(refs, timeout=30) + + # All jobs should be completed + for job_result in completed_jobs: + assert job_result.state == QueuedJob.State.SUCCEEDED + + # Verify concurrency limit was respected + max_observed = max(running_counts) if running_counts else 0 + assert max_observed <= 2, ( + f"Concurrency limit violated: observed {max_observed} " + f"concurrent jobs, expected at most 2. Samples: {running_counts}" + ) + # Verify we actually observed some concurrency + assert max_observed >= 1, ( + f"Expected to observe at least 1 running job. Samples: {running_counts}" + ) + + async def test_concurrency_limit_enforced_across_workers( + self, chancy: Chancy + ): + """ + Test that concurrency limits are enforced across multiple workers. + + This test: + 1. Starts multiple workers + 2. Pushes many jobs with a concurrency limit of 2 for the same key + 3. Samples running job count and verifies it never exceeds the limit + """ + concurrency_key = "test_concurrency.slow_job:shared_user" + + # Create job with concurrency limit of 2 per user + job_template = slow_job.job.with_concurrency( + ConcurrencyRule(max=2, key="user_id") + ) + + # Push 8 jobs for the same user - should only run 2 at a time + jobs = [ + job_template.with_kwargs(user_id="shared_user", duration=0.5) + for _ in range(8) + ] + refs = await _push_many_collect(chancy, jobs) + + # Start 3 workers to increase parallelism pressure + workers = [Worker(chancy, shutdown_timeout=30) for _ in range(3)] + for w in workers: + await w.start() + + try: + # Sample running counts while jobs execute + # Start sampling after a brief delay to let jobs start + await asyncio.sleep(0.3) + running_counts = await _sample_running_counts( + chancy, concurrency_key, samples=30, interval=0.1 + ) + + # Wait for all jobs to complete + completed_jobs = await chancy.wait_for_jobs(refs, timeout=60) + + # Verify all jobs completed successfully + for job_result in completed_jobs: + assert job_result.state == QueuedJob.State.SUCCEEDED + + # Verify concurrency limit was respected during sampling + max_observed = max(running_counts) if running_counts else 0 + assert max_observed <= 2, ( + f"Concurrency limit violated: observed {max_observed} " + f"concurrent jobs, expected at most 2. Samples: {running_counts}" + ) + # Verify we actually observed some concurrency + assert max_observed >= 1, ( + f"Expected to observe at least 1 running job. Samples: {running_counts}" + ) + + finally: + # Stop all workers + for w in workers: + await w.stop() + + async def test_jobs_without_concurrency_not_blocked_by_limited_jobs( + self, chancy: Chancy, worker: Worker + ): + """ + Test that jobs without concurrency constraints are not blocked + by jobs that have concurrency limits. + """ + concurrency_key = "test_concurrency.slow_job:limited_user" + + # Push multiple jobs with strict concurrency limit (max 1) + limited_jobs = [ + slow_job.job.with_concurrency( + ConcurrencyRule(max=1, key="user_id") + ).with_kwargs(user_id="limited_user", duration=0.6) + for _ in range(3) + ] + + # Push regular jobs without concurrency constraints + regular_jobs = [simple_job.job for _ in range(3)] + + # Push all jobs - limited jobs first, then regular jobs + refs = await _push_many_collect(chancy, limited_jobs + regular_jobs) + + # Sample running counts for the limited key + await asyncio.sleep(0.2) + running_counts = await _sample_running_counts( + chancy, concurrency_key, samples=15, interval=0.1 + ) + + # Wait for all jobs and verify they succeeded + completed_jobs = await chancy.wait_for_jobs(refs, timeout=30) + for job_result in completed_jobs: + assert job_result.state == QueuedJob.State.SUCCEEDED + + # Verify the limited jobs never exceeded their concurrency limit + max_observed = max(running_counts) if running_counts else 0 + assert max_observed <= 1, ( + f"Concurrency limit violated: observed {max_observed} " + f"concurrent jobs, expected at most 1. Samples: {running_counts}" + )