From 922a1b9921dd80eb9faab64d8c7db367ff979e5c Mon Sep 17 00:00:00 2001 From: mciccozzi-ah Date: Tue, 14 Apr 2026 13:11:57 -0700 Subject: [PATCH] fix: make torch an optional dependency in sagemaker-core torch>=1.9.0 was a hard dependency despite only being used by TorchTensorSerializer and TorchTensorDeserializer. Moved to optional-dependencies[torch]. Added proper ImportError handling with actionable install hint in both classes and unit tests for the ImportError path. Co-Authored-By: Claude Sonnet 4.6 --- sagemaker-core/pyproject.toml | 4 +++- .../src/sagemaker/core/deserializers/base.py | 7 +++++-- .../src/sagemaker/core/serializers/base.py | 12 +++++++++--- .../unit/deserializers/test_base_deserializers.py | 9 +++++++++ .../tests/unit/serializers/test_base_serializers.py | 9 +++++++++ 5 files changed, 35 insertions(+), 6 deletions(-) diff --git a/sagemaker-core/pyproject.toml b/sagemaker-core/pyproject.toml index 9c76166594..81a01cdbc2 100644 --- a/sagemaker-core/pyproject.toml +++ b/sagemaker-core/pyproject.toml @@ -32,7 +32,6 @@ dependencies = [ "smdebug_rulesconfig>=1.0.1", "schema>=0.7.5", "omegaconf>=2.1.0", - "torch>=1.9.0", "scipy>=1.5.0", # Remote function dependencies "cloudpickle>=2.0.0", @@ -52,6 +51,9 @@ classifiers = [ ] [project.optional-dependencies] +torch = [ + "torch>=1.9.0", +] codegen = [ "black>=24.3.0, <25.0.0", "pandas>=2.0.0, <3.0.0", diff --git a/sagemaker-core/src/sagemaker/core/deserializers/base.py b/sagemaker-core/src/sagemaker/core/deserializers/base.py index 4faae7db74..aace5bcb1a 100644 --- a/sagemaker-core/src/sagemaker/core/deserializers/base.py +++ b/sagemaker-core/src/sagemaker/core/deserializers/base.py @@ -365,8 +365,11 @@ def __init__(self, accept="tensor/pt"): from torch import from_numpy self.convert_npy_to_tensor = from_numpy - except ImportError: - raise Exception("Unable to import pytorch.") + except ImportError as e: + raise ImportError( + "torch is required to use TorchTensorDeserializer. " + "Install it with: pip install sagemaker-core[torch]" + ) from e def deserialize(self, stream, content_type="tensor/pt"): """Deserialize streamed data to TorchTensor diff --git a/sagemaker-core/src/sagemaker/core/serializers/base.py b/sagemaker-core/src/sagemaker/core/serializers/base.py index a4ecf7c1dc..b17c869a54 100644 --- a/sagemaker-core/src/sagemaker/core/serializers/base.py +++ b/sagemaker-core/src/sagemaker/core/serializers/base.py @@ -443,9 +443,15 @@ class TorchTensorSerializer(SimpleBaseSerializer): def __init__(self, content_type="tensor/pt"): super(TorchTensorSerializer, self).__init__(content_type=content_type) - from torch import Tensor - - self.torch_tensor = Tensor + try: + from torch import Tensor + + self.torch_tensor = Tensor + except ImportError as e: + raise ImportError( + "torch is required to use TorchTensorSerializer. " + "Install it with: pip install sagemaker-core[torch]" + ) from e self.numpy_serializer = NumpySerializer() def serialize(self, data): diff --git a/sagemaker-core/tests/unit/deserializers/test_base_deserializers.py b/sagemaker-core/tests/unit/deserializers/test_base_deserializers.py index 8dda0601f9..151246e680 100644 --- a/sagemaker-core/tests/unit/deserializers/test_base_deserializers.py +++ b/sagemaker-core/tests/unit/deserializers/test_base_deserializers.py @@ -14,10 +14,12 @@ import io import json +import sys import numpy as np import pandas as pd import pytest +from unittest.mock import patch from sagemaker.core.deserializers.base import ( StringDeserializer, @@ -28,6 +30,7 @@ JSONDeserializer, PandasDeserializer, JSONLinesDeserializer, + TorchTensorDeserializer, ) @@ -251,3 +254,9 @@ def test_json_lines_deserializer(json_lines_deserializer, source, expected): content_type = "application/jsonlines" actual = json_lines_deserializer.deserialize(stream, content_type) assert actual == expected + + +def test_torch_tensor_deserializer_import_error(): + with patch.dict(sys.modules, {"torch": None}): + with pytest.raises(ImportError, match="pip install sagemaker-core\\[torch\\]"): + TorchTensorDeserializer() diff --git a/sagemaker-core/tests/unit/serializers/test_base_serializers.py b/sagemaker-core/tests/unit/serializers/test_base_serializers.py index 5432ba0feb..9141f2b9d2 100644 --- a/sagemaker-core/tests/unit/serializers/test_base_serializers.py +++ b/sagemaker-core/tests/unit/serializers/test_base_serializers.py @@ -15,10 +15,12 @@ import io import json import os +import sys import numpy as np import pytest import scipy.sparse +from unittest.mock import patch from sagemaker.core.serializers.base import ( CSVSerializer, @@ -29,6 +31,7 @@ JSONLinesSerializer, LibSVMSerializer, DataSerializer, + TorchTensorSerializer, ) DATA_DIR = os.path.join(os.path.dirname(__file__), "..", "..", "data") @@ -358,3 +361,9 @@ def test_data_serializer_file_like(data_serializer): with open(validation_image_file_path, "rb") as f: validation_image_data = f.read() assert input_image_data == validation_image_data + + +def test_torch_tensor_serializer_import_error(): + with patch.dict(sys.modules, {"torch": None}): + with pytest.raises(ImportError, match="pip install sagemaker-core\\[torch\\]"): + TorchTensorSerializer()