Skip to content

Commit 96fc7a6

Browse files
committed
Add env-specific process samplers for continuous option params
Replace null_sampler with proper samplers that return fixed param values for each env's EndogenousProcess (pick, place, push, pour). Place samplers read target coordinates from state or process objects.
1 parent ba0e1e5 commit 96fc7a6

5 files changed

Lines changed: 211 additions & 39 deletions

File tree

predicators/ground_truth_models/boil/processes.py

Lines changed: 65 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,70 @@
11
"""Ground-truth processes for the boil environments."""
22
import logging
33
from pprint import pformat
4-
from typing import Dict, Set, cast
4+
from typing import Dict, Sequence, Set, cast
55

66
import numpy as np
77
import torch
88

9+
from predicators.envs.pybullet_boil import PyBulletBoilEnv
910
from predicators.ground_truth_models import GroundTruthProcessFactory
1011
from predicators.settings import CFG
11-
from predicators.structs import CausalProcess, DelayDistribution, \
12-
EndogenousProcess, ExogenousProcess, LiftedAtom, ParameterizedOption, \
13-
Predicate, Type, Variable
12+
from predicators.structs import Array, CausalProcess, DelayDistribution, \
13+
EndogenousProcess, ExogenousProcess, GroundAtom, LiftedAtom, Object, \
14+
ParameterizedOption, Predicate, State, Type, Variable
1415
from predicators.utils import ConstantDelay, DiscreteGaussianDelay, \
1516
null_sampler
1617

1718

19+
_BOIL_DROP_Z = 0.49 # table_height (0.4) + jug_handle_height (0.09)
20+
21+
22+
def _pick_sampler(state: State, goal: Set[GroundAtom],
23+
rng: np.random.Generator,
24+
objs: Sequence[Object]) -> Array:
25+
del state, goal, rng, objs
26+
return np.array([0.0], dtype=np.float32)
27+
28+
29+
def _push_sampler(state: State, goal: Set[GroundAtom],
30+
rng: np.random.Generator,
31+
objs: Sequence[Object]) -> Array:
32+
del state, goal, rng, objs
33+
return np.array([0.057, 0.104, 0.0, 0.25], dtype=np.float32)
34+
35+
36+
def _place_on_burner_sampler(state: State, goal: Set[GroundAtom],
37+
rng: np.random.Generator,
38+
objs: Sequence[Object]) -> Array:
39+
del goal, rng
40+
# objs = [robot, jug, burner]
41+
burner = objs[2]
42+
x = state.get(burner, "x")
43+
y = state.get(burner, "y") - PyBulletBoilEnv.jug_handle_offset
44+
return np.array([x, y, 0.0, _BOIL_DROP_Z], dtype=np.float32)
45+
46+
47+
def _place_under_faucet_sampler(state: State, goal: Set[GroundAtom],
48+
rng: np.random.Generator,
49+
objs: Sequence[Object]) -> Array:
50+
del goal, rng
51+
# objs = [robot, jug, faucet]
52+
faucet = objs[2]
53+
x = state.get(faucet, "x")
54+
y = (state.get(faucet, "y") - PyBulletBoilEnv.jug_handle_offset
55+
- PyBulletBoilEnv.faucet_x_len)
56+
return np.array([x, y, 0.0, _BOIL_DROP_Z], dtype=np.float32)
57+
58+
59+
def _place_outside_sampler(state: State, goal: Set[GroundAtom],
60+
rng: np.random.Generator,
61+
objs: Sequence[Object]) -> Array:
62+
del state, goal, rng, objs
63+
x = PyBulletBoilEnv.x_mid - 0.15
64+
y = PyBulletBoilEnv.y_mid + 0.10
65+
return np.array([x, y, 0.0, _BOIL_DROP_Z], dtype=np.float32)
66+
67+
1868
class PyBulletBoilGroundTruthProcessFactory(GroundTruthProcessFactory):
1969
"""Ground-truth processes for the boil environment."""
2070

