Skip to content

Commit 71b8c2b

Browse files
committed
Fix kernel factory return types
1 parent dc14a6f commit 71b8c2b

1 file changed

Lines changed: 2 additions & 2 deletions

File tree

  • baybe/surrogates/gaussian_process/components

baybe/surrogates/gaussian_process/components/kernel.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ class _MetaKernelFactory(KernelFactoryProtocol, ABC):
123123
@abstractmethod
124124
def __call__(
125125
self, searchspace: SearchSpace, train_x: Tensor, train_y: Tensor
126-
) -> Kernel: ...
126+
) -> Kernel | GPyTorchKernel: ...
127127

128128

129129
@define
@@ -166,7 +166,7 @@ def _default_task_kernel_factory(self) -> KernelFactoryProtocol:
166166
@override
167167
def __call__(
168168
self, searchspace: SearchSpace, train_x: Tensor, train_y: Tensor
169-
) -> Kernel:
169+
) -> Kernel | GPyTorchKernel:
170170
if searchspace.task_idx is None:
171171
raise IncompatibleSearchSpaceError(
172172
f"'{type(self).__name__}' can only be used with a searchspace that "

0 commit comments

Comments
 (0)