Skip to content

Commit 7529cc1

Browse files
committed
extract pulses and add to recording data
1 parent 81f0e68 commit 7529cc1

3 files changed

Lines changed: 94 additions & 16 deletions

File tree

bats_ai/core/tasks/tasks.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,15 @@ def recording_compute_spectrogram(self, recording_id: int): # noqa: C901, PLR09
228228
pulse_metadata_obj.contours = []
229229
pulse_metadata_obj.save()
230230

231+
from bats_ai.core.utils.batbot_annotations import (
232+
create_pulse_annotations_from_batbot_segments,
233+
)
234+
235+
create_pulse_annotations_from_batbot_segments(
236+
recording,
237+
compressed["segments"],
238+
)
239+
231240
if processing_task:
232241
processing_task.status = ProcessingTask.Status.COMPLETE
233242
processing_task.save()
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
from __future__ import annotations
2+
3+
import logging
4+
from typing import TYPE_CHECKING
5+
6+
from bats_ai.core.models import Annotations, Configuration
7+
8+
if TYPE_CHECKING:
9+
from bats_ai.core.models import Recording
10+
from bats_ai.core.utils.batbot_metadata import BatBotMetadataCurve
11+
12+
logger = logging.getLogger(__name__)
13+
14+
BATBOT_ANNOTATION_MODEL = "batbot"
15+
16+
17+
def _segment_bounds(
18+
segment: BatBotMetadataCurve,
19+
) -> tuple[float, float, float, float] | None:
20+
curve = segment.get("curve_hz_ms") or []
21+
if not curve:
22+
return None
23+
24+
times = [pt[1] for pt in curve]
25+
freqs = [pt[0] for pt in curve]
26+
return min(times), max(times), min(freqs), max(freqs)
27+
28+
29+
def create_pulse_annotations_from_batbot_segments(
30+
recording: Recording,
31+
segments: list[BatBotMetadataCurve],
32+
) -> int:
33+
"""Create pulse annotations from BatBot segments when enabled in Configuration."""
34+
config = Configuration.objects.first()
35+
if not config or not config.create_pulse_annotations_from_batbot:
36+
return 0
37+
38+
Annotations.objects.filter(
39+
recording=recording,
40+
model=BATBOT_ANNOTATION_MODEL,
41+
).delete()
42+
43+
created = 0
44+
for segment in segments:
45+
bounds = _segment_bounds(segment)
46+
if bounds is None:
47+
segment_index = segment.get("segment_index")
48+
logger.warning(
49+
"Skipping BatBot pulse annotation for recording=%s segment_index=%s: no bbox",
50+
recording.pk,
51+
segment_index,
52+
)
53+
continue
54+
55+
t_start, t_end, f_lo, f_hi = bounds
56+
Annotations.objects.create(
57+
recording=recording,
58+
owner=recording.owner,
59+
start_time=t_start,
60+
end_time=t_end,
61+
low_freq=f_lo,
62+
high_freq=f_hi,
63+
type="pulse",
64+
model=BATBOT_ANNOTATION_MODEL,
65+
comments="",
66+
)
67+
created += 1
68+
69+
return created

bats_ai/core/views/recording.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -178,12 +178,12 @@ class RecordingPaginatedResponse(Schema):
178178

179179

180180
class AnnotationSchema(Schema):
181-
start_time: int
182-
end_time: int
183-
low_freq: int
184-
high_freq: int
181+
start_time: float
182+
end_time: float
183+
low_freq: float
184+
high_freq: float
185185
species: list[SpeciesSchema]
186-
comments: str
186+
comments: str = ""
187187
type: str | None = None
188188
id: int | None = None
189189
owner_email: str = None
@@ -196,18 +196,18 @@ def from_orm(cls, obj: Annotations, owner_email=None):
196196
low_freq=obj.low_freq,
197197
high_freq=obj.high_freq,
198198
species=[SpeciesSchema.from_orm(species) for species in obj.species.all()],
199-
comments=obj.comments,
199+
comments=obj.comments or "",
200200
id=obj.id,
201201
type=obj.type,
202202
owner_email=owner_email, # Include owner_email in the schema
203203
)
204204

205205

206206
class UpdateAnnotationsSchema(Schema):
207-
start_time: int | None
208-
end_time: int | None
209-
low_freq: int | None
210-
high_freq: int | None
207+
start_time: float | None
208+
end_time: float | None
209+
low_freq: float | None
210+
high_freq: float | None
211211
species: list[SpeciesSchema] | None
212212
comments: str | None
213213
type: str | None
@@ -278,10 +278,10 @@ def linestring_to_list(ls):
278278

279279
class SequenceAnnotationSchema(Schema):
280280
id: int
281-
start_time: int
282-
end_time: int
281+
start_time: float
282+
end_time: float
283283
type: str | None
284-
comments: str
284+
comments: str = ""
285285
species: list[SpeciesSchema] | None
286286
owner_email: str = None
287287

@@ -292,15 +292,15 @@ def from_orm(cls, obj, owner_email=None):
292292
end_time=obj.end_time,
293293
type=obj.type,
294294
species=[SpeciesSchema.from_orm(species) for species in obj.species.all()],
295-
comments=obj.comments,
295+
comments=obj.comments or "",
296296
id=obj.id,
297297
owner_email=owner_email, # Include owner_email in the schema
298298
)
299299

300300

301301
class UpdateSequenceAnnotationSchema(Schema):
302-
start_time: int = None
303-
end_time: int = None
302+
start_time: float = None
303+
end_time: float = None
304304
type: str | None = None
305305
comments: str | None = None
306306

0 commit comments

Comments
 (0)