@@ -107,7 +157,7 @@ def get_processes(
107157
pick_jug_from_faucet_process = EndogenousProcess(
108158
"PickJugFromFaucet", parameters, condition_at_start, set(),
109159
set(), add_effects, delete_effects, delay_distribution,
110-
torch.tensor(1.0), option, option_vars, null_sampler)
160+
torch.tensor(1.0), option, option_vars, _pick_sampler)
111161
processes.add(pick_jug_from_faucet_process)
112162

113163
# PickJugFromBurner
@@ -134,7 +184,7 @@ def get_processes(
134184
pick_jug_from_burner_process = EndogenousProcess(
135185
"PickJugFromBurner", parameters, condition_at_start, set(),
136186
set(), add_effects, delete_effects, delay_distribution,
137-
torch.tensor(1.0), option, option_vars, null_sampler)
187+
torch.tensor(1.0), option, option_vars, _pick_sampler)
138188
processes.add(pick_jug_from_burner_process)
139189

140190
# PickJugFromOutsideFaucetAndBurner
@@ -160,7 +210,7 @@ def get_processes(
160210
"PickJugFromOutsideFaucetAndBurner", parameters,
161211
condition_at_start, set(),
162212
set(), add_effects, delete_effects, delay_distribution,
163-
torch.tensor(1.0), option, option_vars, null_sampler)
213+
torch.tensor(1.0), option, option_vars, _pick_sampler)
164214
processes.add(pick_jug_outside_faucet_burner_process)
165215

166216
# PlaceOnBurner
@@ -187,7 +237,7 @@ def get_processes(
187237
place_on_burner_process = EndogenousProcess(
188238
"PlaceOnBurner", parameters, condition_at_start, set(),
189239
set(), add_effects, delete_effects, delay_distribution,
190-
torch.tensor(1.0), option, option_vars, null_sampler)
240+
torch.tensor(1.0), option, option_vars, _place_on_burner_sampler)
191241
processes.add(place_on_burner_process)
192242

193243
# PlaceUnderFaucet
@@ -214,7 +264,8 @@ def get_processes(
214264
place_under_faucet_process = EndogenousProcess(
215265
"PlaceUnderFaucet", parameters, condition_at_start, set(),
216266
set(), add_effects, delete_effects, delay_distribution,
217-
torch.tensor(1.0), option, option_vars, null_sampler)
267+
torch.tensor(1.0), option, option_vars,
268+
_place_under_faucet_sampler)
218269
processes.add(place_under_faucet_process)
219270

220271
# PlaceAtOutsideFaucetAndBurner
@@ -238,7 +289,7 @@ def get_processes(
238289
place_at_outside_faucet_burner_process = EndogenousProcess(
239290
"PlaceOutsideFaucetAndBurner", parameters, condition_at_start,
240291
set(), set(), add_effects, delete_effects, delay_distribution,
241-
torch.tensor(1.0), option, option_vars, null_sampler)
292+
torch.tensor(1.0), option, option_vars, _place_outside_sampler)
242293
processes.add(place_at_outside_faucet_burner_process)
243294

244295
# SwitchFaucetOn
@@ -262,7 +313,7 @@ def get_processes(
262313
switch_faucet_on_process = EndogenousProcess(
263314
"SwitchFaucetOn", parameters, condition_at_start, set(),
264315
set(), add_effects, delete_effects, delay_distribution,
265-
torch.tensor(1.0), option, option_vars, null_sampler)
316+
torch.tensor(1.0), option, option_vars, _push_sampler)
266317
processes.add(switch_faucet_on_process)
267318

268319
# SwitchFaucetOff
@@ -286,7 +337,7 @@ def get_processes(
286337
switch_faucet_off_process = EndogenousProcess(
287338
"SwitchFaucetOff", parameters, condition_at_start, set(),
288339
set(), add_effects, delete_effects, delay_distribution,
289-
torch.tensor(1.0), option, option_vars, null_sampler)
340+
torch.tensor(1.0), option, option_vars, _push_sampler)
290341
processes.add(switch_faucet_off_process)
291342

