From 228240a09fe53e5ce9630a1bab301f6fcfab1348 Mon Sep 17 00:00:00 2001 From: aviruthen <91846056+aviruthen@users.noreply.github.com> Date: Tue, 7 Apr 2026 17:09:46 -0400 Subject: [PATCH 1/5] fix: bug: ModelBuilder overwrites user-provided HF_MODEL_ID for DJL Serving, preventi (5529) --- .../sagemaker/serve/model_builder_servers.py | 14 +- .../test_model_builder_servers_hf_model_id.py | 275 ++++++++++++++++++ 2 files changed, 282 insertions(+), 7 deletions(-) create mode 100644 sagemaker-serve/tests/unit/test_model_builder_servers_hf_model_id.py diff --git a/sagemaker-serve/src/sagemaker/serve/model_builder_servers.py b/sagemaker-serve/src/sagemaker/serve/model_builder_servers.py index 43af8b4f7a..48b8e0b307 100644 --- a/sagemaker-serve/src/sagemaker/serve/model_builder_servers.py +++ b/sagemaker-serve/src/sagemaker/serve/model_builder_servers.py @@ -136,7 +136,7 @@ def _build_for_torchserve(self) -> Model: if isinstance(self.model, str): # Configure HuggingFace model support if not self._is_jumpstart_model_id(): - self.env_vars.update({"HF_MODEL_ID": self.model}) + self.env_vars.setdefault("HF_MODEL_ID", self.model) # Add HuggingFace token if available if self.env_vars.get("HUGGING_FACE_HUB_TOKEN"): @@ -212,7 +212,7 @@ def _build_for_tgi(self) -> Model: if isinstance(self.model, str) and not self._is_jumpstart_model_id(): # Configure HuggingFace model for TGI - self.env_vars.update({"HF_MODEL_ID": self.model}) + self.env_vars.setdefault("HF_MODEL_ID", self.model) self.hf_model_config = _get_model_config_properties_from_hf( self.model, self.env_vars.get("HUGGING_FACE_HUB_TOKEN") @@ -320,7 +320,7 @@ def _build_for_djl(self) -> Model: if isinstance(self.model, str) and not self._is_jumpstart_model_id(): # Configure HuggingFace model for DJL - self.env_vars.update({"HF_MODEL_ID": self.model}) + self.env_vars.setdefault("HF_MODEL_ID", self.model) # Get model configuration for DJL optimization self.hf_model_config = _get_model_config_properties_from_hf( @@ -426,7 +426,7 @@ def _build_for_triton(self) -> Model: self.env_vars.update({"HF_TASK": model_task}) # Configure HuggingFace authentication - self.env_vars.update({"HF_MODEL_ID": self.model}) + self.env_vars.setdefault("HF_MODEL_ID", self.model) if self.env_vars.get("HUGGING_FACE_HUB_TOKEN"): self.env_vars["HF_TOKEN"] = self.env_vars.get("HUGGING_FACE_HUB_TOKEN") @@ -532,7 +532,7 @@ def _build_for_tei(self) -> Model: if isinstance(self.model, str) and not self._is_jumpstart_model_id(): # Configure HuggingFace model for TEI - self.env_vars.update({"HF_MODEL_ID": self.model}) + self.env_vars.setdefault("HF_MODEL_ID", self.model) self.hf_model_config = _get_model_config_properties_from_hf( self.model, self.env_vars.get("HUGGING_FACE_HUB_TOKEN") @@ -676,7 +676,7 @@ def _build_for_transformers(self) -> Model: if self.inference_spec is not None: hf_model_id = self.inference_spec.get_model() if isinstance(hf_model_id, str): # Only if it's a valid HF model ID - self.env_vars.update({"HF_MODEL_ID": hf_model_id}) + self.env_vars.setdefault("HF_MODEL_ID", hf_model_id) # Get HF config only for string model IDs if hasattr(self.env_vars, "HF_API_TOKEN"): self.hf_model_config = _get_model_config_properties_from_hf( @@ -687,7 +687,7 @@ def _build_for_transformers(self) -> Model: hf_model_id, self.env_vars.get("HUGGING_FACE_HUB_TOKEN") ) elif isinstance(self.model, str): # Only set HF_MODEL_ID if model is a string - self.env_vars.update({"HF_MODEL_ID": self.model}) + self.env_vars.setdefault("HF_MODEL_ID", self.model) # Get HF config for string model IDs if hasattr(self.env_vars, "HF_API_TOKEN"): self.hf_model_config = _get_model_config_properties_from_hf( diff --git a/sagemaker-serve/tests/unit/test_model_builder_servers_hf_model_id.py b/sagemaker-serve/tests/unit/test_model_builder_servers_hf_model_id.py new file mode 100644 index 0000000000..0b06dd9877 --- /dev/null +++ b/sagemaker-serve/tests/unit/test_model_builder_servers_hf_model_id.py @@ -0,0 +1,275 @@ +"""Unit tests to verify HF_MODEL_ID is not overwritten when user provides it.""" +import unittest +from unittest.mock import Mock, patch, MagicMock, PropertyMock + +from sagemaker.serve.model_builder_servers import _ModelBuilderServers +from sagemaker.serve.utils.types import ModelServer +from sagemaker.serve.mode.function_pointers import Mode + + +def _create_mock_builder(env_vars=None, model="Qwen/Qwen3-VL-4B-Instruct"): + """Create a mock builder with common attributes set.""" + builder = MagicMock(spec=_ModelBuilderServers) + builder.model = model + builder.env_vars = env_vars if env_vars is not None else {} + builder.model_path = "/tmp/test_model_path" + builder.mode = Mode.SAGEMAKER_ENDPOINT + builder.model_server = ModelServer.DJL_SERVING + builder.secret_key = "" + builder.s3_upload_path = None + builder.s3_model_data_url = None + builder.shared_libs = [] + builder.dependencies = {} + builder.image_uri = "test-image-uri" + builder.instance_type = "ml.g5.2xlarge" + builder.sagemaker_session = Mock() + builder.schema_builder = MagicMock() + builder.schema_builder.sample_input = {"inputs": "Hello", "parameters": {}} + builder.inference_spec = None + builder.hf_model_config = {} + builder.model_data_download_timeout = None + builder._user_provided_instance_type = True + builder._is_jumpstart_model_id = Mock(return_value=False) + builder._auto_detect_image_uri = Mock() + builder._prepare_for_mode = Mock(return_value=("s3://model-data", None)) + builder._create_model = Mock(return_value=Mock()) + builder._optimizing = False + builder._validate_djl_serving_sample_data = Mock() + builder._validate_tgi_serving_sample_data = Mock() + builder._validate_for_triton = Mock() + builder.get_huggingface_model_metadata = Mock(return_value={"pipeline_tag": "text-generation"}) + builder.role_arn = "arn:aws:iam::123456789012:role/SageMakerRole" + return builder + + +class TestDjlPreservesHfModelId(unittest.TestCase): + """Test that _build_for_djl preserves user-provided HF_MODEL_ID.""" + + @patch("sagemaker.serve.model_builder_servers._get_model_config_properties_from_hf") + @patch("sagemaker.serve.model_builder_servers._get_default_djl_configurations") + @patch("sagemaker.serve.model_builder_servers._get_nb_instance", return_value=None) + @patch("sagemaker.serve.model_builder_servers._get_gpu_info", return_value=1) + @patch("sagemaker.serve.model_builder_servers._get_default_tensor_parallel_degree", return_value=1) + def test_preserves_user_provided_s3_uri(self, mock_tp, mock_gpu, mock_nb, mock_djl_config, mock_hf_config): + """User-provided S3 URI for HF_MODEL_ID should not be overwritten.""" + mock_hf_config.return_value = {} + mock_djl_config.return_value = ({}, 256) + + s3_path = "s3://my-bucket/models/Qwen/" + builder = _create_mock_builder(env_vars={"HF_MODEL_ID": s3_path}) + + with patch("sagemaker.serve.model_server.djl_serving.prepare._create_dir_structure"): + _ModelBuilderServers._build_for_djl(builder) + + self.assertEqual(builder.env_vars["HF_MODEL_ID"], s3_path) + + @patch("sagemaker.serve.model_builder_servers._get_model_config_properties_from_hf") + @patch("sagemaker.serve.model_builder_servers._get_default_djl_configurations") + @patch("sagemaker.serve.model_builder_servers._get_nb_instance", return_value=None) + @patch("sagemaker.serve.model_builder_servers._get_gpu_info", return_value=1) + @patch("sagemaker.serve.model_builder_servers._get_default_tensor_parallel_degree", return_value=1) + def test_sets_hf_model_id_when_not_provided(self, mock_tp, mock_gpu, mock_nb, mock_djl_config, mock_hf_config): + """HF_MODEL_ID should be set from self.model when user doesn't provide it.""" + mock_hf_config.return_value = {} + mock_djl_config.return_value = ({}, 256) + + builder = _create_mock_builder(env_vars={}) + + with patch("sagemaker.serve.model_server.djl_serving.prepare._create_dir_structure"): + _ModelBuilderServers._build_for_djl(builder) + + self.assertEqual(builder.env_vars["HF_MODEL_ID"], "Qwen/Qwen3-VL-4B-Instruct") + + +class TestTgiPreservesHfModelId(unittest.TestCase): + """Test that _build_for_tgi preserves user-provided HF_MODEL_ID.""" + + @patch("sagemaker.serve.model_builder_servers._get_model_config_properties_from_hf") + @patch("sagemaker.serve.model_builder_servers._get_default_tgi_configurations") + @patch("sagemaker.serve.model_builder_servers._get_nb_instance", return_value=None) + @patch("sagemaker.serve.model_builder_servers._get_gpu_info", return_value=1) + @patch("sagemaker.serve.model_builder_servers._get_default_tensor_parallel_degree", return_value=1) + def test_preserves_user_provided_s3_uri(self, mock_tp, mock_gpu, mock_nb, mock_tgi_config, mock_hf_config): + """User-provided S3 URI for HF_MODEL_ID should not be overwritten.""" + mock_hf_config.return_value = {} + mock_tgi_config.return_value = ({}, 256) + + s3_path = "s3://my-bucket/models/Qwen/" + builder = _create_mock_builder(env_vars={"HF_MODEL_ID": s3_path}) + builder.model_server = ModelServer.TGI + + with patch("sagemaker.serve.model_server.tgi.prepare._create_dir_structure"): + _ModelBuilderServers._build_for_tgi(builder) + + self.assertEqual(builder.env_vars["HF_MODEL_ID"], s3_path) + + @patch("sagemaker.serve.model_builder_servers._get_model_config_properties_from_hf") + @patch("sagemaker.serve.model_builder_servers._get_default_tgi_configurations") + @patch("sagemaker.serve.model_builder_servers._get_nb_instance", return_value=None) + @patch("sagemaker.serve.model_builder_servers._get_gpu_info", return_value=1) + @patch("sagemaker.serve.model_builder_servers._get_default_tensor_parallel_degree", return_value=1) + def test_sets_hf_model_id_when_not_provided(self, mock_tp, mock_gpu, mock_nb, mock_tgi_config, mock_hf_config): + """HF_MODEL_ID should be set from self.model when user doesn't provide it.""" + mock_hf_config.return_value = {} + mock_tgi_config.return_value = ({}, 256) + + builder = _create_mock_builder(env_vars={}) + builder.model_server = ModelServer.TGI + + with patch("sagemaker.serve.model_server.tgi.prepare._create_dir_structure"): + _ModelBuilderServers._build_for_tgi(builder) + + self.assertEqual(builder.env_vars["HF_MODEL_ID"], "Qwen/Qwen3-VL-4B-Instruct") + + +class TestTeiPreservesHfModelId(unittest.TestCase): + """Test that _build_for_tei preserves user-provided HF_MODEL_ID.""" + + @patch("sagemaker.serve.model_builder_servers._get_model_config_properties_from_hf") + @patch("sagemaker.serve.model_builder_servers._get_nb_instance", return_value=None) + def test_preserves_user_provided_s3_uri(self, mock_nb, mock_hf_config): + """User-provided S3 URI for HF_MODEL_ID should not be overwritten.""" + mock_hf_config.return_value = {} + + s3_path = "s3://my-bucket/models/embedding-model/" + builder = _create_mock_builder(env_vars={"HF_MODEL_ID": s3_path}) + builder.model_server = ModelServer.TEI + + with patch("sagemaker.serve.model_server.tgi.prepare._create_dir_structure"): + _ModelBuilderServers._build_for_tei(builder) + + self.assertEqual(builder.env_vars["HF_MODEL_ID"], s3_path) + + @patch("sagemaker.serve.model_builder_servers._get_model_config_properties_from_hf") + @patch("sagemaker.serve.model_builder_servers._get_nb_instance", return_value=None) + def test_sets_hf_model_id_when_not_provided(self, mock_nb, mock_hf_config): + """HF_MODEL_ID should be set from self.model when user doesn't provide it.""" + mock_hf_config.return_value = {} + + builder = _create_mock_builder(env_vars={}) + builder.model_server = ModelServer.TEI + + with patch("sagemaker.serve.model_server.tgi.prepare._create_dir_structure"): + _ModelBuilderServers._build_for_tei(builder) + + self.assertEqual(builder.env_vars["HF_MODEL_ID"], "Qwen/Qwen3-VL-4B-Instruct") + + +class TestTorchservePreservesHfModelId(unittest.TestCase): + """Test that _build_for_torchserve preserves user-provided HF_MODEL_ID.""" + + def test_preserves_user_provided_s3_uri(self): + """User-provided S3 URI for HF_MODEL_ID should not be overwritten.""" + s3_path = "s3://my-bucket/models/my-model/" + builder = _create_mock_builder(env_vars={"HF_MODEL_ID": s3_path}) + builder.model_server = ModelServer.TORCHSERVE + builder.mode = Mode.SAGEMAKER_ENDPOINT + builder._save_model_inference_spec = Mock() + + _ModelBuilderServers._build_for_torchserve(builder) + + self.assertEqual(builder.env_vars["HF_MODEL_ID"], s3_path) + + def test_sets_hf_model_id_when_not_provided(self): + """HF_MODEL_ID should be set from self.model when user doesn't provide it.""" + builder = _create_mock_builder(env_vars={}) + builder.model_server = ModelServer.TORCHSERVE + builder.mode = Mode.SAGEMAKER_ENDPOINT + builder._save_model_inference_spec = Mock() + + _ModelBuilderServers._build_for_torchserve(builder) + + self.assertEqual(builder.env_vars["HF_MODEL_ID"], "Qwen/Qwen3-VL-4B-Instruct") + + +class TestTritonPreservesHfModelId(unittest.TestCase): + """Test that _build_for_triton preserves user-provided HF_MODEL_ID.""" + + def test_preserves_user_provided_s3_uri(self): + """User-provided S3 URI for HF_MODEL_ID should not be overwritten.""" + s3_path = "s3://my-bucket/models/my-model/" + builder = _create_mock_builder(env_vars={"HF_MODEL_ID": s3_path}) + builder.model_server = ModelServer.TRITON + builder._save_inference_spec = Mock() + builder._prepare_for_triton = Mock() + builder._auto_detect_image_for_triton = Mock() + + _ModelBuilderServers._build_for_triton(builder) + + self.assertEqual(builder.env_vars["HF_MODEL_ID"], s3_path) + + def test_sets_hf_model_id_when_not_provided(self): + """HF_MODEL_ID should be set from self.model when user doesn't provide it.""" + builder = _create_mock_builder(env_vars={}) + builder.model_server = ModelServer.TRITON + builder._save_inference_spec = Mock() + builder._prepare_for_triton = Mock() + builder._auto_detect_image_for_triton = Mock() + + _ModelBuilderServers._build_for_triton(builder) + + self.assertEqual(builder.env_vars["HF_MODEL_ID"], "Qwen/Qwen3-VL-4B-Instruct") + + +class TestTransformersPreservesHfModelId(unittest.TestCase): + """Test that _build_for_transformers preserves user-provided HF_MODEL_ID.""" + + @patch("sagemaker.serve.model_builder_servers._get_model_config_properties_from_hf") + @patch("sagemaker.serve.model_builder_servers._get_nb_instance", return_value=None) + def test_preserves_user_provided_s3_uri_with_model_string(self, mock_nb, mock_hf_config): + """User-provided S3 URI for HF_MODEL_ID should not be overwritten when model is a string.""" + mock_hf_config.return_value = {} + + s3_path = "s3://my-bucket/models/my-model/" + builder = _create_mock_builder(env_vars={"HF_MODEL_ID": s3_path}) + builder.model_server = ModelServer.MMS + builder.mode = Mode.SAGEMAKER_ENDPOINT + builder.model_data_download_timeout = None + + with patch("sagemaker.serve.model_server.multi_model_server.prepare._create_dir_structure"): + _ModelBuilderServers._build_for_transformers(builder) + + self.assertEqual(builder.env_vars["HF_MODEL_ID"], s3_path) + + @patch("sagemaker.serve.model_builder_servers._get_model_config_properties_from_hf") + @patch("sagemaker.serve.model_builder_servers._get_nb_instance", return_value=None) + def test_sets_hf_model_id_when_not_provided_with_model_string(self, mock_nb, mock_hf_config): + """HF_MODEL_ID should be set from self.model when user doesn't provide it.""" + mock_hf_config.return_value = {} + + builder = _create_mock_builder(env_vars={}) + builder.model_server = ModelServer.MMS + builder.mode = Mode.SAGEMAKER_ENDPOINT + builder.model_data_download_timeout = None + + with patch("sagemaker.serve.model_server.multi_model_server.prepare._create_dir_structure"): + _ModelBuilderServers._build_for_transformers(builder) + + self.assertEqual(builder.env_vars["HF_MODEL_ID"], "Qwen/Qwen3-VL-4B-Instruct") + + @patch("sagemaker.serve.model_builder_servers._get_model_config_properties_from_hf") + @patch("sagemaker.serve.model_builder_servers._get_nb_instance", return_value=None) + @patch("sagemaker.serve.model_builder_servers.save_pkl") + def test_preserves_user_provided_hf_model_id_with_inference_spec(self, mock_pkl, mock_nb, mock_hf_config): + """User-provided HF_MODEL_ID should not be overwritten when inference_spec provides a model ID.""" + mock_hf_config.return_value = {} + + s3_path = "s3://my-bucket/models/my-model/" + builder = _create_mock_builder(env_vars={"HF_MODEL_ID": s3_path}) + builder.model_server = ModelServer.MMS + builder.mode = Mode.SAGEMAKER_ENDPOINT + builder.model_data_download_timeout = None + builder.model = None # No model string, using inference_spec + builder.inference_spec = Mock() + builder.inference_spec.get_model.return_value = "some-hf-model-id" + builder._is_jumpstart_model_id = Mock(return_value=False) + + with patch("sagemaker.serve.model_server.multi_model_server.prepare._create_dir_structure"): + with patch("os.makedirs"): + _ModelBuilderServers._build_for_transformers(builder) + + self.assertEqual(builder.env_vars["HF_MODEL_ID"], s3_path) + + +if __name__ == "__main__": + unittest.main() From 4badbdd00bf3e8bbf76eec0aa96d71b7626d2221 Mon Sep 17 00:00:00 2001 From: aviruthen <91846056+aviruthen@users.noreply.github.com> Date: Tue, 7 Apr 2026 17:49:41 -0400 Subject: [PATCH 2/5] fix: address review comments (iteration #1) --- .../test_model_builder_servers_hf_model_id.py | 521 +++++++++++------- 1 file changed, 325 insertions(+), 196 deletions(-) diff --git a/sagemaker-serve/tests/unit/test_model_builder_servers_hf_model_id.py b/sagemaker-serve/tests/unit/test_model_builder_servers_hf_model_id.py index 0b06dd9877..1af9891cc5 100644 --- a/sagemaker-serve/tests/unit/test_model_builder_servers_hf_model_id.py +++ b/sagemaker-serve/tests/unit/test_model_builder_servers_hf_model_id.py @@ -1,13 +1,20 @@ """Unit tests to verify HF_MODEL_ID is not overwritten when user provides it.""" -import unittest -from unittest.mock import Mock, patch, MagicMock, PropertyMock +from __future__ import annotations + +from typing import Optional +from unittest.mock import Mock, patch, MagicMock + +import pytest from sagemaker.serve.model_builder_servers import _ModelBuilderServers from sagemaker.serve.utils.types import ModelServer from sagemaker.serve.mode.function_pointers import Mode -def _create_mock_builder(env_vars=None, model="Qwen/Qwen3-VL-4B-Instruct"): +def _create_mock_builder( + env_vars: Optional[dict[str, str]] = None, + model: str = "Qwen/Qwen3-VL-4B-Instruct", +) -> MagicMock: """Create a mock builder with common attributes set.""" builder = MagicMock(spec=_ModelBuilderServers) builder.model = model @@ -24,252 +31,374 @@ def _create_mock_builder(env_vars=None, model="Qwen/Qwen3-VL-4B-Instruct"): builder.instance_type = "ml.g5.2xlarge" builder.sagemaker_session = Mock() builder.schema_builder = MagicMock() - builder.schema_builder.sample_input = {"inputs": "Hello", "parameters": {}} + builder.schema_builder.sample_input = { + "inputs": "Hello", + "parameters": {}, + } builder.inference_spec = None builder.hf_model_config = {} builder.model_data_download_timeout = None builder._user_provided_instance_type = True builder._is_jumpstart_model_id = Mock(return_value=False) builder._auto_detect_image_uri = Mock() - builder._prepare_for_mode = Mock(return_value=("s3://model-data", None)) + builder._prepare_for_mode = Mock( + return_value=("s3://model-data", None) + ) builder._create_model = Mock(return_value=Mock()) builder._optimizing = False builder._validate_djl_serving_sample_data = Mock() builder._validate_tgi_serving_sample_data = Mock() builder._validate_for_triton = Mock() - builder.get_huggingface_model_metadata = Mock(return_value={"pipeline_tag": "text-generation"}) - builder.role_arn = "arn:aws:iam::123456789012:role/SageMakerRole" + builder.get_huggingface_model_metadata = Mock( + return_value={"pipeline_tag": "text-generation"} + ) + builder.role_arn = ( + "arn:aws:iam::123456789012:role/SageMakerRole" + ) return builder -class TestDjlPreservesHfModelId(unittest.TestCase): - """Test that _build_for_djl preserves user-provided HF_MODEL_ID.""" - - @patch("sagemaker.serve.model_builder_servers._get_model_config_properties_from_hf") - @patch("sagemaker.serve.model_builder_servers._get_default_djl_configurations") - @patch("sagemaker.serve.model_builder_servers._get_nb_instance", return_value=None) - @patch("sagemaker.serve.model_builder_servers._get_gpu_info", return_value=1) - @patch("sagemaker.serve.model_builder_servers._get_default_tensor_parallel_degree", return_value=1) - def test_preserves_user_provided_s3_uri(self, mock_tp, mock_gpu, mock_nb, mock_djl_config, mock_hf_config): - """User-provided S3 URI for HF_MODEL_ID should not be overwritten.""" - mock_hf_config.return_value = {} - mock_djl_config.return_value = ({}, 256) - - s3_path = "s3://my-bucket/models/Qwen/" - builder = _create_mock_builder(env_vars={"HF_MODEL_ID": s3_path}) - - with patch("sagemaker.serve.model_server.djl_serving.prepare._create_dir_structure"): +@pytest.fixture +def mock_builder() -> MagicMock: + """Create a mock builder with default (empty) env_vars.""" + return _create_mock_builder(env_vars={}) + + +@pytest.fixture +def mock_builder_with_s3() -> MagicMock: + """Create a mock builder with user-provided S3 HF_MODEL_ID.""" + return _create_mock_builder( + env_vars={"HF_MODEL_ID": "s3://my-bucket/models/Qwen/"} + ) + + +S3_PATH = "s3://my-bucket/models/Qwen/" +DEFAULT_MODEL = "Qwen/Qwen3-VL-4B-Instruct" + + +# --------------------------------------------------------------------------- +# DJL Serving +# --------------------------------------------------------------------------- +class TestBuildForDjlHfModelId: + """Test _build_for_djl preserves user-provided HF_MODEL_ID.""" + + _patches = [ + patch( + "sagemaker.serve.model_builder_servers" + "._get_default_tensor_parallel_degree", + return_value=1, + ), + patch( + "sagemaker.serve.model_builder_servers._get_gpu_info", + return_value=1, + ), + patch( + "sagemaker.serve.model_builder_servers._get_nb_instance", + return_value=None, + ), + patch( + "sagemaker.serve.model_builder_servers" + "._get_default_djl_configurations", + return_value=({}, 256), + ), + patch( + "sagemaker.serve.model_builder_servers" + "._get_model_config_properties_from_hf", + return_value={}, + ), + patch( + "sagemaker.serve.model_server.djl_serving" + ".prepare._create_dir_structure", + ), + ] + + def test_preserves_user_provided_s3_uri( + self, mock_builder_with_s3 + ): + """User-provided S3 URI should not be overwritten.""" + builder = mock_builder_with_s3 + for p in self._patches: + p.start() + try: _ModelBuilderServers._build_for_djl(builder) - - self.assertEqual(builder.env_vars["HF_MODEL_ID"], s3_path) - - @patch("sagemaker.serve.model_builder_servers._get_model_config_properties_from_hf") - @patch("sagemaker.serve.model_builder_servers._get_default_djl_configurations") - @patch("sagemaker.serve.model_builder_servers._get_nb_instance", return_value=None) - @patch("sagemaker.serve.model_builder_servers._get_gpu_info", return_value=1) - @patch("sagemaker.serve.model_builder_servers._get_default_tensor_parallel_degree", return_value=1) - def test_sets_hf_model_id_when_not_provided(self, mock_tp, mock_gpu, mock_nb, mock_djl_config, mock_hf_config): - """HF_MODEL_ID should be set from self.model when user doesn't provide it.""" - mock_hf_config.return_value = {} - mock_djl_config.return_value = ({}, 256) - - builder = _create_mock_builder(env_vars={}) - - with patch("sagemaker.serve.model_server.djl_serving.prepare._create_dir_structure"): + finally: + for p in self._patches: + p.stop() + assert builder.env_vars["HF_MODEL_ID"] == S3_PATH + + def test_sets_default_when_not_provided( + self, mock_builder + ): + """HF_MODEL_ID should default to self.model.""" + builder = mock_builder + for p in self._patches: + p.start() + try: _ModelBuilderServers._build_for_djl(builder) - - self.assertEqual(builder.env_vars["HF_MODEL_ID"], "Qwen/Qwen3-VL-4B-Instruct") - - -class TestTgiPreservesHfModelId(unittest.TestCase): - """Test that _build_for_tgi preserves user-provided HF_MODEL_ID.""" - - @patch("sagemaker.serve.model_builder_servers._get_model_config_properties_from_hf") - @patch("sagemaker.serve.model_builder_servers._get_default_tgi_configurations") - @patch("sagemaker.serve.model_builder_servers._get_nb_instance", return_value=None) - @patch("sagemaker.serve.model_builder_servers._get_gpu_info", return_value=1) - @patch("sagemaker.serve.model_builder_servers._get_default_tensor_parallel_degree", return_value=1) - def test_preserves_user_provided_s3_uri(self, mock_tp, mock_gpu, mock_nb, mock_tgi_config, mock_hf_config): - """User-provided S3 URI for HF_MODEL_ID should not be overwritten.""" - mock_hf_config.return_value = {} - mock_tgi_config.return_value = ({}, 256) - - s3_path = "s3://my-bucket/models/Qwen/" - builder = _create_mock_builder(env_vars={"HF_MODEL_ID": s3_path}) + finally: + for p in self._patches: + p.stop() + assert builder.env_vars["HF_MODEL_ID"] == DEFAULT_MODEL + + +# --------------------------------------------------------------------------- +# TGI +# --------------------------------------------------------------------------- +class TestBuildForTgiHfModelId: + """Test _build_for_tgi preserves user-provided HF_MODEL_ID.""" + + _patches = [ + patch( + "sagemaker.serve.model_builder_servers" + "._get_default_tensor_parallel_degree", + return_value=1, + ), + patch( + "sagemaker.serve.model_builder_servers._get_gpu_info", + return_value=1, + ), + patch( + "sagemaker.serve.model_builder_servers._get_nb_instance", + return_value=None, + ), + patch( + "sagemaker.serve.model_builder_servers" + "._get_default_tgi_configurations", + return_value=({}, 256), + ), + patch( + "sagemaker.serve.model_builder_servers" + "._get_model_config_properties_from_hf", + return_value={}, + ), + patch( + "sagemaker.serve.model_server.tgi" + ".prepare._create_dir_structure", + ), + ] + + def test_preserves_user_provided_s3_uri( + self, mock_builder_with_s3 + ): + builder = mock_builder_with_s3 builder.model_server = ModelServer.TGI - - with patch("sagemaker.serve.model_server.tgi.prepare._create_dir_structure"): + for p in self._patches: + p.start() + try: _ModelBuilderServers._build_for_tgi(builder) - - self.assertEqual(builder.env_vars["HF_MODEL_ID"], s3_path) - - @patch("sagemaker.serve.model_builder_servers._get_model_config_properties_from_hf") - @patch("sagemaker.serve.model_builder_servers._get_default_tgi_configurations") - @patch("sagemaker.serve.model_builder_servers._get_nb_instance", return_value=None) - @patch("sagemaker.serve.model_builder_servers._get_gpu_info", return_value=1) - @patch("sagemaker.serve.model_builder_servers._get_default_tensor_parallel_degree", return_value=1) - def test_sets_hf_model_id_when_not_provided(self, mock_tp, mock_gpu, mock_nb, mock_tgi_config, mock_hf_config): - """HF_MODEL_ID should be set from self.model when user doesn't provide it.""" - mock_hf_config.return_value = {} - mock_tgi_config.return_value = ({}, 256) - - builder = _create_mock_builder(env_vars={}) + finally: + for p in self._patches: + p.stop() + assert builder.env_vars["HF_MODEL_ID"] == S3_PATH + + def test_sets_default_when_not_provided( + self, mock_builder + ): + builder = mock_builder builder.model_server = ModelServer.TGI - - with patch("sagemaker.serve.model_server.tgi.prepare._create_dir_structure"): + for p in self._patches: + p.start() + try: _ModelBuilderServers._build_for_tgi(builder) - - self.assertEqual(builder.env_vars["HF_MODEL_ID"], "Qwen/Qwen3-VL-4B-Instruct") - - -class TestTeiPreservesHfModelId(unittest.TestCase): - """Test that _build_for_tei preserves user-provided HF_MODEL_ID.""" - - @patch("sagemaker.serve.model_builder_servers._get_model_config_properties_from_hf") - @patch("sagemaker.serve.model_builder_servers._get_nb_instance", return_value=None) - def test_preserves_user_provided_s3_uri(self, mock_nb, mock_hf_config): - """User-provided S3 URI for HF_MODEL_ID should not be overwritten.""" - mock_hf_config.return_value = {} - - s3_path = "s3://my-bucket/models/embedding-model/" - builder = _create_mock_builder(env_vars={"HF_MODEL_ID": s3_path}) + finally: + for p in self._patches: + p.stop() + assert builder.env_vars["HF_MODEL_ID"] == DEFAULT_MODEL + + +# --------------------------------------------------------------------------- +# TEI +# --------------------------------------------------------------------------- +class TestBuildForTeiHfModelId: + """Test _build_for_tei preserves user-provided HF_MODEL_ID.""" + + _patches = [ + patch( + "sagemaker.serve.model_builder_servers._get_nb_instance", + return_value=None, + ), + patch( + "sagemaker.serve.model_builder_servers" + "._get_model_config_properties_from_hf", + return_value={}, + ), + patch( + "sagemaker.serve.model_server.tgi" + ".prepare._create_dir_structure", + ), + ] + + def test_preserves_user_provided_s3_uri( + self, mock_builder_with_s3 + ): + builder = mock_builder_with_s3 builder.model_server = ModelServer.TEI - - with patch("sagemaker.serve.model_server.tgi.prepare._create_dir_structure"): + for p in self._patches: + p.start() + try: _ModelBuilderServers._build_for_tei(builder) - - self.assertEqual(builder.env_vars["HF_MODEL_ID"], s3_path) - - @patch("sagemaker.serve.model_builder_servers._get_model_config_properties_from_hf") - @patch("sagemaker.serve.model_builder_servers._get_nb_instance", return_value=None) - def test_sets_hf_model_id_when_not_provided(self, mock_nb, mock_hf_config): - """HF_MODEL_ID should be set from self.model when user doesn't provide it.""" - mock_hf_config.return_value = {} - - builder = _create_mock_builder(env_vars={}) + finally: + for p in self._patches: + p.stop() + assert builder.env_vars["HF_MODEL_ID"] == S3_PATH + + def test_sets_default_when_not_provided( + self, mock_builder + ): + builder = mock_builder builder.model_server = ModelServer.TEI - - with patch("sagemaker.serve.model_server.tgi.prepare._create_dir_structure"): + for p in self._patches: + p.start() + try: _ModelBuilderServers._build_for_tei(builder) - - self.assertEqual(builder.env_vars["HF_MODEL_ID"], "Qwen/Qwen3-VL-4B-Instruct") - - -class TestTorchservePreservesHfModelId(unittest.TestCase): - """Test that _build_for_torchserve preserves user-provided HF_MODEL_ID.""" - - def test_preserves_user_provided_s3_uri(self): - """User-provided S3 URI for HF_MODEL_ID should not be overwritten.""" - s3_path = "s3://my-bucket/models/my-model/" - builder = _create_mock_builder(env_vars={"HF_MODEL_ID": s3_path}) + finally: + for p in self._patches: + p.stop() + assert builder.env_vars["HF_MODEL_ID"] == DEFAULT_MODEL + + +# --------------------------------------------------------------------------- +# TorchServe +# --------------------------------------------------------------------------- +class TestBuildForTorchserveHfModelId: + """Test _build_for_torchserve preserves user-provided HF_MODEL_ID.""" + + def test_preserves_user_provided_s3_uri( + self, mock_builder_with_s3 + ): + builder = mock_builder_with_s3 builder.model_server = ModelServer.TORCHSERVE - builder.mode = Mode.SAGEMAKER_ENDPOINT builder._save_model_inference_spec = Mock() - _ModelBuilderServers._build_for_torchserve(builder) + assert builder.env_vars["HF_MODEL_ID"] == S3_PATH - self.assertEqual(builder.env_vars["HF_MODEL_ID"], s3_path) - - def test_sets_hf_model_id_when_not_provided(self): - """HF_MODEL_ID should be set from self.model when user doesn't provide it.""" - builder = _create_mock_builder(env_vars={}) + def test_sets_default_when_not_provided( + self, mock_builder + ): + builder = mock_builder builder.model_server = ModelServer.TORCHSERVE - builder.mode = Mode.SAGEMAKER_ENDPOINT builder._save_model_inference_spec = Mock() - _ModelBuilderServers._build_for_torchserve(builder) - - self.assertEqual(builder.env_vars["HF_MODEL_ID"], "Qwen/Qwen3-VL-4B-Instruct") + assert builder.env_vars["HF_MODEL_ID"] == DEFAULT_MODEL -class TestTritonPreservesHfModelId(unittest.TestCase): - """Test that _build_for_triton preserves user-provided HF_MODEL_ID.""" +# --------------------------------------------------------------------------- +# Triton +# --------------------------------------------------------------------------- +class TestBuildForTritonHfModelId: + """Test _build_for_triton preserves user-provided HF_MODEL_ID.""" - def test_preserves_user_provided_s3_uri(self): - """User-provided S3 URI for HF_MODEL_ID should not be overwritten.""" - s3_path = "s3://my-bucket/models/my-model/" - builder = _create_mock_builder(env_vars={"HF_MODEL_ID": s3_path}) + def test_preserves_user_provided_s3_uri( + self, mock_builder_with_s3 + ): + builder = mock_builder_with_s3 builder.model_server = ModelServer.TRITON builder._save_inference_spec = Mock() builder._prepare_for_triton = Mock() builder._auto_detect_image_for_triton = Mock() - _ModelBuilderServers._build_for_triton(builder) + assert builder.env_vars["HF_MODEL_ID"] == S3_PATH - self.assertEqual(builder.env_vars["HF_MODEL_ID"], s3_path) - - def test_sets_hf_model_id_when_not_provided(self): - """HF_MODEL_ID should be set from self.model when user doesn't provide it.""" - builder = _create_mock_builder(env_vars={}) + def test_sets_default_when_not_provided( + self, mock_builder + ): + builder = mock_builder builder.model_server = ModelServer.TRITON builder._save_inference_spec = Mock() builder._prepare_for_triton = Mock() builder._auto_detect_image_for_triton = Mock() - _ModelBuilderServers._build_for_triton(builder) - - self.assertEqual(builder.env_vars["HF_MODEL_ID"], "Qwen/Qwen3-VL-4B-Instruct") - - -class TestTransformersPreservesHfModelId(unittest.TestCase): - """Test that _build_for_transformers preserves user-provided HF_MODEL_ID.""" - - @patch("sagemaker.serve.model_builder_servers._get_model_config_properties_from_hf") - @patch("sagemaker.serve.model_builder_servers._get_nb_instance", return_value=None) - def test_preserves_user_provided_s3_uri_with_model_string(self, mock_nb, mock_hf_config): - """User-provided S3 URI for HF_MODEL_ID should not be overwritten when model is a string.""" - mock_hf_config.return_value = {} - - s3_path = "s3://my-bucket/models/my-model/" - builder = _create_mock_builder(env_vars={"HF_MODEL_ID": s3_path}) + assert builder.env_vars["HF_MODEL_ID"] == DEFAULT_MODEL + + +# --------------------------------------------------------------------------- +# Transformers (MMS) +# --------------------------------------------------------------------------- +class TestBuildForTransformersHfModelId: + """Test _build_for_transformers preserves user-provided HF_MODEL_ID.""" + + _patches = [ + patch( + "sagemaker.serve.model_builder_servers._get_nb_instance", + return_value=None, + ), + patch( + "sagemaker.serve.model_builder_servers" + "._get_model_config_properties_from_hf", + return_value={}, + ), + patch( + "sagemaker.serve.model_server.multi_model_server" + ".prepare._create_dir_structure", + ), + ] + + def test_preserves_user_provided_s3_uri( + self, mock_builder_with_s3 + ): + builder = mock_builder_with_s3 builder.model_server = ModelServer.MMS - builder.mode = Mode.SAGEMAKER_ENDPOINT builder.model_data_download_timeout = None - - with patch("sagemaker.serve.model_server.multi_model_server.prepare._create_dir_structure"): + for p in self._patches: + p.start() + try: _ModelBuilderServers._build_for_transformers(builder) - - self.assertEqual(builder.env_vars["HF_MODEL_ID"], s3_path) - - @patch("sagemaker.serve.model_builder_servers._get_model_config_properties_from_hf") - @patch("sagemaker.serve.model_builder_servers._get_nb_instance", return_value=None) - def test_sets_hf_model_id_when_not_provided_with_model_string(self, mock_nb, mock_hf_config): - """HF_MODEL_ID should be set from self.model when user doesn't provide it.""" - mock_hf_config.return_value = {} - - builder = _create_mock_builder(env_vars={}) + finally: + for p in self._patches: + p.stop() + assert builder.env_vars["HF_MODEL_ID"] == S3_PATH + + def test_sets_default_when_not_provided( + self, mock_builder + ): + builder = mock_builder builder.model_server = ModelServer.MMS - builder.mode = Mode.SAGEMAKER_ENDPOINT builder.model_data_download_timeout = None - - with patch("sagemaker.serve.model_server.multi_model_server.prepare._create_dir_structure"): + for p in self._patches: + p.start() + try: _ModelBuilderServers._build_for_transformers(builder) + finally: + for p in self._patches: + p.stop() + assert builder.env_vars["HF_MODEL_ID"] == DEFAULT_MODEL - self.assertEqual(builder.env_vars["HF_MODEL_ID"], "Qwen/Qwen3-VL-4B-Instruct") - - @patch("sagemaker.serve.model_builder_servers._get_model_config_properties_from_hf") - @patch("sagemaker.serve.model_builder_servers._get_nb_instance", return_value=None) @patch("sagemaker.serve.model_builder_servers.save_pkl") - def test_preserves_user_provided_hf_model_id_with_inference_spec(self, mock_pkl, mock_nb, mock_hf_config): - """User-provided HF_MODEL_ID should not be overwritten when inference_spec provides a model ID.""" - mock_hf_config.return_value = {} - - s3_path = "s3://my-bucket/models/my-model/" - builder = _create_mock_builder(env_vars={"HF_MODEL_ID": s3_path}) + @patch( + "sagemaker.serve.model_builder_servers" + "._get_model_config_properties_from_hf", + return_value={}, + ) + @patch( + "sagemaker.serve.model_builder_servers._get_nb_instance", + return_value=None, + ) + @patch( + "sagemaker.serve.model_server.multi_model_server" + ".prepare._create_dir_structure", + ) + @patch("os.makedirs") + def test_preserves_with_inference_spec( + self, + _mock_makedirs, + _mock_dir, + _mock_nb, + _mock_hf, + _mock_pkl, + ): + """User-provided HF_MODEL_ID preserved with inference_spec.""" + builder = _create_mock_builder( + env_vars={"HF_MODEL_ID": S3_PATH} + ) builder.model_server = ModelServer.MMS - builder.mode = Mode.SAGEMAKER_ENDPOINT builder.model_data_download_timeout = None - builder.model = None # No model string, using inference_spec + builder.model = None builder.inference_spec = Mock() - builder.inference_spec.get_model.return_value = "some-hf-model-id" - builder._is_jumpstart_model_id = Mock(return_value=False) - - with patch("sagemaker.serve.model_server.multi_model_server.prepare._create_dir_structure"): - with patch("os.makedirs"): - _ModelBuilderServers._build_for_transformers(builder) - - self.assertEqual(builder.env_vars["HF_MODEL_ID"], s3_path) - - -if __name__ == "__main__": - unittest.main() + builder.inference_spec.get_model.return_value = ( + "some-hf-model-id" + ) + builder._is_jumpstart_model_id = Mock( + return_value=False + ) + _ModelBuilderServers._build_for_transformers(builder) + assert builder.env_vars["HF_MODEL_ID"] == S3_PATH From 22e4363e90941377ddba119c5e65c6a27c80cd46 Mon Sep 17 00:00:00 2001 From: aviruthen <91846056+aviruthen@users.noreply.github.com> Date: Tue, 14 Apr 2026 14:42:55 -0700 Subject: [PATCH 3/5] fix: address review comments (iteration #1) --- .../test_model_builder_servers_hf_model_id.py | 527 +++++++++--------- 1 file changed, 250 insertions(+), 277 deletions(-) diff --git a/sagemaker-serve/tests/unit/test_model_builder_servers_hf_model_id.py b/sagemaker-serve/tests/unit/test_model_builder_servers_hf_model_id.py index 1af9891cc5..66ae83a0f8 100644 --- a/sagemaker-serve/tests/unit/test_model_builder_servers_hf_model_id.py +++ b/sagemaker-serve/tests/unit/test_model_builder_servers_hf_model_id.py @@ -1,7 +1,7 @@ -"""Unit tests to verify HF_MODEL_ID is not overwritten when user provides it.""" +"""Unit tests: HF_MODEL_ID is not overwritten when user provides it.""" from __future__ import annotations -from typing import Optional +from typing import Dict, List, Optional from unittest.mock import Mock, patch, MagicMock import pytest @@ -11,9 +11,13 @@ from sagemaker.serve.mode.function_pointers import Mode +S3_PATH = "s3://my-bucket/models/Qwen/" +DEFAULT_MODEL = "Qwen/Qwen3-VL-4B-Instruct" + + def _create_mock_builder( - env_vars: Optional[dict[str, str]] = None, - model: str = "Qwen/Qwen3-VL-4B-Instruct", + env_vars: Optional[Dict[str, str]] = None, + model: str = DEFAULT_MODEL, ) -> MagicMock: """Create a mock builder with common attributes set.""" builder = MagicMock(spec=_ModelBuilderServers) @@ -49,6 +53,10 @@ def _create_mock_builder( builder._validate_djl_serving_sample_data = Mock() builder._validate_tgi_serving_sample_data = Mock() builder._validate_for_triton = Mock() + builder._save_model_inference_spec = Mock() + builder._save_inference_spec = Mock() + builder._prepare_for_triton = Mock() + builder._auto_detect_image_for_triton = Mock() builder.get_huggingface_model_metadata = Mock( return_value={"pipeline_tag": "text-generation"} ) @@ -68,299 +76,263 @@ def mock_builder() -> MagicMock: def mock_builder_with_s3() -> MagicMock: """Create a mock builder with user-provided S3 HF_MODEL_ID.""" return _create_mock_builder( - env_vars={"HF_MODEL_ID": "s3://my-bucket/models/Qwen/"} + env_vars={"HF_MODEL_ID": S3_PATH} ) -S3_PATH = "s3://my-bucket/models/Qwen/" -DEFAULT_MODEL = "Qwen/Qwen3-VL-4B-Instruct" - - -# --------------------------------------------------------------------------- -# DJL Serving -# --------------------------------------------------------------------------- -class TestBuildForDjlHfModelId: - """Test _build_for_djl preserves user-provided HF_MODEL_ID.""" - - _patches = [ - patch( - "sagemaker.serve.model_builder_servers" - "._get_default_tensor_parallel_degree", - return_value=1, - ), - patch( - "sagemaker.serve.model_builder_servers._get_gpu_info", - return_value=1, - ), - patch( - "sagemaker.serve.model_builder_servers._get_nb_instance", - return_value=None, +# -- Patch sets for each server type ---------------------------------- + +_DJL_PATCHES: List[str] = [ + "sagemaker.serve.model_builder_servers" + "._get_default_tensor_parallel_degree", + "sagemaker.serve.model_builder_servers._get_gpu_info", + "sagemaker.serve.model_builder_servers._get_nb_instance", + "sagemaker.serve.model_builder_servers" + "._get_default_djl_configurations", + "sagemaker.serve.model_builder_servers" + "._get_model_config_properties_from_hf", + "sagemaker.serve.model_server.djl_serving" + ".prepare._create_dir_structure", +] + +_DJL_RETURN_VALUES = [ + 1, # tensor_parallel_degree + 1, # gpu_info + None, # nb_instance + ({}, 256), # djl_configurations + {}, # hf_model_config + None, # _create_dir_structure +] + +_TGI_PATCHES: List[str] = [ + "sagemaker.serve.model_builder_servers" + "._get_default_tensor_parallel_degree", + "sagemaker.serve.model_builder_servers._get_gpu_info", + "sagemaker.serve.model_builder_servers._get_nb_instance", + "sagemaker.serve.model_builder_servers" + "._get_default_tgi_configurations", + "sagemaker.serve.model_builder_servers" + "._get_model_config_properties_from_hf", + "sagemaker.serve.model_server.tgi" + ".prepare._create_dir_structure", +] + +_TGI_RETURN_VALUES = [ + 1, # tensor_parallel_degree + 1, # gpu_info + None, # nb_instance + ({}, 256), # tgi_configurations + {}, # hf_model_config + None, # _create_dir_structure +] + +_TEI_PATCHES: List[str] = [ + "sagemaker.serve.model_builder_servers._get_nb_instance", + "sagemaker.serve.model_builder_servers" + "._get_model_config_properties_from_hf", + "sagemaker.serve.model_server.tgi" + ".prepare._create_dir_structure", +] + +_TEI_RETURN_VALUES = [ + None, # nb_instance + {}, # hf_model_config + None, # _create_dir_structure +] + +_MMS_PATCHES: List[str] = [ + "sagemaker.serve.model_builder_servers._get_nb_instance", + "sagemaker.serve.model_builder_servers" + "._get_model_config_properties_from_hf", + "sagemaker.serve.model_server.multi_model_server" + ".prepare._create_dir_structure", +] + +_MMS_RETURN_VALUES = [ + None, # nb_instance + {}, # hf_model_config + None, # _create_dir_structure +] + + +def _apply_patches( + targets: List[str], + return_values: list, +) -> List: + """Start patches and return the list of patchers.""" + patchers = [] + for target, rv in zip(targets, return_values): + p = patch(target, return_value=rv) + p.start() + patchers.append(p) + return patchers + + +def _stop_patches(patchers: List) -> None: + """Stop all patchers.""" + for p in patchers: + p.stop() + + +# --------------------------------------------------------------- +# Parametrised tests: preserve user-provided HF_MODEL_ID +# --------------------------------------------------------------- +@pytest.mark.parametrize( + "build_method, server_type, patch_targets, patch_rvs", + [ + ( + "_build_for_djl", + ModelServer.DJL_SERVING, + _DJL_PATCHES, + _DJL_RETURN_VALUES, ), - patch( - "sagemaker.serve.model_builder_servers" - "._get_default_djl_configurations", - return_value=({}, 256), - ), - patch( - "sagemaker.serve.model_builder_servers" - "._get_model_config_properties_from_hf", - return_value={}, - ), - patch( - "sagemaker.serve.model_server.djl_serving" - ".prepare._create_dir_structure", - ), - ] - - def test_preserves_user_provided_s3_uri( - self, mock_builder_with_s3 - ): - """User-provided S3 URI should not be overwritten.""" - builder = mock_builder_with_s3 - for p in self._patches: - p.start() - try: - _ModelBuilderServers._build_for_djl(builder) - finally: - for p in self._patches: - p.stop() - assert builder.env_vars["HF_MODEL_ID"] == S3_PATH - - def test_sets_default_when_not_provided( - self, mock_builder - ): - """HF_MODEL_ID should default to self.model.""" - builder = mock_builder - for p in self._patches: - p.start() - try: - _ModelBuilderServers._build_for_djl(builder) - finally: - for p in self._patches: - p.stop() - assert builder.env_vars["HF_MODEL_ID"] == DEFAULT_MODEL - - -# --------------------------------------------------------------------------- -# TGI -# --------------------------------------------------------------------------- -class TestBuildForTgiHfModelId: - """Test _build_for_tgi preserves user-provided HF_MODEL_ID.""" - - _patches = [ - patch( - "sagemaker.serve.model_builder_servers" - "._get_default_tensor_parallel_degree", - return_value=1, + ( + "_build_for_tgi", + ModelServer.TGI, + _TGI_PATCHES, + _TGI_RETURN_VALUES, ), - patch( - "sagemaker.serve.model_builder_servers._get_gpu_info", - return_value=1, + ( + "_build_for_tei", + ModelServer.TEI, + _TEI_PATCHES, + _TEI_RETURN_VALUES, ), - patch( - "sagemaker.serve.model_builder_servers._get_nb_instance", - return_value=None, + ( + "_build_for_torchserve", + ModelServer.TORCHSERVE, + [], + [], ), - patch( - "sagemaker.serve.model_builder_servers" - "._get_default_tgi_configurations", - return_value=({}, 256), + ( + "_build_for_triton", + ModelServer.TRITON, + [], + [], ), - patch( - "sagemaker.serve.model_builder_servers" - "._get_model_config_properties_from_hf", - return_value={}, + ], + ids=[ + "djl", + "tgi", + "tei", + "torchserve", + "triton", + ], +) +def test_preserves_user_provided_hf_model_id( + build_method: str, + server_type: ModelServer, + patch_targets: List[str], + patch_rvs: list, + mock_builder_with_s3: MagicMock, +) -> None: + """User-provided HF_MODEL_ID must not be overwritten.""" + builder = mock_builder_with_s3 + builder.model_server = server_type + patchers = _apply_patches(patch_targets, patch_rvs) + try: + getattr(_ModelBuilderServers, build_method)(builder) + finally: + _stop_patches(patchers) + assert builder.env_vars["HF_MODEL_ID"] == S3_PATH + + +@pytest.mark.parametrize( + "build_method, server_type, patch_targets, patch_rvs", + [ + ( + "_build_for_djl", + ModelServer.DJL_SERVING, + _DJL_PATCHES, + _DJL_RETURN_VALUES, ), - patch( - "sagemaker.serve.model_server.tgi" - ".prepare._create_dir_structure", + ( + "_build_for_tgi", + ModelServer.TGI, + _TGI_PATCHES, + _TGI_RETURN_VALUES, ), - ] - - def test_preserves_user_provided_s3_uri( - self, mock_builder_with_s3 - ): - builder = mock_builder_with_s3 - builder.model_server = ModelServer.TGI - for p in self._patches: - p.start() - try: - _ModelBuilderServers._build_for_tgi(builder) - finally: - for p in self._patches: - p.stop() - assert builder.env_vars["HF_MODEL_ID"] == S3_PATH - - def test_sets_default_when_not_provided( - self, mock_builder - ): - builder = mock_builder - builder.model_server = ModelServer.TGI - for p in self._patches: - p.start() - try: - _ModelBuilderServers._build_for_tgi(builder) - finally: - for p in self._patches: - p.stop() - assert builder.env_vars["HF_MODEL_ID"] == DEFAULT_MODEL - - -# --------------------------------------------------------------------------- -# TEI -# --------------------------------------------------------------------------- -class TestBuildForTeiHfModelId: - """Test _build_for_tei preserves user-provided HF_MODEL_ID.""" - - _patches = [ - patch( - "sagemaker.serve.model_builder_servers._get_nb_instance", - return_value=None, + ( + "_build_for_tei", + ModelServer.TEI, + _TEI_PATCHES, + _TEI_RETURN_VALUES, ), - patch( - "sagemaker.serve.model_builder_servers" - "._get_model_config_properties_from_hf", - return_value={}, + ( + "_build_for_torchserve", + ModelServer.TORCHSERVE, + [], + [], ), - patch( - "sagemaker.serve.model_server.tgi" - ".prepare._create_dir_structure", + ( + "_build_for_triton", + ModelServer.TRITON, + [], + [], ), - ] - - def test_preserves_user_provided_s3_uri( - self, mock_builder_with_s3 - ): - builder = mock_builder_with_s3 - builder.model_server = ModelServer.TEI - for p in self._patches: - p.start() - try: - _ModelBuilderServers._build_for_tei(builder) - finally: - for p in self._patches: - p.stop() - assert builder.env_vars["HF_MODEL_ID"] == S3_PATH - - def test_sets_default_when_not_provided( - self, mock_builder - ): - builder = mock_builder - builder.model_server = ModelServer.TEI - for p in self._patches: - p.start() - try: - _ModelBuilderServers._build_for_tei(builder) - finally: - for p in self._patches: - p.stop() - assert builder.env_vars["HF_MODEL_ID"] == DEFAULT_MODEL - - -# --------------------------------------------------------------------------- -# TorchServe -# --------------------------------------------------------------------------- -class TestBuildForTorchserveHfModelId: - """Test _build_for_torchserve preserves user-provided HF_MODEL_ID.""" - - def test_preserves_user_provided_s3_uri( - self, mock_builder_with_s3 - ): - builder = mock_builder_with_s3 - builder.model_server = ModelServer.TORCHSERVE - builder._save_model_inference_spec = Mock() - _ModelBuilderServers._build_for_torchserve(builder) - assert builder.env_vars["HF_MODEL_ID"] == S3_PATH - - def test_sets_default_when_not_provided( - self, mock_builder - ): - builder = mock_builder - builder.model_server = ModelServer.TORCHSERVE - builder._save_model_inference_spec = Mock() - _ModelBuilderServers._build_for_torchserve(builder) - assert builder.env_vars["HF_MODEL_ID"] == DEFAULT_MODEL - - -# --------------------------------------------------------------------------- -# Triton -# --------------------------------------------------------------------------- -class TestBuildForTritonHfModelId: - """Test _build_for_triton preserves user-provided HF_MODEL_ID.""" - - def test_preserves_user_provided_s3_uri( - self, mock_builder_with_s3 - ): - builder = mock_builder_with_s3 - builder.model_server = ModelServer.TRITON - builder._save_inference_spec = Mock() - builder._prepare_for_triton = Mock() - builder._auto_detect_image_for_triton = Mock() - _ModelBuilderServers._build_for_triton(builder) - assert builder.env_vars["HF_MODEL_ID"] == S3_PATH - - def test_sets_default_when_not_provided( - self, mock_builder - ): - builder = mock_builder - builder.model_server = ModelServer.TRITON - builder._save_inference_spec = Mock() - builder._prepare_for_triton = Mock() - builder._auto_detect_image_for_triton = Mock() - _ModelBuilderServers._build_for_triton(builder) - assert builder.env_vars["HF_MODEL_ID"] == DEFAULT_MODEL - - -# --------------------------------------------------------------------------- -# Transformers (MMS) -# --------------------------------------------------------------------------- + ], + ids=[ + "djl", + "tgi", + "tei", + "torchserve", + "triton", + ], +) +def test_sets_default_hf_model_id_when_not_provided( + build_method: str, + server_type: ModelServer, + patch_targets: List[str], + patch_rvs: list, + mock_builder: MagicMock, +) -> None: + """HF_MODEL_ID should default to self.model.""" + builder = mock_builder + builder.model_server = server_type + patchers = _apply_patches(patch_targets, patch_rvs) + try: + getattr(_ModelBuilderServers, build_method)(builder) + finally: + _stop_patches(patchers) + assert builder.env_vars["HF_MODEL_ID"] == DEFAULT_MODEL + + +# --------------------------------------------------------------- +# Transformers (MMS) — needs extra patches for _create_dir_structure +# --------------------------------------------------------------- class TestBuildForTransformersHfModelId: - """Test _build_for_transformers preserves user-provided HF_MODEL_ID.""" - - _patches = [ - patch( - "sagemaker.serve.model_builder_servers._get_nb_instance", - return_value=None, - ), - patch( - "sagemaker.serve.model_builder_servers" - "._get_model_config_properties_from_hf", - return_value={}, - ), - patch( - "sagemaker.serve.model_server.multi_model_server" - ".prepare._create_dir_structure", - ), - ] + """_build_for_transformers preserves user-provided HF_MODEL_ID.""" def test_preserves_user_provided_s3_uri( - self, mock_builder_with_s3 - ): + self, mock_builder_with_s3: MagicMock + ) -> None: builder = mock_builder_with_s3 builder.model_server = ModelServer.MMS - builder.model_data_download_timeout = None - for p in self._patches: - p.start() + patchers = _apply_patches( + _MMS_PATCHES, _MMS_RETURN_VALUES + ) try: - _ModelBuilderServers._build_for_transformers(builder) + _ModelBuilderServers._build_for_transformers( + builder + ) finally: - for p in self._patches: - p.stop() + _stop_patches(patchers) assert builder.env_vars["HF_MODEL_ID"] == S3_PATH def test_sets_default_when_not_provided( - self, mock_builder - ): + self, mock_builder: MagicMock + ) -> None: builder = mock_builder builder.model_server = ModelServer.MMS - builder.model_data_download_timeout = None - for p in self._patches: - p.start() + patchers = _apply_patches( + _MMS_PATCHES, _MMS_RETURN_VALUES + ) try: - _ModelBuilderServers._build_for_transformers(builder) + _ModelBuilderServers._build_for_transformers( + builder + ) finally: - for p in self._patches: - p.stop() + _stop_patches(patchers) assert builder.env_vars["HF_MODEL_ID"] == DEFAULT_MODEL @patch("sagemaker.serve.model_builder_servers.save_pkl") @@ -370,7 +342,8 @@ def test_sets_default_when_not_provided( return_value={}, ) @patch( - "sagemaker.serve.model_builder_servers._get_nb_instance", + "sagemaker.serve.model_builder_servers" + "._get_nb_instance", return_value=None, ) @patch( @@ -380,12 +353,12 @@ def test_sets_default_when_not_provided( @patch("os.makedirs") def test_preserves_with_inference_spec( self, - _mock_makedirs, - _mock_dir, - _mock_nb, - _mock_hf, - _mock_pkl, - ): + _mock_makedirs: Mock, + _mock_dir: Mock, + _mock_nb: Mock, + _mock_hf: Mock, + _mock_pkl: Mock, + ) -> None: """User-provided HF_MODEL_ID preserved with inference_spec.""" builder = _create_mock_builder( env_vars={"HF_MODEL_ID": S3_PATH} From ab7b3c8ae4a7e1c7aaf149b24b9f894beef81879 Mon Sep 17 00:00:00 2001 From: aviruthen <91846056+aviruthen@users.noreply.github.com> Date: Thu, 16 Apr 2026 12:14:29 -0700 Subject: [PATCH 4/5] fix: address review comments (iteration #1) --- .../test_model_builder_servers_hf_model_id.py | 121 ++++++++++++------ 1 file changed, 80 insertions(+), 41 deletions(-) diff --git a/sagemaker-serve/tests/unit/test_model_builder_servers_hf_model_id.py b/sagemaker-serve/tests/unit/test_model_builder_servers_hf_model_id.py index 66ae83a0f8..3de44e7ea1 100644 --- a/sagemaker-serve/tests/unit/test_model_builder_servers_hf_model_id.py +++ b/sagemaker-serve/tests/unit/test_model_builder_servers_hf_model_id.py @@ -6,7 +6,9 @@ import pytest -from sagemaker.serve.model_builder_servers import _ModelBuilderServers +from sagemaker.serve.model_builder_servers import ( + _ModelBuilderServers, +) from sagemaker.serve.utils.types import ModelServer from sagemaker.serve.mode.function_pointers import Mode @@ -22,7 +24,9 @@ def _create_mock_builder( """Create a mock builder with common attributes set.""" builder = MagicMock(spec=_ModelBuilderServers) builder.model = model - builder.env_vars = env_vars if env_vars is not None else {} + builder.env_vars = ( + env_vars if env_vars is not None else {} + ) builder.model_path = "/tmp/test_model_path" builder.mode = Mode.SAGEMAKER_ENDPOINT builder.model_server = ModelServer.DJL_SERVING @@ -43,7 +47,9 @@ def _create_mock_builder( builder.hf_model_config = {} builder.model_data_download_timeout = None builder._user_provided_instance_type = True - builder._is_jumpstart_model_id = Mock(return_value=False) + builder._is_jumpstart_model_id = Mock( + return_value=False + ) builder._auto_detect_image_uri = Mock() builder._prepare_for_mode = Mock( return_value=("s3://model-data", None) @@ -68,25 +74,27 @@ def _create_mock_builder( @pytest.fixture def mock_builder() -> MagicMock: - """Create a mock builder with default (empty) env_vars.""" + """Create a mock builder with default env_vars.""" return _create_mock_builder(env_vars={}) @pytest.fixture def mock_builder_with_s3() -> MagicMock: - """Create a mock builder with user-provided S3 HF_MODEL_ID.""" + """Mock builder with user-provided S3 HF_MODEL_ID.""" return _create_mock_builder( env_vars={"HF_MODEL_ID": S3_PATH} ) -# -- Patch sets for each server type ---------------------------------- +# -- Patch targets for each server type -------------------------- _DJL_PATCHES: List[str] = [ "sagemaker.serve.model_builder_servers" "._get_default_tensor_parallel_degree", - "sagemaker.serve.model_builder_servers._get_gpu_info", - "sagemaker.serve.model_builder_servers._get_nb_instance", + "sagemaker.serve.model_builder_servers" + "._get_gpu_info", + "sagemaker.serve.model_builder_servers" + "._get_nb_instance", "sagemaker.serve.model_builder_servers" "._get_default_djl_configurations", "sagemaker.serve.model_builder_servers" @@ -96,19 +104,21 @@ def mock_builder_with_s3() -> MagicMock: ] _DJL_RETURN_VALUES = [ - 1, # tensor_parallel_degree - 1, # gpu_info - None, # nb_instance + 1, # tensor_parallel_degree + 1, # gpu_info + None, # nb_instance ({}, 256), # djl_configurations - {}, # hf_model_config - None, # _create_dir_structure + {}, # hf_model_config + None, # _create_dir_structure ] _TGI_PATCHES: List[str] = [ "sagemaker.serve.model_builder_servers" "._get_default_tensor_parallel_degree", - "sagemaker.serve.model_builder_servers._get_gpu_info", - "sagemaker.serve.model_builder_servers._get_nb_instance", + "sagemaker.serve.model_builder_servers" + "._get_gpu_info", + "sagemaker.serve.model_builder_servers" + "._get_nb_instance", "sagemaker.serve.model_builder_servers" "._get_default_tgi_configurations", "sagemaker.serve.model_builder_servers" @@ -118,16 +128,17 @@ def mock_builder_with_s3() -> MagicMock: ] _TGI_RETURN_VALUES = [ - 1, # tensor_parallel_degree - 1, # gpu_info - None, # nb_instance + 1, # tensor_parallel_degree + 1, # gpu_info + None, # nb_instance ({}, 256), # tgi_configurations - {}, # hf_model_config - None, # _create_dir_structure + {}, # hf_model_config + None, # _create_dir_structure ] _TEI_PATCHES: List[str] = [ - "sagemaker.serve.model_builder_servers._get_nb_instance", + "sagemaker.serve.model_builder_servers" + "._get_nb_instance", "sagemaker.serve.model_builder_servers" "._get_model_config_properties_from_hf", "sagemaker.serve.model_server.tgi" @@ -140,8 +151,21 @@ def mock_builder_with_s3() -> MagicMock: None, # _create_dir_structure ] +_TORCHSERVE_PATCHES: List[str] = [ + "sagemaker.serve.model_builder_servers" + ".prepare_for_torchserve", +] + +_TORCHSERVE_RETURN_VALUES = [ + "mock-secret-key", # prepare_for_torchserve +] + +_TRITON_PATCHES: List[str] = [] +_TRITON_RETURN_VALUES: list = [] + _MMS_PATCHES: List[str] = [ - "sagemaker.serve.model_builder_servers._get_nb_instance", + "sagemaker.serve.model_builder_servers" + "._get_nb_instance", "sagemaker.serve.model_builder_servers" "._get_model_config_properties_from_hf", "sagemaker.serve.model_server.multi_model_server" @@ -201,14 +225,14 @@ def _stop_patches(patchers: List) -> None: ( "_build_for_torchserve", ModelServer.TORCHSERVE, - [], - [], + _TORCHSERVE_PATCHES, + _TORCHSERVE_RETURN_VALUES, ), ( "_build_for_triton", ModelServer.TRITON, - [], - [], + _TRITON_PATCHES, + _TRITON_RETURN_VALUES, ), ], ids=[ @@ -231,7 +255,9 @@ def test_preserves_user_provided_hf_model_id( builder.model_server = server_type patchers = _apply_patches(patch_targets, patch_rvs) try: - getattr(_ModelBuilderServers, build_method)(builder) + getattr( + _ModelBuilderServers, build_method + )(builder) finally: _stop_patches(patchers) assert builder.env_vars["HF_MODEL_ID"] == S3_PATH @@ -261,14 +287,14 @@ def test_preserves_user_provided_hf_model_id( ( "_build_for_torchserve", ModelServer.TORCHSERVE, - [], - [], + _TORCHSERVE_PATCHES, + _TORCHSERVE_RETURN_VALUES, ), ( "_build_for_triton", ModelServer.TRITON, - [], - [], + _TRITON_PATCHES, + _TRITON_RETURN_VALUES, ), ], ids=[ @@ -291,21 +317,25 @@ def test_sets_default_hf_model_id_when_not_provided( builder.model_server = server_type patchers = _apply_patches(patch_targets, patch_rvs) try: - getattr(_ModelBuilderServers, build_method)(builder) + getattr( + _ModelBuilderServers, build_method + )(builder) finally: _stop_patches(patchers) assert builder.env_vars["HF_MODEL_ID"] == DEFAULT_MODEL # --------------------------------------------------------------- -# Transformers (MMS) — needs extra patches for _create_dir_structure +# Transformers (MMS) — needs extra patches # --------------------------------------------------------------- class TestBuildForTransformersHfModelId: - """_build_for_transformers preserves user-provided HF_MODEL_ID.""" + """_build_for_transformers preserves HF_MODEL_ID.""" def test_preserves_user_provided_s3_uri( - self, mock_builder_with_s3: MagicMock + self, + mock_builder_with_s3: MagicMock, ) -> None: + """User S3 URI is preserved.""" builder = mock_builder_with_s3 builder.model_server = ModelServer.MMS patchers = _apply_patches( @@ -320,8 +350,10 @@ def test_preserves_user_provided_s3_uri( assert builder.env_vars["HF_MODEL_ID"] == S3_PATH def test_sets_default_when_not_provided( - self, mock_builder: MagicMock + self, + mock_builder: MagicMock, ) -> None: + """HF_MODEL_ID defaults to self.model.""" builder = mock_builder builder.model_server = ModelServer.MMS patchers = _apply_patches( @@ -333,9 +365,13 @@ def test_sets_default_when_not_provided( ) finally: _stop_patches(patchers) - assert builder.env_vars["HF_MODEL_ID"] == DEFAULT_MODEL + assert ( + builder.env_vars["HF_MODEL_ID"] == DEFAULT_MODEL + ) - @patch("sagemaker.serve.model_builder_servers.save_pkl") + @patch( + "sagemaker.serve.model_builder_servers.save_pkl" + ) @patch( "sagemaker.serve.model_builder_servers" "._get_model_config_properties_from_hf", @@ -347,7 +383,8 @@ def test_sets_default_when_not_provided( return_value=None, ) @patch( - "sagemaker.serve.model_server.multi_model_server" + "sagemaker.serve.model_server" + ".multi_model_server" ".prepare._create_dir_structure", ) @patch("os.makedirs") @@ -359,7 +396,7 @@ def test_preserves_with_inference_spec( _mock_hf: Mock, _mock_pkl: Mock, ) -> None: - """User-provided HF_MODEL_ID preserved with inference_spec.""" + """User HF_MODEL_ID preserved with inference_spec.""" builder = _create_mock_builder( env_vars={"HF_MODEL_ID": S3_PATH} ) @@ -373,5 +410,7 @@ def test_preserves_with_inference_spec( builder._is_jumpstart_model_id = Mock( return_value=False ) - _ModelBuilderServers._build_for_transformers(builder) + _ModelBuilderServers._build_for_transformers( + builder + ) assert builder.env_vars["HF_MODEL_ID"] == S3_PATH From 30503876aa98141b37e2d1ffdca1a248f70ece91 Mon Sep 17 00:00:00 2001 From: aviruthen <91846056+aviruthen@users.noreply.github.com> Date: Thu, 16 Apr 2026 14:09:21 -0700 Subject: [PATCH 5/5] fix: address review comments (iteration #2) --- .../test_model_builder_servers_hf_model_id.py | 244 ++++++++---------- 1 file changed, 105 insertions(+), 139 deletions(-) diff --git a/sagemaker-serve/tests/unit/test_model_builder_servers_hf_model_id.py b/sagemaker-serve/tests/unit/test_model_builder_servers_hf_model_id.py index 3de44e7ea1..ac7974399c 100644 --- a/sagemaker-serve/tests/unit/test_model_builder_servers_hf_model_id.py +++ b/sagemaker-serve/tests/unit/test_model_builder_servers_hf_model_id.py @@ -16,6 +16,20 @@ S3_PATH = "s3://my-bucket/models/Qwen/" DEFAULT_MODEL = "Qwen/Qwen3-VL-4B-Instruct" +_MOD = "sagemaker.serve.model_builder_servers" +_DJL_PREP = ( + "sagemaker.serve.model_server" + ".djl_serving.prepare._create_dir_structure" +) +_TGI_PREP = ( + "sagemaker.serve.model_server" + ".tgi.prepare._create_dir_structure" +) +_MMS_PREP = ( + "sagemaker.serve.model_server" + ".multi_model_server.prepare._create_dir_structure" +) + def _create_mock_builder( env_vars: Optional[Dict[str, str]] = None, @@ -54,7 +68,9 @@ def _create_mock_builder( builder._prepare_for_mode = Mock( return_value=("s3://model-data", None) ) - builder._create_model = Mock(return_value=Mock()) + builder._create_model = Mock( + return_value=Mock() + ) builder._optimizing = False builder._validate_djl_serving_sample_data = Mock() builder._validate_tgi_serving_sample_data = Mock() @@ -86,21 +102,15 @@ def mock_builder_with_s3() -> MagicMock: ) -# -- Patch targets for each server type -------------------------- +# -- Patch targets for each server type ---------------------- _DJL_PATCHES: List[str] = [ - "sagemaker.serve.model_builder_servers" - "._get_default_tensor_parallel_degree", - "sagemaker.serve.model_builder_servers" - "._get_gpu_info", - "sagemaker.serve.model_builder_servers" - "._get_nb_instance", - "sagemaker.serve.model_builder_servers" - "._get_default_djl_configurations", - "sagemaker.serve.model_builder_servers" - "._get_model_config_properties_from_hf", - "sagemaker.serve.model_server.djl_serving" - ".prepare._create_dir_structure", + f"{_MOD}._get_default_tensor_parallel_degree", + f"{_MOD}._get_gpu_info", + f"{_MOD}._get_nb_instance", + f"{_MOD}._get_default_djl_configurations", + f"{_MOD}._get_model_config_properties_from_hf", + _DJL_PREP, ] _DJL_RETURN_VALUES = [ @@ -113,18 +123,12 @@ def mock_builder_with_s3() -> MagicMock: ] _TGI_PATCHES: List[str] = [ - "sagemaker.serve.model_builder_servers" - "._get_default_tensor_parallel_degree", - "sagemaker.serve.model_builder_servers" - "._get_gpu_info", - "sagemaker.serve.model_builder_servers" - "._get_nb_instance", - "sagemaker.serve.model_builder_servers" - "._get_default_tgi_configurations", - "sagemaker.serve.model_builder_servers" - "._get_model_config_properties_from_hf", - "sagemaker.serve.model_server.tgi" - ".prepare._create_dir_structure", + f"{_MOD}._get_default_tensor_parallel_degree", + f"{_MOD}._get_gpu_info", + f"{_MOD}._get_nb_instance", + f"{_MOD}._get_default_tgi_configurations", + f"{_MOD}._get_model_config_properties_from_hf", + _TGI_PREP, ] _TGI_RETURN_VALUES = [ @@ -137,12 +141,9 @@ def mock_builder_with_s3() -> MagicMock: ] _TEI_PATCHES: List[str] = [ - "sagemaker.serve.model_builder_servers" - "._get_nb_instance", - "sagemaker.serve.model_builder_servers" - "._get_model_config_properties_from_hf", - "sagemaker.serve.model_server.tgi" - ".prepare._create_dir_structure", + f"{_MOD}._get_nb_instance", + f"{_MOD}._get_model_config_properties_from_hf", + _TGI_PREP, ] _TEI_RETURN_VALUES = [ @@ -152,24 +153,20 @@ def mock_builder_with_s3() -> MagicMock: ] _TORCHSERVE_PATCHES: List[str] = [ - "sagemaker.serve.model_builder_servers" - ".prepare_for_torchserve", + f"{_MOD}.prepare_for_torchserve", ] _TORCHSERVE_RETURN_VALUES = [ - "mock-secret-key", # prepare_for_torchserve + "mock-secret-key", ] _TRITON_PATCHES: List[str] = [] _TRITON_RETURN_VALUES: list = [] _MMS_PATCHES: List[str] = [ - "sagemaker.serve.model_builder_servers" - "._get_nb_instance", - "sagemaker.serve.model_builder_servers" - "._get_model_config_properties_from_hf", - "sagemaker.serve.model_server.multi_model_server" - ".prepare._create_dir_structure", + f"{_MOD}._get_nb_instance", + f"{_MOD}._get_model_config_properties_from_hf", + _MMS_PREP, ] _MMS_RETURN_VALUES = [ @@ -198,50 +195,55 @@ def _stop_patches(patchers: List) -> None: p.stop() -# --------------------------------------------------------------- +# ----------------------------------------------------------- # Parametrised tests: preserve user-provided HF_MODEL_ID -# --------------------------------------------------------------- +# ----------------------------------------------------------- +_SERVER_PARAMS = [ + ( + "_build_for_djl", + ModelServer.DJL_SERVING, + _DJL_PATCHES, + _DJL_RETURN_VALUES, + ), + ( + "_build_for_tgi", + ModelServer.TGI, + _TGI_PATCHES, + _TGI_RETURN_VALUES, + ), + ( + "_build_for_tei", + ModelServer.TEI, + _TEI_PATCHES, + _TEI_RETURN_VALUES, + ), + ( + "_build_for_torchserve", + ModelServer.TORCHSERVE, + _TORCHSERVE_PATCHES, + _TORCHSERVE_RETURN_VALUES, + ), + ( + "_build_for_triton", + ModelServer.TRITON, + _TRITON_PATCHES, + _TRITON_RETURN_VALUES, + ), +] + +_SERVER_IDS = [ + "djl", + "tgi", + "tei", + "torchserve", + "triton", +] + + @pytest.mark.parametrize( "build_method, server_type, patch_targets, patch_rvs", - [ - ( - "_build_for_djl", - ModelServer.DJL_SERVING, - _DJL_PATCHES, - _DJL_RETURN_VALUES, - ), - ( - "_build_for_tgi", - ModelServer.TGI, - _TGI_PATCHES, - _TGI_RETURN_VALUES, - ), - ( - "_build_for_tei", - ModelServer.TEI, - _TEI_PATCHES, - _TEI_RETURN_VALUES, - ), - ( - "_build_for_torchserve", - ModelServer.TORCHSERVE, - _TORCHSERVE_PATCHES, - _TORCHSERVE_RETURN_VALUES, - ), - ( - "_build_for_triton", - ModelServer.TRITON, - _TRITON_PATCHES, - _TRITON_RETURN_VALUES, - ), - ], - ids=[ - "djl", - "tgi", - "tei", - "torchserve", - "triton", - ], + _SERVER_PARAMS, + ids=_SERVER_IDS, ) def test_preserves_user_provided_hf_model_id( build_method: str, @@ -265,45 +267,8 @@ def test_preserves_user_provided_hf_model_id( @pytest.mark.parametrize( "build_method, server_type, patch_targets, patch_rvs", - [ - ( - "_build_for_djl", - ModelServer.DJL_SERVING, - _DJL_PATCHES, - _DJL_RETURN_VALUES, - ), - ( - "_build_for_tgi", - ModelServer.TGI, - _TGI_PATCHES, - _TGI_RETURN_VALUES, - ), - ( - "_build_for_tei", - ModelServer.TEI, - _TEI_PATCHES, - _TEI_RETURN_VALUES, - ), - ( - "_build_for_torchserve", - ModelServer.TORCHSERVE, - _TORCHSERVE_PATCHES, - _TORCHSERVE_RETURN_VALUES, - ), - ( - "_build_for_triton", - ModelServer.TRITON, - _TRITON_PATCHES, - _TRITON_RETURN_VALUES, - ), - ], - ids=[ - "djl", - "tgi", - "tei", - "torchserve", - "triton", - ], + _SERVER_PARAMS, + ids=_SERVER_IDS, ) def test_sets_default_hf_model_id_when_not_provided( build_method: str, @@ -322,12 +287,14 @@ def test_sets_default_hf_model_id_when_not_provided( )(builder) finally: _stop_patches(patchers) - assert builder.env_vars["HF_MODEL_ID"] == DEFAULT_MODEL + assert ( + builder.env_vars["HF_MODEL_ID"] == DEFAULT_MODEL + ) -# --------------------------------------------------------------- +# ----------------------------------------------------------- # Transformers (MMS) — needs extra patches -# --------------------------------------------------------------- +# ----------------------------------------------------------- class TestBuildForTransformersHfModelId: """_build_for_transformers preserves HF_MODEL_ID.""" @@ -347,7 +314,9 @@ def test_preserves_user_provided_s3_uri( ) finally: _stop_patches(patchers) - assert builder.env_vars["HF_MODEL_ID"] == S3_PATH + assert ( + builder.env_vars["HF_MODEL_ID"] == S3_PATH + ) def test_sets_default_when_not_provided( self, @@ -366,27 +335,21 @@ def test_sets_default_when_not_provided( finally: _stop_patches(patchers) assert ( - builder.env_vars["HF_MODEL_ID"] == DEFAULT_MODEL + builder.env_vars["HF_MODEL_ID"] + == DEFAULT_MODEL ) + @patch(f"{_MOD}.prepare_for_mms") + @patch(f"{_MOD}.save_pkl") @patch( - "sagemaker.serve.model_builder_servers.save_pkl" - ) - @patch( - "sagemaker.serve.model_builder_servers" - "._get_model_config_properties_from_hf", + f"{_MOD}._get_model_config_properties_from_hf", return_value={}, ) @patch( - "sagemaker.serve.model_builder_servers" - "._get_nb_instance", + f"{_MOD}._get_nb_instance", return_value=None, ) - @patch( - "sagemaker.serve.model_server" - ".multi_model_server" - ".prepare._create_dir_structure", - ) + @patch(_MMS_PREP) @patch("os.makedirs") def test_preserves_with_inference_spec( self, @@ -395,6 +358,7 @@ def test_preserves_with_inference_spec( _mock_nb: Mock, _mock_hf: Mock, _mock_pkl: Mock, + _mock_mms: Mock, ) -> None: """User HF_MODEL_ID preserved with inference_spec.""" builder = _create_mock_builder( @@ -413,4 +377,6 @@ def test_preserves_with_inference_spec( _ModelBuilderServers._build_for_transformers( builder ) - assert builder.env_vars["HF_MODEL_ID"] == S3_PATH + assert ( + builder.env_vars["HF_MODEL_ID"] == S3_PATH + )