@@ -177,37 +177,65 @@ def __len__(self):
177177
178178def collect_trajectories (env_name : str , state_dim : int , act_dim : int , act_type : str ,
179179 offline_steps : int , max_episode_len : int , gamma : float , seed : int ):
180- """Collects offline_steps env steps with a random policy."""
181180 env = gym .make (env_name )
182- # It's good practice to seed the env for reproducibility if it supports it, though random policy diminishes this
183- # env.seed(seed) # Deprecated, use env.reset(seed=seed)
184-
181+
185182 trajectories = []
186183 buf = TrajectoryBuffer (max_episode_len , state_dim , act_dim , act_type )
187184
188185 print (f"Collecting { offline_steps } steps from { env_name } using a random policy..." )
189186
190187 current_steps = 0
191- obs , _ = env .reset (seed = seed ) # Seed on reset
192188
189+ raw_reset_output = env .reset (seed = seed )
190+ if isinstance (raw_reset_output , tuple ) and len (raw_reset_output ) == 2 and isinstance (raw_reset_output [1 ], dict ):
191+ obs , info_dict = raw_reset_output
192+ else :
193+ obs = raw_reset_output
194+ info_dict = {}
195+
196+ episode_step_count = 0
193197 while current_steps < offline_steps :
194198 action = env .action_space .sample ()
195- next_obs , reward , terminated , truncated , _ = env .step (action )
196- done = terminated or truncated
197199
200+ step_output = env .step (action )
201+
202+ if len (step_output ) == 5 :
203+ next_obs , reward , terminated , truncated , info_dict = step_output
204+ done = terminated or truncated
205+ elif len (step_output ) == 4 :
206+ next_obs , reward , done , info_dict_step = step_output # Use different var name for step info
207+ terminated = done
208+ # Check for truncation more reliably if possible
209+ truncated = done and (info_dict_step .get ('TimeLimit.truncated' , False ) or episode_step_count + 1 >= max_episode_len if max_episode_len > 0 else False )
210+ info_dict .update (info_dict_step ) # Merge step info if needed
211+ else :
212+ raise ValueError (f"Unexpected number of values from env.step(): { len (step_output )} . Output: { step_output } " )
213+
198214 buf .add (obs .astype (np .float32 ), action , reward )
199215
200216 obs = next_obs
201217 current_steps += 1
218+ episode_step_count += 1
202219
203- if done or len ( buf ) == max_episode_len :
220+ if done or ( max_episode_len > 0 and episode_step_count >= max_episode_len ) :
204221 trajectories .append (buf .get_trajectory ())
205222 buf .reset ()
206- obs , _ = env .reset () # No need to re-seed here for subsequent episodes with random policy
207- if current_steps % (offline_steps // 10 ) == 0 and offline_steps > 0 : # Log progress
208- print (f" Collected { current_steps } /{ offline_steps } steps..." )
223+
224+ raw_reset_output_loop = env .reset ()
225+ if isinstance (raw_reset_output_loop , tuple ) and len (raw_reset_output_loop ) == 2 and isinstance (raw_reset_output_loop [1 ], dict ):
226+ obs , info_dict = raw_reset_output_loop
227+ else :
228+ obs = raw_reset_output_loop
229+ info_dict = {}
230+ episode_step_count = 0
231+
232+ # Log progress carefully
233+ if offline_steps > 0 : # Avoid division by zero if offline_steps is 0
234+ log_point = offline_steps // 10
235+ if log_point == 0 : log_point = 1 # Ensure it logs for small offline_steps
236+ if current_steps % log_point == 0 :
237+ print (f" Collected { current_steps } /{ offline_steps } steps..." )
209238
210- # After the loop, add any remaining trajectory in the buffer
211239 if len (buf ) > 0 :
212240 trajectories .append (buf .get_trajectory ())
213241 print (f" Added final partial trajectory of length { len (buf )} ." )
0 commit comments