diff --git a/codeframe/core/tasks.py b/codeframe/core/tasks.py index 6d451539..b59aa363 100644 --- a/codeframe/core/tasks.py +++ b/codeframe/core/tasks.py @@ -169,18 +169,20 @@ def get(workspace: Workspace, task_id: str) -> Optional[Task]: Task if found, None otherwise """ conn = get_db_connection(workspace) - cursor = conn.cursor() - - cursor.execute( - """ - SELECT id, workspace_id, prd_id, title, description, status, priority, depends_on, estimated_hours, complexity_score, uncertainty_level, created_at, updated_at, github_issue_number, parent_id, lineage, is_leaf, hierarchical_id, requirement_ids, external_url, auto_close_github_issue - FROM tasks - WHERE workspace_id = ? AND id = ? - """, - (workspace.id, task_id), - ) - row = cursor.fetchone() - conn.close() + try: + cursor = conn.cursor() + + cursor.execute( + """ + SELECT id, workspace_id, prd_id, title, description, status, priority, depends_on, estimated_hours, complexity_score, uncertainty_level, created_at, updated_at, github_issue_number, parent_id, lineage, is_leaf, hierarchical_id, requirement_ids, external_url, auto_close_github_issue + FROM tasks + WHERE workspace_id = ? AND id = ? + """, + (workspace.id, task_id), + ) + row = cursor.fetchone() + finally: + conn.close() if not row: return None @@ -288,33 +290,35 @@ def list_tasks( List of Tasks """ conn = get_db_connection(workspace) - cursor = conn.cursor() + try: + cursor = conn.cursor() - if status: - cursor.execute( - """ - SELECT id, workspace_id, prd_id, title, description, status, priority, depends_on, estimated_hours, complexity_score, uncertainty_level, created_at, updated_at, github_issue_number, parent_id, lineage, is_leaf, hierarchical_id, requirement_ids, external_url, auto_close_github_issue - FROM tasks - WHERE workspace_id = ? AND status = ? - ORDER BY priority ASC, created_at ASC - LIMIT ? - """, - (workspace.id, status.value, limit), - ) - else: - cursor.execute( - """ - SELECT id, workspace_id, prd_id, title, description, status, priority, depends_on, estimated_hours, complexity_score, uncertainty_level, created_at, updated_at, github_issue_number, parent_id, lineage, is_leaf, hierarchical_id, requirement_ids, external_url, auto_close_github_issue - FROM tasks - WHERE workspace_id = ? - ORDER BY priority ASC, created_at ASC - LIMIT ? - """, - (workspace.id, limit), - ) + if status: + cursor.execute( + """ + SELECT id, workspace_id, prd_id, title, description, status, priority, depends_on, estimated_hours, complexity_score, uncertainty_level, created_at, updated_at, github_issue_number, parent_id, lineage, is_leaf, hierarchical_id, requirement_ids, external_url, auto_close_github_issue + FROM tasks + WHERE workspace_id = ? AND status = ? + ORDER BY priority ASC, created_at ASC + LIMIT ? + """, + (workspace.id, status.value, limit), + ) + else: + cursor.execute( + """ + SELECT id, workspace_id, prd_id, title, description, status, priority, depends_on, estimated_hours, complexity_score, uncertainty_level, created_at, updated_at, github_issue_number, parent_id, lineage, is_leaf, hierarchical_id, requirement_ids, external_url, auto_close_github_issue + FROM tasks + WHERE workspace_id = ? + ORDER BY priority ASC, created_at ASC + LIMIT ? + """, + (workspace.id, limit), + ) - rows = cursor.fetchall() - conn.close() + rows = cursor.fetchall() + finally: + conn.close() return [_row_to_task(row) for row in rows] @@ -468,6 +472,14 @@ def _dispatch_github_autoclose(workspace: Workspace, task: Task) -> None: # a short-lived CLI process for long at exit. _AUTOCLOSE_TIMEOUT = 10.0 +# Strong references to in-flight auto-close tasks scheduled on a running event +# loop. asyncio only keeps a *weak* reference to a task created with +# ``loop.create_task``, so without this set a task could be garbage-collected +# mid-flight (leaving the issue open and emitting a "Task was destroyed but it +# is pending" warning). Tasks remove themselves via the done-callback once they +# complete. Mirrors ``WebhookNotificationService.send_event_background``. +_background_tasks: set[asyncio.Task] = set() + async def _safe_close_issue(pat: str, repo: str, issue_number: int) -> None: """Close the issue, swallowing every error (best-effort, off the hot path).""" @@ -512,7 +524,13 @@ def _close_issue_background(pat: str, repo: str, issue_number: int) -> None: name=f"gh-autoclose-{issue_number}", ).start() else: - loop.create_task(_safe_close_issue(pat, repo, issue_number)) + # Retain a strong reference until the task finishes so asyncio cannot + # garbage-collect it mid-flight; the done-callback consumes any + # exception (``_safe_close_issue`` already swallows them) and drops the + # reference. + task = loop.create_task(_safe_close_issue(pat, repo, issue_number)) + _background_tasks.add(task) + task.add_done_callback(_background_tasks.discard) def update( @@ -626,18 +644,20 @@ def update_depends_on( now = _utc_now().isoformat() conn = get_db_connection(workspace) - cursor = conn.cursor() - - cursor.execute( - """ - UPDATE tasks - SET depends_on = ?, updated_at = ? - WHERE workspace_id = ? AND id = ? - """, - (json.dumps(depends_on), now, workspace.id, task_id), - ) - conn.commit() - conn.close() + try: + cursor = conn.cursor() + + cursor.execute( + """ + UPDATE tasks + SET depends_on = ?, updated_at = ? + WHERE workspace_id = ? AND id = ? + """, + (json.dumps(depends_on), now, workspace.id, task_id), + ) + conn.commit() + finally: + conn.close() task.depends_on = depends_on task.updated_at = datetime.fromisoformat(now) @@ -670,17 +690,19 @@ def update_requirement_ids( now = _utc_now().isoformat() conn = get_db_connection(workspace) - cursor = conn.cursor() - cursor.execute( - """ - UPDATE tasks - SET requirement_ids = ?, updated_at = ? - WHERE workspace_id = ? AND id = ? - """, - (json.dumps(requirement_ids), now, workspace.id, task_id), - ) - conn.commit() - conn.close() + try: + cursor = conn.cursor() + cursor.execute( + """ + UPDATE tasks + SET requirement_ids = ?, updated_at = ? + WHERE workspace_id = ? AND id = ? + """, + (json.dumps(requirement_ids), now, workspace.id, task_id), + ) + conn.commit() + finally: + conn.close() task.requirement_ids = requirement_ids task.updated_at = datetime.fromisoformat(now) @@ -717,18 +739,20 @@ def delete(workspace: Workspace, task_id: str) -> bool: Use delete_cascade() if you need to clean up dependencies. """ conn = get_db_connection(workspace) - cursor = conn.cursor() - - cursor.execute( - """ - DELETE FROM tasks - WHERE workspace_id = ? AND id = ? - """, - (workspace.id, task_id), - ) - deleted = cursor.rowcount > 0 - conn.commit() - conn.close() + try: + cursor = conn.cursor() + + cursor.execute( + """ + DELETE FROM tasks + WHERE workspace_id = ? AND id = ? + """, + (workspace.id, task_id), + ) + deleted = cursor.rowcount > 0 + conn.commit() + finally: + conn.close() return deleted @@ -743,18 +767,20 @@ def delete_all(workspace: Workspace) -> int: Number of tasks deleted """ conn = get_db_connection(workspace) - cursor = conn.cursor() - - cursor.execute( - """ - DELETE FROM tasks - WHERE workspace_id = ? - """, - (workspace.id,), - ) - deleted_count = cursor.rowcount - conn.commit() - conn.close() + try: + cursor = conn.cursor() + + cursor.execute( + """ + DELETE FROM tasks + WHERE workspace_id = ? + """, + (workspace.id,), + ) + deleted_count = cursor.rowcount + conn.commit() + finally: + conn.close() return deleted_count @@ -769,19 +795,21 @@ def count_by_status(workspace: Workspace) -> dict[str, int]: Dict mapping status string to count """ conn = get_db_connection(workspace) - cursor = conn.cursor() - - cursor.execute( - """ - SELECT status, COUNT(*) as count - FROM tasks - WHERE workspace_id = ? - GROUP BY status - """, - (workspace.id,), - ) - rows = cursor.fetchall() - conn.close() + try: + cursor = conn.cursor() + + cursor.execute( + """ + SELECT status, COUNT(*) as count + FROM tasks + WHERE workspace_id = ? + GROUP BY status + """, + (workspace.id,), + ) + rows = cursor.fetchall() + finally: + conn.close() return {row[0]: row[1] for row in rows} diff --git a/codeframe/notifications/webhook.py b/codeframe/notifications/webhook.py index 166d9e7d..6aadc1f9 100644 --- a/codeframe/notifications/webhook.py +++ b/codeframe/notifications/webhook.py @@ -345,36 +345,3 @@ def _run_send_event_sync(self, payload: dict, url: Optional[str]) -> None: exc_info=True, ) - def send_blocker_notification_background( - self, - blocker_id: int, - question: str, - agent_id: str, - task_id: int, - blocker_type: BlockerType, - created_at: datetime, - ) -> None: - """Fire-and-forget wrapper for send_blocker_notification. - - Launches notification task in background without awaiting result. - Use this method to avoid blocking blocker creation. - - Args: - blocker_id: Blocker database ID - question: Blocker question text - agent_id: Agent that created the blocker - task_id: Associated task ID - blocker_type: SYNC or ASYNC - created_at: Blocker creation timestamp - """ - # Create background task - asyncio.create_task( - self.send_blocker_notification( - blocker_id=blocker_id, - question=question, - agent_id=agent_id, - task_id=task_id, - blocker_type=blocker_type, - created_at=created_at, - ) - ) diff --git a/tests/core/test_tasks_connection_safety.py b/tests/core/test_tasks_connection_safety.py new file mode 100644 index 00000000..af966773 --- /dev/null +++ b/tests/core/test_tasks_connection_safety.py @@ -0,0 +1,186 @@ +"""Robustness tests for codeframe.core.tasks (issue #650). + +Covers two low-severity robustness fixes: + +1. Connection leak-on-exception: every DB path in ``tasks.py`` must release its + SQLite connection even when an error is raised between opening it and the + normal ``conn.close()`` call (i.e. the path is wrapped in ``try/finally``). +2. Fire-and-forget retention: the GitHub auto-close task scheduled by + ``_close_issue_background`` in an async context must be retained by a strong + reference until it completes (asyncio only keeps a weak reference, so an + un-retained task can be garbage-collected mid-flight). +""" + +import asyncio + +import pytest + +from codeframe.core import tasks +from codeframe.core.state_machine import TaskStatus +from codeframe.core.workspace import create_or_load_workspace, get_db_connection + +pytestmark = pytest.mark.v2 + + +@pytest.fixture +def workspace(tmp_path): + return create_or_load_workspace(tmp_path) + + +class _TrackingCursor: + """Wraps a real cursor, optionally raising on execute().""" + + def __init__(self, real, fail_on): + self._real = real + self._fail_on = fail_on + + def execute(self, *args, **kwargs): + if self._fail_on == "execute": + raise RuntimeError("boom-execute") + return self._real.execute(*args, **kwargs) + + def __getattr__(self, name): + return getattr(self._real, name) + + +class _TrackingConn: + """Wraps a real SQLite connection, recording close() and optionally + raising on a chosen operation to simulate a mid-operation failure. + + ``fail_on`` may be ``"execute"`` (raise inside the SQL call, so the + function's body — including branch selection — is reached) or ``"commit"`` + (raise after the write executes, exercising the write path's ``finally``).""" + + def __init__(self, real, fail_on, registry): + self._real = real + self._fail_on = fail_on + self.closed = False + registry.append(self) + + def cursor(self): + return _TrackingCursor(self._real.cursor(), self._fail_on) + + def commit(self): + if self._fail_on == "commit": + raise RuntimeError("boom-commit") + return self._real.commit() + + def close(self): + self.closed = True + self._real.close() + + def __getattr__(self, name): + return getattr(self._real, name) + + +def _patch_connections(monkeypatch, fail_on): + """Make tasks.get_db_connection hand out tracking connections that fail + on ``fail_on``. Returns the registry of every connection handed out.""" + registry: list[_TrackingConn] = [] + + def fake(ws): + return _TrackingConn(get_db_connection(ws), fail_on, registry) + + monkeypatch.setattr(tasks, "get_db_connection", fake) + return registry + + +class TestConnectionReleasedOnException: + """All tasks.py DB paths release the connection on exception (#650).""" + + @pytest.mark.parametrize( + "operation", + [ + pytest.param(lambda ws, tid: tasks.get(ws, tid), id="get"), + pytest.param(lambda ws, tid: tasks.list_tasks(ws), id="list_tasks"), + pytest.param( + lambda ws, tid: tasks.list_tasks(ws, status=TaskStatus.READY), + id="list_tasks_status_filter", + ), + pytest.param( + lambda ws, tid: tasks.count_by_status(ws), id="count_by_status" + ), + pytest.param(lambda ws, tid: tasks.delete(ws, tid), id="delete"), + pytest.param(lambda ws, tid: tasks.delete_all(ws), id="delete_all"), + ], + ) + def test_paths_release_on_execute_error( + self, workspace, monkeypatch, operation + ): + # Failing inside execute() reaches the function body (and, for + # list_tasks, both SQL branches), so the leak guard is exercised in + # situ rather than at connection setup. + task = tasks.create(workspace, title="leak-test") + registry = _patch_connections(monkeypatch, fail_on="execute") + + with pytest.raises(RuntimeError, match="boom-execute"): + operation(workspace, task.id) + + assert registry, "expected at least one connection to be opened" + assert all(c.closed for c in registry), "connection leaked on exception" + + @pytest.mark.parametrize( + "operation", + [ + pytest.param( + lambda ws, tid: tasks.update_depends_on(ws, tid, []), + id="update_depends_on", + ), + pytest.param( + lambda ws, tid: tasks.update_requirement_ids(ws, tid, []), + id="update_requirement_ids", + ), + pytest.param(lambda ws, tid: tasks.delete(ws, tid), id="delete"), + pytest.param(lambda ws, tid: tasks.delete_all(ws), id="delete_all"), + ], + ) + def test_write_paths_release_on_commit_error( + self, workspace, monkeypatch, operation + ): + # Every write path commits; failing on commit (after the SQL executes) + # exercises each write path's finally specifically. update_* also call + # get() first (read, no commit) — that connection must close too. + task = tasks.create(workspace, title="leak-test") + registry = _patch_connections(monkeypatch, fail_on="commit") + + with pytest.raises(RuntimeError, match="boom-commit"): + operation(workspace, task.id) + + assert registry, "expected at least one connection to be opened" + assert all(c.closed for c in registry), "connection leaked on exception" + + +class TestFireAndForgetRetention: + """The auto-close async task is retained until completion (#650).""" + + @pytest.mark.asyncio + async def test_async_close_task_is_retained_then_discarded(self, monkeypatch): + gate = asyncio.Event() + ran = asyncio.Event() + + async def fake_close(pat, repo, issue_number): + ran.set() + await gate.wait() + + monkeypatch.setattr(tasks, "_safe_close_issue", fake_close) + # Snapshot + restore so this test never leaves the module set dirty for + # subsequent tests, even if an assertion fails before the task drains. + saved = set(tasks._background_tasks) + tasks._background_tasks.clear() + try: + tasks._close_issue_background("pat", "owner/repo", 7) + + # Retained synchronously so asyncio cannot GC it mid-flight. + assert len(tasks._background_tasks) == 1 + task = next(iter(tasks._background_tasks)) + + await ran.wait() + gate.set() + await task + await asyncio.sleep(0) # let the done-callback (call_soon) run + + # Discarded after completion so the set does not grow unbounded. + assert task not in tasks._background_tasks + finally: + tasks._background_tasks.clear() + tasks._background_tasks.update(saved) diff --git a/tests/notifications/test_webhook_notifications.py b/tests/notifications/test_webhook_notifications.py index 16fe261f..b23f2b85 100644 --- a/tests/notifications/test_webhook_notifications.py +++ b/tests/notifications/test_webhook_notifications.py @@ -238,23 +238,6 @@ async def test_send_blocker_notification_http_error_status(self, webhook_service assert result is False - def test_send_blocker_notification_background(self, webhook_service): - """Test fire-and-forget background notification.""" - created_at = datetime(2025, 11, 8, 14, 30, 0) - - with patch("asyncio.create_task") as mock_create_task: - webhook_service.send_blocker_notification_background( - blocker_id=123, - question="Critical blocker", - agent_id="backend-worker-1", - task_id=456, - blocker_type=BlockerType.SYNC, - created_at=created_at, - ) - - # Verify background task was created - mock_create_task.assert_called_once() - @pytest.mark.asyncio async def test_send_blocker_notification_correct_payload(self, webhook_service): """Test webhook notification sends correct JSON payload."""