Skip to content

Commit 3f99f94

Browse files
committed
Add load_model() tests for all 3 frameworks
Tests the full roundtrip: serialize model → mock API artifact download → load_model() deserializes → verify usable model object. - sklearn: pickle.dumps → load_model → LogisticRegression with correct params - PyTorch: torch.save → load_model → nn.Module with working forward pass - TensorFlow: model.save(.keras) → load_model → working predict() - Unknown framework: raises ValueError This closes the last framework-aware SDK function that wasn't tested per model type. All framework-specific code paths now have coverage.
1 parent f9ee737 commit 3f99f94

1 file changed

Lines changed: 136 additions & 1 deletion

File tree

tests/python/sdk/test_models.py

Lines changed: 136 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
1-
"""Tests for register_model(), publish_version(), and ModelHandle."""
1+
"""Tests for register_model(), publish_version(), load_model(), and ModelHandle."""
22

33
import base64
4+
import io
45
import json
56
import os
7+
import pickle
68
import pytest
79
import responses
810
from unittest.mock import patch
@@ -286,3 +288,136 @@ def test_publish_version_with_artifact(self, client, mock_api, tmp_path):
286288
# Verify the base64 encoding round-trips correctly
287289
decoded = base64.b64decode(body["artifact_data"])
288290
assert decoded == b"\x80\x04\x95\x00\x00\x00\x00"
291+
292+
293+
# ---------------------------------------------------------------------------
294+
# load_model — framework-specific deserialization
295+
# ---------------------------------------------------------------------------
296+
297+
298+
class TestLoadModelSklearn:
299+
"""load_model() deserializes sklearn models via pickle."""
300+
301+
def test_load_model_sklearn(self, client, mock_api, sklearn_model):
302+
"""Roundtrip: serialize sklearn model → mock API → load_model → usable estimator."""
303+
model_bytes = pickle.dumps(sklearn_model)
304+
305+
# Mock resolve endpoint
306+
mock_api.add(
307+
responses.GET,
308+
f"{TEST_API_URL}/sdk/models/resolve/my-clf",
309+
json={"id": "model-sk-001", "framework": "sklearn"},
310+
status=200,
311+
)
312+
# Mock artifact download
313+
mock_api.add(
314+
responses.GET,
315+
f"{TEST_API_URL}/sdk/models/model-sk-001/artifact",
316+
body=model_bytes,
317+
status=200,
318+
content_type="application/octet-stream",
319+
)
320+
321+
loaded = client.load_model("my-clf")
322+
323+
assert type(loaded).__name__ == "LogisticRegression"
324+
# Verify it has the same params as the original
325+
assert loaded.max_iter == sklearn_model.max_iter
326+
327+
328+
class TestLoadModelPytorch:
329+
"""load_model() deserializes PyTorch models via torch.load."""
330+
331+
@pytest.mark.skipif(
332+
not pytest.importorskip("torch", reason="torch required"),
333+
reason="torch not available",
334+
)
335+
def test_load_model_pytorch(self, client, mock_api, pytorch_model):
336+
"""Roundtrip: serialize pytorch model → mock API → load_model → usable nn.Module."""
337+
import torch
338+
339+
buf = io.BytesIO()
340+
torch.save(pytorch_model, buf)
341+
model_bytes = buf.getvalue()
342+
343+
mock_api.add(
344+
responses.GET,
345+
f"{TEST_API_URL}/sdk/models/resolve/my-net",
346+
json={"id": "model-pt-001", "framework": "pytorch"},
347+
status=200,
348+
)
349+
mock_api.add(
350+
responses.GET,
351+
f"{TEST_API_URL}/sdk/models/model-pt-001/artifact",
352+
body=model_bytes,
353+
status=200,
354+
content_type="application/octet-stream",
355+
)
356+
357+
loaded = client.load_model("my-net", device="cpu")
358+
359+
assert isinstance(loaded, torch.nn.Module)
360+
# Verify forward pass works
361+
x = torch.randn(1, 4)
362+
out = loaded(x)
363+
assert out.shape == (1, 2)
364+
365+
366+
class TestLoadModelTensorflow:
367+
"""load_model() deserializes TF/Keras models via keras.load_model."""
368+
369+
def test_load_model_tensorflow(self, client, mock_api, tf_model):
370+
"""Roundtrip: serialize keras model → mock API → load_model → usable model."""
371+
import tempfile, os
372+
keras = pytest.importorskip("keras")
373+
374+
# Serialize the Keras model to .keras bytes
375+
tmpfile = tempfile.mktemp(suffix=".keras")
376+
tf_model.save(tmpfile)
377+
with open(tmpfile, "rb") as f:
378+
model_bytes = f.read()
379+
os.unlink(tmpfile)
380+
381+
mock_api.add(
382+
responses.GET,
383+
f"{TEST_API_URL}/sdk/models/resolve/my-keras",
384+
json={"id": "model-tf-001", "framework": "tensorflow"},
385+
status=200,
386+
)
387+
mock_api.add(
388+
responses.GET,
389+
f"{TEST_API_URL}/sdk/models/model-tf-001/artifact",
390+
body=model_bytes,
391+
status=200,
392+
content_type="application/octet-stream",
393+
)
394+
395+
loaded = client.load_model("my-keras")
396+
397+
assert loaded is not None
398+
# Verify predict works
399+
import numpy as np
400+
out = loaded.predict(np.array([[1.0, 2.0, 3.0, 4.0]]), verbose=0)
401+
assert out.shape == (1, 2)
402+
403+
404+
class TestLoadModelUnsupported:
405+
"""load_model() raises for unknown frameworks."""
406+
407+
def test_load_model_unsupported_framework(self, client, mock_api):
408+
mock_api.add(
409+
responses.GET,
410+
f"{TEST_API_URL}/sdk/models/resolve/my-model",
411+
json={"id": "model-unk-001", "framework": "julia"},
412+
status=200,
413+
)
414+
mock_api.add(
415+
responses.GET,
416+
f"{TEST_API_URL}/sdk/models/model-unk-001/artifact",
417+
body=b"some bytes",
418+
status=200,
419+
content_type="application/octet-stream",
420+
)
421+
422+
with pytest.raises(ValueError, match="Unsupported framework"):
423+
client.load_model("my-model")

0 commit comments

Comments
 (0)