3737from src .utils .seed import seed_everything
3838from src .utils .evaluation import create_env
3939from scripts .eval import evaluate_policy
40+ from src .utils .d4rl_dataset import D4RLSequenceDataset , D4RLTransitionDataset
4041
4142# Lazy Imports for Models to avoid norse/tensorflow crash on Python 3.13
4243def get_model_class (name ):
@@ -225,24 +226,47 @@ def train(cfg, logger):
225226 logger .info (f"--- Checkpoint: Save directory created at { save_dir } ---" )
226227
227228 # Load data and metadata
228- if cfg .model .name in ['iql' , 'cql' ]:
229- dataset = OfflineTransitionDataset (cfg .dataset .path )
229+ state_mean = None
230+ state_std = None
231+
232+ if cfg .dataset .mode == 'd4rl_direct' :
233+ if os .path .isfile (cfg .dataset .path ):
234+ data_dir = os .path .dirname (cfg .dataset .path )
235+ else :
236+ data_dir = cfg .dataset .path
237+
238+ if cfg .model .name in ['iql' , 'cql' ]:
239+ dataset = D4RLTransitionDataset (cfg .env , data_dir = data_dir )
240+ else :
241+ dataset = D4RLSequenceDataset (cfg .env , data_dir = data_dir , seq_len = cfg .model .seq_len )
242+
243+ state_mean = dataset .state_mean
244+ state_std = dataset .state_std
245+
246+ cfg .dataset .state_dim = dataset .states .shape [1 ] if hasattr (dataset , 'states' ) else dataset .state_dim
247+ cfg .dataset .act_dim = dataset .act_dim if hasattr (dataset , 'act_dim' ) else dataset .actions .shape [1 ]
248+ cfg .dataset .max_timesteps = 1000 # Standard MuJoCo
249+
230250 else :
231- dataset = OfflineDataset (cfg .dataset .path )
251+ # Legacy
252+ if cfg .model .name in ['iql' , 'cql' ]:
253+ dataset = OfflineTransitionDataset (cfg .dataset .path )
254+ else :
255+ dataset = OfflineDataset (cfg .dataset .path )
256+
257+ with np .load (cfg .dataset .path , allow_pickle = True ) as data :
258+ metadata = data ["metadata" ].item ()
259+ if isinstance (metadata , str ):
260+ metadata = yaml .safe_load (metadata )
261+
262+ cfg .dataset .state_dim = metadata ["state_dim" ]
263+ cfg .dataset .act_dim = metadata ["act_dim" ]
264+ cfg .dataset .max_timesteps = metadata ["max_timesteps" ]
265+
232266 if len (dataset ) == 0 :
233267 logger .error (f"Dataset at { cfg .dataset .path } is empty! Aborting training." )
234268 sys .exit (1 )
235- logger .info (f"Dataset size: { len (dataset )} clips" )
236-
237- with np .load (cfg .dataset .path , allow_pickle = True ) as data :
238- metadata = data ["metadata" ].item ()
239- if isinstance (metadata , str ):
240- metadata = yaml .safe_load (metadata )
241-
242- # Update config with dataset metadata
243- cfg .dataset .state_dim = metadata ["state_dim" ]
244- cfg .dataset .act_dim = metadata ["act_dim" ]
245- cfg .dataset .max_timesteps = metadata ["max_timesteps" ]
269+ logger .info (f"Dataset size: { len (dataset )} items" )
246270
247271 # Lazily import gymnasium to avoid potential C-extension conflicts at startup
248272 import gymnasium as gym
@@ -349,7 +373,14 @@ def train(cfg, logger):
349373 if env is None :
350374 env = create_env (cfg .env , simulator_available = cfg .training .simulator_available , dataset_path = cfg .dataset .path )
351375
352- eval_results = evaluate_policy (model , env , cfg , episodes = cfg .hyperparameters .eval_episodes )
376+ eval_results = evaluate_policy (
377+ model ,
378+ env ,
379+ cfg ,
380+ episodes = cfg .hyperparameters .eval_episodes ,
381+ state_mean = state_mean ,
382+ state_std = state_std
383+ )
353384 epoch_time = time .time () - start_time
354385 avg_loss = np .mean (epoch_losses ) if epoch_losses else 0.0
355386
@@ -447,8 +478,9 @@ def handle_exception(exc_type, exc_value, exc_traceback):
447478 parser .add_argument ("--env" , type = str , required = True , help = "Environment name (e.g., CartPole-v1)." )
448479 parser .add_argument ("--save-dir" , type = str , default = "results/run" , help = "Directory to save results." )
449480 parser .add_argument ("--seed" , type = int , default = 42 , help = "Random seed." )
450- parser .add_argument ("--dataset-path" , type = str , default = None , help = "Explicit path to dataset file." )
481+ parser .add_argument ("--dataset-path" , type = str , default = None , help = "Explicit path to dataset file or directory ." )
451482 parser .add_argument ("--simulator-available" , action = "store_true" , help = "Set if a real simulator is available for eval." )
483+ parser .add_argument ("--dataset-mode" , type = str , default = "d4rl_direct" , help = "Dataset mode: 'legacy' (npz) or 'd4rl_direct' (hdf5)." )
452484 args = parser .parse_args ()
453485
454486 # Configure logging
@@ -479,12 +511,12 @@ def handle_exception(exc_type, exc_value, exc_traceback):
479511
480512 if args .model in model_abbr and args .env in env_abbr :
481513 config_name = f"{ model_abbr [args .model ]} _{ env_abbr [args .env ]} .yaml"
482- args .config = str (project_root / "configs" / config_name )
514+ args .config = str (snn_dt_root / "configs" / config_name )
483515 else :
484516 # Just try a generic name if above fails
485517 config_name = f"{ args .model } _{ args .env } .yaml"
486- if (project_root / "configs" / config_name ).exists ():
487- args .config = str (project_root / "configs" / config_name )
518+ if (snn_dt_root / "configs" / config_name ).exists ():
519+ args .config = str (snn_dt_root / "configs" / config_name )
488520 else :
489521 # Last resort: use a default?
490522 pass
@@ -522,7 +554,7 @@ def handle_exception(exc_type, exc_value, exc_traceback):
522554 "simulator_available" : args .simulator_available ,
523555 },
524556 "dataset" : {
525- "path" : cfg_raw .get ("dataset" , None ),
557+ "path" : cfg_raw .get ("dataset" , {}). get ( "path" ) if isinstance ( cfg_raw . get ( "dataset" ), dict ) else cfg_raw . get ( "dataset" , None ),
526558 "state_dim" : None , # Will be set from metadata
527559 "act_dim" : None , # Will be set from metadata
528560 "max_timesteps" : None # Will be set from metadata
@@ -563,12 +595,18 @@ def handle_exception(exc_type, exc_value, exc_traceback):
563595 logger .info (f"SNN Config: { cfg .snn } " )
564596
565597 # Dataset path priority: Args > Config > Default
598+ # Config for dataset mode
599+ cfg .dataset .mode = args .dataset_mode
600+
566601 if args .dataset_path :
567602 cfg .dataset .path = args .dataset_path
568603 elif cfg .dataset .path is None :
569- cfg .dataset .path = str (project_root / f"data/{ args .env } /dataset.npz" )
604+ if args .dataset_mode == 'd4rl_direct' :
605+ cfg .dataset .path = str (snn_dt_root / "data/d4rl_raw" )
606+ else :
607+ cfg .dataset .path = str (snn_dt_root / f"data/{ args .env } /dataset.npz" )
570608
571- # Check if dataset exists
609+ # Check if dataset exists (folder or file)
572610 if not os .path .exists (cfg .dataset .path ):
573611 logger .warning (f"Dataset not found at { cfg .dataset .path } . Training will likely fail if data isn't generated." )
574612 else :
@@ -583,4 +621,4 @@ def handle_exception(exc_type, exc_value, exc_traceback):
583621
584622
585623if __name__ == "__main__" :
586- main ()
624+ main ()
0 commit comments