@@ -47,22 +47,35 @@ def generate_random_trajectories(env_name, num_trajectories=1000, max_steps=1000
4747
4848def process_trajectories (trajectories , max_timesteps = 1000 ):
4949 # Find the maximum trajectory length
50- max_len = min (max ([t ['length' ] for t in trajectories ] ), max_timesteps )
50+ max_len = min (max ([t ['length' ] for t in trajectories if t [ 'length' ] > 0 ], default = 0 ), max_timesteps )
5151
52+ if not trajectories or max_len == 0 :
53+ # Handle case with no trajectories or all empty trajectories
54+ return {
55+ 'states' : np .array ([]), 'actions' : np .array ([]), 'returns_to_go' : np .array ([]),
56+ 'timesteps' : np .array ([]), 'mask' : np .array ([]),
57+ 'metadata' : {'state_dim' : 0 , 'act_dim' : 0 , 'max_timesteps' : max_timesteps }
58+ }
59+
5260 # Initialize arrays
5361 num_trajectories = len (trajectories )
5462 state_dim = trajectories [0 ]['states' ][0 ].shape [0 ]
55- act_dim = 1 # For discrete actions
5663
64+ # Determine action dimension from data
65+ all_actions = np .concatenate ([t ['actions' ] for t in trajectories if t ['length' ] > 0 ])
66+ act_dim = int (all_actions .max ()) + 1 if all_actions .size > 0 else 1
67+
5768 states = np .zeros ((num_trajectories , max_len + 1 , state_dim ))
58- actions = np .zeros ((num_trajectories , max_len , act_dim ))
69+ actions = np .zeros ((num_trajectories , max_len , 1 )) # Store actions as scalars
5970 returns_to_go = np .zeros ((num_trajectories , max_len + 1 , 1 ))
6071 timesteps = np .zeros ((num_trajectories , max_len + 1 ))
6172 mask = np .zeros ((num_trajectories , max_len + 1 ))
6273
6374 # Fill arrays
6475 for i , traj in enumerate (trajectories ):
6576 length = min (traj ['length' ], max_len )
77+ if length == 0 :
78+ continue
6679
6780 # States include the final state
6881 states [i , :length + 1 ] = traj ['states' ][:length + 1 ]
@@ -79,7 +92,7 @@ def process_trajectories(trajectories, max_timesteps=1000):
7992 # Create metadata
8093 metadata = {
8194 'state_dim' : state_dim ,
82- 'act_dim' : 2 , # CartPole has 2 actions
95+ 'act_dim' : act_dim ,
8396 'max_timesteps' : max_timesteps
8497 }
8598
0 commit comments