Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/aind_behavior_vr_foraging/data_mappers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ class DataMapperCli(BaseSettings, cli_kebab_case=True):

def cli_cmd(self):
"""Generate aind-data-schema metadata for the VR Foraging dataset located at the specified path."""
from ._rig import AindInstrumentDataMapper
from ._session import AindAcquisitionDataMapper
from ._acquisition import AindAcquisitionDataMapper
from ._instrument import AindInstrumentDataMapper

session_mapper = AindAcquisitionDataMapper(
data_path=Path(self.data_path),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def _map(self) -> acquisition.Acquisition:
acquisition_end_time=utcnow(),
acquisition_start_time=self.session_model.date,
experimenters=self.session_model.experimenter,
acquisition_type=self.session_model.experiment or self.task_model.name,
acquisition_type=self.task_model.name,
coordinate_system=None,
data_streams=self._get_data_streams(),
calibrations=self._get_calibrations(),
Expand Down Expand Up @@ -170,11 +170,20 @@ def _get_data_streams(self) -> List[acquisition.DataStream]:
if _device[0] is not None and self._include_device(_device[1])
]

code = [self._get_bonsai_as_code(), self._get_python_as_code()]
if (
self.curriculum_suggestion is not None
and self.curriculum_suggestion.trainer_state is not None
and self.curriculum_suggestion.trainer_state.curriculum is not None
and self.curriculum_suggestion.trainer_state.is_on_curriculum is True
):
code.append(self._get_curriculum_as_code())

