|
| 1 | +import unittest |
| 2 | + |
| 3 | +import torch |
| 4 | +from torch import nn |
| 5 | + |
| 6 | +from gpytorch.kernels import GibbsKernel, RBFKernel |
| 7 | +from gpytorch.test.base_kernel_test_case import BaseKernelTestCase |
| 8 | + |
| 9 | + |
| 10 | +class ConstantLengthscale(nn.Module): |
| 11 | + r"""Constant :math:`\ell(x) = \exp(c)`""" |
| 12 | + |
| 13 | + def __init__(self, value: float = 1.0): |
| 14 | + super().__init__() |
| 15 | + self.log_value = nn.Parameter(torch.tensor(value).log()) |
| 16 | + |
| 17 | + def forward(self, x: torch.Tensor) -> torch.Tensor: |
| 18 | + return self.log_value.exp().expand(*x.shape[:-1], 1) |
| 19 | + |
| 20 | + |
| 21 | +class MLPLengthscale(nn.Module): |
| 22 | + """Small MLP, non-constant lengthscale function.""" |
| 23 | + |
| 24 | + def __init__(self, in_dim: int = 1, hidden: int = 16): |
| 25 | + super().__init__() |
| 26 | + self.net = nn.Sequential( |
| 27 | + nn.Linear(in_dim, hidden), |
| 28 | + nn.ReLU(), |
| 29 | + nn.Linear(hidden, 1), |
| 30 | + ) |
| 31 | + nn.init.normal_(self.net[-1].weight, std=0.01) |
| 32 | + nn.init.zeros_(self.net[-1].bias) |
| 33 | + |
| 34 | + def forward(self, x: torch.Tensor) -> torch.Tensor: |
| 35 | + return torch.exp(self.net(x)) |
| 36 | + |
| 37 | + |
| 38 | +class TestGibbsKernel(BaseKernelTestCase, unittest.TestCase): |
| 39 | + def create_data_no_batch(self): |
| 40 | + return torch.randn(50, 10) |
| 41 | + |
| 42 | + def create_kernel_no_ard(self, **kwargs): |
| 43 | + return GibbsKernel(ConstantLengthscale(), **kwargs) |
| 44 | + |
| 45 | + def setUp(self): |
| 46 | + self.lfn = ConstantLengthscale(value=1.0) |
| 47 | + self.kernel = GibbsKernel(self.lfn) |
| 48 | + |
| 49 | + def test_diagonal_is_one(self): |
| 50 | + r""":math:`k(x, x) = 1` for all :math:`x`.""" |
| 51 | + for lfn in [ConstantLengthscale(), MLPLengthscale(in_dim=2)]: |
| 52 | + kernel = GibbsKernel(lfn) |
| 53 | + x = torch.randn(20, 2) |
| 54 | + K = kernel(x).to_dense() |
| 55 | + self.assertTrue(torch.allclose(K.diagonal(), torch.ones(20), atol=1e-5)) |
| 56 | + |
| 57 | + def test_reduces_to_rbf_with_constant_lengthscale(self): |
| 58 | + r"""With constant :math:`\ell(x) = \ell`, Gibbs reduces to RBF.""" |
| 59 | + l = 1.5 |
| 60 | + kernel_gibbs = GibbsKernel(ConstantLengthscale(value=l)) |
| 61 | + kernel_rbf = RBFKernel() |
| 62 | + kernel_rbf.lengthscale = l |
| 63 | + |
| 64 | + x1 = torch.randn(8, 1) |
| 65 | + x2 = torch.randn(6, 1) |
| 66 | + |
| 67 | + K_gibbs = kernel_gibbs(x1, x2).to_dense() |
| 68 | + K_rbf = kernel_rbf(x1, x2).to_dense() |
| 69 | + |
| 70 | + self.assertTrue(torch.allclose(K_gibbs, K_rbf, atol=1e-5)) |
| 71 | + |
| 72 | + def test_gradient_flows_to_lengthscale_fn(self): |
| 73 | + """Gradients propagate through lengthscale_fn.""" |
| 74 | + kernel = GibbsKernel(MLPLengthscale(in_dim=2)) |
| 75 | + x = torch.randn(8, 2) |
| 76 | + kernel(x).to_dense().sum().backward() |
| 77 | + |
| 78 | + for name, param in kernel.lengthscale_fn.named_parameters(): |
| 79 | + self.assertIsNotNone(param.grad, f"No gradient for {name}") |
| 80 | + self.assertFalse(torch.all(param.grad == 0), f"Zero gradient for {name}") |
0 commit comments