Skip to content

Commit 88c3c30

Browse files
Update training_phase1.py
1 parent 5a7a5c9 commit 88c3c30

1 file changed

Lines changed: 40 additions & 12 deletions

File tree

novel_phases/phase1/training_phase1.py

Lines changed: 40 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -177,37 +177,65 @@ def __len__(self):
177177

178178
def collect_trajectories(env_name: str, state_dim: int, act_dim: int, act_type: str,
179179
offline_steps: int, max_episode_len: int, gamma: float, seed: int):
180-
"""Collects offline_steps env steps with a random policy."""
181180
env = gym.make(env_name)
182-
# It's good practice to seed the env for reproducibility if it supports it, though random policy diminishes this
183-
# env.seed(seed) # Deprecated, use env.reset(seed=seed)
184-
181+
185182
trajectories = []
186183
buf = TrajectoryBuffer(max_episode_len, state_dim, act_dim, act_type)
187184

188185
print(f"Collecting {offline_steps} steps from {env_name} using a random policy...")
189186

190187
current_steps = 0
191-
obs, _ = env.reset(seed=seed) # Seed on reset
192188

189+
raw_reset_output = env.reset(seed=seed)
190+
if isinstance(raw_reset_output, tuple) and len(raw_reset_output) == 2 and isinstance(raw_reset_output[1], dict):
191+
obs, info_dict = raw_reset_output
192+
else:
193+
obs = raw_reset_output
194+
info_dict = {}
195+
196+
episode_step_count = 0
193197
while current_steps < offline_steps:
194198
action = env.action_space.sample()
195-
next_obs, reward, terminated, truncated, _ = env.step(action)
196-
done = terminated or truncated
197199

200+
step_output = env.step(action)
201+
202+
if len(step_output) == 5:
203+
next_obs, reward, terminated, truncated, info_dict = step_output
204+
done = terminated or truncated
205+
elif len(step_output) == 4:
206+
next_obs, reward, done, info_dict_step = step_output # Use different var name for step info
207+
terminated = done
208+
# Check for truncation more reliably if possible
209+
truncated = done and (info_dict_step.get('TimeLimit.truncated', False) or episode_step_count + 1 >= max_episode_len if max_episode_len > 0 else False)
210+
info_dict.update(info_dict_step) # Merge step info if needed
211+
else:
212+
raise ValueError(f"Unexpected number of values from env.step(): {len(step_output)}. Output: {step_output}")
213+
198214
buf.add(obs.astype(np.float32), action, reward)
199215

200216
obs = next_obs
201217
current_steps += 1
218+
episode_step_count +=1
202219

203-
if done or len(buf) == max_episode_len:
220+
if done or (max_episode_len > 0 and episode_step_count >= max_episode_len):
204221
trajectories.append(buf.get_trajectory())
205222
buf.reset()
206-
obs, _ = env.reset() # No need to re-seed here for subsequent episodes with random policy
207-
if current_steps % (offline_steps // 10) == 0 and offline_steps > 0: # Log progress
208-
print(f" Collected {current_steps}/{offline_steps} steps...")
223+
224+
raw_reset_output_loop = env.reset()
225+
if isinstance(raw_reset_output_loop, tuple) and len(raw_reset_output_loop) == 2 and isinstance(raw_reset_output_loop[1], dict):
226+
obs, info_dict = raw_reset_output_loop
227+
else:
228+
obs = raw_reset_output_loop
229+
info_dict = {}
230+
episode_step_count = 0
231+
232+
# Log progress carefully
233+
if offline_steps > 0 : # Avoid division by zero if offline_steps is 0
234+
log_point = offline_steps // 10
235+
if log_point == 0 : log_point = 1 # Ensure it logs for small offline_steps
236+
if current_steps % log_point == 0:
237+
print(f" Collected {current_steps}/{offline_steps} steps...")
209238

210-
# After the loop, add any remaining trajectory in the buffer
211239
if len(buf) > 0:
212240
trajectories.append(buf.get_trajectory())
213241
print(f" Added final partial trajectory of length {len(buf)}.")

0 commit comments

Comments
 (0)