|
25 | 25 | DEFAULT_MAX_RUNTIME_IN_SECONDS, |
26 | 26 | ) |
27 | 27 | from sagemaker.train.configs import Compute, StoppingCondition |
| 28 | +from sagemaker.core.shapes import InstanceGroup |
28 | 29 |
|
29 | 30 |
|
30 | 31 | class TestDefaultConstants: |
@@ -435,3 +436,172 @@ def test_uses_default_volume_size_when_not_in_document( |
435 | 436 | ) |
436 | 437 |
|
437 | 438 | assert result.volume_size_in_gb == DEFAULT_VOLUME_SIZE |
| 439 | + |
| 440 | + |
| 441 | + def test_does_not_set_instance_type_when_instance_groups_configured(self): |
| 442 | + """Test instance_type is not overwritten when instance_groups are set.""" |
| 443 | + compute = Compute( |
| 444 | + instance_groups=[InstanceGroup(instance_type="ml.p3.2xlarge", instance_count=1, instance_group_name="group1")], |
| 445 | + instance_type=None, |
| 446 | + instance_count=None, |
| 447 | + volume_size_in_gb=30, |
| 448 | + ) |
| 449 | + result = TrainDefaults.get_compute(compute=compute) |
| 450 | + assert result.instance_type is None |
| 451 | + |
| 452 | + def test_does_not_set_instance_count_when_instance_groups_configured(self): |
| 453 | + """Test instance_count is not overwritten when instance_groups are set.""" |
| 454 | + compute = Compute( |
| 455 | + instance_groups=[InstanceGroup(instance_type="ml.p3.2xlarge", instance_count=1, instance_group_name="group1")], |
| 456 | + instance_type=None, |
| 457 | + instance_count=None, |
| 458 | + volume_size_in_gb=30, |
| 459 | + ) |
| 460 | + result = TrainDefaults.get_compute(compute=compute) |
| 461 | + assert result.instance_count is None |
| 462 | + |
| 463 | + def test_sets_volume_size_when_instance_groups_configured(self): |
| 464 | + """Test volume_size_in_gb is still set when instance_groups are configured.""" |
| 465 | + compute = Compute( |
| 466 | + instance_groups=[InstanceGroup(instance_type="ml.p3.2xlarge", instance_count=1, instance_group_name="group1")], |
| 467 | + instance_type=None, |
| 468 | + instance_count=None, |
| 469 | + volume_size_in_gb=None, |
| 470 | + ) |
| 471 | + result = TrainDefaults.get_compute(compute=compute) |
| 472 | + assert result.volume_size_in_gb == DEFAULT_VOLUME_SIZE |
| 473 | + |
| 474 | + def test_preserves_existing_volume_size_with_instance_groups(self): |
| 475 | + """Test existing volume_size_in_gb is preserved when instance_groups are configured.""" |
| 476 | + compute = Compute( |
| 477 | + instance_groups=[InstanceGroup(instance_type="ml.p3.2xlarge", instance_count=1, instance_group_name="group1")], |
| 478 | + instance_type=None, |
| 479 | + instance_count=None, |
| 480 | + volume_size_in_gb=100, |
| 481 | + ) |
| 482 | + result = TrainDefaults.get_compute(compute=compute) |
| 483 | + assert result.volume_size_in_gb == 100 |
| 484 | + |
| 485 | + |
| 486 | +class TestJumpStartTrainDefaultsGetComputeHeterogeneousCluster: |
| 487 | + """Test JumpStartTrainDefaults.get_compute with heterogeneous cluster (instance_groups).""" |
| 488 | + |
| 489 | + @patch("sagemaker.train.defaults.get_hub_content_and_document") |
| 490 | + @patch("sagemaker.train.defaults.TrainDefaults.get_sagemaker_session") |
| 491 | + def test_does_not_set_instance_type_when_instance_groups_configured( |
| 492 | + self, mock_get_session, mock_get_hub_content |
| 493 | + ): |
| 494 | + """Test instance_type is not overwritten when instance_groups are set.""" |
| 495 | + mock_session = MagicMock() |
| 496 | + mock_get_session.return_value = mock_session |
| 497 | + |
| 498 | + mock_document = MagicMock() |
| 499 | + mock_document.DefaultTrainingInstanceType = "ml.p3.2xlarge" |
| 500 | + mock_document.TrainingVolumeSize = 100 |
| 501 | + mock_get_hub_content.return_value = (None, mock_document) |
| 502 | + |
| 503 | + mock_config = MagicMock() |
| 504 | + mock_config.training_config_name = None |
| 505 | + |
| 506 | + compute = Compute( |
| 507 | + instance_groups=[InstanceGroup(instance_type="ml.p3.2xlarge", instance_count=1, instance_group_name="group1")], |
| 508 | + instance_type=None, |
| 509 | + instance_count=None, |
| 510 | + volume_size_in_gb=30, |
| 511 | + ) |
| 512 | + result = JumpStartTrainDefaults.get_compute( |
| 513 | + jumpstart_config=mock_config, |
| 514 | + compute=compute, |
| 515 | + sagemaker_session=mock_session, |
| 516 | + ) |
| 517 | + assert result.instance_type is None |
| 518 | + |
| 519 | + @patch("sagemaker.train.defaults.get_hub_content_and_document") |
| 520 | + @patch("sagemaker.train.defaults.TrainDefaults.get_sagemaker_session") |
| 521 | + def test_does_not_set_instance_count_when_instance_groups_configured( |
| 522 | + self, mock_get_session, mock_get_hub_content |
| 523 | + ): |
| 524 | + """Test instance_count is not overwritten when instance_groups are set.""" |
| 525 | + mock_session = MagicMock() |
| 526 | + mock_get_session.return_value = mock_session |
| 527 | + |
| 528 | + mock_document = MagicMock() |
| 529 | + mock_document.DefaultTrainingInstanceType = "ml.p3.2xlarge" |
| 530 | + mock_document.TrainingVolumeSize = 100 |
| 531 | + mock_get_hub_content.return_value = (None, mock_document) |
| 532 | + |
| 533 | + mock_config = MagicMock() |
| 534 | + mock_config.training_config_name = None |
| 535 | + |
| 536 | + compute = Compute( |
| 537 | + instance_groups=[InstanceGroup(instance_type="ml.p3.2xlarge", instance_count=1, instance_group_name="group1")], |
| 538 | + instance_type=None, |
| 539 | + instance_count=None, |
| 540 | + volume_size_in_gb=30, |
| 541 | + ) |
| 542 | + result = JumpStartTrainDefaults.get_compute( |
| 543 | + jumpstart_config=mock_config, |
| 544 | + compute=compute, |
| 545 | + sagemaker_session=mock_session, |
| 546 | + ) |
| 547 | + assert result.instance_count is None |
| 548 | + |
| 549 | + @patch("sagemaker.train.defaults.get_hub_content_and_document") |
| 550 | + @patch("sagemaker.train.defaults.TrainDefaults.get_sagemaker_session") |
| 551 | + def test_sets_volume_size_from_document_when_instance_groups_configured( |
| 552 | + self, mock_get_session, mock_get_hub_content |
| 553 | + ): |
| 554 | + """Test volume_size_in_gb is set from document even when instance_groups are configured.""" |
| 555 | + mock_session = MagicMock() |
| 556 | + mock_get_session.return_value = mock_session |
| 557 | + |
| 558 | + mock_document = MagicMock() |
| 559 | + mock_document.DefaultTrainingInstanceType = "ml.p3.2xlarge" |
| 560 | + mock_document.TrainingVolumeSize = 100 |
| 561 | + mock_get_hub_content.return_value = (None, mock_document) |
| 562 | + |
| 563 | + mock_config = MagicMock() |
| 564 | + mock_config.training_config_name = None |
| 565 | + |
| 566 | + compute = Compute( |
| 567 | + instance_groups=[InstanceGroup(instance_type="ml.p3.2xlarge", instance_count=1, instance_group_name="group1")], |
| 568 | + instance_type=None, |
| 569 | + instance_count=None, |
| 570 | + volume_size_in_gb=None, |
| 571 | + ) |
| 572 | + result = JumpStartTrainDefaults.get_compute( |
| 573 | + jumpstart_config=mock_config, |
| 574 | + compute=compute, |
| 575 | + sagemaker_session=mock_session, |
| 576 | + ) |
| 577 | + assert result.volume_size_in_gb == 100 |
| 578 | + |
| 579 | + @patch("sagemaker.train.defaults.get_hub_content_and_document") |
| 580 | + @patch("sagemaker.train.defaults.TrainDefaults.get_sagemaker_session") |
| 581 | + def test_sets_default_volume_size_when_instance_groups_and_no_document_volume( |
| 582 | + self, mock_get_session, mock_get_hub_content |
| 583 | + ): |
| 584 | + """Test DEFAULT_VOLUME_SIZE is used when instance_groups set and document has no volume.""" |
| 585 | + mock_session = MagicMock() |
| 586 | + mock_get_session.return_value = mock_session |
| 587 | + |
| 588 | + mock_document = MagicMock() |
| 589 | + mock_document.DefaultTrainingInstanceType = "ml.p3.2xlarge" |
| 590 | + mock_document.TrainingVolumeSize = None |
| 591 | + mock_get_hub_content.return_value = (None, mock_document) |
| 592 | + |
| 593 | + mock_config = MagicMock() |
| 594 | + mock_config.training_config_name = None |
| 595 | + |
| 596 | + compute = Compute( |
| 597 | + instance_groups=[InstanceGroup(instance_type="ml.p3.2xlarge", instance_count=1, instance_group_name="group1")], |
| 598 | + instance_type=None, |
| 599 | + instance_count=None, |
| 600 | + volume_size_in_gb=None, |
| 601 | + ) |
| 602 | + result = JumpStartTrainDefaults.get_compute( |
| 603 | + jumpstart_config=mock_config, |
| 604 | + compute=compute, |
| 605 | + sagemaker_session=mock_session, |
| 606 | + ) |
| 607 | + assert result.volume_size_in_gb == DEFAULT_VOLUME_SIZE |
0 commit comments