Skip to content

Commit 811ea4c

Browse files
authored
Merge pull request #64 from GPflow/model_wrapper_class
ModelWrapper
2 parents bbe9347 + 35cd171 commit 811ea4c

9 files changed

Lines changed: 260 additions & 54 deletions

File tree

GPflowOpt/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,5 +21,6 @@
2121
from . import scaling
2222
from . import objective
2323
from . import pareto
24+
from . import models
2425

2526
from ._version import __version__

GPflowOpt/acquisition/acquisition.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,10 @@
1414

1515
from ..scaling import DataScaler
1616
from ..domain import UnitCube
17+
from ..models import ModelWrapper
1718

1819
from GPflow.param import Parameterized, AutoFlow, ParamList
20+
from GPflow.model import Model
1921
from GPflow import settings
2022

2123
import numpy as np
@@ -48,7 +50,9 @@ def __init__(self, models=[], optimize_restarts=5):
4850
:param optimize_restarts: number of optimization restarts to use when training the models
4951
"""
5052
super(Acquisition, self).__init__()
51-
self._models = ParamList([DataScaler(m) for m in np.atleast_1d(models).tolist()])
53+
models = np.atleast_1d(models)
54+
assert all(isinstance(model, (Model, ModelWrapper)) for model in models)
55+
self._models = ParamList([DataScaler(m) for m in models])
5256

5357
assert (optimize_restarts >= 0)
5458
self.optimize_restarts = optimize_restarts

GPflowOpt/acquisition/ei.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,6 @@ def __init__(self, model):
5757
:param model: GPflow model (single output) representing our belief of the objective
5858
"""
5959
super(ExpectedImprovement, self).__init__(model)
60-
assert (isinstance(model, Model))
6160
self.fmin = DataHolder(np.zeros(1))
6261
self.setup()
6362

GPflowOpt/models.py

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
# Copyright 2017 Joachim van der Herten
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from GPflow.param import Parameterized
15+
from GPflow.model import Model
16+
17+
18+
class ModelWrapper(Parameterized):
19+
"""
20+
Class for fast implementation of a wrapper for models defined in GPflow.
21+
22+
Once wrapped, all lookups for attributes which are not found in the wrapper class are automatically forwarded
23+
to the wrapped model. To influence the I/O of methods on the wrapped class, simply implement the method in the
24+
wrapper and call the appropriate methods on the wrapped class. Specific logic is included to make sure that if
25+
AutoFlow methods are influenced following this pattern, the original AF storage (if existing) is unaffected and a
26+
new storage is added to the subclass.
27+
"""
28+
def __init__(self, model):
29+
"""
30+
:param model: model to be wrapped
31+
"""
32+
super(ModelWrapper, self).__init__()
33+
34+
assert isinstance(model, (Model, ModelWrapper))
35+
#: Wrapped model
36+
self.wrapped = model
37+
38+
def __getattr__(self, item):
39+
"""
40+
If an attribute is not found in this class, it is searched in the wrapped model
41+
"""
42+
# Exception for AF storages, if a method with the same name exists in this class, do not find the cache
43+
# in the wrapped model.
44+
if item.endswith('_AF_storage'):
45+
method = item[1:].rstrip('_AF_storage')
46+
if method in dir(self):
47+
raise AttributeError("{0} has no attribute {1}".format(self.__class__.__name__, item))
48+
return getattr(self.wrapped, item)
49+
50+
def __setattr__(self, key, value):
51+
"""
52+
1) If setting :attr:`wrapped` attribute, point parent to this object (the ModelWrapper).
53+
2) Setting attributes in the right objects. The following rules are processed in order:
54+
(a) If attribute exists in wrapper, set in wrapper.
55+
(b) If no object has been wrapped (wrapper is None), set attribute in the wrapper.
56+
(c) If attribute is found in the wrapped object, set it there. This rule is ignored for AF storages.
57+
(d) Set attribute in wrapper.
58+
"""
59+
if key is 'wrapped':
60+
object.__setattr__(self, key, value)
61+
value.__setattr__('_parent', self)
62+
return
63+
64+
try:
65+
# If attribute is in this object, set it. Test by using getattribute instead of hasattr to avoid lookup in
66+
# wrapped object.
67+
self.__getattribute__(key)
68+
super(ModelWrapper, self).__setattr__(key, value)
69+
except AttributeError:
70+
# Attribute is not in wrapper.
71+
# In case no wrapped object is set yet (e.g. constructor), set in wrapper.
72+
if 'wrapped' not in self.__dict__:
73+
super(ModelWrapper, self).__setattr__(key, value)
74+
return
75+
76+
if hasattr(self, key):
77+
# Now use hasattr, we know getattribute already failed so if it returns true, it must be in the wrapped
78+
# object. Hasattr is called on self instead of self.wrapped to account for the different handling of
79+
# AF storages.
80+
# Prefer setting the attribute in the wrapped object if exists.
81+
setattr(self.wrapped, key, value)
82+
else:
83+
# If not, set in wrapper nonetheless.
84+
super(ModelWrapper, self).__setattr__(key, value)
85+
86+
def __eq__(self, other):
87+
return self.wrapped == other
88+
89+
@Parameterized.name.getter
90+
def name(self):
91+
name = super(ModelWrapper, self).name
92+
return ".".join([name, str.lower(self.__class__.__name__)])