292343
# SwitchBurnerOn
@@ -310,7 +361,7 @@ def get_processes(
310361
switch_burner_on_process = EndogenousProcess(
311362
"SwitchBurnerOn", parameters, condition_at_start, set(),
312363
set(), add_effects, delete_effects, delay_distribution,
313-
torch.tensor(1.0), option, option_vars, null_sampler)
364+
torch.tensor(1.0), option, option_vars, _push_sampler)
314365
processes.add(switch_burner_on_process)
315366

316367
# SwitchBurnerOff
@@ -334,7 +385,7 @@ def get_processes(
334385
switch_burner_off_process = EndogenousProcess(
335386
"SwitchBurnerOff", parameters, condition_at_start, set(),
336387
set(), add_effects, delete_effects, delay_distribution,
337-
torch.tensor(1.0), option, option_vars, null_sampler)
388+
torch.tensor(1.0), option, option_vars, _push_sampler)
338389
processes.add(switch_burner_off_process)
339390

340391
# Noop

predicators/ground_truth_models/coffee/processes.py

Lines changed: 47 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,54 @@
11
"""Ground-truth processes for the coffee environments."""
22

3-
from typing import Dict, Set, cast
3+
from typing import Dict, Sequence, Set, cast
44

55
import numpy as np
66
import torch
77

8+
from predicators.envs.pybullet_coffee import PyBulletCoffeeEnv
89
from predicators.ground_truth_models import GroundTruthProcessFactory
910
from predicators.settings import CFG
10-
from predicators.structs import CausalProcess, DelayDistribution, \
11-
EndogenousProcess, ExogenousProcess, LiftedAtom, ParameterizedOption, \
12-
Predicate, Type, Variable
11+
from predicators.structs import Array, CausalProcess, DelayDistribution, \
12+
EndogenousProcess, ExogenousProcess, GroundAtom, LiftedAtom, Object, \
13+
ParameterizedOption, Predicate, State, Type, Variable
1314
from predicators.utils import ConstantDelay, DiscreteGaussianDelay, \
1415
null_sampler
1516

17+
_COFFEE_DROP_Z = 0.5 # z_lb (0.4) + jug_handle_height (0.1)
18+
19+
20+
def _pick_sampler(state: State, goal: Set[GroundAtom],
21+
rng: np.random.Generator,
22+
objs: Sequence[Object]) -> Array:
23+
del state, goal, rng, objs
24+
return np.array([0.0], dtype=np.float32)
25+
26+
27+
def _push_sampler(state: State, goal: Set[GroundAtom],
28+
rng: np.random.Generator,
29+
objs: Sequence[Object]) -> Array:
30+
"""Push params for TurnMachineOn (button press)."""
31+
del state, goal, rng, objs
32+
return np.array([0.0675, 0.0, -np.pi, 0.0], dtype=np.float32)
33+
34+
35+
def _place_jug_in_machine_sampler(state: State, goal: Set[GroundAtom],
36+
rng: np.random.Generator,
37+
objs: Sequence[Object]) -> Array:
38+
del state, goal, rng
39+
# objs = [robot, jug, machine]
40+
return np.array([PyBulletCoffeeEnv.dispense_area_x,
41+
PyBulletCoffeeEnv.dispense_area_y,
42+
PyBulletCoffeeEnv.robot_init_wrist,
43+
_COFFEE_DROP_Z], dtype=np.float32)
44+
45+
46+
def _pour_sampler(state: State, goal: Set[GroundAtom],
47+
rng: np.random.Generator,
48+
objs: Sequence[Object]) -> Array:
49+
del state, goal, rng, objs
50+
return np.array([np.pi / 4], dtype=np.float32)
51+
1652

