Skip to content

Commit d47b77a

Browse files
Update train.py
1 parent e5abc8e commit d47b77a

1 file changed

Lines changed: 61 additions & 23 deletions

File tree

scripts/train.py

Lines changed: 61 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
from src.utils.seed import seed_everything
3838
from src.utils.evaluation import create_env
3939
from scripts.eval import evaluate_policy
40+
from src.utils.d4rl_dataset import D4RLSequenceDataset, D4RLTransitionDataset
4041

4142
# Lazy Imports for Models to avoid norse/tensorflow crash on Python 3.13
4243
def get_model_class(name):
@@ -225,24 +226,47 @@ def train(cfg, logger):
225226
logger.info(f"--- Checkpoint: Save directory created at {save_dir} ---")
226227

227228
# Load data and metadata
228-
if cfg.model.name in ['iql', 'cql']:
229-
dataset = OfflineTransitionDataset(cfg.dataset.path)
229+
state_mean = None
230+
state_std = None
231+
232+
if cfg.dataset.mode == 'd4rl_direct':
233+
if os.path.isfile(cfg.dataset.path):
234+
data_dir = os.path.dirname(cfg.dataset.path)
235+
else:
236+
data_dir = cfg.dataset.path
237+
238+
if cfg.model.name in ['iql', 'cql']:
239+
dataset = D4RLTransitionDataset(cfg.env, data_dir=data_dir)
240+
else:
241+
dataset = D4RLSequenceDataset(cfg.env, data_dir=data_dir, seq_len=cfg.model.seq_len)
242+
243+
state_mean = dataset.state_mean
244+
state_std = dataset.state_std
245+
246+
cfg.dataset.state_dim = dataset.states.shape[1] if hasattr(dataset, 'states') else dataset.state_dim
247+
cfg.dataset.act_dim = dataset.act_dim if hasattr(dataset, 'act_dim') else dataset.actions.shape[1]
248+
cfg.dataset.max_timesteps = 1000 # Standard MuJoCo
249+
230250
else:
231-
dataset = OfflineDataset(cfg.dataset.path)
251+
# Legacy
252+
if cfg.model.name in ['iql', 'cql']:
253+
dataset = OfflineTransitionDataset(cfg.dataset.path)
254+
else:
255+
dataset = OfflineDataset(cfg.dataset.path)
256+
257+
with np.load(cfg.dataset.path, allow_pickle=True) as data:
258+
metadata = data["metadata"].item()
259+
if isinstance(metadata, str):
260+
metadata = yaml.safe_load(metadata)
261+
262+
cfg.dataset.state_dim = metadata["state_dim"]
263+
cfg.dataset.act_dim = metadata["act_dim"]
264+
cfg.dataset.max_timesteps = metadata["max_timesteps"]
265+
232266
if len(dataset) == 0:
233267
logger.error(f"Dataset at {cfg.dataset.path} is empty! Aborting training.")
234268
sys.exit(1)
235-
logger.info(f"Dataset size: {len(dataset)} clips")
236-
237-
with np.load(cfg.dataset.path, allow_pickle=True) as data:
238-
metadata = data["metadata"].item()
239-
if isinstance(metadata, str):
240-
metadata = yaml.safe_load(metadata)
241-
242-
# Update config with dataset metadata
243-
cfg.dataset.state_dim = metadata["state_dim"]
244-
cfg.dataset.act_dim = metadata["act_dim"]
245-
cfg.dataset.max_timesteps = metadata["max_timesteps"]
269+
logger.info(f"Dataset size: {len(dataset)} items")
246270

247271
# Lazily import gymnasium to avoid potential C-extension conflicts at startup
248272
import gymnasium as gym
@@ -349,7 +373,14 @@ def train(cfg, logger):
349373
if env is None:
350374
env = create_env(cfg.env, simulator_available=cfg.training.simulator_available, dataset_path=cfg.dataset.path)
351375

