|
1 | | -"""Tests for register_model(), publish_version(), and ModelHandle.""" |
| 1 | +"""Tests for register_model(), publish_version(), load_model(), and ModelHandle.""" |
2 | 2 |
|
3 | 3 | import base64 |
| 4 | +import io |
4 | 5 | import json |
5 | 6 | import os |
| 7 | +import pickle |
6 | 8 | import pytest |
7 | 9 | import responses |
8 | 10 | from unittest.mock import patch |
@@ -286,3 +288,136 @@ def test_publish_version_with_artifact(self, client, mock_api, tmp_path): |
286 | 288 | # Verify the base64 encoding round-trips correctly |
287 | 289 | decoded = base64.b64decode(body["artifact_data"]) |
288 | 290 | assert decoded == b"\x80\x04\x95\x00\x00\x00\x00" |
| 291 | + |
| 292 | + |
| 293 | +# --------------------------------------------------------------------------- |
| 294 | +# load_model — framework-specific deserialization |
| 295 | +# --------------------------------------------------------------------------- |
| 296 | + |
| 297 | + |
| 298 | +class TestLoadModelSklearn: |
| 299 | + """load_model() deserializes sklearn models via pickle.""" |
| 300 | + |
| 301 | + def test_load_model_sklearn(self, client, mock_api, sklearn_model): |
| 302 | + """Roundtrip: serialize sklearn model → mock API → load_model → usable estimator.""" |
| 303 | + model_bytes = pickle.dumps(sklearn_model) |
| 304 | + |
| 305 | + # Mock resolve endpoint |
| 306 | + mock_api.add( |
| 307 | + responses.GET, |
| 308 | + f"{TEST_API_URL}/sdk/models/resolve/my-clf", |
| 309 | + json={"id": "model-sk-001", "framework": "sklearn"}, |
| 310 | + status=200, |
| 311 | + ) |
| 312 | + # Mock artifact download |
| 313 | + mock_api.add( |
| 314 | + responses.GET, |
| 315 | + f"{TEST_API_URL}/sdk/models/model-sk-001/artifact", |
| 316 | + body=model_bytes, |
| 317 | + status=200, |
| 318 | + content_type="application/octet-stream", |
| 319 | + ) |
| 320 | + |
| 321 | + loaded = client.load_model("my-clf") |
| 322 | + |
| 323 | + assert type(loaded).__name__ == "LogisticRegression" |
| 324 | + # Verify it has the same params as the original |
| 325 | + assert loaded.max_iter == sklearn_model.max_iter |
| 326 | + |
| 327 | + |
| 328 | +class TestLoadModelPytorch: |
| 329 | + """load_model() deserializes PyTorch models via torch.load.""" |
| 330 | + |
| 331 | + @pytest.mark.skipif( |
| 332 | + not pytest.importorskip("torch", reason="torch required"), |
| 333 | + reason="torch not available", |
| 334 | + ) |
| 335 | + def test_load_model_pytorch(self, client, mock_api, pytorch_model): |
| 336 | + """Roundtrip: serialize pytorch model → mock API → load_model → usable nn.Module.""" |
| 337 | + import torch |
| 338 | + |
| 339 | + buf = io.BytesIO() |
| 340 | + torch.save(pytorch_model, buf) |
| 341 | + model_bytes = buf.getvalue() |
| 342 | + |
| 343 | + mock_api.add( |
| 344 | + responses.GET, |
| 345 | + f"{TEST_API_URL}/sdk/models/resolve/my-net", |
| 346 | + json={"id": "model-pt-001", "framework": "pytorch"}, |
| 347 | + status=200, |
| 348 | + ) |
| 349 | + mock_api.add( |
| 350 | + responses.GET, |
| 351 | + f"{TEST_API_URL}/sdk/models/model-pt-001/artifact", |
| 352 | + body=model_bytes, |
| 353 | + status=200, |
| 354 | + content_type="application/octet-stream", |
| 355 | + ) |
| 356 | + |
| 357 | + loaded = client.load_model("my-net", device="cpu") |
| 358 | + |
| 359 | + assert isinstance(loaded, torch.nn.Module) |
| 360 | + # Verify forward pass works |
| 361 | + x = torch.randn(1, 4) |
| 362 | + out = loaded(x) |
| 363 | + assert out.shape == (1, 2) |
| 364 | + |
| 365 | + |
| 366 | +class TestLoadModelTensorflow: |
| 367 | + """load_model() deserializes TF/Keras models via keras.load_model.""" |
| 368 | + |
| 369 | + def test_load_model_tensorflow(self, client, mock_api, tf_model): |
| 370 | + """Roundtrip: serialize keras model → mock API → load_model → usable model.""" |
| 371 | + import tempfile, os |
| 372 | + keras = pytest.importorskip("keras") |
| 373 | + |
| 374 | + # Serialize the Keras model to .keras bytes |
| 375 | + tmpfile = tempfile.mktemp(suffix=".keras") |
| 376 | + tf_model.save(tmpfile) |
| 377 | + with open(tmpfile, "rb") as f: |
| 378 | + model_bytes = f.read() |
| 379 | + os.unlink(tmpfile) |
| 380 | + |
| 381 | + mock_api.add( |
| 382 | + responses.GET, |
| 383 | + f"{TEST_API_URL}/sdk/models/resolve/my-keras", |
| 384 | + json={"id": "model-tf-001", "framework": "tensorflow"}, |
| 385 | + status=200, |
| 386 | + ) |
| 387 | + mock_api.add( |
| 388 | + responses.GET, |
| 389 | + f"{TEST_API_URL}/sdk/models/model-tf-001/artifact", |
| 390 | + body=model_bytes, |
| 391 | + status=200, |
| 392 | + content_type="application/octet-stream", |
| 393 | + ) |
| 394 | + |
| 395 | + loaded = client.load_model("my-keras") |
| 396 | + |
| 397 | + assert loaded is not None |
| 398 | + # Verify predict works |
| 399 | + import numpy as np |
| 400 | + out = loaded.predict(np.array([[1.0, 2.0, 3.0, 4.0]]), verbose=0) |
| 401 | + assert out.shape == (1, 2) |
| 402 | + |
| 403 | + |
| 404 | +class TestLoadModelUnsupported: |
| 405 | + """load_model() raises for unknown frameworks.""" |
| 406 | + |
| 407 | + def test_load_model_unsupported_framework(self, client, mock_api): |
| 408 | + mock_api.add( |
| 409 | + responses.GET, |
| 410 | + f"{TEST_API_URL}/sdk/models/resolve/my-model", |
| 411 | + json={"id": "model-unk-001", "framework": "julia"}, |
| 412 | + status=200, |
| 413 | + ) |
| 414 | + mock_api.add( |
| 415 | + responses.GET, |
| 416 | + f"{TEST_API_URL}/sdk/models/model-unk-001/artifact", |
| 417 | + body=b"some bytes", |
| 418 | + status=200, |
| 419 | + content_type="application/octet-stream", |
| 420 | + ) |
| 421 | + |
| 422 | + with pytest.raises(ValueError, match="Unsupported framework"): |
| 423 | + client.load_model("my-model") |
0 commit comments