diff --git a/daiv/accounts/context_processors.py b/daiv/accounts/context_processors.py index be190ec15..8298ceba2 100644 --- a/daiv/accounts/context_processors.py +++ b/daiv/accounts/context_processors.py @@ -10,8 +10,8 @@ SECTION_URL_NAMES: dict[str, set[str]] = { "dashboard": {"dashboard"}, - "runs": {"agent_run_new"}, - "activity": {"activity_list", "activity_detail", "activity_stream", "activity_download_md"}, + "activity": {"activity_list", "activity_detail", "activity_stream", "activity_download_md", "agent_run_new"}, + "chat": {"chat_list", "chat_new", "chat_detail"}, "schedules": { "schedule_list", "schedule_create", diff --git a/daiv/accounts/templates/accounts/_sidebar.html b/daiv/accounts/templates/accounts/_sidebar.html index 849e84b6e..3b0cb1e41 100644 --- a/daiv/accounts/templates/accounts/_sidebar.html +++ b/daiv/accounts/templates/accounts/_sidebar.html @@ -9,14 +9,14 @@ {# Primary action — promoted out of the nav list #} - + - {% icon "bolt" "h-3.5 w-3.5" %} + {% icon "squares-plus" "h-3.5 w-3.5" %} - {% translate "Start a run" %} + {% translate "New chat" %} @@ -42,6 +42,13 @@ {% endif %} + + + {% icon "chat-bubble" "h-4 w-4" %} + {% translate "Chat" %} + + diff --git a/daiv/activity/migrations/0009_activity_thread_id.py b/daiv/activity/migrations/0009_activity_thread_id.py new file mode 100644 index 000000000..d052a05ec --- /dev/null +++ b/daiv/activity/migrations/0009_activity_thread_id.py @@ -0,0 +1,30 @@ +# Generated by Django 6.0.4 on 2026-04-23 23:18 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [("activity", "0008_activity_batch_id")] + + operations = [ + migrations.AddField( + model_name="activity", + name="thread_id", + field=models.CharField( + blank=True, + db_index=True, + help_text="LangGraph checkpoint key. Lets chat resume this run.", + max_length=64, + null=True, + unique=True, + verbose_name="thread ID", + ), + ), + migrations.AddConstraint( + model_name="activity", + constraint=models.CheckConstraint( + condition=models.Q(("thread_id__isnull", True)) | models.Q(("thread_id", ""), _negated=True), + name="activity_thread_id_nonempty", + ), + ), + ] diff --git a/daiv/activity/migrations/0010_title.py b/daiv/activity/migrations/0010_title.py new file mode 100644 index 000000000..e3e5a3c46 --- /dev/null +++ b/daiv/activity/migrations/0010_title.py @@ -0,0 +1,15 @@ +# Generated by Django 6.0.4 on 2026-04-28 20:57 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [("activity", "0009_activity_thread_id")] + + operations = [ + migrations.AddField( + model_name="activity", + name="title", + field=models.CharField(blank=True, default="", max_length=120, verbose_name="title"), + ) + ] diff --git a/daiv/activity/models.py b/daiv/activity/models.py index 25c657c10..d61d47eec 100644 --- a/daiv/activity/models.py +++ b/daiv/activity/models.py @@ -88,6 +88,8 @@ class Activity(models.Model): status = models.CharField(_("status"), max_length=10, choices=ActivityStatus.choices, default=ActivityStatus.READY) + title = models.CharField(_("title"), max_length=120, blank=True, default="") + batch_id = models.UUIDField( _("batch ID"), null=True, @@ -96,6 +98,16 @@ class Activity(models.Model): help_text=_("Shared identifier for activities from the same submission."), ) + thread_id = models.CharField( + _("thread ID"), + max_length=64, + null=True, + blank=True, + unique=True, + db_index=True, + help_text=_("LangGraph checkpoint key. Lets chat resume this run."), + ) + external_username = models.CharField( _("external username"), max_length=255, @@ -175,6 +187,14 @@ class Meta: condition=models.Q(external_username__gt=""), ), ] + constraints = [ + # ``thread_id`` is unique=True; "" would collide on the second insert + # under Postgres (which treats NULL as not-equal but "" as a real + # value). Forbid the empty-string sentinel so callers must use NULL. + models.CheckConstraint( + condition=models.Q(thread_id__isnull=True) | ~models.Q(thread_id=""), name="activity_thread_id_nonempty" + ) + ] def __str__(self) -> str: return f"{self.get_trigger_type_display()} on {self.repo_id} ({self.status})" diff --git a/daiv/activity/services.py b/daiv/activity/services.py index 80993e969..22cdad475 100644 --- a/daiv/activity/services.py +++ b/daiv/activity/services.py @@ -9,7 +9,10 @@ from asgiref.sync import async_to_sync from jobs.tasks import run_job_task -from activity.models import Activity +from activity.models import Activity, TriggerType +from automation.titling.tasks import generate_title_task + +_PROMPT_DRIVEN = {TriggerType.API_JOB, TriggerType.MCP_JOB, TriggerType.UI_JOB} if TYPE_CHECKING: from notifications.choices import NotifyOn @@ -96,6 +99,8 @@ def create_activity( external_username: str = "", notify_on: NotifyOn | None = None, batch_id: uuid.UUID | None = None, + thread_id: str | None = None, + title: str = "", ) -> Activity: """Create an Activity record linked to a DBTaskResult. @@ -116,6 +121,8 @@ def create_activity( external_username=external_username, notify_on=notify_on, batch_id=batch_id, + thread_id=thread_id, + title=title[: Activity._meta.get_field("title").max_length], ) @@ -135,6 +142,8 @@ async def acreate_activity( external_username: str = "", notify_on: NotifyOn | None = None, batch_id: uuid.UUID | None = None, + thread_id: str | None = None, + title: str = "", ) -> Activity: """Async variant of create_activity.""" return await Activity.objects.acreate( @@ -152,6 +161,8 @@ async def acreate_activity( external_username=external_username, notify_on=notify_on, batch_id=batch_id, + thread_id=thread_id, + title=title[: Activity._meta.get_field("title").max_length], ) @@ -175,16 +186,27 @@ async def asubmit_batch_runs( _validate(repos) batch_id = uuid.uuid4() - async def _submit_one(target: RepoTarget) -> Activity | BatchSubmitFailure: + schedule_run_base = 0 + if trigger_type == TriggerType.SCHEDULE and scheduled_job is not None: + schedule_run_base = await Activity.objects.filter(scheduled_job=scheduled_job).acount() + + async def _submit_one(idx: int, target: RepoTarget) -> Activity | BatchSubmitFailure: ref_for_task = target.ref or None + thread_id = str(uuid.uuid4()) try: - task = await run_job_task.aenqueue(repo_id=target.repo_id, prompt=prompt, ref=ref_for_task, use_max=use_max) + task = await run_job_task.aenqueue( + repo_id=target.repo_id, prompt=prompt, ref=ref_for_task, use_max=use_max, thread_id=thread_id + ) except Exception as err: # noqa: BLE001 logger.exception("submit_batch_runs: enqueue failed for repo_id=%s batch_id=%s", target.repo_id, batch_id) return BatchSubmitFailure(repo_id=target.repo_id, ref=target.ref, error=f"{type(err).__name__}: {err}") + activity_title = "" + if trigger_type == TriggerType.SCHEDULE and scheduled_job is not None: + activity_title = f"{scheduled_job.name} · run #{schedule_run_base + idx + 1}" + try: - return await acreate_activity( + activity = await acreate_activity( trigger_type=trigger_type, task_result_id=task.id, repo_id=target.repo_id, @@ -196,6 +218,8 @@ async def _submit_one(target: RepoTarget) -> Activity | BatchSubmitFailure: external_username=external_username, notify_on=notify_on, batch_id=batch_id, + thread_id=thread_id, + title=activity_title, ) except Exception: logger.exception( @@ -205,9 +229,22 @@ async def _submit_one(target: RepoTarget) -> Activity | BatchSubmitFailure: ) return BatchSubmitFailure(repo_id=target.repo_id, ref=target.ref, error="ActivityCreationFailed") + if trigger_type in _PROMPT_DRIVEN and prompt: + try: + await generate_title_task.aenqueue( + entity_type="activity", + pk=str(activity.pk), + prompt=prompt, + repo_id=target.repo_id, + ref=target.ref or "", + ) + except Exception: # noqa: BLE001 + logger.exception("Failed to enqueue title task for activity %s", activity.pk) + return activity + # return_exceptions=True guards against BaseException (CancelledError, etc.) aborting the # whole batch; _submit_one already catches Exception itself. - outcomes = await asyncio.gather(*[_submit_one(t) for t in repos], return_exceptions=True) + outcomes = await asyncio.gather(*[_submit_one(i, t) for i, t in enumerate(repos)], return_exceptions=True) activities: list[Activity] = [] failed: list[BatchSubmitFailure] = [] diff --git a/daiv/activity/static/activity/js/activity-stream.js b/daiv/activity/static/activity/js/activity-stream.js index 4f8b03491..f393b953c 100644 --- a/daiv/activity/static/activity/js/activity-stream.js +++ b/daiv/activity/static/activity/js/activity-stream.js @@ -2,15 +2,22 @@ * Alpine.js components for real-time activity status updates via SSE. * * activityStream (list page) — tracks multiple activities in place: - * dotClass(id, fallback) → "status-dot-{variant}" CSS class - * statusClass(id, fallback) → "status-badge-{variant}" CSS class + * dotClass(id, fallback) → object toggling status-dot-{variant} classes + * statusClass(id, fallback) → object toggling status-badge-{variant} classes * statusLabel(id, fallback) → human-readable label * + * Object class maps (rather than a single string) are required so Alpine + * removes the previously rendered variant class when the status transitions — + * otherwise the static server-rendered class lingers alongside the new one + * and the later CSS rule wins. + * * activityDetail (detail page) — subscribes to one activity and reloads the * page on any state change so server-rendered fields (started_at, finished_at, * elapsed counter, duration, timeline dots) reflect the new state. */ document.addEventListener("alpine:init", () => { + const VARIANTS = ["success", "failed", "running", "pending"]; + function statusVariantFor(status) { if (status === "SUCCESSFUL") return "success"; if (status === "FAILED") return "failed"; @@ -25,6 +32,10 @@ document.addEventListener("alpine:init", () => { return "Pending"; } + function variantClassMap(prefix, active) { + return Object.fromEntries(VARIANTS.map((v) => [prefix + v, v === active])); + } + Alpine.data("activityStream", (streamUrl, inFlightIds) => ({ updates: {}, init() { @@ -42,10 +53,10 @@ document.addEventListener("alpine:init", () => { source.onerror = () => source.close(); }, statusClass(id, fallback) { - return "status-badge-" + statusVariantFor(this.updates[id]?.status || fallback); + return variantClassMap("status-badge-", statusVariantFor(this.updates[id]?.status || fallback)); }, dotClass(id, fallback) { - return "status-dot-" + statusVariantFor(this.updates[id]?.status || fallback); + return variantClassMap("status-dot-", statusVariantFor(this.updates[id]?.status || fallback)); }, statusLabel(id, fallback) { const update = this.updates[id]; diff --git a/daiv/activity/static/activity/js/prompt-box.js b/daiv/activity/static/activity/js/prompt-box.js index d81259b3a..e134b2b77 100644 --- a/daiv/activity/static/activity/js/prompt-box.js +++ b/daiv/activity/static/activity/js/prompt-box.js @@ -15,6 +15,7 @@ document.addEventListener("alpine:init", () => { repoPickerUrl = "", branchPickerTemplate = "", conflictMessageTemplate = "Repository already in the list: __LABEL__.", + onChangeEvent = "", }) => ({ repos: (initialRepos || []).map(r => ({ slug: r.repo_id, ref: r.ref || "" })), useMax: initialUseMax, @@ -22,6 +23,7 @@ document.addEventListener("alpine:init", () => { repoPickerUrl, branchPickerTemplate, conflictMessageTemplate, + onChangeEvent, popover: null, editingIndex: null, @@ -45,6 +47,15 @@ document.addEventListener("alpine:init", () => { }); }, + _emitChange() { + if (!this.onChangeEvent) return; + window.dispatchEvent( + new CustomEvent(this.onChangeEvent, { + detail: { repos: this.repos.map((r) => ({ repo_id: r.slug, ref: r.ref || "" })) }, + }), + ); + }, + destroy() { if (this._conflictTimer) clearTimeout(this._conflictTimer); }, @@ -109,6 +120,7 @@ document.addEventListener("alpine:init", () => { if (this.editingIndex === null) this.repos.push(entry); else this.repos.splice(this.editingIndex, 1, entry); this.closePopover(); + this._emitChange(); }, setBranch(ref) { @@ -122,6 +134,7 @@ document.addEventListener("alpine:init", () => { } this.repos[this.editingIndex].ref = ref; this.closePopover(); + this._emitChange(); }, remove(index) { @@ -134,6 +147,7 @@ document.addEventListener("alpine:init", () => { if (this.editingIndex === index) this.closePopover(); else if (index < this.editingIndex) this.editingIndex -= 1; } + this._emitChange(); }, _findConflict(slug, ref, skipIndex) { diff --git a/daiv/activity/templates/activity/_agent_run_fields.html b/daiv/activity/templates/activity/_agent_run_fields.html index 5db7aba18..507c3895c 100644 --- a/daiv/activity/templates/activity/_agent_run_fields.html +++ b/daiv/activity/templates/activity/_agent_run_fields.html @@ -1,25 +1,19 @@ -{% load i18n icon_tags %} -{% url 'codebase:picker-repositories' as repo_picker_url %} -{% url 'codebase:picker-branches' slug='__SLUG__' as branch_picker_template %} -{% trans "Choose repository" as choose_repo_label %} -{% trans "Search repositories..." as search_repos_placeholder %} -{% trans "Search branches..." as search_branches_placeholder %} -{% trans "default" as default_branch_label %} +{% load i18n %} {% trans "Repository already in the list: __LABEL__." as conflict_message_template %}
- + +
+
+ + + {% trans "to send" %} +
+
+ + +
+
+ diff --git a/daiv/chat/templates/chat/_rail.html b/daiv/chat/templates/chat/_rail.html new file mode 100644 index 000000000..c7a51aeba --- /dev/null +++ b/daiv/chat/templates/chat/_rail.html @@ -0,0 +1,39 @@ +{% load i18n %} + diff --git a/daiv/chat/templates/chat/chat_detail.html b/daiv/chat/templates/chat/chat_detail.html new file mode 100644 index 000000000..b0b780ca7 --- /dev/null +++ b/daiv/chat/templates/chat/chat_detail.html @@ -0,0 +1,178 @@ +{% extends "base_app.html" %} +{% load static i18n %} + +{% block title %}{% if thread.title %}{{ thread.title }} — {% endif %}DAIV Chat{% endblock %} +{% block container_width %}max-w-6xl flex flex-col min-h-full{% endblock %} + +{% block head_extra %} + + + + +{% endblock %} + +{% block alpine_plugins %} + + + + + + + + +{% endblock %} + +{% block breadcrumb %}{% include "accounts/_breadcrumb.html" %}{% endblock %} + +{% block app_content %} + {{ turns|json_script:"chat-initial-turns" }} + {{ merge_request|json_script:"chat-initial-merge-request" }} + +
+ +
+ {% if expired %} + + {% endif %} + + {# Responsive summary strip shown below 1100px #} +
+ · + · + + + +
+ +
+ {# Empty state — pick a repo first, then the composer fades in. #} + + + {# Turns #} + + +
+ + +
+
+ + {% if not expired %} + {% include "chat/_composer.html" %} + {% endif %} +
+ + {% include "chat/_rail.html" %} +
+{% endblock %} diff --git a/daiv/chat/templates/chat/chat_list.html b/daiv/chat/templates/chat/chat_list.html new file mode 100644 index 000000000..e3ffd29be --- /dev/null +++ b/daiv/chat/templates/chat/chat_list.html @@ -0,0 +1,48 @@ +{% extends "base_app.html" %} +{% load i18n %} + +{% block breadcrumb %}{% include "accounts/_breadcrumb.html" %}{% endblock %} + +{% block app_content %} +
+

{% trans "Chat" %}

+ {% trans "New chat" %} +
+ +{% if threads %} + + {% include "accounts/_pagination.html" %} +{% else %} +
+

{% trans "No conversations yet." %}

+ {% trans "Start a chat" %} +
+{% endif %} +{% endblock %} diff --git a/daiv/chat/turns.py b/daiv/chat/turns.py new file mode 100644 index 000000000..34449520b --- /dev/null +++ b/daiv/chat/turns.py @@ -0,0 +1,185 @@ +"""Normalize LangChain messages into the chronological ``turns[].segments[]`` shape +consumed by the chat page's Alpine renderer. + +Kept as a pure helper module (no Django imports) so the transformation is trivially +unit-testable without database or view fixtures. +""" + +from __future__ import annotations + +import json +import logging +from typing import Any + +logger = logging.getLogger("daiv.chat") + +_TOOL_USE_BLOCK_TYPES = frozenset({"tool_use", "tool_call"}) +_THINKING_BLOCK_TYPES = frozenset({"thinking", "reasoning"}) +_SKILL_TOOL_NAME = "skill" + + +def build_turns(messages: list[Any]) -> list[dict[str, Any]]: + """Walk a LangChain message list once, producing a list of turns. + + Each turn is ``{"id", "role", "segments": [...]}``. Segments are either + ``{"type": "text", "content"}`` or ``{"type": "tool_call", "id", "name", + "args", "result", "status"}``. Tool results from subsequent ``ToolMessage`` + entries are paired back onto their originating tool-call segment via the + ``tool_call_id``. + + The ``skill`` tool injects a synthetic ``HumanMessage`` carrying the + resolved SKILL.md body right after its ``ToolMessage``. On reload we fold + that body into the skill call's ``result`` instead of rendering it as a + user turn — it's agent scaffolding, not something the human typed. + """ + turns: list[dict[str, Any]] = [] + tool_index: dict[str, tuple[int, int]] = {} + skill_tool_ids: set[str] = set() + pending_skill_tc_id: str | None = None + + for m in messages: + mtype = (getattr(m, "type", None) or getattr(m, "role", "") or "").lower() + if mtype in ("human", "user"): + if pending_skill_tc_id is not None: + _fold_skill_body(m, turns, tool_index, pending_skill_tc_id) + pending_skill_tc_id = None + continue + turns.append(_build_user_turn(m)) + elif mtype in ("ai", "assistant"): + turn = _build_assistant_turn(m) + turn_idx = len(turns) + for seg_idx, seg in enumerate(turn["segments"]): + if seg["type"] == "tool_call" and seg["id"]: + tool_index[seg["id"]] = (turn_idx, seg_idx) + if seg["name"] == _SKILL_TOOL_NAME: + skill_tool_ids.add(seg["id"]) + turns.append(turn) + pending_skill_tc_id = None + elif mtype in ("tool", "tool_result"): + tc_id = _attach_tool_result(m, turns, tool_index) + pending_skill_tc_id = tc_id if tc_id in skill_tool_ids else None + + return turns + + +def _attach_tool_result(m: Any, turns: list[dict[str, Any]], tool_index: dict[str, tuple[int, int]]) -> str | None: + tc_id = getattr(m, "tool_call_id", None) or ((getattr(m, "additional_kwargs", None) or {}).get("tool_call_id")) + if not tc_id or tc_id not in tool_index: + logger.warning("chat: dropping orphan ToolMessage with tool_call_id=%r", tc_id) + return None + + t_idx, s_idx = tool_index[tc_id] + content = getattr(m, "content", "") + if isinstance(content, list): + content = "\n".join(block.get("text", "") if isinstance(block, dict) else str(block) for block in content) + turns[t_idx]["segments"][s_idx]["result"] = str(content or "") + return tc_id + + +def _fold_skill_body(m: Any, turns: list[dict[str, Any]], tool_index: dict[str, tuple[int, int]], tc_id: str) -> None: + """Replace the skill tool's placeholder result with the injected SKILL.md body.""" + t_idx, s_idx = tool_index[tc_id] + content = getattr(m, "content", "") + if isinstance(content, list): + content = "".join( + block.get("text", "") for block in content if isinstance(block, dict) and block.get("type") == "text" + ) + turns[t_idx]["segments"][s_idx]["result"] = str(content or "") + + +def _reasoning_from_additional_kwargs(additional_kwargs: Any) -> str: + """Extract reasoning text from an AIMessage's ``additional_kwargs``. + + Mirrors the shapes ``ag_ui_langgraph.resolve_reasoning_content`` recognizes: + OpenAI legacy ``reasoning.summary[*].text`` and DeepSeek/Qwen/xAI + ``reasoning_content`` (string). + """ + if not isinstance(additional_kwargs, dict): + return "" + rc = additional_kwargs.get("reasoning_content") + if isinstance(rc, str) and rc.strip(): + return rc + reasoning = additional_kwargs.get("reasoning") + if isinstance(reasoning, dict): + summary = reasoning.get("summary") or [] + if isinstance(summary, list): + joined = "\n\n".join( + str(item.get("text", "")) for item in summary if isinstance(item, dict) and item.get("text") + ) + if joined.strip(): + return joined + return "" + + +def _build_user_turn(m: Any) -> dict[str, Any]: + content = getattr(m, "content", "") + if isinstance(content, list): + text = "".join( + block.get("text", "") for block in content if isinstance(block, dict) and block.get("type") == "text" + ) + else: + text = str(content or "") + return {"id": getattr(m, "id", "") or "", "role": "user", "segments": [{"type": "text", "content": text}]} + + +def _build_assistant_turn(m: Any) -> dict[str, Any]: + content = getattr(m, "content", "") + tool_calls = getattr(m, "tool_calls", None) or [] + tc_by_id: dict[str, Any] = {} + for tc in tool_calls: + tc_id = tc.get("id") if isinstance(tc, dict) else getattr(tc, "id", None) + if tc_id: + tc_by_id[tc_id] = tc + + segments: list[dict[str, Any]] = [] + + # Some providers (DeepSeek/Qwen/xAI, certain OpenRouter routes) return + # reasoning out-of-band in ``additional_kwargs`` instead of as a content + # block. Surface it as a leading thinking segment so refresh matches the + # streamed view. + extra_thought = _reasoning_from_additional_kwargs(getattr(m, "additional_kwargs", None)) + if extra_thought: + segments.append({"type": "thinking", "content": extra_thought}) + + if isinstance(content, list): + for block in content: + btype = block.get("type") if isinstance(block, dict) else getattr(block, "type", None) + if btype == "text": + text = block.get("text") if isinstance(block, dict) else getattr(block, "text", "") + if text: + segments.append({"type": "text", "content": text}) + elif btype in _THINKING_BLOCK_TYPES: + # Anthropic uses ``thinking``; the newer LangChain standard uses ``reasoning``. + thought = ( + (block.get("thinking") or block.get("reasoning") or block.get("text") or "") + if isinstance(block, dict) + else "" + ) + if thought: + segments.append({"type": "thinking", "content": thought}) + elif btype in _TOOL_USE_BLOCK_TYPES: + tc_id = block.get("id") if isinstance(block, dict) else getattr(block, "id", None) + canonical = tc_by_id.get(tc_id, block) + segments.append(_tool_call_segment(canonical)) + else: + if isinstance(content, str) and content.strip(): + segments.append({"type": "text", "content": content}) + for tc in tool_calls: + segments.append(_tool_call_segment(tc)) + + return {"id": getattr(m, "id", "") or "", "role": "assistant", "segments": segments} + + +def _tool_call_segment(tc: Any, *, status: str = "done") -> dict[str, Any]: + if isinstance(tc, dict): + tc_id = tc.get("id") or "" + tc_name = tc.get("name") or "" + args = tc.get("args", tc.get("input", tc.get("arguments", ""))) + else: + tc_id = getattr(tc, "id", "") or "" + tc_name = getattr(tc, "name", "") or "" + args = getattr(tc, "args", None) or getattr(tc, "input", None) or "" + + args_str = json.dumps(args) if isinstance(args, (dict, list)) else str(args or "") + + return {"type": "tool_call", "id": tc_id, "name": tc_name, "args": args_str, "result": None, "status": status} diff --git a/daiv/chat/urls.py b/daiv/chat/urls.py new file mode 100644 index 000000000..43a140273 --- /dev/null +++ b/daiv/chat/urls.py @@ -0,0 +1,10 @@ +from django.urls import path + +from chat.views import ChatThreadDetailView, ChatThreadFromActivityView, ChatThreadListView + +urlpatterns = [ + path("", ChatThreadListView.as_view(), name="chat_list"), + path("new/", ChatThreadDetailView.as_view(), name="chat_new"), + path("/", ChatThreadDetailView.as_view(), name="chat_detail"), + path("from-activity//", ChatThreadFromActivityView.as_view(), name="chat_from_activity"), +] diff --git a/daiv/chat/views.py b/daiv/chat/views.py new file mode 100644 index 000000000..256482f79 --- /dev/null +++ b/daiv/chat/views.py @@ -0,0 +1,99 @@ +from __future__ import annotations + +from typing import Any + +from django.contrib.auth.mixins import LoginRequiredMixin +from django.http import Http404, HttpResponseGone +from django.shortcuts import get_object_or_404, redirect +from django.urls import reverse +from django.views.generic import DetailView, ListView, View + +from activity.models import Activity +from asgiref.sync import async_to_sync + +from accounts.mixins import BreadcrumbMixin +from chat.models import ChatThread +from chat.repo_state import aget_existing_mr_payload, mr_to_payload +from chat.turns import build_turns +from core.checkpointer import open_checkpointer + + +async def _ahydrate(thread_id: str) -> tuple[list[Any], bool, dict | None]: + """Return (messages, expired, merge_request_payload) for a thread.""" + async with open_checkpointer() as cp: + tup = await cp.aget_tuple({"configurable": {"thread_id": thread_id}}) + if tup is None: + return [], True, None + channel_values = (tup.checkpoint or {}).get("channel_values", {}) + messages = channel_values.get("messages", []) + return messages, False, mr_to_payload(channel_values.get("merge_request")) + + +class ChatThreadListView(LoginRequiredMixin, BreadcrumbMixin, ListView): + model = ChatThread + template_name = "chat/chat_list.html" + context_object_name = "threads" + paginate_by = 25 + + def get_queryset(self): + return ChatThread.objects.for_user(self.request.user) + + def get_breadcrumbs(self): + return [{"label": "Chat", "url": None}] + + +class ChatThreadDetailView(LoginRequiredMixin, BreadcrumbMixin, DetailView): + """Renders the chat page for a specific thread, or the empty state when no + ``thread_id`` URL kwarg is present (the ``chat_new`` route). + """ + + model = ChatThread + template_name = "chat/chat_detail.html" + context_object_name = "thread" + pk_url_kwarg = "thread_id" + + def get_queryset(self): + return ChatThread.objects.for_user(self.request.user) + + def get_object(self, queryset=None): + if "thread_id" not in self.kwargs: + return None + return super().get_object(queryset) + + def get_context_data(self, **kwargs: Any) -> dict[str, Any]: + ctx = super().get_context_data(**kwargs) + thread = ctx.setdefault("thread", None) + if thread is None: + ctx.update({"turns": [], "expired": False, "active_run_id": "", "merge_request": None}) + return ctx + messages_history, expired, merge_request = async_to_sync(_ahydrate)(thread.thread_id) + if merge_request is None: + merge_request = async_to_sync(aget_existing_mr_payload)(thread.repo_id, thread.ref) + ctx["turns"] = build_turns(messages_history) + ctx["expired"] = expired + ctx["active_run_id"] = thread.active_run_id + ctx["merge_request"] = merge_request + return ctx + + def get_breadcrumbs(self): + chat_url = reverse("chat_list") + thread = getattr(self, "object", None) + if thread is None: + return [{"label": "Chat", "url": chat_url}, {"label": "New", "url": None}] + return [{"label": "Chat", "url": chat_url}, {"label": thread.title or thread.thread_id[:8], "url": None}] + + +class ChatThreadFromActivityView(LoginRequiredMixin, View): + """Bridge: create (or reuse) a ChatThread for an activity and redirect to it.""" + + def post(self, request, *, activity_id): + activity = get_object_or_404(Activity, pk=activity_id, user=request.user) + if not activity.thread_id: + raise Http404 + + messages, expired, _mr = async_to_sync(_ahydrate)(activity.thread_id) + if expired: + return HttpResponseGone("This run's state has expired. Start a fresh chat from its prompt.") + + thread, _ = async_to_sync(ChatThread.aget_or_create_from_activity)(request.user, activity) + return redirect("chat_detail", thread_id=thread.thread_id) diff --git a/daiv/codebase/clients/base.py b/daiv/codebase/clients/base.py index fdd9291a1..d2708a9ab 100644 --- a/daiv/codebase/clients/base.py +++ b/daiv/codebase/clients/base.py @@ -233,6 +233,23 @@ def get_bot_commit_email(self) -> str: def get_merge_request(self, repo_id: str, merge_request_id: int) -> MergeRequest: pass + @abc.abstractmethod + def get_merge_request_by_branches( + self, repo_id: str, source_branch: str, target_branch: str + ) -> MergeRequest | None: + """ + Return the first open merge request for this source/target branch pair, or ``None``. + + Args: + repo_id: The repository ID. + source_branch: The source branch. + target_branch: The target branch. + + Returns: + The first open MR matching the branch pair, or ``None`` if none exist. + """ + pass + @abc.abstractmethod def get_merge_request_comment(self, repo_id: str, merge_request_id: int, comment_id: str) -> Discussion: pass diff --git a/daiv/codebase/clients/github/api/callbacks.py b/daiv/codebase/clients/github/api/callbacks.py index 952909480..44e655661 100644 --- a/daiv/codebase/clients/github/api/callbacks.py +++ b/daiv/codebase/clients/github/api/callbacks.py @@ -105,6 +105,7 @@ async def process_callback(self): use_max=self.issue.has_max_label(), user=daiv_user, external_username=self.sender.username, + title=self.issue.title, ) except Exception: logger.exception("Failed to create activity for issue %s#%s", self.repository.full_name, self.issue.number) @@ -172,6 +173,7 @@ async def process_callback(self): use_max=self.issue.has_max_label(), user=daiv_user, external_username=self.comment.user.username, + title=self.issue.title, ) except Exception: logger.exception( @@ -200,6 +202,7 @@ async def process_callback(self): use_max=self.issue.has_max_label(), user=daiv_user, external_username=self.comment.user.username, + title=self.issue.title, ) except Exception: logger.exception( diff --git a/daiv/codebase/clients/github/client.py b/daiv/codebase/clients/github/client.py index ec2f0bc8b..83074049e 100644 --- a/daiv/codebase/clients/github/client.py +++ b/daiv/codebase/clients/github/client.py @@ -440,6 +440,39 @@ def update_or_create_merge_request( draft=pr.draft, ) + def get_merge_request_by_branches( + self, repo_id: str, source_branch: str, target_branch: str + ) -> MergeRequest | None: + """ + Return the open pull request for this source/target branch pair, or ``None``. + + Args: + repo_id: The repository ID. + source_branch: The source branch. + target_branch: The target branch. + + Returns: + The pull request if one open PR matches, otherwise ``None``. + """ + repo = self.client.get_repo(repo_id, lazy=True) + prs = repo.get_pulls(state="open", base=target_branch, head=source_branch) + pr = next(iter(prs), None) + if pr is None: + return None + return MergeRequest( + repo_id=repo_id, + merge_request_id=pr.number, + source_branch=pr.head.ref, + target_branch=pr.base.ref, + title=pr.title, + description=pr.body or "", + labels=[label.name for label in pr.labels], + web_url=pr.html_url, + sha=pr.head.sha, + author=User(id=pr.user.id, username=pr.user.login, name=pr.user.name), + draft=pr.draft, + ) + def update_merge_request( self, repo_id: str, diff --git a/daiv/codebase/clients/gitlab/api/callbacks.py b/daiv/codebase/clients/gitlab/api/callbacks.py index c464c77e8..bf3cd5bee 100644 --- a/daiv/codebase/clients/gitlab/api/callbacks.py +++ b/daiv/codebase/clients/gitlab/api/callbacks.py @@ -123,6 +123,7 @@ async def process_callback(self): use_max=self.object_attributes.has_max_label(), user=daiv_user, external_username=self.user.username, + title=self.object_attributes.title, ) except Exception: logger.exception( @@ -200,6 +201,7 @@ async def process_callback(self): use_max=self.issue.has_max_label(), user=daiv_user, external_username=self.user.username, + title=self.issue.title, ) except Exception: logger.exception( @@ -230,6 +232,7 @@ async def process_callback(self): use_max=self.merge_request.has_max_label(), user=daiv_user, external_username=self.user.username, + title=self.merge_request.title, ) except Exception: logger.exception( diff --git a/daiv/codebase/clients/gitlab/client.py b/daiv/codebase/clients/gitlab/client.py index 414fddc2a..8c34f95e8 100644 --- a/daiv/codebase/clients/gitlab/client.py +++ b/daiv/codebase/clients/gitlab/client.py @@ -485,6 +485,29 @@ def update_or_create_merge_request( return self._serialize_merge_request(repo_id, merge_request) raise e + def get_merge_request_by_branches( + self, repo_id: str, source_branch: str, target_branch: str + ) -> MergeRequest | None: + """ + Return the open merge request for this source/target branch pair, or ``None``. + + Args: + repo_id: The repository ID. + source_branch: The source branch. + target_branch: The target branch. + + Returns: + The merge request if one open MR matches, otherwise ``None``. + """ + project = self.client.projects.get(repo_id, lazy=True) + merge_requests = project.mergerequests.list( + source_branch=source_branch, target_branch=target_branch, state="opened", iterator=True + ) + merge_request = next(merge_requests, None) + if merge_request is None: + return None + return self._serialize_merge_request(repo_id, merge_request) + def update_merge_request( self, repo_id: str, diff --git a/daiv/codebase/clients/swe.py b/daiv/codebase/clients/swe.py index 53b36b657..779b8c3e5 100644 --- a/daiv/codebase/clients/swe.py +++ b/daiv/codebase/clients/swe.py @@ -287,6 +287,12 @@ def get_merge_request(self, repo_id: str, merge_request_id: int) -> MergeRequest """Not supported for SWE client.""" raise NotImplementedError("SWERepoClient does not support merge requests") + def get_merge_request_by_branches( + self, repo_id: str, source_branch: str, target_branch: str + ) -> MergeRequest | None: + """Not supported for SWE client.""" + raise NotImplementedError("SWERepoClient does not support merge requests") + def get_merge_request_comment(self, repo_id: str, merge_request_id: int, comment_id: str) -> Discussion: """Not supported for SWE client.""" raise NotImplementedError("SWERepoClient does not support merge request comments") diff --git a/daiv/codebase/managers/issue_addressor.py b/daiv/codebase/managers/issue_addressor.py index 73900ba99..1f5b5ddaa 100644 --- a/daiv/codebase/managers/issue_addressor.py +++ b/daiv/codebase/managers/issue_addressor.py @@ -1,16 +1,15 @@ import logging from typing import TYPE_CHECKING -from django.conf import settings as django_settings from django.template.loader import render_to_string from langchain_core.messages import HumanMessage -from langgraph.checkpoint.redis.aio import AsyncRedisSaver from automation.agent.graph import create_daiv_agent from automation.agent.usage_tracking import build_usage_summary, track_usage_metadata from automation.agent.utils import build_langsmith_config, extract_text_content, get_daiv_agent_kwargs from codebase.base import GitPlatform +from core.checkpointer import open_checkpointer from core.constants import BOT_NAME from core.utils import generate_uuid @@ -90,10 +89,7 @@ async def _address_issue(self) -> AgentResult: HumanMessage(name=self.issue.author.username, id=str(self.issue.iid), content=message_content) ) - async with AsyncRedisSaver.from_conn_string( - django_settings.DJANGO_REDIS_CHECKPOINT_URL, - ttl={"default_ttl": django_settings.DJANGO_REDIS_CHECKPOINT_TTL_MINUTES}, - ) as checkpointer: + async with open_checkpointer() as checkpointer: agent_kwargs = get_daiv_agent_kwargs( model_config=self.ctx.config.models.agent, use_max=self.issue.has_max_label() ) diff --git a/daiv/codebase/managers/review_addressor.py b/daiv/codebase/managers/review_addressor.py index 36baaa148..16f50cf08 100644 --- a/daiv/codebase/managers/review_addressor.py +++ b/daiv/codebase/managers/review_addressor.py @@ -3,11 +3,9 @@ import logging from typing import TYPE_CHECKING -from django.conf import settings as django_settings from django.template.loader import render_to_string from langchain_core.messages import HumanMessage -from langgraph.checkpoint.redis.aio import AsyncRedisSaver from unidiff import LINE_TYPE_CONTEXT, Hunk, PatchedFile from unidiff.patch import Line @@ -15,6 +13,7 @@ from automation.agent.usage_tracking import build_usage_summary, track_usage_metadata from automation.agent.utils import build_langsmith_config, extract_text_content, get_daiv_agent_kwargs from codebase.base import GitPlatform, MergeRequest, Note, NoteDiffPosition, NoteDiffPositionType, NotePositionType +from core.checkpointer import open_checkpointer from core.constants import BOT_NAME from core.utils import generate_uuid @@ -229,10 +228,7 @@ async def _address_comments(self) -> AgentResult: self.ctx.repository.slug, self.merge_request.merge_request_id, self.mention_comment_id ) - async with AsyncRedisSaver.from_conn_string( - django_settings.DJANGO_REDIS_CHECKPOINT_URL, - ttl={"default_ttl": django_settings.DJANGO_REDIS_CHECKPOINT_TTL_MINUTES}, - ) as checkpointer: + async with open_checkpointer() as checkpointer: agent_kwargs = get_daiv_agent_kwargs(model_config=self.ctx.config.models.agent) daiv_agent = await create_daiv_agent( ctx=self.ctx, checkpointer=checkpointer, store=self.store, **agent_kwargs diff --git a/daiv/codebase/templates/codebase/_repo_picker.html b/daiv/codebase/templates/codebase/_repo_picker.html new file mode 100644 index 000000000..5d1f56440 --- /dev/null +++ b/daiv/codebase/templates/codebase/_repo_picker.html @@ -0,0 +1,188 @@ +{% load i18n icon_tags %} +{% url 'codebase:picker-repositories' as repo_picker_url %} +{% url 'codebase:picker-branches' slug='__SLUG__' as branch_picker_template %} +{% trans "Choose repository" as choose_repo_label %} +{% trans "Search repositories..." as search_repos_placeholder %} +{% trans "Search branches..." as search_branches_placeholder %} +{% trans "default" as default_branch_label %} +{% trans "Repository already in the list: __LABEL__." as conflict_message_template %} + +{% comment %} +Reusable repo + branch picker. Renders the chip list, both popovers, and (optionally) +the hidden form inputs that submit the selection. + +Parameters: +- max_repos (default 1) — chip-list cap +- initial_repos (default "[]") — JSON list of {repo_id, ref} to seed the picker +- field_name (default "repos") — hidden input name; set to empty to skip submission +- required (default True) — emit the sr-only guard input that blocks form submit when empty +- field_errors (default None) — Django ErrorList for the repos field +- extra_errors (default None) — Django ErrorList for sibling fields rendered inside the same box +- with_x_data (default True) — wrap in `x-data='promptBox(...)'`; set False when a parent already provides scope +- on_change_event (default "") — when set, promptBox dispatches this window event with detail={repos} on every chip mutation + +`prompt-box.js` must be loaded for the Alpine state to bind. +{% endcomment %} + +{% if with_x_data|default_if_none:True %} +
+{% endif %} + +
+ + + + + +
+ + {% comment %} + Both popovers stay in the DOM with hx-get pre-set so HTMX registers them during its + initial sweep. openBranchPicker rewrites the branch input's hx-get (via setAttribute) + before firing `refresh`, so the __SLUG__ placeholder is replaced with the real slug. + {% endcomment %} +
+ +
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ +
+ +
+
+
+
+
+
+
+
+
+
+ + {% if field_name|default:"repos" %} + + {% endif %} + + {% if required|default_if_none:True %} + {# sr-only required input: browser blocks submit when empty; name is not a form field so POST value is discarded. #} + + {% endif %} + + {% if field_errors or extra_errors or show_conflict_message|default_if_none:True %} +
    + {% for err in extra_errors %}
  • {{ err }}
  • {% endfor %} + {% for err in field_errors %}
  • {{ err }}
  • {% endfor %} +
  • +
+ {% endif %} + +{% if with_x_data|default_if_none:True %} +
+{% endif %} diff --git a/daiv/codebase/utils.py b/daiv/codebase/utils.py index 8ba824042..227b07bc9 100644 --- a/daiv/codebase/utils.py +++ b/daiv/codebase/utils.py @@ -1,5 +1,6 @@ import contextlib import fnmatch +import logging import re import tempfile from pathlib import Path @@ -9,10 +10,13 @@ from langchain_core.messages import AIMessage, AnyMessage, HumanMessage from unidiff import PatchSet from unidiff.constants import LINE_TYPE_CONTEXT +from unidiff.errors import UnidiffParseError from unidiff.patch import Line from core.constants import BOT_NAME +logger = logging.getLogger(__name__) + if TYPE_CHECKING: from git import Repo @@ -118,6 +122,37 @@ def redact_diff_content( return str(patch_set) if not as_patch_set else patch_set +def files_changed_from_patch(patch: str | None) -> list[dict[str, str]]: + """Derive a ``{path, op[, from_path]}`` list from a unified diff. + + The sandbox reports every workspace mutation regardless of how it happened + (bash ``rm``/``mv``, scripts, ``find -delete``, …), which is what lets the + chat rail surface them alongside ``edit_file``/``write_file`` tool calls. + """ + if not patch or not patch.strip(): + return [] + try: + patch_set = PatchSet.from_string(patch) + except UnidiffParseError: + logger.warning("Failed to parse patch for files_changed", exc_info=True) + return [] + + files: list[dict[str, str]] = [] + for patched_file in patch_set: + entry: dict[str, str] = {"path": patched_file.path} + if patched_file.is_added_file: + entry["op"] = "added" + elif patched_file.is_removed_file: + entry["op"] = "deleted" + elif patched_file.is_rename: + entry["op"] = "renamed" + entry["from_path"] = patched_file.source_file.removeprefix("a/").removeprefix("b/") + else: + entry["op"] = "modified" + files.append(entry) + return files + + class GitManager: """ Manager for interacting with a Git repository. diff --git a/daiv/core/api/__init__.py b/daiv/core/api/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/daiv/core/api/throttling.py b/daiv/core/api/throttling.py new file mode 100644 index 000000000..3aa2429e0 --- /dev/null +++ b/daiv/core/api/throttling.py @@ -0,0 +1,18 @@ +from ninja.throttling import AuthRateThrottle + +from core.site_settings import site_settings + + +class JobsRateThrottle(AuthRateThrottle): + """Per-user throttle for endpoints that kick off agent runs. + + Rate is read from ``site_settings.jobs_throttle_rate`` at call time so the + admin can change it without a redeploy. Both the API job endpoint and the + chat completion endpoint use this — both spin up sandbox/agent work on + each call, so a single per-user budget is the right default. + """ + + THROTTLE_RATES: dict[str, str | None] = {} + + def get_rate(self) -> str | None: + return site_settings.jobs_throttle_rate diff --git a/daiv/core/checkpointer.py b/daiv/core/checkpointer.py new file mode 100644 index 000000000..3fcb065a0 --- /dev/null +++ b/daiv/core/checkpointer.py @@ -0,0 +1,23 @@ +from __future__ import annotations + +from contextlib import asynccontextmanager +from typing import TYPE_CHECKING + +from django.conf import settings + +from langgraph.checkpoint.redis.aio import AsyncRedisSaver + +if TYPE_CHECKING: + from collections.abc import AsyncIterator + + +@asynccontextmanager +async def open_checkpointer() -> AsyncIterator[AsyncRedisSaver]: + """Yield a configured AsyncRedisSaver using project settings. + + Single source of truth for the Redis connection + TTL. + """ + async with AsyncRedisSaver.from_conn_string( + settings.DJANGO_REDIS_CHECKPOINT_URL, ttl={"default_ttl": settings.DJANGO_REDIS_CHECKPOINT_TTL_MINUTES} + ) as cp: + yield cp diff --git a/daiv/core/log_filters.py b/daiv/core/log_filters.py new file mode 100644 index 000000000..e55639a4b --- /dev/null +++ b/daiv/core/log_filters.py @@ -0,0 +1,32 @@ +import asyncio +import logging + +# Substrings of the asyncio/asgiref log lines that fire when an SSE client +# disconnects mid-stream. Anything outside this allow-list keeps logging so +# unrelated cancellation bugs (timeouts, deliberate cancels, shutdown races) +# remain visible. +_ALLOWED_CANCEL_FRAGMENTS = ("Task was destroyed but it is pending", "Exception in callback", "was never awaited") + + +class SuppressCancelledError(logging.Filter): + """Drop the specific CancelledError records produced by ASGI client disconnects. + + Why: Django ASGI wraps sync middleware with ``sync_to_async``; when a chat SSE + client closes the response, the cancellation propagates through every sync + middleware and asgiref/asyncio logs the resulting traceback. The cleanup + itself runs correctly in the streamer's ``finally``. + + Scoped narrowly: only matches records on the ``asyncio`` / ``concurrent.futures`` + loggers whose exc_info is a CancelledError AND whose formatted message + contains one of the known noise fragments. Anything else (including + legitimate cancellation diagnostics from our own code) passes through. + """ + + def filter(self, record: logging.LogRecord) -> bool: + exc_info = record.exc_info + if not (exc_info and exc_info[0] is not None and issubclass(exc_info[0], asyncio.CancelledError)): + return True + if record.name not in ("asyncio", "concurrent.futures"): + return True + message = record.getMessage() + return not any(fragment in message for fragment in _ALLOWED_CANCEL_FRAGMENTS) diff --git a/daiv/core/static/core/img/icons/chat-bubble.svg b/daiv/core/static/core/img/icons/chat-bubble.svg new file mode 100644 index 000000000..32682c8ce --- /dev/null +++ b/daiv/core/static/core/img/icons/chat-bubble.svg @@ -0,0 +1,3 @@ + + + diff --git a/daiv/daiv/api.py b/daiv/daiv/api.py index 2dc3fc617..70ca43cdb 100644 --- a/daiv/daiv/api.py +++ b/daiv/daiv/api.py @@ -3,7 +3,7 @@ from ninja import NinjaAPI from accounts.api.router import router as accounts_router -from chat.api.views import chat_router, models_router +from chat.api.views import chat_router from codebase.api.router import router as codebase_router from . import __version__ @@ -12,6 +12,5 @@ api.add_router("/accounts", accounts_router) api.add_router("/codebase", codebase_router) api.add_router("/chat", chat_router) -api.add_router("/models", models_router) api.add_router("/jobs", jobs_router) api.add_router("/oauth", oauth_router) diff --git a/daiv/daiv/settings/components/common.py b/daiv/daiv/settings/components/common.py index b64b23b56..1ea67e212 100644 --- a/daiv/daiv/settings/components/common.py +++ b/daiv/daiv/settings/components/common.py @@ -15,6 +15,7 @@ "accounts", "activity", "automation", + "chat", "codebase", "core", "mcp_server", diff --git a/daiv/daiv/settings/components/logs.py b/daiv/daiv/settings/components/logs.py index fb4eab23b..c64e0d7c9 100644 --- a/daiv/daiv/settings/components/logs.py +++ b/daiv/daiv/settings/components/logs.py @@ -7,6 +7,7 @@ LOGGING: dict = { "version": 1, "disable_existing_loggers": False, + "filters": {"suppress_cancelled_error": {"()": "core.log_filters.SuppressCancelledError"}}, "formatters": { "verbose": {"format": "[%(asctime)s] %(levelname)s - %(name)s - %(message)s", "datefmt": "%d-%m-%Y:%H:%M:%S %z"} }, @@ -14,5 +15,7 @@ "loggers": { "": {"level": LOGGING_LEVEL, "handlers": ["console"]}, "daiv": {"level": "DEBUG", "handlers": ["console"], "propagate": False}, + "asyncio": {"handlers": ["console"], "filters": ["suppress_cancelled_error"], "propagate": False}, + "concurrent.futures": {"handlers": ["console"], "filters": ["suppress_cancelled_error"], "propagate": False}, }, } diff --git a/daiv/daiv/urls.py b/daiv/daiv/urls.py index 9438f3d95..5760ccf99 100644 --- a/daiv/daiv/urls.py +++ b/daiv/daiv/urls.py @@ -28,6 +28,7 @@ def location(self, item): path("dashboard/", include("accounts.urls.dashboard")), path("dashboard/configuration/", include("core.urls.configuration")), path("dashboard/activity/", include("activity.urls")), + path("dashboard/chat/", include("chat.urls")), path("dashboard/runs/", include("activity.urls_runs", namespace="runs")), path("dashboard/notifications/", include("notifications.urls")), path("dashboard/schedules/", include("schedules.urls")), diff --git a/daiv/jobs/api/views.py b/daiv/jobs/api/views.py index c436f315e..a08f3b0f1 100644 --- a/daiv/jobs/api/views.py +++ b/daiv/jobs/api/views.py @@ -7,11 +7,10 @@ from activity.services import RepoTarget, asubmit_batch_runs from django_tasks_db.models import DBTaskResult from ninja import Router -from ninja.throttling import AuthRateThrottle from automation.agent.results import parse_agent_result from chat.api.security import AuthBearer -from core.site_settings import site_settings +from core.api.throttling import JobsRateThrottle from jobs.tasks import run_job_task from .schemas import JobStatusResponse, JobSubmitFailureItem, JobSubmitJobItem, JobSubmitRequest, JobSubmitResponse @@ -21,16 +20,7 @@ jobs_router = Router(auth=AuthBearer(), tags=["jobs"]) -class _LazyThrottle(AuthRateThrottle): - """Rate throttle that reads the rate from site_settings at startup (avoids import-time DB access).""" - - THROTTLE_RATES = {} - - def get_rate(self): - return site_settings.jobs_throttle_rate - - -@jobs_router.post("", response={202: JobSubmitResponse, 503: dict}, throttle=[_LazyThrottle()]) +@jobs_router.post("", response={202: JobSubmitResponse, 503: dict}, throttle=[JobsRateThrottle()]) async def submit_job(request: HttpRequest, payload: JobSubmitRequest): """Submit a batch of 1-20 agent jobs. Each repository runs as an independent job. diff --git a/daiv/jobs/tasks.py b/daiv/jobs/tasks.py index 612523781..5543cb0a8 100644 --- a/daiv/jobs/tasks.py +++ b/daiv/jobs/tasks.py @@ -1,9 +1,7 @@ import logging -import uuid from django_tasks import task from langchain_core.messages import HumanMessage -from langgraph.checkpoint.memory import InMemorySaver from automation.agent.graph import create_daiv_agent from automation.agent.results import AgentResult, build_agent_result @@ -11,39 +9,42 @@ from automation.agent.utils import build_langsmith_config, extract_text_content, get_daiv_agent_kwargs from codebase.base import Scope from codebase.context import set_runtime_ctx +from core.checkpointer import open_checkpointer logger = logging.getLogger("daiv.jobs") @task() -async def run_job_task(repo_id: str, prompt: str, ref: str | None = None, use_max: bool = False) -> AgentResult: - """ - Run the DAIV agent for a submitted job and return a standardized result. - - Args: - repo_id: The repository id. - prompt: The user prompt to send to the agent. - ref: The git reference. Defaults to the repository's default branch. - use_max: Whether to use the max model configuration. +async def run_job_task( + repo_id: str, prompt: str, thread_id: str, ref: str | None = None, use_max: bool = False +) -> AgentResult: + """Run the DAIV agent for a submitted job and return a standardized result. - Returns: - An :class:`AgentResult` dict with the agent response and code_changes flag. + The ``thread_id`` is used as the LangGraph checkpoint key. Callers MUST mint one + up-front and persist it on the corresponding ``Activity`` — chat resume is built + on the assumption that the activity row and the checkpointer share the same key. + A silent UUID fallback here would break that contract on the resume path. """ - logger.info("Starting job for repo_id=%s, ref=%s, use_max=%s", repo_id, ref, use_max) + if not thread_id: + raise ValueError("run_job_task requires a non-empty thread_id; mint one before enqueueing") + + logger.info("Starting job for repo_id=%s, ref=%s, use_max=%s, thread_id=%s", repo_id, ref, use_max, thread_id) input_data = {"messages": [HumanMessage(content=prompt)]} try: - async with set_runtime_ctx(repo_id=repo_id, scope=Scope.GLOBAL, ref=ref) as runtime_ctx: + async with ( + set_runtime_ctx(repo_id=repo_id, scope=Scope.GLOBAL, ref=ref) as runtime_ctx, + open_checkpointer() as checkpointer, + ): agent_kwargs = get_daiv_agent_kwargs(model_config=runtime_ctx.config.models.agent, use_max=use_max) - checkpointer = InMemorySaver() config = build_langsmith_config( runtime_ctx, trigger="job", model=agent_kwargs["model_names"][0], thinking_level=agent_kwargs["thinking_level"], extra_metadata={"ref": ref}, - configurable={"thread_id": str(uuid.uuid4())}, + configurable={"thread_id": thread_id}, ) daiv_agent = await create_daiv_agent(ctx=runtime_ctx, checkpointer=checkpointer, **agent_kwargs) with track_usage_metadata() as usage_handler: @@ -59,7 +60,7 @@ async def run_job_task(repo_id: str, prompt: str, ref: str | None = None, use_ma response_text = extract_text_content(messages[-1].content) - logger.info("Job completed for repo_id=%s", repo_id) + logger.info("Job completed for repo_id=%s, thread_id=%s", repo_id, thread_id) return await build_agent_result( daiv_agent, config, response=response_text, usage=build_usage_summary(usage_handler.usage_metadata).to_dict() ) diff --git a/daiv/mcp_server/server.py b/daiv/mcp_server/server.py index 99232c03f..bb9f9b52a 100644 --- a/daiv/mcp_server/server.py +++ b/daiv/mcp_server/server.py @@ -33,12 +33,11 @@ It reads, writes, and edits files, runs shell commands in a sandbox, and debugs CI/CD pipelines. **Automatic post-job steps — do NOT ask for these in `prompt`.** When a job produces code \ -changes, DAIV automatically commits, pushes to a newly created branch, and opens a \ -merge/pull request. The branch name, commit message, and MR/PR title and description are \ -generated by DAIV from the diff — do NOT supply them, and do NOT instruct the agent to \ -"commit", "push", "create a branch", or "open a PR/MR". Such instructions waste tokens and \ -can mislead the agent. Each `submit_job` call produces its own branch and its own MR/PR — \ -the response includes `merge_request_url` when one was created. +changes, DAIV automatically commits, pushes, and ensures a merge/pull request exists. The \ +branch name, commit message, and MR/PR title and description are generated by DAIV from \ +the diff — do NOT supply them, and do NOT instruct the agent to "commit", "push", "create \ +a branch", or "open a PR/MR". Such instructions waste tokens and can mislead the agent. \ +The response includes `merge_request_url` when one was created or updated. Use `list_repositories` to discover available repositories by name or topic. \ If the user already knows the `repo_id`, skip discovery and call `submit_job` directly. @@ -47,9 +46,10 @@ or error messages. Vague requests produce poor results. `ref` is the STARTING POINT the agent reads from (base branch or commit SHA), not a target \ -branch name. Omit it to start from the repository default branch. Pass an existing branch \ -name to work from that branch's state — DAIV will still publish the result to a new branch \ -and open a new MR/PR. +branch name. Omit it to start from the repository default branch. If `ref` is a branch \ +that already has an open MR/PR, DAIV pushes the new commit onto that same branch and \ +updates the existing MR/PR — useful for fixing CI or in-review branch. \ +Otherwise DAIV creates a new branch and opens a new MR/PR. Jobs are rate-limited per user. Long-running jobs may exceed the 10-minute polling window; \ continue polling with `get_job_status` if the result is not yet available.\ diff --git a/daiv/notifications/migrations/0003_alter_notificationdelivery_channel_type_and_more.py b/daiv/notifications/migrations/0003_alter_notificationdelivery_channel_type_and_more.py new file mode 100644 index 000000000..4d92cab97 --- /dev/null +++ b/daiv/notifications/migrations/0003_alter_notificationdelivery_channel_type_and_more.py @@ -0,0 +1,24 @@ +# Generated by Django 6.0.4 on 2026-04-27 22:14 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [("notifications", "0002_seed_email_bindings")] + + operations = [ + migrations.AlterField( + model_name="notificationdelivery", + name="channel_type", + field=models.CharField( + choices=[("email", "Email"), ("rocketchat", "Rocket Chat")], max_length=32, verbose_name="channel type" + ), + ), + migrations.AlterField( + model_name="userchannelbinding", + name="channel_type", + field=models.CharField( + choices=[("email", "Email"), ("rocketchat", "Rocket Chat")], max_length=32, verbose_name="channel type" + ), + ), + ] diff --git a/daiv/static_src/css/input.css b/daiv/static_src/css/input.css index de148609c..30f734c42 100644 --- a/daiv/static_src/css/input.css +++ b/daiv/static_src/css/input.css @@ -261,3 +261,1201 @@ a, button, [role="button"] { @apply px-3 py-2 text-sm; } } + +/* -- Chat UI ------------------------------------------------------------- */ + +@layer components { + /*
is the page scroll container (claude.ai / chatgpt pattern). The + wrapper is `min-h-full flex flex-col` from chat_detail.html, so the chat + fills main when content is short and grows naturally past that as turns + accumulate. `grid-template-rows: minmax(0, 1fr)` makes the column stretch + to fill the grid even when content is short; without it the column + collapses to content height and the sticky composer can't reach the + viewport bottom. */ + .chat-shell { + display: grid; + grid-template-columns: 1fr; + grid-template-rows: minmax(0, 1fr); + gap: 0; + flex: 1 1 auto; + } + + @media (min-width: 1100px) { + .chat-shell { + grid-template-columns: minmax(0, 1fr) 280px; + column-gap: 1.5rem; + } + + /* Drop the rail column entirely and re-center the chat-column at the + same width it has on the chat detail page (1fr − 280px rail − 1.5rem + gap of a 72rem container ≈ 848px). Avoids the off-center look you'd + get from just hiding the rail with its column still reserved. */ + .chat-shell--empty { + grid-template-columns: 1fr; + column-gap: 0; + } + + .chat-shell--empty .chat-rail { + display: none; + } + + .chat-shell--empty .chat-column { + max-width: 53rem; + margin-inline: auto; + width: 100%; + } + } + + .chat-column { + position: relative; + display: flex; + flex-direction: column; + min-width: 0; + } + + .chat-rail { + display: none; + } + + @media (min-width: 1100px) { + .chat-rail { + display: block; + position: sticky; + top: 0; + align-self: start; + padding: 1.5rem 1rem; + border-left: 1px solid rgba(255, 255, 255, 0.04); + font-size: 12.5px; + color: #cbd5e1; + max-height: 100vh; + overflow-y: auto; + } + } + + .chat-rail__block + .chat-rail__block { + margin-top: 20px; + } + + .chat-rail__label { + font-family: "JetBrains Mono", ui-monospace, monospace; + font-size: 10px; + letter-spacing: 0.14em; + text-transform: uppercase; + color: #64748b; + margin-bottom: 6px; + } + + .chat-rail__repo { + font-family: "JetBrains Mono", ui-monospace, monospace; + font-size: 12px; + color: #cbd5e1; + } + + .chat-rail__repo-ref { + color: #93c5fd; + } + + .chat-rail__status { + display: inline-flex; + align-items: center; + gap: 6px; + padding: 4px 9px; + border-radius: 999px; + font-size: 11px; + border: 1px solid transparent; + font-family: "JetBrains Mono", ui-monospace, monospace; + } + + .chat-rail__status--idle { + background: rgba(52, 211, 153, 0.08); + border-color: rgba(52, 211, 153, 0.18); + color: #34d399; + } + + .chat-rail__status--running { + background: rgba(251, 191, 36, 0.08); + border-color: rgba(251, 191, 36, 0.18); + color: #fbbf24; + } + + .chat-rail__status--thinking { + background: rgba(167, 139, 250, 0.08); + border-color: rgba(167, 139, 250, 0.18); + color: #c4b5fd; + } + + .chat-rail__status--error { + background: rgba(248, 113, 113, 0.08); + border-color: rgba(248, 113, 113, 0.18); + color: #fca5a5; + } + + .chat-rail__status-dot { + width: 5px; + height: 5px; + border-radius: 999px; + background: currentColor; + } + + .chat-rail__file { + display: flex; + align-items: baseline; + gap: 6px; + width: 100%; + text-align: left; + padding: 3px 4px; + color: #cbd5e1; + font-family: "JetBrains Mono", ui-monospace, monospace; + font-size: 11.5px; + border: 1px solid transparent; + border-radius: 4px; + background: transparent; + cursor: pointer; + } + + .chat-rail__file-op { + flex: 0 0 10px; + text-align: center; + font-weight: 600; + color: #64748b; + } + + .chat-rail__file-path { + flex: 1 1 auto; + min-width: 0; + white-space: nowrap; + overflow: hidden; + text-overflow: ellipsis; + } + + .chat-rail__file--added .chat-rail__file-op { color: #34d399; } + .chat-rail__file--deleted .chat-rail__file-op { color: #f87171; } + .chat-rail__file--deleted .chat-rail__file-path { color: #94a3b8; text-decoration: line-through; } + .chat-rail__file--renamed .chat-rail__file-op { color: #60a5fa; } + + .chat-rail__file:hover { + background: rgba(255, 255, 255, 0.03); + } + + .chat-rail__file--new { + animation: rail-file-in 300ms ease-out both, rail-file-pulse 800ms ease-out 300ms both; + border-left: 1px solid transparent; + } + + @keyframes rail-file-in { + from { + opacity: 0; + transform: translateX(-6px); + } + } + + @keyframes rail-file-pulse { + 0%, 100% { + border-left-color: transparent; + } + 40% { + border-left-color: #34d399; + } + } + + /* Transcript */ + + .chat-transcript { + flex: 1 1 auto; + padding: 1.5rem 0 4rem 0; + } + + .chat-hero { + display: flex; + flex-direction: column; + gap: 12px; + max-width: 34rem; + margin: 10vh auto 0; + padding: 2rem 1rem; + animation: fade-up 0.6s ease-out both; + } + + .chat-hero__eyebrow { + display: inline-flex; + align-items: center; + gap: 8px; + font-family: "JetBrains Mono", ui-monospace, monospace; + font-size: 11px; + letter-spacing: 0.14em; + text-transform: uppercase; + color: #94a3b8; + } + + .chat-hero__eyebrow::before { + content: ""; + width: 20px; + height: 1px; + background: linear-gradient(90deg, transparent, #94a3b8); + } + + .chat-hero__title { + font-family: "Outfit", system-ui, sans-serif; + font-size: clamp(28px, 4.5vw, 44px); + font-weight: 500; + line-height: 1.1; + letter-spacing: -0.02em; + background: linear-gradient(180deg, #fff, #cbd5e1); + -webkit-background-clip: text; + background-clip: text; + -webkit-text-fill-color: transparent; + } + + .chat-hero__subtitle { + font-size: 15px; + color: #94a3b8; + } + + /* Stack the two subtitle variants in the same grid cell so the container + always sizes to the larger one — eliminates the layout jump when the + visible text swaps after a repo is picked. Crossfade via opacity. */ + .chat-hero__subtitle-stack { + display: grid; + } + + .chat-hero__subtitle-stack > .chat-hero__subtitle { + grid-area: 1 / 1; + margin: 0; + transition: opacity 0.3s ease; + } + + .chat-hero__subtitle-stack > .chat-hero__subtitle--hidden { + opacity: 0; + pointer-events: none; + } + + .chat-hero__picker { + margin-top: 18px; + } + + /* New-chat empty state: collapse the transcript so the composer docks right + under the hero picker instead of pinning to the viewport bottom. Keeps the + pick-repo → type-message flow visually sequential. */ + .chat-column--empty .chat-transcript { + flex: 0 0 auto; + padding-bottom: 0; + } + + /* Lift the hero (and its repo/branch popovers) above the composer's + backdrop-filter stacking context — otherwise the popover that anchors + `top-full` from the picker visually slides under the composer below it. */ + .chat-column--empty .chat-hero { + position: relative; + z-index: 20; + } + + .chat-column--empty .chat-composer { + position: relative; + bottom: auto; + margin-top: 12px; + } + + .chat-column--empty .chat-composer::before { + display: none; + } + + /* Repo chip — shared between the activity form picker and the chat composer. + Inside `_repo_picker.html` (interactive, with edit/remove buttons) the chip + is rendered inline with its sibling utility classes; the styles here mirror + that look so the chat composer can render a static, non-interactive chip + with the same appearance using just `.repo-chip` + `.repo-chip__btn`. */ + + .repo-chip { + display: inline-flex; + align-items: center; + gap: 4px; + padding: 4px; + border-radius: 999px; + border: 1px solid rgba(255, 255, 255, 0.08); + background: rgba(255, 255, 255, 0.05); + color: #cbd5e1; + font-size: 13px; + font-weight: 500; + } + + .repo-chip__btn { + display: inline-flex; + align-items: center; + gap: 6px; + padding: 2px 6px; + border-radius: 999px; + } + + .repo-chip__btn--branch { + color: #93c5fd; + } + + /* When the agent has committed to a working branch, lift the branch segment + so users see at a glance that it's not the original ref anymore. */ + .repo-chip__btn--accent { + background: rgba(96, 165, 250, 0.12); + color: #bfdbfe; + } + + .repo-chip__btn--accent .repo-chip__btn--branch { + color: #dbeafe; + } + + /* MR pill — companion chip rendered after the repo chip when GitMiddleware + produced a merge request. Sits in the same row, links to the MR, and + carries an optional draft badge so reviewers can spot WIP work. */ + .repo-chip--mr { + background: rgba(34, 197, 94, 0.08); + border-color: rgba(34, 197, 94, 0.22); + color: #bbf7d0; + text-decoration: none; + transition: background 0.15s ease, border-color 0.15s ease, transform 0.15s ease; + } + + .repo-chip--mr:hover { + background: rgba(34, 197, 94, 0.14); + border-color: rgba(34, 197, 94, 0.35); + transform: translateY(-1px); + } + + .repo-chip__badge { + display: inline-flex; + align-items: center; + padding: 0 6px; + height: 16px; + border-radius: 8px; + background: rgba(251, 191, 36, 0.18); + border: 1px solid rgba(251, 191, 36, 0.28); + color: #fde68a; + font-size: 10.5px; + font-weight: 600; + text-transform: uppercase; + letter-spacing: 0.05em; + } + + /* Turn */ + + .chat-turn { + padding: 0.5rem 0; + animation: fade-up 0.45s ease-out both; + } + + .chat-turn__body { + display: flex; + flex-direction: column; + gap: 0.5rem; + min-width: 0; + font-size: 14.5px; + color: #e2e8f0; + line-height: 1.65; + } + + /* Text segments breathe: more above than below, same whether alone in a + turn or followed by tool-call segments. */ + .chat-segment--text { + margin-top: 0.5rem; + margin-bottom: 0.25rem; + font-size: 15.5px; + line-height: 1.7; + } + + /* User turn: no thread line, bubble hugs the right edge */ + .chat-turn--user .chat-turn__body { + align-items: flex-end; + text-align: right; + } + + .chat-turn--user .chat-segment--text { + max-width: min(100%, 70ch); + padding: 8px 14px; + background: rgba(96, 165, 250, 0.1); + border: 1px solid rgba(96, 165, 250, 0.22); + border-radius: 14px; + border-top-right-radius: 6px; + color: #dbeafe; + text-align: left; + } + + .chat-turn--user .chat-text { + white-space: pre-wrap; + overflow-wrap: break-word; + } + + .chat-text :where(p, ul, ol, pre, blockquote, h1, h2, h3, h4, h5, h6, hr, table, dl) { + margin: 0 0 0.65rem 0; + } + + .chat-text :where(p, ul, ol, pre, blockquote, h1, h2, h3, h4, h5, h6, hr, table, dl):last-child { + margin-bottom: 0; + } + + .chat-text p { + margin-bottom: 0.85em; + } + + /* Heading hierarchy: tight to the content beneath them, breathing above so + they bind to their section. Subtle rules under h1/h2 give an editorial + scan-line without reading as decoration. */ + .chat-text :where(h1, h2, h3, h4, h5, h6) { + margin: 1.5em 0 0.45em; + color: #f1f5f9; + font-weight: 600; + line-height: 1.3; + letter-spacing: -0.005em; + } + + .chat-text :where(h1, h2, h3, h4, h5, h6):first-child { + margin-top: 0; + } + + .chat-text h1 { + font-size: 1.4em; + padding-bottom: 0.3em; + border-bottom: 1px solid rgba(255, 255, 255, 0.07); + letter-spacing: -0.01em; + } + + .chat-text h2 { + font-size: 1.22em; + padding-bottom: 0.25em; + border-bottom: 1px solid rgba(255, 255, 255, 0.05); + } + + .chat-text h3 { + font-size: 1.08em; + } + + .chat-text h4 { + font-size: 1em; + color: #e2e8f0; + } + + .chat-text :where(h5, h6) { + font-size: 0.85em; + color: #cbd5e1; + text-transform: uppercase; + letter-spacing: 0.06em; + } + + /* Lists: bullets/numbers visible in the gutter, muted so the content stays + primary. Nested lists tighten so 3-deep agendas don't blow out. */ + .chat-text :where(ul, ol) { + padding-inline-start: 1.5em; + } + + .chat-text ul { + list-style: disc; + } + + .chat-text ol { + list-style: decimal; + } + + .chat-text li { + margin: 0.2em 0; + } + + .chat-text li::marker { + color: rgba(148, 163, 184, 0.7); + } + + .chat-text li > :where(ul, ol) { + margin: 0.2em 0 0; + } + + .chat-text li > p { + margin-bottom: 0.3em; + } + + /* Task lists (GFM): drop the disc, align the checkbox flush with text. */ + .chat-text li:has(> input[type="checkbox"]) { + list-style: none; + margin-inline-start: -1.25em; + } + + .chat-text li > input[type="checkbox"] { + margin-right: 0.5em; + accent-color: #60a5fa; + transform: translateY(1px); + } + + .chat-text blockquote { + border-left: 3px solid rgba(96, 165, 250, 0.4); + padding: 0.1em 0 0.1em 1em; + color: #94a3b8; + font-style: italic; + } + + .chat-text blockquote > :last-child { + margin-bottom: 0; + } + + .chat-text hr { + border: 0; + height: 1px; + background: linear-gradient(to right, transparent, rgba(255, 255, 255, 0.14), transparent); + margin: 1.5em 0; + } + + .chat-text strong { + color: #f1f5f9; + font-weight: 600; + } + + .chat-text em { + color: #e2e8f0; + } + + .chat-text mark { + background: rgba(251, 191, 36, 0.18); + color: #fde68a; + padding: 0 3px; + border-radius: 3px; + } + + .chat-text img { + max-width: 100%; + height: auto; + border-radius: 8px; + border: 1px solid rgba(255, 255, 255, 0.06); + } + + .chat-text pre { + font-family: "JetBrains Mono", ui-monospace, monospace; + font-size: 12.5px; + background: rgba(0, 0, 0, 0.35); + border: 1px solid rgba(255, 255, 255, 0.04); + border-radius: 8px; + padding: 12px 14px; + overflow-x: auto; + scrollbar-width: thin; + scrollbar-color: rgba(255, 255, 255, 0.15) transparent; + } + + /* Inline code: tinted chip in the amber accent so it reads as a single token + rather than a colour burst against body text. */ + .chat-text code:not(pre code) { + font-family: "JetBrains Mono", ui-monospace, monospace; + font-size: 0.88em; + padding: 1px 6px; + border-radius: 5px; + background: rgba(251, 191, 36, 0.08); + border: 1px solid rgba(251, 191, 36, 0.16); + color: #fcd34d; + } + + /* Code chips inside headings should track the heading colour so they don't + scream over their neighbours. */ + .chat-text :where(h1, h2, h3, h4, h5, h6) code:not(pre code) { + color: #f1f5f9; + background: rgba(255, 255, 255, 0.06); + border-color: rgba(255, 255, 255, 0.08); + font-size: 0.92em; + } + + .chat-text kbd { + font-family: "JetBrains Mono", ui-monospace, monospace; + font-size: 11.5px; + padding: 1px 6px; + border: 1px solid rgba(255, 255, 255, 0.12); + border-bottom-width: 2px; + border-radius: 5px; + background: rgba(255, 255, 255, 0.04); + color: #e2e8f0; + box-shadow: 0 1px 0 rgba(0, 0, 0, 0.4); + } + + .chat-text a { + color: #93c5fd; + text-decoration: underline; + text-decoration-color: rgba(147, 197, 253, 0.3); + text-underline-offset: 2px; + text-decoration-thickness: 1px; + transition: text-decoration-color 0.15s ease; + } + + .chat-text a:hover { + text-decoration-color: #93c5fd; + } + + .chat-text table { + width: 100%; + overflow-x: auto; + border-collapse: collapse; + font-size: 13px; + border: 1px solid rgba(255, 255, 255, 0.08); + border-radius: 8px; + margin-bottom: 0.65rem; + } + + .chat-text thead { + background: rgba(255, 255, 255, 0.04); + } + + .chat-text th, + .chat-text td { + padding: 7px 12px; + text-align: left; + vertical-align: top; + border-bottom: 1px solid rgba(255, 255, 255, 0.05); + border-right: 1px solid rgba(255, 255, 255, 0.04); + } + + .chat-text th { + color: #e2e8f0; + font-weight: 600; + letter-spacing: 0.01em; + } + + .chat-text td { + color: #cbd5e1; + } + + .chat-text th:last-child, + .chat-text td:last-child { + border-right: 0; + } + + .chat-text tbody tr:last-child td { + border-bottom: 0; + } + + .chat-text tbody tr:nth-child(even) { + background: rgba(255, 255, 255, 0.015); + } + + .chat-text tbody tr:hover { + background: rgba(255, 255, 255, 0.035); + } + + .chat-text td code, + .chat-text th code { + font-size: 12px; + } + + /* Tool call */ + + .chat-tool { + border: 1px solid rgba(255, 255, 255, 0.06); + background: rgba(255, 255, 255, 0.02); + border-radius: 10px; + overflow: hidden; + transition: border-color 0.15s ease, background 0.15s ease; + } + + .chat-tool:hover { + border-color: rgba(255, 255, 255, 0.1); + } + + .chat-tool[open] { + background: rgba(255, 255, 255, 0.025); + } + + .chat-tool--running { + border-color: rgba(251, 191, 36, 0.2); + background: rgba(251, 191, 36, 0.03); + } + + .chat-tool--error { + border-color: rgba(248, 113, 113, 0.25); + background: rgba(248, 113, 113, 0.03); + } + + .chat-tool__summary { + display: flex; + align-items: center; + gap: 10px; + padding: 9px 12px; + cursor: pointer; + list-style: none; + font-family: "JetBrains Mono", ui-monospace, monospace; + font-size: 12px; + } + + .chat-tool__summary::-webkit-details-marker { + display: none; + } + + .chat-tool__icon { + display: inline-flex; + align-items: center; + justify-content: center; + width: 16px; + height: 16px; + border-radius: 999px; + flex-shrink: 0; + font-size: 10px; + } + + .chat-tool--running .chat-tool__icon { + color: #fbbf24; + background: rgba(251, 191, 36, 0.12); + } + + .chat-tool--done .chat-tool__icon { + color: #34d399; + background: rgba(52, 211, 153, 0.1); + } + + .chat-tool--error .chat-tool__icon { + color: #f87171; + background: rgba(248, 113, 113, 0.12); + } + + .chat-tool__spinner { + width: 10px; + height: 10px; + border-radius: 999px; + border: 1.5px solid currentColor; + border-top-color: transparent; + animation: spin 0.9s linear infinite; + } + + @keyframes spin { + to { transform: rotate(360deg); } + } + + /* Publish-phase chip — compact inline status emitted by the diff_to_metadata + subagents during commit/MR creation. Sits in the assistant turn body where + a regular tool card would, but reads as a single line so a reader scans + past it once it's done. */ + .chat-phase { + display: inline-flex; + align-items: center; + gap: 8px; + padding: 5px 12px; + border-radius: 999px; + border: 1px solid rgba(96, 165, 250, 0.22); + background: rgba(96, 165, 250, 0.08); + color: #bfdbfe; + font-size: 12.5px; + font-weight: 500; + width: fit-content; + transition: border-color 0.2s ease, background 0.2s ease, color 0.2s ease; + } + + .chat-phase--done { + border-color: rgba(34, 197, 94, 0.28); + background: rgba(34, 197, 94, 0.08); + color: #bbf7d0; + } + + .chat-phase__icon { + display: inline-flex; + align-items: center; + justify-content: center; + width: 14px; + height: 14px; + font-size: 11px; + line-height: 1; + } + + .chat-phase__spinner { + width: 10px; + height: 10px; + border-radius: 999px; + border: 1.5px solid currentColor; + border-top-color: transparent; + animation: spin 0.9s linear infinite; + } + + .chat-phase__suffix { + opacity: 0.7; + margin-left: -4px; + } + + .chat-tool__name { + color: #a5b4fc; + font-weight: 500; + } + + .chat-tool--error .chat-tool__name { + color: #fca5a5; + } + + .chat-tool__path { + color: #64748b; + flex: 1; + min-width: 0; + white-space: nowrap; + overflow: hidden; + text-overflow: ellipsis; + } + + .chat-tool__badge { + padding: 1px 6px; + border-radius: 4px; + font-size: 10px; + font-family: "JetBrains Mono", ui-monospace, monospace; + flex-shrink: 0; + } + + .chat-tool__badge--success { background: rgba(52, 211, 153, 0.1); color: #34d399; } + .chat-tool__badge--danger { background: rgba(248, 113, 113, 0.1); color: #f87171; } + .chat-tool__badge--warn { background: rgba(251, 191, 36, 0.12); color: #fbbf24; } + .chat-tool__badge--info { background: rgba(96, 165, 250, 0.1); color: #60a5fa; } + .chat-tool__badge--violet { background: rgba(167, 139, 250, 0.12); color: #c4b5fd; } + .chat-tool__badge--neutral { background: rgba(255, 255, 255, 0.04); color: #94a3b8; } + + .chat-tool__chev { + color: #475569; + transition: transform 0.2s ease; + flex-shrink: 0; + } + + .chat-tool[open] .chat-tool__chev { + transform: rotate(180deg); + } + + .chat-tool__body { + padding: 0 12px 12px; + display: flex; + flex-direction: column; + gap: 10px; + animation: fade-up 0.2s ease-out both; + } + + .chat-tool__block { + border-top: 1px dashed rgba(255, 255, 255, 0.06); + padding-top: 10px; + } + + .chat-tool__block-label { + font-family: "JetBrains Mono", ui-monospace, monospace; + font-size: 10px; + letter-spacing: 0.14em; + text-transform: uppercase; + color: #64748b; + margin-bottom: 6px; + } + + .chat-tool__code { + font-family: "JetBrains Mono", ui-monospace, monospace; + font-size: 12.5px; + line-height: 1.55; + color: #cbd5e1; + background: rgba(0, 0, 0, 0.28); + border: 1px solid rgba(255, 255, 255, 0.04); + padding: 10px 12px; + border-radius: 8px; + white-space: pre-wrap; + word-break: break-word; + max-height: 360px; + overflow-y: auto; + } + + .chat-bash { + display: flex; + flex-direction: column; + gap: 8px; + } + + .chat-bash + .chat-bash { + margin-top: 6px; + padding-top: 10px; + border-top: 1px dashed rgba(255, 255, 255, 0.06); + } + + .chat-bash__cmd { + font-family: "JetBrains Mono", ui-monospace, monospace; + font-size: 12.5px; + color: #e2e8f0; + word-break: break-word; + } + + .chat-bash__prompt { + color: #a78bfa; + margin-right: 6px; + } + + .chat-bash__exit--err { + font-family: "JetBrains Mono", ui-monospace, monospace; + font-size: 11px; + color: #f87171; + align-self: flex-start; + padding: 2px 8px; + border-radius: 4px; + background: rgba(248, 113, 113, 0.08); + } + + .chat-tool__highlight { + animation: tool-highlight 600ms ease-out both; + } + + @keyframes tool-highlight { + from { + box-shadow: 0 0 0 2px rgba(167, 139, 250, 0.6); + } + to { + box-shadow: 0 0 0 0 rgba(167, 139, 250, 0); + } + } + + /* Thinking ticker */ + + .chat-thinking { + display: inline-flex; + align-items: center; + gap: 8px; + padding: 6px 0; + color: #94a3b8; + font-family: "JetBrains Mono", ui-monospace, monospace; + font-size: 13px; + } + + .chat-thinking__dot { + width: 6px; + height: 6px; + border-radius: 999px; + background: #fbbf24; + box-shadow: 0 0 10px rgba(251, 191, 36, 0.5); + animation: pulse-dot 1.4s ease-in-out infinite; + } + + /* Per-turn reasoning disclosure */ + + .chat-thinking-seg { + margin: 0.25rem 0; + } + + .chat-thinking-seg__summary { + display: inline-flex; + align-items: center; + gap: 8px; + padding: 2px 0; + color: #94a3b8; + font-family: "JetBrains Mono", ui-monospace, monospace; + font-size: 12.5px; + cursor: pointer; + list-style: none; + user-select: none; + transition: color 0.15s ease; + } + + .chat-thinking-seg__summary::-webkit-details-marker { + display: none; + } + + .chat-thinking-seg__summary::before { + content: "›"; + display: inline-block; + width: 10px; + color: #64748b; + transition: transform 0.15s ease; + } + + .chat-thinking-seg[open] > .chat-thinking-seg__summary::before { + transform: rotate(90deg); + } + + .chat-thinking-seg__summary:hover { + color: #cbd5e1; + } + + .chat-thinking-seg__dot { + width: 5px; + height: 5px; + border-radius: 999px; + background: #a78bfa; + box-shadow: 0 0 8px rgba(167, 139, 250, 0.5); + animation: pulse-dot 1.4s ease-in-out infinite; + } + + .chat-thinking-seg__body { + margin: 0.5rem 0 0.75rem 18px; + padding-left: 10px; + border-left: 2px solid rgba(148, 163, 184, 0.18); + color: #94a3b8; + font-size: 13.5px; + line-height: 1.65; + } + + /* Composer */ + + .chat-composer { + position: sticky; + bottom: 16px; + flex: 0 0 auto; + margin: 0; + padding: 12px 14px; + border: 1px solid rgba(255, 255, 255, 0.06); + border-radius: 16px; + background: linear-gradient(180deg, rgba(20, 22, 40, 0.85), rgba(12, 14, 30, 0.85)); + backdrop-filter: blur(14px); + transition: border-color 0.2s ease, box-shadow 0.2s ease; + z-index: 10; + } + + /* Soft fade behind scrolled content above the sticky composer. Stops a few + pixels short of the composer's top border so it doesn't visually erase + the (intentionally faint) 1px top border. */ + .chat-composer::before { + content: ""; + position: absolute; + left: 0; + right: 0; + bottom: calc(100% + 2px); + height: 40px; + pointer-events: none; + background: linear-gradient(180deg, rgba(3, 7, 18, 0), rgba(3, 7, 18, 0.45)); + } + + .chat-composer:focus-within { + border-color: rgba(167, 139, 250, 0.3); + box-shadow: + 0 0 0 1px rgba(167, 139, 250, 0.14), + 0 16px 48px -16px rgba(124, 58, 237, 0.35); + } + + .chat-composer--sending { + opacity: 0.72; + } + + .chat-composer__meta { + display: flex; + align-items: center; + gap: 6px; + margin-bottom: 8px; + } + + .chat-composer__textarea { + width: 100%; + background: transparent !important; + border: none !important; + outline: none; + resize: none; + padding: 4px 2px !important; + font-family: "Outfit", system-ui, sans-serif; + font-size: 15px; + color: #fff; + line-height: 1.55; + min-height: 2.8rem; + max-height: 14rem; + } + + .chat-composer__textarea::placeholder { + color: #64748b; + font-weight: 300; + } + + .chat-composer__textarea:focus { + box-shadow: none !important; + outline: none !important; + } + + .chat-composer__actions { + display: flex; + align-items: center; + justify-content: space-between; + gap: 10px; + margin-top: 6px; + } + + .chat-composer__hint { + display: inline-flex; + align-items: center; + gap: 6px; + font-size: 10.5px; + color: #64748b; + font-family: "JetBrains Mono", ui-monospace, monospace; + } + + .chat-composer__kbd { + display: inline-flex; + align-items: center; + justify-content: center; + min-width: 18px; + height: 18px; + padding: 0 4px; + border-radius: 4px; + background: rgba(255, 255, 255, 0.05); + border: 1px solid rgba(255, 255, 255, 0.08); + font-size: 10px; + color: #cbd5e1; + } + + .chat-composer__btn { + display: inline-flex; + align-items: center; + gap: 6px; + padding: 7px 14px; + border-radius: 10px; + font-weight: 600; + font-size: 13.5px; + font-family: "Outfit", system-ui, sans-serif; + letter-spacing: 0.01em; + transition: transform 0.15s ease, background 0.15s ease; + } + + .chat-composer__btn--send { + background: linear-gradient(180deg, #fff, #e2e8f0); + color: #030712; + } + + .chat-composer__btn--send:hover:not(:disabled) { + transform: translateY(-1px); + background: #fff; + } + + .chat-composer__btn--send:disabled { + opacity: 0.45; + cursor: not-allowed; + } + + .chat-composer__btn--stop { + background: rgba(248, 113, 113, 0.08); + border: 1px solid rgba(248, 113, 113, 0.25); + color: #fca5a5; + } + + .chat-composer__btn--stop::before { + content: ""; + width: 8px; + height: 8px; + background: #fca5a5; + border-radius: 2px; + } + + /* Jump-to-latest pill */ + + .chat-jump { + position: absolute; + left: 50%; + bottom: calc(100% + 12px); + transform: translateX(-50%); + padding: 4px 12px; + background: rgba(20, 22, 40, 0.85); + border: 1px solid rgba(167, 139, 250, 0.25); + border-radius: 999px; + color: #c4b5fd; + font-size: 12px; + font-family: "JetBrains Mono", ui-monospace, monospace; + cursor: pointer; + backdrop-filter: blur(8px); + animation: fade-up 0.3s ease-out both; + } + + /* Responsive summary strip */ + + .chat-summary { + display: flex; + align-items: center; + gap: 10px; + flex-wrap: wrap; + padding: 10px 0; + margin-bottom: 12px; + border-bottom: 1px solid rgba(255, 255, 255, 0.04); + font-family: "JetBrains Mono", ui-monospace, monospace; + font-size: 11.5px; + color: #94a3b8; + } + + @media (min-width: 1100px) { + .chat-summary { + display: none; + } + } + + .chat-summary__sep { + color: #334155; + } +} diff --git a/pyproject.toml b/pyproject.toml index a96db338f..aecc9b9c1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,6 +19,8 @@ classifiers = [ "Programming Language :: Python :: 3.14", ] dependencies = [ + "ag-ui-langgraph==0.0.34", + "copilotkit==0.1.86", "croniter==6.2.2", "ddgs==9.14.1", "deepagents==0.5.1", diff --git a/tests/unit_tests/activity/test_batch_submit.py b/tests/unit_tests/activity/test_batch_submit.py index 251fc9d59..15f605733 100644 --- a/tests/unit_tests/activity/test_batch_submit.py +++ b/tests/unit_tests/activity/test_batch_submit.py @@ -60,7 +60,13 @@ def test_single_repo_creates_one_activity_with_batch_id(self, member_user): assert result.activities[0].batch_id == result.batch_id assert result.activities[0].repo_id == "a/b" assert result.activities[0].trigger_type == TriggerType.UI_JOB - m_task.aenqueue.assert_awaited_once_with(repo_id="a/b", prompt="do it", ref=None, use_max=False) + m_task.aenqueue.assert_awaited_once() + enqueue_kwargs = m_task.aenqueue.await_args.kwargs + assert enqueue_kwargs["repo_id"] == "a/b" + assert enqueue_kwargs["prompt"] == "do it" + assert enqueue_kwargs["ref"] is None + assert enqueue_kwargs["use_max"] is False + assert enqueue_kwargs["thread_id"] == result.activities[0].thread_id def test_five_repos_creates_five_activities_sharing_batch_id(self, member_user): tasks_seen = [] @@ -86,6 +92,12 @@ async def _aenqueue(**kwargs): assert [t["repo_id"] for t in tasks_seen] == [f"o/r{i}" for i in range(5)] assert tasks_seen[0]["ref"] is None # empty ref threads as None assert tasks_seen[1]["ref"] == "dev" + # Each activity gets a distinct thread_id that matches the one passed to the task. + activity_thread_ids = [a.thread_id for a in result.activities] + assert all(activity_thread_ids) + assert len(set(activity_thread_ids)) == 5 + task_thread_ids = [t["thread_id"] for t in tasks_seen] + assert set(task_thread_ids) == set(activity_thread_ids) def test_empty_repos_raises_value_error(self, member_user): with pytest.raises(ValueError): @@ -148,6 +160,7 @@ async def _aenqueue(**kwargs): class _Stub: def __init__(self, task_result_id): self.task_result_id = task_result_id + self.pk = uuid.uuid4() async def _flaky_create(**kwargs): if kwargs["repo_id"] == "o/b": diff --git a/tests/unit_tests/automation/agent/middlewares/test_git.py b/tests/unit_tests/automation/agent/middlewares/test_git.py index b00d60fc7..2b9cd9924 100644 --- a/tests/unit_tests/automation/agent/middlewares/test_git.py +++ b/tests/unit_tests/automation/agent/middlewares/test_git.py @@ -1,4 +1,4 @@ -from unittest.mock import AsyncMock, Mock, patch +from unittest.mock import AsyncMock, MagicMock, Mock, patch import pytest @@ -11,6 +11,10 @@ def _make_runtime(*, scope: Scope = Scope.ISSUE) -> Mock: runtime = Mock() runtime.context = Mock() runtime.context.scope = scope + runtime.context.merge_request = None + runtime.context.repository = Mock(slug="a/b") + runtime.context.config = Mock(default_branch="main") + runtime.context.gitrepo = Mock() return runtime @@ -27,3 +31,107 @@ async def test_aafter_agent_propagates_push_permission_error(self): pytest.raises(GitPushPermissionError), ): await middleware.aafter_agent(state={"merge_request": None}, runtime=runtime) + + async def test_abefore_agent_seeds_open_mr_when_state_empty(self): + """No MR in state → seed it so it streams via STATE_SNAPSHOT.""" + middleware = GitMiddleware() + runtime = _make_runtime(scope=Scope.GLOBAL) + existing_mr = MagicMock(source_branch="feature-x") + + with ( + patch("automation.agent.middlewares.git.get_repo_ref", return_value="feature-x"), + patch( + "automation.agent.middlewares.git.GitMiddleware._alookup_open_mr", + new=AsyncMock(return_value=existing_mr), + ) as lookup, + ): + result = await middleware.abefore_agent({}, runtime) + + lookup.assert_awaited_once_with(runtime.context) + assert result == {"merge_request": existing_mr, "code_changes": False} + + async def test_abefore_agent_skips_lookup_when_state_has_mr(self): + middleware = GitMiddleware() + runtime = _make_runtime(scope=Scope.GLOBAL) + state_mr = MagicMock(source_branch="feature-x") + + with ( + patch("automation.agent.middlewares.git.get_repo_ref", return_value="feature-x"), + patch("automation.agent.middlewares.git.GitMiddleware._alookup_open_mr", new=AsyncMock()) as lookup, + ): + result = await middleware.abefore_agent({"merge_request": state_mr}, runtime) + + lookup.assert_not_called() + assert result["merge_request"] is state_mr + + async def test_abefore_agent_runtime_context_overrides_state_in_mr_scope(self): + middleware = GitMiddleware() + runtime = _make_runtime(scope=Scope.MERGE_REQUEST) + runtime.context.merge_request = MagicMock(source_branch="feature-y") + stale_state_mr = MagicMock(source_branch="feature-x") + + with ( + patch("automation.agent.middlewares.git.get_repo_ref", return_value="feature-y"), + patch("automation.agent.middlewares.git.GitMiddleware._alookup_open_mr", new=AsyncMock()) as lookup, + ): + result = await middleware.abefore_agent({"merge_request": stale_state_mr}, runtime) + + lookup.assert_not_called() + assert result["merge_request"] is runtime.context.merge_request + + async def test_alookup_open_mr_returns_platform_mr(self): + runtime = _make_runtime() + existing_mr = MagicMock(source_branch="feature-x") + client = MagicMock() + client.get_merge_request_by_branches = MagicMock(return_value=existing_mr) + + with ( + patch("automation.agent.middlewares.git.get_repo_ref", return_value="feature-x"), + patch("automation.agent.middlewares.git.RepoClient.create_instance", return_value=client), + ): + mr = await GitMiddleware._alookup_open_mr(runtime.context) # noqa: SLF001 + + assert mr is existing_mr + client.get_merge_request_by_branches.assert_called_once_with("a/b", "feature-x", "main") + + async def test_alookup_open_mr_skips_default_branch(self): + runtime = _make_runtime() + + with ( + patch("automation.agent.middlewares.git.get_repo_ref", return_value="main"), + patch("automation.agent.middlewares.git.RepoClient.create_instance") as create, + ): + mr = await GitMiddleware._alookup_open_mr(runtime.context) # noqa: SLF001 + + assert mr is None + create.assert_not_called() + + async def test_alookup_open_mr_swallows_platform_errors(self): + from gitlab.exceptions import GitlabError + + runtime = _make_runtime() + client = MagicMock() + client.get_merge_request_by_branches = MagicMock(side_effect=GitlabError("gitlab down")) + + with ( + patch("automation.agent.middlewares.git.get_repo_ref", return_value="feature-x"), + patch("automation.agent.middlewares.git.RepoClient.create_instance", return_value=client), + ): + mr = await GitMiddleware._alookup_open_mr(runtime.context) # noqa: SLF001 + + assert mr is None + + async def test_alookup_open_mr_propagates_unexpected_errors(self): + """Bugs (KeyError/AttributeError) must NOT be silently caught — the + publisher would then create a duplicate MR thinking none existed. + """ + runtime = _make_runtime() + client = MagicMock() + client.get_merge_request_by_branches = MagicMock(side_effect=KeyError("missing field")) + + with ( + patch("automation.agent.middlewares.git.get_repo_ref", return_value="feature-x"), + patch("automation.agent.middlewares.git.RepoClient.create_instance", return_value=client), + pytest.raises(KeyError), + ): + await GitMiddleware._alookup_open_mr(runtime.context) # noqa: SLF001 diff --git a/tests/unit_tests/automation/agent/middlewares/test_sandbox.py b/tests/unit_tests/automation/agent/middlewares/test_sandbox.py index b5727adc6..691250a26 100644 --- a/tests/unit_tests/automation/agent/middlewares/test_sandbox.py +++ b/tests/unit_tests/automation/agent/middlewares/test_sandbox.py @@ -86,7 +86,12 @@ async def test_bash_tool_applies_patch_and_returns_results_json(self, tmp_path: output = await bash_tool.coroutine(command="echo ok", runtime=runtime) assert file_path.read_text() == "new\n" - assert output == "[]" + import json as _json + + payload = _json.loads(output) + assert payload["commands"] == [] + # The patch just edits hello.txt — nothing added/deleted/renamed. + assert payload["files_changed"] == [{"path": "hello.txt", "op": "modified"}] async def test_bash_tool_returns_error_when_sandbox_call_fails(self, tmp_path: Path): repo_dir = tmp_path / "repoX" diff --git a/tests/unit_tests/automation/agent/middlewares/test_web_fetch.py b/tests/unit_tests/automation/agent/middlewares/test_web_fetch.py index 0cfd5e466..dcdbec6d4 100644 --- a/tests/unit_tests/automation/agent/middlewares/test_web_fetch.py +++ b/tests/unit_tests/automation/agent/middlewares/test_web_fetch.py @@ -206,7 +206,7 @@ async def test_cache_key_changes_with_prompt(httpx_mock): async def test_invalid_url_returns_message(): result = await web_fetch_module.web_fetch_tool.ainvoke({"url": "not-a-url", "prompt": "x"}) - assert result == "Invalid URL. Provide a fully-formed http(s) URL (e.g., https://example.com)." + assert result == "error: Invalid URL. Provide a fully-formed http(s) URL (e.g., https://example.com)." async def test_empty_prompt_returns_contents(httpx_mock): @@ -245,7 +245,7 @@ async def test_rejects_large_content(httpx_mock): mock_site_settings.web_fetch_max_content_chars = 5 result = await web_fetch_module.web_fetch_tool.ainvoke({"url": "https://example.com", "prompt": "x"}) - assert "Page content is too large to safely analyze in one pass." in result + assert result.startswith("error: Page content is too large to safely analyze in one pass.") async def test_get_auth_headers_exact_domain_match(): diff --git a/tests/unit_tests/automation/agent/middlewares/test_web_search.py b/tests/unit_tests/automation/agent/middlewares/test_web_search.py index 6018a9c8c..f98476fcc 100644 --- a/tests/unit_tests/automation/agent/middlewares/test_web_search.py +++ b/tests/unit_tests/automation/agent/middlewares/test_web_search.py @@ -1,3 +1,4 @@ +import json from datetime import datetime from unittest.mock import AsyncMock, MagicMock, patch @@ -19,9 +20,8 @@ async def test_successful_search_duckduckgo(self, mock_wrapper_class, mock_setti result = await web_search_tool.ainvoke({"query": "test query"}) - assert "Test title" in result - assert "Test content" in result - assert "https://example.com" in result + parsed = json.loads(result) + assert parsed == [{"title": "Test title", "link": "https://example.com", "content": "Test content"}] mock_wrapper.results.assert_called_once_with("test query", max_results=5) @patch("automation.agent.middlewares.web_search.site_settings") @@ -40,10 +40,14 @@ async def test_successful_search_tavily(self, mock_wrapper_class, mock_settings) result = await web_search_tool.ainvoke({"query": "test query"}) - assert "Test tavily answer" in result - assert "Test tavily content" in result - assert "Test tavily title" in result - assert "https://example.com" in result + parsed = json.loads(result) + # Tavily's synthesized answer is prepended with link="" so the model can tell it apart from citable hits. + assert parsed[0] == {"title": "Suggested answer", "link": "", "content": "Test tavily answer"} + assert parsed[1] == { + "title": "Test tavily title", + "link": "https://example.com", + "content": "Test tavily content", + } mock_wrapper.raw_results_async.assert_called_once_with("test query", max_results=5, include_answer=True) @patch("automation.agent.middlewares.web_search.site_settings") @@ -58,7 +62,7 @@ async def test_no_results_duckduckgo(self, mock_wrapper_class, mock_settings): result = await web_search_tool.ainvoke({"query": "test query"}) - assert "No relevant results found" in result + assert json.loads(result) == [] mock_wrapper.results.assert_called_once_with("test query", max_results=5) @patch("automation.agent.middlewares.web_search.site_settings") @@ -73,7 +77,28 @@ async def test_no_results_tavily(self, mock_wrapper_class, mock_settings): result = await web_search_tool.ainvoke({"query": "test query"}) - assert "No relevant results found" in result + assert json.loads(result) == [] + + @patch("automation.agent.middlewares.web_search.site_settings") + @patch("automation.agent.middlewares.web_search.DuckDuckGoSearchAPIWrapper") + async def test_special_chars_survive_json_roundtrip(self, mock_wrapper_class, mock_settings): + # JSON encoding handles `"`, `&`, etc. natively — guard against future regressions + # if anyone reintroduces ad-hoc string formatting. + mock_settings.web_search_engine = "duckduckgo" + mock_settings.web_search_max_results = 5 + + mock_wrapper = MagicMock() + mock_wrapper.results.return_value = [ + {"title": 'Why "X" beats Y & Z', "link": "https://example.com/q?a=1&b=2", "snippet": "body with "} + ] + mock_wrapper_class.return_value = mock_wrapper + + result = await web_search_tool.ainvoke({"query": "q"}) + + parsed = json.loads(result) + assert parsed[0]["title"] == 'Why "X" beats Y & Z' + assert parsed[0]["link"] == "https://example.com/q?a=1&b=2" + assert parsed[0]["content"] == "body with " @patch("automation.agent.middlewares.web_search.site_settings") async def test_invalid_search_engine(self, mock_settings): @@ -102,17 +127,15 @@ async def test_multiple_results_formatting(self, mock_wrapper_class, mock_settin result = await web_search_tool.ainvoke({"query": "test query"}) - assert "First title" in result - assert "First result" in result - assert "https://example.com/first" in result - assert "Second title" in result - assert "Second result" in result - assert "https://example.com/second" in result - assert "Third title" in result - assert "Third result" in result - assert "https://example.com/third" in result - assert result.count("") == 3 + parsed = json.loads(result) + assert len(parsed) == 3 + assert [r["title"] for r in parsed] == ["First title", "Second title", "Third title"] + assert [r["link"] for r in parsed] == [ + "https://example.com/first", + "https://example.com/second", + "https://example.com/third", + ] + assert [r["content"] for r in parsed] == ["First result", "Second result", "Third result"] mock_wrapper.results.assert_called_once_with("test query", max_results=5) diff --git a/tests/unit_tests/automation/titling/__init__.py b/tests/unit_tests/automation/titling/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/unit_tests/automation/titling/test_tasks.py b/tests/unit_tests/automation/titling/test_tasks.py new file mode 100644 index 000000000..c3f7ec19e --- /dev/null +++ b/tests/unit_tests/automation/titling/test_tasks.py @@ -0,0 +1,127 @@ +from __future__ import annotations + +from unittest.mock import MagicMock, patch + +import pytest +from activity.models import Activity, TriggerType + +from automation.titling import tasks as titling_tasks +from automation.titling.tasks import GeneratedTitle, _ref_is_informative, generate_title_task + + +@pytest.mark.parametrize( + ("ref", "expected"), + [ + ("", False), + ("main", False), + ("MAIN", False), + ("master", False), + ("develop", False), + ("prod", False), + ("Production", False), + ("a1b2c3d", False), # 7-char SHA + ("DEADBEEFCAFE1234567890ABCDEF1234567890AB", False), # 40-char SHA, mixed case + ("feat/copilotkit-chat", True), + ("bugfix-123", True), + ("a1b2c3", True), # 6 chars — too short for SHA pattern + ("g1h2i3j4", True), # contains non-hex chars + ("release/2026-04", True), + ], +) +def test_ref_is_informative(ref: str, expected: bool): + assert _ref_is_informative(ref) is expected + + +def _fake_chain(title: str = "Generated test title", capture: dict | None = None): + """Build a Mock that mimics ``llm.with_structured_output(...).with_retry(...).with_fallbacks(...)``.""" + chain = MagicMock() + chain.with_structured_output.return_value = chain + chain.with_retry.return_value = chain + chain.with_fallbacks.return_value = chain + chain.with_config.return_value = chain + + def _invoke(messages): + if capture is not None: + capture["messages"] = messages + return GeneratedTitle(title=title) + + chain.invoke.side_effect = _invoke + return chain + + +@pytest.mark.django_db +class TestGenerateTitleTask: + def _make_activity(self, *, title: str = "") -> Activity: + return Activity.objects.create(trigger_type=TriggerType.API_JOB, repo_id="group/repo", title=title) + + def test_returns_silently_when_entity_missing(self): + with patch.object(titling_tasks.BaseAgent, "get_model") as get_model: + generate_title_task.func( + entity_type="activity", pk="00000000-0000-0000-0000-000000000000", prompt="any", repo_id="x/y" + ) + get_model.assert_not_called() + + def test_overwrites_existing_heuristic_title(self): + """Heuristic titles set at creation are placeholders; the LLM-generated + title must overwrite them. (No user-facing edit endpoint exists, so no + need to protect manual edits.) + """ + activity = self._make_activity(title="Heuristic placeholder") + with patch.object(titling_tasks.BaseAgent, "get_model", return_value=_fake_chain(title="LLM generated")): + generate_title_task.func(entity_type="activity", pk=str(activity.pk), prompt="any", repo_id="group/repo") + activity.refresh_from_db() + assert activity.title == "LLM generated" + + def test_returns_when_model_not_configured(self): + activity = self._make_activity() + with patch.object(titling_tasks.BaseAgent, "get_model", side_effect=RuntimeError("no key")): + generate_title_task.func(entity_type="activity", pk=str(activity.pk), prompt="any", repo_id="group/repo") + activity.refresh_from_db() + assert activity.title == "" + + def test_writes_generated_title(self): + activity = self._make_activity() + with patch.object(titling_tasks.BaseAgent, "get_model", return_value=_fake_chain(title="Add login feature")): + generate_title_task.func( + entity_type="activity", pk=str(activity.pk), prompt="add login", repo_id="group/repo" + ) + activity.refresh_from_db() + assert activity.title == "Add login feature" + + def test_user_text_includes_branch_when_informative(self): + activity = self._make_activity() + capture: dict = {} + with patch.object(titling_tasks.BaseAgent, "get_model", return_value=_fake_chain(capture=capture)): + generate_title_task.func( + entity_type="activity", + pk=str(activity.pk), + prompt="add login", + repo_id="group/repo", + ref="feat/copilotkit-chat", + ) + human_text = capture["messages"][-1].content + assert "Repository: group/repo" in human_text + assert "Branch: feat/copilotkit-chat" in human_text + assert "Task: add login" in human_text + + def test_user_text_omits_branch_for_generic_ref(self): + activity = self._make_activity() + capture: dict = {} + with patch.object(titling_tasks.BaseAgent, "get_model", return_value=_fake_chain(capture=capture)): + generate_title_task.func( + entity_type="activity", pk=str(activity.pk), prompt="add login", repo_id="group/repo", ref="main" + ) + human_text = capture["messages"][-1].content + assert "Branch:" not in human_text + + def test_prompt_truncated_to_500_chars(self): + activity = self._make_activity() + capture: dict = {} + long_prompt = "x" * 1000 + with patch.object(titling_tasks.BaseAgent, "get_model", return_value=_fake_chain(capture=capture)): + generate_title_task.func( + entity_type="activity", pk=str(activity.pk), prompt=long_prompt, repo_id="group/repo" + ) + human_text = capture["messages"][-1].content + assert human_text.endswith("x" * 500) + assert "x" * 501 not in human_text diff --git a/tests/unit_tests/chat/api/test_event_filter.py b/tests/unit_tests/chat/api/test_event_filter.py new file mode 100644 index 000000000..4c626eb67 --- /dev/null +++ b/tests/unit_tests/chat/api/test_event_filter.py @@ -0,0 +1,475 @@ +"""Tests for ``SubagentEventFilter``. + +Mirrors the live ag_ui_langgraph stream in three respects: top-level events +carry an empty/single-segment ``langgraph_checkpoint_ns``; nested subagent +events carry a ``":UUID|:UUID"`` ns; and the parent's ``task`` +TOOL_CALL_START is delivered late (via the OnToolEnd re-emit) *after* the +subagent has already streamed events. +""" + +from ag_ui.core.events import ( + EventType, + StateSnapshotEvent, + TextMessageContentEvent, + TextMessageStartEvent, + ToolCallArgsEvent, + ToolCallEndEvent, + ToolCallResultEvent, + ToolCallStartEvent, +) + +from chat.api.event_filter import SubagentEventFilter + + +def _ev(klass, *, ns: str = "", chunk: dict | None = None, **kwargs): + raw: dict = {"metadata": {"langgraph_checkpoint_ns": ns}} + if chunk is not None: + raw["data"] = {"chunk": chunk} + return klass(raw_event=raw, **kwargs) + + +async def _drain(stream): + return [ev async for ev in stream] + + +async def _aiter(items): + for it in items: + yield it + + +def _filter(events): + return SubagentEventFilter().apply(_aiter(events)) + + +async def test_filter_drops_nested_subagent_events(): + events = [ + _ev(TextMessageStartEvent, ns="model:abc", message_id="m1", role="assistant"), + _ev(TextMessageContentEvent, ns="model:abc", message_id="m1", delta="hi"), + _ev(TextMessageStartEvent, ns="tools:par|model:sub", message_id="m2", role="assistant"), + _ev(TextMessageContentEvent, ns="tools:par|model:sub", message_id="m2", delta="bleed"), + ] + out = await _drain(_filter(events)) + # Only the two top-level (ns without ``|``) events survive. + assert [e.type for e in out] == [EventType.TEXT_MESSAGE_START, EventType.TEXT_MESSAGE_CONTENT] + assert all("|" not in (e.raw_event or {}).get("metadata", {}).get("langgraph_checkpoint_ns", "") for e in out) + + +async def test_filter_synthesizes_task_start_before_first_nested_event(): + snapshot = _ev( + StateSnapshotEvent, + ns="", + snapshot={ + "messages": [ + {"type": "human", "content": "go"}, + { + "type": "ai", + "content": "Now launching subagent.", + "tool_calls": [{"id": "tc-task-1", "name": "task", "args": {"subagent_type": "explore"}}], + }, + ] + }, + ) + nested = _ev(TextMessageStartEvent, ns="tools:par|model:sub", message_id="m1", role="assistant") + out = await _drain(_filter([snapshot, nested])) + types = [e.type for e in out] + assert types == [ + EventType.STATE_SNAPSHOT, + EventType.TOOL_CALL_START, + EventType.TOOL_CALL_ARGS, + EventType.TOOL_CALL_END, + ] + start = out[1] + assert start.tool_call_id == "tc-task-1" + assert start.tool_call_name == "task" + args = out[2] + assert args.tool_call_id == "tc-task-1" + assert "explore" in args.delta + + +async def test_filter_drops_late_reemit_for_synthesized_task(): + snapshot = _ev( + StateSnapshotEvent, + ns="", + snapshot={ + "messages": [ + {"type": "ai", "content": "...", "tool_calls": [{"id": "tc-task-1", "name": "task", "args": {"x": 1}}]} + ] + }, + ) + nested = _ev(TextMessageStartEvent, ns="tools:par|model:sub", message_id="m1", role="assistant") + # The OnToolEnd re-emit re-issues START/ARGS/END for the *same* tool_call_id + # at the top level, after the subagent has finished. + late_start = _ev(ToolCallStartEvent, ns="tools:par", tool_call_id="tc-task-1", tool_call_name="task") + late_args = _ev(ToolCallArgsEvent, ns="tools:par", tool_call_id="tc-task-1", delta='{"x":1}') + late_end = _ev(ToolCallEndEvent, ns="tools:par", tool_call_id="tc-task-1") + result = _ev(ToolCallResultEvent, ns="", tool_call_id="tc-task-1", message_id="r1", content="done") + + out = await _drain(_filter([snapshot, nested, late_start, late_args, late_end, result])) + types = [e.type for e in out] + # Synthetic START/ARGS/END (3) + STATE_SNAPSHOT + RESULT — no late re-emit. + assert types == [ + EventType.STATE_SNAPSHOT, + EventType.TOOL_CALL_START, + EventType.TOOL_CALL_ARGS, + EventType.TOOL_CALL_END, + EventType.TOOL_CALL_RESULT, + ] + + +async def test_filter_does_not_synthesize_for_already_emitted_task(): + snap1 = _ev( + StateSnapshotEvent, + ns="", + snapshot={"messages": [{"type": "ai", "tool_calls": [{"id": "tc-1", "name": "task", "args": {}}]}]}, + ) + nested1 = _ev(TextMessageStartEvent, ns="tools:p|model:s", message_id="m", role="assistant") + # A *second* top-level snapshot rebroadcasts the same already-streamed + # tool_call. We must not re-emit synthetic events for it. + snap2 = _ev( + StateSnapshotEvent, + ns="", + snapshot={"messages": [{"type": "ai", "tool_calls": [{"id": "tc-1", "name": "task", "args": {}}]}]}, + ) + nested2 = _ev(TextMessageStartEvent, ns="tools:p|model:s", message_id="m2", role="assistant") + out = await _drain(_filter([snap1, nested1, snap2, nested2])) + starts = [e for e in out if e.type == EventType.TOOL_CALL_START] + assert len(starts) == 1 + + +async def test_filter_passes_non_task_tool_calls_through(): + # A plain ``read_file`` tool call must surface unchanged — only ``task`` + # tool_calls are intercepted by the synthesize/dedupe path. + events = [ + _ev(ToolCallStartEvent, ns="tools:p", tool_call_id="tc-rf", tool_call_name="read_file"), + _ev(ToolCallArgsEvent, ns="tools:p", tool_call_id="tc-rf", delta='{"path":"/a"}'), + _ev(ToolCallEndEvent, ns="tools:p", tool_call_id="tc-rf"), + _ev(ToolCallResultEvent, ns="", tool_call_id="tc-rf", message_id="r", content="contents"), + ] + out = await _drain(_filter(events)) + assert [e.type for e in out] == [ + EventType.TOOL_CALL_START, + EventType.TOOL_CALL_ARGS, + EventType.TOOL_CALL_END, + EventType.TOOL_CALL_RESULT, + ] + + +async def test_filter_handles_parallel_task_tool_calls(): + snapshot = _ev( + StateSnapshotEvent, + ns="", + snapshot={ + "messages": [ + { + "type": "ai", + "tool_calls": [ + {"id": "tc-a", "name": "task", "args": {"subagent_type": "explore"}}, + {"id": "tc-b", "name": "task", "args": {"subagent_type": "general-purpose"}}, + ], + } + ] + }, + ) + nested = _ev(TextMessageStartEvent, ns="tools:p|model:s", message_id="m", role="assistant") + out = await _drain(_filter([snapshot, nested])) + starts = [e for e in out if e.type == EventType.TOOL_CALL_START] + assert {s.tool_call_id for s in starts} == {"tc-a", "tc-b"} + + +async def test_filter_skips_args_event_when_args_are_empty(): + # An empty-dict args field should not produce a useless TOOL_CALL_ARGS + # frame carrying ``"{}"`` — the chat UI would render it as a stray empty + # delta. Tested for both ``{}`` and ``""``. + for empty_args in ({}, ""): + snapshot = _ev( + StateSnapshotEvent, + ns="", + snapshot={"messages": [{"type": "ai", "tool_calls": [{"id": "tc-x", "name": "task", "args": empty_args}]}]}, + ) + nested = _ev(TextMessageStartEvent, ns="tools:p|model:s", message_id="m", role="assistant") + out = await _drain(_filter([snapshot, nested])) + assert [e.type for e in out] == [EventType.STATE_SNAPSHOT, EventType.TOOL_CALL_START, EventType.TOOL_CALL_END] + + +async def test_filter_handles_langchain_message_objects_in_snapshot(): + # In the live ag_ui_langgraph stream the snapshot's ``messages`` arrive as + # LangChain BaseMessage instances, not dicts (the AGUI encoder serializes + # them later, but our filter sits in front of the encoder). + from langchain_core.messages import AIMessage, HumanMessage + + snapshot = _ev( + StateSnapshotEvent, + ns="", + snapshot={ + "messages": [ + HumanMessage(content="hi", id="u1"), + AIMessage( + content="launching", id="m1", tool_calls=[{"id": "tc-task-1", "name": "task", "args": {"x": 1}}] + ), + ] + }, + ) + nested = _ev(TextMessageStartEvent, ns="tools:p|model:s", message_id="m", role="assistant") + out = await _drain(_filter([snapshot, nested])) + starts = [e for e in out if e.type == EventType.TOOL_CALL_START] + assert [s.tool_call_id for s in starts] == ["tc-task-1"] + + +async def test_filter_passes_late_reemit_through_when_no_nested_event_fired(): + # If a subagent returns immediately (or upstream skips emitting nested + # events for some other reason), synthesis never fires. The late + # OnToolEnd re-emit then becomes the user-visible TOOL_CALL_START — it + # must NOT be deduped because nothing was synthesized for that tcid. + snapshot = _ev( + StateSnapshotEvent, + ns="", + snapshot={"messages": [{"type": "ai", "tool_calls": [{"id": "tc-fast", "name": "task", "args": {"x": 1}}]}]}, + ) + late_start = _ev(ToolCallStartEvent, ns="tools:p", tool_call_id="tc-fast", tool_call_name="task") + late_args = _ev(ToolCallArgsEvent, ns="tools:p", tool_call_id="tc-fast", delta='{"x":1}') + late_end = _ev(ToolCallEndEvent, ns="tools:p", tool_call_id="tc-fast") + result = _ev(ToolCallResultEvent, ns="", tool_call_id="tc-fast", message_id="r", content="done") + out = await _drain(_filter([snapshot, late_start, late_args, late_end, result])) + assert [e.type for e in out] == [ + EventType.STATE_SNAPSHOT, + EventType.TOOL_CALL_START, + EventType.TOOL_CALL_ARGS, + EventType.TOOL_CALL_END, + EventType.TOOL_CALL_RESULT, + ] + + +async def test_filter_passes_tool_call_result_for_synthesized_task(): + # Isolated regression test for the docstring's "RESULT flows through + # untouched" claim. Locks in that the dedup logic only targets the + # START/ARGS/END trio, never RESULT. + snapshot = _ev( + StateSnapshotEvent, + ns="", + snapshot={"messages": [{"type": "ai", "tool_calls": [{"id": "tc-r", "name": "task", "args": {}}]}]}, + ) + nested = _ev(TextMessageStartEvent, ns="tools:p|model:s", message_id="m", role="assistant") + result = _ev(ToolCallResultEvent, ns="", tool_call_id="tc-r", message_id="r", content="payload") + out = await _drain(_filter([snapshot, nested, result])) + [tcr] = [e for e in out if e.type == EventType.TOOL_CALL_RESULT] + assert tcr.tool_call_id == "tc-r" + assert tcr.content == "payload" + + +async def test_filter_synthesizes_args_string_passthrough_without_reencoding(): + # ``args`` arrives as a partial JSON string when the parent's chat model + # streams the tool_call mid-flight. Re-encoding via ``json.dumps`` would + # double-quote it; the filter must pass strings through verbatim. + snapshot = _ev( + StateSnapshotEvent, + ns="", + snapshot={"messages": [{"type": "ai", "tool_calls": [{"id": "tc-s", "name": "task", "args": '{"q":"x"}'}]}]}, + ) + nested = _ev(TextMessageStartEvent, ns="tools:p|model:s", message_id="m", role="assistant") + out = await _drain(_filter([snapshot, nested])) + [args_ev] = [e for e in out if e.type == EventType.TOOL_CALL_ARGS] + assert args_ev.delta == '{"q":"x"}' + + +async def test_filter_only_synthesizes_latest_ai_message_tool_calls(): + # Older AI messages in the snapshot have already been streamed (their + # tool_calls are in ``task_calls`` from a prior snapshot). The filter + # must only consider the latest AIMessage so a snapshot reissuing old + # history doesn't re-trigger synthesis for already-completed tasks. + snapshot = _ev( + StateSnapshotEvent, + ns="", + snapshot={ + "messages": [ + {"type": "ai", "tool_calls": [{"id": "tc-old", "name": "task", "args": {}}]}, + {"type": "tool", "tool_call_id": "tc-old", "content": "done"}, + {"type": "ai", "tool_calls": [{"id": "tc-new", "name": "task", "args": {}}]}, + ] + }, + ) + nested = _ev(TextMessageStartEvent, ns="tools:p|model:s", message_id="m", role="assistant") + out = await _drain(_filter([snapshot, nested])) + starts = [e for e in out if e.type == EventType.TOOL_CALL_START] + assert [s.tool_call_id for s in starts] == ["tc-new"] + + +async def test_filter_drops_misrouted_args_for_sibling_tool_calls(): + # ag_ui_langgraph misroutes streamed arg deltas: when the LLM moves on + # from the first tool_call to a sibling, subsequent ARGS chunks still + # carry the *first* tool's id but the underlying chunk's + # ``tool_call_chunks[0].index`` points at the sibling. Drop those so the + # first tool's args don't get a concatenated JSON blob; the sibling is + # recovered via STATE_SNAPSHOT synthesis. + natural_start = _ev( + ToolCallStartEvent, + ns="model:m", + chunk={"tool_call_chunks": [{"index": 0, "id": "tc1", "name": "read_file", "args": ""}]}, + tool_call_id="tc1", + tool_call_name="read_file", + ) + own_arg = _ev( + ToolCallArgsEvent, + ns="model:m", + chunk={"tool_call_chunks": [{"index": 0, "args": '{"path":"a.py"}'}]}, + tool_call_id="tc1", + delta='{"path":"a.py"}', + ) + sibling_arg = _ev( + ToolCallArgsEvent, + ns="model:m", + # chunk index 1 → belongs to the second sibling tool_call; the event's + # tool_call_id is wrong (still tc1) because ag_ui_langgraph reuses + # current_stream's id. + chunk={"tool_call_chunks": [{"index": 1, "args": '{"path":"b.py"}'}]}, + tool_call_id="tc1", + delta='{"path":"b.py"}', + ) + out = await _drain(_filter([natural_start, own_arg, sibling_arg])) + assert [e.type for e in out] == [EventType.TOOL_CALL_START, EventType.TOOL_CALL_ARGS] + assert out[1].delta == '{"path":"a.py"}' + + +async def test_filter_synthesizes_sibling_tool_calls_after_natural_first(): + # The realistic multi-tool-call sequence: tc1 streams naturally (its + # TOOL_CALL_START + correct-index ARGS pass through), tc2's args are + # misrouted-and-dropped, then STATE_SNAPSHOT arrives with both tcids. + # tc1 must NOT be re-synthesized (it's already on the wire); tc2 MUST be + # synthesized so its segment exists when its TOOL_CALL_RESULT arrives. + natural_start = _ev( + ToolCallStartEvent, + ns="model:m", + chunk={"tool_call_chunks": [{"index": 0, "id": "tc1", "name": "read_file", "args": ""}]}, + tool_call_id="tc1", + tool_call_name="read_file", + ) + own_arg = _ev( + ToolCallArgsEvent, + ns="model:m", + chunk={"tool_call_chunks": [{"index": 0, "args": '{"path":"a.py"}'}]}, + tool_call_id="tc1", + delta='{"path":"a.py"}', + ) + misrouted = _ev( + ToolCallArgsEvent, + ns="model:m", + chunk={"tool_call_chunks": [{"index": 1, "args": '{"path":"b.py"}'}]}, + tool_call_id="tc1", + delta='{"path":"b.py"}', + ) + snapshot = _ev( + StateSnapshotEvent, + ns="", + snapshot={ + "messages": [ + { + "type": "ai", + "tool_calls": [ + {"id": "tc1", "name": "read_file", "args": {"path": "a.py"}}, + {"id": "tc2", "name": "read_file", "args": {"path": "b.py"}}, + ], + } + ] + }, + ) + result_2 = _ev(ToolCallResultEvent, ns="", tool_call_id="tc2", message_id="r2", content="contents-b") + out = await _drain(_filter([natural_start, own_arg, misrouted, snapshot, result_2])) + + starts = [(e.tool_call_id, getattr(e, "tool_call_name", None)) for e in out if e.type == EventType.TOOL_CALL_START] + # tc1 from natural stream + tc2 synthesized — tc1 NOT duplicated. + assert starts == [("tc1", "read_file"), ("tc2", "read_file")] + args = [(e.tool_call_id, e.delta) for e in out if e.type == EventType.TOOL_CALL_ARGS] + # tc1's own (index 0) arg + synthesized tc2 args; misrouted index-1 dropped. + assert args == [("tc1", '{"path":"a.py"}'), ("tc2", '{"path": "b.py"}')] + + +async def test_filter_synthesizes_for_non_task_tool_calls_dropped_by_ag_ui(): + # When ag_ui_langgraph drops the natural TOOL_CALL_START on a + # text→tool_call transition (the same code path that drops ``task`` + # starts), we still synthesize from STATE_SNAPSHOT regardless of the + # tool name — non-``task`` calls should also get a segment so their + # RESULT can find one. + snapshot = _ev( + StateSnapshotEvent, + ns="", + snapshot={ + "messages": [ + { + "type": "ai", + "content": "Looking up the file.", + "tool_calls": [{"id": "tc-rf", "name": "read_file", "args": {"path": "x.py"}}], + } + ] + }, + ) + out = await _drain(_filter([snapshot])) + starts = [e for e in out if e.type == EventType.TOOL_CALL_START] + assert [(s.tool_call_id, s.tool_call_name) for s in starts] == [("tc-rf", "read_file")] + + +async def test_filter_passes_args_with_malformed_chunk_shapes_through(): + # Malformed/missing chunk shapes must not crash and must not be + # misclassified as misrouted — the event passes through unchanged so the + # natural arg streaming for the first tool_call survives upstream + # changes. + cases = [ + # No data key. + {"metadata": {"langgraph_checkpoint_ns": ""}}, + # data.chunk is None. + {"metadata": {"langgraph_checkpoint_ns": ""}, "data": {"chunk": None}}, + # tool_call_chunks missing. + {"metadata": {"langgraph_checkpoint_ns": ""}, "data": {"chunk": {}}}, + # tool_call_chunks is empty. + {"metadata": {"langgraph_checkpoint_ns": ""}, "data": {"chunk": {"tool_call_chunks": []}}}, + # First chunk lacks index. + {"metadata": {"langgraph_checkpoint_ns": ""}, "data": {"chunk": {"tool_call_chunks": [{"args": "x"}]}}}, + # First chunk has non-int index. + { + "metadata": {"langgraph_checkpoint_ns": ""}, + "data": {"chunk": {"tool_call_chunks": [{"index": "1", "args": "x"}]}}, + }, + ] + for raw in cases: + ev = ToolCallArgsEvent(raw_event=raw, tool_call_id="tc1", delta="x") + out = await _drain(_filter([ev])) + assert [e.type for e in out] == [EventType.TOOL_CALL_ARGS], f"failed for raw={raw}" + + +async def test_filter_drops_misrouted_args_for_already_synthesized_tcid(): + # If a tcid was synthesized (not naturally started) AND a misrouted ARGS + # arrives carrying it, both drop conditions hold. Lock in that the event + # is dropped exactly once — order of the two checks must not matter. + snapshot = _ev( + StateSnapshotEvent, + ns="", + snapshot={"messages": [{"type": "ai", "tool_calls": [{"id": "tc-s", "name": "task", "args": {"x": 1}}]}]}, + ) + misrouted = _ev( + ToolCallArgsEvent, + ns="model:m", + chunk={"tool_call_chunks": [{"index": 1, "args": "y"}]}, + tool_call_id="tc-s", + delta="y", + ) + out = await _drain(_filter([snapshot, misrouted])) + args_events = [e for e in out if e.type == EventType.TOOL_CALL_ARGS] + # Only the synthesized ARGS for tc-s — the misrouted one is dropped. + assert [a.delta for a in args_events] == ['{"x": 1}'] + + +async def test_filter_skips_synthesis_when_latest_ai_has_no_tool_calls(): + # _iter_latest_tool_calls returns after the first AI message even if its + # tool_calls is empty — older AI messages' tool_calls were already + # emitted on prior snapshots and must not re-synthesize. + snapshot = _ev( + StateSnapshotEvent, + ns="", + snapshot={ + "messages": [ + {"type": "ai", "tool_calls": [{"id": "tc-old", "name": "task", "args": {}}]}, + {"type": "tool", "tool_call_id": "tc-old", "content": "done"}, + {"type": "ai", "content": "all done", "tool_calls": []}, + ] + }, + ) + out = await _drain(_filter([snapshot])) + assert [e.type for e in out] == [EventType.STATE_SNAPSHOT] diff --git a/tests/unit_tests/chat/api/test_streaming.py b/tests/unit_tests/chat/api/test_streaming.py new file mode 100644 index 000000000..ab669fc6f --- /dev/null +++ b/tests/unit_tests/chat/api/test_streaming.py @@ -0,0 +1,255 @@ +"""Direct unit tests for ``ChatRunStreamer.events()``. + +Covers the streaming-specific behavior the HTTP-level tests in test_views.py +don't reach — most importantly the STATE_SNAPSHOT-driven ``last_mr`` capture +that keeps the composer MR pill alive across reloads, and the run-slot +lifecycle invariants. +""" + +from types import SimpleNamespace +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from ag_ui.core.events import EventType, StateSnapshotEvent +from ag_ui.encoder import EventEncoder + +from chat.api.streaming import ChatRunStreamer + + +def _mock_ctx(*_args, **_kwargs): + """Async context manager yielding a MagicMock — stands in for ``open_checkpointer`` + / ``set_runtime_ctx`` so we don't touch Redis or clone a repo. + """ + ctx = MagicMock() + ctx.__aenter__ = AsyncMock(return_value=MagicMock()) + ctx.__aexit__ = AsyncMock(return_value=None) + return ctx + + +def _streamer(input_data=None) -> ChatRunStreamer: + if input_data is None: + input_data = SimpleNamespace(thread_id="t-stream", run_id="r-1") + return ChatRunStreamer( + repo_id="a/b", + ref="main", + thread_id="t-stream", + run_id="r-1", + input_data=input_data, + encoder=EventEncoder(accept="text/event-stream"), + ) + + +def _mock_agent(events): + """Patch ``RuntimeContextLangGraphAGUIAgent`` so its instance's ``run()`` yields + the supplied iterable of AGUI events. + """ + + async def _run(_input): + for e in events: + yield e + + instance = MagicMock() + instance.run = _run + return instance + + +@pytest.mark.django_db(transaction=True) +async def test_events_captures_merge_request_from_state_snapshot_and_persists_ref(): + # MR carried through state survives encoder serialization as a dict; the + # capture branch in events() preserves whatever value lands in the snapshot. + mr = {"source_branch": "feature-y", "merge_request_id": 42} + snapshot = StateSnapshotEvent( + type=EventType.STATE_SNAPSHOT, + raw_event={"metadata": {"langgraph_checkpoint_ns": ""}}, + snapshot={"merge_request": mr}, + ) + + persist_calls = [] + release_calls = [] + + async def _capture_persist(thread_id, original_ref, captured_mr): + persist_calls.append((thread_id, original_ref, captured_mr)) + + async def _capture_release(thread_id, run_id): + release_calls.append((thread_id, run_id)) + + with ( + patch("chat.api.streaming.open_checkpointer", _mock_ctx), + patch("chat.api.streaming.set_runtime_ctx", _mock_ctx), + patch("chat.api.streaming.create_daiv_agent", new=AsyncMock()), + patch("chat.api.streaming.RuntimeContextLangGraphAGUIAgent", return_value=_mock_agent([snapshot])), + patch("chat.api.streaming.ChatThreadService.persist_ref", side_effect=_capture_persist), + patch("chat.api.streaming.ChatThreadService.release_run", side_effect=_capture_release), + patch("chat.api.streaming.ChatThreadService.heartbeat", new=AsyncMock()), + ): + streamer = _streamer() + async for _ in streamer.events(): + pass + + assert persist_calls == [("t-stream", "main", mr)] + assert release_calls == [("t-stream", "r-1")] + + +@pytest.mark.django_db(transaction=True) +async def test_events_captures_latest_merge_request_when_multiple_snapshots(): + """Multiple snapshots arrive — last one wins. Regression for + accidentally rewriting capture as ``last_mr or ...``. + """ + mr_first = {"source_branch": "feature-x"} + mr_last = {"source_branch": "feature-final"} + snap_first = StateSnapshotEvent( + type=EventType.STATE_SNAPSHOT, + raw_event={"metadata": {"langgraph_checkpoint_ns": ""}}, + snapshot={"merge_request": mr_first}, + ) + snap_last = StateSnapshotEvent( + type=EventType.STATE_SNAPSHOT, + raw_event={"metadata": {"langgraph_checkpoint_ns": ""}}, + snapshot={"merge_request": mr_last}, + ) + + persist_calls = [] + + async def _capture_persist(thread_id, original_ref, captured_mr): + persist_calls.append((thread_id, original_ref, captured_mr)) + + with ( + patch("chat.api.streaming.open_checkpointer", _mock_ctx), + patch("chat.api.streaming.set_runtime_ctx", _mock_ctx), + patch("chat.api.streaming.create_daiv_agent", new=AsyncMock()), + patch("chat.api.streaming.RuntimeContextLangGraphAGUIAgent", return_value=_mock_agent([snap_first, snap_last])), + patch("chat.api.streaming.ChatThreadService.persist_ref", side_effect=_capture_persist), + patch("chat.api.streaming.ChatThreadService.release_run", new=AsyncMock()), + patch("chat.api.streaming.ChatThreadService.heartbeat", new=AsyncMock()), + ): + async for _ in _streamer().events(): + pass + + assert persist_calls == [("t-stream", "main", mr_last)] + + +@pytest.mark.django_db(transaction=True) +async def test_events_persists_none_when_no_state_snapshot_carries_merge_request(): + # A run that never emits a snapshot with ``merge_request`` should leave the + # thread's ref untouched. We assert this via persist_ref receiving None. + snapshot_no_mr = StateSnapshotEvent( + type=EventType.STATE_SNAPSHOT, + raw_event={"metadata": {"langgraph_checkpoint_ns": ""}}, + snapshot={"messages": []}, + ) + + persist_calls = [] + + async def _capture_persist(thread_id, original_ref, captured_mr): + persist_calls.append((thread_id, original_ref, captured_mr)) + + with ( + patch("chat.api.streaming.open_checkpointer", _mock_ctx), + patch("chat.api.streaming.set_runtime_ctx", _mock_ctx), + patch("chat.api.streaming.create_daiv_agent", new=AsyncMock()), + patch("chat.api.streaming.RuntimeContextLangGraphAGUIAgent", return_value=_mock_agent([snapshot_no_mr])), + patch("chat.api.streaming.ChatThreadService.persist_ref", side_effect=_capture_persist), + patch("chat.api.streaming.ChatThreadService.release_run", new=AsyncMock()), + patch("chat.api.streaming.ChatThreadService.heartbeat", new=AsyncMock()), + ): + async for _ in _streamer().events(): + pass + + assert persist_calls == [("t-stream", "main", None)] + + +@pytest.mark.django_db(transaction=True) +async def test_events_skips_persist_ref_when_run_errored(): + """A partial run must not pin ``ref`` to whatever interim branch a snapshot + captured before the failure — the user would then reload onto half-built state. + """ + interim_mr = {"source_branch": "feature-half"} + snap = StateSnapshotEvent( + type=EventType.STATE_SNAPSHOT, + raw_event={"metadata": {"langgraph_checkpoint_ns": ""}}, + snapshot={"merge_request": interim_mr}, + ) + + async def _events_then_boom(): + yield snap + raise RuntimeError("kaboom") + + runner = MagicMock() + runner.run = lambda _input: _events_then_boom() + + persist_calls: list = [] + release_calls: list = [] + + async def _capture_persist(*args): + persist_calls.append(args) + + async def _capture_release(*args): + release_calls.append(args) + + with ( + patch("chat.api.streaming.open_checkpointer", _mock_ctx), + patch("chat.api.streaming.set_runtime_ctx", _mock_ctx), + patch("chat.api.streaming.create_daiv_agent", new=AsyncMock()), + patch("chat.api.streaming.RuntimeContextLangGraphAGUIAgent", return_value=runner), + patch("chat.api.streaming.ChatThreadService.persist_ref", side_effect=_capture_persist), + patch("chat.api.streaming.ChatThreadService.release_run", side_effect=_capture_release), + patch("chat.api.streaming.ChatThreadService.heartbeat", new=AsyncMock()), + ): + async for _ in _streamer().events(): + pass + + assert persist_calls == [] + # Release still fires regardless of outcome — that's the slot-leak guard. + assert release_calls == [("t-stream", "r-1")] + + +@pytest.mark.django_db(transaction=True) +async def test_events_releases_run_even_when_persist_ref_raises(): + # Regression: a DB hiccup in persist_ref must not leave the per-thread + # slot permanently claimed. + release_calls = [] + + async def _persist_boom(*_a, **_kw): + raise RuntimeError("db down") + + async def _capture_release(thread_id, run_id): + release_calls.append((thread_id, run_id)) + + with ( + patch("chat.api.streaming.open_checkpointer", _mock_ctx), + patch("chat.api.streaming.set_runtime_ctx", _mock_ctx), + patch("chat.api.streaming.create_daiv_agent", new=AsyncMock()), + patch("chat.api.streaming.RuntimeContextLangGraphAGUIAgent", return_value=_mock_agent([])), + patch("chat.api.streaming.ChatThreadService.persist_ref", side_effect=_persist_boom), + patch("chat.api.streaming.ChatThreadService.release_run", side_effect=_capture_release), + patch("chat.api.streaming.ChatThreadService.heartbeat", new=AsyncMock()), + ): + async for _ in _streamer().events(): + pass + + assert release_calls == [("t-stream", "r-1")] + + +def test_streamer_post_init_rejects_thread_id_mismatch(): + """Construction-time guard: thread_id/run_id must match input_data.""" + with pytest.raises(ValueError, match="thread_id mismatch"): + ChatRunStreamer( + repo_id="a/b", + ref="main", + thread_id="t-foo", + run_id="r-1", + input_data=SimpleNamespace(thread_id="t-bar", run_id="r-1"), + encoder=EventEncoder(accept="text/event-stream"), + ) + + +def test_streamer_post_init_rejects_run_id_mismatch(): + with pytest.raises(ValueError, match="run_id mismatch"): + ChatRunStreamer( + repo_id="a/b", + ref="main", + thread_id="t", + run_id="r-1", + input_data=SimpleNamespace(thread_id="t", run_id="r-2"), + encoder=EventEncoder(accept="text/event-stream"), + ) diff --git a/tests/unit_tests/chat/api/test_threads.py b/tests/unit_tests/chat/api/test_threads.py new file mode 100644 index 000000000..f8aed8c8e --- /dev/null +++ b/tests/unit_tests/chat/api/test_threads.py @@ -0,0 +1,185 @@ +import asyncio +from datetime import timedelta +from types import SimpleNamespace +from unittest.mock import patch + +from django.utils import timezone + +import pytest + +from accounts.models import User +from chat.api.threads import STALE_RUN_MINUTES, ChatThreadService, _extract_first_user_message +from chat.models import ChatThread + + +def _fake_input(messages, *, role="user"): + return SimpleNamespace(messages=[SimpleNamespace(role=role, content=c) for c in messages]) + + +def test_extract_first_user_message_empty_returns_empty_string(): + assert _extract_first_user_message(_fake_input([])) == "" + + +def test_extract_first_user_message_skips_non_string_content(): + # AG-UI supports list-of-blocks content for multimodal messages; they shouldn't be used + # as a title. We fall through to the next string-typed message. + payload = _fake_input([[{"type": "text", "text": "hi"}], "fallback title"]) + assert _extract_first_user_message(payload) == "fallback title" + + +def test_extract_first_user_message_skips_whitespace_only(): + assert _extract_first_user_message(_fake_input([" \n\t", "actual content"])) == "actual content" + + +def test_extract_first_user_message_returns_first_non_empty_string(): + assert _extract_first_user_message(_fake_input(["first", "second"])) == "first" + + +def test_extract_first_user_message_skips_non_user_roles(): + # Title should be derived from a human/user message, never from an assistant + # bootstrap message that happened to land in input_data.messages first. + msgs = SimpleNamespace( + messages=[ + SimpleNamespace(role="assistant", content="Hi! How can I help?"), + SimpleNamespace(role="user", content="actual ask"), + ] + ) + assert _extract_first_user_message(msgs) == "actual ask" + + +@pytest.mark.django_db(transaction=True) +async def test_persist_ref_updates_when_branch_changed(): + user = await User.objects.acreate_user(username="u-ref-1", email="ref1@x.com", password="x") # noqa: S106 + await ChatThread.objects.acreate(thread_id="t-ref-1", user=user, repo_id="a/b", ref="feature-x") + + await ChatThreadService.persist_ref("t-ref-1", "feature-x", SimpleNamespace(source_branch="feature-y")) + + refreshed = await ChatThread.objects.aget(thread_id="t-ref-1") + assert refreshed.ref == "feature-y" + await user.adelete() + + +@pytest.mark.django_db(transaction=True) +async def test_persist_ref_noop_when_branch_unchanged(): + with patch("chat.api.threads.ChatThread.objects.filter") as filter_mock: + await ChatThreadService.persist_ref("t-ref-2", "feature-x", SimpleNamespace(source_branch="feature-x")) + + filter_mock.assert_not_called() + + +@pytest.mark.django_db(transaction=True) +async def test_persist_ref_noop_when_no_mr_captured(): + with patch("chat.api.threads.ChatThread.objects.filter") as filter_mock: + await ChatThreadService.persist_ref("t-ref-3", "feature-x", None) + + filter_mock.assert_not_called() + + +@pytest.mark.django_db(transaction=True) +async def test_try_claim_run_succeeds_on_free_slot(): + user = await User.objects.acreate_user(username="u-claim-1", email="c1@x.com", password="x") # noqa: S106 + await ChatThread.objects.acreate(thread_id="t-claim-1", user=user, repo_id="a/b", ref="main") + + assert await ChatThreadService.try_claim_run("t-claim-1", "r-1") is True + + refreshed = await ChatThread.objects.aget(thread_id="t-claim-1") + assert refreshed.active_run_id == "r-1" + await user.adelete() + + +@pytest.mark.django_db(transaction=True) +async def test_try_claim_run_fails_on_held_slot(): + user = await User.objects.acreate_user(username="u-claim-2", email="c2@x.com", password="x") # noqa: S106 + await ChatThread.objects.acreate( + thread_id="t-claim-2", user=user, repo_id="a/b", ref="main", active_run_id="r-existing" + ) + + assert await ChatThreadService.try_claim_run("t-claim-2", "r-new") is False + + refreshed = await ChatThread.objects.aget(thread_id="t-claim-2") + # Loser does not overwrite the winner's run_id. + assert refreshed.active_run_id == "r-existing" + await user.adelete() + + +@pytest.mark.django_db(transaction=True) +async def test_try_claim_run_concurrent_calls_yield_exactly_one_winner(): + # Direct regression test for the TOCTOU fix in commit dde32f93. The whole + # point of the conditional UPDATE is that two simultaneous claims can't + # both succeed; this asserts the protocol holds when the calls overlap. + user = await User.objects.acreate_user(username="u-claim-3", email="c3@x.com", password="x") # noqa: S106 + await ChatThread.objects.acreate(thread_id="t-claim-3", user=user, repo_id="a/b", ref="main") + + results = await asyncio.gather( + ChatThreadService.try_claim_run("t-claim-3", "r-A"), ChatThreadService.try_claim_run("t-claim-3", "r-B") + ) + assert sorted(results) == [False, True] + + refreshed = await ChatThread.objects.aget(thread_id="t-claim-3") + assert refreshed.active_run_id in ("r-A", "r-B") + await user.adelete() + + +@pytest.mark.django_db(transaction=True) +async def test_release_run_clears_slot_and_reopens_for_claim(): + user = await User.objects.acreate_user(username="u-rel", email="rel@x.com", password="x") # noqa: S106 + await ChatThread.objects.acreate(thread_id="t-rel", user=user, repo_id="a/b", ref="main", active_run_id="r-old") + + await ChatThreadService.release_run("t-rel", "r-old") + refreshed = await ChatThread.objects.aget(thread_id="t-rel") + assert refreshed.active_run_id is None + + # Next claim succeeds — the slot is genuinely free, not just blanked. + assert await ChatThreadService.try_claim_run("t-rel", "r-next") is True + await user.adelete() + + +@pytest.mark.django_db(transaction=True) +async def test_release_run_does_not_clear_other_holders_slot(): + """Stale `finally` from a cancelled run must not stomp a freshly-claimed slot.""" + user = await User.objects.acreate_user(username="u-rel-mismatch", email="chat@example.com", password="x") # noqa: S106 + await ChatThread.objects.acreate(thread_id="t-rel-x", user=user, repo_id="a/b", ref="main", active_run_id="r-fresh") + + # Stale streamer's finally tries to release with the OLD run_id. + await ChatThreadService.release_run("t-rel-x", "r-stale") + + refreshed = await ChatThread.objects.aget(thread_id="t-rel-x") + assert refreshed.active_run_id == "r-fresh" # untouched + await user.adelete() + + +@pytest.mark.django_db(transaction=True) +async def test_try_claim_run_takes_over_stale_slot(): + """Worker crash leaves active_run_id set; after the heartbeat window expires + a fresh claim succeeds. Without this the thread would be permanently locked. + """ + user = await User.objects.acreate_user(username="u-stale", email="owner@example.com", password="x") # noqa: S106 + stale_at = timezone.now() - timedelta(minutes=STALE_RUN_MINUTES + 1) + await ChatThread.objects.acreate(thread_id="t-stale", user=user, repo_id="a/b", ref="main", active_run_id="r-dead") + # auto_now would clobber the stale timestamp — force it via aupdate. + await ChatThread.objects.filter(thread_id="t-stale").aupdate(last_active_at=stale_at) + + assert await ChatThreadService.try_claim_run("t-stale", "r-new") is True + refreshed = await ChatThread.objects.aget(thread_id="t-stale") + assert refreshed.active_run_id == "r-new" + await user.adelete() + + +@pytest.mark.django_db(transaction=True) +async def test_heartbeat_only_bumps_when_caller_holds_slot(): + """Delayed heartbeat from a previous run must not keep a stolen slot alive.""" + user = await User.objects.acreate_user(username="u-hb", email="i@example.com", password="x") # noqa: S106 + await ChatThread.objects.acreate(thread_id="t-hb", user=user, repo_id="a/b", ref="main", active_run_id="r-current") + old_timestamp = timezone.now() - timedelta(minutes=STALE_RUN_MINUTES + 5) + await ChatThread.objects.filter(thread_id="t-hb").aupdate(last_active_at=old_timestamp) + + # Stale run heartbeats — should be a no-op because it doesn't hold the slot. + await ChatThreadService.heartbeat("t-hb", "r-stale") + refreshed = await ChatThread.objects.aget(thread_id="t-hb") + assert (timezone.now() - refreshed.last_active_at).total_seconds() > STALE_RUN_MINUTES * 60 + + # Real holder bumps successfully. + await ChatThreadService.heartbeat("t-hb", "r-current") + refreshed = await ChatThread.objects.aget(thread_id="t-hb") + assert (timezone.now() - refreshed.last_active_at).total_seconds() < 5 + await user.adelete() diff --git a/tests/unit_tests/chat/api/test_views.py b/tests/unit_tests/chat/api/test_views.py index da66c5c6e..c90090eeb 100644 --- a/tests/unit_tests/chat/api/test_views.py +++ b/tests/unit_tests/chat/api/test_views.py @@ -1,30 +1,216 @@ +from unittest.mock import AsyncMock, MagicMock, patch + import pytest from ninja.testing import TestAsyncClient -from chat.api.views import MODEL_ID +from accounts.models import APIKey, User +from chat.models import ChatThread from daiv.api import api @pytest.fixture -def client_unauthenticated(): +def client(): return TestAsyncClient(api) -@pytest.mark.django_db -async def test_create_chat_completion(client_unauthenticated: TestAsyncClient): - response = await client_unauthenticated.post( - "/chat/completions", json={"model": MODEL_ID, "messages": [{"role": "user", "content": "Hello, how are you?"}]} +@pytest.fixture +async def authed(): + """Return (APIKey, raw_key, user) for authenticated tests.""" + user = await User.objects.acreate_user( + username="chatuser", + email="chat@example.com", + password="testpass123", # noqa: S106 ) - assert response.status_code == 401 + key_obj, raw = await APIKey.objects.create_key(user=user, name="Test") + return key_obj, raw, user + + +def _auth_headers(raw_key: str, **extra) -> dict: + return {"Authorization": f"Bearer {raw_key}", **extra} + + +def _run_agent_input(**overrides) -> dict: + return { + "threadId": "t-1", + "runId": "r-1", + "state": {}, + "messages": [{"id": "m-1", "role": "user", "content": "hello"}], + "tools": [], + "context": [], + "forwardedProps": {}, + **overrides, + } + + +def _mock_stream(*_args, **_kwargs): + """Factory that returns an async context manager yielding a MagicMock. Used to patch + open_checkpointer() and set_runtime_ctx() during tests so we exercise the ownership + path without hitting Redis or cloning a repo. + """ + ctx = MagicMock() + ctx.__aenter__ = AsyncMock(return_value=MagicMock()) + ctx.__aexit__ = AsyncMock(return_value=None) + return ctx @pytest.mark.django_db -async def test_get_models_unauthenticated(client_unauthenticated: TestAsyncClient): - response = await client_unauthenticated.get("/models") - assert response.status_code == 401 +async def test_missing_repo_id_header_returns_404(client: TestAsyncClient, authed): + _, raw, user = authed + response = await client.post( + "/chat/completions", json=_run_agent_input(), headers=_auth_headers(raw, **{"X-Ref": "main"}) + ) + assert response.status_code == 404 + await user.adelete() @pytest.mark.django_db -async def test_get_model_detail_unauthenticated(client_unauthenticated: TestAsyncClient): - response = await client_unauthenticated.get(f"/models/{MODEL_ID}") - assert response.status_code == 401 +async def test_missing_ref_header_returns_404(client: TestAsyncClient, authed): + _, raw, user = authed + response = await client.post( + "/chat/completions", json=_run_agent_input(), headers=_auth_headers(raw, **{"X-Repo-ID": "owner/repo"}) + ) + assert response.status_code == 404 + await user.adelete() + + +@pytest.mark.django_db(transaction=True) +async def test_cross_user_thread_id_is_rejected(client: TestAsyncClient, authed): + _, raw, user = authed + other = await User.objects.acreate_user( + username="owner", + email="owner@example.com", + password="x", # noqa: S106 + ) + await ChatThread.objects.acreate(thread_id="t-owned", user=other, repo_id="a/b", ref="main") + + response = await client.post( + "/chat/completions", + json=_run_agent_input(threadId="t-owned"), + headers=_auth_headers(raw, **{"X-Repo-ID": "a/b", "X-Ref": "main"}), + ) + assert response.status_code == 403 + await user.adelete() + await other.adelete() + + +@pytest.mark.django_db(transaction=True) +async def test_unknown_thread_id_implicit_creates_thread(client: TestAsyncClient, authed): + _, raw, user = authed + with ( + patch("chat.api.streaming.open_checkpointer", _mock_stream), + patch("chat.api.streaming.set_runtime_ctx", _mock_stream), + patch("chat.api.streaming.create_daiv_agent", new=AsyncMock()), + patch("chat.api.streaming.RuntimeContextLangGraphAGUIAgent") as m_agent_cls, + ): + m_instance = MagicMock() + + async def _empty_stream(_input): + if False: # generator that yields nothing + yield + + m_instance.run = _empty_stream + m_agent_cls.return_value = m_instance + + response = await client.post( + "/chat/completions", + json=_run_agent_input(threadId="t-new"), + headers=_auth_headers(raw, **{"X-Repo-ID": "a/b", "X-Ref": "main"}), + ) + + assert response.status_code == 200 + created = await ChatThread.objects.filter(thread_id="t-new").afirst() + assert created is not None + assert created.user_id == user.id + assert created.repo_id == "a/b" + assert created.ref == "main" + # The finally block in ChatRunStreamer.events() clears the run slot after the stream completes. + assert created.active_run_id is None + await user.adelete() + + +@pytest.mark.django_db(transaction=True) +async def test_exception_in_stream_clears_active_run_id_and_emits_run_error(client: TestAsyncClient, authed): + _, raw, user = authed + await ChatThread.objects.acreate(thread_id="t-boom", user=user, repo_id="a/b", ref="main") + + with ( + patch("chat.api.streaming.open_checkpointer", _mock_stream), + patch("chat.api.streaming.set_runtime_ctx", _mock_stream), + patch("chat.api.streaming.create_daiv_agent", new=AsyncMock()), + patch("chat.api.streaming.RuntimeContextLangGraphAGUIAgent") as m_agent_cls, + ): + m_instance = MagicMock() + + async def _boom(_input): + if False: + yield + raise RuntimeError("kaboom") + + m_instance.run = _boom + m_agent_cls.return_value = m_instance + + response = await client.post( + "/chat/completions", + json=_run_agent_input(threadId="t-boom"), + headers=_auth_headers(raw, **{"X-Repo-ID": "a/b", "X-Ref": "main"}), + ) + + assert response.status_code == 200 + body = response.content.decode() + assert "RUN_ERROR" in body + assert "run_failed" in body + # User-facing message must not leak the raw exception class/message — that + # could expose internal paths, SQL fragments, secrets that happen to land + # in a stack trace. + assert "kaboom" not in body + assert "RuntimeError" not in body + refreshed = await ChatThread.objects.aget(thread_id="t-boom") + assert refreshed.active_run_id is None + await user.adelete() + + +@pytest.mark.django_db(transaction=True) +async def test_thread_status_reports_active_run(client: TestAsyncClient, authed): + _, raw, user = authed + await ChatThread.objects.acreate(thread_id="t-live", user=user, repo_id="a/b", ref="main", active_run_id="r-1") + await ChatThread.objects.acreate(thread_id="t-idle", user=user, repo_id="a/b", ref="main", active_run_id=None) + + live = await client.get("/chat/threads/t-live/status", headers=_auth_headers(raw)) + idle = await client.get("/chat/threads/t-idle/status", headers=_auth_headers(raw)) + + assert live.status_code == 200 + assert live.json() == {"active": True} + assert idle.status_code == 200 + assert idle.json() == {"active": False} + await user.adelete() + + +@pytest.mark.django_db(transaction=True) +async def test_thread_status_rejects_cross_user_access(client: TestAsyncClient, authed): + _, raw, user = authed + other = await User.objects.acreate_user( + username="intruder", + email="i@example.com", + password="x", # noqa: S106 + ) + await ChatThread.objects.acreate(thread_id="t-foreign", user=other, repo_id="a/b", ref="main", active_run_id="r-9") + + response = await client.get("/chat/threads/t-foreign/status", headers=_auth_headers(raw)) + assert response.status_code == 404 + await user.adelete() + await other.adelete() + + +@pytest.mark.django_db(transaction=True) +async def test_concurrent_run_returns_409(client: TestAsyncClient, authed): + _, raw, user = authed + await ChatThread.objects.acreate( + thread_id="t-busy", user=user, repo_id="a/b", ref="main", active_run_id="r-existing" + ) + response = await client.post( + "/chat/completions", + json=_run_agent_input(threadId="t-busy"), + headers=_auth_headers(raw, **{"X-Repo-ID": "a/b", "X-Ref": "main"}), + ) + assert response.status_code == 409 + await user.adelete() diff --git a/tests/unit_tests/chat/test_models.py b/tests/unit_tests/chat/test_models.py new file mode 100644 index 000000000..171e9a7b2 --- /dev/null +++ b/tests/unit_tests/chat/test_models.py @@ -0,0 +1,33 @@ +from django.db import IntegrityError + +import pytest +from activity.models import Activity, TriggerType + +from chat.models import ChatThread + + +@pytest.mark.django_db +def test_chat_thread_thread_id_is_unique_primary_key(member_user): + ChatThread.objects.create(thread_id="t-1", user=member_user, repo_id="a/b", ref="main") + with pytest.raises(IntegrityError): + ChatThread.objects.create(thread_id="t-1", user=member_user, repo_id="a/b", ref="main") + + +@pytest.mark.django_db(transaction=True) +async def test_aget_or_create_from_activity_is_idempotent(member_user): + activity = await Activity.objects.acreate( + trigger_type=TriggerType.UI_JOB, + repo_id="a/b", + ref="main", + prompt="first message", + thread_id="t-42", + user=member_user, + ) + thread_a, created_a = await ChatThread.aget_or_create_from_activity(member_user, activity) + thread_b, created_b = await ChatThread.aget_or_create_from_activity(member_user, activity) + assert created_a is True + assert created_b is False + assert thread_a.thread_id == thread_b.thread_id == "t-42" + assert thread_a.repo_id == "a/b" + assert thread_a.ref == "main" + assert thread_a.title.startswith("first message") diff --git a/tests/unit_tests/chat/test_repo_state.py b/tests/unit_tests/chat/test_repo_state.py new file mode 100644 index 000000000..8532d186a --- /dev/null +++ b/tests/unit_tests/chat/test_repo_state.py @@ -0,0 +1,150 @@ +from unittest.mock import MagicMock, patch + +import pytest + +from chat.repo_state import aget_existing_mr_payload, mr_to_payload +from codebase.base import MergeRequest +from codebase.base import User as CBUser + + +def _make_mr(**overrides): + base = { + "repo_id": "a/b", + "merge_request_id": 42, + "source_branch": "feature-x", + "target_branch": "main", + "title": "Add feature X", + "description": "", + "labels": [], + "web_url": "https://gitlab.example/a/b/-/merge_requests/42", + "sha": "deadbeef", + "author": CBUser(id=1, username="u", name="U"), + "draft": True, + } + base.update(overrides) + return MergeRequest(**base) + + +async def test_returns_none_when_repo_id_missing(): + with patch("chat.repo_state.RepositoryConfig.get_config") as get_config: + result = await aget_existing_mr_payload("", "feature-x") + assert result is None + get_config.assert_not_called() + + +async def test_returns_none_when_ref_missing(): + with patch("chat.repo_state.RepositoryConfig.get_config") as get_config: + result = await aget_existing_mr_payload("a/b", "") + assert result is None + get_config.assert_not_called() + + +async def test_returns_none_when_ref_is_default_branch_and_skips_lookup(): + repo_client = MagicMock() + with ( + patch("chat.repo_state.RepositoryConfig.get_config", return_value=MagicMock(default_branch="main")), + patch("chat.repo_state.RepoClient.create_instance", return_value=repo_client) as factory, + ): + result = await aget_existing_mr_payload("a/b", "main") + + assert result is None + factory.assert_not_called() + repo_client.get_merge_request_by_branches.assert_not_called() + + +async def test_returns_payload_on_happy_path(): + repo_client = MagicMock() + repo_client.get_merge_request_by_branches.return_value = _make_mr() + with ( + patch("chat.repo_state.RepositoryConfig.get_config", return_value=MagicMock(default_branch="main")), + patch("chat.repo_state.RepoClient.create_instance", return_value=repo_client), + ): + result = await aget_existing_mr_payload("a/b", "feature-x") + + assert result == mr_to_payload(_make_mr()) + assert result["id"] == 42 + assert result["draft"] is True + repo_client.get_merge_request_by_branches.assert_called_once_with("a/b", "feature-x", "main") + + +async def test_returns_none_when_lookup_returns_none(): + repo_client = MagicMock() + repo_client.get_merge_request_by_branches.return_value = None + with ( + patch("chat.repo_state.RepositoryConfig.get_config", return_value=MagicMock(default_branch="main")), + patch("chat.repo_state.RepoClient.create_instance", return_value=repo_client), + ): + result = await aget_existing_mr_payload("a/b", "feature-x") + + assert result is None + + +async def test_swallows_platform_errors_and_logs(caplog): + """Platform hiccups (HTTP/transport errors) degrade to None with a logged + exception. Programming bugs are NOT swallowed — they propagate. + """ + import httpx + + with ( + patch("chat.repo_state.RepositoryConfig.get_config", side_effect=httpx.ConnectError("platform unreachable")), + caplog.at_level("ERROR", logger="daiv.chat"), + ): + result = await aget_existing_mr_payload("a/b", "feature-x") + + assert result is None + assert any("Failed to look up existing merge request" in rec.message for rec in caplog.records) + + +async def test_swallows_errors_from_client_call(): + """SDK errors (gitlab/github/httpx) are caught.""" + from gitlab.exceptions import GitlabError + + repo_client = MagicMock() + repo_client.get_merge_request_by_branches.side_effect = GitlabError("api 500") + with ( + patch("chat.repo_state.RepositoryConfig.get_config", return_value=MagicMock(default_branch="main")), + patch("chat.repo_state.RepoClient.create_instance", return_value=repo_client), + ): + result = await aget_existing_mr_payload("a/b", "feature-x") + + assert result is None + + +async def test_propagates_unexpected_errors(): + """Bugs (KeyError/AttributeError/TypeError) must NOT be silently caught — + they should surface as 500s rather than masking as a fake 'no MR'. + """ + repo_client = MagicMock() + repo_client.get_merge_request_by_branches.side_effect = KeyError("missing field") + with ( + patch("chat.repo_state.RepositoryConfig.get_config", return_value=MagicMock(default_branch="main")), + patch("chat.repo_state.RepoClient.create_instance", return_value=repo_client), + pytest.raises(KeyError), + ): + await aget_existing_mr_payload("a/b", "feature-x") + + +@pytest.mark.parametrize(("payload_input", "expected_keys"), [(None, None), ("not-a-mr", None)]) +def test_mr_to_payload_handles_invalid_inputs(payload_input, expected_keys): + assert mr_to_payload(payload_input) is expected_keys + + +def test_mr_to_payload_accepts_dict_form(): + """Re-hydrated checkpointer state stores MRs as dicts; payload converter must accept that shape.""" + raw = { + "merge_request_id": 7, + "web_url": "https://x/7", + "title": "T", + "draft": True, + "source_branch": "f", + "target_branch": "m", + } + payload = mr_to_payload(raw) + assert payload == { + "id": 7, + "url": "https://x/7", + "title": "T", + "draft": True, + "source_branch": "f", + "target_branch": "m", + } diff --git a/tests/unit_tests/chat/test_turns.py b/tests/unit_tests/chat/test_turns.py new file mode 100644 index 000000000..c60b61cae --- /dev/null +++ b/tests/unit_tests/chat/test_turns.py @@ -0,0 +1,233 @@ +from __future__ import annotations + +from langchain_core.messages import AIMessage, HumanMessage, ToolMessage + +from chat.turns import build_turns + + +def test_build_turns_empty_list_returns_empty(): + assert build_turns([]) == [] + + +def test_build_turns_human_string_content_single_text_segment(): + m = HumanMessage(content="hello world", id="h-1") + result = build_turns([m]) + assert result == [{"id": "h-1", "role": "user", "segments": [{"type": "text", "content": "hello world"}]}] + + +def test_build_turns_human_list_content_joins_text_blocks(): + m = HumanMessage(content=[{"type": "text", "text": "hi "}, {"type": "text", "text": "there"}], id="h-2") + result = build_turns([m]) + assert result[0]["segments"] == [{"type": "text", "content": "hi there"}] + + +def test_build_turns_ai_string_with_tool_calls_emits_text_then_tool_segments(): + m = AIMessage( + content="Let me check the README.", + id="a-1", + tool_calls=[ + {"id": "tc-1", "name": "read_file", "args": {"path": "README.md"}}, + {"id": "tc-2", "name": "grep", "args": {"pattern": "TODO"}}, + ], + ) + result = build_turns([m]) + assert len(result) == 1 + turn = result[0] + assert turn["role"] == "assistant" + assert turn["id"] == "a-1" + assert turn["segments"] == [ + {"type": "text", "content": "Let me check the README."}, + { + "type": "tool_call", + "id": "tc-1", + "name": "read_file", + "args": '{"path": "README.md"}', + "result": None, + "status": "done", + }, + { + "type": "tool_call", + "id": "tc-2", + "name": "grep", + "args": '{"pattern": "TODO"}', + "result": None, + "status": "done", + }, + ] + + +def test_build_turns_ai_empty_content_omits_text_segment(): + m = AIMessage(content="", id="a-2", tool_calls=[{"id": "tc-3", "name": "ls", "args": {"path": "/"}}]) + result = build_turns([m]) + assert result[0]["segments"] == [ + {"type": "tool_call", "id": "tc-3", "name": "ls", "args": '{"path": "/"}', "result": None, "status": "done"} + ] + + +def test_build_turns_ai_list_content_preserves_block_interleaving(): + m = AIMessage( + content=[ + {"type": "text", "text": "Let me look."}, + {"type": "tool_use", "id": "tc-a", "name": "read_file", "input": {"path": "a.py"}}, + {"type": "text", "text": "Now I'll search."}, + {"type": "tool_use", "id": "tc-b", "name": "grep", "input": {"pattern": "x"}}, + ], + id="a-3", + tool_calls=[ + {"id": "tc-a", "name": "read_file", "args": {"path": "a.py"}}, + {"id": "tc-b", "name": "grep", "args": {"pattern": "x"}}, + ], + ) + result = build_turns([m]) + segments = result[0]["segments"] + assert [s["type"] for s in segments] == ["text", "tool_call", "text", "tool_call"] + assert segments[0]["content"] == "Let me look." + assert segments[1]["id"] == "tc-a" + assert segments[1]["name"] == "read_file" + assert segments[2]["content"] == "Now I'll search." + assert segments[3]["id"] == "tc-b" + + +def test_build_turns_ai_list_content_tool_use_without_matching_tool_call_still_emitted(): + m = AIMessage( + content=[{"type": "tool_use", "id": "tc-orphan", "name": "custom", "input": {"k": "v"}}], + id="a-4", + tool_calls=[], + ) + result = build_turns([m]) + assert result[0]["segments"] == [ + { + "type": "tool_call", + "id": "tc-orphan", + "name": "custom", + "args": '{"k": "v"}', + "result": None, + "status": "done", + } + ] + + +def test_build_turns_tool_message_result_attaches_to_matching_tool_call(): + ai = AIMessage( + content="Let me check.", id="a-5", tool_calls=[{"id": "tc-x", "name": "read_file", "args": {"path": "x.py"}}] + ) + tool = ToolMessage(content="file contents", tool_call_id="tc-x", id="t-1") + result = build_turns([ai, tool]) + assert len(result) == 1 # ToolMessages do not create their own turn + tool_seg = result[0]["segments"][1] + assert tool_seg["result"] == "file contents" + assert tool_seg["status"] == "done" + + +def test_build_turns_tool_message_list_content_joins_text_blocks(): + ai = AIMessage(content="", id="a-6", tool_calls=[{"id": "tc-y", "name": "grep", "args": {"pattern": "x"}}]) + tool = ToolMessage( + content=[{"type": "text", "text": "line-a"}, {"type": "text", "text": "line-b"}], tool_call_id="tc-y", id="t-2" + ) + result = build_turns([ai, tool]) + assert result[0]["segments"][0]["result"] == "line-a\nline-b" + + +def test_build_turns_orphan_tool_message_is_dropped_with_warning(caplog): + tool = ToolMessage(content="orphan", tool_call_id="tc-missing", id="t-3") + with caplog.at_level("WARNING", logger="daiv.chat"): + result = build_turns([tool]) + assert result == [] + assert any("tc-missing" in rec.message for rec in caplog.records) + + +def test_build_turns_skill_injection_folds_human_body_into_tool_result(): + msgs = [ + HumanMessage(content="run /plan", id="h-1"), + AIMessage(content="", id="a-1", tool_calls=[{"id": "tc-skill", "name": "skill", "args": {"skill": "plan"}}]), + ToolMessage(content="Launching skill 'plan'...", tool_call_id="tc-skill", id="t-1"), + HumanMessage(content="# Plan skill body\n\ninstructions...", id="h-synthetic"), + AIMessage(content="Here is the plan.", id="a-2"), + ] + result = build_turns(msgs) + roles = [t["role"] for t in result] + assert roles == ["user", "assistant", "assistant"] + skill_seg = result[1]["segments"][0] + assert skill_seg["name"] == "skill" + assert skill_seg["result"] == "# Plan skill body\n\ninstructions..." + + +def test_build_turns_human_after_non_skill_tool_still_renders_as_user_turn(): + msgs = [ + AIMessage(content="", id="a-1", tool_calls=[{"id": "tc-1", "name": "read_file", "args": {"path": "a"}}]), + ToolMessage(content="contents", tool_call_id="tc-1", id="t-1"), + HumanMessage(content="thanks", id="h-1"), + ] + result = build_turns(msgs) + assert [t["role"] for t in result] == ["assistant", "user"] + assert result[1]["segments"][0]["content"] == "thanks" + + +def test_build_turns_ai_thinking_block_emits_thinking_segment(): + # Anthropic shape: content list with a leading thinking block. + m = AIMessage( + content=[ + {"type": "thinking", "thinking": "Let me consider the options…", "signature": "sig"}, + {"type": "text", "text": "Here is the answer."}, + ], + id="a-think-1", + ) + result = build_turns([m]) + assert result[0]["segments"] == [ + {"type": "thinking", "content": "Let me consider the options…"}, + {"type": "text", "content": "Here is the answer."}, + ] + + +def test_build_turns_ai_reasoning_block_emits_thinking_segment(): + # LangChain standardized shape: content list with a reasoning block. + m = AIMessage( + content=[{"type": "reasoning", "reasoning": "Step-by-step plan…"}, {"type": "text", "text": "Done."}], + id="a-think-2", + ) + result = build_turns([m]) + assert result[0]["segments"] == [ + {"type": "thinking", "content": "Step-by-step plan…"}, + {"type": "text", "content": "Done."}, + ] + + +def test_build_turns_ai_reasoning_in_additional_kwargs_string_emits_thinking_segment(): + # DeepSeek / Qwen / xAI / some OpenRouter routes: reasoning_content as a string. + m = AIMessage( + content="Final answer.", id="a-think-3", additional_kwargs={"reasoning_content": "Considered options A, B, C…"} + ) + result = build_turns([m]) + assert result[0]["segments"] == [ + {"type": "thinking", "content": "Considered options A, B, C…"}, + {"type": "text", "content": "Final answer."}, + ] + + +def test_build_turns_ai_reasoning_in_additional_kwargs_summary_emits_thinking_segment(): + # OpenAI legacy reasoning.summary[*].text shape. + m = AIMessage( + content="Final answer.", + id="a-think-4", + additional_kwargs={"reasoning": {"summary": [{"text": "step 1"}, {"text": "step 2"}]}}, + ) + result = build_turns([m]) + assert result[0]["segments"] == [ + {"type": "thinking", "content": "step 1\n\nstep 2"}, + {"type": "text", "content": "Final answer."}, + ] + + +def test_build_turns_mixed_order_human_ai_tool_ai_tool_human(): + msgs = [ + HumanMessage(content="first prompt", id="h-1"), + AIMessage(content="ok", id="a-1", tool_calls=[{"id": "tc-1", "name": "read_file", "args": {"path": "a"}}]), + ToolMessage(content="result-1", tool_call_id="tc-1", id="t-1"), + AIMessage(content="next", id="a-2", tool_calls=[{"id": "tc-2", "name": "grep", "args": {"pattern": "z"}}]), + ToolMessage(content="result-2", tool_call_id="tc-2", id="t-2"), + HumanMessage(content="second prompt", id="h-2"), + ] + result = build_turns(msgs) + assert [t["role"] for t in result] == ["user", "assistant", "assistant", "user"] + assert result[1]["segments"][1]["result"] == "result-1" + assert result[2]["segments"][1]["result"] == "result-2" diff --git a/tests/unit_tests/chat/test_views.py b/tests/unit_tests/chat/test_views.py new file mode 100644 index 000000000..c92941f2d --- /dev/null +++ b/tests/unit_tests/chat/test_views.py @@ -0,0 +1,280 @@ +from unittest.mock import AsyncMock, MagicMock, patch + +from django.urls import reverse + +import pytest +from activity.models import Activity, TriggerType + +from accounts.models import Role, User +from chat.models import ChatThread + + +@pytest.fixture +def other_user(db): + return User.objects.create_user(username="other", email="other@test.com", password="x", role=Role.MEMBER) # noqa: S106 + + +@pytest.mark.django_db +def test_list_view_requires_login(client): + resp = client.get(reverse("chat_list")) + assert resp.status_code == 302 + assert "/accounts/login" in resp["Location"] or "login" in resp["Location"].lower() + + +@pytest.mark.django_db +def test_list_view_only_shows_users_threads(member_client, member_user, other_user): + mine = ChatThread.objects.create(thread_id="t-mine", user=member_user, repo_id="a/b", ref="main") + ChatThread.objects.create(thread_id="t-theirs", user=other_user, repo_id="a/b", ref="main") + resp = member_client.get(reverse("chat_list")) + assert resp.status_code == 200 + threads = list(resp.context["threads"]) + assert [t.thread_id for t in threads] == [mine.thread_id] + + +@pytest.mark.django_db +def test_detail_view_404s_for_other_user_thread(member_client, other_user): + thread = ChatThread.objects.create(thread_id="t-other", user=other_user, repo_id="a/b", ref="main") + resp = member_client.get(reverse("chat_detail", kwargs={"thread_id": thread.thread_id})) + assert resp.status_code == 404 + + +@pytest.mark.django_db +def test_detail_view_with_live_checkpoint_renders_transcript(member_client, member_user): + from langchain_core.messages import AIMessage + + thread = ChatThread.objects.create(thread_id="t-live", user=member_user, repo_id="a/b", ref="main") + msg = AIMessage(content="hello from agent", id="m-1") + tup = MagicMock(checkpoint={"channel_values": {"messages": [msg]}}) + with ( + patch("chat.views.open_checkpointer") as cp_ctx, + patch("chat.views.aget_existing_mr_payload", AsyncMock(return_value=None)), + ): + saver = MagicMock() + saver.aget_tuple = AsyncMock(return_value=tup) + cp_ctx.return_value.__aenter__ = AsyncMock(return_value=saver) + cp_ctx.return_value.__aexit__ = AsyncMock(return_value=None) + resp = member_client.get(reverse("chat_detail", kwargs={"thread_id": thread.thread_id})) + + assert resp.status_code == 200 + assert resp.context["expired"] is False + turns = resp.context["turns"] + assert len(turns) == 1 + assert turns[0]["role"] == "assistant" + assert turns[0]["segments"] == [{"type": "text", "content": "hello from agent"}] + + +@pytest.mark.django_db +def test_detail_view_with_missing_checkpoint_flags_expired(member_client, member_user): + thread = ChatThread.objects.create(thread_id="t-gone", user=member_user, repo_id="a/b", ref="main") + with ( + patch("chat.views.open_checkpointer") as cp_ctx, + patch("chat.views.aget_existing_mr_payload", AsyncMock(return_value=None)), + ): + saver = MagicMock() + saver.aget_tuple = AsyncMock(return_value=None) + cp_ctx.return_value.__aenter__ = AsyncMock(return_value=saver) + cp_ctx.return_value.__aexit__ = AsyncMock(return_value=None) + resp = member_client.get(reverse("chat_detail", kwargs={"thread_id": thread.thread_id})) + + assert resp.status_code == 200 + assert resp.context["expired"] is True + + +@pytest.mark.django_db +def test_detail_view_empty_state_renders_new_page(member_client): + resp = member_client.get(reverse("chat_new")) + assert resp.status_code == 200 + assert resp.context["thread"] is None + assert resp.context["expired"] is False + + +@pytest.mark.django_db +def test_detail_view_surfaces_existing_mr_when_checkpoint_has_none(member_client, member_user): + """Composer should show an MR pill when one already exists for the chat's branch.""" + from codebase.base import MergeRequest + from codebase.base import User as CBUser + + thread = ChatThread.objects.create(thread_id="t-mr", user=member_user, repo_id="a/b", ref="feature-x") + tup = MagicMock(checkpoint={"channel_values": {"messages": []}}) + existing_mr = MergeRequest( + repo_id="a/b", + merge_request_id=42, + source_branch="feature-x", + target_branch="main", + title="Add feature X", + description="", + labels=[], + web_url="https://gitlab.example/a/b/-/merge_requests/42", + sha="deadbeef", + author=CBUser(id=1, username="u", name="U"), + draft=True, + ) + repo_client = MagicMock() + repo_client.get_merge_request_by_branches.return_value = existing_mr + with ( + patch("chat.views.open_checkpointer") as cp_ctx, + patch("chat.repo_state.RepoClient.create_instance", return_value=repo_client), + patch("chat.repo_state.RepositoryConfig.get_config", return_value=MagicMock(default_branch="main")), + ): + saver = MagicMock() + saver.aget_tuple = AsyncMock(return_value=tup) + cp_ctx.return_value.__aenter__ = AsyncMock(return_value=saver) + cp_ctx.return_value.__aexit__ = AsyncMock(return_value=None) + resp = member_client.get(reverse("chat_detail", kwargs={"thread_id": thread.thread_id})) + + assert resp.status_code == 200 + mr = resp.context["merge_request"] + assert mr is not None + assert mr["id"] == 42 + assert mr["url"] == "https://gitlab.example/a/b/-/merge_requests/42" + assert mr["draft"] is True + repo_client.get_merge_request_by_branches.assert_called_once_with("a/b", "feature-x", "main") + + +@pytest.mark.django_db +def test_detail_view_skips_mr_lookup_when_checkpoint_already_has_one(member_client, member_user): + """When LangGraph state already carries an MR, don't hit the platform.""" + thread = ChatThread.objects.create(thread_id="t-cached", user=member_user, repo_id="a/b", ref="feature-y") + stored_mr = { + "merge_request_id": 7, + "web_url": "https://example/7", + "title": "Stored", + "draft": False, + "source_branch": "feature-y", + "target_branch": "main", + } + tup = MagicMock(checkpoint={"channel_values": {"messages": [], "merge_request": stored_mr}}) + repo_client = MagicMock() + with ( + patch("chat.views.open_checkpointer") as cp_ctx, + patch("chat.repo_state.RepoClient.create_instance", return_value=repo_client) as factory, + ): + saver = MagicMock() + saver.aget_tuple = AsyncMock(return_value=tup) + cp_ctx.return_value.__aenter__ = AsyncMock(return_value=saver) + cp_ctx.return_value.__aexit__ = AsyncMock(return_value=None) + resp = member_client.get(reverse("chat_detail", kwargs={"thread_id": thread.thread_id})) + + assert resp.status_code == 200 + assert resp.context["merge_request"]["id"] == 7 + factory.assert_not_called() + + +@pytest.mark.django_db +def test_detail_view_swallows_platform_errors_in_mr_lookup(member_client, member_user): + """Platform hiccups must not break the page — but only platform-typed + exceptions are swallowed. Programming bugs propagate. + """ + import httpx + + thread = ChatThread.objects.create(thread_id="t-err", user=member_user, repo_id="a/b", ref="feature-z") + tup = MagicMock(checkpoint={"channel_values": {"messages": []}}) + with ( + patch("chat.views.open_checkpointer") as cp_ctx, + patch("chat.repo_state.RepositoryConfig.get_config", side_effect=httpx.ConnectError("platform unreachable")), + ): + saver = MagicMock() + saver.aget_tuple = AsyncMock(return_value=tup) + cp_ctx.return_value.__aenter__ = AsyncMock(return_value=saver) + cp_ctx.return_value.__aexit__ = AsyncMock(return_value=None) + resp = member_client.get(reverse("chat_detail", kwargs={"thread_id": thread.thread_id})) + + assert resp.status_code == 200 + assert resp.context["merge_request"] is None + + +@pytest.mark.django_db +def test_detail_view_propagates_unexpected_errors_in_mr_lookup(member_client, member_user): + """Bugs (KeyError/AttributeError/TypeError) must NOT be silently swallowed + by the soft-fallback path — they would mask real failures behind a + plausible-looking "no MR" UI. + """ + thread = ChatThread.objects.create(thread_id="t-bug", user=member_user, repo_id="a/b", ref="feature-z") + tup = MagicMock(checkpoint={"channel_values": {"messages": []}}) + with ( + patch("chat.views.open_checkpointer") as cp_ctx, + patch("chat.repo_state.RepositoryConfig.get_config", side_effect=KeyError("config missing")), + ): + saver = MagicMock() + saver.aget_tuple = AsyncMock(return_value=tup) + cp_ctx.return_value.__aenter__ = AsyncMock(return_value=saver) + cp_ctx.return_value.__aexit__ = AsyncMock(return_value=None) + with pytest.raises(KeyError): + member_client.get(reverse("chat_detail", kwargs={"thread_id": thread.thread_id})) + + +@pytest.mark.django_db +def test_detail_view_skips_mr_lookup_when_branch_is_default(member_client, member_user): + """No MR makes sense when source == target, so don't ask the platform.""" + thread = ChatThread.objects.create(thread_id="t-main", user=member_user, repo_id="a/b", ref="main") + tup = MagicMock(checkpoint={"channel_values": {"messages": []}}) + repo_client = MagicMock() + with ( + patch("chat.views.open_checkpointer") as cp_ctx, + patch("chat.repo_state.RepoClient.create_instance", return_value=repo_client) as factory, + patch("chat.repo_state.RepositoryConfig.get_config", return_value=MagicMock(default_branch="main")), + ): + saver = MagicMock() + saver.aget_tuple = AsyncMock(return_value=tup) + cp_ctx.return_value.__aenter__ = AsyncMock(return_value=saver) + cp_ctx.return_value.__aexit__ = AsyncMock(return_value=None) + resp = member_client.get(reverse("chat_detail", kwargs={"thread_id": thread.thread_id})) + + assert resp.status_code == 200 + assert resp.context["merge_request"] is None + factory.assert_not_called() + repo_client.get_merge_request_by_branches.assert_not_called() + + +@pytest.mark.django_db +def test_from_activity_404_for_other_users_activity(member_client, other_user): + activity = Activity.objects.create( + trigger_type=TriggerType.UI_JOB, repo_id="a/b", ref="main", prompt="x", thread_id="t-x", user=other_user + ) + resp = member_client.post(reverse("chat_from_activity", kwargs={"activity_id": activity.id})) + assert resp.status_code == 404 + + +@pytest.mark.django_db +def test_from_activity_404_when_activity_has_no_thread_id(member_client, member_user): + activity = Activity.objects.create( + trigger_type=TriggerType.UI_JOB, repo_id="a/b", ref="main", prompt="x", user=member_user + ) + resp = member_client.post(reverse("chat_from_activity", kwargs={"activity_id": activity.id})) + assert resp.status_code == 404 + + +@pytest.mark.django_db +def test_from_activity_410_when_checkpoint_missing(member_client, member_user): + activity = Activity.objects.create( + trigger_type=TriggerType.UI_JOB, repo_id="a/b", ref="main", prompt="x", thread_id="t-gone", user=member_user + ) + with patch("chat.views.open_checkpointer") as cp_ctx: + saver = MagicMock() + saver.aget_tuple = AsyncMock(return_value=None) + cp_ctx.return_value.__aenter__ = AsyncMock(return_value=saver) + cp_ctx.return_value.__aexit__ = AsyncMock(return_value=None) + resp = member_client.post(reverse("chat_from_activity", kwargs={"activity_id": activity.id})) + assert resp.status_code == 410 + + +@pytest.mark.django_db +def test_from_activity_creates_thread_and_redirects(member_client, member_user): + activity = Activity.objects.create( + trigger_type=TriggerType.UI_JOB, + repo_id="a/b", + ref="main", + prompt="hello there", + thread_id="t-alive", + user=member_user, + ) + tup = MagicMock(checkpoint={"channel_values": {"messages": []}}) + with patch("chat.views.open_checkpointer") as cp_ctx: + saver = MagicMock() + saver.aget_tuple = AsyncMock(return_value=tup) + cp_ctx.return_value.__aenter__ = AsyncMock(return_value=saver) + cp_ctx.return_value.__aexit__ = AsyncMock(return_value=None) + resp = member_client.post(reverse("chat_from_activity", kwargs={"activity_id": activity.id})) + assert resp.status_code == 302 + assert ChatThread.objects.filter(thread_id="t-alive", user=member_user).exists() + assert resp["Location"] == reverse("chat_detail", kwargs={"thread_id": "t-alive"}) diff --git a/tests/unit_tests/codebase/clients/github/test_client.py b/tests/unit_tests/codebase/clients/github/test_client.py index d4471fb6e..2cbd14066 100644 --- a/tests/unit_tests/codebase/clients/github/test_client.py +++ b/tests/unit_tests/codebase/clients/github/test_client.py @@ -361,3 +361,44 @@ def test_list_branches_respects_limit_and_stops_iteration(self, github_client): result = github_client.list_branches("owner/repo", limit=2) assert result == ["a", "b"] + + def test_get_merge_request_by_branches_returns_first_open_match(self, github_client): + """When an open PR exists for the source/target pair, return a serialized MergeRequest.""" + mock_repo = Mock() + mock_pr = Mock() + mock_pr.number = 7 + mock_pr.head = Mock(ref="feat-x", sha="abc123") + mock_pr.base = Mock(ref="main") + mock_pr.title = "feat: add x" + mock_pr.body = "details" + label = Mock() + label.name = "enhancement" + mock_pr.labels = [label] + mock_pr.html_url = "https://github.com/o/r/pull/7" + mock_user = Mock(id=1, login="alice") + mock_user.name = "Alice" + mock_pr.user = mock_user + mock_pr.draft = True + mock_repo.get_pulls.return_value = iter([mock_pr]) + github_client.client.get_repo.return_value = mock_repo + + result = github_client.get_merge_request_by_branches("owner/repo", "feat-x", "main") + + assert result is not None + assert result.merge_request_id == 7 + assert result.source_branch == "feat-x" + assert result.target_branch == "main" + assert result.draft is True + assert result.web_url == "https://github.com/o/r/pull/7" + assert result.labels == ["enhancement"] + mock_repo.get_pulls.assert_called_once_with(state="open", base="main", head="feat-x") + + def test_get_merge_request_by_branches_returns_none_when_empty(self, github_client): + """No open PR matching the branch pair → ``None``.""" + mock_repo = Mock() + mock_repo.get_pulls.return_value = iter([]) + github_client.client.get_repo.return_value = mock_repo + + result = github_client.get_merge_request_by_branches("owner/repo", "feat-x", "main") + + assert result is None diff --git a/tests/unit_tests/codebase/clients/gitlab/test_client.py b/tests/unit_tests/codebase/clients/gitlab/test_client.py index dc9dee0ed..ac6ee6f9b 100644 --- a/tests/unit_tests/codebase/clients/gitlab/test_client.py +++ b/tests/unit_tests/codebase/clients/gitlab/test_client.py @@ -377,3 +377,29 @@ def test_list_branches_caps_per_page_at_100(self, gitlab_client): _, kwargs = mock_project.branches.list.call_args assert kwargs["per_page"] == 100 + + def test_get_merge_request_by_branches_returns_first_open_match(self, gitlab_client): + """When an open MR exists for the source/target pair, return the serialized MR.""" + mock_project = Mock() + mock_mr = Mock() + mock_project.mergerequests.list.return_value = iter([mock_mr]) + gitlab_client.client.projects.get.return_value = mock_project + sentinel = Mock(name="serialized") + with patch.object(gitlab_client, "_serialize_merge_request", return_value=sentinel) as serialize: + result = gitlab_client.get_merge_request_by_branches("group/repo", "feat-x", "main") + + assert result is sentinel + mock_project.mergerequests.list.assert_called_once_with( + source_branch="feat-x", target_branch="main", state="opened", iterator=True + ) + serialize.assert_called_once_with("group/repo", mock_mr) + + def test_get_merge_request_by_branches_returns_none_when_empty(self, gitlab_client): + """Empty list → ``None`` (not an exception).""" + mock_project = Mock() + mock_project.mergerequests.list.return_value = iter([]) + gitlab_client.client.projects.get.return_value = mock_project + + result = gitlab_client.get_merge_request_by_branches("group/repo", "feat-x", "main") + + assert result is None diff --git a/tests/unit_tests/codebase/test_utils.py b/tests/unit_tests/codebase/test_utils.py index 4e4f9b615..8b054613c 100644 --- a/tests/unit_tests/codebase/test_utils.py +++ b/tests/unit_tests/codebase/test_utils.py @@ -1,7 +1,7 @@ from langchain_core.messages import AIMessage, HumanMessage from codebase.base import Discussion, Note, NoteableType, User -from codebase.utils import discussion_has_daiv_mentions, note_mentions_daiv, notes_to_messages +from codebase.utils import discussion_has_daiv_mentions, files_changed_from_patch, note_mentions_daiv, notes_to_messages from core.constants import BOT_NAME @@ -254,3 +254,44 @@ def test_discussion_with_daiv_mentions_multilines(self): discussion = Discussion(id="discussion_1", notes=notes) assert discussion_has_daiv_mentions(discussion, current_user) is True + + +class TestFilesChangedFromPatch: + """Patch-parsing covers every op the rail needs to surface for bash edits.""" + + def test_empty_or_none_returns_empty_list(self): + assert files_changed_from_patch(None) == [] + assert files_changed_from_patch("") == [] + assert files_changed_from_patch(" \n") == [] + + def test_modified_file(self): + patch = ( + "diff --git a/daiv/foo.py b/daiv/foo.py\n" + "index 1111111..2222222 100644\n" + "--- a/daiv/foo.py\n" + "+++ b/daiv/foo.py\n" + "@@ -1 +1 @@\n-old\n+new\n" + ) + assert files_changed_from_patch(patch) == [{"path": "daiv/foo.py", "op": "modified"}] + + def test_added_and_deleted(self): + patch = ( + "diff --git a/new.txt b/new.txt\n" + "new file mode 100644\n" + "--- /dev/null\n" + "+++ b/new.txt\n" + "@@ -0,0 +1 @@\n+hi\n" + "diff --git a/old.txt b/old.txt\n" + "deleted file mode 100644\n" + "--- a/old.txt\n" + "+++ /dev/null\n" + "@@ -1 +0,0 @@\n-bye\n" + ) + assert files_changed_from_patch(patch) == [ + {"path": "new.txt", "op": "added"}, + {"path": "old.txt", "op": "deleted"}, + ] + + def test_rename_carries_from_path(self): + patch = "diff --git a/src/a.py b/src/b.py\nsimilarity index 100%\nrename from src/a.py\nrename to src/b.py\n" + assert files_changed_from_patch(patch) == [{"path": "src/b.py", "op": "renamed", "from_path": "src/a.py"}] diff --git a/tests/unit_tests/jobs/api/test_views.py b/tests/unit_tests/jobs/api/test_views.py index 73bc5f76f..48dbaa839 100644 --- a/tests/unit_tests/jobs/api/test_views.py +++ b/tests/unit_tests/jobs/api/test_views.py @@ -143,9 +143,13 @@ async def _aenq(**kwargs): assert len(data["jobs"]) == 1 assert data["jobs"][0]["job_id"] == str(task_id) assert data["failed"] == [] - mock_task.aenqueue.assert_called_once_with( - repo_id="group/project", prompt="List all files", ref=None, use_max=False - ) + mock_task.aenqueue.assert_called_once() + kwargs = mock_task.aenqueue.call_args.kwargs + assert kwargs["repo_id"] == "group/project" + assert kwargs["prompt"] == "List all files" + assert kwargs["ref"] is None + assert kwargs["use_max"] is False + assert kwargs["thread_id"] @pytest.mark.django_db(transaction=True) @@ -176,7 +180,13 @@ async def _aenq(**kwargs): response = await authenticated_client.post("/jobs", json=_single_repo_body(prompt="Fix the bug", use_max=True)) assert response.status_code == 202 - mock_task.aenqueue.assert_called_once_with(repo_id="group/project", prompt="Fix the bug", ref=None, use_max=True) + mock_task.aenqueue.assert_called_once() + kwargs = mock_task.aenqueue.call_args.kwargs + assert kwargs["repo_id"] == "group/project" + assert kwargs["prompt"] == "Fix the bug" + assert kwargs["ref"] is None + assert kwargs["use_max"] is True + assert kwargs["thread_id"] @pytest.mark.django_db(transaction=True) diff --git a/tests/unit_tests/jobs/test_tasks.py b/tests/unit_tests/jobs/test_tasks.py new file mode 100644 index 000000000..2070c7394 --- /dev/null +++ b/tests/unit_tests/jobs/test_tasks.py @@ -0,0 +1,49 @@ +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from jobs.tasks import run_job_task + + +@pytest.mark.django_db +async def test_run_job_task_uses_async_redis_saver_with_thread_id(): + """run_job_task must use the shared open_checkpointer (AsyncRedisSaver) and thread its + thread_id through to the langgraph config.""" + last_message = MagicMock() + last_message.content = "ok" + fake_result = {"messages": [last_message]} + + runtime_ctx = MagicMock() + runtime_ctx.config.models.agent = MagicMock() + + agent = AsyncMock() + agent.ainvoke = AsyncMock(return_value=fake_result) + + with ( + patch("jobs.tasks.open_checkpointer") as cp_ctx, + patch("jobs.tasks.set_runtime_ctx") as rc_ctx, + patch("jobs.tasks.create_daiv_agent", new=AsyncMock(return_value=agent)), + patch( + "jobs.tasks.get_daiv_agent_kwargs", + return_value={"model_names": ["claude-4-7-opus"], "thinking_level": "medium"}, + ), + patch("jobs.tasks.build_langsmith_config", return_value={"configurable": {"thread_id": "t-123"}}), + patch("jobs.tasks.build_agent_result", new=AsyncMock(return_value={"response": "ok"})), + patch("jobs.tasks.build_usage_summary", return_value=MagicMock(to_dict=lambda: {})), + patch("jobs.tasks.track_usage_metadata"), + ): + cp_ctx.return_value.__aenter__.return_value = object() + rc_ctx.return_value.__aenter__.return_value = runtime_ctx + + await run_job_task.func(repo_id="owner/repo", prompt="hi", ref="main", use_max=False, thread_id="t-123") + + cp_ctx.assert_called_once() + call_kwargs = agent.ainvoke.call_args.kwargs + assert call_kwargs["config"]["configurable"]["thread_id"] == "t-123" + + +async def test_run_job_task_rejects_missing_thread_id(): + """Chat resume relies on the activity row and the checkpointer sharing the + same thread_id. A silent UUID fallback would break the resume contract. + """ + with pytest.raises(ValueError, match="non-empty thread_id"): + await run_job_task.func(repo_id="owner/repo", prompt="hi", thread_id="") diff --git a/tests/unit_tests/mcp_server/test_server.py b/tests/unit_tests/mcp_server/test_server.py index 5eaccef66..e3e448033 100644 --- a/tests/unit_tests/mcp_server/test_server.py +++ b/tests/unit_tests/mcp_server/test_server.py @@ -16,8 +16,12 @@ def _mock_task(): class _FakeActivity: + _next_pk = 0 + def __init__(self, task_result_id): self.task_result_id = task_result_id + type(self)._next_pk += 1 + self.pk = type(self)._next_pk async def _fake_acreate_activity(**kwargs): @@ -25,7 +29,25 @@ async def _fake_acreate_activity(**kwargs): def _patch_acreate(): - return patch("activity.services.acreate_activity", new_callable=AsyncMock, side_effect=_fake_acreate_activity) + # Patch acreate_activity and silence the post-create title task enqueue so + # tests don't depend on the queue backend. + acreate_patch = patch( + "activity.services.acreate_activity", new_callable=AsyncMock, side_effect=_fake_acreate_activity + ) + title_patch = patch("activity.services.generate_title_task") + + class _Combined: + def __enter__(self): + mock_create = acreate_patch.__enter__() + mock_title = title_patch.__enter__() + mock_title.aenqueue = AsyncMock(return_value=None) + return mock_create + + def __exit__(self, exc_type, exc, tb): + title_patch.__exit__(exc_type, exc, tb) + return acreate_patch.__exit__(exc_type, exc, tb) + + return _Combined() @pytest.mark.django_db(transaction=True) @@ -87,9 +109,13 @@ async def test_submit_job_passes_ref(): with patch("activity.services.run_job_task") as mock_task, _patch_acreate(): mock_task.aenqueue = AsyncMock(return_value=_mock_task()) await submit_job(repos=[{"repo_id": "group/project", "ref": "feature-branch"}], prompt="Fix the bug") - mock_task.aenqueue.assert_called_once_with( - repo_id="group/project", prompt="Fix the bug", ref="feature-branch", use_max=False - ) + mock_task.aenqueue.assert_called_once() + kwargs = mock_task.aenqueue.call_args.kwargs + assert kwargs["repo_id"] == "group/project" + assert kwargs["prompt"] == "Fix the bug" + assert kwargs["ref"] == "feature-branch" + assert kwargs["use_max"] is False + assert kwargs["thread_id"] @pytest.mark.django_db(transaction=True) diff --git a/uv.lock b/uv.lock index 4b1e62d5f..1dcb5b03b 100644 --- a/uv.lock +++ b/uv.lock @@ -7,6 +7,39 @@ resolution-markers = [ "sys_platform != 'emscripten' and sys_platform != 'win32'", ] +[[package]] +name = "ag-ui-langgraph" +version = "0.0.34" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "ag-ui-protocol" }, + { name = "langchain" }, + { name = "langchain-core" }, + { name = "langgraph" }, + { name = "pydantic" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/11/5e/950cdb65de973660f2634a675b0a614657a5a00258a7eede21d3a9318992/ag_ui_langgraph-0.0.34.tar.gz", hash = "sha256:755323d5256407ce62d6b9af447a9f1250554e7056c5e2115027a4174a736c41", size = 258423, upload-time = "2026-04-20T21:09:25.789Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/f9/e5/c4eeee50262de83eb10de29b55259244d0a95edceb0d08c63ba0ae175719/ag_ui_langgraph-0.0.34-py3-none-any.whl", hash = "sha256:124dccdae48af124f857b746aa3c0424b2d332b73b06ac426b6d5dcd811ea984", size = 27020, upload-time = "2026-04-20T21:09:26.816Z" }, +] + +[package.optional-dependencies] +fastapi = [ + { name = "fastapi" }, +] + +[[package]] +name = "ag-ui-protocol" +version = "0.1.18" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pydantic" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/4c/d7/5711eada86da9bd7684e58645653a1693ef20b66cc3efbb1deeafef80f8d/ag_ui_protocol-0.1.18.tar.gz", hash = "sha256:b37c672c3fd6bac12b316c39f45ad9db9f137bbb885489c79f268507029a22ff", size = 9937, upload-time = "2026-04-21T20:44:59.151Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d8/74/913c9b8fc566c6da650aecbddf25a5d8186b54138df265eb9eb546f56141/ag_ui_protocol-0.1.18-py3-none-any.whl", hash = "sha256:d151c0f0a34160647f1571163f7185746f4326b15a56d1560de5082a7a0e7a12", size = 12607, upload-time = "2026-04-21T20:45:00.097Z" }, +] + [[package]] name = "aiohappyeyeballs" version = "2.6.1" @@ -316,6 +349,24 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/d1/d6/3965ed04c63042e047cb6a3e6ed1a63a35087b6a609aa3a15ed8ac56c221/colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6", size = 25335, upload-time = "2022-10-25T02:36:20.889Z" }, ] +[[package]] +name = "copilotkit" +version = "0.1.86" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "ag-ui-langgraph", extra = ["fastapi"] }, + { name = "ag-ui-protocol" }, + { name = "fastapi" }, + { name = "langchain" }, + { name = "langgraph" }, + { name = "partialjson" }, + { name = "toml" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/30/1b/6c3540e03238d5a6bdd1cfbd4b5a8500f13a005eb19f37e7ff31240925c8/copilotkit-0.1.86.tar.gz", hash = "sha256:ece51bb354e76840aa7194d12e96e2e7b7f5210bae42ad59437bffa8791e9f23", size = 45181, upload-time = "2026-04-08T17:56:21.289Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/48/7a/83dd622e9060089e0de55b39db7606a81f2514431b96aa621595e8c4c731/copilotkit-0.1.86-py3-none-any.whl", hash = "sha256:525d689ce0cb67f9c34822d10c3fa98c83c4ca7552c4a6bd5235a1bce2734671", size = 53437, upload-time = "2026-04-08T17:56:19.623Z" }, +] + [[package]] name = "coverage" version = "7.13.5" @@ -425,6 +476,8 @@ name = "daiv" version = "2.0.0" source = { virtual = "." } dependencies = [ + { name = "ag-ui-langgraph" }, + { name = "copilotkit" }, { name = "croniter" }, { name = "ddgs" }, { name = "deepagents" }, @@ -501,6 +554,8 @@ docs = [ [package.metadata] requires-dist = [ + { name = "ag-ui-langgraph", specifier = "==0.0.34" }, + { name = "copilotkit", specifier = "==0.1.86" }, { name = "croniter", specifier = "==6.2.2" }, { name = "ddgs", specifier = "==9.14.1" }, { name = "deepagents", specifier = "==0.5.1" }, @@ -848,6 +903,22 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/c1/ea/53f2148663b321f21b5a606bd5f191517cf40b7072c0497d3c92c4a13b1e/executing-2.2.1-py2.py3-none-any.whl", hash = "sha256:760643d3452b4d777d295bb167ccc74c64a81df23fb5e08eff250c425a4b2017", size = 28317, upload-time = "2025-09-01T09:48:08.5Z" }, ] +[[package]] +name = "fastapi" +version = "0.136.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "annotated-doc" }, + { name = "pydantic" }, + { name = "starlette" }, + { name = "typing-extensions" }, + { name = "typing-inspection" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/5d/45/c130091c2dfa061bbfe3150f2a5091ef1adf149f2a8d2ae769ecaf6e99a2/fastapi-0.136.1.tar.gz", hash = "sha256:7af665ad7acfa0a3baf8983d393b6b471b9da10ede59c60045f49fbc89a0fa7f", size = 397448, upload-time = "2026-04-23T16:49:44.046Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/5a/ff/2e4eca3ade2c22fe1dea7043b8ee9dabe47753349eb1b56a202de8af6349/fastapi-0.136.1-py3-none-any.whl", hash = "sha256:a6e9d7eeada96c93a4d69cb03836b44fa34e2854accb7244a1ece36cd4781c3f", size = 117683, upload-time = "2026-04-23T16:49:42.437Z" }, +] + [[package]] name = "filelock" version = "3.25.2" @@ -1024,7 +1095,9 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/78/02/bde66806e8f169cf90b14d02c500c44cdbe02c8e224c9c67bafd1b8cadd1/greenlet-3.4.0-cp314-cp314-macosx_11_0_universal2.whl", hash = "sha256:10a07aca6babdd18c16a3f4f8880acfffc2b88dfe431ad6aa5f5740759d7d75e", size = 286291, upload-time = "2026-04-08T17:09:34.307Z" }, { url = "https://files.pythonhosted.org/packages/05/1f/39da1c336a87d47c58352fb8a78541ce63d63ae57c5b9dae1fe02801bbc2/greenlet-3.4.0-cp314-cp314-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:076e21040b3a917d3ce4ad68fb5c3c6b32f1405616c4a57aa83120979649bd3d", size = 656749, upload-time = "2026-04-08T16:24:41.721Z" }, { url = "https://files.pythonhosted.org/packages/d3/6c/90ee29a4ee27af7aa2e2ec408799eeb69ee3fcc5abcecac6ddd07a5cd0f2/greenlet-3.4.0-cp314-cp314-manylinux_2_24_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:e82689eea4a237e530bb5cb41b180ef81fa2160e1f89422a67be7d90da67f615", size = 669084, upload-time = "2026-04-08T16:31:01.372Z" }, + { url = "https://files.pythonhosted.org/packages/d2/4a/74078d3936712cff6d3c91a930016f476ce4198d84e224fe6d81d3e02880/greenlet-3.4.0-cp314-cp314-manylinux_2_24_s390x.manylinux_2_28_s390x.whl", hash = "sha256:06c2d3b89e0c62ba50bd7adf491b14f39da9e7e701647cb7b9ff4c99bee04b19", size = 673405, upload-time = "2026-04-08T16:40:42.527Z" }, { url = "https://files.pythonhosted.org/packages/07/49/d4cad6e5381a50947bb973d2f6cf6592621451b09368b8c20d9b8af49c5b/greenlet-3.4.0-cp314-cp314-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:4df3b0b2289ec686d3c821a5fee44259c05cfe824dd5e6e12c8e5f5df23085cf", size = 665621, upload-time = "2026-04-08T15:56:35.995Z" }, + { url = "https://files.pythonhosted.org/packages/79/3e/df8a83ab894751bc31e1106fdfaa80ca9753222f106b04de93faaa55feb7/greenlet-3.4.0-cp314-cp314-manylinux_2_39_riscv64.whl", hash = "sha256:070b8bac2ff3b4d9e0ff36a0d19e42103331d9737e8504747cd1e659f76297bd", size = 471670, upload-time = "2026-04-08T16:43:08.512Z" }, { url = "https://files.pythonhosted.org/packages/37/31/d1edd54f424761b5d47718822f506b435b6aab2f3f93b465441143ea5119/greenlet-3.4.0-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:8bff29d586ea415688f4cec96a591fcc3bf762d046a796cdadc1fdb6e7f2d5bf", size = 1622259, upload-time = "2026-04-08T16:26:23.201Z" }, { url = "https://files.pythonhosted.org/packages/b0/c6/6d3f9cdcb21c4e12a79cb332579f1c6aa1af78eb68059c5a957c7812d95e/greenlet-3.4.0-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:8a569c2fb840c53c13a2b8967c63621fafbd1a0e015b9c82f408c33d626a2fda", size = 1686916, upload-time = "2026-04-08T15:57:34.282Z" }, { url = "https://files.pythonhosted.org/packages/63/45/c1ca4a1ad975de4727e52d3ffe641ae23e1d7a8ffaa8ff7a0477e1827b92/greenlet-3.4.0-cp314-cp314-win_amd64.whl", hash = "sha256:207ba5b97ea8b0b60eb43ffcacf26969dd83726095161d676aac03ff913ee50d", size = 239821, upload-time = "2026-04-08T17:03:48.423Z" }, @@ -1032,7 +1105,9 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/d4/8f/18d72b629783f5e8d045a76f5325c1e938e659a9e4da79c7dcd10169a48d/greenlet-3.4.0-cp314-cp314t-macosx_11_0_universal2.whl", hash = "sha256:d70012e51df2dbbccfaf63a40aaf9b40c8bed37c3e3a38751c926301ce538ece", size = 294681, upload-time = "2026-04-08T15:52:35.778Z" }, { url = "https://files.pythonhosted.org/packages/9e/ad/5fa86ec46769c4153820d58a04062285b3b9e10ba3d461ee257b68dcbf53/greenlet-3.4.0-cp314-cp314t-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:a58bec0751f43068cd40cff31bb3ca02ad6000b3a51ca81367af4eb5abc480c8", size = 658899, upload-time = "2026-04-08T16:24:43.32Z" }, { url = "https://files.pythonhosted.org/packages/43/f0/4e8174ca0e87ae748c409f055a1ba161038c43cc0a5a6f1433a26ac2e5bf/greenlet-3.4.0-cp314-cp314t-manylinux_2_24_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:05fa0803561028f4b2e3b490ee41216a842eaee11aed004cc343a996d9523aa2", size = 665284, upload-time = "2026-04-08T16:31:02.833Z" }, + { url = "https://files.pythonhosted.org/packages/ef/92/466b0d9afd44b8af623139a3599d651c7564fa4152f25f117e1ee5949ffb/greenlet-3.4.0-cp314-cp314t-manylinux_2_24_s390x.manylinux_2_28_s390x.whl", hash = "sha256:c4cd56a9eb7a6444edbc19062f7b6fbc8f287c663b946e3171d899693b1c19fa", size = 665872, upload-time = "2026-04-08T16:40:43.912Z" }, { url = "https://files.pythonhosted.org/packages/19/da/991cf7cd33662e2df92a1274b7eb4d61769294d38a1bba8a45f31364845e/greenlet-3.4.0-cp314-cp314t-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:e60d38719cb80b3ab5e85f9f1aed4960acfde09868af6762ccb27b260d68f4ed", size = 661861, upload-time = "2026-04-08T15:56:37.269Z" }, + { url = "https://files.pythonhosted.org/packages/0d/14/3395a7ef3e260de0325152ddfe19dffb3e49fe10873b94654352b53ad48e/greenlet-3.4.0-cp314-cp314t-manylinux_2_39_riscv64.whl", hash = "sha256:1f85f204c4d54134ae850d401fa435c89cd667d5ce9dc567571776b45941af72", size = 489237, upload-time = "2026-04-08T16:43:09.993Z" }, { url = "https://files.pythonhosted.org/packages/36/c5/6c2c708e14db3d9caea4b459d8464f58c32047451142fe2cfd90e7458f41/greenlet-3.4.0-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:7f50c804733b43eded05ae694691c9aa68bca7d0a867d67d4a3f514742a2d53f", size = 1622182, upload-time = "2026-04-08T16:26:24.777Z" }, { url = "https://files.pythonhosted.org/packages/7a/4c/50c5fed19378e11a29fabab1f6be39ea95358f4a0a07e115a51ca93385d8/greenlet-3.4.0-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:2d4f0635dc4aa638cda4b2f5a07ae9a2cff9280327b581a3fcb6f317b4fbc38a", size = 1685050, upload-time = "2026-04-08T15:57:36.453Z" }, { url = "https://files.pythonhosted.org/packages/db/72/85ae954d734703ab48e622c59d4ce35d77ce840c265814af9c078cacc7aa/greenlet-3.4.0-cp314-cp314t-win_amd64.whl", hash = "sha256:1a4a48f24681300c640f143ba7c404270e1ebbbcf34331d7104a4ff40f8ea705", size = 245554, upload-time = "2026-04-08T17:03:50.044Z" }, @@ -2219,6 +2294,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/b6/61/fae042894f4296ec49e3f193aff5d7c18440da9e48102c3315e1bc4519a7/parso-0.8.6-py2.py3-none-any.whl", hash = "sha256:2c549f800b70a5c4952197248825584cb00f033b29c692671d3bf08bf380baff", size = 106894, upload-time = "2026-02-09T15:45:21.391Z" }, ] +[[package]] +name = "partialjson" +version = "0.0.8" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/5d/b2/59669fdc3ecbc724a077c598c1c9b4068549af0cd8c3b5add9337bd4d93a/partialjson-0.0.8.tar.gz", hash = "sha256:91217e19a15049332df534477f56420065ad1729cedee7d8c7433e1d2acc7dca", size = 4142, upload-time = "2024-08-03T18:03:15.798Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/8a/fb/453af21468774dbd0954853735a4fc7841544c3022ff86e5d93252d7ea72/partialjson-0.0.8-py3-none-any.whl", hash = "sha256:22c6c60944137f931a7033fa0eeee2d74b49114f3d45c25a560b07a6ebf22b76", size = 4549, upload-time = "2024-08-03T18:03:14.447Z" }, +] + [[package]] name = "pathspec" version = "1.0.4" @@ -3240,6 +3324,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/af/df/c7891ef9d2712ad774777271d39fdef63941ffba0a9d59b7ad1fd2765e57/tiktoken-0.12.0-cp314-cp314t-win_amd64.whl", hash = "sha256:f61c0aea5565ac82e2ec50a05e02a6c44734e91b51c10510b084ea1b8e633a71", size = 920667, upload-time = "2025-10-06T20:22:34.444Z" }, ] +[[package]] +name = "toml" +version = "0.10.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/be/ba/1f744cdc819428fc6b5084ec34d9b30660f6f9daaf70eead706e3203ec3c/toml-0.10.2.tar.gz", hash = "sha256:b3bda1d108d5dd99f4a20d24d9c348e91c4db7ab1b749200bded2f839ccbe68f", size = 22253, upload-time = "2020-11-01T01:40:22.204Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/44/6f/7120676b6d73228c96e17f1f794d8ab046fc910d781c8d151120c3f1569e/toml-0.10.2-py2.py3-none-any.whl", hash = "sha256:806143ae5bfb6a3c6e736a764057db0e6a0e05e338b5630894a5f779cabb4f9b", size = 16588, upload-time = "2020-11-01T01:40:20.672Z" }, +] + [[package]] name = "toml-fmt-common" version = "1.3.2"