Skip to content

Commit cc36cfe

Browse files
authored
Merge pull request #72 from GPflow/recompilation_fix
Recompilation fix
2 parents 6354112 + 5ba4259 commit cc36cfe

6 files changed

Lines changed: 98 additions & 5 deletions

File tree

GPflowOpt/acquisition/acquisition.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,7 @@ def models(self):
174174
175175
:return: list of GPflow models
176176
"""
177-
return self._models
177+
return self._models.sorted_params
178178

179179
@property
180180
def data(self):
@@ -318,7 +318,7 @@ def _optimize_models(self):
318318

319319
@Acquisition.models.getter
320320
def models(self):
321-
return ParamList([model for acq in self.operands for model in acq.models.sorted_params])
321+
return [model for acq in self.operands for model in acq.models]
322322

323323
def enable_scaling(self, domain):
324324
for oper in self.operands:

GPflowOpt/models.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,35 @@
1515
from GPflow.model import Model
1616

1717

18+
class ParentHook(object):
19+
"""
20+
Temporary solution for fixing the recompilation issues (#37, GPflow issue #442).
21+
22+
An object of this class is returned when highest_parent is called on a model, which holds references to the highest
23+
parentable, as well as the highest model class. When setting the needs recompile flag, this is intercepted and
24+
performed on the model. At the same time, kill autoflow is called on the highest parent.
25+
"""
26+
def __init__(self, highest_parent, highest_model):
27+
self._hp = highest_parent
28+
self._hm = highest_model
29+
30+
def __getattr__(self, item):
31+
if item is '_needs_recompile':
32+
return getattr(self._hm, item)
33+
return getattr(self._hp, item)
34+
35+
def __setattr__(self, key, value):
36+
if key in ['_hp', '_hm']:
37+
object.__setattr__(self, key, value)
38+
return
39+
if key is '_needs_recompile':
40+
setattr(self._hm, key, value)
41+
if value:
42+
self._hp._kill_autoflow()
43+
else:
44+
setattr(self._hp, key, value)
45+
46+
1847
class ModelWrapper(Parameterized):
1948
"""
2049
Class for fast implementation of a wrapper for models defined in GPflow.
@@ -25,6 +54,7 @@ class ModelWrapper(Parameterized):
2554
AutoFlow methods are influenced following this pattern, the original AF storage (if existing) is unaffected and a
2655
new storage is added to the subclass.
2756
"""
57+
2858
def __init__(self, model):
2959
"""
3060
:param model: model to be wrapped
@@ -45,6 +75,7 @@ def __getattr__(self, item):
4575
method = item[1:].rstrip('_AF_storage')
4676
if method in dir(self):
4777
raise AttributeError("{0} has no attribute {1}".format(self.__class__.__name__, item))
78+
4879
return getattr(self.wrapped, item)
4980

5081
def __setattr__(self, key, value):
@@ -90,3 +121,11 @@ def __eq__(self, other):
90121
def name(self):
91122
name = super(ModelWrapper, self).name
92123
return ".".join([name, str.lower(self.__class__.__name__)])
124+
125+
@Parameterized.highest_parent.getter
126+
def highest_parent(self):
127+
"""
128+
Returns an instance of the ParentHook instead of the usual reference to a Parentable.
129+
"""
130+
original_hp = super(ModelWrapper, self).highest_parent
131+
return original_hp if isinstance(original_hp, ParentHook) else ParentHook(original_hp, self)

GPflowOpt/scaling.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
# limitations under the License.
1414

1515
from GPflow.param import DataHolder, AutoFlow
16-
from GPflow.model import GPModel
1716
from GPflow import settings
1817
import numpy as np
1918
from .transforms import LinearTransform, DataTransform
@@ -53,6 +52,7 @@ class DataScaler(ModelWrapper):
5352
required, it is the responsibility of the implementation to rescale the hyperparameters. Additionally, applying
5453
hyperpriors should anticipate for the scaled data.
5554
"""
55+
5656
def __init__(self, model, domain=None, normalize_Y=False):
5757
"""
5858
:param model: model to be wrapped

testing/test_acquisition.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,8 @@ def test_object_integrity(self, acquisition):
146146
for oper in acquisition.operands:
147147
self.assertTrue(isinstance(oper, GPflowOpt.acquisition.Acquisition),
148148
msg="All operands should be an acquisition object")
149-
self.assertTrue(all(isinstance(m, GPflowOpt.models.ModelWrapper) for m in acquisition.models.sorted_params))
149+
150+
self.assertTrue(all(isinstance(m, GPflowOpt.models.ModelWrapper) for m in acquisition.models))
150151

151152
@parameterized.expand(list(zip(aggregations)))
152153
def test_data(self, acquisition):
@@ -219,7 +220,7 @@ def test_marginalized_score(self, acquisition):
219220

220221
@parameterized.expand(list(zip([aggregations[2]])))
221222
def test_mcmc_acq_models(self, acquisition):
222-
self.assertListEqual(acquisition.models.sorted_params, acquisition.operands[0].models.sorted_params)
223+
self.assertListEqual(acquisition.models, acquisition.operands[0].models)
223224

224225

225226
class TestJointAcquisition(unittest.TestCase):
@@ -298,3 +299,26 @@ def test_multi_aggr(self):
298299
joint = first * second
299300
self.assertIsInstance(joint, GPflowOpt.acquisition.AcquisitionProduct)
300301
self.assertListEqual(joint.operands.sorted_params, [acq1, acq2, acq3, acq4])
302+
303+
304+
class TestRecompile(unittest.TestCase):
305+
"""
306+
Regression test for #37
307+
"""
308+
def test_vgp(self):
309+
domain = GPflowOpt.domain.UnitCube(2)
310+
X = GPflowOpt.design.RandomDesign(10, domain).generate()
311+
Y = np.sin(X[:,[0]])
312+
m = GPflow.vgp.VGP(X, Y, GPflow.kernels.RBF(2), GPflow.likelihoods.Gaussian())
313+
acq = GPflowOpt.acquisition.ExpectedImprovement(m)
314+
m._compile()
315+
self.assertFalse(m._needs_recompile)
316+
acq.evaluate(GPflowOpt.design.RandomDesign(10, domain).generate())
317+
self.assertTrue(hasattr(acq, '_evaluate_AF_storage'))
318+
319+
Xnew = GPflowOpt.design.RandomDesign(5, domain).generate()
320+
Ynew = np.sin(Xnew[:,[0]])
321+
acq.set_data(np.vstack((X, Xnew)), np.vstack((Y, Ynew)))
322+
self.assertFalse(hasattr(acq, '_needs_recompile'))
323+
self.assertFalse(hasattr(acq, '_evaluate_AF_storage'))
324+
acq.evaluate(GPflowOpt.design.RandomDesign(10, domain).generate())

testing/test_datascaler.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,3 +106,4 @@ def test_predict_scaling(self):
106106
fs = n.predict_density(Xt, Yt)
107107
np.testing.assert_allclose(fr, fs, rtol=1e-2)
108108

109+

testing/test_modelwrapper.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,35 @@ def test_name(self):
111111
n = MethodOverride(create_parabola_model(GPflowOpt.domain.UnitCube(2)))
112112
self.assertEqual(n.name, 'unnamed.methodoverride')
113113

114+
def test_parent_hook(self):
115+
self.m.optimize(maxiter=5)
116+
w = GPflowOpt.models.ModelWrapper(self.m)
117+
self.assertTrue(isinstance(self.m.highest_parent, GPflowOpt.models.ParentHook))
118+
self.assertEqual(self.m.highest_parent._hp, w)
119+
self.assertEqual(self.m.highest_parent._hm, w)
120+
121+
w2 = GPflowOpt.models.ModelWrapper(w)
122+
self.assertEqual(self.m.highest_parent._hp, w2)
123+
self.assertEqual(self.m.highest_parent._hm, w2)
114124

125+
p = GPflow.param.Parameterized()
126+
p.model = w2
127+
self.assertEqual(self.m.highest_parent._hp, p)
128+
self.assertEqual(self.m.highest_parent._hm, w2)
115129

130+
p.predictor = create_parabola_model(GPflowOpt.domain.UnitCube(2))
131+
p.predictor.predict_f(p.predictor.X.value)
132+
self.assertTrue(hasattr(p.predictor, '_predict_f_AF_storage'))
133+
self.assertFalse(self.m._needs_recompile)
134+
self.m.highest_parent._needs_recompile = True
135+
self.assertFalse('_needs_recompile' in p.__dict__)
136+
self.assertFalse('_needs_recompile' in w.__dict__)
137+
self.assertFalse('_needs_recompile' in w2.__dict__)
138+
self.assertTrue(self.m._needs_recompile)
139+
self.assertFalse(hasattr(p.predictor, '_predict_f_AF_storage'))
140+
141+
self.assertEqual(self.m.highest_parent.get_free_state, p.get_free_state)
142+
self.m.highest_parent._needs_setup = True
143+
self.assertTrue(hasattr(p, '_needs_setup'))
144+
self.assertTrue(p._needs_setup)
116145

0 commit comments

Comments
 (0)