|
| 1 | +"""Experiment adapter for PyTorch Lightning experiments.""" |
| 2 | + |
| 3 | +# copyright: hyperactive developers, MIT License (see LICENSE file) |
| 4 | + |
| 5 | +__author__ = ["amitsubhashchejara"] |
| 6 | + |
| 7 | +import numpy as np |
| 8 | + |
| 9 | +from hyperactive.base import BaseExperiment |
| 10 | + |
| 11 | + |
| 12 | +class TorchExperiment(BaseExperiment): |
| 13 | + """Experiment adapter for PyTorch Lightning experiments. |
| 14 | +
|
| 15 | + This class is used to perform experiments using PyTorch Lightning modules. |
| 16 | + It allows for hyperparameter tuning and evaluation of the model's performance |
| 17 | + using specified metrics. |
| 18 | +
|
| 19 | + The experiment trains a Lightning module with given hyperparameters and returns |
| 20 | + the validation metric value for optimization. |
| 21 | +
|
| 22 | + Parameters |
| 23 | + ---------- |
| 24 | + datamodule : L.LightningDataModule |
| 25 | + A PyTorch Lightning DataModule that handles data loading and preparation. |
| 26 | + lightning_module : type |
| 27 | + A PyTorch Lightning Module class (not an instance) that will be instantiated |
| 28 | + with hyperparameters during optimization. |
| 29 | + trainer_kwargs : dict, optional (default=None) |
| 30 | + A dictionary of keyword arguments to pass to the PyTorch Lightning Trainer. |
| 31 | + objective_metric : str, optional (default='val_loss') |
| 32 | + The metric used to evaluate the model's performance. This should correspond |
| 33 | + to a metric logged in the LightningModule during validation. |
| 34 | +
|
| 35 | + Examples |
| 36 | + -------- |
| 37 | + >>> from hyperactive.experiment.integrations import TorchExperiment |
| 38 | + >>> import torch |
| 39 | + >>> import lightning as L |
| 40 | + >>> from torch import nn |
| 41 | + >>> from torch.utils.data import DataLoader |
| 42 | + >>> |
| 43 | + >>> # Define a simple Lightning Module |
| 44 | + >>> class SimpleLightningModule(L.LightningModule): |
| 45 | + ... def __init__(self, input_dim=10, hidden_dim=16, lr=1e-3): |
| 46 | + ... super().__init__() |
| 47 | + ... self.save_hyperparameters() |
| 48 | + ... self.model = nn.Sequential( |
| 49 | + ... nn.Linear(input_dim, hidden_dim), |
| 50 | + ... nn.ReLU(), |
| 51 | + ... nn.Linear(hidden_dim, 2) |
| 52 | + ... ) |
| 53 | + ... self.lr = lr |
| 54 | + ... |
| 55 | + ... def forward(self, x): |
| 56 | + ... return self.model(x) |
| 57 | + ... |
| 58 | + ... def training_step(self, batch, batch_idx): |
| 59 | + ... x, y = batch |
| 60 | + ... y_hat = self(x) |
| 61 | + ... loss = nn.functional.cross_entropy(y_hat, y) |
| 62 | + ... self.log("train_loss", loss) |
| 63 | + ... return loss |
| 64 | + ... |
| 65 | + ... def validation_step(self, batch, batch_idx): |
| 66 | + ... x, y = batch |
| 67 | + ... y_hat = self(x) |
| 68 | + ... val_loss = nn.functional.cross_entropy(y_hat, y) |
| 69 | + ... self.log("val_loss", val_loss, on_epoch=True) |
| 70 | + ... return val_loss |
| 71 | + ... |
| 72 | + ... def configure_optimizers(self): |
| 73 | + ... return torch.optim.Adam(self.parameters(), lr=self.lr) |
| 74 | + >>> |
| 75 | + >>> # Create DataModule |
| 76 | + >>> class RandomDataModule(L.LightningDataModule): |
| 77 | + ... def __init__(self, batch_size=32): |
| 78 | + ... super().__init__() |
| 79 | + ... self.batch_size = batch_size |
| 80 | + ... |
| 81 | + ... def setup(self, stage=None): |
| 82 | + ... dataset = torch.utils.data.TensorDataset( |
| 83 | + ... torch.randn(100, 10), |
| 84 | + ... torch.randint(0, 2, (100,)) |
| 85 | + ... ) |
| 86 | + ... self.train, self.val = torch.utils.data.random_split( |
| 87 | + ... dataset, [80, 20] |
| 88 | + ... ) |
| 89 | + ... |
| 90 | + ... def train_dataloader(self): |
| 91 | + ... return DataLoader(self.train, batch_size=self.batch_size) |
| 92 | + ... |
| 93 | + ... def val_dataloader(self): |
| 94 | + ... return DataLoader(self.val, batch_size=self.batch_size) |
| 95 | + >>> |
| 96 | + >>> datamodule = RandomDataModule(batch_size=16) |
| 97 | + >>> datamodule.setup() |
| 98 | + >>> |
| 99 | + >>> # Create Experiment |
| 100 | + >>> experiment = TorchExperiment( |
| 101 | + ... datamodule=datamodule, |
| 102 | + ... lightning_module=SimpleLightningModule, |
| 103 | + ... trainer_kwargs={'max_epochs': 3}, |
| 104 | + ... objective_metric="val_loss" |
| 105 | + ... ) |
| 106 | + >>> |
| 107 | + >>> params = {"input_dim": 10, "hidden_dim": 16, "lr": 1e-3} |
| 108 | + >>> |
| 109 | + >>> val_result, metadata = experiment._evaluate(params) |
| 110 | + """ |
| 111 | + |
| 112 | + _tags = { |
| 113 | + "property:randomness": "random", |
| 114 | + "property:higher_or_lower_is_better": "lower", |
| 115 | + "authors": ["amitsubhashchejara"], |
| 116 | + "python_dependencies": ["torch", "lightning"], |
| 117 | + } |
| 118 | + |
| 119 | + def __init__( |
| 120 | + self, |
| 121 | + datamodule, |
| 122 | + lightning_module, |
| 123 | + trainer_kwargs=None, |
| 124 | + objective_metric: str = "val_loss", |
| 125 | + ): |
| 126 | + self.datamodule = datamodule |
| 127 | + self.lightning_module = lightning_module |
| 128 | + self.trainer_kwargs = trainer_kwargs or {} |
| 129 | + self.objective_metric = objective_metric |
| 130 | + |
| 131 | + super().__init__() |
| 132 | + |
| 133 | + self._trainer_kwargs = { |
| 134 | + "max_epochs": 10, |
| 135 | + "enable_checkpointing": False, |
| 136 | + "logger": False, |
| 137 | + "enable_progress_bar": False, |
| 138 | + "enable_model_summary": False, |
| 139 | + } |
| 140 | + if trainer_kwargs is not None: |
| 141 | + self._trainer_kwargs.update(trainer_kwargs) |
| 142 | + |
| 143 | + def _paramnames(self): |
| 144 | + """Return the parameter names of the search. |
| 145 | +
|
| 146 | + Returns |
| 147 | + ------- |
| 148 | + list of str, or None |
| 149 | + The parameter names of the search parameters. |
| 150 | + If not known or arbitrary, return None. |
| 151 | + """ |
| 152 | + import inspect |
| 153 | + |
| 154 | + sig = inspect.signature(self.lightning_module.__init__) |
| 155 | + return [p for p in sig.parameters.keys() if p != "self"] |
| 156 | + |
| 157 | + def _evaluate(self, params): |
| 158 | + """Evaluate the parameters. |
| 159 | +
|
| 160 | + Parameters |
| 161 | + ---------- |
| 162 | + params : dict with string keys |
| 163 | + Parameters to evaluate. |
| 164 | +
|
| 165 | + Returns |
| 166 | + ------- |
| 167 | + float |
| 168 | + The value of the parameters as per evaluation. |
| 169 | + dict |
| 170 | + Additional metadata about the search. |
| 171 | + """ |
| 172 | + import lightning as L |
| 173 | + |
| 174 | + try: |
| 175 | + model = self.lightning_module(**params) |
| 176 | + trainer = L.Trainer(**self._trainer_kwargs) |
| 177 | + trainer.fit(model, self.datamodule) |
| 178 | + |
| 179 | + val_result = trainer.callback_metrics.get(self.objective_metric) |
| 180 | + metadata = {} |
| 181 | + |
| 182 | + if val_result is None: |
| 183 | + available_metrics = list(trainer.callback_metrics.keys()) |
| 184 | + raise ValueError( |
| 185 | + f"Metric '{self.objective_metric}' not found. " |
| 186 | + f"Available: {available_metrics}" |
| 187 | + ) |
| 188 | + if hasattr(val_result, "item"): |
| 189 | + val_result = np.float64(val_result.detach().cpu().item()) |
| 190 | + elif isinstance(val_result, (int, float)): |
| 191 | + val_result = np.float64(val_result) |
| 192 | + else: |
| 193 | + val_result = np.float64(float(val_result)) |
| 194 | + |
| 195 | + return val_result, metadata |
| 196 | + |
| 197 | + except Exception as e: |
| 198 | + print(f"Training failed with params {params}: {e}") |
| 199 | + return np.float64(float("inf")), {} |
| 200 | + |
| 201 | + @classmethod |
| 202 | + def get_test_params(cls, parameter_set="default"): |
| 203 | + """Return testing parameter settings for the estimator. |
| 204 | +
|
| 205 | + Parameters |
| 206 | + ---------- |
| 207 | + parameter_set : str, default="default" |
| 208 | + Name of the set of test parameters to return, for use in tests. |
| 209 | +
|
| 210 | + Returns |
| 211 | + ------- |
| 212 | + params : dict or list of dict, default = {} |
| 213 | + Parameters to create testing instances of the class. |
| 214 | + """ |
| 215 | + import lightning as L |
| 216 | + import torch |
| 217 | + from torch import nn |
| 218 | + from torch.utils.data import DataLoader |
| 219 | + |
| 220 | + class SimpleLightningModule(L.LightningModule): |
| 221 | + def __init__(self, input_dim=10, hidden_dim=16, lr=1e-3): |
| 222 | + super().__init__() |
| 223 | + self.save_hyperparameters() |
| 224 | + self.model = nn.Sequential( |
| 225 | + nn.Linear(input_dim, hidden_dim), |
| 226 | + nn.ReLU(), |
| 227 | + nn.Linear(hidden_dim, 2), |
| 228 | + ) |
| 229 | + self.lr = lr |
| 230 | + |
| 231 | + def forward(self, x): |
| 232 | + return self.model(x) |
| 233 | + |
| 234 | + def training_step(self, batch, batch_idx): |
| 235 | + x, y = batch |
| 236 | + y_hat = self(x) |
| 237 | + loss = nn.functional.cross_entropy(y_hat, y) |
| 238 | + self.log("train_loss", loss) |
| 239 | + return loss |
| 240 | + |
| 241 | + def validation_step(self, batch, batch_idx): |
| 242 | + x, y = batch |
| 243 | + y_hat = self(x) |
| 244 | + val_loss = nn.functional.cross_entropy(y_hat, y) |
| 245 | + self.log("val_loss", val_loss, on_epoch=True) |
| 246 | + return val_loss |
| 247 | + |
| 248 | + def configure_optimizers(self): |
| 249 | + return torch.optim.Adam(self.parameters(), lr=self.lr) |
| 250 | + |
| 251 | + class RandomDataModule(L.LightningDataModule): |
| 252 | + def __init__(self, batch_size=32): |
| 253 | + super().__init__() |
| 254 | + self.batch_size = batch_size |
| 255 | + |
| 256 | + def setup(self, stage=None): |
| 257 | + dataset = torch.utils.data.TensorDataset( |
| 258 | + torch.randn(100, 10), torch.randint(0, 2, (100,)) |
| 259 | + ) |
| 260 | + self.train, self.val = torch.utils.data.random_split(dataset, [80, 20]) |
| 261 | + |
| 262 | + def train_dataloader(self): |
| 263 | + return DataLoader(self.train, batch_size=self.batch_size) |
| 264 | + |
| 265 | + def val_dataloader(self): |
| 266 | + return DataLoader(self.val, batch_size=self.batch_size) |
| 267 | + |
| 268 | + datamodule = RandomDataModule(batch_size=16) |
| 269 | + |
| 270 | + params = { |
| 271 | + "datamodule": datamodule, |
| 272 | + "lightning_module": SimpleLightningModule, |
| 273 | + "trainer_kwargs": { |
| 274 | + "max_epochs": 1, |
| 275 | + "enable_progress_bar": False, |
| 276 | + "enable_model_summary": False, |
| 277 | + "logger": False, |
| 278 | + }, |
| 279 | + "objective_metric": "val_loss", |
| 280 | + } |
| 281 | + |
| 282 | + class RegressionModule(L.LightningModule): |
| 283 | + def __init__(self, num_layers=2, hidden_size=32, dropout=0.1): |
| 284 | + super().__init__() |
| 285 | + self.save_hyperparameters() |
| 286 | + layers = [] |
| 287 | + input_size = 20 |
| 288 | + for _ in range(num_layers): |
| 289 | + layers.extend( |
| 290 | + [ |
| 291 | + nn.Linear(input_size, hidden_size), |
| 292 | + nn.ReLU(), |
| 293 | + nn.Dropout(dropout), |
| 294 | + ] |
| 295 | + ) |
| 296 | + input_size = hidden_size |
| 297 | + layers.append(nn.Linear(hidden_size, 1)) |
| 298 | + self.model = nn.Sequential(*layers) |
| 299 | + |
| 300 | + def forward(self, x): |
| 301 | + return self.model(x) |
| 302 | + |
| 303 | + def training_step(self, batch, batch_idx): |
| 304 | + x, y = batch |
| 305 | + y_hat = self(x).squeeze() |
| 306 | + loss = nn.functional.mse_loss(y_hat, y) |
| 307 | + self.log("train_loss", loss) |
| 308 | + return loss |
| 309 | + |
| 310 | + def validation_step(self, batch, batch_idx): |
| 311 | + x, y = batch |
| 312 | + y_hat = self(x).squeeze() |
| 313 | + val_loss = nn.functional.mse_loss(y_hat, y) |
| 314 | + self.log("val_loss", val_loss, on_epoch=True) |
| 315 | + return val_loss |
| 316 | + |
| 317 | + def configure_optimizers(self): |
| 318 | + return torch.optim.SGD(self.parameters(), lr=0.01) |
| 319 | + |
| 320 | + class RegressionDataModule(L.LightningDataModule): |
| 321 | + def __init__(self, batch_size=16, num_samples=150): |
| 322 | + super().__init__() |
| 323 | + self.batch_size = batch_size |
| 324 | + self.num_samples = num_samples |
| 325 | + |
| 326 | + def setup(self, stage=None): |
| 327 | + X = torch.randn(self.num_samples, 20) |
| 328 | + y = torch.randn(self.num_samples) |
| 329 | + dataset = torch.utils.data.TensorDataset(X, y) |
| 330 | + train_size = int(0.8 * self.num_samples) |
| 331 | + val_size = self.num_samples - train_size |
| 332 | + self.train, self.val = torch.utils.data.random_split( |
| 333 | + dataset, [train_size, val_size] |
| 334 | + ) |
| 335 | + |
| 336 | + def train_dataloader(self): |
| 337 | + return DataLoader(self.train, batch_size=self.batch_size) |
| 338 | + |
| 339 | + def val_dataloader(self): |
| 340 | + return DataLoader(self.val, batch_size=self.batch_size) |
| 341 | + |
| 342 | + datamodule2 = RegressionDataModule(batch_size=16, num_samples=150) |
| 343 | + |
| 344 | + params2 = { |
| 345 | + "datamodule": datamodule2, |
| 346 | + "lightning_module": RegressionModule, |
| 347 | + "trainer_kwargs": { |
| 348 | + "max_epochs": 1, |
| 349 | + "enable_progress_bar": False, |
| 350 | + "enable_model_summary": False, |
| 351 | + "logger": False, |
| 352 | + }, |
| 353 | + "objective_metric": "val_loss", |
| 354 | + } |
| 355 | + |
| 356 | + return [params, params2] |
| 357 | + |
| 358 | + @classmethod |
| 359 | + def _get_score_params(cls): |
| 360 | + """Return settings for testing score/evaluate functions. |
| 361 | +
|
| 362 | + Returns a list, the i-th element should be valid arguments for |
| 363 | + self.evaluate and self.score, of an instance constructed with |
| 364 | + self.get_test_params()[i]. |
| 365 | +
|
| 366 | + Returns |
| 367 | + ------- |
| 368 | + list of dict |
| 369 | + The parameters to be used for scoring. |
| 370 | + """ |
| 371 | + score_params1 = {"input_dim": 10, "hidden_dim": 20, "lr": 0.001} |
| 372 | + score_params2 = {"num_layers": 3, "hidden_size": 64, "dropout": 0.2} |
| 373 | + return [score_params1, score_params2] |
0 commit comments