33from __future__ import annotations
44
55import logging
6- from typing import FrozenSet , Iterator , List , Set , Tuple
6+ from typing import List , Set
77
88from predicators .src import utils
99from predicators .src .nsrt_learning .option_learning import create_option_learner
1010from predicators .src .nsrt_learning .sampler_learning import learn_samplers
1111from predicators .src .nsrt_learning .segmentation import segment_trajectory
12+ from predicators .src .nsrt_learning .side_predicate_learning import \
13+ PredictionErrorHillClimbingSidePredicateLearner , \
14+ PreserveSkeletonsHillClimbingSidePredicateLearner , SidePredicateLearner
1215from predicators .src .nsrt_learning .strips_learning import \
1316 learn_strips_operators
14- from predicators .src .planning import task_plan_grounding
15- from predicators .src .predicate_search_score_functions import \
16- _PredictionErrorScoreFunction
1717from predicators .src .settings import CFG
18- from predicators .src .structs import NSRT , GroundAtom , LowLevelTrajectory , \
19- OptionSpec , PartialNSRTAndDatastore , Predicate , Segment , STRIPSOperator , \
20- Task , _GroundNSRT
18+ from predicators .src .structs import NSRT , LowLevelTrajectory , \
19+ PartialNSRTAndDatastore , Predicate , Task
2120
2221
2322def learn_nsrts_from_data (trajectories : List [LowLevelTrajectory ],
@@ -79,8 +78,21 @@ def learn_nsrts_from_data(trajectories: List[LowLevelTrajectory],
7978 if CFG .side_predicate_learner != "no_learning" :
8079 assert CFG .option_learner == "no_learning" , \
8180 "Can't learn options and side predicates together."
82- pnads = _learn_pnad_side_predicates (pnads , trajectories , train_tasks ,
83- predicates , segmented_trajs )
81+ if CFG .side_predicate_learner == "prediction_error_hill_climbing" :
82+ side_pred_learner : SidePredicateLearner = \
83+ PredictionErrorHillClimbingSidePredicateLearner (
84+ pnads , trajectories , train_tasks , predicates ,
85+ segmented_trajs )
86+ elif CFG .side_predicate_learner == "preserve_skeletons_hill_climbing" :
87+ side_pred_learner = \
88+ PreserveSkeletonsHillClimbingSidePredicateLearner (
89+ pnads , trajectories , train_tasks , predicates ,
90+ segmented_trajs )
91+ else :
92+ raise ValueError (
93+ f"side_predicate_learner { CFG .side_predicate_learner } not " +
94+ "implemented" )
95+ pnads = side_pred_learner .sideline ()
8496
8597 # STEP 5: Learn options (option_learning.py) and update PNADs.
8698 _learn_pnad_options (pnads ) # in-place update
@@ -97,130 +109,6 @@ def learn_nsrts_from_data(trajectories: List[LowLevelTrajectory],
97109 return set (nsrts )
98110
99111
100- def _learn_pnad_side_predicates (
101- pnads : List [PartialNSRTAndDatastore ],
102- trajectories : List [LowLevelTrajectory ], train_tasks : List [Task ],
103- predicates : Set [Predicate ],
104- segmented_trajs : List [List [Segment ]]) -> List [PartialNSRTAndDatastore ]:
105-
106- def _check_goal (s : Tuple [PartialNSRTAndDatastore , ...]) -> bool :
107- del s # unused
108- # There are no goal states for this search; run until exhausted.
109- return False
110-
111- def _get_successors (
112- s : Tuple [PartialNSRTAndDatastore , ...],
113- ) -> Iterator [Tuple [None , Tuple [PartialNSRTAndDatastore , ...], float ]]:
114- # For each PNAD/operator...
115- for i in range (len (s )):
116- pnad = s [i ]
117- _ , option_vars = pnad .option_spec
118- # ...consider changing each of its effects to a side predicate.
119- for effect in pnad .op .add_effects :
120- if len (pnad .op .add_effects ) > 1 :
121- # We don't want sidelining to result in a no-op
122- new_pnad = PartialNSRTAndDatastore (
123- pnad .op .effect_to_side_predicate (
124- effect , option_vars , "add" ), pnad .datastore ,
125- pnad .option_spec )
126- sprime = list (s )
127- sprime [i ] = new_pnad
128- yield (None , tuple (sprime ), 1.0 )
129-
130- # ...consider removing it.
131- sprime = list (s )
132- del sprime [i ]
133- yield (None , tuple (sprime ), 1.0 )
134-
135- if CFG .side_predicate_learner == "prediction_error_hillclimbing" :
136- score_func = _PredictionErrorScoreFunction (predicates , [], {},
137- train_tasks )
138-
139- def _evaluate (s : Tuple [PartialNSRTAndDatastore , ...]) -> float :
140- # Score function for search. Lower is better.
141- strips_ops = [pnad .op for pnad in s ]
142- option_specs = [pnad .option_spec for pnad in s ]
143- score = score_func .evaluate_with_operators (frozenset (),
144- trajectories ,
145- segmented_trajs ,
146- strips_ops ,
147- option_specs )
148- return score
149-
150- elif CFG .side_predicate_learner == "preserve_skeletons_hillclimbing" :
151-
152- def _evaluate (s : Tuple [PartialNSRTAndDatastore , ...]) -> float :
153- # Score function for search. Lower is better.
154- strips_ops = [pnad .op for pnad in s ]
155- option_specs = [pnad .option_spec for pnad in s ]
156- preserves_harmlessness = check_harmlessness (
157- predicates , train_tasks , trajectories , segmented_trajs ,
158- strips_ops , option_specs )
159- # NOTE: Arbitrary large number bigger than the total number of
160- # operators at the start of the search.
161- score = 10 * len (pnads )
162- if preserves_harmlessness :
163- score = 2 * len (strips_ops )
164- for op in strips_ops :
165- score -= len (op .side_predicates )
166- return score
167-
168- else :
169- raise ValueError (
170- f"side_predicate_learner { CFG .side_predicate_learner } not " +
171- "implemented" )
172-
173- # Run the search, starting from original PNADs.
174- path , _ , _ = utils .run_hill_climbing (tuple (pnads ), _check_goal ,
175- _get_successors , _evaluate )
176- # The last state in the search holds the final PNADs.
177- pnads = list (path [- 1 ])
178- # Recompute the datastores in the PNADs. We need to do this
179- # because now that we have side predicates, each transition may be
180- # assigned to *multiple* datastores.
181- _recompute_datastores_from_segments (segmented_trajs , pnads )
182- return pnads
183-
184-
185- def _recompute_datastores_from_segments (
186- segmented_trajs : List [List [Segment ]],
187- pnads : List [PartialNSRTAndDatastore ]) -> None :
188- for pnad in pnads :
189- pnad .datastore = [] # reset all PNAD datastores
190- for seg_traj in segmented_trajs :
191- objects = set (seg_traj [0 ].states [0 ])
192- for segment in seg_traj :
193- assert segment .has_option ()
194- segment_option = segment .get_option ()
195- segment_param_option = segment_option .parent
196- segment_option_objs = tuple (segment_option .objects )
197- # Get ground operators given these objects and option objs.
198- for pnad in pnads :
199- param_opt , opt_vars = pnad .option_spec
200- if param_opt != segment_param_option :
201- continue
202- isub = dict (zip (opt_vars , segment_option_objs ))
203- # Consider adding this segment to each datastore.
204- for ground_op in utils .all_ground_operators_given_partial (
205- pnad .op , objects , isub ):
206- # Check if preconditions hold.
207- if not ground_op .preconditions .issubset (
208- segment .init_atoms ):
209- continue
210- # Check if effects match. Note that we're using the side
211- # predicates semantics here!
212- atoms = utils .apply_operator (ground_op , segment .init_atoms )
213- if not atoms .issubset (segment .final_atoms ):
214- continue
215- # Skip over segments that have multiple possible bindings.
216- if len (set (ground_op .objects )) != len (ground_op .objects ):
217- continue
218- # This segment belongs in this datastore, so add it.
219- sub = dict (zip (pnad .op .parameters , ground_op .objects ))
220- pnad .add_to_datastore ((segment , sub ),
221- check_effect_equality = False )
222-
223-
224112def _learn_pnad_options (pnads : List [PartialNSRTAndDatastore ]) -> None :
225113 logging .info ("\n Doing option learning..." )
226114 option_learner = create_option_learner ()
@@ -263,99 +151,3 @@ def _learn_pnad_samplers(pnads: List[PartialNSRTAndDatastore],
263151 # Replace the samplers in the PNADs.
264152 for pnad , sampler in zip (pnads , samplers ):
265153 pnad .sampler = sampler
266-
267-
268- def check_harmlessness (predicates : Set [Predicate ], train_tasks : List [Task ],
269- trajectories : List [LowLevelTrajectory ],
270- segmented_trajs : List [List [Segment ]],
271- strips_ops : List [STRIPSOperator ],
272- option_specs : List [OptionSpec ]) -> bool :
273- """Function to check whether a given set of operators and predicates
274- preserves harmlessness over demonstrations on some number of training
275- tasks.
276-
277- Preserving harmlessness roughly means that the set of operators and
278- predicates supports the agent's ability to plan to achieve all of
279- the training tasks in the same way as was demonstrated (i.e, the
280- predicates and operators don't render any demonstrated trajectory
281- impossible).
282- """
283-
284- assert len (trajectories ) == len (segmented_trajs )
285- for ll_traj , seg_traj in zip (trajectories , segmented_trajs ):
286- if not ll_traj .is_demo :
287- continue
288- atoms_seq = utils .segment_trajectory_to_atoms_sequence (seg_traj )
289- traj_goal = train_tasks [ll_traj .train_task_idx ].goal
290- demo_preserved = check_single_demo_preservation (
291- ll_traj , atoms_seq , traj_goal , predicates , strips_ops ,
292- option_specs )
293- if not demo_preserved :
294- return False
295-
296- return True
297-
298-
299- def check_single_demo_preservation (ll_traj : LowLevelTrajectory ,
300- atoms_seq : List [Set [GroundAtom ]],
301- traj_goal : Set [GroundAtom ],
302- predicates : Set [Predicate ],
303- strips_ops : List [STRIPSOperator ],
304- option_specs : List [OptionSpec ]) -> bool :
305- """Function to check whether a given set of operators and predicates
306- preserves a single training trajectory."""
307- init_atoms = utils .abstract (ll_traj .states [0 ], predicates )
308- objects = set (ll_traj .states [0 ])
309- ground_nsrts , _ = task_plan_grounding (init_atoms , objects , strips_ops ,
310- option_specs )
311- heuristic = utils .create_task_planning_heuristic (
312- CFG .sesame_task_planning_heuristic , init_atoms , traj_goal ,
313- ground_nsrts , predicates , objects )
314-
315- def _check_goal (state : Tuple [FrozenSet [GroundAtom ], int ]) -> bool :
316- return traj_goal .issubset (state [0 ])
317-
318- def _get_successor_with_correct_option (
319- searchnode_state : Tuple [FrozenSet [GroundAtom ], int ]
320- ) -> Iterator [Tuple [_GroundNSRT , Tuple [FrozenSet [GroundAtom ], int ],
321- float ]]:
322- state = searchnode_state [0 ]
323- idx_into_traj = searchnode_state [1 ]
324-
325- if idx_into_traj > len (ll_traj .actions ) - 1 :
326- return
327-
328- gt_option = ll_traj .actions [idx_into_traj ].get_option ()
329- expected_next_hl_state = atoms_seq [idx_into_traj + 1 ]
330-
331- for applicable_nsrt in utils .get_applicable_operators (
332- ground_nsrts , state ):
333- # NOTE: we check that the ParameterizedOptions are equal before
334- # attempting to ground because otherwise, we might
335- # get a parameter mismatch and trigger an AssertionError
336- # during grounding.
337- if applicable_nsrt .option != gt_option .parent :
338- continue
339- if applicable_nsrt .option_objs != gt_option .objects :
340- continue
341- next_hl_state = utils .apply_operator (applicable_nsrt , set (state ))
342- exp_state_matches = next_hl_state .issubset (expected_next_hl_state )
343- if exp_state_matches :
344- # The returned cost is uniform because we don't
345- # actually care about finding the shortest path;
346- # just one that matches!
347- yield (applicable_nsrt , (frozenset (next_hl_state ),
348- idx_into_traj + 1 ), 1.0 )
349-
350- init_atoms_frozen = frozenset (init_atoms )
351- init_searchnode_state = (init_atoms_frozen , 0 )
352- # NOTE: each state in the below GBFS is a tuple of
353- # (current_atoms, idx_into_traj). The idx_into_traj is necessary because
354- # we need to check whether the atoms that are true at this particular
355- # index into the trajectory is what we would expect given the demo
356- # trajectory.
357- state_seq , _ = utils .run_gbfs (
358- init_searchnode_state , _check_goal , _get_successor_with_correct_option ,
359- lambda searchnode_state : heuristic (searchnode_state [0 ]))
360-
361- return _check_goal (state_seq [- 1 ])
0 commit comments