Skip to content

Commit e1a7c3d

Browse files
committed
wip
1 parent 2052749 commit e1a7c3d

3 files changed

Lines changed: 68 additions & 29 deletions

File tree

predicators/approaches/pp_online_predicate_invention_approach.py

Lines changed: 61 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -284,6 +284,7 @@ def _get_successors(
284284

285285
def get_false_positive_states_from_seg_trajs(
286286
segmented_trajs: List[List[Segment]],
287+
trajectories: List[LowLevelTrajectory],
287288
exogenous_processes: List[ExogenousProcess],
288289
) -> Dict[_GroundExogenousProcess, List[State]]:
289290

@@ -295,10 +296,12 @@ def get_false_positive_states_from_seg_trajs(
295296
# Cache for ground_exogenous_processes to avoid recomputation
296297
objects_to_ground_processes = {}
297298

298-
for segmented_traj in segmented_trajs:
299-
# Checking each segmented trajectory
300-
objects = frozenset(segmented_traj[0].trajectory.states[0])
299+
for traj, segmented_traj in zip(trajectories, segmented_trajs):
300+
scheduled_events: Dict[int, List[Tuple[_GroundExogenousProcess,
301+
int]]] = {}
302+
t = 0 # Time counter in the trajectory
301303
# Only recompute if objects are different
304+
objects = frozenset(segmented_traj[0].trajectory.states[0])
302305
if objects not in objects_to_ground_processes:
303306
ground_exogenous_processes, _ = task_plan_grounding(
304307
set(),
@@ -307,34 +310,67 @@ def get_false_positive_states_from_seg_trajs(
307310
allow_noops=True,
308311
compute_reachable_atoms=False)
309312
objects_to_ground_processes[objects] = ground_exogenous_processes
313+
assert all(
314+
isinstance(g_exo_process, _GroundExogenousProcess)
315+
for g_exo_process in ground_exogenous_processes), \
316+
"Expected all processes to be ground exogenous processes."
310317
else:
311318
ground_exogenous_processes = objects_to_ground_processes[objects]
312319

313320
# Pre-compute segment init_atoms for efficiency
314321
segment_init_atoms = [segment.init_atoms for segment in segmented_traj]
315322

316-
for g_exo_process in ground_exogenous_processes:
317-
condition = g_exo_process.condition_at_start # Cache reference
318-
add_effects = g_exo_process.add_effects
319-
delete_effects = g_exo_process.delete_effects
320-
321-
for i, segment in enumerate(segmented_traj):
322-
satisfy_condition = condition.issubset(segment_init_atoms[i])
323-
first_state_or_prev_state_doesnt_satisfy = i == 0 or \
324-
not condition.issubset(segment_init_atoms[i - 1])
325-
326-
if satisfy_condition and first_state_or_prev_state_doesnt_satisfy:
327-
false_positive_process_state[g_exo_process].append(
328-
# segment.trajectory.states[0])
329-
segment.init_atoms)
330-
331-
# Check for removal condition
332-
if (add_effects.issubset(segment.add_effects)
333-
and delete_effects.issubset(segment.delete_effects)):
334-
if false_positive_process_state[g_exo_process]:
335-
# TODO: we don't really know which one to remove, pop
336-
# the first one is a bias.
337-
false_positive_process_state[g_exo_process].pop(0)
323+
for i, segment in enumerate(segmented_traj):
324+
# 1. Process effects scheduled for this step or earlier
325+
relevant_events = [
326+
(scheduled_time, proc_n_times)
327+
for scheduled_time, proc_n_times in scheduled_events.items()
328+
if scheduled_time <= t
329+
]
330+
for scheduled_time, proc_n_times in relevant_events:
331+
for g_exo_process, start_time in proc_n_times:
332+
condition_overall = g_exo_process.condition_overall
333+
condition_at_end = g_exo_process.condition_at_end
334+
add_effects = g_exo_process.add_effects
335+
delete_effects = g_exo_process.delete_effects
336+
337+
if (all(condition_overall.issubset(s) for s in
338+
traj.states[start_time + 1:]) and
339+
condition_at_end.issubset(traj.states[scheduled_time])):
340+
341+
false_positive_process_state[g_exo_process].append(
342+
segment.init_atoms)
343+
344+
# Check for effects scheduled for this step.
345+
if (add_effects.issubset(traj.states[scheduled_time].add_effects) and
346+
delete_effects.issubset(traj.states[scheduled_time].delete_effects)):
347+
if false_positive_process_state[g_exo_process]:
348+
false_positive_process_state[g_exo_process].pop(0)
349+
350+
# Delete the scheduled events that are no longer relevant
351+
del scheduled_events[scheduled_time]
352+
353+
# 2. Schedule an effect to be checked later
354+
for g_exo_process in ground_exogenous_processes:
355+
condition_at_start = g_exo_process.condition_at_start
356+
condition_overall = g_exo_process.condition_overall
357+
add_effects = g_exo_process.add_effects
358+
delete_effects = g_exo_process.delete_effects
359+
360+
satisfy_condition = condition_at_start.issubset(
361+
segment_init_atoms[i])
362+
first_state_to_satisfy = i == 0 or \
363+
not condition_at_start.issubset(segment_init_atoms[i - 1])
364+
365+
if satisfy_condition and first_state_to_satisfy:
366+
delay = g_exo_process.delay_distribution.sample()
367+
schedued_time = t + delay
368+
if schedued_time not in scheduled_events:
369+
scheduled_events[schedued_time] = []
370+
scheduled_events[schedued_time].append((g_exo_process, i))
371+
# e.g. if current 0, len is 5, the next timestep is 5.
372+
t += len(segment.states)
373+
338374
return false_positive_process_state
339375

340376

predicators/nsrt_learning/strips_learning/clustering_learner.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -477,6 +477,7 @@ def _learn_pnad_preconditions_sequential(self,
477477
parameters=new_params), pnad.datastore,
478478
pnad.option_spec)
479479
final_pnads.append(new_pnad)
480+
breakpoint()
480481

481482
return final_pnads
482483

@@ -537,7 +538,8 @@ def _score_preconditions(self, exogenous_process: ExogenousProcess,
537538
exogenous_process.condition_overall = set(preconditions)
538539
false_positive_process_state =\
539540
self._get_false_positive_states_from_seg_trajs(
540-
self._atom_change_segmented_trajs, [exogenous_process])
541+
self._atom_change_segmented_trajs,
542+
self._trajectories, [exogenous_process])
541543
num_false_positives = 0
542544
for _, states in false_positive_process_state.items():
543545
num_false_positives += len(states)
@@ -695,7 +697,8 @@ def _get_top_consistent_conditions(self, initial_atom: Set[LiftedAtom],
695697

696698
false_positive_process_state = \
697699
self._get_fp_states_from_seg_trajs(
698-
self._atom_change_segmented_trajs, [exogenous_process])
700+
self._atom_change_segmented_trajs,
701+
self._trajectories, [exogenous_process])
699702
num_false_positives = sum(
700703
len(states)
701704
for states in false_positive_process_state.values())

scripts/configs/mara_bench.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,8 +76,8 @@ APPROACHES:
7676
strips_learner: "llm"
7777
find_best_matching_pnad_skip_if_effect_not_subset: False
7878
# exogenous_process_learner: "cluster_and_llm_select"
79-
# exogenous_process_learner: "cluster_and_search_process_learner"
80-
exogenous_process_learner: "cluster_and_inverse_planning"
79+
exogenous_process_learner: "cluster_and_search_process_learner"
80+
# exogenous_process_learner: "cluster_and_inverse_planning"
8181
process_learner_check_false_positives: False
8282
# To have demos to stop when option terminates.
8383
terminate_on_goal_reached: False

0 commit comments

Comments
 (0)