|
4 | 4 | from datetime import timedelta |
5 | 5 | from io import BytesIO, StringIO |
6 | 6 | import json |
| 7 | +from urllib.parse import urljoin |
7 | 8 | import zipfile |
8 | 9 |
|
| 10 | +from django.conf import settings |
9 | 11 | from django.contrib.auth.models import User |
10 | 12 | from django.core.files import File |
| 13 | +from django.core.files.storage import default_storage |
| 14 | +from django.db.models import Prefetch |
11 | 15 | from django.utils.timezone import now |
12 | 16 |
|
13 | 17 | from bats_ai.celery import app |
|
18 | 22 | RecordingTag, |
19 | 23 | SequenceAnnotations, |
20 | 24 | ) |
| 25 | +from bats_ai.core.models.recording_annotation import RecordingAnnotationSpecies |
| 26 | + |
| 27 | +RECORDING_ANNOTATION_EXPORT_SCHEMA_VERSION = 1 |
| 28 | + |
| 29 | +RECORDING_ANNOTATION_FLAT_FIELDNAMES = [ |
| 30 | + "recording_id", |
| 31 | + "filename", |
| 32 | + "grts_cell_id", |
| 33 | + "sample_frame_id", |
| 34 | + "id", |
| 35 | + "owner", |
| 36 | + "comments", |
| 37 | + "created", |
| 38 | + "model", |
| 39 | + "species", |
| 40 | + "species_codes", |
| 41 | + "confidence", |
| 42 | + "submitted", |
| 43 | + "additional_data", |
| 44 | + "spectrogram_url", |
| 45 | + "wav_download_url", |
| 46 | +] |
21 | 47 |
|
22 | 48 |
|
23 | 49 | def build_filters(filters, *, has_confidence=False): |
@@ -179,6 +205,175 @@ def export_tag_annotation_summary_task(self, export_id: int): |
179 | 205 | raise |
180 | 206 |
|
181 | 207 |
|
| 208 | +@app.task(bind=True) |
| 209 | +def export_recording_annotation_hierarchy_task(self, export_id: int): |
| 210 | + export_record = ExportedAnnotationFile.objects.get(pk=export_id) |
| 211 | + try: |
| 212 | + recordings_payload, flat_rows, manifest = _collect_recording_annotations_export() |
| 213 | + |
| 214 | + buffer = BytesIO() |
| 215 | + with zipfile.ZipFile(buffer, "w", zipfile.ZIP_DEFLATED) as zipf: |
| 216 | + _write_recording_annotations_zip( |
| 217 | + zipf, |
| 218 | + recordings_payload, |
| 219 | + flat_rows, |
| 220 | + manifest, |
| 221 | + ) |
| 222 | + |
| 223 | + buffer.seek(0) |
| 224 | + filename = f"recording-annotations-{export_id}.zip" |
| 225 | + export_record.file.save(filename, File(buffer), save=False) |
| 226 | + export_record.download_url = export_record.file.url |
| 227 | + export_record.status = "complete" |
| 228 | + export_record.expires_at = now() + timedelta(hours=24) |
| 229 | + export_record.save() |
| 230 | + except Exception: |
| 231 | + export_record.status = "failed" |
| 232 | + export_record.save() |
| 233 | + raise |
| 234 | + |
| 235 | + |
| 236 | +def _recording_species_lists(annotation): |
| 237 | + common_names = [] |
| 238 | + species_codes = [] |
| 239 | + for species_link in annotation.ordered_species_links: |
| 240 | + species = species_link.species |
| 241 | + common_names.append(species.common_name) |
| 242 | + species_codes.append(species.species_code) |
| 243 | + return common_names, species_codes |
| 244 | + |
| 245 | + |
| 246 | +def _recording_annotation_entry_dict(annotation, species, species_codes): |
| 247 | + return { |
| 248 | + "id": annotation.id, |
| 249 | + "owner": annotation.owner.username, |
| 250 | + "comments": annotation.comments, |
| 251 | + "created": annotation.created.isoformat(), |
| 252 | + "model": annotation.model, |
| 253 | + "species": species, |
| 254 | + "species_codes": species_codes, |
| 255 | + "confidence": annotation.confidence, |
| 256 | + "additional_data": annotation.additional_data, |
| 257 | + "submitted": annotation.submitted, |
| 258 | + } |
| 259 | + |
| 260 | + |
| 261 | +def _recording_export_metadata(recording): |
| 262 | + return { |
| 263 | + "recording_id": recording.id, |
| 264 | + "filename": recording.name, |
| 265 | + "grts_cell_id": recording.grts_cell_id, |
| 266 | + "sample_frame_id": recording.sample_frame_id, |
| 267 | + "spectrogram_url": urljoin( |
| 268 | + settings.BATAI_WEB_URL, |
| 269 | + f"/recording/{recording.id}/spectrogram", |
| 270 | + ), |
| 271 | + "wav_download_url": ( |
| 272 | + default_storage.url(recording.audio_file.name) if recording.audio_file else None |
| 273 | + ), |
| 274 | + } |
| 275 | + |
| 276 | + |
| 277 | +def _collect_recording_annotations_export(): |
| 278 | + species_links_prefetch = Prefetch( |
| 279 | + "recordingannotationspecies_set", |
| 280 | + queryset=RecordingAnnotationSpecies.objects.select_related("species").order_by("order"), |
| 281 | + to_attr="ordered_species_links", |
| 282 | + ) |
| 283 | + annotations = ( |
| 284 | + RecordingAnnotation.objects.select_related("recording", "owner") |
| 285 | + .prefetch_related(species_links_prefetch) |
| 286 | + .order_by("recording_id", "id") |
| 287 | + ) |
| 288 | + |
| 289 | + recordings_by_id = {} |
| 290 | + flat_rows = [] |
| 291 | + submitted_annotation_count = 0 |
| 292 | + unsubmitted_annotation_count = 0 |
| 293 | + |
| 294 | + for annotation in annotations: |
| 295 | + recording = annotation.recording |
| 296 | + recording_entry = recordings_by_id.get(recording.id) |
| 297 | + if recording_entry is None: |
| 298 | + recording_metadata = _recording_export_metadata(recording) |
| 299 | + recording_entry = { |
| 300 | + **recording_metadata, |
| 301 | + "submitted_annotations": 0, |
| 302 | + "unsubmitted_annotations": 0, |
| 303 | + "annotations": [], |
| 304 | + } |
| 305 | + recordings_by_id[recording.id] = recording_entry |
| 306 | + |
| 307 | + if annotation.submitted: |
| 308 | + recording_entry["submitted_annotations"] += 1 |
| 309 | + submitted_annotation_count += 1 |
| 310 | + else: |
| 311 | + recording_entry["unsubmitted_annotations"] += 1 |
| 312 | + unsubmitted_annotation_count += 1 |
| 313 | + |
| 314 | + species, species_codes = _recording_species_lists(annotation) |
| 315 | + annotation_entry = _recording_annotation_entry_dict( |
| 316 | + annotation, |
| 317 | + species, |
| 318 | + species_codes, |
| 319 | + ) |
| 320 | + recording_entry["annotations"].append(annotation_entry) |
| 321 | + flat_rows.append( |
| 322 | + { |
| 323 | + **_recording_export_metadata(recording), |
| 324 | + **annotation_entry, |
| 325 | + } |
| 326 | + ) |
| 327 | + |
| 328 | + recordings_payload = sorted( |
| 329 | + recordings_by_id.values(), |
| 330 | + key=lambda recording: recording["recording_id"], |
| 331 | + ) |
| 332 | + annotation_count = submitted_annotation_count + unsubmitted_annotation_count |
| 333 | + manifest = { |
| 334 | + "export_type": "recording_annotation_hierarchy", |
| 335 | + "schema_version": RECORDING_ANNOTATION_EXPORT_SCHEMA_VERSION, |
| 336 | + "exported_at": now().isoformat(), |
| 337 | + "recording_count": len(recordings_by_id), |
| 338 | + "annotation_count": annotation_count, |
| 339 | + "submitted_annotation_count": submitted_annotation_count, |
| 340 | + "unsubmitted_annotation_count": unsubmitted_annotation_count, |
| 341 | + } |
| 342 | + return recordings_payload, flat_rows, manifest |
| 343 | + |
| 344 | + |
| 345 | +def _flat_row_for_csv(row): |
| 346 | + csv_row = {key: row.get(key) for key in RECORDING_ANNOTATION_FLAT_FIELDNAMES} |
| 347 | + for key in ("species", "species_codes"): |
| 348 | + value = csv_row.get(key) |
| 349 | + if isinstance(value, list): |
| 350 | + csv_row[key] = "|".join("" if item is None else str(item) for item in value) |
| 351 | + if csv_row.get("additional_data") is not None: |
| 352 | + csv_row["additional_data"] = json.dumps(csv_row["additional_data"]) |
| 353 | + return csv_row |
| 354 | + |
| 355 | + |
| 356 | +def _write_recording_annotations_zip(zipf, recordings_payload, flat_rows, manifest): |
| 357 | + zipf.writestr("export_manifest.json", json.dumps(manifest)) |
| 358 | + zipf.writestr( |
| 359 | + "recording_annotations.json", |
| 360 | + json.dumps({"recordings": recordings_payload}), |
| 361 | + ) |
| 362 | + if not flat_rows: |
| 363 | + return |
| 364 | + |
| 365 | + zipf.writestr( |
| 366 | + "recording_annotations_flat.json", |
| 367 | + json.dumps(flat_rows), |
| 368 | + ) |
| 369 | + csv_buf = StringIO() |
| 370 | + writer = csv.DictWriter(csv_buf, fieldnames=RECORDING_ANNOTATION_FLAT_FIELDNAMES) |
| 371 | + writer.writeheader() |
| 372 | + for row in flat_rows: |
| 373 | + writer.writerow(_flat_row_for_csv(row)) |
| 374 | + zipf.writestr("recording_annotations_flat.csv", csv_buf.getvalue()) |
| 375 | + |
| 376 | + |
182 | 377 | def _collect_tag_summary_rows(): |
183 | 378 | tag_rows = [] |
184 | 379 | tag_user_rows = [] |
|
0 commit comments