1919from torch .nn import CrossEntropyLoss , L1Loss , MSELoss , NLLLoss
2020from torch .optim import SGD , Adam , AdamW , Optimizer
2121from torch .optim .lr_scheduler import MultiStepLR , _LRScheduler
22- from torch .utils .data import DataLoader , Subset
22+ from torch .utils .data import DataLoader , Dataset , Subset
2323from torch .utils .tensorboard import SummaryWriter
2424
2525from aviary .core import BaseModelClass , Normalizer , TaskType , sampled_softmax
@@ -45,13 +45,13 @@ def init_model(
4545 Args:
4646 model_class (type[BaseModelClass]): Which model class to initialize.
4747 model_params (dict[str, Any]): Dictionary containing model specific hyperparameters.
48- device (type[torch.device] | Literal[ "cuda", "cpu"] ): Device the model will run on.
48+ device (type[torch.device] | "cuda" | "cpu"): Device the model will run on.
4949 resume (str, optional): Path to model checkpoint to resume. Defaults to None.
5050 fine_tune (str, optional): Path to model checkpoint to fine tune. Defaults to None.
5151 transfer (str, optional): Path to model checkpoint to transfer. Defaults to None.
5252
5353 Returns:
54- type[ BaseModelClass] : An initialised model of type model_class
54+ BaseModelClass: An initialised model of type model_class.
5555 """
5656 robust = model_params ["robust" ]
5757 n_targets = model_params ["n_targets" ]
@@ -149,11 +149,11 @@ def init_optim(
149149
150150 Args:
151151 model (type[BaseModelClass]): Model to be optimized.
152- optim (type[Optimizer] | Literal[ "SGD", "Adam", "AdamW"] ): Which optimizer to use
153- learning_rate (float): Learning rate for optimzation
152+ optim (type[Optimizer] | "SGD" | "Adam" | "AdamW"): Which optimizer to use
153+ learning_rate (float): Learning rate for optimization
154154 weight_decay (float): Weight decay for optimizer
155155 momentum (float): Momentum for optimizer
156- device (type[torch.device] | Literal[ "cuda", "cpu"] ): Device the model will run on
156+ device (type[torch.device] | "cuda" | "cpu"): Device the model will run on
157157 milestones (Iterable, optional): When to decay learning rate. Defaults to ().
158158 gamma (float, optional): Multiplier for learning rate decay. Defaults to 0.3.
159159 resume (str, optional): Path to model checkpoint to resume. Defaults to None.
@@ -203,7 +203,7 @@ def init_losses(
203203
204204 Args:
205205 task_dict (dict[str, TaskType]): Map of target names to "regression" or "classification".
206- loss_dict (dict[str, Literal[ "L1", "L2", "CSE"] ]): Map of target names to loss functions.
206+ loss_dict (dict[str, "L1" | "L2" | "CSE"]): Map of target names to loss functions.
207207 robust (bool, optional): Whether to use an uncertainty adjusted loss. Defaults to False.
208208
209209 Returns:
@@ -253,7 +253,7 @@ def init_normalizers(
253253
254254 Args:
255255 task_dict (dict[str, TaskType]): Map of target names to "regression" or "classification".
256- device (type[ torch.device] | Literal[ "cuda", "cpu"] ): Device the model will run on
256+ device (torch.device | "cuda" | "cpu"): Device the model will run on
257257 resume (str, optional): Path to model checkpoint to resume. Defaults to None.
258258
259259 Returns:
@@ -284,8 +284,8 @@ def train_ensemble(
284284 run_id : int ,
285285 ensemble_folds : int ,
286286 epochs : int ,
287- train_set : Subset ,
288- val_set : Subset ,
287+ train_set : Dataset | Subset ,
288+ val_set : Dataset | Subset ,
289289 log : bool ,
290290 data_params : dict [str , Any ],
291291 setup_params : dict [str , Any ],
@@ -310,12 +310,17 @@ def train_ensemble(
310310 setup_params (dict[str, Any]): Dictionary of setup parameters
311311 restart_params (dict[str, Any]): Dictionary of restart parameters
312312 model_params (dict[str, Any]): Dictionary of model parameters
313- loss_dict (dict[str, Literal[ "L1", "L2", "CSE"] ]): Map of target names
313+ loss_dict (dict[str, "L1" | "L2" | "CSE"]): Map of target names
314314 to loss functions.
315315 patience (int, optional): Maximum number of epochs without improvement
316316 when early stopping. Defaults to None.
317317 verbose (bool, optional): Whether to show progress bars for each epoch.
318318 """
319+ if isinstance (train_set , Subset ):
320+ train_set = train_set .dataset
321+ if isinstance (val_set , Subset ):
322+ val_set = val_set .dataset
323+
319324 train_generator = DataLoader (train_set , ** data_params )
320325 print (f"Training on { len (train_set ):,} samples" )
321326
@@ -350,13 +355,11 @@ def train_ensemble(
350355
351356 for target , normalizer in normalizer_dict .items ():
352357 if normalizer is not None :
353- sample_target = Tensor (
354- train_set .dataset .df [target ].iloc [train_set .indices ].values
355- )
358+ sample_target = Tensor (train_set .df [target ].values )
356359 if not restart_params ["resume" ]:
357360 normalizer .fit (sample_target )
358361 print (
359- f"Dummy MAE: { torch . mean ( torch . abs ( sample_target - normalizer .mean )):.4f} "
362+ f"Dummy MAE: { ( sample_target - normalizer .mean ). abs (). mean ( ):.4f} "
360363 )
361364
362365 if log :
@@ -415,7 +418,7 @@ def results_multitask( # noqa: C901
415418 model_name : str ,
416419 run_id : int ,
417420 ensemble_folds : int ,
418- test_set : Subset ,
421+ test_set : Dataset | Subset ,
419422 data_params : dict [str , Any ],
420423 robust : bool ,
421424 task_dict : dict [str , TaskType ],
@@ -436,7 +439,7 @@ def results_multitask( # noqa: C901
436439 loss function.
437440 task_dict (dict[str, TaskType]): Map of target names to "regression" or
438441 "classification".
439- device (type[torch.device] | Literal[ "cuda", "cpu"] ): Device the model will run on
442+ device (type[torch.device] | "cuda" | "cpu"): Device the model will run on
440443 eval_type (str, optional): Whether to use final or early-stopping checkpoints.
441444 Defaults to "checkpoint".
442445 print_results (bool, optional): Whether to print out summary metrics.
@@ -459,6 +462,9 @@ def results_multitask( # noqa: C901
459462 "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n "
460463 )
461464
465+ if isinstance (test_set , Subset ):
466+ test_set = test_set .dataset
467+
462468 test_generator = DataLoader (test_set , ** data_params )
463469 print (f"Testing on { len (test_set ):,} samples" )
464470
0 commit comments