diff --git a/docs/reference/task_cli.md b/docs/reference/task_cli.md index 3393c94..034e6f5 100644 --- a/docs/reference/task_cli.md +++ b/docs/reference/task_cli.md @@ -27,12 +27,18 @@ if __name__ == "__main__": ### `ls` — list registered tasks -Prints a table of all registered tasks with their schedule, priority, timeout, and CPU-bound flag. +Prints a table of all registered tasks with their schedule, priority, timeout, CPU-bound flag, and tags. ```bash python -m myapp ls ``` +Pass `--tags`/`-t` (repeatable) to only list tasks that have at least one of the given tags: + +```bash +python -m myapp ls --tags fast --tags slow +``` + ### `serve` — start the HTTP server ```bash diff --git a/docs/tutorials/task_app.md b/docs/tutorials/task_app.md index 88a9176..8f90d2c 100644 --- a/docs/tutorials/task_app.md +++ b/docs/tutorials/task_app.md @@ -53,6 +53,14 @@ $ python -m examples.simple_fastapi and check the openapi UI at [http://127.0.0.1:8000/docs](http://127.0.0.1:8000/docs). +The `GET /tasks` endpoint lists registered tasks and accepts a repeatable `tags` +query parameter to only return tasks that have at least one of the given tags: + +``` +GET /tasks +GET /tasks?tags=fast&tags=slow +``` + ## Task App Command Line diff --git a/examples/tasks/__init__.py b/examples/tasks/__init__.py index 507c331..245233d 100644 --- a/examples/tasks/__init__.py +++ b/examples/tasks/__init__.py @@ -58,7 +58,11 @@ class Sleep(BaseModel): abort: bool = False -@task(max_concurrency=1, timeout_seconds=2) +@task( + max_concurrency=1, + timeout_seconds=2, + tags=("test", "fast"), +) async def fast(context: TaskRun[Sleep]) -> None: """A task that sleeps for a while but has a 2 seconds timeout""" await asyncio.sleep(context.params.sleep) @@ -66,7 +70,11 @@ async def fast(context: TaskRun[Sleep]) -> None: raise RuntimeError("just an error") -@task(max_concurrency=1, timeout_seconds=120) +@task( + max_concurrency=1, + timeout_seconds=120, + tags=("test", "slow"), +) async def dummy(context: TaskRun[Sleep]) -> None: """A task that sleeps for a while or errors""" await asyncio.sleep(context.params.sleep) @@ -76,7 +84,10 @@ async def dummy(context: TaskRun[Sleep]) -> None: raise RuntimeError("just an error") -@task(schedule=every(timedelta(seconds=2)), tags=["skip_db"]) +@task( + schedule=every(timedelta(seconds=2)), + tags=("skip_db",), +) async def ping(context: TaskRun) -> None: """A simple scheduled task that ping the broker""" redis_cli = cast(RedisTaskBroker, context.task_manager.broker).redis_cli diff --git a/fluid/scheduler/cli.py b/fluid/scheduler/cli.py index f2e754d..94b53a4 100644 --- a/fluid/scheduler/cli.py +++ b/fluid/scheduler/cli.py @@ -129,11 +129,17 @@ def execute_task(log: bool, run_id: str, params: str, **extra: Any) -> None: @click.command() +@click.option( + "--tags", + "-t", + multiple=True, + help="Only list tasks that have at least one of these tags", +) @click.pass_context -def ls(ctx: click.Context) -> None: +def ls(ctx: click.Context, tags: tuple[str, ...]) -> None: """List all tasks with their schedules""" task_manager = ctx_task_manager(ctx) - table = asyncio.run(tasks_table(task_manager)) + table = asyncio.run(tasks_table(task_manager, tags=set(tags))) console = Console() console.print(table) @@ -219,7 +225,7 @@ async def enable_task(task_manager: TaskManager, task: str, enable: bool) -> Non raise click.ClickException(f"Task {task} not found") from e -async def tasks_table(task_manager: TaskManager) -> Table: +async def tasks_table(task_manager: TaskManager, tags: set[str] | None = None) -> Table: task_info = await task_manager.broker.get_tasks_info() dynamic = {t.name: t for t in task_info} table = Table(title="Tasks") @@ -229,16 +235,20 @@ async def tasks_table(task_manager: TaskManager) -> Table: table.add_column("CPU bound", style="magenta") table.add_column("Timeout secs", style="green") table.add_column("Priority", style="magenta") + table.add_column("Tags", style="blue") table.add_column("Description", style="green") for name in sorted(task_manager.registry): task = task_manager.registry[name] + if tags and not tags.intersection(task.tags): + continue table.add_row( name, ":white_check_mark:" if dynamic[name].enabled else "[red]:x:", - str(task.schedule), + str(task.schedule) if task.schedule is not None else "[red]:x:", ":white_check_mark:" if task.cpu_bound else "[red]:x:", str(task.timeout_seconds), str(task.priority), + ", ".join(sorted(task.tags)), task.short_description, ) return table diff --git a/fluid/scheduler/endpoints.py b/fluid/scheduler/endpoints.py index f06125d..590d7c8 100644 --- a/fluid/scheduler/endpoints.py +++ b/fluid/scheduler/endpoints.py @@ -1,7 +1,7 @@ from enum import Enum from typing import Any, Callable, Sequence, cast -from fastapi import APIRouter, Depends, FastAPI, HTTPException, Path, Request +from fastapi import APIRouter, Depends, FastAPI, HTTPException, Path, Query, Request from pydantic import BaseModel, Field from typing_extensions import Annotated, Doc @@ -132,8 +132,18 @@ async def queue_task( return queue_task -async def get_tasks(task_manager: TaskManagerDep) -> list[TaskInfo]: - return await task_manager.broker.get_tasks_info() +async def get_tasks( + task_manager: TaskManagerDep, + tags: Annotated[ + list[str] | None, + Query(description="Only return tasks that have at least one of these tags"), + ] = None, +) -> list[TaskInfo]: + tasks = await task_manager.broker.get_tasks_info() + if tags: + wanted = set(tags) + tasks = [task for task in tasks if wanted.intersection(task.tags)] + return tasks async def get_task( diff --git a/fluid/scheduler/models.py b/fluid/scheduler/models.py index e65d292..ab07c09 100644 --- a/fluid/scheduler/models.py +++ b/fluid/scheduler/models.py @@ -185,6 +185,7 @@ class TaskInfoBase(BaseModel): module: str = Field(description="Task module") priority: TaskPriority = Field(description="Task priority") schedule: str | None = Field(default=None, description="Task schedule") + tags: frozenset[str] = Field(default_factory=frozenset, description="Task tags") class TaskInfoUpdate(BaseModel): @@ -327,6 +328,7 @@ def info(self, **params: Any) -> TaskInfo: module=self.module, priority=self.priority, schedule=str(self.schedule) if self.schedule else None, + tags=self.tags, ) return TaskInfo(**compact_dict(params)) diff --git a/tests/scheduler/test_cli.py b/tests/scheduler/test_cli.py index e9d6982..5eb3949 100644 --- a/tests/scheduler/test_cli.py +++ b/tests/scheduler/test_cli.py @@ -34,6 +34,21 @@ def test_cli_ls(): result = runner.invoke(task_manager_cli, ["ls"]) assert result.exit_code == 0 assert result.output + assert "ping" in result.output + assert "dummy" in result.output + + +def test_cli_ls_tags(): + runner = CliRunner() + result = runner.invoke(task_manager_cli, ["ls", "--tags", "skip_db"]) + assert result.exit_code == 0 + assert "ping" in result.output + assert "dummy" not in result.output + result = runner.invoke(task_manager_cli, ["ls", "-t", "slow", "-t", "fast"]) + assert result.exit_code == 0 + assert "dummy" in result.output + assert "fast" in result.output + assert "ping" not in result.output def test_cli_exec_empty(): diff --git a/tests/scheduler/test_endpoints.py b/tests/scheduler/test_endpoints.py index 4d0e50e..de43a08 100644 --- a/tests/scheduler/test_endpoints.py +++ b/tests/scheduler/test_endpoints.py @@ -14,6 +14,21 @@ async def test_get_tasks(cli: TaskClient) -> None: tasks = {task["name"]: TaskInfo(**task) for task in data} dummy = tasks["dummy"] assert dummy.name == "dummy" + assert set(dummy.tags) == {"test", "slow"} + + +async def test_get_tasks_by_tags(cli: TaskClient) -> None: + data = await cli.get(f"{cli.url}/tasks?tags=test") + names = {task["name"] for task in data} + assert names == {"fast", "dummy"} + data = await cli.get(f"{cli.url}/tasks?tags=skip_db") + names = {task["name"] for task in data} + assert names == {"ping"} + data = await cli.get(f"{cli.url}/tasks?tags=skip_db&tags=fast") + names = {task["name"] for task in data} + assert names == {"ping", "fast"} + data = await cli.get(f"{cli.url}/tasks?tags=does-not-exist") + assert data == [] async def test_get_tasks_status(cli: TaskClient) -> None: