Skip to content

Commit 917495f

Browse files
authored
Export tag summary (#505)
* unify dev containers and docker compose up * add export task for tags * add endpoint for exporting task summary * front-end for exporting tag summaries
1 parent df1f378 commit 917495f

7 files changed

Lines changed: 309 additions & 7 deletions

File tree

bats_ai/core/tasks/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
from __future__ import annotations
2+
3+
# Import task modules so Celery autodiscovery registers decorated tasks.
4+
from . import export_task, periodic, tasks # noqa: F401

bats_ai/core/tasks/export_task.py

Lines changed: 214 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,11 @@
22

33
import csv
44
from datetime import timedelta
5-
from io import BytesIO
5+
from io import BytesIO, StringIO
66
import json
77
import zipfile
88

9+
from django.contrib.auth.models import User
910
from django.core.files import File
1011
from django.utils.timezone import now
1112

@@ -14,6 +15,7 @@
1415
Annotations,
1516
ExportedAnnotationFile,
1617
RecordingAnnotation,
18+
RecordingTag,
1719
SequenceAnnotations,
1820
)
1921

@@ -152,3 +154,214 @@ def export_annotations_task(filters: dict, annotation_types: list, export_id: in
152154
export_record.status = "failed"
153155
export_record.save()
154156
raise
157+
158+
159+
@app.task(bind=True)
160+
def export_tag_annotation_summary_task(self, export_id: int):
161+
export_record = ExportedAnnotationFile.objects.get(pk=export_id)
162+
try:
163+
tag_rows, tag_user_rows = _collect_tag_summary_rows()
164+
165+
buffer = BytesIO()
166+
with zipfile.ZipFile(buffer, "w", zipfile.ZIP_DEFLATED) as zipf:
167+
_write_tag_exports(zipf, tag_rows, tag_user_rows)
168+
169+
buffer.seek(0)
170+
filename = f"tag-annotation-summary-{export_id}.zip"
171+
export_record.file.save(filename, File(buffer), save=False)
172+
export_record.download_url = export_record.file.url
173+
export_record.status = "complete"
174+
export_record.expires_at = now() + timedelta(hours=24)
175+
export_record.save()
176+
except Exception:
177+
export_record.status = "failed"
178+
export_record.save()
179+
raise
180+
181+
182+
def _collect_tag_summary_rows():
183+
tag_rows = []
184+
tag_user_rows = []
185+
users = list(User.objects.order_by("username").values("id", "username"))
186+
tags = RecordingTag.objects.select_related("user").prefetch_related(
187+
"recording_set__recordingannotation_set__owner"
188+
)
189+
190+
for tag in tags:
191+
tag_row, user_rows = _build_rows_for_tag(tag, users)
192+
tag_rows.append(tag_row)
193+
tag_user_rows.extend(user_rows)
194+
195+
return tag_rows, tag_user_rows
196+
197+
198+
def _build_rows_for_tag(tag, users):
199+
recordings = list(tag.recording_set.all())
200+
total_recordings = len(recordings)
201+
annotations_by_user = _group_recording_annotations_by_user(recordings)
202+
annotated_total, submitted_total, unsubmitted_total = _collect_total_sets(annotations_by_user)
203+
204+
tag_row = {
205+
"tag_id": tag.id,
206+
"tag_text": tag.text,
207+
"tag_owner": tag.user.username,
208+
"total_recordings": total_recordings,
209+
"annotated_recordings": len(annotated_total),
210+
"submitted_recordings": len(submitted_total),
211+
"unsubmitted_recordings": len(unsubmitted_total),
212+
"remaining_recordings": total_recordings - len(annotated_total),
213+
}
214+
user_rows = _build_user_rows(
215+
tag,
216+
total_recordings,
217+
annotations_by_user,
218+
users,
219+
)
220+
return tag_row, user_rows
221+
222+
223+
def _group_recording_annotations_by_user(recordings):
224+
annotations_by_user = {}
225+
for recording in recordings:
226+
for annotation in recording.recordingannotation_set.all():
227+
key = annotation.owner_id
228+
if key not in annotations_by_user:
229+
annotations_by_user[key] = {
230+
"username": annotation.owner.username,
231+
"annotated_recordings": set(),
232+
"submitted_recordings": set(),
233+
"unsubmitted_recordings": set(),
234+
}
235+
user_stats = annotations_by_user[key]
236+
user_stats["annotated_recordings"].add(recording.id)
237+
if annotation.submitted:
238+
user_stats["submitted_recordings"].add(recording.id)
239+
else:
240+
user_stats["unsubmitted_recordings"].add(recording.id)
241+
return annotations_by_user
242+
243+
244+
def _collect_total_sets(annotations_by_user):
245+
annotated_total = set()
246+
submitted_total = set()
247+
unsubmitted_total = set()
248+
for user_stats in annotations_by_user.values():
249+
annotated_total.update(user_stats["annotated_recordings"])
250+
submitted_total.update(user_stats["submitted_recordings"])
251+
unsubmitted_total.update(user_stats["unsubmitted_recordings"])
252+
return annotated_total, submitted_total, unsubmitted_total
253+
254+
255+
def _build_user_rows(tag, total_recordings, annotations_by_user, users):
256+
user_rows = []
257+
for user in users:
258+
owner_id = user["id"]
259+
username = user["username"]
260+
user_stats = annotations_by_user.get(owner_id)
261+
if user_stats is None:
262+
user_stats = {
263+
"username": username,
264+
"annotated_recordings": set(),
265+
"submitted_recordings": set(),
266+
"unsubmitted_recordings": set(),
267+
}
268+
annotated_count = len(user_stats["annotated_recordings"])
269+
user_rows.append(
270+
{
271+
"tag_id": tag.id,
272+
"tag_text": tag.text,
273+
"tag_owner": tag.user.username,
274+
"user_id": owner_id,
275+
"username": username,
276+
"total_recordings": total_recordings,
277+
"annotated_recordings": annotated_count,
278+
"submitted_recordings": len(user_stats["submitted_recordings"]),
279+
"unsubmitted_recordings": len(user_stats["unsubmitted_recordings"]),
280+
"remaining_recordings": total_recordings - annotated_count,
281+
}
282+
)
283+
return user_rows
284+
285+
286+
def _write_tag_exports(zipf, tag_rows, tag_user_rows):
287+
tag_fieldnames = [
288+
"tag_id",
289+
"tag_text",
290+
"tag_owner",
291+
"total_recordings",
292+
"annotated_recordings",
293+
"submitted_recordings",
294+
"unsubmitted_recordings",
295+
"remaining_recordings",
296+
]
297+
tag_user_fieldnames = [
298+
"tag_id",
299+
"tag_text",
300+
"tag_owner",
301+
"user_id",
302+
"username",
303+
"total_recordings",
304+
"annotated_recordings",
305+
"submitted_recordings",
306+
"unsubmitted_recordings",
307+
"remaining_recordings",
308+
]
309+
310+
tag_csv_buf = StringIO()
311+
tag_writer = csv.DictWriter(tag_csv_buf, fieldnames=tag_fieldnames)
312+
tag_writer.writeheader()
313+
for row in tag_rows:
314+
tag_writer.writerow(row)
315+
zipf.writestr("tag_summary.csv", tag_csv_buf.getvalue())
316+
317+
tag_user_csv_buf = StringIO()
318+
tag_user_writer = csv.DictWriter(tag_user_csv_buf, fieldnames=tag_user_fieldnames)
319+
tag_user_writer.writeheader()
320+
for row in tag_user_rows:
321+
tag_user_writer.writerow(row)
322+
zipf.writestr("tag_summary_by_user.csv", tag_user_csv_buf.getvalue())
323+
324+
users_payload = _build_users_payload(tag_user_rows)
325+
zipf.writestr(
326+
"tag_annotation_summary.json",
327+
json.dumps(
328+
{
329+
"users": users_payload,
330+
},
331+
indent=2,
332+
),
333+
)
334+
335+
336+
def _build_users_payload(tag_user_rows):
337+
users_by_id = {}
338+
for row in tag_user_rows:
339+
user_id = row["user_id"]
340+
if user_id not in users_by_id:
341+
users_by_id[user_id] = {
342+
"user_id": user_id,
343+
"username": row["username"],
344+
"tags": [],
345+
}
346+
347+
tag_entry = {
348+
"tag_id": row["tag_id"],
349+
"tag_text": row["tag_text"],
350+
"tag_owner": row["tag_owner"],
351+
"has_annotations": row["annotated_recordings"] > 0,
352+
}
353+
if row["annotated_recordings"] > 0:
354+
tag_entry.update(
355+
{
356+
"total_recordings": row["total_recordings"],
357+
"annotated_recordings": row["annotated_recordings"],
358+
"submitted_recordings": row["submitted_recordings"],
359+
"unsubmitted_recordings": row["unsubmitted_recordings"],
360+
"remaining_recordings": row["remaining_recordings"],
361+
}
362+
)
363+
else:
364+
tag_entry["annotated_recordings"] = 0
365+
users_by_id[user_id]["tags"].append(tag_entry)
366+
367+
return sorted(users_by_id.values(), key=lambda user: user["username"])

bats_ai/core/views/configuration.py

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
11
from __future__ import annotations
22

3+
from datetime import timedelta
34
import logging
45

56
from django.http import JsonResponse
7+
from django.utils.timezone import now
68
from ninja import Schema
79
from ninja.pagination import RouterPaginated
810

9-
from bats_ai.core.models import Configuration
11+
from bats_ai.core.models import Configuration, ExportedAnnotationFile
12+
from bats_ai.core.tasks.export_task import export_tag_annotation_summary_task
1013

1114
logger = logging.getLogger(__name__)
1215

@@ -41,9 +44,9 @@ def get_configuration(request):
4144
spectrogram_x_stretch=config.spectrogram_x_stretch,
4245
spectrogram_view=config.spectrogram_view,
4346
default_color_scheme=config.default_color_scheme,
44-
default_spectrogram_background_color=config.default_spectrogram_background_color,
47+
default_spectrogram_background_color=(config.default_spectrogram_background_color),
4548
non_admin_upload_enabled=config.non_admin_upload_enabled,
46-
mark_annotations_completed_enabled=config.mark_annotations_completed_enabled,
49+
mark_annotations_completed_enabled=(config.mark_annotations_completed_enabled),
4750
is_admin=request.user.is_authenticated and request.user.is_superuser,
4851
)
4952

@@ -78,3 +81,21 @@ def get_current_user(request):
7881
"id": request.user.id,
7982
}
8083
return {"email": "", "name": ""}
84+
85+
86+
class ExportTagSummaryResponse(Schema):
87+
exportId: int
88+
89+
90+
@router.post("/export-tag-summary", response=ExportTagSummaryResponse)
91+
def export_tag_summary(request):
92+
if not request.user.is_authenticated or not request.user.is_superuser:
93+
return JsonResponse({"error": "Permission denied"}, status=403)
94+
95+
export = ExportedAnnotationFile.objects.create(
96+
filters_applied={"type": "tag_annotation_summary"},
97+
status="pending",
98+
expires_at=now() + timedelta(hours=24),
99+
)
100+
export_tag_annotation_summary_task.delay(export.id)
101+
return {"exportId": export.id}

bats_ai/core/views/export_annotation.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,22 @@
1515
router = Router()
1616

1717

18+
def _is_tag_annotation_summary_export(export: ExportedAnnotationFile) -> bool:
19+
filters_applied = export.filters_applied
20+
return (
21+
isinstance(filters_applied, dict)
22+
and filters_applied.get("type") == "tag_annotation_summary"
23+
)
24+
25+
26+
def _can_access_export(request, export: ExportedAnnotationFile) -> bool:
27+
# Tag annotation summary exports include user-level aggregate stats,
28+
# so only admins can access them.
29+
if _is_tag_annotation_summary_export(export):
30+
return request.user.is_authenticated and request.user.is_superuser
31+
return True
32+
33+
1834
class ExportedAnnotationFileSchema(BaseModel):
1935
id: int
2036
status: str
@@ -37,16 +53,19 @@ def list_exports(request):
3753
expiresAt=e.expires_at,
3854
)
3955
for e in exports
56+
if _can_access_export(request, e)
4057
]
4158

