Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion sagemaker-core/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ dependencies = [
"smdebug_rulesconfig>=1.0.1",
"schema>=0.7.5",
"omegaconf>=2.1.0",
"torch>=1.9.0",
"scipy>=1.5.0",
# Remote function dependencies
"cloudpickle>=2.0.0",
Expand All @@ -52,6 +51,9 @@ classifiers = [
]

[project.optional-dependencies]
torch = [
"torch>=1.9.0",
]
codegen = [
"black>=24.3.0, <25.0.0",
"pandas>=2.0.0, <3.0.0",
Expand Down
7 changes: 5 additions & 2 deletions sagemaker-core/src/sagemaker/core/deserializers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,8 +365,11 @@ def __init__(self, accept="tensor/pt"):
from torch import from_numpy

self.convert_npy_to_tensor = from_numpy
except ImportError:
raise Exception("Unable to import pytorch.")
except ImportError as e:
raise ImportError(
"torch is required to use TorchTensorDeserializer. "
"Install it with: pip install sagemaker-core[torch]"
) from e

def deserialize(self, stream, content_type="tensor/pt"):
"""Deserialize streamed data to TorchTensor
Expand Down
12 changes: 9 additions & 3 deletions sagemaker-core/src/sagemaker/core/serializers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,9 +443,15 @@ class TorchTensorSerializer(SimpleBaseSerializer):

def __init__(self, content_type="tensor/pt"):
super(TorchTensorSerializer, self).__init__(content_type=content_type)
from torch import Tensor

self.torch_tensor = Tensor
try:
from torch import Tensor

self.torch_tensor = Tensor
except ImportError as e:
raise ImportError(
"torch is required to use TorchTensorSerializer. "
"Install it with: pip install sagemaker-core[torch]"
) from e
self.numpy_serializer = NumpySerializer()

def serialize(self, data):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,12 @@

import io
import json
import sys

import numpy as np
import pandas as pd
import pytest
from unittest.mock import patch

from sagemaker.core.deserializers.base import (
StringDeserializer,
Expand All @@ -28,6 +30,7 @@
JSONDeserializer,
PandasDeserializer,
JSONLinesDeserializer,
TorchTensorDeserializer,
)


Expand Down Expand Up @@ -251,3 +254,9 @@ def test_json_lines_deserializer(json_lines_deserializer, source, expected):
content_type = "application/jsonlines"
actual = json_lines_deserializer.deserialize(stream, content_type)
assert actual == expected


def test_torch_tensor_deserializer_import_error():
with patch.dict(sys.modules, {"torch": None}):
with pytest.raises(ImportError, match="pip install sagemaker-core\\[torch\\]"):
TorchTensorDeserializer()
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,12 @@
import io
import json
import os
import sys

import numpy as np
import pytest
import scipy.sparse
from unittest.mock import patch

from sagemaker.core.serializers.base import (
CSVSerializer,
Expand All @@ -29,6 +31,7 @@
JSONLinesSerializer,
LibSVMSerializer,
DataSerializer,
TorchTensorSerializer,
)

DATA_DIR = os.path.join(os.path.dirname(__file__), "..", "..", "data")
Expand Down Expand Up @@ -358,3 +361,9 @@ def test_data_serializer_file_like(data_serializer):
with open(validation_image_file_path, "rb") as f:
validation_image_data = f.read()
assert input_image_data == validation_image_data


def test_torch_tensor_serializer_import_error():
with patch.dict(sys.modules, {"torch": None}):
with pytest.raises(ImportError, match="pip install sagemaker-core\\[torch\\]"):
TorchTensorSerializer()
Loading