Skip to content

Commit 6bc0e5d

Browse files
committed
Simplify lock selection workflow
1 parent 468fdc6 commit 6bc0e5d

2 files changed

Lines changed: 225 additions & 124 deletions

File tree

src/_pytask/lock.py

Lines changed: 22 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
from _pytask.click import ColoredGroup
1616
from _pytask.console import console
1717
from _pytask.dag import create_dag
18-
from _pytask.dag_utils import task_and_descending_tasks
1918
from _pytask.dag_utils import task_and_preceding_tasks
2019
from _pytask.exceptions import CollectionError
2120
from _pytask.exceptions import ConfigurationError
@@ -25,6 +24,10 @@
2524
from _pytask.lockfile import _build_task_entry
2625
from _pytask.lockfile import _TaskEntry
2726
from _pytask.lockfile import build_portable_task_id
27+
from _pytask.mark import Expression
28+
from _pytask.mark import KeywordMatcher
29+
from _pytask.mark import MarkMatcher
30+
from _pytask.mark import ParseError
2831
from _pytask.node_protocols import PNode
2932
from _pytask.node_protocols import PProvisionalNode
3033
from _pytask.node_protocols import PTask
@@ -37,7 +40,6 @@
3740
if TYPE_CHECKING:
3841
from collections.abc import Callable
3942

40-
from _pytask.dag_graph import DAG
4143
from _pytask.lockfile import LockfileState
4244

4345

@@ -55,10 +57,6 @@ def _validate_confirmation_options(raw_config: dict[str, Any]) -> None:
5557

5658

5759
def _keyword_filter(tasks: list[PTask], expression: str) -> set[str]:
58-
from _pytask.mark import Expression # noqa: PLC0415
59-
from _pytask.mark import KeywordMatcher # noqa: PLC0415
60-
from _pytask.mark import ParseError # noqa: PLC0415
61-
6260
try:
6361
compiled = Expression.compile_(expression)
6462
except ParseError as e:
@@ -73,10 +71,6 @@ def _keyword_filter(tasks: list[PTask], expression: str) -> set[str]:
7371

7472

7573
def _marker_filter(tasks: list[PTask], expression: str) -> set[str]:
76-
from _pytask.mark import Expression # noqa: PLC0415
77-
from _pytask.mark import MarkMatcher # noqa: PLC0415
78-
from _pytask.mark import ParseError # noqa: PLC0415
79-
8074
try:
8175
compiled = Expression.compile_(expression)
8276
except ParseError as e:
@@ -90,32 +84,7 @@ def _marker_filter(tasks: list[PTask], expression: str) -> set[str]:
9084
}
9185

9286

93-
def _expand_task_selection(
94-
task_signatures: set[str],
95-
dag: DAG,
96-
*,
97-
with_ancestors: bool,
98-
with_descendants: bool,
99-
) -> set[str]:
100-
selected = set(task_signatures)
101-
if with_ancestors:
102-
selected |= set(
103-
chain.from_iterable(
104-
task_and_preceding_tasks(signature, dag)
105-
for signature in task_signatures
106-
)
107-
)
108-
if with_descendants:
109-
selected |= set(
110-
chain.from_iterable(
111-
task_and_descending_tasks(signature, dag)
112-
for signature in task_signatures
113-
)
114-
)
115-
return selected
116-
117-
118-
def _select_tasks(session: Session) -> list[PTask]:
87+
def _select_tasks_exact(session: Session) -> list[PTask]:
11988
selected = {task.signature for task in session.tasks}
12089

12190
expression = session.config.get("expression")
@@ -126,11 +95,15 @@ def _select_tasks(session: Session) -> list[PTask]:
12695
if marker_expression:
12796
selected &= _marker_filter(session.tasks, marker_expression)
12897

129-
selected = _expand_task_selection(
130-
selected,
131-
session.dag,
132-
with_ancestors=session.config.get("with_ancestors", False),
133-
with_descendants=session.config.get("with_descendants", False),
98+
return [task for task in session.tasks if task.signature in selected]
99+
100+
101+
def _select_tasks_with_ancestors(session: Session) -> list[PTask]:
102+
selected = {task.signature for task in _select_tasks_exact(session)}
103+
selected |= set(
104+
chain.from_iterable(
105+
task_and_preceding_tasks(signature, session.dag) for signature in selected
106+
)
134107
)
135108
return [task for task in session.tasks if task.signature in selected]
136109

@@ -286,7 +259,11 @@ def _run_lock_command(
286259
session.dag = create_dag(session=session)
287260

288261
if planner_with_tasks is not None:
289-
tasks = _select_tasks(session)
262+
tasks = (
263+
_select_tasks_with_ancestors(session)
264+
if raw_config["subcommand"] == "accept"
265+
else _select_tasks_exact(session)
266+
)
290267
planned_changes = planner_with_tasks(session, tasks)
291268
else:
292269
assert planner is not None
@@ -328,18 +305,6 @@ def lock() -> None:
328305

329306

330307
@lock.command(cls=ColoredCommand)
331-
@click.option(
332-
"--with-ancestors",
333-
is_flag=True,
334-
default=False,
335-
help="Also include preceding tasks of the selected tasks.",
336-
)
337-
@click.option(
338-
"--with-descendants",
339-
is_flag=True,
340-
default=False,
341-
help="Also include descending tasks of the selected tasks.",
342-
)
343308
@click.option(
344309
"--dry-run",
345310
is_flag=True,
@@ -354,7 +319,8 @@ def lock() -> None:
354319
help="Apply the changes without prompting for confirmation.",
355320
)
356321
def accept(**raw_config: Any) -> None:
357-
"""Accept the current state for selected tasks without executing them."""
322+
"""Accept the current state for selected tasks and their ancestors."""
323+
raw_config["subcommand"] = "accept"
358324
sys.exit(
359325
_run_lock_command(
360326
raw_config,
@@ -365,18 +331,6 @@ def accept(**raw_config: Any) -> None:
365331

366332

367333
@lock.command(cls=ColoredCommand)
368-
@click.option(
369-
"--with-ancestors",
370-
is_flag=True,
371-
default=False,
372-
help="Also include preceding tasks of the selected tasks.",
373-
)
374-
@click.option(
375-
"--with-descendants",
376-
is_flag=True,
377-
default=False,
378-
help="Also include descending tasks of the selected tasks.",
379-
)
380334
@click.option(
381335
"--dry-run",
382336
is_flag=True,
@@ -392,6 +346,7 @@ def accept(**raw_config: Any) -> None:
392346
)
393347
def reset(**raw_config: Any) -> None:
394348
"""Remove recorded state for selected tasks."""
349+
raw_config["subcommand"] = "reset"
395350
sys.exit(
396351
_run_lock_command(
397352
raw_config,

0 commit comments

Comments
 (0)