diff --git a/sagemaker-train/src/sagemaker/train/aws_batch/batch_api_helper.py b/sagemaker-train/src/sagemaker/train/aws_batch/batch_api_helper.py index ea20ab362b..6b03995ffb 100644 --- a/sagemaker-train/src/sagemaker/train/aws_batch/batch_api_helper.py +++ b/sagemaker-train/src/sagemaker/train/aws_batch/batch_api_helper.py @@ -46,7 +46,7 @@ def _submit_service_job( timeout: Set with value of timeout if specified, else default to 1 day. share_identifier: value of shareIdentifier if specified. tags: A dict of string to string representing Batch tags. - quota_share_name: Quota Share name for the Batch job. + quota_share_name: value of quotaShareName if specified. preemption_config: Preemption configuration. Returns: diff --git a/sagemaker-train/src/sagemaker/train/aws_batch/training_queue.py b/sagemaker-train/src/sagemaker/train/aws_batch/training_queue.py index d3b5f78940..1d5f66eacd 100644 --- a/sagemaker-train/src/sagemaker/train/aws_batch/training_queue.py +++ b/sagemaker-train/src/sagemaker/train/aws_batch/training_queue.py @@ -71,6 +71,11 @@ def submit( raise ValueError( "TrainingQueue requires using a ModelTrainer with Mode.SAGEMAKER_TRAINING_JOB" ) + + if share_identifier != None and quota_share_name != None: + raise ValueError( + "Either share_identifier or quota_share_name can be specified, but not both" + ) training_payload = training_job._create_training_job_args( input_data_config=inputs, boto3=True ) @@ -108,6 +113,7 @@ def map( share_identifier: Optional[str] = None, timeout: Optional[Dict] = None, tags: Optional[Dict] = None, + quota_share_name: Optional[str] = None, ) -> List[TrainingQueuedJob]: """Submit queued jobs to the provided estimator and return a list of TrainingQueuedJob objects. @@ -120,6 +126,7 @@ def map( share_identifier: Share identifier for the Batch jobs. timeout: Timeout configuration for the Batch jobs. tags: Tags apply to Batch job. These tags are for Batch job only. + quota_share_name: Quota share name for the Batch jobs. Returns: a list of TrainingQueuedJob objects with each Batch job ARN and job name. @@ -144,6 +151,7 @@ def map( share_identifier, timeout, tags, + quota_share_name, ) queued_batch_job_list.append(queued_batch_job) @@ -171,7 +179,7 @@ def list_jobs( for job_result in job_result_dict.get("jobSummaryList", []): if "jobArn" in job_result and "jobName" in job_result: jobs_to_return.append( - TrainingQueuedJob(job_result["jobArn"], job_result["jobName"], job_result.get("shareIdentifier", None)) + TrainingQueuedJob(job_result["jobArn"], job_result["jobName"], job_result.get("shareIdentifier", None), job_result.get("quotaShareName", None)) ) else: logging.warning("Missing JobArn or JobName in Batch ListJobs API") @@ -182,19 +190,27 @@ def list_jobs_by_share( self, status: Optional[str] = JOB_STATUS_RUNNING, share_identifier: Optional[str] = None, + quota_share_name: Optional[str] = None, ) -> List[TrainingQueuedJob]: """List Batch jobs according to status and share. Args: status: Batch job status. share_identifier: Batch fairshare share identifier. + quota_share_name: Batch quota management share name. Returns: A list of QueuedJob. """ filters = None + if share_identifier != None and quota_share_name != None: + raise ValueError( + "Either share_identifier or quota_share_name can be specified, but not both" + ) if share_identifier: filters = [{"name": "SHARE_IDENTIFIER", "values": [share_identifier]}] + elif quota_share_name: + filters = [{"name": "QUOTA_SHARE_NAME", "values": [quota_share_name]}] jobs_to_return = [] next_token = None @@ -202,7 +218,7 @@ def list_jobs_by_share( for job_result in job_result_dict.get("jobSummaryList", []): if "jobArn" in job_result and "jobName" in job_result: jobs_to_return.append( - TrainingQueuedJob(job_result["jobArn"], job_result["jobName"], job_result.get("shareIdentifier", None)) + TrainingQueuedJob(job_result["jobArn"], job_result["jobName"], job_result.get("shareIdentifier", None), job_result.get("quotaShareName", None)) ) else: logging.warning("Missing JobArn or JobName in Batch ListJobs API") diff --git a/sagemaker-train/src/sagemaker/train/aws_batch/training_queued_job.py b/sagemaker-train/src/sagemaker/train/aws_batch/training_queued_job.py index 0b8d73eebc..df7816823d 100644 --- a/sagemaker-train/src/sagemaker/train/aws_batch/training_queued_job.py +++ b/sagemaker-train/src/sagemaker/train/aws_batch/training_queued_job.py @@ -45,10 +45,11 @@ class TrainingQueuedJob: With this class, customers are able to attach the latest training job to a ModelTrainer. """ - def __init__(self, job_arn: str, job_name: str, share_identifier: Optional[str] = None): + def __init__(self, job_arn: str, job_name: str, share_identifier: Optional[str] = None, quota_share_name: Optional[str] = None): self.job_arn = job_arn self.job_name = job_name self.share_identifier = share_identifier + self.quota_share_name = quota_share_name self._no_training_job_status = {"SUBMITTED", "PENDING", "RUNNABLE"} def get_model_trainer(self) -> ModelTrainer: diff --git a/sagemaker-train/tests/unit/train/aws_batch/conftest.py b/sagemaker-train/tests/unit/train/aws_batch/conftest.py index 02d883f6e4..58852dd50c 100644 --- a/sagemaker-train/tests/unit/train/aws_batch/conftest.py +++ b/sagemaker-train/tests/unit/train/aws_batch/conftest.py @@ -43,6 +43,7 @@ # Batch configuration SCHEDULING_PRIORITY = 1 SHARE_IDENTIFIER = "test-share-id" +QUOTA_SHARE_NAME = "test-quota-share" ATTEMPT_DURATION_IN_SECONDS = 86400 REASON = "Test termination reason" NEXT_TOKEN = "test-next-token" @@ -159,6 +160,24 @@ "nextToken": None, } +LIST_SERVICE_JOB_BY_QUOTA_SHARE_RESP_WITH_JOBS = { + "jobSummaryList": [ + { + "jobName": JOB_NAME, + "jobArn": JOB_ARN, + "jobId": JOB_ID, + "quotaShareName": QUOTA_SHARE_NAME, + }, + { + "jobName": "another-job", + "jobArn": "arn:aws:batch:us-west-2:123456789012:job/another-id", + "jobId": "another-id", + "quotaShareName": "another-quota-share", + }, + ], + "nextToken": None, +} + LIST_SERVICE_JOB_RESP_WITH_NEXT_TOKEN = { "jobSummaryList": [ {"jobName": JOB_NAME, "jobArn": JOB_ARN, "jobId": JOB_ID}, diff --git a/sagemaker-train/tests/unit/train/aws_batch/test_batch_api_helper.py b/sagemaker-train/tests/unit/train/aws_batch/test_batch_api_helper.py index 9e90edfb39..322e4d29d5 100644 --- a/sagemaker-train/tests/unit/train/aws_batch/test_batch_api_helper.py +++ b/sagemaker-train/tests/unit/train/aws_batch/test_batch_api_helper.py @@ -37,6 +37,7 @@ TIMEOUT_CONFIG, SCHEDULING_PRIORITY, SHARE_IDENTIFIER, + QUOTA_SHARE_NAME, SUBMIT_SERVICE_JOB_RESP, DESCRIBE_SERVICE_JOB_RESP_RUNNING, LIST_SERVICE_JOB_RESP_EMPTY, @@ -97,6 +98,25 @@ def test_submit_service_job_with_all_params(self, mock_get_client): assert call_kwargs["shareIdentifier"] == SHARE_IDENTIFIER assert call_kwargs["timeoutConfig"] == TIMEOUT_CONFIG + @patch("sagemaker.train.aws_batch.batch_api_helper.get_batch_boto_client") + def test_submit_service_job_with_quota_share_name(self, mock_get_client): + """Test submit_service_job with quota_share_name parameter""" + mock_client = Mock() + mock_client.submit_service_job.return_value = SUBMIT_SERVICE_JOB_RESP + mock_get_client.return_value = mock_client + + result = _submit_service_job( + TRAINING_JOB_PAYLOAD, + JOB_NAME, + JOB_QUEUE, + quota_share_name=QUOTA_SHARE_NAME, + ) + + assert result["jobArn"] == SUBMIT_SERVICE_JOB_RESP["jobArn"] + call_kwargs = mock_client.submit_service_job.call_args[1] + assert call_kwargs["quotaShareName"] == QUOTA_SHARE_NAME + assert "shareIdentifier" not in call_kwargs + @patch("sagemaker.train.aws_batch.batch_api_helper.get_batch_boto_client") def test_submit_service_job_with_tags(self, mock_get_client): """Test submit_service_job merges batch and training tags""" diff --git a/sagemaker-train/tests/unit/train/aws_batch/test_training_queue.py b/sagemaker-train/tests/unit/train/aws_batch/test_training_queue.py index 3c1084ff58..fe92fe4123 100644 --- a/sagemaker-train/tests/unit/train/aws_batch/test_training_queue.py +++ b/sagemaker-train/tests/unit/train/aws_batch/test_training_queue.py @@ -24,12 +24,14 @@ JOB_ID, SCHEDULING_PRIORITY, SHARE_IDENTIFIER, + QUOTA_SHARE_NAME, TIMEOUT_CONFIG, BATCH_TAGS, DEFAULT_SAGEMAKER_TRAINING_RETRY_CONFIG, SUBMIT_SERVICE_JOB_RESP, LIST_SERVICE_JOB_RESP_WITH_JOBS, LIST_SERVICE_JOB_BY_SHARE_RESP_WITH_JOBS, + LIST_SERVICE_JOB_BY_QUOTA_SHARE_RESP_WITH_JOBS, LIST_SERVICE_JOB_RESP_EMPTY, TRAINING_JOB_PAYLOAD, QUOTA_SHARE_NAME, @@ -68,6 +70,7 @@ def test_submit_model_trainer(self, mock_submit_service_job): SHARE_IDENTIFIER, TIMEOUT_CONFIG, BATCH_TAGS, + None, ) assert queued_job.job_name == JOB_NAME @@ -93,6 +96,7 @@ def test_submit_with_default_timeout(self, mock_submit_service_job): SHARE_IDENTIFIER, None, # No timeout BATCH_TAGS, + None, ) call_kwargs = mock_submit_service_job.call_args[0] @@ -118,6 +122,7 @@ def test_submit_with_generated_job_name(self, mock_submit_service_job): SHARE_IDENTIFIER, TIMEOUT_CONFIG, BATCH_TAGS, + None, ) call_kwargs = mock_submit_service_job.call_args[0] @@ -138,6 +143,7 @@ def test_submit_invalid_training_job_type(self): SHARE_IDENTIFIER, TIMEOUT_CONFIG, BATCH_TAGS, + None, ) def test_submit_invalid_training_mode(self): @@ -157,6 +163,7 @@ def test_submit_invalid_training_mode(self): SHARE_IDENTIFIER, TIMEOUT_CONFIG, BATCH_TAGS, + None, ) @patch("sagemaker.train.aws_batch.training_queue._submit_service_job") @@ -180,6 +187,57 @@ def test_submit_missing_job_arn_in_response(self, mock_submit_service_job): SHARE_IDENTIFIER, TIMEOUT_CONFIG, BATCH_TAGS, + None, + ) + + + @patch("sagemaker.train.aws_batch.training_queue._submit_service_job") + def test_submit_with_quota_share_name(self, mock_submit_service_job): + """Test submit with quota_share_name""" + mock_submit_service_job.return_value = SUBMIT_SERVICE_JOB_RESP + + trainer = Mock(spec=ModelTrainer) + trainer.training_mode = Mode.SAGEMAKER_TRAINING_JOB + trainer._create_training_job_args.return_value = TRAINING_JOB_PAYLOAD + + queue = TrainingQueue(JOB_QUEUE) + queued_job = queue.submit( + trainer, + [], + JOB_NAME, + DEFAULT_SAGEMAKER_TRAINING_RETRY_CONFIG, + SCHEDULING_PRIORITY, + None, + TIMEOUT_CONFIG, + BATCH_TAGS, + QUOTA_SHARE_NAME, + ) + + assert queued_job.job_name == JOB_NAME + assert queued_job.job_arn == JOB_ARN + mock_submit_service_job.assert_called_once() + call_args = mock_submit_service_job.call_args[0] + assert call_args[8] == QUOTA_SHARE_NAME + + def test_submit_both_share_identifier_and_quota_share_name_raises(self): + """Test submit raises error when both share_identifier and quota_share_name are specified""" + trainer = Mock(spec=ModelTrainer) + trainer.training_mode = Mode.SAGEMAKER_TRAINING_JOB + trainer._create_training_job_args.return_value = TRAINING_JOB_PAYLOAD + + queue = TrainingQueue(JOB_QUEUE) + + with pytest.raises(ValueError, match="Either share_identifier or quota_share_name"): + queue.submit( + trainer, + [], + JOB_NAME, + DEFAULT_SAGEMAKER_TRAINING_RETRY_CONFIG, + SCHEDULING_PRIORITY, + SHARE_IDENTIFIER, + TIMEOUT_CONFIG, + BATCH_TAGS, + QUOTA_SHARE_NAME, ) @@ -206,6 +264,7 @@ def test_map_multiple_inputs(self, mock_submit_service_job): SHARE_IDENTIFIER, TIMEOUT_CONFIG, BATCH_TAGS, + None, ) assert len(queued_jobs) == 3 @@ -232,6 +291,7 @@ def test_map_with_job_names(self, mock_submit_service_job): SHARE_IDENTIFIER, TIMEOUT_CONFIG, BATCH_TAGS, + None, ) assert len(queued_jobs) == 2 @@ -255,8 +315,35 @@ def test_map_mismatched_job_names_length(self): SHARE_IDENTIFIER, TIMEOUT_CONFIG, BATCH_TAGS, + None, ) + @patch("sagemaker.train.aws_batch.training_queue._submit_service_job") + def test_map_with_quota_share_name(self, mock_submit_service_job): + """Test map passes quota_share_name through to submit""" + mock_submit_service_job.return_value = SUBMIT_SERVICE_JOB_RESP + + trainer = Mock(spec=ModelTrainer) + trainer.training_mode = Mode.SAGEMAKER_TRAINING_JOB + trainer._create_training_job_args.return_value = TRAINING_JOB_PAYLOAD + + queue = TrainingQueue(JOB_QUEUE) + inputs = ["input1", "input2"] + queued_jobs = queue.map( + trainer, + inputs, + None, + DEFAULT_SAGEMAKER_TRAINING_RETRY_CONFIG, + SCHEDULING_PRIORITY, + None, + TIMEOUT_CONFIG, + BATCH_TAGS, + QUOTA_SHARE_NAME, + ) + + assert len(queued_jobs) == 2 + for call_args in mock_submit_service_job.call_args_list: + assert call_args[0][8] == QUOTA_SHARE_NAME class TestTrainingQueueList: """Tests for TrainingQueue.list_jobs method""" @@ -305,6 +392,20 @@ def test_list_jobs_empty(self, mock_list_service_job): assert len(jobs) == 0 + @patch("sagemaker.train.aws_batch.training_queue._list_service_job") + def test_list_jobs_extracts_quota_share_name(self, mock_list_service_job): + """Test list_jobs extracts quotaShareName from response into TrainingQueuedJob""" + mock_list_service_job.return_value = iter([LIST_SERVICE_JOB_BY_QUOTA_SHARE_RESP_WITH_JOBS]) + + queue = TrainingQueue(JOB_QUEUE) + jobs = queue.list_jobs() + + assert len(jobs) == 2 + assert jobs[0].quota_share_name == QUOTA_SHARE_NAME + assert jobs[0].share_identifier is None + assert jobs[1].quota_share_name == "another-quota-share" + assert jobs[1].share_identifier is None + class TestTrainingQueueListByShare: """Tests for TrainingQueue.list_jobs_by_share method""" @@ -345,6 +446,38 @@ def test_list_jobs_by_share_with_share_filter(self, mock_list_service_job): SHARE_IDENTIFIER ], "Filter values should contain the share identifier" + @patch("sagemaker.train.aws_batch.training_queue._list_service_job") + def test_list_jobs_by_share_with_quota_share_name_filter(self, mock_list_service_job): + """Test list_jobs_by_share with quota_share_name filter""" + mock_list_service_job.return_value = iter([LIST_SERVICE_JOB_BY_QUOTA_SHARE_RESP_WITH_JOBS]) + + queue = TrainingQueue(JOB_QUEUE) + jobs = queue.list_jobs_by_share(quota_share_name=QUOTA_SHARE_NAME) + + assert len(jobs) == 2 + assert jobs[0].quota_share_name == QUOTA_SHARE_NAME + + mock_list_service_job.assert_called_once() + + call_args = mock_list_service_job.call_args[0] + filters = call_args[2] if len(call_args) > 2 else None + + assert filters is not None, "Filters should be passed to list_service_job" + assert filters[0]["name"] == "QUOTA_SHARE_NAME", "QUOTA_SHARE_NAME filter should be present" + assert filters[0]["values"] == [ + QUOTA_SHARE_NAME + ], "Filter values should contain the quota share name" + + def test_list_jobs_by_share_both_share_and_quota_raises(self): + """Test list_jobs_by_share raises error when both share_identifier and quota_share_name are specified""" + queue = TrainingQueue(JOB_QUEUE) + + with pytest.raises(ValueError, match="Either share_identifier or quota_share_name"): + queue.list_jobs_by_share( + share_identifier=SHARE_IDENTIFIER, + quota_share_name=QUOTA_SHARE_NAME, + ) + @patch("sagemaker.train.aws_batch.training_queue._list_service_job") def test_list_jobs_by_share_empty(self, mock_list_service_job): """Test list_jobs_by_share returns empty list""" @@ -400,7 +533,7 @@ def test_submit_with_quota_share_name_and_preemption_config(self, mock_submit_se JOB_NAME, DEFAULT_SAGEMAKER_TRAINING_RETRY_CONFIG, SCHEDULING_PRIORITY, - SHARE_IDENTIFIER, + None, TIMEOUT_CONFIG, BATCH_TAGS, quota_share_name=QUOTA_SHARE_NAME, @@ -416,7 +549,7 @@ def test_submit_with_quota_share_name_and_preemption_config(self, mock_submit_se DEFAULT_SAGEMAKER_TRAINING_RETRY_CONFIG, SCHEDULING_PRIORITY, TIMEOUT_CONFIG, - SHARE_IDENTIFIER, + None, BATCH_TAGS, QUOTA_SHARE_NAME, PREEMPTION_CONFIG, diff --git a/sagemaker-train/tests/unit/train/aws_batch/test_training_queued_job.py b/sagemaker-train/tests/unit/train/aws_batch/test_training_queued_job.py index 2ba61f4471..532b436c7f 100644 --- a/sagemaker-train/tests/unit/train/aws_batch/test_training_queued_job.py +++ b/sagemaker-train/tests/unit/train/aws_batch/test_training_queued_job.py @@ -47,6 +47,14 @@ def test_training_queued_job_init(self): assert queued_job.job_arn == JOB_ARN assert queued_job.job_name == JOB_NAME + def test_training_queued_job_init_with_quota_share_name(self): + """Test TrainingQueuedJob initialization with quota_share_name""" + queued_job = TrainingQueuedJob(JOB_ARN, JOB_NAME, quota_share_name="test-quota") + assert queued_job.job_arn == JOB_ARN + assert queued_job.job_name == JOB_NAME + assert queued_job.share_identifier is None + assert queued_job.quota_share_name == "test-quota" + class TestTrainingQueuedJobDescribe: """Tests for TrainingQueuedJob.describe method""" diff --git a/v3-examples/training-examples/aws_batch/utils/aws_batch_resource_management.py b/v3-examples/training-examples/aws_batch/utils/aws_batch_resource_management.py index c77b222023..4c0e8af9f5 100644 --- a/v3-examples/training-examples/aws_batch/utils/aws_batch_resource_management.py +++ b/v3-examples/training-examples/aws_batch/utils/aws_batch_resource_management.py @@ -816,8 +816,10 @@ def list_jobs_by_quota_share(training_queue: TrainingQueue, quota_share_names: L Args: training_queue (TrainingQueue): The TrainingQueue to query for jobs. """ - all_jobs = [job for status in statuses for job in training_queue.list_jobs(status=status.value)] - + all_jobs = [ + job for status in statuses for qs_name in quota_share_names + for job in training_queue.list_jobs_by_share(quota_share_name=qs_name, status=status.value) + ] jobs_by_qs = {qs_name: [] for qs_name in quota_share_names} for job in all_jobs: job_detail = job.describe()