Skip to content

Commit 1f43a8d

Browse files
Add deterministic reversals curriculum (#545)
* feat: add deterministic reversal curricula Use MarkovEnvironment (renamed from EnvironmentStatistics) after PR #547 * refactor: rename reward_locked to reward_capped * Deduplicate stage definition * Add curriculum definition * fix: correct the available functions to avoid double depletion * refactor: remove cap_reward variable and utilize reward available instead * linting * Revert "linting" This reverts commit 38ab1e4. * Revert "refactor: remove cap_reward variable and utilize reward available instead" This reverts commit 130f293. * Revert "fix: correct the available functions to avoid double depletion" This reverts commit e3148eb. --------- Co-authored-by: bruno-f-cruz <7049351+bruno-f-cruz@users.noreply.github.com>
1 parent 38cadaf commit 1f43a8d

10 files changed

Lines changed: 5029 additions & 0 deletions

File tree

schema/deterministic_reversals.json

Lines changed: 2312 additions & 0 deletions
Large diffs are not rendered by default.

schema/deterministic_reversals_reward_capped.json

Lines changed: 2352 additions & 0 deletions
Large diffs are not rendered by default.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
from .curriculum import CURRICULUM, CURRICULUM_NAME, PKG_LOCATION, TRAINER, run_curriculum
2+
3+
__all__ = [
4+
"CURRICULUM_NAME",
5+
"CURRICULUM",
6+
"TRAINER",
7+
"run_curriculum",
8+
"PKG_LOCATION",
9+
]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
from typing import Any, Callable, Type
2+
3+
import aind_behavior_curriculum
4+
from aind_behavior_curriculum import Stage, StageTransition, Trainer, TrainerState, create_curriculum
5+
from aind_behavior_vr_foraging.task_logic import AindVrForagingTaskLogic
6+
7+
from .. import __semver__
8+
from ..cli import CurriculumCliArgs, CurriculumSuggestion
9+
from ..depletion.curriculum import (
10+
metrics_from_dataset_path,
11+
st_s_stage_all_odors_rewarded_s_stage_graduation,
12+
st_s_stage_one_odor_no_depletion_s_stage_one_odor_w_depletion_day_0,
13+
st_s_stage_one_odor_w_depletion_day_0_s_stage_all_odors_rewarded,
14+
st_s_stage_one_odor_w_depletion_day_0_s_stage_one_odor_w_depletion_day_1,
15+
st_s_stage_one_odor_w_depletion_day_1_s_stage_all_odors_rewarded,
16+
st_s_stage_one_odor_w_depletion_day_1_s_stage_one_odor_w_depletion_day_0,
17+
trainer_state_from_file,
18+
)
19+
from ..depletion.stages import (
20+
make_s_stage_one_odor_no_depletion,
21+
make_s_stage_one_odor_w_depletion_day_0,
22+
make_s_stage_one_odor_w_depletion_day_1,
23+
)
24+
25+
26+
def build_deterministic_reversal_curriculum(
27+
curriculum_name: str,
28+
pkg_location: str,
29+
make_all_odors_rewarded: Callable[[], Stage],
30+
make_graduation: Callable[[], Stage],
31+
) -> tuple[
32+
aind_behavior_curriculum.Curriculum,
33+
Trainer,
34+
Callable[[CurriculumCliArgs], CurriculumSuggestion[TrainerState[Any], Any]],
35+
]:
36+
curriculum_class: Type[aind_behavior_curriculum.Curriculum[AindVrForagingTaskLogic]] = create_curriculum(
37+
curriculum_name, __semver__, (AindVrForagingTaskLogic,), pkg_location=pkg_location
38+
)
39+
curriculum = curriculum_class()
40+
41+
curriculum.add_stage_transition(
42+
make_s_stage_one_odor_no_depletion(),
43+
make_s_stage_one_odor_w_depletion_day_0(),
44+
StageTransition(st_s_stage_one_odor_no_depletion_s_stage_one_odor_w_depletion_day_0),
45+
)
46+
curriculum.add_stage_transition(
47+
make_s_stage_one_odor_w_depletion_day_0(),
48+
make_s_stage_one_odor_w_depletion_day_1(),
49+
StageTransition(st_s_stage_one_odor_w_depletion_day_0_s_stage_one_odor_w_depletion_day_1),
50+
)
51+
curriculum.add_stage_transition(
52+
make_s_stage_one_odor_w_depletion_day_1(),
53+
make_s_stage_one_odor_w_depletion_day_0(),
54+
StageTransition(st_s_stage_one_odor_w_depletion_day_1_s_stage_one_odor_w_depletion_day_0),
55+
)
56+
curriculum.add_stage_transition(
57+
make_s_stage_one_odor_w_depletion_day_1(),
58+
make_all_odors_rewarded(),
59+
StageTransition(st_s_stage_one_odor_w_depletion_day_1_s_stage_all_odors_rewarded),
60+
)
61+
curriculum.add_stage_transition(
62+
make_s_stage_one_odor_w_depletion_day_0(),
63+
make_all_odors_rewarded(),
64+
StageTransition(st_s_stage_one_odor_w_depletion_day_0_s_stage_all_odors_rewarded),
65+
)
66+
curriculum.add_stage_transition(
67+
make_all_odors_rewarded(),
68+
make_graduation(),
69+
StageTransition(st_s_stage_all_odors_rewarded_s_stage_graduation),
70+
)
71+
72+
trainer = Trainer(curriculum)
73+
74+
def run_curriculum(args: CurriculumCliArgs) -> CurriculumSuggestion[TrainerState[Any], Any]:
75+
metrics: aind_behavior_curriculum.Metrics
76+
trainer_state = trainer_state_from_file(args.input_trainer_state, trainer)
77+
metrics = metrics_from_dataset_path(args.data_directory, trainer_state)
78+
trainer_state = trainer.evaluate(trainer_state, metrics)
79+
return CurriculumSuggestion(trainer_state=trainer_state, metrics=metrics, version=__semver__)
80+
81+
return curriculum, trainer, run_curriculum
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,224 @@
1+
from typing import Literal, Optional
2+
3+
import numpy as np
4+
from aind_behavior_curriculum import MetricsProvider, Stage
5+
from aind_behavior_vr_foraging import task_logic
6+
from aind_behavior_vr_foraging.task_logic import AindVrForagingTaskLogic, AindVrForagingTaskParameters
7+
8+
from ..depletion import helpers
9+
from ..depletion.metrics import metrics_from_dataset
10+
11+
12+
def deterministic_curves(
13+
amount_drop: float = 5.0,
14+
option: Optional[Literal["single", "delayed"]] = "single",
15+
*,
16+
cap_delayed_rewards: bool = False,
17+
) -> list[task_logic.RewardFunction]:
18+
if option == "delayed":
19+
lut_values = [0.5, 1, 1, 1, 0]
20+
probability = task_logic.LookupTableFunction(
21+
lut_keys=list(np.arange(len(lut_values)) + 1), lut_values=lut_values
22+
)
23+
reward_function_prob = task_logic.PatchRewardFunction(
24+
probability=probability,
25+
rule=task_logic.RewardFunctionRule.ON_CHOICE_ACCUMULATED,
26+
)
27+
if cap_delayed_rewards:
28+
reward_available = amount_drop * 3
29+
available = task_logic.ClampedRateFunction(
30+
rate=task_logic.scalar_value(-amount_drop), minimum=0, maximum=reward_available
31+
)
32+
reward_function_avail = task_logic.PatchRewardFunction(
33+
available=available,
34+
rule=task_logic.RewardFunctionRule.ON_REWARD,
35+
)
36+
reset_function = task_logic.OnThisPatchEntryRewardFunction(
37+
probability=task_logic.SetValueFunction(value=task_logic.scalar_value(1)),
38+
available=task_logic.SetValueFunction(value=task_logic.scalar_value(reward_available)),
39+
)
40+
return [reward_function_prob, reward_function_avail, reset_function]
41+
else:
42+
reward_available = 100
43+
reset_function = task_logic.OnThisPatchEntryRewardFunction(
44+
probability=task_logic.SetValueFunction(value=task_logic.scalar_value(1)),
45+
available=task_logic.SetValueFunction(value=task_logic.scalar_value(reward_available)),
46+
)
47+
return [reward_function_prob, reset_function]
48+
49+
elif option == "single":
50+
lut_values = [1, 0]
51+
probability = task_logic.LookupTableFunction(lut_keys=[1, 2], lut_values=lut_values)
52+
reward_function = task_logic.PatchRewardFunction(
53+
probability=probability,
54+
rule=task_logic.RewardFunctionRule.ON_CHOICE_ACCUMULATED,
55+
)
56+
reset_function = task_logic.OnThisPatchEntryRewardFunction(
57+
probability=task_logic.SetValueFunction(value=task_logic.scalar_value(1)),
58+
available=task_logic.SetValueFunction(value=task_logic.scalar_value(100)),
59+
)
60+
return [reward_function, reset_function]
61+
62+
elif option is None:
63+
probability = task_logic.SetValueFunction(value=task_logic.scalar_value(0))
64+
reward_function = task_logic.PatchRewardFunction(
65+
probability=probability,
66+
rule=task_logic.RewardFunctionRule.ON_CHOICE,
67+
)
68+
reset_function = task_logic.OnThisPatchEntryRewardFunction(
69+
probability=task_logic.SetValueFunction(value=task_logic.scalar_value(0)),
70+
available=task_logic.SetValueFunction(value=task_logic.scalar_value(0)),
71+
)
72+
return [reward_function, reset_function]
73+
74+
else:
75+
raise ValueError(f"Option {option} not recognized. Valid options are 'single', 'delayed', and None.")
76+
77+
78+
def make_patch(
79+
label: str,
80+
state_index: int,
81+
odor_index: list[float],
82+
patch_type: Optional[Literal["single", "delayed"]],
83+
reward_amount: float = 5.0,
84+
first_p: float = 0.5,
85+
reward_available: float = 9999,
86+
stop_duration: float = 0.5,
87+
delay_mean: float = 0.5,
88+
cap_delayed_rewards: bool = False,
89+
) -> task_logic.Patch:
90+
agent = task_logic.RewardSpecification(
91+
operant_logic=helpers.make_operant_logic(stop_duration=stop_duration, is_operant=False),
92+
delay=helpers.make_exponential_distribution(rate=1 / delay_mean, minimum=0.0, maximum=1.0),
93+
amount=task_logic.scalar_value(value=reward_amount),
94+
probability=task_logic.scalar_value(first_p),
95+
available=task_logic.scalar_value(reward_available),
96+
reward_function=deterministic_curves(
97+
amount_drop=reward_amount, option=patch_type, cap_delayed_rewards=cap_delayed_rewards
98+
),
99+
)
100+
return task_logic.Patch(
101+
label=label,
102+
state_index=state_index,
103+
odor_specification=odor_index,
104+
reward_specification=agent,
105+
patch_virtual_sites_generator=helpers.make_patch_virtual_sites_generator(
106+
rewardsite=50,
107+
interpatch_min=100,
108+
interpatch_max=250,
109+
intersite_min=20,
110+
intersite_max=80,
111+
),
112+
)
113+
114+
115+
def make_s_stage_all_odors_rewarded(
116+
delayed_reward_available: float = 100,
117+
cap_delayed_rewards: bool = False,
118+
) -> Stage:
119+
return Stage(
120+
name="all_odors_rewarded",
121+
task=AindVrForagingTaskLogic(
122+
stage_name="all_odors_rewarded",
123+
task_parameters=AindVrForagingTaskParameters(
124+
operation_control=helpers.make_default_operation_control(velocity_threshold=8),
125+
environment=task_logic.BlockStructure(
126+
blocks=[
127+
task_logic.Block(
128+
environment=task_logic.MarkovEnvironment(
129+
first_state_occupancy=[0.5, 0.5],
130+
transition_matrix=[[0.5, 0.5], [0.5, 0.5]],
131+
patches=[
132+
make_patch(
133+
label="patch_single",
134+
state_index=0,
135+
odor_index=[0, 0, 1],
136+
patch_type="single",
137+
reward_amount=5.0,
138+
first_p=1,
139+
reward_available=100,
140+
cap_delayed_rewards=cap_delayed_rewards,
141+
),
142+
make_patch(
143+
label="patch_delayed",
144+
state_index=1,
145+
odor_index=[0, 1, 0],
146+
patch_type="delayed",
147+
reward_amount=5.0,
148+
first_p=0.5,
149+
reward_available=delayed_reward_available,
150+
cap_delayed_rewards=cap_delayed_rewards,
151+
),
152+
],
153+
),
154+
end_conditions=[],
155+
)
156+
],
157+
),
158+
),
159+
),
160+
metrics_provider=MetricsProvider(metrics_from_dataset),
161+
)
162+
163+
164+
def make_s_stage_graduation(
165+
delayed_reward_available: float = 100,
166+
cap_delayed_rewards: bool = False,
167+
) -> Stage:
168+
return Stage(
169+
name="graduation",
170+
task=AindVrForagingTaskLogic(
171+
stage_name="graduation",
172+
task_parameters=AindVrForagingTaskParameters(
173+
operation_control=helpers.make_default_operation_control(velocity_threshold=8),
174+
environment=task_logic.BlockStructure(
175+
blocks=[
176+
task_logic.Block(
177+
environment=task_logic.MarkovEnvironment(
178+
first_state_occupancy=[1 / 3, 1 / 3, 1 / 3],
179+
transition_matrix=[
180+
[1 / 3, 1 / 3, 1 / 3],
181+
[1 / 3, 1 / 3, 1 / 3],
182+
[1 / 3, 1 / 3, 1 / 3],
183+
],
184+
patches=[
185+
make_patch(
186+
label="patch_null",
187+
state_index=0,
188+
odor_index=[1, 0, 0],
189+
patch_type=None,
190+
reward_amount=0.0,
191+
first_p=0,
192+
reward_available=0,
193+
cap_delayed_rewards=cap_delayed_rewards,
194+
),
195+
make_patch(
196+
label="patch_delayed",
197+
state_index=1,
198+
odor_index=[0, 1, 0],
199+
patch_type="delayed",
200+
reward_amount=5.0,
201+
first_p=0.5,
202+
reward_available=delayed_reward_available,
203+
cap_delayed_rewards=cap_delayed_rewards,
204+
),
205+
make_patch(
206+
label="patch_single",
207+
state_index=2,
208+
odor_index=[0, 0, 1],
209+
patch_type="single",
210+
reward_amount=5.0,
211+
first_p=1,
212+
reward_available=100,
213+
cap_delayed_rewards=cap_delayed_rewards,
214+
),
215+
],
216+
),
217+
end_conditions=[],
218+
)
219+
],
220+
),
221+
),
222+
),
223+
metrics_provider=MetricsProvider(metrics_from_dataset),
224+
)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
from ._curriculum_builder import build_deterministic_reversal_curriculum
2+
from .stages import make_s_stage_all_odors_rewarded, make_s_stage_graduation
3+
4+
CURRICULUM_NAME = "DeterministicReversals"
5+
PKG_LOCATION = ".".join(__name__.split(".")[:-1])
6+
7+
CURRICULUM, TRAINER, run_curriculum = build_deterministic_reversal_curriculum(
8+
CURRICULUM_NAME,
9+
PKG_LOCATION,
10+
make_s_stage_all_odors_rewarded,
11+
make_s_stage_graduation,
12+
)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from ._stages_shared import make_s_stage_all_odors_rewarded, make_s_stage_graduation
2+
3+
__all__ = ["make_s_stage_all_odors_rewarded", "make_s_stage_graduation"]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
from .curriculum import CURRICULUM, CURRICULUM_NAME, PKG_LOCATION, TRAINER, run_curriculum
2+
3+
__all__ = [
4+
"CURRICULUM_NAME",
5+
"CURRICULUM",
6+
"TRAINER",
7+
"run_curriculum",
8+
"PKG_LOCATION",
9+
]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
from ..deterministic_reversals._curriculum_builder import build_deterministic_reversal_curriculum
2+
from .stages import make_s_stage_all_odors_rewarded, make_s_stage_graduation
3+
4+
CURRICULUM_NAME = "DeterministicReversalsRewardCapped"
5+
PKG_LOCATION = ".".join(__name__.split(".")[:-1])
6+
7+
CURRICULUM, TRAINER, run_curriculum = build_deterministic_reversal_curriculum(
8+
CURRICULUM_NAME,
9+
PKG_LOCATION,
10+
make_s_stage_all_odors_rewarded,
11+
make_s_stage_graduation,
12+
)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
from functools import partial
2+
3+
from ..deterministic_reversals._stages_shared import (
4+
make_s_stage_all_odors_rewarded as _make_s_stage_all_odors_rewarded,
5+
)
6+
from ..deterministic_reversals._stages_shared import (
7+
make_s_stage_graduation as _make_s_stage_graduation,
8+
)
9+
10+
make_s_stage_all_odors_rewarded = partial(
11+
_make_s_stage_all_odors_rewarded, delayed_reward_available=15, cap_delayed_rewards=True
12+
)
13+
make_s_stage_graduation = partial(_make_s_stage_graduation, delayed_reward_available=15, cap_delayed_rewards=True)
14+
15+
__all__ = ["make_s_stage_all_odors_rewarded", "make_s_stage_graduation"]

0 commit comments

Comments
 (0)