Skip to content

Commit 57d128a

Browse files
committed
Add tags filtering
1 parent a95da66 commit 57d128a

8 files changed

Lines changed: 88 additions & 11 deletions

File tree

docs/reference/task_cli.md

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,18 @@ if __name__ == "__main__":
2727

2828
### `ls` — list registered tasks
2929

30-
Prints a table of all registered tasks with their schedule, priority, timeout, and CPU-bound flag.
30+
Prints a table of all registered tasks with their schedule, priority, timeout, CPU-bound flag, and tags.
3131

3232
```bash
3333
python -m myapp ls
3434
```
3535

36+
Pass `--tags`/`-t` (repeatable) to only list tasks that have at least one of the given tags:
37+
38+
```bash
39+
python -m myapp ls --tags fast --tags slow
40+
```
41+
3642
### `serve` — start the HTTP server
3743

3844
```bash

docs/tutorials/task_app.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,14 @@ $ python -m examples.simple_fastapi
5353

5454
and check the openapi UI at [http://127.0.0.1:8000/docs](http://127.0.0.1:8000/docs).
5555

56+
The `GET /tasks` endpoint lists registered tasks and accepts a repeatable `tags`
57+
query parameter to only return tasks that have at least one of the given tags:
58+
59+
```
60+
GET /tasks
61+
GET /tasks?tags=fast&tags=slow
62+
```
63+
5664

5765
## Task App Command Line
5866

examples/tasks/__init__.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,15 +58,23 @@ class Sleep(BaseModel):
5858
abort: bool = False
5959

6060

61-
@task(max_concurrency=1, timeout_seconds=2)
61+
@task(
62+
max_concurrency=1,
63+
timeout_seconds=2,
64+
tags=("test", "fast"),
65+
)
6266
async def fast(context: TaskRun[Sleep]) -> None:
6367
"""A task that sleeps for a while but has a 2 seconds timeout"""
6468
await asyncio.sleep(context.params.sleep)
6569
if context.params.error:
6670
raise RuntimeError("just an error")
6771

6872

69-
@task(max_concurrency=1, timeout_seconds=120)
73+
@task(
74+
max_concurrency=1,
75+
timeout_seconds=120,
76+
tags=("test", "slow"),
77+
)
7078
async def dummy(context: TaskRun[Sleep]) -> None:
7179
"""A task that sleeps for a while or errors"""
7280
await asyncio.sleep(context.params.sleep)
@@ -76,7 +84,10 @@ async def dummy(context: TaskRun[Sleep]) -> None:
7684
raise RuntimeError("just an error")
7785

7886

79-
@task(schedule=every(timedelta(seconds=2)), tags=["skip_db"])
87+
@task(
88+
schedule=every(timedelta(seconds=2)),
89+
tags=("skip_db",),
90+
)
8091
async def ping(context: TaskRun) -> None:
8192
"""A simple scheduled task that ping the broker"""
8293
redis_cli = cast(RedisTaskBroker, context.task_manager.broker).redis_cli

fluid/scheduler/cli.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -129,11 +129,17 @@ def execute_task(log: bool, run_id: str, params: str, **extra: Any) -> None:
129129

130130

131131
@click.command()
132+
@click.option(
133+
"--tags",
134+
"-t",
135+
multiple=True,
136+
help="Only list tasks that have at least one of these tags",
137+
)
132138
@click.pass_context
133-
def ls(ctx: click.Context) -> None:
139+
def ls(ctx: click.Context, tags: tuple[str, ...]) -> None:
134140
"""List all tasks with their schedules"""
135141
task_manager = ctx_task_manager(ctx)
136-
table = asyncio.run(tasks_table(task_manager))
142+
table = asyncio.run(tasks_table(task_manager, tags=set(tags)))
137143
console = Console()
138144
console.print(table)
139145

@@ -219,7 +225,7 @@ async def enable_task(task_manager: TaskManager, task: str, enable: bool) -> Non
219225
raise click.ClickException(f"Task {task} not found") from e
220226

221227

222-
async def tasks_table(task_manager: TaskManager) -> Table:
228+
async def tasks_table(task_manager: TaskManager, tags: set[str] | None = None) -> Table:
223229
task_info = await task_manager.broker.get_tasks_info()
224230
dynamic = {t.name: t for t in task_info}
225231
table = Table(title="Tasks")
@@ -229,16 +235,20 @@ async def tasks_table(task_manager: TaskManager) -> Table:
229235
table.add_column("CPU bound", style="magenta")
230236
table.add_column("Timeout secs", style="green")
231237
table.add_column("Priority", style="magenta")
238+
table.add_column("Tags", style="blue")
232239
table.add_column("Description", style="green")
233240
for name in sorted(task_manager.registry):
234241
task = task_manager.registry[name]
242+
if tags and not tags.intersection(task.tags):
243+
continue
235244
table.add_row(
236245
name,
237246
":white_check_mark:" if dynamic[name].enabled else "[red]:x:",
238-
str(task.schedule),
247+
str(task.schedule) if task.schedule is not None else "[red]:x:",
239248
":white_check_mark:" if task.cpu_bound else "[red]:x:",
240249
str(task.timeout_seconds),
241250
str(task.priority),
251+
", ".join(sorted(task.tags)),
242252
task.short_description,
243253
)
244254
return table

fluid/scheduler/endpoints.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from enum import Enum
22
from typing import Any, Callable, Sequence, cast
33

4-
from fastapi import APIRouter, Depends, FastAPI, HTTPException, Path, Request
4+
from fastapi import APIRouter, Depends, FastAPI, HTTPException, Path, Query, Request
55
from pydantic import BaseModel, Field
66
from typing_extensions import Annotated, Doc
77

@@ -132,8 +132,18 @@ async def queue_task(
132132
return queue_task
133133

134134

135-
async def get_tasks(task_manager: TaskManagerDep) -> list[TaskInfo]:
136-
return await task_manager.broker.get_tasks_info()
135+
async def get_tasks(
136+
task_manager: TaskManagerDep,
137+
tags: Annotated[
138+
list[str] | None,
139+
Query(description="Only return tasks that have at least one of these tags"),
140+
] = None,
141+
) -> list[TaskInfo]:
142+
tasks = await task_manager.broker.get_tasks_info()
143+
if tags:
144+
wanted = set(tags)
145+
tasks = [task for task in tasks if wanted.intersection(task.tags)]
146+
return tasks
137147

138148

139149
async def get_task(

fluid/scheduler/models.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,7 @@ class TaskInfoBase(BaseModel):
185185
module: str = Field(description="Task module")
186186
priority: TaskPriority = Field(description="Task priority")
187187
schedule: str | None = Field(default=None, description="Task schedule")
188+
tags: frozenset[str] = Field(default_factory=frozenset, description="Task tags")
188189

189190

190191
class TaskInfoUpdate(BaseModel):
@@ -327,6 +328,7 @@ def info(self, **params: Any) -> TaskInfo:
327328
module=self.module,
328329
priority=self.priority,
329330
schedule=str(self.schedule) if self.schedule else None,
331+
tags=self.tags,
330332
)
331333
return TaskInfo(**compact_dict(params))
332334

tests/scheduler/test_cli.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,21 @@ def test_cli_ls():
3434
result = runner.invoke(task_manager_cli, ["ls"])
3535
assert result.exit_code == 0
3636
assert result.output
37+
assert "ping" in result.output
38+
assert "dummy" in result.output
39+
40+
41+
def test_cli_ls_tags():
42+
runner = CliRunner()
43+
result = runner.invoke(task_manager_cli, ["ls", "--tags", "skip_db"])
44+
assert result.exit_code == 0
45+
assert "ping" in result.output
46+
assert "dummy" not in result.output
47+
result = runner.invoke(task_manager_cli, ["ls", "-t", "slow", "-t", "fast"])
48+
assert result.exit_code == 0
49+
assert "dummy" in result.output
50+
assert "fast" in result.output
51+
assert "ping" not in result.output
3752

3853

3954
def test_cli_exec_empty():

tests/scheduler/test_endpoints.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,21 @@ async def test_get_tasks(cli: TaskClient) -> None:
1414
tasks = {task["name"]: TaskInfo(**task) for task in data}
1515
dummy = tasks["dummy"]
1616
assert dummy.name == "dummy"
17+
assert set(dummy.tags) == {"test", "slow"}
18+
19+
20+
async def test_get_tasks_by_tags(cli: TaskClient) -> None:
21+
data = await cli.get(f"{cli.url}/tasks?tags=test")
22+
names = {task["name"] for task in data}
23+
assert names == {"fast", "dummy"}
24+
data = await cli.get(f"{cli.url}/tasks?tags=skip_db")
25+
names = {task["name"] for task in data}
26+
assert names == {"ping"}
27+
data = await cli.get(f"{cli.url}/tasks?tags=skip_db&tags=fast")
28+
names = {task["name"] for task in data}
29+
assert names == {"ping", "fast"}
30+
data = await cli.get(f"{cli.url}/tasks?tags=does-not-exist")
31+
assert data == []
1732

1833

1934
async def test_get_tasks_status(cli: TaskClient) -> None:

0 commit comments

Comments
 (0)