4259

4360
@router.get("/{export_id}", response=ExportedAnnotationFileSchema)
4461
def get_export_status(request, export_id: int):
4562
export = get_object_or_404(ExportedAnnotationFile, pk=export_id)
63+
if not _can_access_export(request, export):
64+
return JsonResponse({"error": "Permission denied"}, status=403)
4665
return ExportedAnnotationFileSchema(
4766
id=export.id,
4867
status=export.status,
49-
downloadUrl=export.download_url if export.status == "complete" else None,
68+
downloadUrl=(export.download_url if export.status == "complete" else None),
5069
created=export.created,
5170
expiresAt=export.expires_at,
5271
)
@@ -55,6 +74,8 @@ def get_export_status(request, export_id: int):
5574
@router.delete("/{export_id}")
5675
def delete_export(request, export_id: int):
5776
export = get_object_or_404(ExportedAnnotationFile, pk=export_id)
77+
if not _can_access_export(request, export):
78+
return JsonResponse({"error": "Permission denied"}, status=403)
5879

5980
# Optional: block deleting exports still in progress
6081
if export.status not in ("complete", "failed", "expired"):

client/src/api/api.ts

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -769,13 +769,24 @@ export interface ExportStatus {
769769
expiresAt: string;
770770
}
771771

772+
export interface ExportTagSummaryResponse {
773+
exportId: number;
774+
}
775+
772776
async function getExportStatus(exportId: number) {
773777
const result = await axiosInstance.get<ExportStatus>(
774778
`/export-annotation/${exportId}`,
775779
);
776780
return result.data;
777781
}
778782

783+
async function exportTagSummary(): Promise<ExportTagSummaryResponse> {
784+
const result = await axiosInstance.post<ExportTagSummaryResponse>(
785+
"/configuration/export-tag-summary",
786+
);
787+
return result.data;
788+
}
789+
779790
export interface VettingDetails {
780791
id: number;
781792
user_id: number;
@@ -899,6 +910,7 @@ export {
899910
getFilteredProcessingTasks,
900911
getFileAnnotationDetails,
901912
getExportStatus,
913+
exportTagSummary,
902914
getRecordingTags,
903915
getUnsubmittedNeighbors,
904916
getComputedPulseContour,

0 commit comments

Comments
 (0)