Skip to content

Commit 0069493

Browse files
committed
data input tests
1 parent f22c1d2 commit 0069493

1 file changed

Lines changed: 46 additions & 0 deletions

File tree

tests/test_data_input.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,3 +94,49 @@ def test_single_datapoint_identity(x: Any) -> None:
9494
assert len(x) == len(_rho)
9595
else:
9696
assert len(_rho) == 1
97+
98+
99+
@pytest.mark.parametrize("dy", [
100+
rng.rand(3, 3),
101+
[0.1, 0.2, 0.3],
102+
0.1,
103+
np.array([0.1, 0.2, 0.3]),
104+
np.array([0.1, 0.2, 0.3]).reshape(-1, 1),
105+
np.array([0.1, 0.2, 0.3]).reshape(1, -1)
106+
])
107+
def test_multiple_dy_inputs(dy: Any) -> None:
108+
x = np.array([1, 2, 3])
109+
y = np.array([1, 2, 3])
110+
data = {
111+
'x': x,
112+
'y': y,
113+
'dy': dy
114+
}
115+
116+
kernel = fp.kernels.RadialBasisFunction(0.5, 0.3)
117+
constraints = [fp.constraints.LinearEquality(fp.operators.Identity(), data)]
118+
model = fp.models.GaussianProcess(kernel, constraints)
119+
_rho, _ = model.predict(x)
120+
121+
assert len(x) == len(_rho)
122+
123+
124+
@pytest.mark.parametrize("data_input", [
125+
{'x': [1, 2, 3], 'y': [1, 2, 3], 'dy': [0.1, 0.2, 0.3]},
126+
[[1, 2, 3], [1, 2, 3], [0.1, 0.2, 0.3]],
127+
np.array([[1, 2, 3], [1, 2, 3], [0.1, 0.2, 0.3]])
128+
])
129+
def test_different_data_inputs(data_input: Any) -> None:
130+
if isinstance(data_input, dict):
131+
x = data_input['x']
132+
else:
133+
x = data_input[0]
134+
135+
data = data_input
136+
137+
kernel = fp.kernels.RadialBasisFunction(0.5, 0.3)
138+
constraints = [fp.constraints.LinearEquality(fp.operators.Identity(), data)]
139+
model = fp.models.GaussianProcess(kernel, constraints)
140+
_rho, _ = model.predict(x)
141+
142+
assert len(x) == len(_rho)

0 commit comments

Comments
 (0)