1753
class PyBulletCoffeeGroundTruthProcessFactory(GroundTruthProcessFactory):
1854
"""Ground-truth processes for the coffee environment."""
@@ -189,7 +225,7 @@ def get_processes(
189225
pick_jug_from_table_process = EndogenousProcess(
190226
"PickJugFromTable", parameters, condition_at_start, set(),
191227
set(), add_effects, delete_effects, delay_distribution,
192-
torch.tensor(1.0), option, option_vars, null_sampler)
228+
torch.tensor(1.0), option, option_vars, _pick_sampler)
193229
processes.add(pick_jug_from_table_process)
194230

195231
# PlaceJugInMachine
@@ -215,7 +251,8 @@ def get_processes(
215251
place_jug_in_machine_process = EndogenousProcess(
216252
"PlaceJugInMachine", parameters, condition_at_start, set(),
217253
set(), add_effects, delete_effects, delay_distribution,
218-
torch.tensor(1.0), option, option_vars, null_sampler)
254+
torch.tensor(1.0), option, option_vars,
255+
_place_jug_in_machine_sampler)
219256
processes.add(place_jug_in_machine_process)
220257

221258
# TurnMachineOn
@@ -242,7 +279,7 @@ def get_processes(
242279
turn_machine_on_process = EndogenousProcess(
243280
"TurnMachineOn", parameters, condition_at_start, set(),
244281
set(), add_effects, delete_effects, delay_distribution,
245-
torch.tensor(1.0), option, option_vars, null_sampler)
282+
torch.tensor(1.0), option, option_vars, _push_sampler)
246283
processes.add(turn_machine_on_process)
247284

248285
# PickJugFromMachine
@@ -268,7 +305,7 @@ def get_processes(
268305
pick_jug_from_machine_process = EndogenousProcess(
269306
"PickJugFromMachine", parameters, condition_at_start, set(),
270307
set(), add_effects, delete_effects, delay_distribution,
271-
torch.tensor(1.0), option, option_vars, null_sampler)
308+
torch.tensor(1.0), option, option_vars, _pick_sampler)
272309
processes.add(pick_jug_from_machine_process)
273310

274311
# Pour from not-above-cup
@@ -293,7 +330,7 @@ def get_processes(
293330
pourFromNotAboveCup_process = EndogenousProcess(
294331
"PourFromNotAboveCup", parameters, condition_at_start, set(),
295332
set(), add_effects, delete_effects, delay_distribution,
296-
torch.tensor(1.0), option, option_vars, null_sampler)
333+
torch.tensor(1.0), option, option_vars, _pour_sampler)
297334
processes.add(pourFromNotAboveCup_process)
298335

