Skip to content

Commit 23340af

Browse files
authored
Merge pull request #232 from njzjz-bothub/pr-5181-production-tests
test(oom): return floating mock outputs
2 parents a1ba195 + bd871c3 commit 23340af

1 file changed

Lines changed: 6 additions & 6 deletions

File tree

source/tests/common/test_oom_retry.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -89,8 +89,8 @@ def _make_backend(self, backend: str, method_name: str) -> tuple[Any, MagicMock]
8989
DeepEval.__abstractmethods__ = abstract_methods
9090

9191
model = MagicMock()
92-
model.eval_descriptor.return_value = np.array([1, 2, 3])
93-
model.eval_fitting_last_layer.return_value = np.array([4, 5, 6])
92+
model.eval_descriptor.return_value = np.array([1.0, 2.0, 3.0])
93+
model.eval_fitting_last_layer.return_value = np.array([4.0, 5.0, 6.0])
9494

9595
if backend == "pd" and method_name == "eval_descriptor":
9696
# Paddle eval_descriptor accepts either a ModelWrapper or a direct model.
@@ -154,7 +154,7 @@ def test_pt_eval_descriptor_retry_clears_hook_between_attempts(self) -> None:
154154
"pt",
155155
"eval_descriptor",
156156
"set_eval_descriptor_hook",
157-
np.array([1, 2, 3]),
157+
np.array([1.0, 2.0, 3.0]),
158158
)
159159

160160
def test_pt_eval_fitting_last_layer_retry_clears_hook_between_attempts(
@@ -164,15 +164,15 @@ def test_pt_eval_fitting_last_layer_retry_clears_hook_between_attempts(
164164
"pt",
165165
"eval_fitting_last_layer",
166166
"set_eval_fitting_last_layer_hook",
167-
np.array([4, 5, 6]),
167+
np.array([4.0, 5.0, 6.0]),
168168
)
169169

170170
def test_pd_eval_descriptor_retry_clears_hook_between_attempts(self) -> None:
171171
self._assert_retry_clears_hook_between_attempts(
172172
"pd",
173173
"eval_descriptor",
174174
"set_eval_descriptor_hook",
175-
np.array([1, 2, 3]),
175+
np.array([1.0, 2.0, 3.0]),
176176
)
177177

178178
def test_pd_eval_fitting_last_layer_retry_clears_hook_between_attempts(
@@ -182,7 +182,7 @@ def test_pd_eval_fitting_last_layer_retry_clears_hook_between_attempts(
182182
"pd",
183183
"eval_fitting_last_layer",
184184
"set_eval_fitting_last_layer_hook",
185-
np.array([4, 5, 6]),
185+
np.array([4.0, 5.0, 6.0]),
186186
)
187187

188188
def test_pt_eval_descriptor_runtime_error_clears_state(self) -> None:

0 commit comments

Comments
 (0)