data_streams: list[acquisition.DataStream] = [
acquisition.DataStream(
stream_start_time=self.session_model.date,
stream_end_time=self.session_end_time,
code=[self._get_bonsai_as_code(), self._get_python_as_code()],
code=code,
active_devices=active_devices,
modalities=modalities,
configurations=self._get_cameras_config(),
Expand Down Expand Up @@ -257,7 +266,8 @@ def _get_stimulus_epochs(self) -> List[acquisition.StimulusEpoch]:

# Animal performance, curriculum, and metrics
performance_metrics: Optional[acquisition.PerformanceMetrics] = None
curriculum_status: str = "false"
curriculum_status: Optional[str] = None
training_protocol_name: Optional[str] = None

if self.curriculum_suggestion is not None:
logger.debug("Curriculum suggestion found. Setting performance metrics based on curriculum suggestion.")
Expand All @@ -268,7 +278,10 @@ def _get_stimulus_epochs(self) -> List[acquisition.StimulusEpoch]:
)
if self.trainer_state is not None:
logger.debug("Trainer state found. Setting curriculum status based on trainer state.")
curriculum_status = str(self.trainer_state.is_on_curriculum)
if self.trainer_state.stage is not None:
curriculum_status = str(self.trainer_state.stage.name)
if self.trainer_state.curriculum is not None:
training_protocol_name = str(self.trainer_state.curriculum.name)

stimulus_epochs: list[acquisition.StimulusEpoch] = [
acquisition.StimulusEpoch(
Expand All @@ -277,15 +290,16 @@ def _get_stimulus_epochs(self) -> List[acquisition.StimulusEpoch]:
stimulus_start_time=self.session_model.date,
stimulus_end_time=self.session_end_time,
configurations=stimulus_epoch_configurations,
stimulus_name=self.session_model.experiment or self.task_model.name,
stimulus_name=self.task_model.name,
stimulus_modalities=stimulus_modalities,
performance_metrics=performance_metrics,
curriculum_status=curriculum_status,
training_protocol_name=training_protocol_name,
)
]
return stimulus_epochs

def _get_cameras_config(self) -> List[acquisition.DetectorConfig]:
def _get_cameras_config(self) -> list[acquisition.DetectorConfig]:
def _map_camera(name: str, camera: cameras.CameraTypes) -> acquisition.DetectorConfig:
assert camera.video_writer is not None, "Camera does not have a video writer configured."
return acquisition.DetectorConfig(
Expand Down Expand Up @@ -326,6 +340,8 @@ def _get_bonsai_as_code(self) -> acquisition.Code:
url=self.repository.remote().url,
name="Aind.Behavior.VrForaging",
version=self.repository.head.commit.hexsha,
# version=__semver__, # TODO slot this in when this is solved https://github.com/AllenNeuralDynamics/aind-data-schema/issues/1789
# sha=self.repository.head.commit.hexsha,
language="Bonsai",
language_version=bonsai_version,
run_script=Path(self.bonsai_app.workflow),
Expand All @@ -345,3 +361,32 @@ def _get_python_as_code(self) -> acquisition.Code:
language="Python",
language_version=semver,
)

def _get_curriculum_as_code(self) -> acquisition.Code:
target = Path("plugins/curricula")
submodule: Optional[git.Submodule] = None
for sub in self.repository.submodules:
if Path(sub.path) == target:
submodule = sub
break

if submodule is None:
raise ValueError(
f"Could not find a git submodule at '{target}' inside repository '{self.repository.working_tree_dir}'."
)

if self.curriculum_suggestion is None:
raise ValueError("Curriculum suggestion is not set.")
if (
self.curriculum_suggestion.trainer_state is None
or self.curriculum_suggestion.trainer_state.curriculum is None
):
raise ValueError("Trainer state or curriculum is not set in the curriculum suggestion.")
return acquisition.Code(
url=submodule.url,
# sha=submodule.hexsha, # TODO see https://github.com/AllenNeuralDynamics/aind-data-schema/issues/1789
name=self.curriculum_suggestion.trainer_state.curriculum.pkg_location,
version=self.curriculum_suggestion.trainer_state.curriculum.version,
language="aind-behavior-curriculum",
language_version=self.curriculum_suggestion.dsl_version,
)
1 change: 1 addition & 0 deletions src/aind_behavior_vr_foraging/task_logic.py
Original file line number Diff line number Diff line change
Expand Up @@ -941,6 +941,7 @@ class OperationControl(BaseModel):
wait_to_start_duration: float = Field(default=0, ge=0, description="Duration to wait before starting the task")
wait_to_finish_duration: float = Field(default=0, ge=0, description="Duration to wait after finishing the task")


# ==================== BLOCK END CONDITIONS ====================


Expand Down
165 changes: 161 additions & 4 deletions tests/test_aind_data_mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,14 @@
from pathlib import Path
from unittest.mock import MagicMock, patch

import aind_behavior_curriculum
from aind_behavior_curriculum import Metrics, Stage, Trainer, create_curriculum
from aind_data_schema.core import acquisition, instrument
from aind_data_schema.utils import compatibility_check
from clabe.apps import CurriculumSuggestion

from aind_behavior_vr_foraging.data_mappers._rig import AindInstrumentDataMapper
from aind_behavior_vr_foraging.data_mappers._session import AindAcquisitionDataMapper
from aind_behavior_vr_foraging.data_mappers._acquisition import AindAcquisitionDataMapper
from aind_behavior_vr_foraging.data_mappers._instrument import AindInstrumentDataMapper

sys.path.append(".")
from aind_behavior_vr_foraging.cli import DataMapperCli
Expand Down Expand Up @@ -53,7 +56,7 @@ def setUp(self):
def tearDown(self):
self.temp_dir.cleanup()

@patch("aind_behavior_vr_foraging.data_mappers._session.AindAcquisitionDataMapper._map")
@patch("aind_behavior_vr_foraging.data_mappers._acquisition.AindAcquisitionDataMapper._map")
def test_session_mock_map(self, mock_map):
mock_map.return_value = MagicMock()
result = self.session_mapper.map()
Expand All @@ -69,7 +72,7 @@ def test_session_round_trip(self):
assert mapped is not None
acquisition.Acquisition.model_validate_json(mapped.model_dump_json())

@patch("aind_behavior_vr_foraging.data_mappers._rig.AindInstrumentDataMapper._map")
@patch("aind_behavior_vr_foraging.data_mappers._instrument.AindInstrumentDataMapper._map")
def test_rig_mock_map(self, mock_map):
mock_map.return_value = MagicMock()
result = self.rig_mapper.map()
Expand Down Expand Up @@ -107,5 +110,159 @@ def test_mapper_cli(self):
self.assertTrue(acquisition_path.exists())


def _make_curriculum_suggestion() -> CurriculumSuggestion:
"""Create a minimal CurriculumSuggestion using aind_behavior_curriculum primitives.

Uses the same pattern as the 'demo' mode of the template curriculum:
trainer_state and metrics are constructed programmatically without real session data.
"""

class _DemoMetrics(Metrics):
reward_rate: float = 0.75
trials_count: int = 100

curriculum_class = create_curriculum(
"DemoCurriculum",
"0.0.0",
(task_logic.__class__,),
pkg_location="demo.curriculum",
)
curriculum = curriculum_class()
stage = Stage(name="demo_stage", task=task_logic)
curriculum.add_stage(stage)
trainer = Trainer(curriculum)
trainer_state = trainer.create_trainer_state(
stage=stage,
is_on_curriculum=True,
active_policies=tuple(),
)
metrics = _DemoMetrics()
return CurriculumSuggestion(
trainer_state=trainer_state,
metrics=metrics,
version="0.0.0",
dsl_version=aind_behavior_curriculum.__version__,
)


class TestCurriculumIntegrationInDataMapper(unittest.TestCase):
"""Tests that AindAcquisitionDataMapper correctly integrates curriculum suggestion data."""

def setUp(self):
self.temp_dir = tempfile.TemporaryDirectory()
self.data_path = Path(self.temp_dir.name)

logs_dir = self.data_path / "Behavior" / "Logs"
logs_dir.mkdir(parents=True, exist_ok=True)

with open(logs_dir / "session_input.json", "w", encoding="utf-8") as f:
json.dump(session.model_dump(mode="json"), f, indent=2)
with open(logs_dir / "rig_input.json", "w", encoding="utf-8") as f:
json.dump(rig.model_dump(mode="json"), f, indent=2)
with open(logs_dir / "tasklogic_input.json", "w", encoding="utf-8") as f:
json.dump(task_logic.model_dump(mode="json"), f, indent=2)

self.curriculum_suggestion = _make_curriculum_suggestion()

# Write trainer_state.json so the mapper picks up curriculum_status / training_protocol_name
trainer_state_path = self.data_path / "Behavior" / "trainer_state.json"
trainer_state_path.write_text(self.curriculum_suggestion.trainer_state.model_dump_json(), encoding="utf-8")

self.repo_path = Path("./")
self.session_end_time = datetime(2023, 1, 1, 12, 0, 0, tzinfo=timezone.utc)

def tearDown(self):
self.temp_dir.cleanup()

def _make_mapper(self, curriculum_suggestion=None) -> AindAcquisitionDataMapper:
return AindAcquisitionDataMapper(
data_path=self.data_path,
repo_path=self.repo_path,
session_end_time=self.session_end_time,
curriculum_suggestion=curriculum_suggestion,
)

def test_curriculum_suggestion_instance_is_accepted(self):
"""Passing a CurriculumSuggestion instance is stored on the mapper."""
mapper = self._make_mapper(self.curriculum_suggestion)
self.assertIs(mapper.curriculum_suggestion, self.curriculum_suggestion)

def test_curriculum_suggestion_from_json_file(self):
"""Passing a file path loads the suggestion via JSON round-trip."""
suggestion_path = Path(self.temp_dir.name) / "suggestion.json"
suggestion_path.write_text(self.curriculum_suggestion.model_dump_json(), encoding="utf-8")

mapper = self._make_mapper(suggestion_path)
self.assertIsNotNone(mapper.curriculum_suggestion)
self.assertEqual(
mapper.curriculum_suggestion.trainer_state.curriculum.name,
self.curriculum_suggestion.trainer_state.curriculum.name,
)

def test_stimulus_epoch_has_performance_metrics(self):
"""Performance metrics from the suggestion appear in the mapped stimulus epoch."""
mapped = self._make_mapper(self.curriculum_suggestion).map()
epoch = mapped.stimulus_epochs[0]
self.assertIsNotNone(epoch.performance_metrics)
output_params = epoch.performance_metrics.output_parameters.model_dump()
self.assertIn("reward_rate", output_params)
self.assertAlmostEqual(output_params["reward_rate"], 0.75)

def test_stimulus_epoch_has_curriculum_status(self):
"""curriculum_status in the stimulus epoch matches the stage name."""
mapped = self._make_mapper(self.curriculum_suggestion).map()
epoch = mapped.stimulus_epochs[0]
self.assertEqual(epoch.curriculum_status, "demo_stage")

def test_stimulus_epoch_has_training_protocol_name(self):
"""training_protocol_name in the stimulus epoch matches the curriculum name."""
mapped = self._make_mapper(self.curriculum_suggestion).map()
epoch = mapped.stimulus_epochs[0]
self.assertEqual(epoch.training_protocol_name, "DemoCurriculum")

def test_data_stream_includes_curriculum_code(self):
"""When on curriculum, the data stream code list includes the curriculum entry."""
mapped = self._make_mapper(self.curriculum_suggestion).map()
stream = mapped.data_streams[0]
code_names = [c.name for c in stream.code if c.name is not None]
self.assertIn("demo.curriculum", code_names)

def test_curriculum_code_metadata(self):
"""Curriculum Code entry carries the expected metadata from the submodule and suggestion."""
import git

repo = git.Repo("./")
submodule = next(sub for sub in repo.submodules if sub.path == "plugins/curricula")

mapped = self._make_mapper(self.curriculum_suggestion).map()
stream = mapped.data_streams[0]
curriculum_code = next(c for c in stream.code if c.name == "demo.curriculum")

self.assertEqual(curriculum_code.url, submodule.url)
self.assertEqual(
curriculum_code.version,
self.curriculum_suggestion.trainer_state.curriculum.version,
)
self.assertEqual(curriculum_code.language, "aind-behavior-curriculum")
self.assertEqual(curriculum_code.language_version, self.curriculum_suggestion.dsl_version)

def test_no_curriculum_suggestion_omits_curriculum_fields(self):
"""Without a suggestion, performance_metrics is absent.

Note: curriculum_status and training_protocol_name are sourced from the
trainer_state.json file (self.trainer_state), which is independent of the
curriculum_suggestion argument. They will be set whenever trainer_state.json
exists, regardless of whether a suggestion is provided.
"""
mapped = self._make_mapper(curriculum_suggestion=None).map()
epoch = mapped.stimulus_epochs[0]
self.assertIsNone(epoch.performance_metrics)

def test_acquisition_round_trip_with_curriculum(self):
"""Mapped acquisition with curriculum data survives a JSON round-trip."""
mapped = self._make_mapper(self.curriculum_suggestion).map()
acquisition.Acquisition.model_validate_json(mapped.model_dump_json())


if __name__ == "__main__":
unittest.main()
Loading