Skip to content

Commit f9cd2d1

Browse files
authored
GibbsKernel added (#2744)
Unittests added + docs updated
1 parent fe24b22 commit f9cd2d1

4 files changed

Lines changed: 171 additions & 0 deletions

File tree

docs/source/kernels.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,13 @@ Standard Kernels
4343
:members:
4444

4545

46+
:hidden:`GibbsKernel`
47+
~~~~~~~~~~~~~~~~~~~~~~~~~~~
48+
49+
.. autoclass:: GibbsKernel
50+
:members:
51+
52+
4653
:hidden:`LinearKernel`
4754
~~~~~~~~~~~~~~~~~~~~~~
4855

gpytorch/kernels/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from .cylindrical_kernel import CylindricalKernel
1111
from .distributional_input_kernel import DistributionalInputKernel
1212
from .gaussian_symmetrized_kl_kernel import GaussianSymmetrizedKLKernel
13+
from .gibbs_kernel import GibbsKernel
1314
from .grid_interpolation_kernel import GridInterpolationKernel
1415
from .grid_kernel import GridKernel
1516
from .hamming_kernel import HammingIMQKernel
@@ -49,6 +50,7 @@
4950
"CosineKernel",
5051
"DistributionalInputKernel",
5152
"GaussianSymmetrizedKLKernel",
53+
"GibbsKernel",
5254
"GridKernel",
5355
"GridInterpolationKernel",
5456
"HammingIMQKernel",

gpytorch/kernels/gibbs_kernel.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
#!/usr/bin/env python3
2+
3+
from __future__ import annotations
4+
5+
from copy import deepcopy
6+
7+
import torch
8+
from torch import nn
9+
10+
from .kernel import Kernel
11+
12+
13+
class GibbsKernel(Kernel):
14+
r"""
15+
Gibbs kernel with input-dependent lengthscale :math:`\ell(x)` (Gibbs, 1997)
16+
17+
.. math::
18+
k(x, x') = \sqrt{\frac{2\ell(x)\ell(x')}{\ell(x)^2 + \ell(x')^2}}
19+
\exp\left(-\frac{(x-x')^2}{\ell(x)^2 + \ell(x')^2}\right)
20+
21+
:param lengthscale_fn: A callable torch.nn.Module mapping inputs to
22+
positive lengthscales. Must output tensors of shape (... x N x 1)
23+
for input of shape (... x N x D)
24+
:type lengthscale_fn: torch.nn.Module
25+
26+
Example::
27+
28+
class LengthscaleMLP(torch.nn.Module):
29+
def __init__(self, in_dim=1, hidden=32):
30+
super().__init__()
31+
self.net = torch.nn.Sequential(
32+
torch.nn.Linear(in_dim, hidden),
33+
torch.nn.ReLU(),
34+
torch.nn.Linear(hidden, 1),
35+
torch.nn.Softplus(),
36+
)
37+
38+
def forward(self, x):
39+
return self.net(x)
40+
41+
kernel = GibbsKernel(lengthscale_fn=LengthscaleMLP(in_dim=1))
42+
"""
43+
44+
is_stationary = False
45+
has_lengthscale = False
46+
47+
def __init__(self, lengthscale_fn: nn.Module, **kwargs):
48+
if kwargs.get("ard_num_dims") is not None:
49+
raise NotImplementedError("GibbsKernel does not support ARD.")
50+
super().__init__(**kwargs)
51+
self.lengthscale_fn = lengthscale_fn
52+
53+
# Update batch_shape explicitly:
54+
# Base class derives new batch_shape from parameters,
55+
# but GibbsKernel has none
56+
def __getitem__(self, index):
57+
if len(self.batch_shape) == 0:
58+
return self
59+
new_kernel = deepcopy(self)
60+
index = index if isinstance(index, tuple) else (index,)
61+
new_kernel.batch_shape = torch.empty(self.batch_shape)[index].shape
62+
return new_kernel
63+
64+
def forward(self, x1: torch.Tensor, x2: torch.Tensor, diag: bool = False, **params):
65+
x1_eq_x2 = torch.equal(x1, x2)
66+
67+
l1 = self.lengthscale_fn(x1)
68+
if l1.shape[-1] != 1:
69+
raise ValueError(f"lengthscale_fn must return shape (..., k, 1), got (..., k, {l1.shape[-1]})")
70+
l2 = l1 if x1_eq_x2 else self.lengthscale_fn(x2)
71+
72+
dist_sq = self.covar_dist(x1, x2, square_dist=True, diag=diag, **params)
73+
74+
if diag:
75+
S = (l1.pow(2) + l2.pow(2)).squeeze(-1)
76+
prod = (l1 * l2).squeeze(-1)
77+
else:
78+
S = l1.pow(2) + l2.pow(2).transpose(-2, -1)
79+
prod = l1 * l2.transpose(-2, -1)
80+
81+
prefactor = (2.0 * prod / S).sqrt()
82+
return prefactor * (-dist_sq / S).exp()

test/kernels/test_gibbs_kernel.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
import unittest
2+
3+
import torch
4+
from torch import nn
5+
6+
from gpytorch.kernels import GibbsKernel, RBFKernel
7+
from gpytorch.test.base_kernel_test_case import BaseKernelTestCase
8+
9+
10+
class ConstantLengthscale(nn.Module):
11+
r"""Constant :math:`\ell(x) = \exp(c)`"""
12+
13+
def __init__(self, value: float = 1.0):
14+
super().__init__()
15+
self.log_value = nn.Parameter(torch.tensor(value).log())
16+
17+
def forward(self, x: torch.Tensor) -> torch.Tensor:
18+
return self.log_value.exp().expand(*x.shape[:-1], 1)
19+
20+
21+
class MLPLengthscale(nn.Module):
22+
"""Small MLP, non-constant lengthscale function."""
23+
24+
def __init__(self, in_dim: int = 1, hidden: int = 16):
25+
super().__init__()
26+
self.net = nn.Sequential(
27+
nn.Linear(in_dim, hidden),
28+
nn.ReLU(),
29+
nn.Linear(hidden, 1),
30+
)
31+
nn.init.normal_(self.net[-1].weight, std=0.01)
32+
nn.init.zeros_(self.net[-1].bias)
33+
34+
def forward(self, x: torch.Tensor) -> torch.Tensor:
35+
return torch.exp(self.net(x))
36+
37+
38+
class TestGibbsKernel(BaseKernelTestCase, unittest.TestCase):
39+
def create_data_no_batch(self):
40+
return torch.randn(50, 10)
41+
42+
def create_kernel_no_ard(self, **kwargs):
43+
return GibbsKernel(ConstantLengthscale(), **kwargs)
44+
45+
def setUp(self):
46+
self.lfn = ConstantLengthscale(value=1.0)
47+
self.kernel = GibbsKernel(self.lfn)
48+
49+
def test_diagonal_is_one(self):
50+
r""":math:`k(x, x) = 1` for all :math:`x`."""
51+
for lfn in [ConstantLengthscale(), MLPLengthscale(in_dim=2)]:
52+
kernel = GibbsKernel(lfn)
53+
x = torch.randn(20, 2)
54+
K = kernel(x).to_dense()
55+
self.assertTrue(torch.allclose(K.diagonal(), torch.ones(20), atol=1e-5))
56+
57+
def test_reduces_to_rbf_with_constant_lengthscale(self):
58+
r"""With constant :math:`\ell(x) = \ell`, Gibbs reduces to RBF."""
59+
l = 1.5
60+
kernel_gibbs = GibbsKernel(ConstantLengthscale(value=l))
61+
kernel_rbf = RBFKernel()
62+
kernel_rbf.lengthscale = l
63+
64+
x1 = torch.randn(8, 1)
65+
x2 = torch.randn(6, 1)
66+
67+
K_gibbs = kernel_gibbs(x1, x2).to_dense()
68+
K_rbf = kernel_rbf(x1, x2).to_dense()
69+
70+
self.assertTrue(torch.allclose(K_gibbs, K_rbf, atol=1e-5))
71+
72+
def test_gradient_flows_to_lengthscale_fn(self):
73+
"""Gradients propagate through lengthscale_fn."""
74+
kernel = GibbsKernel(MLPLengthscale(in_dim=2))
75+
x = torch.randn(8, 2)
76+
kernel(x).to_dense().sum().backward()
77+
78+
for name, param in kernel.lengthscale_fn.named_parameters():
79+
self.assertIsNotNone(param.grad, f"No gradient for {name}")
80+
self.assertFalse(torch.all(param.grad == 0), f"Zero gradient for {name}")

0 commit comments

Comments
 (0)