forked from aws/sagemaker-python-sdk
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathemr_step.py
More file actions
276 lines (232 loc) · 10.9 KB
/
emr_step.py
File metadata and controls
276 lines (232 loc) · 10.9 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""The step definitions for workflow."""
from __future__ import absolute_import
from typing import Any, Dict, List, Union, Optional
from sagemaker.core.helper.pipeline_variable import RequestType
from sagemaker.core.workflow.properties import (
Properties,
)
from sagemaker.mlops.workflow.steps import Step, StepTypeEnum, CacheConfig
class EMRStepConfig:
"""Config for a Hadoop Jar step."""
def __init__(
self,
jar,
args: List[str] = None,
main_class: str = None,
properties: List[dict] = None,
output_args: dict[str, str] = None,
):
"""Create a definition for input data used by an EMR cluster(job flow) step.
See AWS documentation for more information about the `StepConfig
<https://docs.aws.amazon.com/emr/latest/APIReference/API_StepConfig.html>`_ API parameters.
Args:
args(List[str]):
A list of command line arguments passed to
the JAR file's main function when executed.
jar(str): A path to a JAR file run during the step.
main_class(str): The name of the main class in the specified Java file.
properties(List(dict)): A list of key-value pairs that are set when the step runs.
output_args(dict[str, str]):
A dict of argument-value pairs (output_name: S3 URI) that extends the command line
args and can be accessible in other steps via EMRStep.emr_outputs[output_name].
Argument names are prepended by '--' automatically.
Example: {"output-path": "s3://my-bucket/output/"} will result in the following
command line args: ["--output-path", "s3://my-bucket/output/"]
"""
self.jar = jar
self.args = args
self.main_class = main_class
self.properties = properties
self.output_args_index = {}
if output_args:
for output_arg_name, output_arg_value in output_args.items():
self.args.extend([f"--{output_arg_name}", output_arg_value])
self.output_args_index[output_arg_name] = len(self.args) - 1
def to_request(self) -> RequestType:
"""Convert EMRStepConfig object to request dict."""
config = {"HadoopJarStep": {"Jar": self.jar}}
if self.args is not None:
config["HadoopJarStep"]["Args"] = self.args
if self.main_class is not None:
config["HadoopJarStep"]["MainClass"] = self.main_class
if self.properties is not None:
config["HadoopJarStep"]["Properties"] = self.properties
return config
INSTANCES = "Instances"
INSTANCEGROUPS = "InstanceGroups"
INSTANCEFLEETS = "InstanceFleets"
ERR_STR_WITH_NAME_AUTO_TERMINATION_OR_STEPS = (
"In EMRStep {step_name}, cluster_config "
"should not contain any of the Name, "
"AutoTerminationPolicy and/or Steps."
)
ERR_STR_WITHOUT_INSTANCE = "In EMRStep {step_name}, cluster_config must contain " + INSTANCES + "."
ERR_STR_WITH_KEEPJOBFLOW_OR_TERMINATIONPROTECTED = (
"In EMRStep {step_name}, " + INSTANCES + " should not contain "
"KeepJobFlowAliveWhenNoSteps or "
"TerminationProtected."
)
ERR_STR_BOTH_OR_NONE_INSTANCEGROUPS_OR_INSTANCEFLEETS = (
"In EMRStep {step_name}, "
+ INSTANCES
+ " should contain either "
+ INSTANCEGROUPS
+ " or "
+ INSTANCEFLEETS
+ "."
)
ERR_STR_WITH_BOTH_CLUSTER_ID_AND_CLUSTER_CFG = (
"EMRStep {step_name} can not have both cluster_id"
"or cluster_config."
"To use EMRStep with "
"cluster_config, cluster_id "
"must be explicitly set to None."
)
ERR_STR_WITH_EXEC_ROLE_ARN_AND_WITHOUT_CLUSTER_ID = (
"EMRStep {step_name} cannot have execution_role_arn"
"without cluster_id."
"To use EMRStep with "
"execution_role_arn, cluster_id "
"must not be None."
)
ERR_STR_WITHOUT_CLUSTER_ID_AND_CLUSTER_CFG = (
"EMRStep {step_name} must have either cluster_id or cluster_config"
)
class EMRStep(Step):
"""EMR step for workflow."""
def _validate_cluster_config(self, cluster_config, step_name):
"""Validates user provided cluster_config.
Args:
cluster_config(Union[Dict[str, Any], List[Dict[str, Any]]]):
user provided cluster configuration.
step_name: The name of the EMR step.
"""
if (
"Name" in cluster_config
or "AutoTerminationPolicy" in cluster_config
or "Steps" in cluster_config
):
raise ValueError(
ERR_STR_WITH_NAME_AUTO_TERMINATION_OR_STEPS.format(step_name=step_name)
)
if INSTANCES not in cluster_config:
raise ValueError(ERR_STR_WITHOUT_INSTANCE.format(step_name=step_name))
if (
"KeepJobFlowAliveWhenNoSteps" in cluster_config[INSTANCES]
or "TerminationProtected" in cluster_config[INSTANCES]
):
raise ValueError(
ERR_STR_WITH_KEEPJOBFLOW_OR_TERMINATIONPROTECTED.format(step_name=step_name)
)
if (
INSTANCEGROUPS in cluster_config[INSTANCES]
and INSTANCEFLEETS in cluster_config[INSTANCES]
) or (
INSTANCEGROUPS not in cluster_config[INSTANCES]
and INSTANCEFLEETS not in cluster_config[INSTANCES]
):
raise ValueError(
ERR_STR_BOTH_OR_NONE_INSTANCEGROUPS_OR_INSTANCEFLEETS.format(step_name=step_name)
)
def __init__(
self,
name: str,
display_name: str,
description: str,
cluster_id: str,
step_config: EMRStepConfig,
depends_on: Optional[List[Union[str, Step]]] = None,
cache_config: Optional[CacheConfig] = None,
cluster_config: Optional[Dict[str, Any]] = None,
execution_role_arn: Optional[str] = None,
):
"""Constructs an `EMRStep`.
Args:
name(str): The name of the EMR step.
display_name(str): The display name of the EMR step.
description(str): The description of the EMR step.
cluster_id(str): The ID of the running EMR cluster.
step_config(EMRStepConfig): One StepConfig to be executed by the job flow.
depends_on (List[Union[str, Step]]): A list of `Step`
names or `Step` instances that this `EMRStep`
depends on.
cache_config(CacheConfig): A `sagemaker.workflow.steps.CacheConfig` instance.
cluster_config(Dict[str, Any]): The recipe of the
EMR cluster, passed as a dictionary.
The elements are defined in the request syntax for `RunJobFlow`.
However, the following elements are not recognized as part of the cluster
configuration and you should not include them in the dictionary:
* ``cluster_config[Name]``
* ``cluster_config[Steps]``
* ``cluster_config[AutoTerminationPolicy]``
* ``cluster_config[Instances][KeepJobFlowAliveWhenNoSteps]``
* ``cluster_config[Instances][TerminationProtected]``
For more information about the fields you can include in your cluster
configuration, see
https://docs.aws.amazon.com/emr/latest/APIReference/API_RunJobFlow.html.
Note that if you want to use ``cluster_config``, then you have to set
``cluster_id`` as None.
execution_role_arn(str): The ARN of the runtime role assumed by this `EMRStep`. The
job submitted to your EMR cluster uses this role to access AWS resources. This
value is passed as ExecutionRoleArn to the AddJobFlowSteps request (an EMR request)
called on the cluster specified by ``cluster_id``, so you can only include this
field if ``cluster_id`` is not None.
"""
super(EMRStep, self).__init__(name, display_name, description, StepTypeEnum.EMR, depends_on)
emr_step_args = {"StepConfig": step_config.to_request()}
root_property = Properties(step_name=name, step=self, shape_name="Step", service_name="emr")
if cluster_id is None and cluster_config is None:
raise ValueError(ERR_STR_WITHOUT_CLUSTER_ID_AND_CLUSTER_CFG.format(step_name=name))
if cluster_id is not None and cluster_config is not None:
raise ValueError(ERR_STR_WITH_BOTH_CLUSTER_ID_AND_CLUSTER_CFG.format(step_name=name))
if execution_role_arn is not None and cluster_id is None:
raise ValueError(
ERR_STR_WITH_EXEC_ROLE_ARN_AND_WITHOUT_CLUSTER_ID.format(step_name=name)
)
if cluster_id is not None:
emr_step_args["ClusterId"] = cluster_id
root_property.__dict__["ClusterId"] = cluster_id
if execution_role_arn is not None:
emr_step_args["ExecutionRoleArn"] = execution_role_arn
root_property.__dict__["ExecutionRoleArn"] = execution_role_arn
elif cluster_config is not None:
self._validate_cluster_config(cluster_config, name)
emr_step_args["ClusterConfig"] = cluster_config
root_property.__dict__["ClusterConfig"] = cluster_config
self.args = emr_step_args
self.cache_config = cache_config
self._properties = root_property
self.emr_outputs = {
output_name: self.properties.Config.Args[step_config.output_args_index[output_name]]
for output_name in step_config.output_args_index
}
@property
def arguments(self) -> RequestType:
"""The arguments dict that is used to call `AddJobFlowSteps`.
NOTE: The AddFlowJobSteps request is not quite the args list that workflow needs.
The Name attribute in AddJobFlowSteps cannot be passed; it will be set during runtime.
In addition to that, we will also need to include emr job inputs and output config.
"""
return self.args
@property
def properties(self) -> RequestType:
"""A Properties object representing the EMR DescribeStepResponse model"""
return self._properties
def to_request(self) -> RequestType:
"""Updates the dictionary with cache configuration."""
request_dict = super().to_request()
if self.cache_config:
request_dict.update(self.cache_config.config)
return request_dict