Skip to content

Commit 2c3ea88

Browse files
authored
Merge pull request #1237 from srtab/refactor/notification-token-context
2 parents d03830f + 47a15c2 commit 2c3ea88

2 files changed

Lines changed: 123 additions & 6 deletions

File tree

daiv/notifications/signals.py

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from django.conf import settings
66
from django.db import Error as DatabaseError
77
from django.db import IntegrityError
8-
from django.db.models import Count, Q
8+
from django.db.models import Count, Q, Sum
99
from django.db.models.signals import post_save
1010
from django.dispatch import receiver
1111
from django.urls import reverse
@@ -93,6 +93,10 @@ def _render_payload(activity: Activity) -> tuple[str, str, dict]:
9393
"trigger_owner": owner,
9494
"repo_id": repo,
9595
"duration_seconds": activity.duration,
96+
"input_tokens": activity.input_tokens,
97+
"output_tokens": activity.output_tokens,
98+
"total_tokens": activity.total_tokens,
99+
"cost_usd": float(activity.cost_usd) if activity.cost_usd is not None else None,
96100
}
97101
return subject, body, context
98102

@@ -155,6 +159,10 @@ def _handle_batch_completion(activity: Activity, siblings, total: int) -> None:
155159
agg = siblings.aggregate(
156160
terminal=Count("id", filter=Q(status__in=ActivityStatus.terminal())),
157161
successful=Count("id", filter=Q(status=ActivityStatus.SUCCESSFUL)),
162+
total_input_tokens=Sum("input_tokens"),
163+
total_output_tokens=Sum("output_tokens"),
164+
total_total_tokens=Sum("total_tokens"),
165+
total_cost_usd=Sum("cost_usd"),
158166
)
159167
if agg["terminal"] < total:
160168
return
@@ -175,12 +183,18 @@ def _handle_batch_completion(activity: Activity, siblings, total: int) -> None:
175183
failed = total - successful
176184
agg_status = ActivityStatus.SUCCESSFUL if failed == 0 else ActivityStatus.FAILED
177185

178-
rows = list(siblings.values_list("repo_id", "started_at", "finished_at"))
186+
rows = list(siblings.values_list("repo_id", "started_at", "finished_at", "status"))
179187

180188
effective = activity.effective_notify_on
181189
channels = [cls.channel_type for cls in enabled_channels()] if _status_matches(effective, agg_status) else []
182190

183-
subject, body, context = _render_batch_payload(activity, rows, total, successful, failed, agg_status)
191+
usage = {
192+
"input_tokens": agg["total_input_tokens"],
193+
"output_tokens": agg["total_output_tokens"],
194+
"total_tokens": agg["total_total_tokens"],
195+
"cost_usd": float(agg["total_cost_usd"]) if agg["total_cost_usd"] is not None else None,
196+
}
197+
subject, body, context = _render_batch_payload(activity, rows, total, successful, failed, agg_status, usage)
184198
link_url = f"{reverse('activity_list')}?batch={activity.batch_id}"
185199

186200
for recipient in recipients.values():
@@ -232,11 +246,14 @@ def _rollup_exists(recipient, batch_id) -> bool:
232246

233247

234248
def _render_batch_payload(
235-
activity: Activity, rows: list[tuple], total: int, successful: int, failed: int, agg_status: str
249+
activity: Activity, rows: list[tuple], total: int, successful: int, failed: int, agg_status: str, usage: dict
236250
) -> tuple[str, str, dict]:
237251
is_schedule = _is_schedule(activity)
238252
ok = failed == 0
239-
repo_ids = sorted({repo for repo, _start, _end in rows if repo})
253+
repo_ids = sorted({repo for repo, _start, _end, _status in rows if repo})
254+
repo_results = [
255+
{"repo": repo, "ok": status == ActivityStatus.SUCCESSFUL} for repo, _start, _end, status in rows if repo
256+
]
240257
name = activity.scheduled_job.name if is_schedule else ""
241258
owner = str(activity.scheduled_job.user) if is_schedule else ""
242259

@@ -277,11 +294,16 @@ def _render_batch_payload(
277294
"trigger_owner": owner,
278295
"repo_id": repo_ids[0] if len(repo_ids) == 1 else "",
279296
"repo_ids": repo_ids,
297+
"repo_results": repo_results,
280298
"total": total,
281299
"successful_count": successful,
282300
"failed_count": failed,
283301
"duration_seconds": _batch_duration(rows),
284302
"batch_id": str(activity.batch_id),
303+
"input_tokens": usage["input_tokens"],
304+
"output_tokens": usage["output_tokens"],
305+
"total_tokens": usage["total_tokens"],
306+
"cost_usd": usage["cost_usd"],
285307
}
286308
return subject, body, context
287309

@@ -297,7 +319,7 @@ def _summarize_repos(repo_ids: list[str], limit: int = 3) -> str:
297319

298320
def _batch_duration(rows: list[tuple]) -> float | None:
299321
"""Wall-clock span from earliest start to latest finish across the batch."""
300-
pairs = [(start, end) for _repo, start, end in rows if start and end]
322+
pairs = [(start, end) for _repo, start, end, _status in rows if start and end]
301323
if not pairs:
302324
return None
303325
earliest = min(start for start, _end in pairs)

tests/unit_tests/notifications/test_signals.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import logging
22
import uuid
3+
from decimal import Decimal
34
from unittest.mock import patch
45

