Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def _submit_service_job(
timeout: Set with value of timeout if specified, else default to 1 day.
share_identifier: value of shareIdentifier if specified.
tags: A dict of string to string representing Batch tags.
quota_share_name: Quota Share name for the Batch job.
quota_share_name: value of quotaShareName if specified.
preemption_config: Preemption configuration.

Returns:
Expand Down
20 changes: 18 additions & 2 deletions sagemaker-train/src/sagemaker/train/aws_batch/training_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,11 @@ def submit(
raise ValueError(
"TrainingQueue requires using a ModelTrainer with Mode.SAGEMAKER_TRAINING_JOB"
)

if share_identifier != None and quota_share_name != None:
raise ValueError(
"Either share_identifier or quota_share_name can be specified, but not both"
)
training_payload = training_job._create_training_job_args(
input_data_config=inputs, boto3=True
)
Expand Down Expand Up @@ -108,6 +113,7 @@ def map(
share_identifier: Optional[str] = None,
timeout: Optional[Dict] = None,
tags: Optional[Dict] = None,
quota_share_name: Optional[str] = None,
) -> List[TrainingQueuedJob]:
"""Submit queued jobs to the provided estimator and return a list of TrainingQueuedJob objects.

Expand All @@ -120,6 +126,7 @@ def map(
share_identifier: Share identifier for the Batch jobs.
timeout: Timeout configuration for the Batch jobs.
tags: Tags apply to Batch job. These tags are for Batch job only.
quota_share_name: Quota share name for the Batch jobs.

Returns: a list of TrainingQueuedJob objects with each Batch job ARN and job name.

Expand All @@ -144,6 +151,7 @@ def map(
share_identifier,
timeout,
tags,
quota_share_name,
)
queued_batch_job_list.append(queued_batch_job)

Expand Down Expand Up @@ -171,7 +179,7 @@ def list_jobs(
for job_result in job_result_dict.get("jobSummaryList", []):
if "jobArn" in job_result and "jobName" in job_result:
jobs_to_return.append(
TrainingQueuedJob(job_result["jobArn"], job_result["jobName"], job_result.get("shareIdentifier", None))
TrainingQueuedJob(job_result["jobArn"], job_result["jobName"], job_result.get("shareIdentifier", None), job_result.get("quotaShareName", None))
)
else:
logging.warning("Missing JobArn or JobName in Batch ListJobs API")
Expand All @@ -182,27 +190,35 @@ def list_jobs_by_share(
self,
status: Optional[str] = JOB_STATUS_RUNNING,
share_identifier: Optional[str] = None,
quota_share_name: Optional[str] = None,
) -> List[TrainingQueuedJob]:
"""List Batch jobs according to status and share.

Args:
status: Batch job status.
share_identifier: Batch fairshare share identifier.
quota_share_name: Batch quota management share name.

Returns: A list of QueuedJob.

"""
filters = None
if share_identifier != None and quota_share_name != None:
raise ValueError(
"Either share_identifier or quota_share_name can be specified, but not both"
)
if share_identifier:
filters = [{"name": "SHARE_IDENTIFIER", "values": [share_identifier]}]
elif quota_share_name:
filters = [{"name": "QUOTA_SHARE_NAME", "values": [quota_share_name]}]

jobs_to_return = []
next_token = None
for job_result_dict in _list_service_job(self.queue_name, status, filters, next_token):
for job_result in job_result_dict.get("jobSummaryList", []):
if "jobArn" in job_result and "jobName" in job_result:
jobs_to_return.append(
TrainingQueuedJob(job_result["jobArn"], job_result["jobName"], job_result.get("shareIdentifier", None))
TrainingQueuedJob(job_result["jobArn"], job_result["jobName"], job_result.get("shareIdentifier", None), job_result.get("quotaShareName", None))
)
else:
logging.warning("Missing JobArn or JobName in Batch ListJobs API")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,11 @@ class TrainingQueuedJob:
With this class, customers are able to attach the latest training job to a ModelTrainer.
"""

def __init__(self, job_arn: str, job_name: str, share_identifier: Optional[str] = None):
def __init__(self, job_arn: str, job_name: str, share_identifier: Optional[str] = None, quota_share_name: Optional[str] = None):
self.job_arn = job_arn
self.job_name = job_name
self.share_identifier = share_identifier
self.quota_share_name = quota_share_name
self._no_training_job_status = {"SUBMITTED", "PENDING", "RUNNABLE"}

def get_model_trainer(self) -> ModelTrainer:
Expand Down
19 changes: 19 additions & 0 deletions sagemaker-train/tests/unit/train/aws_batch/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
# Batch configuration
SCHEDULING_PRIORITY = 1
SHARE_IDENTIFIER = "test-share-id"
QUOTA_SHARE_NAME = "test-quota-share"
ATTEMPT_DURATION_IN_SECONDS = 86400
REASON = "Test termination reason"
NEXT_TOKEN = "test-next-token"
Expand Down Expand Up @@ -159,6 +160,24 @@
"nextToken": None,
}

LIST_SERVICE_JOB_BY_QUOTA_SHARE_RESP_WITH_JOBS = {
"jobSummaryList": [
{
"jobName": JOB_NAME,
"jobArn": JOB_ARN,
"jobId": JOB_ID,
"quotaShareName": QUOTA_SHARE_NAME,
},
{
"jobName": "another-job",
"jobArn": "arn:aws:batch:us-west-2:123456789012:job/another-id",
"jobId": "another-id",
"quotaShareName": "another-quota-share",
},
],
"nextToken": None,
}

LIST_SERVICE_JOB_RESP_WITH_NEXT_TOKEN = {
"jobSummaryList": [
{"jobName": JOB_NAME, "jobArn": JOB_ARN, "jobId": JOB_ID},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
TIMEOUT_CONFIG,
SCHEDULING_PRIORITY,
SHARE_IDENTIFIER,
QUOTA_SHARE_NAME,
SUBMIT_SERVICE_JOB_RESP,
DESCRIBE_SERVICE_JOB_RESP_RUNNING,
LIST_SERVICE_JOB_RESP_EMPTY,
Expand Down Expand Up @@ -97,6 +98,25 @@ def test_submit_service_job_with_all_params(self, mock_get_client):
assert call_kwargs["shareIdentifier"] == SHARE_IDENTIFIER
assert call_kwargs["timeoutConfig"] == TIMEOUT_CONFIG

@patch("sagemaker.train.aws_batch.batch_api_helper.get_batch_boto_client")
def test_submit_service_job_with_quota_share_name(self, mock_get_client):
"""Test submit_service_job with quota_share_name parameter"""
mock_client = Mock()
mock_client.submit_service_job.return_value = SUBMIT_SERVICE_JOB_RESP
mock_get_client.return_value = mock_client

result = _submit_service_job(
TRAINING_JOB_PAYLOAD,
JOB_NAME,
JOB_QUEUE,
quota_share_name=QUOTA_SHARE_NAME,
)

assert result["jobArn"] == SUBMIT_SERVICE_JOB_RESP["jobArn"]
call_kwargs = mock_client.submit_service_job.call_args[1]
assert call_kwargs["quotaShareName"] == QUOTA_SHARE_NAME
assert "shareIdentifier" not in call_kwargs

@patch("sagemaker.train.aws_batch.batch_api_helper.get_batch_boto_client")
def test_submit_service_job_with_tags(self, mock_get_client):
"""Test submit_service_job merges batch and training tags"""
Expand Down
Loading
Loading