Skip to content

Commit 3a9cbc7

Browse files
authored
fix: restore BatchTransformInput.destination attribute in v3 (aws#5865)
In v3, BatchTransformInput.__init__ assigns self.s3_output.s3_uri = destination, but s3_output is never created on the class or its MonitoringInput parent (it only exists on the unrelated MonitoringOutput class). Constructing a BatchTransformInput therefore raises AttributeError: 'BatchTransformInput' object has no attribute 's3_output' before the user can use it in any monitoring schedule call. Reproducer: from sagemaker.core.model_monitor import ( BatchTransformInput, MonitoringDatasetFormat ) BatchTransformInput( data_captured_destination_s3_uri="s3://bucket/captured", destination="/opt/ml/processing/input", dataset_format=MonitoringDatasetFormat.csv(header=False), ) This restores the v2.244.x behavior of storing destination as a plain attribute and reading it back in _to_request_dict for the LocalPath field. Also unskips and expands TestBatchTransformInput unit tests that were marked xfail with "BatchTransformInput has initialization issues in the source code".
1 parent 8a755c5 commit 3a9cbc7

2 files changed

Lines changed: 78 additions & 4 deletions

File tree

sagemaker-core/src/sagemaker/core/model_monitor/model_monitoring.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4174,7 +4174,7 @@ def __init__(
41744174
41754175
"""
41764176
self.data_captured_destination_s3_uri = data_captured_destination_s3_uri
4177-
self.s3_output.s3_uri = destination
4177+
self.destination = destination
41784178
self.s3_input_mode = s3_input_mode
41794179
self.s3_data_distribution_type = s3_data_distribution_type
41804180
self.dataset_format = dataset_format
@@ -4193,7 +4193,7 @@ def _to_request_dict(self):
41934193
"""Generates a request dictionary using the parameters provided to the class."""
41944194
batch_transform_input_data = {
41954195
"DataCapturedDestinationS3Uri": self.data_captured_destination_s3_uri,
4196-
"LocalPath": self.s3_output.s3_uri,
4196+
"LocalPath": self.destination,
41974197
"S3InputMode": self.s3_input_mode,
41984198
"S3DataDistributionType": self.s3_data_distribution_type,
41994199
"DatasetFormat": self.dataset_format,

sagemaker-core/tests/unit/test_model_monitoring.py

Lines changed: 76 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
CONSTRAINT_VIOLATIONS_JSON_DEFAULT_FILE_NAME,
3030
DEFAULT_REPOSITORY_NAME,
3131
)
32+
from sagemaker.core.model_monitor.dataset_format import MonitoringDatasetFormat
3233
from sagemaker.core.processing import ProcessingInput, ProcessingOutput
3334
from sagemaker.core.shapes import ProcessingS3Input, ProcessingS3Output
3435
from sagemaker.core.network import NetworkConfig
@@ -140,8 +141,81 @@ def test_to_request_dict_minimal(self):
140141

141142
class TestBatchTransformInput:
142143
def test_init_minimal(self):
143-
# Skip this test as BatchTransformInput has initialization issues in the source code
144-
pytest.skip("BatchTransformInput has initialization issues in the source code")
144+
bti = BatchTransformInput(
145+
data_captured_destination_s3_uri="s3://bucket/captured",
146+
destination="/opt/ml/processing/input",
147+
dataset_format=MonitoringDatasetFormat.csv(header=False),
148+
)
149+
assert bti.data_captured_destination_s3_uri == "s3://bucket/captured"
150+
assert bti.destination == "/opt/ml/processing/input"
151+
assert bti.s3_input_mode == "File"
152+
assert bti.s3_data_distribution_type == "FullyReplicated"
153+
154+
def test_init_with_all_parameters(self):
155+
bti = BatchTransformInput(
156+
data_captured_destination_s3_uri="s3://bucket/captured",
157+
destination="/opt/ml/processing/input",
158+
dataset_format=MonitoringDatasetFormat.csv(header=False),
159+
s3_input_mode="Pipe",
160+
s3_data_distribution_type="ShardedByS3Key",
161+
start_time_offset="-PT1H",
162+
end_time_offset="-PT0H",
163+
features_attribute="features",
164+
inference_attribute="prediction",
165+
probability_attribute="probability",
166+
probability_threshold_attribute=0.5,
167+
exclude_features_attribute="feature1,feature2",
168+
)
169+
assert bti.s3_input_mode == "Pipe"
170+
assert bti.start_time_offset == "-PT1H"
171+
assert bti.features_attribute == "features"
172+
173+
def test_to_request_dict_minimal(self):
174+
bti = BatchTransformInput(
175+
data_captured_destination_s3_uri="s3://bucket/captured",
176+
destination="/opt/ml/processing/input",
177+
dataset_format=MonitoringDatasetFormat.csv(header=False),
178+
)
179+
request_dict = bti._to_request_dict()
180+
assert "BatchTransformInput" in request_dict
181+
payload = request_dict["BatchTransformInput"]
182+
assert payload["DataCapturedDestinationS3Uri"] == "s3://bucket/captured"
183+
assert payload["LocalPath"] == "/opt/ml/processing/input"
184+
assert payload["S3InputMode"] == "File"
185+
assert payload["S3DataDistributionType"] == "FullyReplicated"
186+
187+
def test_to_request_dict_excludes_none_values(self):
188+
bti = BatchTransformInput(
189+
data_captured_destination_s3_uri="s3://bucket/captured",
190+
destination="/opt/ml/processing/input",
191+
dataset_format=MonitoringDatasetFormat.csv(header=False),
192+
)
193+
payload = bti._to_request_dict()["BatchTransformInput"]
194+
assert "StartTimeOffset" not in payload
195+
assert "EndTimeOffset" not in payload
196+
assert "FeaturesAttribute" not in payload
197+
assert "InferenceAttribute" not in payload
198+
assert "ProbabilityAttribute" not in payload
199+
assert "ProbabilityThresholdAttribute" not in payload
200+
assert "ExcludeFeaturesAttribute" not in payload
201+
202+
def test_to_request_dict_includes_optional_values(self):
203+
bti = BatchTransformInput(
204+
data_captured_destination_s3_uri="s3://bucket/captured",
205+
destination="/opt/ml/processing/input",
206+
dataset_format=MonitoringDatasetFormat.csv(header=False),
207+
start_time_offset="-PT1H",
208+
end_time_offset="-PT0H",
209+
probability_attribute="probability",
210+
probability_threshold_attribute=0.5,
211+
exclude_features_attribute="f1,f2",
212+
)
213+
payload = bti._to_request_dict()["BatchTransformInput"]
214+
assert payload["StartTimeOffset"] == "-PT1H"
215+
assert payload["EndTimeOffset"] == "-PT0H"
216+
assert payload["ProbabilityAttribute"] == "probability"
217+
assert payload["ProbabilityThresholdAttribute"] == 0.5
218+
assert payload["ExcludeFeaturesAttribute"] == "f1,f2"
145219

146220

147221
class TestBaseliningJob:

0 commit comments

Comments
 (0)