Skip to content

Commit c7ef66f

Browse files
davlind-amznmnganesh-amznaviruthen
authored
feature: Add Support for AWS Batch Quota Management Job Submission and Job Priority Update (#5659)
* feature: [SDKv3]Add Support for QM Job Submission and Job Priority Update (#1970) * Trigger checks in changed modules and dependent modules (#1958) * Update pr workflow (#1963) * Trigger checks in changed modules and dependent modules * Removing github token dependency * Add back GH_PAT token to detect changes (#1965) * feature: Add Support for QM Job Submission and Job Priority Update --------- Co-authored-by: aviruthen <91846056+aviruthen@users.noreply.github.com> * feature: Updating aws_batch TrainingQueue integration test to support quota management. (#1978) * feature: Added an example notebook for QuotaManagement job submission on AWS Batch TrainingQueues. (#1980) * fix: aws_batch/test_training_queue QM unit test fix --------- Co-authored-by: mnganesh-amzn <mnganesh@amazon.com> Co-authored-by: aviruthen <91846056+aviruthen@users.noreply.github.com>
1 parent 302b973 commit c7ef66f

File tree

11 files changed

+1909
-73
lines changed

11 files changed

+1909
-73
lines changed

sagemaker-train/src/sagemaker/train/aws_batch/batch_api_helper.py

Lines changed: 48 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@ def _submit_service_job(
3232
timeout: Optional[Dict] = None,
3333
share_identifier: Optional[str] = None,
3434
tags: Optional[Dict] = None,
35+
quota_share_name: Optional[str] = None,
36+
preemption_config: Optional[Dict] = None,
3537
) -> Dict:
3638
"""Batch submit_service_job API helper function.
3739
@@ -44,6 +46,8 @@ def _submit_service_job(
4446
timeout: Set with value of timeout if specified, else default to 1 day.
4547
share_identifier: value of shareIdentifier if specified.
4648
tags: A dict of string to string representing Batch tags.
49+
quota_share_name: Quota Share name for the Batch job.
50+
preemption_config: Preemption configuration.
4751
4852
Returns:
4953
A dict containing jobArn, jobName and jobId.
@@ -68,6 +72,10 @@ def _submit_service_job(
6872
payload["shareIdentifier"] = share_identifier
6973
if tags or training_payload_tags:
7074
payload["tags"] = __merge_tags(tags, training_payload_tags)
75+
if quota_share_name:
76+
payload["quotaShareName"] = quota_share_name
77+
if preemption_config:
78+
payload["preemptionConfiguration"] = preemption_config
7179
return client.submit_service_job(**payload)
7280

7381

@@ -96,21 +104,45 @@ def _describe_service_job(job_id: str) -> Dict:
96104
'jobId': 'string',
97105
'jobName': 'string',
98106
'jobQueue': 'string',
107+
'latestAttempt': {
108+
'serviceResourceId': {
109+
'name': 'string',
110+
'value': 'string'
111+
}
112+
},
113+
'preemptionSummary': {
114+
'preemptedAttemptCount': 123,
115+
'recentPreemptedAttempts': [
116+
{
117+
'serviceResourceId': {
118+
'name': 'string',
119+
'value': 'string'
120+
},
121+
'startedAt': 123,
122+
'stoppedAt': 123,
123+
'statusReason': 'string'
124+
},
125+
]
126+
},
99127
'retryStrategy': {
100128
'attempts': 123
101129
},
102130
'schedulingPriority': 123,
103131
'serviceRequestPayload': 'string',
104-
'serviceJobType': 'EKS'|'ECS'|'ECS_FARGATE'|'SAGEMAKER_TRAINING',
132+
'serviceJobType': 'SAGEMAKER_TRAINING',
105133
'shareIdentifier': 'string',
134+
'quotaShareName': 'string',
135+
'preemptionConfiguration': {
136+
'preemptionRetriesBeforeTermination': 123
137+
},
106138
'startedAt': 123,
107-
'status': 'SUBMITTED'|'PENDING'|'RUNNABLE'|'STARTING'|'RUNNING'|'SUCCEEDED'|'FAILED',
139+
'status': 'SUBMITTED'|'PENDING'|'RUNNABLE'|'SCHEDULED'|'STARTING'|'RUNNING'|'SUCCEEDED'|'FAILED',
108140
'statusReason': 'string',
109141
'stoppedAt': 123,
110142
'tags': {
111143
'string': 'string'
112144
},
113-
'timeout': {
145+
'timeoutConfig': {
114146
'attemptDurationSeconds': 123
115147
}
116148
}
@@ -132,6 +164,19 @@ def _terminate_service_job(job_id: str, reason: Optional[str] = "default termina
132164
return client.terminate_service_job(jobId=job_id, reason=reason)
133165

134166

167+
def _update_service_job(job_id: str, scheduling_priority: int) -> Dict:
168+
"""Batch update_service_job API helper function.
169+
170+
Args:
171+
job_id: Job ID or Job Arn
172+
scheduling_priority: An integer representing scheduling priority.
173+
174+
Returns: a dict containing jobArn, jobId and jobName.
175+
"""
176+
client = get_batch_boto_client()
177+
return client.update_service_job(jobId=job_id, schedulingPriority=scheduling_priority)
178+
179+
135180
def _list_service_job(
136181
job_queue: str,
137182
job_status: Optional[str] = None,

sagemaker-train/src/sagemaker/train/aws_batch/training_queue.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@ def submit(
4141
share_identifier: Optional[str] = None,
4242
timeout: Optional[Dict] = None,
4343
tags: Optional[Dict] = None,
44+
quota_share_name: Optional[str] = None,
45+
preemption_config: Optional[Dict] = None,
4446
) -> TrainingQueuedJob:
4547
"""Submit a queued job and return a QueuedJob object.
4648
@@ -53,6 +55,8 @@ def submit(
5355
share_identifier: Share identifier for Batch job.
5456
timeout: Timeout configuration for Batch job.
5557
tags: Tags apply to Batch job. These tags are for Batch job only.
58+
quota_share_name: Quota Share name for the Batch job.
59+
preemption_config: Preemption configuration.
5660
5761
Returns: a TrainingQueuedJob object with Batch job ARN and job name.
5862
@@ -85,6 +89,8 @@ def submit(
8589
timeout,
8690
share_identifier,
8791
tags,
92+
quota_share_name,
93+
preemption_config,
8894
)
8995
if "jobArn" not in resp or "jobName" not in resp:
9096
raise MissingRequiredArgument(

sagemaker-train/src/sagemaker/train/aws_batch/training_queued_job.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
SourceCode,
3030
TrainingImageConfig,
3131
)
32-
from .batch_api_helper import _terminate_service_job, _describe_service_job
32+
from .batch_api_helper import _terminate_service_job, _describe_service_job, _update_service_job
3333
from .exception import NoTrainingJob, MissingRequiredArgument
3434
from ..utils import _get_training_job_name_from_training_job_arn
3535
from .constants import JOB_STATUS_COMPLETED, JOB_STATUS_FAILED, POLL_IN_SECONDS
@@ -85,6 +85,17 @@ def terminate(self, reason: Optional[str] = "Default terminate reason") -> None:
8585
"""
8686
_terminate_service_job(self.job_arn, reason)
8787

88+
def update(self, scheduling_priority: int) -> Dict:
89+
"""Update Batch job.
90+
91+
Args:
92+
scheduling_priority: An integer representing scheduling priority.
93+
94+
Returns: A dict which includes jobArn, jobName and jobId.
95+
96+
"""
97+
return _update_service_job(self.job_arn, scheduling_priority)
98+
8899
def describe(self) -> Dict:
89100
"""Describe Batch job.
90101

0 commit comments

Comments
 (0)