Skip to content

Commit 3d8cc14

Browse files
author
ronuchit
authored
refactor side predicate code into new file with classes (#584)
* refactor side predicate code into new file with classes * improve a comment
1 parent 2c6001a commit 3d8cc14

5 files changed

Lines changed: 308 additions & 234 deletions

File tree

scripts/run_supercloud_experiments.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,9 @@ for SEED in $(seq $START_SEED $((NUM_SEEDS+START_SEED-1))); do
3131
python $FILE --experiment_id tools_invent_allexclude --env tools --approach grammar_search_invention --excluded_predicates all --seed $SEED --num_train_tasks 200
3232

3333
# repeated_nextto
34-
# requires extra flag: "--side_predicate_learner prediction_error_hillclimbing"
34+
# requires extra flag: "--side_predicate_learner prediction_error_hill_climbing"
3535
python $FILE --experiment_id repeated_nextto_oracle --env repeated_nextto --approach oracle --seed $SEED --num_train_tasks 0
36-
python $FILE --experiment_id repeated_nextto_nsrt_learning --env repeated_nextto --approach nsrt_learning --side_predicate_learner prediction_error_hillclimbing --seed $SEED --num_train_tasks 50
36+
python $FILE --experiment_id repeated_nextto_nsrt_learning --env repeated_nextto --approach nsrt_learning --side_predicate_learner prediction_error_hill_climbing --seed $SEED --num_train_tasks 50
3737

3838
# playroom
3939
python $FILE --experiment_id playroom_oracle --env playroom --approach oracle --seed $SEED --num_train_tasks 0

src/nsrt_learning/nsrt_learning_main.py

Lines changed: 21 additions & 229 deletions
Original file line numberDiff line numberDiff line change
@@ -3,21 +3,20 @@
33
from __future__ import annotations
44

55
import logging
6-
from typing import FrozenSet, Iterator, List, Set, Tuple
6+
from typing import List, Set
77

88
from predicators.src import utils
99
from predicators.src.nsrt_learning.option_learning import create_option_learner
1010
from predicators.src.nsrt_learning.sampler_learning import learn_samplers
1111
from predicators.src.nsrt_learning.segmentation import segment_trajectory
12+
from predicators.src.nsrt_learning.side_predicate_learning import \
13+
PredictionErrorHillClimbingSidePredicateLearner, \
14+
PreserveSkeletonsHillClimbingSidePredicateLearner, SidePredicateLearner
1215
from predicators.src.nsrt_learning.strips_learning import \
1316
learn_strips_operators
14-
from predicators.src.planning import task_plan_grounding
15-
from predicators.src.predicate_search_score_functions import \
16-
_PredictionErrorScoreFunction
1717
from predicators.src.settings import CFG
18-
from predicators.src.structs import NSRT, GroundAtom, LowLevelTrajectory, \
19-
OptionSpec, PartialNSRTAndDatastore, Predicate, Segment, STRIPSOperator, \
20-
Task, _GroundNSRT
18+
from predicators.src.structs import NSRT, LowLevelTrajectory, \
19+
PartialNSRTAndDatastore, Predicate, Task
2120

2221

2322
def learn_nsrts_from_data(trajectories: List[LowLevelTrajectory],
@@ -79,8 +78,21 @@ def learn_nsrts_from_data(trajectories: List[LowLevelTrajectory],
7978
if CFG.side_predicate_learner != "no_learning":
8079
assert CFG.option_learner == "no_learning", \
8180
"Can't learn options and side predicates together."
82-
pnads = _learn_pnad_side_predicates(pnads, trajectories, train_tasks,
83-
predicates, segmented_trajs)
81+
if CFG.side_predicate_learner == "prediction_error_hill_climbing":
82+
side_pred_learner: SidePredicateLearner = \
83+
PredictionErrorHillClimbingSidePredicateLearner(
84+
pnads, trajectories, train_tasks, predicates,
85+
segmented_trajs)
86+
elif CFG.side_predicate_learner == "preserve_skeletons_hill_climbing":
87+
side_pred_learner = \
88+
PreserveSkeletonsHillClimbingSidePredicateLearner(
89+
pnads, trajectories, train_tasks, predicates,
90+
segmented_trajs)
91+
else:
92+
raise ValueError(
93+
f"side_predicate_learner {CFG.side_predicate_learner} not " +
94+
"implemented")
95+
pnads = side_pred_learner.sideline()
8496

8597
# STEP 5: Learn options (option_learning.py) and update PNADs.
8698
_learn_pnad_options(pnads) # in-place update
@@ -97,130 +109,6 @@ def learn_nsrts_from_data(trajectories: List[LowLevelTrajectory],
97109
return set(nsrts)
98110

99111

100-
def _learn_pnad_side_predicates(
101-
pnads: List[PartialNSRTAndDatastore],
102-
trajectories: List[LowLevelTrajectory], train_tasks: List[Task],
103-
predicates: Set[Predicate],
104-
segmented_trajs: List[List[Segment]]) -> List[PartialNSRTAndDatastore]:
105-
106-
def _check_goal(s: Tuple[PartialNSRTAndDatastore, ...]) -> bool:
107-
del s # unused
108-
# There are no goal states for this search; run until exhausted.
109-
return False
110-
111-
def _get_successors(
112-
s: Tuple[PartialNSRTAndDatastore, ...],
113-
) -> Iterator[Tuple[None, Tuple[PartialNSRTAndDatastore, ...], float]]:
114-
# For each PNAD/operator...
115-
for i in range(len(s)):
116-
pnad = s[i]
117-
_, option_vars = pnad.option_spec
118-
# ...consider changing each of its effects to a side predicate.
119-
for effect in pnad.op.add_effects:
120-
if len(pnad.op.add_effects) > 1:
121-
# We don't want sidelining to result in a no-op
122-
new_pnad = PartialNSRTAndDatastore(
123-
pnad.op.effect_to_side_predicate(
124-
effect, option_vars, "add"), pnad.datastore,
125-
pnad.option_spec)
126-
sprime = list(s)
127-
sprime[i] = new_pnad
128-
yield (None, tuple(sprime), 1.0)
129-
130-
# ...consider removing it.
131-
sprime = list(s)
132-
del sprime[i]
133-
yield (None, tuple(sprime), 1.0)
134-
135-
if CFG.side_predicate_learner == "prediction_error_hillclimbing":
136-
score_func = _PredictionErrorScoreFunction(predicates, [], {},
137-
train_tasks)
138-
139-
def _evaluate(s: Tuple[PartialNSRTAndDatastore, ...]) -> float:
140-
# Score function for search. Lower is better.
141-
strips_ops = [pnad.op for pnad in s]
142-
option_specs = [pnad.option_spec for pnad in s]
143-
score = score_func.evaluate_with_operators(frozenset(),
144-
trajectories,
145-
segmented_trajs,
146-
strips_ops,
147-
option_specs)
148-
return score
149-
150-
elif CFG.side_predicate_learner == "preserve_skeletons_hillclimbing":
151-
152-
def _evaluate(s: Tuple[PartialNSRTAndDatastore, ...]) -> float:
153-
# Score function for search. Lower is better.
154-
strips_ops = [pnad.op for pnad in s]
155-
option_specs = [pnad.option_spec for pnad in s]
156-
preserves_harmlessness = check_harmlessness(
157-
predicates, train_tasks, trajectories, segmented_trajs,
158-
strips_ops, option_specs)
159-
# NOTE: Arbitrary large number bigger than the total number of
160-
# operators at the start of the search.
161-
score = 10 * len(pnads)
162-
if preserves_harmlessness:
163-
score = 2 * len(strips_ops)
164-
for op in strips_ops:
165-
score -= len(op.side_predicates)
166-
return score
167-
168-
else:
169-
raise ValueError(
170-
f"side_predicate_learner {CFG.side_predicate_learner} not " +
171-
"implemented")
172-
173-
# Run the search, starting from original PNADs.
174-
path, _, _ = utils.run_hill_climbing(tuple(pnads), _check_goal,
175-
_get_successors, _evaluate)
176-
# The last state in the search holds the final PNADs.
177-
pnads = list(path[-1])
178-
# Recompute the datastores in the PNADs. We need to do this
179-
# because now that we have side predicates, each transition may be
180-
# assigned to *multiple* datastores.
181-
_recompute_datastores_from_segments(segmented_trajs, pnads)
182-
return pnads
183-
184-
185-
def _recompute_datastores_from_segments(
186-
segmented_trajs: List[List[Segment]],
187-
pnads: List[PartialNSRTAndDatastore]) -> None:
188-
for pnad in pnads:
189-
pnad.datastore = [] # reset all PNAD datastores
190-
for seg_traj in segmented_trajs:
191-
objects = set(seg_traj[0].states[0])
192-
for segment in seg_traj:
193-
assert segment.has_option()
194-
segment_option = segment.get_option()
195-
segment_param_option = segment_option.parent
196-
segment_option_objs = tuple(segment_option.objects)
197-
# Get ground operators given these objects and option objs.
198-
for pnad in pnads:
199-
param_opt, opt_vars = pnad.option_spec
200-
if param_opt != segment_param_option:
201-
continue
202-
isub = dict(zip(opt_vars, segment_option_objs))
203-
# Consider adding this segment to each datastore.
204-
for ground_op in utils.all_ground_operators_given_partial(
205-
pnad.op, objects, isub):
206-
# Check if preconditions hold.
207-
if not ground_op.preconditions.issubset(
208-
segment.init_atoms):
209-
continue
210-
# Check if effects match. Note that we're using the side
211-
# predicates semantics here!
212-
atoms = utils.apply_operator(ground_op, segment.init_atoms)
213-
if not atoms.issubset(segment.final_atoms):
214-
continue
215-
# Skip over segments that have multiple possible bindings.
216-
if len(set(ground_op.objects)) != len(ground_op.objects):
217-
continue
218-
# This segment belongs in this datastore, so add it.
219-
sub = dict(zip(pnad.op.parameters, ground_op.objects))
220-
pnad.add_to_datastore((segment, sub),
221-
check_effect_equality=False)
222-
223-
224112
def _learn_pnad_options(pnads: List[PartialNSRTAndDatastore]) -> None:
225113
logging.info("\nDoing option learning...")
226114
option_learner = create_option_learner()
@@ -263,99 +151,3 @@ def _learn_pnad_samplers(pnads: List[PartialNSRTAndDatastore],
263151
# Replace the samplers in the PNADs.
264152
for pnad, sampler in zip(pnads, samplers):
265153
pnad.sampler = sampler
266-
267-
268-
def check_harmlessness(predicates: Set[Predicate], train_tasks: List[Task],
269-
trajectories: List[LowLevelTrajectory],
270-
segmented_trajs: List[List[Segment]],
271-
strips_ops: List[STRIPSOperator],
272-
option_specs: List[OptionSpec]) -> bool:
273-
"""Function to check whether a given set of operators and predicates
274-
preserves harmlessness over demonstrations on some number of training
275-
tasks.
276-
277-
Preserving harmlessness roughly means that the set of operators and
278-
predicates supports the agent's ability to plan to achieve all of
279-
the training tasks in the same way as was demonstrated (i.e, the
280-
predicates and operators don't render any demonstrated trajectory
281-
impossible).
282-
"""
283-
284-
assert len(trajectories) == len(segmented_trajs)
285-
for ll_traj, seg_traj in zip(trajectories, segmented_trajs):
286-
if not ll_traj.is_demo:
287-
continue
288-
atoms_seq = utils.segment_trajectory_to_atoms_sequence(seg_traj)
289-
traj_goal = train_tasks[ll_traj.train_task_idx].goal
290-
demo_preserved = check_single_demo_preservation(
291-
ll_traj, atoms_seq, traj_goal, predicates, strips_ops,
292-
option_specs)
293-
if not demo_preserved:
294-
return False
295-
296-
return True
297-
298-
299-
def check_single_demo_preservation(ll_traj: LowLevelTrajectory,
300-
atoms_seq: List[Set[GroundAtom]],
301-
traj_goal: Set[GroundAtom],
302-
predicates: Set[Predicate],
303-
strips_ops: List[STRIPSOperator],
304-
option_specs: List[OptionSpec]) -> bool:
305-
"""Function to check whether a given set of operators and predicates
306-
preserves a single training trajectory."""
307-
init_atoms = utils.abstract(ll_traj.states[0], predicates)
308-
objects = set(ll_traj.states[0])
309-
ground_nsrts, _ = task_plan_grounding(init_atoms, objects, strips_ops,
310-
option_specs)
311-
heuristic = utils.create_task_planning_heuristic(
312-
CFG.sesame_task_planning_heuristic, init_atoms, traj_goal,
313-
ground_nsrts, predicates, objects)
314-
315-
def _check_goal(state: Tuple[FrozenSet[GroundAtom], int]) -> bool:
316-
return traj_goal.issubset(state[0])
317-
318-
def _get_successor_with_correct_option(
319-
searchnode_state: Tuple[FrozenSet[GroundAtom], int]
320-
) -> Iterator[Tuple[_GroundNSRT, Tuple[FrozenSet[GroundAtom], int],
321-
float]]:
322-
state = searchnode_state[0]
323-
idx_into_traj = searchnode_state[1]
324-
325-
if idx_into_traj > len(ll_traj.actions) - 1:
326-
return
327-
328-
gt_option = ll_traj.actions[idx_into_traj].get_option()
329-
expected_next_hl_state = atoms_seq[idx_into_traj + 1]
330-
331-
for applicable_nsrt in utils.get_applicable_operators(
332-
ground_nsrts, state):
333-
# NOTE: we check that the ParameterizedOptions are equal before
334-
# attempting to ground because otherwise, we might
335-
# get a parameter mismatch and trigger an AssertionError
336-
# during grounding.
337-
if applicable_nsrt.option != gt_option.parent:
338-
continue
339-
if applicable_nsrt.option_objs != gt_option.objects:
340-
continue
341-
next_hl_state = utils.apply_operator(applicable_nsrt, set(state))
342-
exp_state_matches = next_hl_state.issubset(expected_next_hl_state)
343-
if exp_state_matches:
344-
# The returned cost is uniform because we don't
345-
# actually care about finding the shortest path;
346-
# just one that matches!
347-
yield (applicable_nsrt, (frozenset(next_hl_state),
348-
idx_into_traj + 1), 1.0)
349-
350-
init_atoms_frozen = frozenset(init_atoms)
351-
init_searchnode_state = (init_atoms_frozen, 0)
352-
# NOTE: each state in the below GBFS is a tuple of
353-
# (current_atoms, idx_into_traj). The idx_into_traj is necessary because
354-
# we need to check whether the atoms that are true at this particular
355-
# index into the trajectory is what we would expect given the demo
356-
# trajectory.
357-
state_seq, _ = utils.run_gbfs(
358-
init_searchnode_state, _check_goal, _get_successor_with_correct_option,
359-
lambda searchnode_state: heuristic(searchnode_state[0]))
360-
361-
return _check_goal(state_seq[-1])

0 commit comments

Comments
 (0)