352-
eval_results = evaluate_policy(model, env, cfg, episodes=cfg.hyperparameters.eval_episodes)
376+
eval_results = evaluate_policy(
377+
model,
378+
env,
379+
cfg,
380+
episodes=cfg.hyperparameters.eval_episodes,
381+
state_mean=state_mean,
382+
state_std=state_std
383+
)
353384
epoch_time = time.time() - start_time
354385
avg_loss = np.mean(epoch_losses) if epoch_losses else 0.0
355386

@@ -447,8 +478,9 @@ def handle_exception(exc_type, exc_value, exc_traceback):
447478
parser.add_argument("--env", type=str, required=True, help="Environment name (e.g., CartPole-v1).")
448479
parser.add_argument("--save-dir", type=str, default="results/run", help="Directory to save results.")
449480
parser.add_argument("--seed", type=int, default=42, help="Random seed.")
450-
parser.add_argument("--dataset-path", type=str, default=None, help="Explicit path to dataset file.")
481+
parser.add_argument("--dataset-path", type=str, default=None, help="Explicit path to dataset file or directory.")
451482
parser.add_argument("--simulator-available", action="store_true", help="Set if a real simulator is available for eval.")
483+
parser.add_argument("--dataset-mode", type=str, default="d4rl_direct", help="Dataset mode: 'legacy' (npz) or 'd4rl_direct' (hdf5).")
452484
args = parser.parse_args()
453485

454486
# Configure logging
@@ -479,12 +511,12 @@ def handle_exception(exc_type, exc_value, exc_traceback):
479511

480512
if args.model in model_abbr and args.env in env_abbr:
481513
config_name = f"{model_abbr[args.model]}_{env_abbr[args.env]}.yaml"
482-
args.config = str(project_root / "configs" / config_name)
514+
args.config = str(snn_dt_root / "configs" / config_name)
483515
else:
484516
# Just try a generic name if above fails
485517
config_name = f"{args.model}_{args.env}.yaml"
486-
if (project_root / "configs" / config_name).exists():
487-
args.config = str(project_root / "configs" / config_name)
518+
if (snn_dt_root / "configs" / config_name).exists():
519+
args.config = str(snn_dt_root / "configs" / config_name)
488520
else:
489521
# Last resort: use a default?
490522
pass
@@ -522,7 +554,7 @@ def handle_exception(exc_type, exc_value, exc_traceback):
522554
"simulator_available": args.simulator_available,
523555
},
524556
"dataset": {
525-
"path": cfg_raw.get("dataset", None),
557+
"path": cfg_raw.get("dataset", {}).get("path") if isinstance(cfg_raw.get("dataset"), dict) else cfg_raw.get("dataset", None),
526558
"state_dim": None, # Will be set from metadata
527559
"act_dim": None, # Will be set from metadata
528560
"max_timesteps": None # Will be set from metadata
@@ -563,12 +595,18 @@ def handle_exception(exc_type, exc_value, exc_traceback):
563595
logger.info(f"SNN Config: {cfg.snn}")
564596

565597
# Dataset path priority: Args > Config > Default
598+
# Config for dataset mode
599+
cfg.dataset.mode = args.dataset_mode
600+
566601
if args.dataset_path:
567602
cfg.dataset.path = args.dataset_path
568603
elif cfg.dataset.path is None:
569-
cfg.dataset.path = str(project_root / f"data/{args.env}/dataset.npz")
604+
if args.dataset_mode == 'd4rl_direct':
605+
cfg.dataset.path = str(snn_dt_root / "data/d4rl_raw")
606+
else:
607+
cfg.dataset.path = str(snn_dt_root / f"data/{args.env}/dataset.npz")
570608

571-
# Check if dataset exists
609+
# Check if dataset exists (folder or file)
572610
if not os.path.exists(cfg.dataset.path):
573611
logger.warning(f"Dataset not found at {cfg.dataset.path}. Training will likely fail if data isn't generated.")
574612
else:
@@ -583,4 +621,4 @@ def handle_exception(exc_type, exc_value, exc_traceback):
583621

584622

585623
if __name__ == "__main__":
586-
main()
624+
main()

0 commit comments

Comments
 (0)