Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions sagemaker-core/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,6 @@ dependencies = [
"smdebug_rulesconfig>=1.0.1",
"schema>=0.7.5",
"omegaconf>=2.1.0",
"torch>=1.9.0",
"scipy>=1.5.0",
# Remote function dependencies
"cloudpickle>=2.0.0",
"paramiko>=2.11.0",
Expand All @@ -52,6 +50,13 @@ classifiers = [
]

[project.optional-dependencies]
torch = [
"torch>=1.9.0",
"scipy>=1.5.0",
]
all = [
"sagemaker-core[torch]",
]
codegen = [
"black>=24.3.0, <25.0.0",
"pandas>=2.0.0, <3.0.0",
Expand Down
5 changes: 4 additions & 1 deletion sagemaker-core/src/sagemaker/core/deserializers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,10 @@ def __init__(self, accept="tensor/pt"):

self.convert_npy_to_tensor = from_numpy
except ImportError:
raise Exception("Unable to import pytorch.")
raise ImportError(
"torch is required for TorchTensorDeserializer. "
"Install it with: pip install 'sagemaker-core[torch]'"
)

def deserialize(self, stream, content_type="tensor/pt"):
"""Deserialize streamed data to TorchTensor
Expand Down
8 changes: 7 additions & 1 deletion sagemaker-core/src/sagemaker/core/serializers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,7 +443,13 @@ class TorchTensorSerializer(SimpleBaseSerializer):

def __init__(self, content_type="tensor/pt"):
super(TorchTensorSerializer, self).__init__(content_type=content_type)
from torch import Tensor
try:
from torch import Tensor
except ImportError:
raise ImportError(
"torch is required for TorchTensorSerializer. "
"Install it with: pip install 'sagemaker-core[torch]'"
)

self.torch_tensor = Tensor
self.numpy_serializer = NumpySerializer()
Expand Down
108 changes: 108 additions & 0 deletions sagemaker-core/tests/unit/serializers/test_torch_optional.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
from __future__ import absolute_import

import sys
from unittest import mock

import pytest


def test_torch_tensor_serializer_raises_import_error_when_torch_missing():
"""Verify TorchTensorSerializer raises ImportError with helpful message when torch is missing."""
with mock.patch.dict(sys.modules, {"torch": None}):
# Need to reload the module to pick up the mocked import
from sagemaker.core.serializers.base import TorchTensorSerializer

with pytest.raises(ImportError, match="pip install"):
TorchTensorSerializer()


def test_torch_tensor_deserializer_raises_import_error_when_torch_missing():
"""Verify TorchTensorDeserializer raises ImportError with helpful message when torch is missing."""
with mock.patch.dict(sys.modules, {"torch": None}):
from sagemaker.core.deserializers.base import TorchTensorDeserializer

with pytest.raises(ImportError, match="pip install"):
TorchTensorDeserializer()


def test_torch_tensor_serializer_works_when_torch_available():
"""Verify TorchTensorSerializer can be instantiated when torch is available."""
torch = pytest.importorskip("torch")
from sagemaker.core.serializers.base import TorchTensorSerializer

serializer = TorchTensorSerializer()
assert serializer.CONTENT_TYPE == "tensor/pt"

# Test serialization of a simple tensor
tensor = torch.tensor([1.0, 2.0, 3.0])
result = serializer.serialize(tensor)
assert result is not None


def test_torch_tensor_deserializer_works_when_torch_available():
"""Verify TorchTensorDeserializer can be instantiated when torch is available."""
pytest.importorskip("torch")
from sagemaker.core.deserializers.base import TorchTensorDeserializer

deserializer = TorchTensorDeserializer()
assert deserializer.ACCEPT == ("tensor/pt",)


def test_base_serializers_importable_without_torch():
"""Verify non-torch serializers can be imported and used without torch."""
from sagemaker.core.serializers.base import (
CSVSerializer,
NumpySerializer,
JSONSerializer,
IdentitySerializer,
JSONLinesSerializer,
LibSVMSerializer,
DataSerializer,
StringSerializer,
)

# Verify they can be instantiated
assert CSVSerializer() is not None
assert NumpySerializer() is not None
assert JSONSerializer() is not None
assert IdentitySerializer() is not None
assert JSONLinesSerializer() is not None
assert LibSVMSerializer() is not None
assert DataSerializer() is not None
assert StringSerializer() is not None


def test_base_deserializers_importable_without_torch():
"""Verify non-torch deserializers can be imported and used without torch."""
from sagemaker.core.deserializers.base import (
StringDeserializer,
BytesDeserializer,
CSVDeserializer,
StreamDeserializer,
NumpyDeserializer,
JSONDeserializer,
PandasDeserializer,
JSONLinesDeserializer,
)

# Verify they can be instantiated
assert StringDeserializer() is not None
assert BytesDeserializer() is not None
assert CSVDeserializer() is not None
assert StreamDeserializer() is not None
assert NumpyDeserializer() is not None
assert JSONDeserializer() is not None
assert PandasDeserializer() is not None
assert JSONLinesDeserializer() is not None
2 changes: 1 addition & 1 deletion sagemaker-core/tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ commands =
pytest {posargs}
deps =
-r ../requirements/extras/test_requirements.txt
../sagemaker-core
../sagemaker-core[torch]
.[test]
mock
depends =
Expand Down
Loading