Skip to content

Commit 763ad9c

Browse files
Create d4rl_dataset.py
1 parent d47b77a commit 763ad9c

1 file changed

Lines changed: 328 additions & 0 deletions

File tree

snn-dt/src/utils/d4rl_dataset.py

Lines changed: 328 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,328 @@
1+
import os
2+
import h5py
3+
import numpy as np
4+
import torch
5+
from torch.utils.data import Dataset
6+
import logging
7+
8+
logger = logging.getLogger(__name__)
9+
10+
def load_data_from_h5(h5_path):
11+
if not os.path.exists(h5_path):
12+
raise FileNotFoundError(f"Dataset not found at {h5_path}")
13+
14+
with h5py.File(h5_path, 'r') as f:
15+
observations = f['observations'][:]
16+
actions = f['actions'][:]
17+
rewards = f['rewards'][:]
18+
terminals = f['terminals'][:]
19+
# Handle timeouts
20+
if 'timeouts' in f:
21+
timeouts = f['timeouts'][:]
22+
else:
23+
timeouts = np.zeros_like(terminals)
24+
25+
return observations, actions, rewards, terminals, timeouts
26+
27+
def compute_trajectories(observations, actions, rewards, terminals, timeouts):
28+
trajectories = []
29+
30+
N = observations.shape[0]
31+
start = 0
32+
33+
for i in range(N):
34+
done_bool = bool(terminals[i])
35+
final_timestep = bool(timeouts[i])
36+
37+
if done_bool or final_timestep:
38+
end = i + 1
39+
traj = {
40+
'observations': observations[start:end],
41+
'actions': actions[start:end],
42+
'rewards': rewards[start:end],
43+
'dones': terminals[start:end], # Keep as is (0 or 1)
44+
'length': end - start
45+
}
46+
trajectories.append(traj)
47+
start = end
48+
49+
# Handle last trajectory if not terminated explicitly
50+
if start < N:
51+
traj = {
52+
'observations': observations[start:],
53+
'actions': actions[start:],
54+
'rewards': rewards[start:],
55+
'dones': terminals[start:],
56+
'length': N - start
57+
}
58+
trajectories.append(traj)
59+
60+
return trajectories
61+
62+
def compute_rtg(trajectories):
63+
for traj in trajectories:
64+
rewards = traj['rewards']
65+
rtg = np.zeros_like(rewards)
66+
running_return = 0
67+
for t in reversed(range(len(rewards))):
68+
running_return += rewards[t]
69+
rtg[t] = running_return
70+
traj['returns_to_go'] = rtg
71+
return trajectories
72+
73+
def compute_normalization(trajectories):
74+
states = []
75+
for traj in trajectories:
76+
states.append(traj['observations'])
77+
states = np.concatenate(states, axis=0)
78+
79+
mean = np.mean(states, axis=0)
80+
std = np.std(states, axis=0) + 1e-6 # Avoid div by zero
81+
82+
return mean, std
83+
84+
class D4RLSequenceDataset(Dataset):
85+
def __init__(self, env_name, data_dir="data/d4rl_raw", seq_len=50):
86+
self.env_name = env_name
87+
self.seq_len = seq_len
88+
89+
# Construct filename. logic: hopper-medium-v2 -> hopper_medium-v2.hdf5?
90+
# Re-using logic from download/convert:
91+
# download stores as {url_filename}.
92+
# url: .../hopper_medium-v2.hdf5
93+
# So look for hopper_medium-v2.hdf5 if env is hopper-medium-v2.
94+
# But wait, env name has hyphens. Filename usually has underscores.
95+
# Try both.
96+
97+
filename = f"{env_name}.hdf5"
98+
path = os.path.join(data_dir, filename)
99+
if not os.path.exists(path):
100+
# Try underscore
101+
filename_us = env_name.replace('-', '_') + ".hdf5"
102+
path = os.path.join(data_dir, filename_us)
103+
104+
if not os.path.exists(path):
105+
# Try partial underscore (d4rl style: hopper_medium-v2.hdf5)
106+
# Split by -v
107+
parts = env_name.split('-v')
108+
base = parts[0].replace('-', '_')
109+
suffix = f"-v{parts[1]}"
110+
filename_mixed = base + suffix + ".hdf5"
111+
path = os.path.join(data_dir, filename_mixed)
112+
113+
logger.info(f"Loading dataset from {path}")
114+
115+
obs, act, rew, term, time = load_data_from_h5(path)
116+
self.trajectories = compute_trajectories(obs, act, rew, term, time)
117+
self.trajectories = compute_rtg(self.trajectories)
118+
self.state_mean, self.state_std = compute_normalization(self.trajectories)
119+
120+
# Pre-compute valid indices
121+
self.indices = []
122+
for i, traj in enumerate(self.trajectories):
123+
# For each trajectory, valid start indices
124+
# We want windows of length seq_len.
125+
# If traj_len < seq_len, we only have 1 window (0 to traj_len, padded)
126+
# If traj_len >= seq_len, we have traj_len - seq_len + 1 windows?
127+
# Standard DT: samples random t in [0, traj_len - 1].
128+
# Then takes [t, t+seq_len].
129+
# Pads if goes over.
130+
# I will follow this "sample any start point" logic to maximize data usage.
131+
132+
T = traj['length']
133+
for t in range(T):
134+
self.indices.append((i, t))
135+
136+
self.state_dim = self.state_mean.shape[0]
137+
self.act_dim = self.trajectories[0]['actions'].shape[1]
138+
139+
# Check discrete/continuous from data (heuristic)
140+
# Actually passed from config usually, but we can guess.
141+
# D4RL MuJoCo is continuous.
142+
self.is_discrete = False
143+
144+
def __len__(self):
145+
return len(self.indices)
146+
147+
def __getitem__(self, idx):
148+
traj_idx, start_t = self.indices[idx]
149+
traj = self.trajectories[traj_idx]
150+
T = traj['length']
151+
152+
# Determine end index
153+
end_t = start_t + self.seq_len
154+
155+
# Prepare buffers
156+
states = np.zeros((self.seq_len, self.state_dim), dtype=np.float32)
157+
actions = np.zeros((self.seq_len, self.act_dim), dtype=np.float32)
158+
rewards = np.zeros((self.seq_len, 1), dtype=np.float32)
159+
rtg = np.zeros((self.seq_len, 1), dtype=np.float32)
160+
timesteps = np.zeros((self.seq_len), dtype=np.int64)
161+
mask = np.zeros((self.seq_len), dtype=np.float32)
162+
dones = np.zeros((self.seq_len, 1), dtype=np.float32)
163+
164+
# Calculate real data range
165+
real_end_t = min(end_t, T)
166+
real_len = real_end_t - start_t
167+
168+
# Extract data
169+
s_data = traj['observations'][start_t:real_end_t]
170+
a_data = traj['actions'][start_t:real_end_t]
171+
r_data = traj['rewards'][start_t:real_end_t]
172+
rtg_data = traj['returns_to_go'][start_t:real_end_t]
173+
d_data = traj['dones'][start_t:real_end_t]
174+
175+
# Normalize states
176+
s_data = (s_data - self.state_mean) / self.state_std
177+
178+
# Fill buffers
179+
states[:real_len] = s_data
180+
actions[:real_len] = a_data
181+
rewards[:real_len] = r_data.reshape(-1, 1)
182+
rtg[:real_len] = rtg_data.reshape(-1, 1)
183+
timesteps[:real_len] = np.arange(start_t, real_end_t)
184+
mask[:real_len] = 1.0
185+
dones[:real_len] = d_data.reshape(-1, 1)
186+
187+
return {
188+
"states": torch.from_numpy(states),
189+
"actions": torch.from_numpy(actions),
190+
"rewards": torch.from_numpy(rewards),
191+
"returns_to_go": torch.from_numpy(rtg),
192+
"timesteps": torch.from_numpy(timesteps),
193+
"mask": torch.from_numpy(mask),
194+
"dones": torch.from_numpy(dones) # Optional, but good to have
195+
}
196+
197+
class D4RLTransitionDataset(Dataset):
198+
def __init__(self, env_name, data_dir="data/d4rl_raw"):
199+
self.env_name = env_name
200+
201+
# Similar filename logic
202+
filename = f"{env_name}.hdf5"
203+
path = os.path.join(data_dir, filename)
204+
if not os.path.exists(path):
205+
filename_us = env_name.replace('-', '_') + ".hdf5"
206+
path = os.path.join(data_dir, filename_us)
207+
if not os.path.exists(path):
208+
parts = env_name.split('-v')
209+
base = parts[0].replace('-', '_')
210+
suffix = f"-v{parts[1]}"
211+
filename_mixed = base + suffix + ".hdf5"
212+
path = os.path.join(data_dir, filename_mixed)
213+
214+
logger.info(f"Loading transition dataset from {path}")
215+
216+
obs, act, rew, term, time = load_data_from_h5(path)
217+
218+
# For transitions (s, a, r, s'), we need next states.
219+
# We can reconstruct next states from observations: s[t+1]
220+
# But we need to be careful about boundaries.
221+
222+
# Vectorized transition creation
223+
# Identify terminals to mask out transitions crossing episodes
224+
# terminals[i] means step i is terminal. Next step i+1 is start of new episode (or end of data).
225+
# We want (s_i, a_i, r_i, s_{i+1}, d_i).
226+
# If d_i is True, s_{i+1} might be invalid or from next episode.
227+
# In D4RL, if d_i=True, s_{i+1} is usually the reset state of next traj.
228+
# But for offline RL, we treat s_{i+1} as terminal state if available, or just mask it.
229+
# However, many algorithms expect 'next_state' to calculate target Q.
230+
# If done=True, target Q is usually just r. So next_state doesn't matter much (but should be valid shape).
231+
232+
N = obs.shape[0]
233+
234+
# Create next_obs array
235+
next_obs = np.zeros_like(obs)
236+
next_obs[:-1] = obs[1:]
237+
next_obs[-1] = obs[-1] # Fallback
238+
239+
# Compute mean/std
240+
self.state_mean = np.mean(obs, axis=0)
241+
self.state_std = np.std(obs, axis=0) + 1e-6
242+
243+
# Normalize current states
244+
self.states = (obs - self.state_mean) / self.state_std
245+
# Normalize next states
246+
self.next_states = (next_obs - self.state_mean) / self.state_std
247+
248+
self.actions = act
249+
self.rewards = rew.reshape(-1, 1)
250+
self.dones = term.reshape(-1, 1)
251+
252+
# Filter out invalid transitions (where step i was terminal or timeout, so i+1 is not next state)
253+
# Actually, if step i is terminal, (s_i, a_i, r_i, s_i', d_i=1) is valid.
254+
# But s_i' (next_obs[i]) corresponds to obs[i+1], which is START of next episode.
255+
# This is WRONG. s_i' should be the terminal state of current episode.
256+
# But D4RL often doesn't store the final observation after 'done'.
257+
# However, standard practice is: if done, next_state doesn't matter for Q-value (masked by 1-done).
258+
# But we must ensure we don't train on (s_T, a_T, r_T, s_{0_new}, done) where s_{0_new} belongs to next trajectory
259+
# if the algorithm relies on s'.
260+
# With done=1, term in Bellman eq zeroes out V(s'), so s' value is ignored.
261+
# BUT, if it's a TIMEOUT (truncation), done=0 but we shouldn't bootstrap from next episode start.
262+
# D4RL has 'timeouts'.
263+
264+
valid_mask = np.ones(N, dtype=bool)
265+
266+
# Mark steps where i is end of trajectory (timeout or terminal)
267+
# If timeout[i] is True, then i is last step. i+1 is new traj.
268+
# We should probably NOT use the transition (s_i, ..., s_{i+1}) if it's a timeout?
269+
# Or we treat it as done=0 but mask it?
270+
# Standard: keep it, but ensure next_state is handled?
271+
# Actually, simpler approach:
272+
# Use the computed trajectories from before to be safe.
273+
274+
self.trajectories = compute_trajectories(obs, act, rew, term, time)
275+
276+
# Rebuild flat arrays from trajectories to ensure correctness
277+
s_list, a_list, r_list, ns_list, d_list = [], [], [], [], []
278+
279+
for traj in self.trajectories:
280+
t_s = traj['observations']
281+
t_a = traj['actions']
282+
t_r = traj['rewards']
283+
t_d = traj['dones']
284+
L = len(t_s)
285+
286+
# For each step t in 0..L-1
287+
# Next state:
288+
# If t < L-1: s[t+1]
289+
# If t == L-1:
290+
# If done=True, s' is terminal (unknown/irrelevant). We can use s[t].
291+
# If done=False (timeout), s' is unknown (truncated).
292+
293+
# Normalize traj states first
294+
t_s_norm = (t_s - self.state_mean) / self.state_std
295+
296+
# Transitions 0 to L-2
297+
if L > 1:
298+
s_list.append(t_s_norm[:-1])
299+
a_list.append(t_a[:-1])
300+
r_list.append(t_r[:-1])
301+
ns_list.append(t_s_norm[1:])
302+
d_list.append(t_d[:-1])
303+
304+
# Last transition L-1
305+
s_list.append(t_s_norm[-1].reshape(1, -1))
306+
a_list.append(t_a[-1].reshape(1, -1))
307+
r_list.append(t_r[-1].reshape(1))
308+
# Next state for last step: duplicate current state (safe if done=1)
309+
ns_list.append(t_s_norm[-1].reshape(1, -1))
310+
d_list.append(t_d[-1].reshape(1))
311+
312+
self.states = np.concatenate(s_list, axis=0).astype(np.float32)
313+
self.actions = np.concatenate(a_list, axis=0).astype(np.float32)
314+
self.rewards = np.concatenate(r_list, axis=0).astype(np.float32).reshape(-1, 1)
315+
self.next_states = np.concatenate(ns_list, axis=0).astype(np.float32)
316+
self.dones = np.concatenate(d_list, axis=0).astype(np.float32).reshape(-1, 1)
317+
318+
def __len__(self):
319+
return len(self.states)
320+
321+
def __getitem__(self, idx):
322+
return {
323+
"states": torch.from_numpy(self.states[idx]),
324+
"actions": torch.from_numpy(self.actions[idx]),
325+
"rewards": torch.from_numpy(self.rewards[idx]),
326+
"next_states": torch.from_numpy(self.next_states[idx]),
327+
"dones": torch.from_numpy(self.dones[idx])
328+
}

0 commit comments

Comments
 (0)