Skip to content

Commit 10e3310

Browse files
amplehhoutampl
andauthored
feature: extend list_jobs_by_share for quota_share_name (#5669)
Co-authored-by: houtampl <houtampl@amazon.com>
1 parent a40c856 commit 10e3310

File tree

8 files changed

+207
-8
lines changed

8 files changed

+207
-8
lines changed

sagemaker-train/src/sagemaker/train/aws_batch/batch_api_helper.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def _submit_service_job(
4646
timeout: Set with value of timeout if specified, else default to 1 day.
4747
share_identifier: value of shareIdentifier if specified.
4848
tags: A dict of string to string representing Batch tags.
49-
quota_share_name: Quota Share name for the Batch job.
49+
quota_share_name: value of quotaShareName if specified.
5050
preemption_config: Preemption configuration.
5151
5252
Returns:

sagemaker-train/src/sagemaker/train/aws_batch/training_queue.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,11 @@ def submit(
7171
raise ValueError(
7272
"TrainingQueue requires using a ModelTrainer with Mode.SAGEMAKER_TRAINING_JOB"
7373
)
74+
75+
if share_identifier != None and quota_share_name != None:
76+
raise ValueError(
77+
"Either share_identifier or quota_share_name can be specified, but not both"
78+
)
7479
training_payload = training_job._create_training_job_args(
7580
input_data_config=inputs, boto3=True
7681
)
@@ -108,6 +113,7 @@ def map(
108113
share_identifier: Optional[str] = None,
109114
timeout: Optional[Dict] = None,
110115
tags: Optional[Dict] = None,
116+
quota_share_name: Optional[str] = None,
111117
) -> List[TrainingQueuedJob]:
112118
"""Submit queued jobs to the provided estimator and return a list of TrainingQueuedJob objects.
113119
@@ -120,6 +126,7 @@ def map(
120126
share_identifier: Share identifier for the Batch jobs.
121127
timeout: Timeout configuration for the Batch jobs.
122128
tags: Tags apply to Batch job. These tags are for Batch job only.
129+
quota_share_name: Quota share name for the Batch jobs.
123130
124131
Returns: a list of TrainingQueuedJob objects with each Batch job ARN and job name.
125132
@@ -144,6 +151,7 @@ def map(
144151
share_identifier,
145152
timeout,
146153
tags,
154+
quota_share_name,
147155
)
148156
queued_batch_job_list.append(queued_batch_job)
149157

@@ -171,7 +179,7 @@ def list_jobs(
171179
for job_result in job_result_dict.get("jobSummaryList", []):
172180
if "jobArn" in job_result and "jobName" in job_result:
173181
jobs_to_return.append(
174-
TrainingQueuedJob(job_result["jobArn"], job_result["jobName"], job_result.get("shareIdentifier", None))
182+
TrainingQueuedJob(job_result["jobArn"], job_result["jobName"], job_result.get("shareIdentifier", None), job_result.get("quotaShareName", None))
175183
)
176184
else:
177185
logging.warning("Missing JobArn or JobName in Batch ListJobs API")
@@ -182,27 +190,35 @@ def list_jobs_by_share(
182190
self,
183191
status: Optional[str] = JOB_STATUS_RUNNING,
184192
share_identifier: Optional[str] = None,
193+
quota_share_name: Optional[str] = None,
185194
) -> List[TrainingQueuedJob]:
186195
"""List Batch jobs according to status and share.
187196
188197
Args:
189198
status: Batch job status.
190199
share_identifier: Batch fairshare share identifier.
200+
quota_share_name: Batch quota management share name.
191201
192202
Returns: A list of QueuedJob.
193203
194204
"""
195205
filters = None
206+
if share_identifier != None and quota_share_name != None:
207+
raise ValueError(
208+
"Either share_identifier or quota_share_name can be specified, but not both"
209+
)
196210
if share_identifier:
197211
filters = [{"name": "SHARE_IDENTIFIER", "values": [share_identifier]}]
212+
elif quota_share_name:
213+
filters = [{"name": "QUOTA_SHARE_NAME", "values": [quota_share_name]}]
198214

199215
jobs_to_return = []
200216
next_token = None
201217
for job_result_dict in _list_service_job(self.queue_name, status, filters, next_token):
202218
for job_result in job_result_dict.get("jobSummaryList", []):
203219
if "jobArn" in job_result and "jobName" in job_result:
204220
jobs_to_return.append(
205-
TrainingQueuedJob(job_result["jobArn"], job_result["jobName"], job_result.get("shareIdentifier", None))
221+
TrainingQueuedJob(job_result["jobArn"], job_result["jobName"], job_result.get("shareIdentifier", None), job_result.get("quotaShareName", None))
206222
)
207223
else:
208224
logging.warning("Missing JobArn or JobName in Batch ListJobs API")

sagemaker-train/src/sagemaker/train/aws_batch/training_queued_job.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,10 +45,11 @@ class TrainingQueuedJob:
4545
With this class, customers are able to attach the latest training job to a ModelTrainer.
4646
"""
4747

48-
def __init__(self, job_arn: str, job_name: str, share_identifier: Optional[str] = None):
48+
def __init__(self, job_arn: str, job_name: str, share_identifier: Optional[str] = None, quota_share_name: Optional[str] = None):
4949
self.job_arn = job_arn
5050
self.job_name = job_name
5151
self.share_identifier = share_identifier
52+
self.quota_share_name = quota_share_name
5253
self._no_training_job_status = {"SUBMITTED", "PENDING", "RUNNABLE"}
5354

5455
def get_model_trainer(self) -> ModelTrainer:

sagemaker-train/tests/unit/train/aws_batch/conftest.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
# Batch configuration
4444
SCHEDULING_PRIORITY = 1
4545
SHARE_IDENTIFIER = "test-share-id"
46+
QUOTA_SHARE_NAME = "test-quota-share"
4647
ATTEMPT_DURATION_IN_SECONDS = 86400
4748
REASON = "Test termination reason"
4849
NEXT_TOKEN = "test-next-token"
@@ -159,6 +160,24 @@
159160
"nextToken": None,
160161
}
161162

163+
LIST_SERVICE_JOB_BY_QUOTA_SHARE_RESP_WITH_JOBS = {
164+
"jobSummaryList": [
165+
{
166+
"jobName": JOB_NAME,
167+
"jobArn": JOB_ARN,
168+
"jobId": JOB_ID,
169+
"quotaShareName": QUOTA_SHARE_NAME,
170+
},
171+
{
172+
"jobName": "another-job",
173+
"jobArn": "arn:aws:batch:us-west-2:123456789012:job/another-id",
174+
"jobId": "another-id",
175+
"quotaShareName": "another-quota-share",
176+
},
177+
],
178+
"nextToken": None,
179+
}
180+
162181
LIST_SERVICE_JOB_RESP_WITH_NEXT_TOKEN = {
163182
"jobSummaryList": [
164183
{"jobName": JOB_NAME, "jobArn": JOB_ARN, "jobId": JOB_ID},

sagemaker-train/tests/unit/train/aws_batch/test_batch_api_helper.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
TIMEOUT_CONFIG,
3838
SCHEDULING_PRIORITY,
3939
SHARE_IDENTIFIER,
40+
QUOTA_SHARE_NAME,
4041
SUBMIT_SERVICE_JOB_RESP,
4142
DESCRIBE_SERVICE_JOB_RESP_RUNNING,
4243
LIST_SERVICE_JOB_RESP_EMPTY,
@@ -97,6 +98,25 @@ def test_submit_service_job_with_all_params(self, mock_get_client):
9798
assert call_kwargs["shareIdentifier"] == SHARE_IDENTIFIER
9899
assert call_kwargs["timeoutConfig"] == TIMEOUT_CONFIG
99100

101+
@patch("sagemaker.train.aws_batch.batch_api_helper.get_batch_boto_client")
102+
def test_submit_service_job_with_quota_share_name(self, mock_get_client):
103+
"""Test submit_service_job with quota_share_name parameter"""
104+
mock_client = Mock()
105+
mock_client.submit_service_job.return_value = SUBMIT_SERVICE_JOB_RESP
106+
mock_get_client.return_value = mock_client
107+
108+
result = _submit_service_job(
109+
TRAINING_JOB_PAYLOAD,
110+
JOB_NAME,
111+
JOB_QUEUE,
112+
quota_share_name=QUOTA_SHARE_NAME,
113+
)
114+
115+
assert result["jobArn"] == SUBMIT_SERVICE_JOB_RESP["jobArn"]
116+
call_kwargs = mock_client.submit_service_job.call_args[1]
117+
assert call_kwargs["quotaShareName"] == QUOTA_SHARE_NAME
118+
assert "shareIdentifier" not in call_kwargs
119+
100120
@patch("sagemaker.train.aws_batch.batch_api_helper.get_batch_boto_client")
101121
def test_submit_service_job_with_tags(self, mock_get_client):
102122
"""Test submit_service_job merges batch and training tags"""

0 commit comments

Comments
 (0)