Skip to content

Commit 2396310

Browse files
amplehhoutampl
andauthored
Feature: Add support for listing Batch jobs by share identifier (#5585)
Co-authored-by: houtampl <houtampl@amazon.com>
1 parent d3770cc commit 2396310

File tree

4 files changed

+106
-5
lines changed

4 files changed

+106
-5
lines changed

sagemaker-train/src/sagemaker/train/aws_batch/training_queue.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,38 @@ def list_jobs(
165165
for job_result in job_result_dict.get("jobSummaryList", []):
166166
if "jobArn" in job_result and "jobName" in job_result:
167167
jobs_to_return.append(
168-
TrainingQueuedJob(job_result["jobArn"], job_result["jobName"])
168+
TrainingQueuedJob(job_result["jobArn"], job_result["jobName"], job_result.get("shareIdentifier", None))
169+
)
170+
else:
171+
logging.warning("Missing JobArn or JobName in Batch ListJobs API")
172+
continue
173+
return jobs_to_return
174+
175+
def list_jobs_by_share(
176+
self,
177+
status: Optional[str] = JOB_STATUS_RUNNING,
178+
share_identifier: Optional[str] = None,
179+
) -> List[TrainingQueuedJob]:
180+
"""List Batch jobs according to status and share.
181+
182+
Args:
183+
status: Batch job status.
184+
share_identifier: Batch fairshare share identifier.
185+
186+
Returns: A list of QueuedJob.
187+
188+
"""
189+
filters = None
190+
if share_identifier:
191+
filters = [{"name": "SHARE_IDENTIFIER", "values": [share_identifier]}]
192+
193+
jobs_to_return = []
194+
next_token = None
195+
for job_result_dict in _list_service_job(self.queue_name, status, filters, next_token):
196+
for job_result in job_result_dict.get("jobSummaryList", []):
197+
if "jobArn" in job_result and "jobName" in job_result:
198+
jobs_to_return.append(
199+
TrainingQueuedJob(job_result["jobArn"], job_result["jobName"], job_result.get("shareIdentifier", None))
169200
)
170201
else:
171202
logging.warning("Missing JobArn or JobName in Batch ListJobs API")

sagemaker-train/src/sagemaker/train/aws_batch/training_queued_job.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,9 +45,10 @@ class TrainingQueuedJob:
4545
With this class, customers are able to attach the latest training job to a ModelTrainer.
4646
"""
4747

48-
def __init__(self, job_arn: str, job_name: str):
48+
def __init__(self, job_arn: str, job_name: str, share_identifier: Optional[str] = None):
4949
self.job_arn = job_arn
5050
self.job_name = job_name
51+
self.share_identifier = share_identifier
5152
self._no_training_job_status = {"SUBMITTED", "PENDING", "RUNNABLE"}
5253

5354
def get_model_trainer(self) -> ModelTrainer:

sagemaker-train/tests/unit/train/aws_batch/conftest.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,24 @@
141141
"nextToken": None,
142142
}
143143

144+
LIST_SERVICE_JOB_BY_SHARE_RESP_WITH_JOBS = {
145+
"jobSummaryList": [
146+
{
147+
"jobName": JOB_NAME,
148+
"jobArn": JOB_ARN,
149+
"jobId": JOB_ID,
150+
"shareIdentifier": SHARE_IDENTIFIER,
151+
},
152+
{
153+
"jobName": "another-job",
154+
"jobArn": "arn:aws:batch:us-west-2:123456789012:job/another-id",
155+
"jobId": "another-id",
156+
"shareIdentifier": "another-share-identifier",
157+
},
158+
],
159+
"nextToken": None,
160+
}
161+
144162
LIST_SERVICE_JOB_RESP_WITH_NEXT_TOKEN = {
145163
"jobSummaryList": [
146164
{"jobName": JOB_NAME, "jobArn": JOB_ARN, "jobId": JOB_ID},

sagemaker-train/tests/unit/train/aws_batch/test_training_queue.py

Lines changed: 54 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
DEFAULT_SAGEMAKER_TRAINING_RETRY_CONFIG,
3030
SUBMIT_SERVICE_JOB_RESP,
3131
LIST_SERVICE_JOB_RESP_WITH_JOBS,
32+
LIST_SERVICE_JOB_BY_SHARE_RESP_WITH_JOBS,
3233
LIST_SERVICE_JOB_RESP_EMPTY,
3334
TRAINING_JOB_PAYLOAD,
3435
)
@@ -279,14 +280,14 @@ def test_list_jobs_with_name_filter(self, mock_list_service_job):
279280

280281
# Verify list_service_job was called
281282
mock_list_service_job.assert_called_once()
282-
283+
283284
# Get the call arguments - list_service_job is called with positional args:
284285
# list_service_job(queue_name, status, filters, next_token)
285286
call_args = mock_list_service_job.call_args[0]
286-
287+
287288
# The 3rd positional argument (index 2) is filters
288289
filters = call_args[2] if len(call_args) > 2 else None
289-
290+
290291
# Verify filters contain the job name
291292
assert filters is not None, "Filters should be passed to list_service_job"
292293
assert filters[0]["name"] == "JOB_NAME", "JOB_NAME filter should be present"
@@ -303,6 +304,56 @@ def test_list_jobs_empty(self, mock_list_service_job):
303304
assert len(jobs) == 0
304305

305306

307+
class TestTrainingQueueListByShare:
308+
"""Tests for TrainingQueue.list_jobs_by_share method"""
309+
310+
@patch("sagemaker.train.aws_batch.training_queue._list_service_job")
311+
def test_list_jobs_by_share_default(self, mock_list_service_job):
312+
"""Test list_jobs_by_share with default parameters"""
313+
mock_list_service_job.return_value = iter([LIST_SERVICE_JOB_BY_SHARE_RESP_WITH_JOBS])
314+
315+
queue = TrainingQueue(JOB_QUEUE)
316+
jobs = queue.list_jobs_by_share()
317+
318+
assert len(jobs) == 2
319+
assert jobs[0].share_identifier == SHARE_IDENTIFIER
320+
321+
@patch("sagemaker.train.aws_batch.training_queue._list_service_job")
322+
def test_list_jobs_by_share_with_share_filter(self, mock_list_service_job):
323+
"""Test list_jobs_by_share with job name filter"""
324+
mock_list_service_job.return_value = iter([LIST_SERVICE_JOB_BY_SHARE_RESP_WITH_JOBS])
325+
326+
queue = TrainingQueue(JOB_QUEUE)
327+
jobs = queue.list_jobs_by_share(share_identifier=SHARE_IDENTIFIER)
328+
329+
# Verify list_service_job was called
330+
mock_list_service_job.assert_called_once()
331+
332+
# Get the call arguments - list_service_job is called with positional args:
333+
# list_service_job(queue_name, status, filters, next_token)
334+
call_args = mock_list_service_job.call_args[0]
335+
336+
# The 3rd positional argument (index 2) is filters
337+
filters = call_args[2] if len(call_args) > 2 else None
338+
339+
# Verify filters contain the share identifier
340+
assert filters is not None, "Filters should be passed to list_service_job"
341+
assert filters[0]["name"] == "SHARE_IDENTIFIER", "SHARE_IDENTIFIER filter should be present"
342+
assert filters[0]["values"] == [
343+
SHARE_IDENTIFIER
344+
], "Filter values should contain the share identifier"
345+
346+
@patch("sagemaker.train.aws_batch.training_queue._list_service_job")
347+
def test_list_jobs_by_share_empty(self, mock_list_service_job):
348+
"""Test list_jobs_by_share returns empty list"""
349+
mock_list_service_job.return_value = iter([LIST_SERVICE_JOB_RESP_EMPTY])
350+
351+
queue = TrainingQueue(JOB_QUEUE)
352+
jobs = queue.list_jobs_by_share()
353+
354+
assert len(jobs) == 0
355+
356+
306357
class TestTrainingQueueGet:
307358
"""Tests for TrainingQueue.get_job method"""
308359

0 commit comments

Comments
 (0)