@@ -57,62 +57,73 @@ class OfflineTransitionDataset(Dataset):
5757 def __init__ (self , dataset_path ):
5858 data = np .load (dataset_path )
5959
60- # First, calculate the total number of transitions
61- total_transitions = 0
62- for i in range (data ["states" ].shape [0 ]):
63- mask = data ["mask" ][i ]
64- clip_len = int (mask .sum ())
65- clip_len = min (clip_len , data ["states" ].shape [1 ]) # Cap clip_len
66- total_transitions += clip_len
67-
68- # Pre-allocate numpy arrays
69- state_dim = data ["states" ].shape [2 ]
70- action_dim = data ["actions" ].shape [2 ]
60+ states_list , actions_list , rewards_list , next_states_list , dones_list = [], [], [], [], []
7161
72- self .states = np .zeros ((total_transitions , state_dim ), dtype = np .float32 )
73- self .actions = np .zeros ((total_transitions , action_dim ), dtype = np .float32 )
74- self .rewards = np .zeros ((total_transitions , 1 ), dtype = np .float32 )
75- self .next_states = np .zeros ((total_transitions , state_dim ), dtype = np .float32 )
76- self .dones = np .zeros ((total_transitions , 1 ), dtype = np .float32 )
77-
78- current_idx = 0
62+ # Check if actions exist and get action_dim, otherwise handle gracefully
63+ if "actions" not in data or data ["actions" ].shape [0 ] == 0 :
64+ action_dim = 0 # Placeholder, this dataset will be empty
65+ else :
66+ action_dim = data ["actions" ].shape [2 ]
67+
7968 for i in range (data ["states" ].shape [0 ]):
8069 mask = data ["mask" ][i ]
8170 clip_len = int (mask .sum ())
82- clip_len = min (clip_len , data ["states" ].shape [1 ]) # Cap clip_len
83-
71+
8472 if clip_len == 0 :
8573 continue
8674
87- # Get the trajectory data
75+ # Trajectory data
8876 traj_states = data ["states" ][i , :clip_len ]
89- traj_actions = data ["actions" ][i , :clip_len ]
9077 traj_rtg = data ["returns_to_go" ][i , :clip_len ]
9178
92- # Populate states and actions
93- self .states [current_idx : current_idx + clip_len ] = traj_states
94- self .actions [current_idx : current_idx + clip_len ] = traj_actions
79+ # Add all states for this trajectory to the list
80+ states_list .append (traj_states )
9581
96- # Populate rewards, next_states, and dones
82+ # Actions: N-1 real actions, plus one dummy action for the terminal state
83+ traj_actions_list = []
9784 if clip_len > 1 :
98- self .rewards [current_idx : current_idx + clip_len - 1 ] = (traj_rtg [:- 1 ] - traj_rtg [1 :]).reshape (- 1 , 1 )
99- self .next_states [current_idx : current_idx + clip_len - 1 ] = traj_states [1 :]
100- self .dones [current_idx : current_idx + clip_len - 1 ] = 0.0
101-
102- # Final transition
103- if clip_len > 0 :
104- self .rewards [current_idx + clip_len - 1 ] = traj_rtg [- 1 ]
105- self .next_states [current_idx + clip_len - 1 ] = np .zeros_like (traj_states [- 1 ])
106- self .dones [current_idx + clip_len - 1 ] = 1.0
85+ traj_actions_list .append (data ["actions" ][i , :clip_len - 1 ])
86+
87+ # Add a dummy action for the terminal state
88+ dummy_action = np .zeros ((1 , action_dim ), dtype = np .float32 )
89+ traj_actions_list .append (dummy_action )
90+ actions_list .append (np .concatenate (traj_actions_list , axis = 0 ))
91+
92+ # Rewards and next_states
93+ rewards = np .zeros ((clip_len , 1 ), dtype = np .float32 )
94+ next_states = np .zeros_like (traj_states )
95+ dones = np .zeros ((clip_len , 1 ), dtype = np .float32 )
96+
97+ if clip_len > 1 :
98+ # Rewards for non-terminal states
99+ rewards [:- 1 ] = (traj_rtg [:- 1 ] - traj_rtg [1 :]).reshape (- 1 , 1 )
100+ # Next_states for non-terminal states
101+ next_states [:- 1 ] = traj_states [1 :]
102+
103+ # Handle terminal transition
104+ rewards [- 1 ] = traj_rtg [- 1 ] # This is RTG, not a true reward, but matches original logic
105+ # next_states[-1] is already zeros
106+ dones [- 1 ] = 1.0
107107
108- current_idx += clip_len
109-
110- # Convert to torch tensors
111- self .states = torch .from_numpy (self .states ).float ()
112- self .actions = torch .from_numpy (self .actions ).float ()
113- self .rewards = torch .from_numpy (self .rewards ).float ()
114- self .next_states = torch .from_numpy (self .next_states ).float ()
115- self .dones = torch .from_numpy (self .dones ).float ()
108+ rewards_list .append (rewards )
109+ next_states_list .append (next_states )
110+ dones_list .append (dones )
111+
112+ # Handle case where no valid trajectories were found
113+ if not states_list :
114+ self .states = torch .empty (0 , data ["states" ].shape [2 ], dtype = torch .float32 )
115+ self .actions = torch .empty (0 , action_dim , dtype = torch .float32 )
116+ self .rewards = torch .empty (0 , 1 , dtype = torch .float32 )
117+ self .next_states = torch .empty (0 , data ["states" ].shape [2 ], dtype = torch .float32 )
118+ self .dones = torch .empty (0 , 1 , dtype = torch .float32 )
119+ return
120+
121+ # Concatenate and convert to torch tensors
122+ self .states = torch .from_numpy (np .concatenate (states_list , axis = 0 )).float ()
123+ self .actions = torch .from_numpy (np .concatenate (actions_list , axis = 0 )).float ()
124+ self .rewards = torch .from_numpy (np .concatenate (rewards_list , axis = 0 )).float ()
125+ self .next_states = torch .from_numpy (np .concatenate (next_states_list , axis = 0 )).float ()
126+ self .dones = torch .from_numpy (np .concatenate (dones_list , axis = 0 )).float ()
116127
117128 def __len__ (self ):
118129 return len (self .states )
0 commit comments