Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion docs/reference/task_cli.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,18 @@ if __name__ == "__main__":

### `ls` — list registered tasks

Prints a table of all registered tasks with their schedule, priority, timeout, and CPU-bound flag.
Prints a table of all registered tasks with their schedule, priority, timeout, CPU-bound flag, and tags.

```bash
python -m myapp ls
```

Pass `--tags`/`-t` (repeatable) to only list tasks that have at least one of the given tags:

```bash
python -m myapp ls --tags fast --tags slow
```

### `serve` — start the HTTP server

```bash
Expand Down
8 changes: 8 additions & 0 deletions docs/tutorials/task_app.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,14 @@ $ python -m examples.simple_fastapi

and check the openapi UI at [http://127.0.0.1:8000/docs](http://127.0.0.1:8000/docs).

The `GET /tasks` endpoint lists registered tasks and accepts a repeatable `tags`
query parameter to only return tasks that have at least one of the given tags:

```
GET /tasks
GET /tasks?tags=fast&tags=slow
```


## Task App Command Line

Expand Down
17 changes: 14 additions & 3 deletions examples/tasks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,15 +58,23 @@ class Sleep(BaseModel):
abort: bool = False


@task(max_concurrency=1, timeout_seconds=2)
@task(
max_concurrency=1,
timeout_seconds=2,
tags=("test", "fast"),
)
async def fast(context: TaskRun[Sleep]) -> None:
"""A task that sleeps for a while but has a 2 seconds timeout"""
await asyncio.sleep(context.params.sleep)
if context.params.error:
raise RuntimeError("just an error")


@task(max_concurrency=1, timeout_seconds=120)
@task(
max_concurrency=1,
timeout_seconds=120,
tags=("test", "slow"),
)
async def dummy(context: TaskRun[Sleep]) -> None:
"""A task that sleeps for a while or errors"""
await asyncio.sleep(context.params.sleep)
Expand All @@ -76,7 +84,10 @@ async def dummy(context: TaskRun[Sleep]) -> None:
raise RuntimeError("just an error")


@task(schedule=every(timedelta(seconds=2)), tags=["skip_db"])
@task(
schedule=every(timedelta(seconds=2)),
tags=("skip_db",),
)
async def ping(context: TaskRun) -> None:
"""A simple scheduled task that ping the broker"""
redis_cli = cast(RedisTaskBroker, context.task_manager.broker).redis_cli
Expand Down
18 changes: 14 additions & 4 deletions fluid/scheduler/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,11 +129,17 @@ def execute_task(log: bool, run_id: str, params: str, **extra: Any) -> None:


@click.command()
@click.option(
"--tags",
"-t",
multiple=True,
help="Only list tasks that have at least one of these tags",
)
@click.pass_context
def ls(ctx: click.Context) -> None:
def ls(ctx: click.Context, tags: tuple[str, ...]) -> None:
"""List all tasks with their schedules"""
task_manager = ctx_task_manager(ctx)
table = asyncio.run(tasks_table(task_manager))
table = asyncio.run(tasks_table(task_manager, tags=set(tags)))
console = Console()
console.print(table)

Expand Down Expand Up @@ -219,7 +225,7 @@ async def enable_task(task_manager: TaskManager, task: str, enable: bool) -> Non
raise click.ClickException(f"Task {task} not found") from e


async def tasks_table(task_manager: TaskManager) -> Table:
async def tasks_table(task_manager: TaskManager, tags: set[str] | None = None) -> Table:
task_info = await task_manager.broker.get_tasks_info()
dynamic = {t.name: t for t in task_info}
table = Table(title="Tasks")
Expand All @@ -229,16 +235,20 @@ async def tasks_table(task_manager: TaskManager) -> Table:
table.add_column("CPU bound", style="magenta")
table.add_column("Timeout secs", style="green")
table.add_column("Priority", style="magenta")
table.add_column("Tags", style="blue")
table.add_column("Description", style="green")
for name in sorted(task_manager.registry):
task = task_manager.registry[name]
if tags and not tags.intersection(task.tags):
continue
table.add_row(
name,
":white_check_mark:" if dynamic[name].enabled else "[red]:x:",
str(task.schedule),
str(task.schedule) if task.schedule is not None else "[red]:x:",
":white_check_mark:" if task.cpu_bound else "[red]:x:",
str(task.timeout_seconds),
str(task.priority),
", ".join(sorted(task.tags)),
task.short_description,
)
return table
16 changes: 13 additions & 3 deletions fluid/scheduler/endpoints.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from enum import Enum
from typing import Any, Callable, Sequence, cast

from fastapi import APIRouter, Depends, FastAPI, HTTPException, Path, Request
from fastapi import APIRouter, Depends, FastAPI, HTTPException, Path, Query, Request
from pydantic import BaseModel, Field
from typing_extensions import Annotated, Doc

Expand Down Expand Up @@ -132,8 +132,18 @@ async def queue_task(
return queue_task


async def get_tasks(task_manager: TaskManagerDep) -> list[TaskInfo]:
return await task_manager.broker.get_tasks_info()
async def get_tasks(
task_manager: TaskManagerDep,
tags: Annotated[
list[str] | None,
Query(description="Only return tasks that have at least one of these tags"),
] = None,
) -> list[TaskInfo]:
tasks = await task_manager.broker.get_tasks_info()
if tags:
wanted = set(tags)
tasks = [task for task in tasks if wanted.intersection(task.tags)]
return tasks


async def get_task(
Expand Down
2 changes: 2 additions & 0 deletions fluid/scheduler/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,7 @@ class TaskInfoBase(BaseModel):
module: str = Field(description="Task module")
priority: TaskPriority = Field(description="Task priority")
schedule: str | None = Field(default=None, description="Task schedule")
tags: frozenset[str] = Field(default_factory=frozenset, description="Task tags")


class TaskInfoUpdate(BaseModel):
Expand Down Expand Up @@ -327,6 +328,7 @@ def info(self, **params: Any) -> TaskInfo:
module=self.module,
priority=self.priority,
schedule=str(self.schedule) if self.schedule else None,
tags=self.tags,
)
return TaskInfo(**compact_dict(params))

Expand Down
15 changes: 15 additions & 0 deletions tests/scheduler/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,21 @@ def test_cli_ls():
result = runner.invoke(task_manager_cli, ["ls"])
assert result.exit_code == 0
assert result.output
assert "ping" in result.output
assert "dummy" in result.output


def test_cli_ls_tags():
runner = CliRunner()
result = runner.invoke(task_manager_cli, ["ls", "--tags", "skip_db"])
assert result.exit_code == 0
assert "ping" in result.output
assert "dummy" not in result.output
result = runner.invoke(task_manager_cli, ["ls", "-t", "slow", "-t", "fast"])
assert result.exit_code == 0
assert "dummy" in result.output
assert "fast" in result.output
assert "ping" not in result.output


def test_cli_exec_empty():
Expand Down
15 changes: 15 additions & 0 deletions tests/scheduler/test_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,21 @@ async def test_get_tasks(cli: TaskClient) -> None:
tasks = {task["name"]: TaskInfo(**task) for task in data}
dummy = tasks["dummy"]
assert dummy.name == "dummy"
assert set(dummy.tags) == {"test", "slow"}


async def test_get_tasks_by_tags(cli: TaskClient) -> None:
data = await cli.get(f"{cli.url}/tasks?tags=test")
names = {task["name"] for task in data}
assert names == {"fast", "dummy"}
data = await cli.get(f"{cli.url}/tasks?tags=skip_db")
names = {task["name"] for task in data}
assert names == {"ping"}
data = await cli.get(f"{cli.url}/tasks?tags=skip_db&tags=fast")
names = {task["name"] for task in data}
assert names == {"ping", "fast"}
data = await cli.get(f"{cli.url}/tasks?tags=does-not-exist")
assert data == []


async def test_get_tasks_status(cli: TaskClient) -> None:
Expand Down
Loading