@@ -60,31 +60,19 @@ class OfflineTransitionDataset(Dataset):
6060 def __init__ (self , dataset_path ):
6161 data = np .load (dataset_path , mmap_mode = 'r' )
6262
63- # Pre-calculate total number of transitions to pre-allocate memmapped arrays
6463 total_transitions = 0
6564 for i in range (data ['mask' ].shape [0 ]):
6665 total_transitions += int (data ['mask' ][i ].sum ())
6766
6867 state_dim = data ['states' ].shape [2 ]
6968 action_dim = data ['actions' ].shape [2 ] if 'actions' in data .keys () and data ['actions' ].shape [0 ] > 0 else 0
7069
71- # Create a temporary directory for memory-mapped files
72- self .temp_dir = tempfile .mkdtemp ()
73- atexit .register (self ._cleanup )
74-
75- self .states_mmap_path = os .path .join (self .temp_dir , 'states.mmap' )
76- self .actions_mmap_path = os .path .join (self .temp_dir , 'actions.mmap' )
77- self .rewards_mmap_path = os .path .join (self .temp_dir , 'rewards.mmap' )
78- self .next_states_mmap_path = os .path .join (self .temp_dir , 'next_states.mmap' )
79- self .dones_mmap_path = os .path .join (self .temp_dir , 'dones.mmap' )
80-
81- # Pre-allocate memory-mapped arrays
82- self .states = np .memmap (self .states_mmap_path , dtype = np .float32 , mode = 'w+' , shape = (total_transitions , state_dim ))
83- # Use int64 for discrete actions
84- self .actions = np .memmap (self .actions_mmap_path , dtype = np .int64 , mode = 'w+' , shape = (total_transitions , action_dim ))
85- self .rewards = np .memmap (self .rewards_mmap_path , dtype = np .float32 , mode = 'w+' , shape = (total_transitions , 1 ))
86- self .next_states = np .memmap (self .next_states_mmap_path , dtype = np .float32 , mode = 'w+' , shape = (total_transitions , state_dim ))
87- self .dones = np .memmap (self .dones_mmap_path , dtype = np .float32 , mode = 'w+' , shape = (total_transitions , 1 ))
70+ # Pre-allocate arrays in memory
71+ self .states = np .empty ((total_transitions , state_dim ), dtype = np .float32 )
72+ self .actions = np .empty ((total_transitions , action_dim ), dtype = np .int64 )
73+ self .rewards = np .empty ((total_transitions , 1 ), dtype = np .float32 )
74+ self .next_states = np .empty ((total_transitions , state_dim ), dtype = np .float32 )
75+ self .dones = np .empty ((total_transitions , 1 ), dtype = np .float32 )
8876
8977 current_idx = 0
9078 for i in range (data ['states' ].shape [0 ]):
@@ -94,16 +82,13 @@ def __init__(self, dataset_path):
9482 if clip_len == 0 :
9583 continue
9684
97- # Trajectory data
9885 traj_states = data ['states' ][i , :clip_len ]
9986 traj_rtg = data ['returns_to_go' ][i , :clip_len ]
10087
101- # Actions
10288 traj_actions = np .zeros ((clip_len , action_dim ), dtype = np .int64 )
10389 if clip_len > 1 :
10490 traj_actions [:clip_len - 1 ] = data ['actions' ][i , :clip_len - 1 ].astype (np .int64 )
10591
106- # Rewards and next_states
10792 rewards = np .zeros ((clip_len , 1 ), dtype = np .float32 )
10893 next_states = np .zeros_like (traj_states )
10994 dones = np .zeros ((clip_len , 1 ), dtype = np .float32 )
@@ -115,7 +100,6 @@ def __init__(self, dataset_path):
115100 rewards [- 1 ] = traj_rtg [- 1 ]
116101 dones [- 1 ] = 1.0
117102
118- # Write to memmapped arrays
119103 self .states [current_idx :current_idx + clip_len ] = traj_states
120104 self .actions [current_idx :current_idx + clip_len ] = traj_actions
121105 self .rewards [current_idx :current_idx + clip_len ] = rewards
@@ -124,16 +108,12 @@ def __init__(self, dataset_path):
124108
125109 current_idx += clip_len
126110
127- # Convert numpy memmaps to torch tensors
128111 self .states = torch .from_numpy (self .states ).float ()
129112 self .actions = torch .from_numpy (self .actions ).float ()
130113 self .rewards = torch .from_numpy (self .rewards ).float ()
131114 self .next_states = torch .from_numpy (self .next_states ).float ()
132115 self .dones = torch .from_numpy (self .dones ).float ()
133116
134- def _cleanup (self ):
135- shutil .rmtree (self .temp_dir )
136-
137117 def __len__ (self ):
138118 return len (self .states )
139119
@@ -383,6 +363,14 @@ def main():
383363 "temperature" : cfg_raw .get ("temperature" , 3.0 ),
384364 "expectile" : cfg_raw .get ("expectile" , 0.7 ),
385365 "hidden_size" : cfg_raw .get ("hidden_size" , 256 )
366+ },
367+ "cql" : {
368+ "tau" : cfg_raw .get ("tau" , 0.005 ),
369+ "temperature" : cfg_raw .get ("temperature" , 1.0 ),
370+ "hidden_size" : cfg_raw .get ("hidden_size" , 256 ),
371+ "with_lagrange" : cfg_raw .get ("with_lagrange" , False ),
372+ "cql_weight" : cfg_raw .get ("cql_weight" , 1.0 ),
373+ "target_action_gap" : cfg_raw .get ("target_action_gap" , 10.0 )
386374 }
387375 }
388376
0 commit comments