1+ import os
2+ import h5py
3+ import numpy as np
4+ import torch
5+ from torch .utils .data import Dataset
6+ import logging
7+
8+ logger = logging .getLogger (__name__ )
9+
10+ def load_data_from_h5 (h5_path ):
11+ if not os .path .exists (h5_path ):
12+ raise FileNotFoundError (f"Dataset not found at { h5_path } " )
13+
14+ with h5py .File (h5_path , 'r' ) as f :
15+ observations = f ['observations' ][:]
16+ actions = f ['actions' ][:]
17+ rewards = f ['rewards' ][:]
18+ terminals = f ['terminals' ][:]
19+ # Handle timeouts
20+ if 'timeouts' in f :
21+ timeouts = f ['timeouts' ][:]
22+ else :
23+ timeouts = np .zeros_like (terminals )
24+
25+ return observations , actions , rewards , terminals , timeouts
26+
27+ def compute_trajectories (observations , actions , rewards , terminals , timeouts ):
28+ trajectories = []
29+
30+ N = observations .shape [0 ]
31+ start = 0
32+
33+ for i in range (N ):
34+ done_bool = bool (terminals [i ])
35+ final_timestep = bool (timeouts [i ])
36+
37+ if done_bool or final_timestep :
38+ end = i + 1
39+ traj = {
40+ 'observations' : observations [start :end ],
41+ 'actions' : actions [start :end ],
42+ 'rewards' : rewards [start :end ],
43+ 'dones' : terminals [start :end ], # Keep as is (0 or 1)
44+ 'length' : end - start
45+ }
46+ trajectories .append (traj )
47+ start = end
48+
49+ # Handle last trajectory if not terminated explicitly
50+ if start < N :
51+ traj = {
52+ 'observations' : observations [start :],
53+ 'actions' : actions [start :],
54+ 'rewards' : rewards [start :],
55+ 'dones' : terminals [start :],
56+ 'length' : N - start
57+ }
58+ trajectories .append (traj )
59+
60+ return trajectories
61+
62+ def compute_rtg (trajectories ):
63+ for traj in trajectories :
64+ rewards = traj ['rewards' ]
65+ rtg = np .zeros_like (rewards )
66+ running_return = 0
67+ for t in reversed (range (len (rewards ))):
68+ running_return += rewards [t ]
69+ rtg [t ] = running_return
70+ traj ['returns_to_go' ] = rtg
71+ return trajectories
72+
73+ def compute_normalization (trajectories ):
74+ states = []
75+ for traj in trajectories :
76+ states .append (traj ['observations' ])
77+ states = np .concatenate (states , axis = 0 )
78+
79+ mean = np .mean (states , axis = 0 )
80+ std = np .std (states , axis = 0 ) + 1e-6 # Avoid div by zero
81+
82+ return mean , std
83+
84+ class D4RLSequenceDataset (Dataset ):
85+ def __init__ (self , env_name , data_dir = "data/d4rl_raw" , seq_len = 50 ):
86+ self .env_name = env_name
87+ self .seq_len = seq_len
88+
89+ # Construct filename. logic: hopper-medium-v2 -> hopper_medium-v2.hdf5?
90+ # Re-using logic from download/convert:
91+ # download stores as {url_filename}.
92+ # url: .../hopper_medium-v2.hdf5
93+ # So look for hopper_medium-v2.hdf5 if env is hopper-medium-v2.
94+ # But wait, env name has hyphens. Filename usually has underscores.
95+ # Try both.
96+
97+ filename = f"{ env_name } .hdf5"
98+ path = os .path .join (data_dir , filename )
99+ if not os .path .exists (path ):
100+ # Try underscore
101+ filename_us = env_name .replace ('-' , '_' ) + ".hdf5"
102+ path = os .path .join (data_dir , filename_us )
103+
104+ if not os .path .exists (path ):
105+ # Try partial underscore (d4rl style: hopper_medium-v2.hdf5)
106+ # Split by -v
107+ parts = env_name .split ('-v' )
108+ base = parts [0 ].replace ('-' , '_' )
109+ suffix = f"-v{ parts [1 ]} "
110+ filename_mixed = base + suffix + ".hdf5"
111+ path = os .path .join (data_dir , filename_mixed )
112+
113+ logger .info (f"Loading dataset from { path } " )
114+
115+ obs , act , rew , term , time = load_data_from_h5 (path )
116+ self .trajectories = compute_trajectories (obs , act , rew , term , time )
117+ self .trajectories = compute_rtg (self .trajectories )
118+ self .state_mean , self .state_std = compute_normalization (self .trajectories )
119+
120+ # Pre-compute valid indices
121+ self .indices = []
122+ for i , traj in enumerate (self .trajectories ):
123+ # For each trajectory, valid start indices
124+ # We want windows of length seq_len.
125+ # If traj_len < seq_len, we only have 1 window (0 to traj_len, padded)
126+ # If traj_len >= seq_len, we have traj_len - seq_len + 1 windows?
127+ # Standard DT: samples random t in [0, traj_len - 1].
128+ # Then takes [t, t+seq_len].
129+ # Pads if goes over.
130+ # I will follow this "sample any start point" logic to maximize data usage.
131+
132+ T = traj ['length' ]
133+ for t in range (T ):
134+ self .indices .append ((i , t ))
135+
136+ self .state_dim = self .state_mean .shape [0 ]
137+ self .act_dim = self .trajectories [0 ]['actions' ].shape [1 ]
138+
139+ # Check discrete/continuous from data (heuristic)
140+ # Actually passed from config usually, but we can guess.
141+ # D4RL MuJoCo is continuous.
142+ self .is_discrete = False
143+
144+ def __len__ (self ):
145+ return len (self .indices )
146+
147+ def __getitem__ (self , idx ):
148+ traj_idx , start_t = self .indices [idx ]
149+ traj = self .trajectories [traj_idx ]
150+ T = traj ['length' ]
151+
152+ # Determine end index
153+ end_t = start_t + self .seq_len
154+
155+ # Prepare buffers
156+ states = np .zeros ((self .seq_len , self .state_dim ), dtype = np .float32 )
157+ actions = np .zeros ((self .seq_len , self .act_dim ), dtype = np .float32 )
158+ rewards = np .zeros ((self .seq_len , 1 ), dtype = np .float32 )
159+ rtg = np .zeros ((self .seq_len , 1 ), dtype = np .float32 )
160+ timesteps = np .zeros ((self .seq_len ), dtype = np .int64 )
161+ mask = np .zeros ((self .seq_len ), dtype = np .float32 )
162+ dones = np .zeros ((self .seq_len , 1 ), dtype = np .float32 )
163+
164+ # Calculate real data range
165+ real_end_t = min (end_t , T )
166+ real_len = real_end_t - start_t
167+
168+ # Extract data
169+ s_data = traj ['observations' ][start_t :real_end_t ]
170+ a_data = traj ['actions' ][start_t :real_end_t ]
171+ r_data = traj ['rewards' ][start_t :real_end_t ]
172+ rtg_data = traj ['returns_to_go' ][start_t :real_end_t ]
173+ d_data = traj ['dones' ][start_t :real_end_t ]
174+
175+ # Normalize states
176+ s_data = (s_data - self .state_mean ) / self .state_std
177+
178+ # Fill buffers
179+ states [:real_len ] = s_data
180+ actions [:real_len ] = a_data
181+ rewards [:real_len ] = r_data .reshape (- 1 , 1 )
182+ rtg [:real_len ] = rtg_data .reshape (- 1 , 1 )
183+ timesteps [:real_len ] = np .arange (start_t , real_end_t )
184+ mask [:real_len ] = 1.0
185+ dones [:real_len ] = d_data .reshape (- 1 , 1 )
186+
187+ return {
188+ "states" : torch .from_numpy (states ),
189+ "actions" : torch .from_numpy (actions ),
190+ "rewards" : torch .from_numpy (rewards ),
191+ "returns_to_go" : torch .from_numpy (rtg ),
192+ "timesteps" : torch .from_numpy (timesteps ),
193+ "mask" : torch .from_numpy (mask ),
194+ "dones" : torch .from_numpy (dones ) # Optional, but good to have
195+ }
196+
197+ class D4RLTransitionDataset (Dataset ):
198+ def __init__ (self , env_name , data_dir = "data/d4rl_raw" ):
199+ self .env_name = env_name
200+
201+ # Similar filename logic
202+ filename = f"{ env_name } .hdf5"
203+ path = os .path .join (data_dir , filename )
204+ if not os .path .exists (path ):
205+ filename_us = env_name .replace ('-' , '_' ) + ".hdf5"
206+ path = os .path .join (data_dir , filename_us )
207+ if not os .path .exists (path ):
208+ parts = env_name .split ('-v' )
209+ base = parts [0 ].replace ('-' , '_' )
210+ suffix = f"-v{ parts [1 ]} "
211+ filename_mixed = base + suffix + ".hdf5"
212+ path = os .path .join (data_dir , filename_mixed )
213+
214+ logger .info (f"Loading transition dataset from { path } " )
215+
216+ obs , act , rew , term , time = load_data_from_h5 (path )
217+
218+ # For transitions (s, a, r, s'), we need next states.
219+ # We can reconstruct next states from observations: s[t+1]
220+ # But we need to be careful about boundaries.
221+
222+ # Vectorized transition creation
223+ # Identify terminals to mask out transitions crossing episodes
224+ # terminals[i] means step i is terminal. Next step i+1 is start of new episode (or end of data).
225+ # We want (s_i, a_i, r_i, s_{i+1}, d_i).
226+ # If d_i is True, s_{i+1} might be invalid or from next episode.
227+ # In D4RL, if d_i=True, s_{i+1} is usually the reset state of next traj.
228+ # But for offline RL, we treat s_{i+1} as terminal state if available, or just mask it.
229+ # However, many algorithms expect 'next_state' to calculate target Q.
230+ # If done=True, target Q is usually just r. So next_state doesn't matter much (but should be valid shape).
231+
232+ N = obs .shape [0 ]
233+
234+ # Create next_obs array
235+ next_obs = np .zeros_like (obs )
236+ next_obs [:- 1 ] = obs [1 :]
237+ next_obs [- 1 ] = obs [- 1 ] # Fallback
238+
239+ # Compute mean/std
240+ self .state_mean = np .mean (obs , axis = 0 )
241+ self .state_std = np .std (obs , axis = 0 ) + 1e-6
242+
243+ # Normalize current states
244+ self .states = (obs - self .state_mean ) / self .state_std
245+ # Normalize next states
246+ self .next_states = (next_obs - self .state_mean ) / self .state_std
247+
248+ self .actions = act
249+ self .rewards = rew .reshape (- 1 , 1 )
250+ self .dones = term .reshape (- 1 , 1 )
251+
252+ # Filter out invalid transitions (where step i was terminal or timeout, so i+1 is not next state)
253+ # Actually, if step i is terminal, (s_i, a_i, r_i, s_i', d_i=1) is valid.
254+ # But s_i' (next_obs[i]) corresponds to obs[i+1], which is START of next episode.
255+ # This is WRONG. s_i' should be the terminal state of current episode.
256+ # But D4RL often doesn't store the final observation after 'done'.
257+ # However, standard practice is: if done, next_state doesn't matter for Q-value (masked by 1-done).
258+ # But we must ensure we don't train on (s_T, a_T, r_T, s_{0_new}, done) where s_{0_new} belongs to next trajectory
259+ # if the algorithm relies on s'.
260+ # With done=1, term in Bellman eq zeroes out V(s'), so s' value is ignored.
261+ # BUT, if it's a TIMEOUT (truncation), done=0 but we shouldn't bootstrap from next episode start.
262+ # D4RL has 'timeouts'.
263+
264+ valid_mask = np .ones (N , dtype = bool )
265+
266+ # Mark steps where i is end of trajectory (timeout or terminal)
267+ # If timeout[i] is True, then i is last step. i+1 is new traj.
268+ # We should probably NOT use the transition (s_i, ..., s_{i+1}) if it's a timeout?
269+ # Or we treat it as done=0 but mask it?
270+ # Standard: keep it, but ensure next_state is handled?
271+ # Actually, simpler approach:
272+ # Use the computed trajectories from before to be safe.
273+
274+ self .trajectories = compute_trajectories (obs , act , rew , term , time )
275+
276+ # Rebuild flat arrays from trajectories to ensure correctness
277+ s_list , a_list , r_list , ns_list , d_list = [], [], [], [], []
278+
279+ for traj in self .trajectories :
280+ t_s = traj ['observations' ]
281+ t_a = traj ['actions' ]
282+ t_r = traj ['rewards' ]
283+ t_d = traj ['dones' ]
284+ L = len (t_s )
285+
286+ # For each step t in 0..L-1
287+ # Next state:
288+ # If t < L-1: s[t+1]
289+ # If t == L-1:
290+ # If done=True, s' is terminal (unknown/irrelevant). We can use s[t].
291+ # If done=False (timeout), s' is unknown (truncated).
292+
293+ # Normalize traj states first
294+ t_s_norm = (t_s - self .state_mean ) / self .state_std
295+
296+ # Transitions 0 to L-2
297+ if L > 1 :
298+ s_list .append (t_s_norm [:- 1 ])
299+ a_list .append (t_a [:- 1 ])
300+ r_list .append (t_r [:- 1 ])
301+ ns_list .append (t_s_norm [1 :])
302+ d_list .append (t_d [:- 1 ])
303+
304+ # Last transition L-1
305+ s_list .append (t_s_norm [- 1 ].reshape (1 , - 1 ))
306+ a_list .append (t_a [- 1 ].reshape (1 , - 1 ))
307+ r_list .append (t_r [- 1 ].reshape (1 ))
308+ # Next state for last step: duplicate current state (safe if done=1)
309+ ns_list .append (t_s_norm [- 1 ].reshape (1 , - 1 ))
310+ d_list .append (t_d [- 1 ].reshape (1 ))
311+
312+ self .states = np .concatenate (s_list , axis = 0 ).astype (np .float32 )
313+ self .actions = np .concatenate (a_list , axis = 0 ).astype (np .float32 )
314+ self .rewards = np .concatenate (r_list , axis = 0 ).astype (np .float32 ).reshape (- 1 , 1 )
315+ self .next_states = np .concatenate (ns_list , axis = 0 ).astype (np .float32 )
316+ self .dones = np .concatenate (d_list , axis = 0 ).astype (np .float32 ).reshape (- 1 , 1 )
317+
318+ def __len__ (self ):
319+ return len (self .states )
320+
321+ def __getitem__ (self , idx ):
322+ return {
323+ "states" : torch .from_numpy (self .states [idx ]),
324+ "actions" : torch .from_numpy (self .actions [idx ]),
325+ "rewards" : torch .from_numpy (self .rewards [idx ]),
326+ "next_states" : torch .from_numpy (self .next_states [idx ]),
327+ "dones" : torch .from_numpy (self .dones [idx ])
328+ }
0 commit comments