299336
# Pour from above-cup
@@ -321,7 +358,7 @@ def get_processes(
321358
pourFromNotAboveCup_process = EndogenousProcess(
322359
"PourFromCup", parameters, condition_at_start, set(),
323360
set(), add_effects, delete_effects, delay_distribution,
324-
torch.tensor(1.0), option, option_vars, null_sampler,
361+
torch.tensor(1.0), option, option_vars, _pour_sampler,
325362
ignore_effects)
326363
processes.add(pourFromNotAboveCup_process)
327364

predicators/ground_truth_models/domino/processes.py

Lines changed: 50 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,59 @@
11
"""Ground-truth processes for the domino environment."""
22

3-
from typing import Dict, Set
3+
from typing import Dict, Sequence, Set
44

5+
import numpy as np
56
import torch
67

78
from predicators.ground_truth_models import GroundTruthProcessFactory
89
from predicators.settings import CFG
9-
from predicators.structs import CausalProcess, EndogenousProcess, \
10-
ExogenousProcess, LiftedAtom, ParameterizedOption, Predicate, Type, \
11-
Variable
10+
from predicators.structs import Array, CausalProcess, EndogenousProcess, \
11+
ExogenousProcess, GroundAtom, LiftedAtom, Object, ParameterizedOption, \
12+
Predicate, State, Type, Variable
1213
from predicators.utils import ConstantDelay, DiscreteGaussianDelay, \
1314
null_sampler
1415

16+
# Fixed parameter values for domino environment.
17+
_DOMINO_GRASP_Z_OFFSET = 0.0825 # domino_height * 0.55
18+
_DOMINO_DROP_Z = 0.5695 # table_height + domino_height * 1.13
19+
_DOMINO_OFFSET_X = 0.045 # domino_depth * 3
20+
_DOMINO_OFFSET_Z = 0.0825 # domino_height * 0.55
21+
_DOMINO_OFFSET_ROT = np.pi / 2
22+
_DOMINO_PUSH_THROUGH_FRAC = 0.25
23+
24+
25+
def _pick_sampler(state: State, goal: Set[GroundAtom],
26+
rng: np.random.Generator,
27+
objs: Sequence[Object]) -> Array:
28+
"""Return fixed grasp_z_offset for domino pick."""
29+
del state, goal, rng, objs
30+
return np.array([_DOMINO_GRASP_Z_OFFSET], dtype=np.float32)
31+
32+
33+
def _push_sampler(state: State, goal: Set[GroundAtom],
34+
rng: np.random.Generator,
35+
objs: Sequence[Object]) -> Array:
36+
"""Return fixed push params for domino push."""
37+
del state, goal, rng, objs
38+
return np.array([_DOMINO_OFFSET_X, _DOMINO_OFFSET_Z,
39+
_DOMINO_OFFSET_ROT, _DOMINO_PUSH_THROUGH_FRAC],
40+
dtype=np.float32)
41+
42+
43+
def _place_sampler(state: State, goal: Set[GroundAtom],
44+
rng: np.random.Generator,
45+
objs: Sequence[Object]) -> Array:
46+
"""Return placement params from process objects."""
47+
del state, goal, rng
48+
# objs = [robot, domino1, domino2, target_pos, rotation]
49+
target_pos = objs[3]
50+
rotation = objs[4]
51+
x = float(target_pos.name.split("_")[1])
52+
y = float(target_pos.name.split("_")[2])
53+
angle_deg = float(rotation.name.split("_")[-1])
54+
yaw = np.radians(angle_deg)
55+
return np.array([x, y, yaw, _DOMINO_DROP_Z], dtype=np.float32)
56+
1557

1658
class PyBulletDominoGroundTruthProcessFactory(GroundTruthProcessFactory):
1759
"""Ground-truth processes for the domino grid environment."""
@@ -86,7 +128,7 @@ def get_processes(
86128
push_start_block_process = EndogenousProcess(
87129
"PushStartBlock", parameters, condition_at_start, set(),
88130
set(), add_effects, delete_effects, delay_distribution,
89-
torch.tensor(1.0), option, option_vars, null_sampler,
131+
torch.tensor(1.0), option, option_vars, _push_sampler,
90132
ignore_effects)
91133
processes.add(push_start_block_process)
92134

@@ -125,7 +167,7 @@ def get_processes(
125167
delete_effects,
126168
delay_distribution,
127169
torch.tensor(1.0), option,
128-
option_vars, null_sampler,
170+
option_vars, _pick_sampler,
129171
ignore_effects)
130172
processes.add(pick_domino_process)
131173

@@ -137,7 +179,7 @@ def get_processes(
137179
target_pos = Variable("?pos1", position_type)
138180
rotation = Variable("?rot", rotation_type)
139181
parameters = [robot, domino1, domino2, target_pos, rotation]
140-
option_vars = [robot, domino1, domino2, target_pos, rotation]
182+
option_vars = [robot]
141183
option = Place
142184
condition_at_start = {
143185
LiftedAtom(Holding, [robot, domino1]),
@@ -165,7 +207,7 @@ def get_processes(
165207
delete_effects,
166208
delay_distribution,
167209
torch.tensor(1.0), option,
168-
option_vars, null_sampler,
210+
option_vars, _place_sampler,
169211
ignore_effects)
170212
processes.add(place_domino_process)
171213

0 commit comments

Comments
 (0)