11from __future__ import annotations
22
33from collections import deque
4+ from typing import TYPE_CHECKING
45
56import numpy as np
67import torch
1415_N_TRANSFORMS = 8
1516_POLICY_INDEX_MAPS : np .ndarray | None = None
1617
18+ if TYPE_CHECKING :
19+ from data .replay_buffer import TrainingExample
20+
1721
1822def _rotate_coord_ccw (r : int , c : int , k : int , size : int ) -> tuple [int , int ]:
1923 rr , cc = r , c
@@ -98,25 +102,65 @@ def _augment_policy(policy: np.ndarray, transform_id: int) -> np.ndarray:
98102 return pi_aug
99103
100104
105+ def split_train_val_examples (
106+ * ,
107+ all_examples : list [TrainingExample ],
108+ val_split : float ,
109+ shuffle : bool ,
110+ seed : int ,
111+ ) -> tuple [list [TrainingExample ], list [TrainingExample ]]:
112+ """Split examples into disjoint train/val sets with optional seeded shuffling."""
113+ n_total = len (all_examples )
114+ if n_total == 0 :
115+ return [], []
116+ n_val = int (n_total * val_split )
117+ n_val = min (max (0 , n_val ), n_total )
118+ n_train = n_total - n_val
119+ if n_val == 0 :
120+ return list (all_examples ), []
121+ if not shuffle :
122+ return list (all_examples [:n_train ]), list (all_examples [n_train :])
123+
124+ rng = np .random .default_rng (seed = seed )
125+ val_indices = np .sort (rng .choice (n_total , size = n_val , replace = False ))
126+ val_set = {int (i ) for i in val_indices .tolist ()}
127+ # Keep train set in chronological order so "recent" remains meaningful.
128+ train_indices = [idx for idx in range (n_total ) if idx not in val_set ]
129+ train_examples = [all_examples [idx ] for idx in train_indices ]
130+ val_examples = [all_examples [int (idx )] for idx in val_indices ]
131+ return train_examples , val_examples
132+
133+
101134class AtaxxDataset (Dataset [tuple [torch .Tensor , torch .Tensor , torch .Tensor ]]):
102135 """Dataset wrapper from replay buffer examples."""
103136
104137 def __init__ (
105138 self ,
106- buffer : ReplayBuffer ,
139+ buffer : ReplayBuffer | None = None ,
107140 augment : bool = True ,
108141 reference_buffer : bool = False ,
109142 val_split : float = 0.1 ,
143+ examples : list [TrainingExample ] | None = None ,
110144 ) -> None :
111145 self .augment = augment
112146 self .examples : list [tuple [np .ndarray , np .ndarray , float ]] | deque [
113147 tuple [np .ndarray , np .ndarray , float ]
114148 ]
149+ if examples is not None :
150+ self .examples = list (examples )
151+ return
152+ if buffer is None :
153+ self .examples = []
154+ return
155+
115156 raw_examples = list (buffer .buffer ) if reference_buffer else buffer .get_all ()
116- n_val = int (len (raw_examples ) * val_split )
117- n_train = len (raw_examples ) - n_val
118- # Keep train/validation disjoint so val loss is a true hold-out metric.
119- self .examples = raw_examples [:n_train ]
157+ train_examples , _ = split_train_val_examples (
158+ all_examples = raw_examples ,
159+ val_split = val_split ,
160+ shuffle = False ,
161+ seed = 0 ,
162+ )
163+ self .examples = train_examples
120164
121165 def __len__ (self ) -> int :
122166 return len (self .examples )
@@ -140,11 +184,26 @@ def __getitem__(self, index: int) -> tuple[torch.Tensor, torch.Tensor, torch.Ten
140184class ValidationDataset (Dataset [tuple [torch .Tensor , torch .Tensor , torch .Tensor ]]):
141185 """Hold-out validation split from replay buffer."""
142186
143- def __init__ (self , buffer : ReplayBuffer , split : float = 0.1 ) -> None :
187+ def __init__ (
188+ self ,
189+ buffer : ReplayBuffer | None = None ,
190+ split : float = 0.1 ,
191+ examples : list [TrainingExample ] | None = None ,
192+ ) -> None :
193+ if examples is not None :
194+ self .examples = list (examples )
195+ return
196+ if buffer is None :
197+ self .examples = []
198+ return
144199 all_examples = buffer .get_all ()
145- n_val = int (len (all_examples ) * split )
146- n_train = len (all_examples ) - n_val
147- self .examples = all_examples [n_train :] if n_val > 0 else []
200+ _ , val_examples = split_train_val_examples (
201+ all_examples = all_examples ,
202+ val_split = split ,
203+ shuffle = False ,
204+ seed = 0 ,
205+ )
206+ self .examples = val_examples
148207
149208 def __len__ (self ) -> int :
150209 return len (self .examples )
0 commit comments