forked from aws/sagemaker-python-sdk
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest_emr_step.py
More file actions
99 lines (84 loc) · 3.33 KB
/
test_emr_step.py
File metadata and controls
99 lines (84 loc) · 3.33 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""Unit tests for workflow emr_step."""
from __future__ import absolute_import
import pytest
from sagemaker.mlops.workflow.emr_step import EMRStep, EMRStepConfig
from sagemaker.mlops.workflow.steps import StepTypeEnum
from sagemaker.core.workflow.properties import Properties
def test_emr_step_config_init():
config = EMRStepConfig(jar="s3://bucket/my.jar", args=["arg1", "arg2"])
assert config.jar == "s3://bucket/my.jar"
assert config.args == ["arg1", "arg2"]
def test_emr_step_config_to_request():
config = EMRStepConfig(jar="s3://bucket/my.jar", args=["arg1"])
request = config.to_request()
assert request["HadoopJarStep"]["Jar"] == "s3://bucket/my.jar"
assert request["HadoopJarStep"]["Args"] == ["arg1"]
def test_emr_step_with_cluster_id():
config = EMRStepConfig(jar="s3://bucket/my.jar")
step = EMRStep(
name="emr-step",
display_name="EMR Step",
description="Test EMR step",
cluster_id="j-123456",
step_config=config,
)
assert step.name == "emr-step"
assert step.step_type == StepTypeEnum.EMR
def test_emr_step_with_cluster_config():
config = EMRStepConfig(jar="s3://bucket/my.jar")
cluster_config = {
"Instances": {"InstanceGroups": [{"InstanceType": "m5.xlarge", "InstanceCount": 1}]}
}
step = EMRStep(
name="emr-step",
display_name="EMR Step",
description="Test EMR step",
cluster_id=None,
step_config=config,
cluster_config=cluster_config,
)
assert step.name == "emr-step"
def test_emr_step_without_cluster_id_or_config_raises_error():
config = EMRStepConfig(jar="s3://bucket/my.jar")
with pytest.raises(ValueError, match="must have either cluster_id or cluster_config"):
EMRStep(
name="emr-step",
display_name="EMR Step",
description="Test EMR step",
cluster_id=None,
step_config=config,
)
def test_emr_step_with_both_cluster_id_and_config_raises_error():
config = EMRStepConfig(jar="s3://bucket/my.jar")
with pytest.raises(ValueError, match="can not have both cluster_id"):
EMRStep(
name="emr-step",
display_name="EMR Step",
description="Test EMR step",
cluster_id="j-123456",
step_config=config,
cluster_config={"Instances": {}},
)
def test_emr_step_with_output_args():
config = EMRStepConfig(jar="s3://bucket/my.jar", args=["arg1"], output_args={"output": "s3://bucket/my/output/path"})
step = EMRStep(
name="emr-step",
display_name="EMR Step",
description="Test EMR step",
cluster_id="j-123456",
step_config=config,
)
assert "output" in step.emr_outputs
assert isinstance(step.emr_outputs["output"], Properties)