Skip to content

Commit 87e670a

Browse files
Add PyTorch Lightning integration (#203)
This pull request is to merge a new experiment `torch-lightning-experiment` for the issue `195`. The new class `TorchExperiment` is an experiment adapter for PyTorch Lightning experiments and is used to perform experiments using PyTorch Lightning modules. It allows for hyperparameter tuning and evaluation of the model's performance using specified metrics. The `TorchExperiment` class accepts a `LightningModule`, `DataModule`, `Trainer`, and an `objective_matric` with default value `val_loss`. The `_evaluate` function internally performs a training run and returns a score of the `objective_matric`. Fixes #195
1 parent aee4558 commit 87e670a

3 files changed

Lines changed: 378 additions & 0 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ test_parallel_backends = [
7474
all_extras = [
7575
"hyperactive[integrations]",
7676
"optuna<5",
77+
"lightning",
7778
]
7879

7980

src/hyperactive/experiment/integrations/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,14 @@
1111
from hyperactive.experiment.integrations.sktime_forecasting import (
1212
SktimeForecastingExperiment,
1313
)
14+
from hyperactive.experiment.integrations.torch_lightning_experiment import (
15+
TorchExperiment,
16+
)
1417

1518
__all__ = [
1619
"SklearnCvExperiment",
1720
"SkproProbaRegExperiment",
1821
"SktimeClassificationExperiment",
1922
"SktimeForecastingExperiment",
23+
"TorchExperiment",
2024
]
Lines changed: 373 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,373 @@
1+
"""Experiment adapter for PyTorch Lightning experiments."""
2+
3+
# copyright: hyperactive developers, MIT License (see LICENSE file)
4+
5+
__author__ = ["amitsubhashchejara"]
6+
7+
import numpy as np
8+
9+
from hyperactive.base import BaseExperiment
10+
11+
12+
class TorchExperiment(BaseExperiment):
13+
"""Experiment adapter for PyTorch Lightning experiments.
14+
15+
This class is used to perform experiments using PyTorch Lightning modules.
16+
It allows for hyperparameter tuning and evaluation of the model's performance
17+
using specified metrics.
18+
19+
The experiment trains a Lightning module with given hyperparameters and returns
20+
the validation metric value for optimization.
21+
22+
Parameters
23+
----------
24+
datamodule : L.LightningDataModule
25+
A PyTorch Lightning DataModule that handles data loading and preparation.
26+
lightning_module : type
27+
A PyTorch Lightning Module class (not an instance) that will be instantiated
28+
with hyperparameters during optimization.
29+
trainer_kwargs : dict, optional (default=None)
30+
A dictionary of keyword arguments to pass to the PyTorch Lightning Trainer.
31+
objective_metric : str, optional (default='val_loss')
32+
The metric used to evaluate the model's performance. This should correspond
33+
to a metric logged in the LightningModule during validation.
34+
35+
Examples
36+
--------
37+
>>> from hyperactive.experiment.integrations import TorchExperiment
38+
>>> import torch
39+
>>> import lightning as L
40+
>>> from torch import nn
41+
>>> from torch.utils.data import DataLoader
42+
>>>
43+
>>> # Define a simple Lightning Module
44+
>>> class SimpleLightningModule(L.LightningModule):
45+
... def __init__(self, input_dim=10, hidden_dim=16, lr=1e-3):
46+
... super().__init__()
47+
... self.save_hyperparameters()
48+
... self.model = nn.Sequential(
49+
... nn.Linear(input_dim, hidden_dim),
50+
... nn.ReLU(),
51+
... nn.Linear(hidden_dim, 2)
52+
... )
53+
... self.lr = lr
54+
...
55+
... def forward(self, x):
56+
... return self.model(x)
57+
...
58+
... def training_step(self, batch, batch_idx):
59+
... x, y = batch
60+
... y_hat = self(x)
61+
... loss = nn.functional.cross_entropy(y_hat, y)
62+
... self.log("train_loss", loss)
63+
... return loss
64+
...
65+
... def validation_step(self, batch, batch_idx):
66+
... x, y = batch
67+
... y_hat = self(x)
68+
... val_loss = nn.functional.cross_entropy(y_hat, y)
69+
... self.log("val_loss", val_loss, on_epoch=True)
70+
... return val_loss
71+
...
72+
... def configure_optimizers(self):
73+
... return torch.optim.Adam(self.parameters(), lr=self.lr)
74+
>>>
75+
>>> # Create DataModule
76+
>>> class RandomDataModule(L.LightningDataModule):
77+
... def __init__(self, batch_size=32):
78+
... super().__init__()
79+
... self.batch_size = batch_size
80+
...
81+
... def setup(self, stage=None):
82+
... dataset = torch.utils.data.TensorDataset(
83+
... torch.randn(100, 10),
84+
... torch.randint(0, 2, (100,))
85+
... )
86+
... self.train, self.val = torch.utils.data.random_split(
87+
... dataset, [80, 20]
88+
... )
89+
...
90+
... def train_dataloader(self):
91+
... return DataLoader(self.train, batch_size=self.batch_size)
92+
...
93+
... def val_dataloader(self):
94+
... return DataLoader(self.val, batch_size=self.batch_size)
95+
>>>
96+
>>> datamodule = RandomDataModule(batch_size=16)
97+
>>> datamodule.setup()
98+
>>>
99+
>>> # Create Experiment
100+
>>> experiment = TorchExperiment(
101+
... datamodule=datamodule,
102+
... lightning_module=SimpleLightningModule,
103+
... trainer_kwargs={'max_epochs': 3},
104+
... objective_metric="val_loss"
105+
... )
106+
>>>
107+
>>> params = {"input_dim": 10, "hidden_dim": 16, "lr": 1e-3}
108+
>>>
109+
>>> val_result, metadata = experiment._evaluate(params)
110+
"""
111+
112+
_tags = {
113+
"property:randomness": "random",
114+
"property:higher_or_lower_is_better": "lower",
115+
"authors": ["amitsubhashchejara"],
116+
"python_dependencies": ["torch", "lightning"],
117+
}
118+
119+
def __init__(
120+
self,
121+
datamodule,
122+
lightning_module,
123+
trainer_kwargs=None,
124+
objective_metric: str = "val_loss",
125+
):
126+
self.datamodule = datamodule
127+
self.lightning_module = lightning_module
128+
self.trainer_kwargs = trainer_kwargs or {}
129+
self.objective_metric = objective_metric
130+
131+
super().__init__()
132+
133+
self._trainer_kwargs = {
134+
"max_epochs": 10,
135+
"enable_checkpointing": False,
136+
"logger": False,
137+
"enable_progress_bar": False,
138+
"enable_model_summary": False,
139+
}
140+
if trainer_kwargs is not None:
141+
self._trainer_kwargs.update(trainer_kwargs)
142+
143+
def _paramnames(self):
144+
"""Return the parameter names of the search.
145+
146+
Returns
147+
-------
148+
list of str, or None
149+
The parameter names of the search parameters.
150+
If not known or arbitrary, return None.
151+
"""
152+
import inspect
153+
154+
sig = inspect.signature(self.lightning_module.__init__)
155+
return [p for p in sig.parameters.keys() if p != "self"]
156+
157+
def _evaluate(self, params):
158+
"""Evaluate the parameters.
159+
160+
Parameters
161+
----------
162+
params : dict with string keys
163+
Parameters to evaluate.
164+
165+
Returns
166+
-------
167+
float
168+
The value of the parameters as per evaluation.
169+
dict
170+
Additional metadata about the search.
171+
"""
172+
import lightning as L
173+
174+
try:
175+
model = self.lightning_module(**params)
176+
trainer = L.Trainer(**self._trainer_kwargs)
177+
trainer.fit(model, self.datamodule)
178+
179+
val_result = trainer.callback_metrics.get(self.objective_metric)
180+
metadata = {}
181+
182+
if val_result is None:
183+
available_metrics = list(trainer.callback_metrics.keys())
184+
raise ValueError(
185+
f"Metric '{self.objective_metric}' not found. "
186+
f"Available: {available_metrics}"
187+
)
188+
if hasattr(val_result, "item"):
189+
val_result = np.float64(val_result.detach().cpu().item())
190+
elif isinstance(val_result, (int, float)):
191+
val_result = np.float64(val_result)
192+
else:
193+
val_result = np.float64(float(val_result))
194+
195+
return val_result, metadata
196+
197+
except Exception as e:
198+
print(f"Training failed with params {params}: {e}")
199+
return np.float64(float("inf")), {}
200+
201+
@classmethod
202+
def get_test_params(cls, parameter_set="default"):
203+
"""Return testing parameter settings for the estimator.
204+
205+
Parameters
206+
----------
207+
parameter_set : str, default="default"
208+
Name of the set of test parameters to return, for use in tests.
209+
210+
Returns
211+
-------
212+
params : dict or list of dict, default = {}
213+
Parameters to create testing instances of the class.
214+
"""
215+
import lightning as L
216+
import torch
217+
from torch import nn
218+
from torch.utils.data import DataLoader
219+
220+
class SimpleLightningModule(L.LightningModule):
221+
def __init__(self, input_dim=10, hidden_dim=16, lr=1e-3):
222+
super().__init__()
223+
self.save_hyperparameters()
224+
self.model = nn.Sequential(
225+
nn.Linear(input_dim, hidden_dim),
226+
nn.ReLU(),
227+
nn.Linear(hidden_dim, 2),
228+
)
229+
self.lr = lr
230+
231+
def forward(self, x):
232+
return self.model(x)
233+
234+
def training_step(self, batch, batch_idx):
235+
x, y = batch
236+
y_hat = self(x)
237+
loss = nn.functional.cross_entropy(y_hat, y)
238+
self.log("train_loss", loss)
239+
return loss
240+
241+
def validation_step(self, batch, batch_idx):
242+
x, y = batch
243+
y_hat = self(x)
244+
val_loss = nn.functional.cross_entropy(y_hat, y)
245+
self.log("val_loss", val_loss, on_epoch=True)
246+
return val_loss
247+
248+
def configure_optimizers(self):
249+
return torch.optim.Adam(self.parameters(), lr=self.lr)
250+
251+
class RandomDataModule(L.LightningDataModule):
252+
def __init__(self, batch_size=32):
253+
super().__init__()
254+
self.batch_size = batch_size
255+
256+
def setup(self, stage=None):
257+
dataset = torch.utils.data.TensorDataset(
258+
torch.randn(100, 10), torch.randint(0, 2, (100,))
259+
)
260+
self.train, self.val = torch.utils.data.random_split(dataset, [80, 20])
261+
262+
def train_dataloader(self):
263+
return DataLoader(self.train, batch_size=self.batch_size)
264+
265+
def val_dataloader(self):
266+
return DataLoader(self.val, batch_size=self.batch_size)
267+
268+
datamodule = RandomDataModule(batch_size=16)
269+
270+
params = {
271+
"datamodule": datamodule,
272+
"lightning_module": SimpleLightningModule,
273+
"trainer_kwargs": {
274+
"max_epochs": 1,
275+
"enable_progress_bar": False,
276+
"enable_model_summary": False,
277+
"logger": False,
278+
},
279+
"objective_metric": "val_loss",
280+
}
281+
282+
class RegressionModule(L.LightningModule):
283+
def __init__(self, num_layers=2, hidden_size=32, dropout=0.1):
284+
super().__init__()
285+
self.save_hyperparameters()
286+
layers = []
287+
input_size = 20
288+
for _ in range(num_layers):
289+
layers.extend(
290+
[
291+
nn.Linear(input_size, hidden_size),
292+
nn.ReLU(),
293+
nn.Dropout(dropout),
294+
]
295+
)
296+
input_size = hidden_size
297+
layers.append(nn.Linear(hidden_size, 1))
298+
self.model = nn.Sequential(*layers)
299+
300+
def forward(self, x):
301+
return self.model(x)
302+
303+
def training_step(self, batch, batch_idx):
304+
x, y = batch
305+
y_hat = self(x).squeeze()
306+
loss = nn.functional.mse_loss(y_hat, y)
307+
self.log("train_loss", loss)
308+
return loss
309+
310+
def validation_step(self, batch, batch_idx):
311+
x, y = batch
312+
y_hat = self(x).squeeze()
313+
val_loss = nn.functional.mse_loss(y_hat, y)
314+
self.log("val_loss", val_loss, on_epoch=True)
315+
return val_loss
316+
317+
def configure_optimizers(self):
318+
return torch.optim.SGD(self.parameters(), lr=0.01)
319+
320+
class RegressionDataModule(L.LightningDataModule):
321+
def __init__(self, batch_size=16, num_samples=150):
322+
super().__init__()
323+
self.batch_size = batch_size
324+
self.num_samples = num_samples
325+
326+
def setup(self, stage=None):
327+
X = torch.randn(self.num_samples, 20)
328+
y = torch.randn(self.num_samples)
329+
dataset = torch.utils.data.TensorDataset(X, y)
330+
train_size = int(0.8 * self.num_samples)
331+
val_size = self.num_samples - train_size
332+
self.train, self.val = torch.utils.data.random_split(
333+
dataset, [train_size, val_size]
334+
)
335+
336+
def train_dataloader(self):
337+
return DataLoader(self.train, batch_size=self.batch_size)
338+
339+
def val_dataloader(self):
340+
return DataLoader(self.val, batch_size=self.batch_size)
341+
342+
datamodule2 = RegressionDataModule(batch_size=16, num_samples=150)
343+
344+
params2 = {
345+
"datamodule": datamodule2,
346+
"lightning_module": RegressionModule,
347+
"trainer_kwargs": {
348+
"max_epochs": 1,
349+
"enable_progress_bar": False,
350+
"enable_model_summary": False,
351+
"logger": False,
352+
},
353+
"objective_metric": "val_loss",
354+
}
355+
356+
return [params, params2]
357+
358+
@classmethod
359+
def _get_score_params(cls):
360+
"""Return settings for testing score/evaluate functions.
361+
362+
Returns a list, the i-th element should be valid arguments for
363+
self.evaluate and self.score, of an instance constructed with
364+
self.get_test_params()[i].
365+
366+
Returns
367+
-------
368+
list of dict
369+
The parameters to be used for scoring.
370+
"""
371+
score_params1 = {"input_dim": 10, "hidden_dim": 20, "lr": 0.001}
372+
score_params2 = {"num_layers": 3, "hidden_size": 64, "dropout": 0.2}
373+
return [score_params1, score_params2]

0 commit comments

Comments
 (0)