Skip to content

Commit b1eac4c

Browse files
committed
Updates.
1 parent 17ae019 commit b1eac4c

5 files changed

Lines changed: 46 additions & 14 deletions

File tree

predicators/approaches/grammar_search_invention_approach.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -486,6 +486,11 @@ def enumerate(self) -> Iterator[Tuple[Predicate, float]]:
486486
"Whole0",
487487
"Cut0",
488488
],
489+
"pybullet_coffee": [
490+
"JugInMachine",
491+
"Holding",
492+
"HandEmpty",
493+
]
489494
}
490495
_DEBUG_VLM_PREDICATES = defaultdict(list, _DEBUG_VLM_PREDICATES)
491496

@@ -1054,15 +1059,18 @@ def learn_from_offline_dataset(self, dataset: Dataset) -> None:
10541059
self._parse_atom_dataset_from_annotated_dataset(dataset)
10551060
atom_dataset = utils.merge_ground_atom_datasets(
10561061
atom_dataset_from_grammar, atom_dataset_from_vlm)
1062+
10571063
# If grammar_search_invent_geo_predicates_only is False, then we
10581064
# want to invent both VLM and geo predicates
1059-
if not CFG.grammar_search_invent_geo_predicates_only:
1060-
candidates = candidates_from_grammar | candidates_from_vlm
1061-
# Otherwise, we only want to invent geo predicates, and directly
1062-
# select all the VLM predicates.
1063-
else:
1064-
candidates = candidates_from_grammar
1065-
self._initial_predicates |= set(candidates_from_vlm.keys())
1065+
# if not CFG.grammar_search_invent_geo_predicates_only:
1066+
# candidates = candidates_from_grammar | candidates_from_vlm
1067+
# # Otherwise, we only want to invent geo predicates, and directly
1068+
# # select all the VLM predicates.
1069+
# else:
1070+
# candidates = candidates_from_grammar
1071+
# self._initial_predicates |= set(candidates_from_vlm.keys())
1072+
candidates = candidates_from_vlm
1073+
10661074
elif not CFG.offline_data_method in [
10671075
"demo+labelled_atoms", "saved_vlm_img_demos_folder",
10681076
"demo_with_vlm_imgs"

predicators/datasets/generate_atom_trajs_with_vlm.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ def _generate_prompt_for_atom_proposals(
7878
for act in traj.actions)
7979
# NOTE: exact same issue as described in the above note for
8080
# naive_whole_traj.
81+
# import pdb; pdb.set_trace()
8182
ret_list.append(
8283
(prompt, [traj.imgs[i][0] for i in range(len(traj.imgs))]))
8384
else: # pragma: no cover.
@@ -144,6 +145,8 @@ def _label_single_trajectory_with_vlm_atom_values(indexed_traj: Tuple[
144145
obj_names = [o.name for o in traj.objects]
145146
filtered_atoms_list = []
146147
for a in atoms_list:
148+
# Remove whitespace from the atom string.
149+
a = a.replace(' ', '')
147150
# Get the names of the objects in this atom.
148151
atom_args = a[a.find('(') + 1:a.find(')')]
149152
atom_objs = atom_args.split(',')
@@ -412,6 +415,7 @@ def _save_img_option_trajs_in_folder(
412415
for j, img_list in enumerate(img_option_traj.imgs):
413416
curr_traj_timestep_folder = Path(curr_traj_folder, str(j))
414417
os.makedirs(curr_traj_timestep_folder, exist_ok=False)
418+
# import pdb; pdb.set_trace()
415419
for k, img in enumerate(img_list):
416420
img.save(
417421
Path(curr_traj_timestep_folder,
@@ -1086,6 +1090,7 @@ def create_ground_atom_data_from_generated_demos(
10861090
raise NotImplementedError(
10871091
f"Cropped images not implemented for {CFG.env}.")
10881092
if CFG.env in ["pybullet_coffee"]:
1093+
# import pdb; pdb.set_trace()
10891094
state_imgs.append(state.simulator_state['images'])
10901095
else:
10911096
state_imgs.append([
@@ -1116,6 +1121,7 @@ def create_ground_atom_data_from_generated_demos(
11161121
if CFG.vlm_predicate_vision_api_generate_ground_atoms:
11171122
generate_func = _generate_ground_atoms_with_vlm_oo_code_gen
11181123
else:
1124+
# import pdb; pdb.set_trace()
11191125
generate_func = _generate_ground_atoms_with_vlm_pure_visual_preds
11201126
ground_atoms_trajs = generate_func(img_option_trajs, env, train_tasks,
11211127
known_predicates, all_task_objs, vlm)

predicators/datasets/vlm_input_data_prompts/atom_labelling/img_option_diffs_label_history.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
You are a vision system for a robot provided with two images: a before image showing the state before a skill is executed, an after image showing the state after the skill is executed. You are given a list of predicates below, and you are given the values of these predicates in the image before the skill is executed. Your job is to output the values of the following predicates in the image after the skill is executed. Pay careful attention to the visual changes between the two images to figure out which predicates change and which predicates do not change. Note that some or all of the predicates don't necessary have to change. First, output a description of what changes you expect to happen based on the skill that was just run, explicitly noting the skill that was run. Second, output a description of what visual changes you see happen between the before and after images, looking specifically at the objects involved in the skill's arguments, noting what objects these are. From these two descriptions, for each predicate labeled in the previous timestep, note whether you expect its value to change or stay the same. Next, output each predicate value in the after image as a bulleted list (use '*' for the bullets) with each predicate and value on a different line. Ensure there is a period ('.') after the truth value of the predicate. For each predicate value, provide an explanation as to why you labelled this predicate as having this particular value, and note what value this predicate had in the previous timestep, which is given to you in the prompt. Use the format: `* <predicate>: <truth_value>. <explanation>`. When labeling the value of a predicate, if you don't see the objects involved in that predicate, retain its truth value from the previous timestep. Also, if your description of changes you expect to happen, and your description of visual changes you saw happen, have nothing to do with the predicate you are trying to label, retain its truth value from the previous timestep. For example, if in the previous timestep I paint an object, and in the current timestamp I sit on it, we don't expect its color to change after sitting on it.
1+
You are a vision system for a robot provided with two images: a before image showing the state before a skill is executed, an after image showing the state after the skill is executed. You are given a list of predicates below, and you are given the values of these predicates in the image before the skill is executed. Your job is to output the values of the following predicates in the image after the skill is executed. Pay careful attention to the visual changes between the two images to figure out which predicates change and which predicates do not change. Note that some or all of the predicates don't necessary have to change. First, output a description of what changes you expect to happen based on the skill that was just run, explicitly noting the skill that was run. Second, output a description of what visual changes you see happen between the before and after images, looking specifically at the objects involved in the skill's arguments, noting what objects these are. From these two descriptions, for each predicate labeled in the previous timestep, note whether you expect its value to change or stay the same. Next, for each predicate given in the list of predicates to label, output each predicate value in the after image as a bulleted list (use '*' for the bullets) with each predicate and value on a different line. Ensure there is a period ('.') after the truth value of the predicate. For each predicate value, provide an explanation as to why you labelled this predicate as having this particular value, and note what value this predicate had in the previous timestep, which is given to you in the prompt. Use the format: `* <predicate>: <truth_value>. <explanation>`. When labeling the value of a predicate, if you don't see the objects involved in that predicate, retain its truth value from the previous timestep. Also, if your description of changes you expect to happen, and your description of visual changes you saw happen, have nothing to do with the predicate you are trying to label, retain its truth value from the previous timestep. For example, if in the previous timestep I paint an object, and in the current timestamp I sit on it, we don't expect its color to change after sitting on it.
22

33
Your response should have three sections. Here is an outline of what your response should look like:
44
[START OULTLINE]
@@ -12,4 +12,4 @@ Your response should have three sections. Here is an outline of what your respon
1212
[insert your bulleted list of `* <predicate>: <truth value>. <explanation>`]
1313
[END OUTLINE]
1414

15-
Predicates:
15+
Predicates to label:

predicators/envs/pybullet_coffee.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@
4848
create_single_arm_pybullet_robot
4949
from predicators.settings import CFG
5050
from predicators.structs import Action, Array, EnvironmentTask, Object, \
51-
Predicate, State, Observation
51+
Predicate, State, Observation, Task
5252

5353
class PyBulletCoffeeEnv(PyBulletEnv, CoffeeEnv):
5454
"""PyBullet Coffee domain.
@@ -285,6 +285,22 @@ def predicates(self) -> Set[Predicate]:
285285
@property
286286
def agent_goal_predicates(self) -> Set[Predicate]:
287287
return self.goal_predicates
288+
289+
# def get_vlm_debug_atom_strs(self,
290+
# train_tasks: List[Task]) -> List[List[str]]:
291+
# # Convert the default value from List[List[str]] to List[str] to match
292+
# # the other entries we'll put into the dictionary.
293+
# default = [a[0] for a in super().get_vlm_debug_atom_strs(train_tasks)]
294+
# atom_strs_by_task_type = {
295+
# "more_stacks": ["Cooked(patty1)"],
296+
# "fatter_burger": ["Cooked(patty1)"],
297+
# "combo_burger":
298+
# ["Cooked(patty1)", "Cut(lettuce1)", "Whole(lettuce1)"]
299+
# }
300+
# atom_strs_by_task_type = defaultdict(lambda: default,
301+
# atom_strs_by_task_type)
302+
# atom_strs = atom_strs_by_task_type[CFG.burger_no_move_task_type]
303+
# return [[a] for a in atom_strs]
288304

289305
@property
290306
def oracle_proposed_predicates(self) -> Set[Predicate]:

predicators/pretrained_model_interface.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -387,10 +387,12 @@ def _sample_completions(
387387
images=imgs,
388388
detail="auto")
389389
responses = [
390-
self.call_openai_api(messages,
391-
model=self.model_name,
392-
max_tokens=self._max_tokens,
393-
temperature=temperature)
390+
self.call_openai_api(messages, model=self.model_name, max_tokens=self._max_tokens, temperature=temperature)
394391
for _ in range(num_completions)
395392
]
393+
while any("sorry" in response.lower() for response in responses):
394+
responses = [
395+
self.call_openai_api(messages, model=self.model_name, max_tokens=self._max_tokens, temperature=temperature)
396+
for _ in range(num_completions)
397+
]
396398
return responses

0 commit comments

Comments
 (0)