Skip to content

Commit ea06eb3

Browse files
committed
Add descriptive param names to skill factory options for LLM planner
Update continuous parameter descriptions in pick, place, push, and pour skill factories so the LLM planner sees informative names like "approach_distance (dist behind target along facing dir to start push)" instead of terse names like "offset_x". Also update docstrings and test comments to match.
1 parent 518e444 commit ea06eb3

7 files changed

Lines changed: 270 additions & 194 deletions

File tree

predicators/approaches/agent_option_learning_approach.py

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,8 @@ class AgentOptionLearningApproach(AgentPlannerApproach):
4040

4141
def __init__(self, initial_predicates: Set[Predicate],
4242
initial_options: Set[ParameterizedOption], types: Set[Type],
43-
action_space: Box, train_tasks: List[Task],
44-
*args: Any, **kwargs: Any) -> None:
43+
action_space: Box, train_tasks: List[Task], *args: Any,
44+
**kwargs: Any) -> None:
4545
# Agent-specific state (before super().__init__)
4646
self._agent_proposed_options: Set[ParameterizedOption] = set()
4747

@@ -82,13 +82,13 @@ def _get_agent_system_prompt(self) -> str:
8282
Continuous params: `(grasp_z_offset,)`.
8383
- `create_place_skill(name, types, config)` — place a held object \
8484
(move above, descend, release, retreat). No get_target_pose_fn; \
85-
target comes from continuous params: `(x, y, yaw, drop_z)`.
85+
target comes from continuous params: `(target_x, target_y, target_yaw, release_z)`.
8686
- `create_push_skill(name, types, config, get_target_pose_fn)` — \
8787
push with standard 4-waypoint trajectory. Requires \
8888
`config.robot_home_pos` to be set. Facing direction is \
8989
`(sin(yaw), cos(yaw))` from `get_target_pose_fn`. \
90-
Continuous params: `(offset_x, offset_z, offset_rot, \
91-
push_through_frac)`.
90+
Continuous params: `(approach_distance, contact_z_offset, \
91+
ee_yaw_offset, push_through_frac)`.
9292
- `create_pour_skill(name, types, config, get_target_pose_fn, \
9393
tilt_terminal_fn=None)` — pour from a held container \
9494
(move above, descend, tilt). Continuous params: `(pour_tilt,)`.
@@ -135,21 +135,21 @@ def _get_agent_tool_names(self) -> Optional[List[str]]:
135135
def _get_sandbox_reference_files(self) -> Dict[str, str]:
136136
return {
137137
"skill_factories/base.py":
138-
"predicators/ground_truth_models/skill_factories/base.py",
138+
"predicators/ground_truth_models/skill_factories/base.py",
139139
"skill_factories/__init__.py":
140-
"predicators/ground_truth_models/skill_factories/__init__.py",
140+
"predicators/ground_truth_models/skill_factories/__init__.py",
141141
"skill_factories/pick.py":
142-
"predicators/ground_truth_models/skill_factories/pick.py",
142+
"predicators/ground_truth_models/skill_factories/pick.py",
143143
"skill_factories/move_to.py":
144-
"predicators/ground_truth_models/skill_factories/move_to.py",
144+
"predicators/ground_truth_models/skill_factories/move_to.py",
145145
"skill_factories/place.py":
146-
"predicators/ground_truth_models/skill_factories/place.py",
146+
"predicators/ground_truth_models/skill_factories/place.py",
147147
"skill_factories/push.py":
148-
"predicators/ground_truth_models/skill_factories/push.py",
148+
"predicators/ground_truth_models/skill_factories/push.py",
149149
"skill_factories/pour.py":
150-
"predicators/ground_truth_models/skill_factories/pour.py",
150+
"predicators/ground_truth_models/skill_factories/pour.py",
151151
"skill_factories/wait.py":
152-
"predicators/ground_truth_models/skill_factories/wait.py",
152+
"predicators/ground_truth_models/skill_factories/wait.py",
153153
}
154154