GPflowOpt/scaling.py

Lines changed: 27 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -12,21 +12,23 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from GPflow.param import DataHolder, AutoFlow, Parameterized
16-
from GPflow.model import Model, GPModel
15+
from GPflow.param import DataHolder, AutoFlow
16+
from GPflow.model import GPModel
1717
from GPflow import settings
1818
import numpy as np
1919
from .transforms import LinearTransform, DataTransform
2020
from .domain import UnitCube
21+
from .models import ModelWrapper
2122

2223
float_type = settings.dtypes.float_type
2324

2425

25-
class DataScaler(GPModel):
26+
class DataScaler(ModelWrapper):
2627
"""
27-
Model-wrapping class, primarily intended to assure the data in GPflow models is scaled. One DataScaler wraps one
28-
GPflow model, and can scale the input as well as the output data. By default, if any kind of object attribute
29-
is not found in the datascaler object, it is searched on the wrapped model.
28+
Model-wrapping class, primarily intended to assure the data in GPflow models is scaled.
29+
30+
One DataScaler wraps one GPflow model, and can scale the input as well as the output data. By default,
31+
if any kind of object attribute is not found in the datascaler object, it is searched on the wrapped model.
3032
3133
The datascaler supports both input as well as output scaling, although both scalings are set up differently:
3234
@@ -59,13 +61,8 @@ def __init__(self, model, domain=None, normalize_Y=False):
5961
:param normalize_Y: (default: False) enable automatic scaling of output values to zero mean and unit
6062
variance.
6163
"""
62-
# model sanity checks
63-
assert (model is not None)
64-
assert (isinstance(model, GPModel))
65-
self._parent = None
66-
67-
# Wrap model
68-
self.wrapped = model
64+
# model sanity checks, slightly stronger conditions than the wrapper
65+
super(DataScaler, self).__init__(model)
6966

7067
# Initial configuration of the datascaler
7168
n_inputs = model.X.shape[1]
@@ -74,34 +71,8 @@ def __init__(self, model, domain=None, normalize_Y=False):
7471
self._normalize_Y = normalize_Y
7572
self._output_transform = LinearTransform(np.ones(n_outputs), np.zeros(n_outputs))
7673

77-
# The assignments in the constructor of GPModel take care of initial re-scaling of model data.
78-
super(DataScaler, self).__init__(model.X.value, model.Y.value, None, None, 1, name=model.name+"_datascaler")
79-
del self.kern
80-
del self.mean_function
81-
del self.likelihood
82-
83-
def __getattr__(self, item):
84-
"""
85-
If an attribute is not found in this class, it is searched in the wrapped model
86-
"""
87-
return self.wrapped.__getattribute__(item)
88-
89-
def __setattr__(self, key, value):
90-
"""
91-
If setting :attr:`wrapped` attribute, point parent to this object (the datascaler)
92-
"""
93-
if key is 'wrapped':
94-
object.__setattr__(self, key, value)
95-
value.__setattr__('_parent', self)
96-
return
97-
98-
super(DataScaler, self).__setattr__(key, value)
99-
100-
def __eq__(self, other):
101-
return self.wrapped == other
102-
103-
def __str__(self, prepend=''):
104-
return self.wrapped.__str__(prepend)
74+
self.X = model.X.value
75+
self.Y = model.Y.value
10576