56
from django.utils import timezone
@@ -420,6 +421,44 @@ def test_job_rendered_subject_and_event_type(self, member_user):
420421
assert n.context["trigger_label"]
421422
assert n.context["repo_id"] == "acme/app"
422423

424+
def test_job_context_carries_token_and_cost_usage(self, member_user):
425+
member_user.notify_on_jobs = NotifyOn.ALWAYS
426+
member_user.save(update_fields=["notify_on_jobs"])
427+
428+
activity = Activity.objects.create(
429+
trigger_type=TriggerType.UI_JOB,
430+
user=member_user,
431+
repo_id="acme/app",
432+
status=ActivityStatus.SUCCESSFUL,
433+
input_tokens=12345,
434+
output_tokens=6789,
435+
total_tokens=19134,
436+
cost_usd=Decimal("0.214321"),
437+
)
438+
activity_finished.send(sender=Activity, activity=activity)
439+
440+
n = Notification.objects.get(recipient=member_user, event_type="job.finished")
441+
assert n.context["input_tokens"] == 12345
442+
assert n.context["output_tokens"] == 6789
443+
assert n.context["total_tokens"] == 19134
444+
# JSONField round-trips Decimal as float; renderers receive a number, not a string.
445+
assert n.context["cost_usd"] == pytest.approx(0.214321)
446+
447+
def test_job_context_token_fields_default_to_none_when_unset(self, member_user):
448+
member_user.notify_on_jobs = NotifyOn.ALWAYS
449+
member_user.save(update_fields=["notify_on_jobs"])
450+
451+
activity = Activity.objects.create(
452+
trigger_type=TriggerType.UI_JOB, user=member_user, repo_id="acme/app", status=ActivityStatus.SUCCESSFUL
453+
)
454+
activity_finished.send(sender=Activity, activity=activity)
455+
456+
n = Notification.objects.get(recipient=member_user, event_type="job.finished")
457+
assert n.context["input_tokens"] is None
458+
assert n.context["output_tokens"] is None
459+
assert n.context["total_tokens"] is None
460+
assert n.context["cost_usd"] is None
461+
423462

424463
@pytest.mark.django_db
425464
class TestUserBindingSeeder:
@@ -765,6 +804,62 @@ def test_schedule_batch_mixed_outcomes_renders_count_in_subject(self, member_use
765804
assert schedule.name in rollup.subject
766805
assert "2/3" in rollup.subject
767806

807+
def test_batch_context_carries_repo_results_and_summed_usage(self, member_user):
808+
member_user.notify_on_jobs = NotifyOn.ALWAYS
809+
member_user.save(update_fields=["notify_on_jobs"])
810+
811+
bid = uuid.uuid4()
812+
a = Activity.objects.create(
813+
trigger_type=TriggerType.API_JOB,
814+
user=member_user,
815+
repo_id="acme/api",
816+
status=ActivityStatus.SUCCESSFUL,
817+
batch_id=bid,
818+
notify_on=NotifyOn.ALWAYS,
819+
input_tokens=100,
820+
output_tokens=200,
821+
total_tokens=300,
822+
cost_usd=Decimal("0.10"),
823+
)
824+
b = Activity.objects.create(
825+
trigger_type=TriggerType.API_JOB,
826+
user=member_user,
827+
repo_id="acme/legacy",
828+
status=ActivityStatus.FAILED,
829+
batch_id=bid,
830+
notify_on=NotifyOn.ALWAYS,
831+
input_tokens=50,
832+
output_tokens=75,
833+
total_tokens=125,
834+
cost_usd=Decimal("0.05"),
835+
)
836+
activity_finished.send(sender=Activity, activity=a)
837+
activity_finished.send(sender=Activity, activity=b)
838+
839+
rollup = Notification.objects.get(recipient=member_user, event_type="job_batch.finished")
840+
assert rollup.context["input_tokens"] == 150
841+
assert rollup.context["output_tokens"] == 275
842+
assert rollup.context["total_tokens"] == 425
843+
assert rollup.context["cost_usd"] == pytest.approx(0.15)
844+
845+
# repo_results preserves per-repo outcome so renderers can show ✓/✗ per row.
846+
results_by_repo = {r["repo"]: r["ok"] for r in rollup.context["repo_results"]}
847+
assert results_by_repo == {"acme/api": True, "acme/legacy": False}
848+
849+
def test_batch_context_usage_totals_are_none_when_no_activity_has_usage(self, member_user):
850+
member_user.notify_on_jobs = NotifyOn.ALWAYS
851+
member_user.save(update_fields=["notify_on_jobs"])
852+
853+
a, b = self._make_batch(member_user, statuses=[ActivityStatus.SUCCESSFUL, ActivityStatus.SUCCESSFUL])
854+
activity_finished.send(sender=Activity, activity=a)
855+
activity_finished.send(sender=Activity, activity=b)
856+
857+
rollup = Notification.objects.get(event_type="job_batch.finished")
858+
assert rollup.context["input_tokens"] is None
859+
assert rollup.context["output_tokens"] is None
860+
assert rollup.context["total_tokens"] is None
861+
assert rollup.context["cost_usd"] is None
862+
768863
def test_empty_recipients_on_multi_job_batch_logs_warning(self, caplog):
769864
bid = uuid.uuid4()
770865
a = Activity.objects.create(

0 commit comments

Comments
 (0)