155155
# ------------------------------------------------------------------ #
@@ -163,8 +163,8 @@ def _get_all_options(self) -> Set[ParameterizedOption]:
163163
# Also include iteration_proposals.proposed_options as a fallback
164164
# in case the Docker sync to tool_context.options was incomplete.
165165
proposal_opts = self._tool_context.iteration_proposals.proposed_options
166-
result = (self._initial_options | self._agent_proposed_options |
167-
self._tool_context.options | proposal_opts)
166+
result = (self._initial_options | self._agent_proposed_options
167+
| self._tool_context.options | proposal_opts)
168168
if not result:
169169
logging.warning(
170170
"_get_all_options() returning empty set. "
@@ -195,10 +195,10 @@ def _sync_tool_context(self) -> None:
195195
def _build_skill_factory_context(self) -> Dict[str, Any]:
196196
"""Build exec context with skill factory functions for
197197
propose_options."""
198-
from predicators.ground_truth_models.skill_factories import (
199-
Phase, PhaseAction, PhaseSkill, SkillConfig, create_move_to_skill,
200-
create_pick_skill, create_place_skill, create_pour_skill,
201-
create_push_skill, create_wait_option, make_move_to_phase)
198+
from predicators.ground_truth_models.skill_factories import Phase, \
199+
PhaseAction, PhaseSkill, SkillConfig, create_move_to_skill, \
200+
create_pick_skill, create_place_skill, create_pour_skill, \
201+
create_push_skill, create_wait_option, make_move_to_phase
202202

203203
context: Dict[str, Any] = {
204204
# Skill factory functions
@@ -296,8 +296,8 @@ def _solve(self, task: Task, timeout: int) -> Callable[[State], Action]:
296296
policy = super()._solve(task, timeout)
297297

298298
# Snapshot agent-proposed options (everything beyond initial)
299-
self._agent_proposed_options = (
300-
self._tool_context.options - self._initial_options)
299+
self._agent_proposed_options = (self._tool_context.options -
300+
self._initial_options)
301301

302302
# Record iteration summary (options only)
303303
proposals = self._tool_context.iteration_proposals

predicators/ground_truth_models/skill_factories/__init__.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,17 +16,19 @@
1616
1717
Shared signature pattern
1818
------------------------
19-
All factory functions (except ``create_wait_option``) share the same first
20-
four arguments::
19+
Most factory functions share the same first three arguments::
2120
2221
create_<X>_skill(
2322
name: str, # Option name for logging/matching
2423
types: Sequence[Type],# Object types (robot first)
25-
params_space: Box, # Continuous parameter space
2624
config: SkillConfig, # Shared environment configuration
2725
... # Skill-specific arguments
2826
)
2927
28+
Each factory builds its ``params_space`` internally from canonical parameter
29+
definitions (e.g. ``_PICK_PARAMS``, ``_PLACE_PARAMS``). The exception is
30+
``create_move_to_skill``, which takes an explicit ``params_space`` argument.
31+
3032
``create_wait_option`` uses ``(name, config, robot_type)`` since it always
3133
operates on a single robot type with no parameters.
3234

predicators/ground_truth_models/skill_factories/pick.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def _get_jug_pose(state, objects, params, config):
5454

5555
# Canonical continuous parameters for Pick.
5656
_PICK_PARAMS = [
57-
("grasp_z_offset", 0.0, 0.1),
57+
("grasp_z_offset (height above object origin to close gripper)", 0.0, 0.1),
5858
]
5959

6060

@@ -140,5 +140,9 @@ def _descend_pose(
140140
make_move_to_phase("Lift", _above_pose, "closed"),
141141
]
142142

143-
return PhaseSkill(name, types, params_space, config, phases,
143+
return PhaseSkill(name,
144+
types,
145+
params_space,
146+
config,
147+
phases,
144148
params_description=params_description).build()

predicators/ground_truth_models/skill_factories/place.py

Lines changed: 21 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,15 @@
44
``ParameterizedOption`` that places a held object by:
55
66
1. Moving above the placement target at ``config.transport_z``.
7-
2. Descending to ``drop_z`` (from params).
7+
2. Descending to ``release_z`` (from params).
88
3. Opening the gripper to release.
99
4. Retreating back up to ``config.transport_z``.
1010
11-
The placement target ``(x, y, yaw)`` and ``drop_z`` are all provided as
12-
continuous parameters -- no callback is needed.
11+
The placement target ``(target_x, target_y, target_yaw)`` and
12+
``release_z`` are all provided as continuous parameters -- no callback
13+
is needed.
1314
14-
Continuous parameters: ``(x, y, yaw, drop_z)``
15+
Continuous parameters: ``(target_x, target_y, target_yaw, release_z)``
1516
1617
Example::
1718
@@ -28,20 +29,20 @@
2829

2930
from typing import Sequence, Tuple
3031

32+
import numpy as np
33+
3134
from predicators.ground_truth_models.skill_factories.base import Phase, \
3235
PhaseAction, PhaseSkill, SkillConfig, build_params_space
3336
from predicators.ground_truth_models.skill_factories.move_to import \
3437
make_move_to_phase
3538
from predicators.structs import Array, Object, ParameterizedOption, State, Type
3639

37-
import numpy as np
38-
3940
# Canonical continuous parameters for Place.
4041
_PLACE_PARAMS = [
41-
("x", 0.4, 1.1),
42-
("y", 1.1, 1.6),
43-
("yaw", -np.pi, np.pi),
44-
("drop_z", 0.4, 0.6),
42+
("target_x (world x position for placement)", 0.4, 1.1),
43+
("target_y (world y position for placement)", 1.1, 1.6),
44+
("target_yaw (placement orientation in radians)", -np.pi, np.pi),
45+
("release_z (world z height to open gripper)", 0.4, 0.6),
4546
]
4647

4748

@@ -55,15 +56,15 @@ def create_place_skill(
5556
Phases:
5657
0. **MoveAbove** -- Move end-effector above the placement at
5758
``config.transport_z``, with fingers closed.
58-
1. **Descend** -- Lower to ``drop_z`` (from params), with fingers
59-
closed.
59+
1. **Descend** -- Lower to ``release_z`` (from params), with
60+
fingers closed.
6061
2. **OpenFingers** -- Open the gripper to release the object.
6162
3. **Retreat** -- Rise back to ``config.transport_z``, with fingers
6263
open.
6364
6465
Continuous parameters:
65-
``(x, y, yaw, drop_z)`` -- placement position, orientation, and
66-
release height.
66+
``(target_x, target_y, target_yaw, release_z)`` -- placement
67+
position, orientation, and release height.
6768
6869
Args:
6970
name: Option name used for logging and matching.
@@ -110,7 +111,7 @@ def _drop_pose(
110111
phases = [
111112
# Phase 0: Move above placement
112113
make_move_to_phase("MoveAbove", _above_pose, "closed"),
113-
# Phase 1: Descend to drop height
114+
# Phase 1: Descend to release height
114115
make_move_to_phase("Descend", _drop_pose, "closed"),
115116
# Phase 2: Open fingers to release
116117
Phase(
@@ -122,5 +123,9 @@ def _drop_pose(
122123
make_move_to_phase("Retreat", _above_pose, "open"),
123124
]
124125

125-
return PhaseSkill(name, types, params_space, config, phases,
126+
return PhaseSkill(name,
127+
types,
128+
params_space,
129+
config,
130+
phases,
126131
params_description=params_description).build()

predicators/ground_truth_models/skill_factories/pour.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def _get_pour_pose(state, objects, params, config):
4747

4848
# Canonical continuous parameters for Pour.
4949
_POUR_PARAMS = [
50-
("pour_tilt", 0.5, 1.0),
50+
("pour_tilt (EE tilt angle for pouring, radians)", 0.5, 1.0),
5151
]
5252

5353

@@ -110,8 +110,8 @@ def _tilt_target(
110110
) -> Tuple[Pose, Pose, str]:
111111
pour_tilt = float(params[0])
112112
robot_obj = objects[0]
113-
current_position = (state.get(robot_obj, "x"),
114-
state.get(robot_obj, "y"),
113+
current_position = (state.get(robot_obj,
114+
"x"), state.get(robot_obj, "y"),
115115
state.get(robot_obj, "z"))
116116
current_orn = p.getQuaternionFromEuler(
117117
[0, state.get(robot_obj, "tilt"),
@@ -137,5 +137,9 @@ def _tilt_target(
137137
),
138138
]
139139

140-
return PhaseSkill(name, types, params_space, config, phases,
140+
return PhaseSkill(name,
141+
types,
142+
params_space,
143+
config,
144+
phases,
141145
params_description=params_description).build()

predicators/ground_truth_models/skill_factories/push.py

Lines changed: 29 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66
77
1. Closing the gripper.
88
2. Moving above & behind the target at ``config.transport_z``.
9-
3. Descending to contact height (target z + ``offset_z``).
10-
4. Pushing through the target (``push_through_frac * offset_x``
9+
3. Descending to contact height (target z + ``contact_z_offset``).
10+
4. Pushing through the target (``push_through_frac * approach_distance``
1111
past the target along its facing direction).
1212
5. Retreating to ``config.robot_home_pos``.
1313
6. Opening the gripper.
@@ -18,7 +18,7 @@
1818
1919
``config.robot_home_pos`` **must** be set.
2020
21-
Continuous parameters: ``(offset_x, offset_z, offset_rot, push_through_frac)``
21+
Continuous parameters: ``(approach_distance, contact_z_offset, ee_yaw_offset, push_through_frac)``
2222
2323
Example::
2424
@@ -61,10 +61,13 @@ def _get_domino_pose(state, objects, params, config):
6161

6262
# Canonical continuous parameters for Push.
6363
_PUSH_PARAMS = [
64-
("offset_x", 0.03, 0.08),
65-
("offset_z", 0.0, 0.12),
66-
("offset_rot", -np.pi, np.pi),
67-
("push_through_frac", 0.0, 0.3),
64+
("approach_distance (dist behind target along facing dir to start push)",
65+
0.03, 0.08),
66+
("contact_z_offset (height above target z for contact)", 0.0, 0.12),
67+
("ee_yaw_offset (EE rotation offset from target yaw, radians)", -np.pi,
68+
np.pi),
69+
("push_through_frac (fraction of approach_distance to push past target)",
70+
0.0, 0.3),
6871
]
6972

7073

@@ -79,17 +82,18 @@ def create_push_skill(
7982
Phases:
8083
0. **CloseFingers** -- Close the gripper before approaching.
8184
1. **Waypoint_0** -- Move above & behind the target at
82-
``config.transport_z``, offset by ``offset_x`` opposite the
83-
facing direction.
85+
``config.transport_z``, offset by ``approach_distance``
86+
opposite the facing direction.
8487
2. **Waypoint_1** -- Descend to contact height
85-
(target z + ``offset_z``) at the same behind position.
88+
(target z + ``contact_z_offset``) at the same behind position.
8689
3. **Waypoint_2** -- Push forward through the target by
87-
``offset_x * push_through_frac`` along the facing direction.
90+
``approach_distance * push_through_frac`` along the facing
91+
direction.
8892
4. **Waypoint_3** -- Retreat to ``config.robot_home_pos``.
8993
5. **OpenFingers** -- Open the gripper.
9094
9195
Continuous parameters:
92-
``(offset_x, offset_z, offset_rot, push_through_frac)``
96+
``(approach_distance, contact_z_offset, ee_yaw_offset, push_through_frac)``
9397
9498
Args:
9599
name: Option name used for logging and matching.
@@ -112,8 +116,14 @@ def create_push_skill(
112116
# -- Standard 4-waypoint trajectory ----------------------------------
113117

114118
def _waypoints(
115-
ox: float, oy: float, oz: float, oyaw: float, cfg: SkillConfig,
116-
s_offset_x: float, s_offset_z: float, s_offset_rot: float,
119+
ox: float,
120+
oy: float,
121+
oz: float,
122+
oyaw: float,
123+
cfg: SkillConfig,
124+
s_offset_x: float,
125+
s_offset_z: float,
126+
s_offset_rot: float,
117127
s_push_frac: float,
118128
) -> List[Tuple[float, float, float, float, str]]:
119129
assert cfg.robot_home_pos is not None
@@ -199,5 +209,9 @@ def _get_target(
199209
action_type=PhaseAction.CHANGE_FINGERS,
200210
target_fn=_open_fingers_target))
201211

202-
return PhaseSkill(name, types, params_space, config, phases,
212+
return PhaseSkill(name,
213+
types,
214+
params_space,
215+
config,
216+
phases,
203217
params_description=params_description).build()

0 commit comments

Comments
 (0)