Skip to content

Commit 990aff4

Browse files
authored
Fix predict() returning stale results for different w_pred grids
_maybe_prepare_inference never checked if w_pred changed between calls, so calling predict() a second time with a different grid would reuse the cached OpKer from the first call and blow up or give wrong results. Now we store w_pred alongside OpKer and compare before reusing.
1 parent 09b03af commit 990aff4

2 files changed

Lines changed: 43 additions & 6 deletions

File tree

fredipy/models.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import scipy as sp
99

1010
from .covariance import TwoSided, OneSided
11-
from .util import make_column_vector
11+
from .util import allclose, make_column_vector
1212

1313

1414
class Model:
@@ -218,11 +218,12 @@ def _maybe_prepare_inference(
218218
self,
219219
w_pred: np.ndarray,
220220
) -> None:
221-
if not self._inference_cache:
222-
OpKer = self.OpKer(
223-
self.kernel, self.constraints, w_pred)
224-
self._inference_cache = {
225-
'OpKer': OpKer}
221+
if self._inference_cache and allclose(w_pred, self._inference_cache['w_pred']):
222+
return
223+
OpKer = self.OpKer(
224+
self.kernel, self.constraints, w_pred)
225+
self._inference_cache = {
226+
'OpKer': OpKer, 'w_pred': w_pred}
226227

227228

228229
class GP(GaussianProcess):

tests/test_reconstruction.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,3 +169,39 @@ def test_dressing_1D() -> None:
169169
print(devs)
170170
assert all(i > 0 for i in devs), \
171171
"Reconstructed data does not match input"
172+
173+
174+
def test_predict_different_w_pred() -> None:
175+
"""predict() called twice with different grids should give correct results both times"""
176+
w_pred_1 = np.arange(0.5, 5, 0.5)
177+
w_pred_2 = np.arange(0.1, 10, 0.1)
178+
p = np.linspace(0.1, 10, 30)
179+
180+
a = 1.6
181+
m = 1
182+
g = 0.8
183+
184+
G = get_G(p, a, m, g)
185+
err = 1e-5
186+
187+
data = {
188+
'x': p,
189+
'y': G + err * rng.randn(len(G)),
190+
'cov_y': err**2 * np.ones_like(p)}
191+
192+
kernel = fp.kernels.RadialBasisFunction(0.5, 0.3)
193+
integrator = fp.integrators.Riemann_1D(0, 10, 500)
194+
integral_op = fp.operators.Integral(kl_kernel, integrator)
195+
constraints = [fp.constraints.LinearEquality(integral_op, data)]
196+
model = fp.models.GaussianProcess(kernel, constraints)
197+
198+
rho1, err1 = model.predict(w_pred_1)
199+
rho2, err2 = model.predict(w_pred_2)
200+
201+
ref1 = get_rho(w_pred_1, a, m, g)
202+
ref2 = get_rho(w_pred_2, a, m, g)
203+
204+
devs1 = err1 - abs(rho1.flatten() - ref1)
205+
devs2 = err2 - abs(rho2.flatten() - ref2)
206+
assert all(i > 0 for i in devs1), "first predict wrong"
207+
assert all(i > 0 for i in devs2), "second predict wrong"

0 commit comments

Comments
 (0)