10677
@property
10778
def input_transform(self):
@@ -216,6 +187,20 @@ def build_predict(self, Xnew, full_cov=False):
216187
f, var = self.wrapped.build_predict(self.input_transform.build_forward(Xnew), full_cov=full_cov)
217188
return self.output_transform.build_backward(f), self.output_transform.build_backward_variance(var)
218189

190+
@AutoFlow((float_type, [None, None]))
191+
def predict_f(self, Xnew):
192+
"""
193+
Compute the mean and variance of held-out data at the points Xnew
194+
"""
195+
return self.build_predict(Xnew)
196+
197+
@AutoFlow((float_type, [None, None]))
198+
def predict_f_full_cov(self, Xnew):
199+
"""
200+
Compute the mean and variance of held-out data at the points Xnew
201+
"""
202+
return self.build_predict(Xnew, full_cov=True)
203+
219204
@AutoFlow((float_type, [None, None]))
220205
def predict_y(self, Xnew):
221206
"""
@@ -230,6 +215,6 @@ def predict_density(self, Xnew, Ynew):
230215
"""
231216
Compute the (log) density of the data Ynew at the points Xnew
232217
"""
233-
mu, var = self.build_predict(Xnew)
218+
mu, var = self.wrapped.build_predict(self.input_transform.build_forward(Xnew))
234219
Ys = self.output_transform.build_forward(Ynew)
235220
return self.likelihood.predict_density(mu, var, Ys)

GPflowOpt/transforms.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -61,9 +61,6 @@ def __invert__(self):
6161
"""
6262
raise NotImplementedError
6363

64-
def __str__(self):
65-
raise NotImplementedError
66-
6764

6865
class LinearTransform(DataTransform):
6966
"""
@@ -155,5 +152,3 @@ def __invert__(self):
155152
A_inv = np.linalg.inv(self.A.value.T)
156153
return LinearTransform(A_inv, -np.dot(self.b.value, A_inv))
157154

158-
def __str__(self):
159-
return 'XA + b'

doc/source/interfaces.rst

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,3 +36,11 @@ Transform
3636
:special-members:
3737
.. autoclass:: GPflowOpt.transforms.DataTransform
3838
:special-members:
39+
40+
ModelWrapper
41+
------------
42+
.. automodule:: GPflowOpt.models
43+
:special-members:
44+
.. autoclass:: GPflowOpt.models.ModelWrapper
45+
:members:
46+
:special-members:

testing/test_datascaler.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,6 @@ def test_object_integrity(self):
2626
Xs, Ys = m.X.value, m.Y.value
2727
n = DataScaler(m, self.domain)
2828

29-
self.assertEqual(n.wrapped, m)
30-
self.assertEqual(m._parent, n)
3129
self.assertTrue(np.allclose(Xs, n.X.value))
3230
self.assertTrue(np.allclose(Ys, n.Y.value))
3331

@@ -80,7 +78,7 @@ def test_enabling_transforms(self):
8078

8179
def test_predict_scaling(self):
8280
m = self.create_parabola_model()
83-
n = DataScaler(self.create_parabola_model(), self.domain)
81+
n = DataScaler(self.create_parabola_model(), self.domain, normalize_Y=True)
8482
m.optimize()
8583
n.optimize()
8684

@@ -100,7 +98,8 @@ def test_predict_scaling(self):
10098
self.assertTrue(np.allclose(fr, fs, atol=1e-3))
10199
self.assertTrue(np.allclose(vr, vs, atol=1e-3))
102100

103-
Yt = parabola2d(Xt) #+ np.random.rand(20, 1) * 0.05
101+
Yt = parabola2d(Xt)
104102
fr = m.predict_density(Xt, Yt)
105103
fs = n.predict_density(Xt, Yt)
106-
np.testing.assert_allclose(fr, fs, rtol=1e-3)
104+
np.testing.assert_allclose(fr, fs, rtol=1e-2)
105+

0 commit comments

Comments
 (0)