forked from aws/sagemaker-python-sdk
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathemr_serverless_step.py
More file actions
162 lines (139 loc) · 6.96 KB
/
emr_serverless_step.py
File metadata and controls
162 lines (139 loc) · 6.96 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""The step definitions for EMR Serverless workflow."""
from __future__ import absolute_import
from typing import Any, Dict, List, Union, Optional
from sagemaker.core.helper.pipeline_variable import RequestType
from sagemaker.core.workflow.properties import Properties
from sagemaker.mlops.workflow.retry import StepRetryPolicy
from sagemaker.mlops.workflow.step_collections import StepCollection
from sagemaker.mlops.workflow.steps import ConfigurableRetryStep, Step, StepTypeEnum, CacheConfig
class EMRServerlessJobConfig:
"""Config for EMR Serverless job."""
def __init__(
self,
job_driver: Dict,
execution_role_arn: str,
configuration_overrides: Optional[Dict] = None,
execution_timeout_minutes: Optional[int] = None,
name: Optional[str] = None,
tags: Optional[Dict[str, str]] = None,
): # pylint: disable=too-many-positional-arguments
"""Create a definition for EMR Serverless job configuration.
Args:
job_driver (Dict): The job driver for the job run.
execution_role_arn (str): The execution role ARN for the job run.
configuration_overrides (Dict, optional): Configuration overrides for the job run.
execution_timeout_minutes (int, optional): The maximum duration for the job run.
name (str, optional): The optional job run name.
tags (Dict[str, str], optional): The tags assigned to the job run.
"""
self.job_driver = job_driver
self.execution_role_arn = execution_role_arn
self.configuration_overrides = configuration_overrides
self.execution_timeout_minutes = execution_timeout_minutes
self.name = name
self.tags = tags
def to_request(self, application_id: Optional[str] = None) -> RequestType:
"""Convert EMRServerlessJobConfig object to request dict."""
config = {"executionRoleArn": self.execution_role_arn, "jobDriver": self.job_driver}
if application_id is not None:
config["applicationId"] = application_id
if self.configuration_overrides is not None:
config["configurationOverrides"] = self.configuration_overrides
if self.execution_timeout_minutes is not None:
config["executionTimeoutMinutes"] = self.execution_timeout_minutes
if self.name is not None:
config["name"] = self.name
if self.tags is not None:
config["tags"] = self.tags
return config
ERR_STR_WITH_BOTH_APP_ID_AND_APP_CONFIG = (
"EMRServerlessStep {step_name} cannot have both application_id and application_config. "
"To use EMRServerlessStep with application_config, "
"application_id must be explicitly set to None."
)
ERR_STR_WITHOUT_APP_ID_AND_APP_CONFIG = (
"EMRServerlessStep {step_name} must have either application_id or application_config"
)
class EMRServerlessStep(ConfigurableRetryStep):
"""EMR Serverless step for workflow with configurable retry policies."""
def __init__(
self,
name: str,
display_name: str,
description: str,
job_config: EMRServerlessJobConfig,
application_id: Optional[str] = None,
application_config: Optional[Dict[str, Any]] = None,
depends_on: Optional[List[Union[str, Step, StepCollection]]] = None,
cache_config: Optional[CacheConfig] = None,
retry_policies: Optional[List[StepRetryPolicy]] = None,
): # pylint: disable=too-many-positional-arguments
"""Constructs an `EMRServerlessStep`.
Args:
name (str): The name of the EMR Serverless step.
display_name (str): The display name of the EMR Serverless step.
description (str): The description of the EMR Serverless step.
job_config (EMRServerlessJobConfig): Job configuration for the EMR Serverless job.
application_id (str, optional): The ID of the existing EMR Serverless application.
application_config (Dict[str, Any], optional): Configuration for creating a new
EMR Serverless application.
depends_on (List[Union[str, Step, StepCollection]], optional): A list of
`Step`/`StepCollection` names or `Step` instances or `StepCollection` instances
that this `EMRServerlessStep` depends on.
cache_config (CacheConfig, optional): A `sagemaker.workflow.steps.CacheConfig` instance.
retry_policies (List[StepRetryPolicy], optional): A list of retry policies.
"""
super().__init__(
name=name,
step_type=StepTypeEnum.EMR_SERVERLESS,
display_name=display_name,
description=description,
depends_on=depends_on,
retry_policies=retry_policies,
)
if application_id is None and application_config is None:
raise ValueError(ERR_STR_WITHOUT_APP_ID_AND_APP_CONFIG.format(step_name=name))
if application_id is not None and application_config is not None:
raise ValueError(ERR_STR_WITH_BOTH_APP_ID_AND_APP_CONFIG.format(step_name=name))
emr_serverless_args = {
"ExecutionRoleArn": job_config.execution_role_arn, # Top-level role (used by backend)
"JobConfig": job_config.to_request(
application_id
), # Role also in JobConfig (structure requirement)
}
if application_id is not None:
emr_serverless_args["ApplicationId"] = application_id
elif application_config is not None:
emr_serverless_args["ApplicationConfig"] = application_config
self.args = emr_serverless_args
self.cache_config = cache_config
root_property = Properties(
step_name=name, step=self, shape_name="GetJobRunResponse", service_name="emr-serverless"
)
self._properties = root_property
@property
def arguments(self) -> RequestType:
"""The arguments dict that is used to call EMR Serverless APIs."""
return self.args
@property
def properties(self) -> RequestType:
"""A Properties object representing the EMR Serverless GetJobRunResponse model."""
return self._properties
def to_request(self) -> RequestType:
"""Updates the dictionary with cache configuration and retry policies."""
request_dict = super().to_request()
if self.cache_config:
request_dict.update(self.cache_config.config)
return request_dict