2323from agentlab .llm .llm_utils import Discussion , ParseError , SystemMessage , retry
2424from agentlab .llm .tracking import cost_tracker_decorator
2525
26- from .generic_agent_prompt import GenericPromptFlags , MainPrompt
26+ from .generic_agent_prompt import (
27+ GenericPromptFlags ,
28+ MainPrompt ,
29+ StepWiseRetrievalPrompt ,
30+ )
2731
2832
2933@dataclass
@@ -102,6 +106,16 @@ def set_task_name(self, task_name: str):
102106 def get_action (self , obs ):
103107
104108 self .obs_history .append (obs )
109+
110+ system_prompt = SystemMessage (dp .SystemPrompt ().prompt )
111+
112+ queries , think_queries = self ._get_queries ()
113+
114+ # TODO
115+ # use those queries to retreive from the database. e.g.:
116+ # hints = self.hint_db.get_hints(queries)
117+ # then add those hints to the main prompt
118+
105119 main_prompt = MainPrompt (
106120 action_set = self .action_set ,
107121 obs_history = self .obs_history ,
@@ -120,8 +134,6 @@ def get_action(self, obs):
120134
121135 max_prompt_tokens , max_trunc_itr = self ._get_maxes ()
122136
123- system_prompt = SystemMessage (dp .SystemPrompt ().prompt )
124-
125137 human_prompt = dp .fit_tokens (
126138 shrinkable = main_prompt ,
127139 max_prompt_tokens = max_prompt_tokens ,
@@ -168,6 +180,31 @@ def get_action(self, obs):
168180 )
169181 return ans_dict ["action" ], agent_info
170182
183+ def _get_queries (self ):
184+ """Retrieve queries for hinting."""
185+ system_prompt = SystemMessage (dp .SystemPrompt ().prompt )
186+ query_prompt = StepWiseRetrievalPrompt (
187+ obs_history = self .obs_history ,
188+ actions = self .actions ,
189+ thoughts = self .thoughts ,
190+ obs_flags = self .flags .obs ,
191+ n_queries = self .flags .n_retrieval_queries , # TODO
192+ )
193+
194+ chat_messages = Discussion ([system_prompt , query_prompt .prompt ])
195+ ans_dict = retry (
196+ self .chat_llm ,
197+ chat_messages ,
198+ n_retry = self .max_retry ,
199+ parser = query_prompt ._parse_answer ,
200+ )
201+
202+ queries = ans_dict .get ("queries" , [])
203+ assert len (queries ) == self .flags .n_retrieval_queries
204+
205+ # TODO: we should probably propagate these chat_messages to be able to see them in xray
206+ return queries , ans_dict .get ("think" , None )
207+
171208 def reset (self , seed = None ):
172209 self .seed = seed
173210 self .plan = "No plan yet"
0 commit comments