1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414
15- from GPflow .param import DataHolder , AutoFlow , Parameterized
16- from GPflow .model import Model , GPModel
15+ from GPflow .param import DataHolder , AutoFlow
16+ from GPflow .model import GPModel
1717from GPflow import settings
1818import numpy as np
1919from .transforms import LinearTransform , DataTransform
2020from .domain import UnitCube
21+ from .models import ModelWrapper
2122
2223float_type = settings .dtypes .float_type
2324
2425
25- class DataScaler (GPModel ):
26+ class DataScaler (ModelWrapper ):
2627 """
27- Model-wrapping class, primarily intended to assure the data in GPflow models is scaled. One DataScaler wraps one
28- GPflow model, and can scale the input as well as the output data. By default, if any kind of object attribute
29- is not found in the datascaler object, it is searched on the wrapped model.
28+ Model-wrapping class, primarily intended to assure the data in GPflow models is scaled.
29+
30+ One DataScaler wraps one GPflow model, and can scale the input as well as the output data. By default,
31+ if any kind of object attribute is not found in the datascaler object, it is searched on the wrapped model.
3032
3133 The datascaler supports both input as well as output scaling, although both scalings are set up differently:
3234
@@ -59,13 +61,8 @@ def __init__(self, model, domain=None, normalize_Y=False):
5961 :param normalize_Y: (default: False) enable automatic scaling of output values to zero mean and unit
6062 variance.
6163 """
62- # model sanity checks
63- assert (model is not None )
64- assert (isinstance (model , GPModel ))
65- self ._parent = None
66-
67- # Wrap model
68- self .wrapped = model
64+ # model sanity checks, slightly stronger conditions than the wrapper
65+ super (DataScaler , self ).__init__ (model )
6966
7067 # Initial configuration of the datascaler
7168 n_inputs = model .X .shape [1 ]
@@ -74,34 +71,8 @@ def __init__(self, model, domain=None, normalize_Y=False):
7471 self ._normalize_Y = normalize_Y
7572 self ._output_transform = LinearTransform (np .ones (n_outputs ), np .zeros (n_outputs ))
7673
77- # The assignments in the constructor of GPModel take care of initial re-scaling of model data.
78- super (DataScaler , self ).__init__ (model .X .value , model .Y .value , None , None , 1 , name = model .name + "_datascaler" )
79- del self .kern
80- del self .mean_function
81- del self .likelihood
82-
83- def __getattr__ (self , item ):
84- """
85- If an attribute is not found in this class, it is searched in the wrapped model
86- """
87- return self .wrapped .__getattribute__ (item )
88-
89- def __setattr__ (self , key , value ):
90- """
91- If setting :attr:`wrapped` attribute, point parent to this object (the datascaler)
92- """
93- if key is 'wrapped' :
94- object .__setattr__ (self , key , value )
95- value .__setattr__ ('_parent' , self )
96- return
97-
98- super (DataScaler , self ).__setattr__ (key , value )
99-
100- def __eq__ (self , other ):
101- return self .wrapped == other
102-
103- def __str__ (self , prepend = '' ):
104- return self .wrapped .__str__ (prepend )
74+ self .X = model .X .value
75+ self .Y = model .Y .value
10576
10677 @property
10778 def input_transform (self ):
@@ -216,6 +187,20 @@ def build_predict(self, Xnew, full_cov=False):
216187 f , var = self .wrapped .build_predict (self .input_transform .build_forward (Xnew ), full_cov = full_cov )
217188 return self .output_transform .build_backward (f ), self .output_transform .build_backward_variance (var )
218189
190+ @AutoFlow ((float_type , [None , None ]))
191+ def predict_f (self , Xnew ):
192+ """
193+ Compute the mean and variance of held-out data at the points Xnew
194+ """
195+ return self .build_predict (Xnew )
196+
197+ @AutoFlow ((float_type , [None , None ]))
198+ def predict_f_full_cov (self , Xnew ):
199+ """
200+ Compute the mean and variance of held-out data at the points Xnew
201+ """
202+ return self .build_predict (Xnew , full_cov = True )
203+
219204 @AutoFlow ((float_type , [None , None ]))
220205 def predict_y (self , Xnew ):
221206 """
@@ -230,6 +215,6 @@ def predict_density(self, Xnew, Ynew):
230215 """
231216 Compute the (log) density of the data Ynew at the points Xnew
232217 """
233- mu , var = self .build_predict (Xnew )
218+ mu , var = self .wrapped . build_predict (self . input_transform . build_forward ( Xnew ) )
234219 Ys = self .output_transform .build_forward (Ynew )
235220 return self .likelihood .predict_density (mu , var , Ys )
0 commit comments