|
| 1 | +import json |
1 | 2 | import pytest |
2 | 3 | from unittest.mock import Mock, patch, MagicMock |
3 | 4 | from sagemaker.train.common_utils.finetune_utils import ( |
@@ -304,6 +305,7 @@ def test__get_fine_tuning_options_and_model_arn(self, mock_boto_client, mock_get |
304 | 305 | mock_s3_client.get_object.return_value = { |
305 | 306 | "Body": Mock(read=Mock(return_value=b'{"learning_rate": 0.001}')) |
306 | 307 | } |
| 308 | + mock_session.boto_session.client.return_value = mock_s3_client |
307 | 309 |
|
308 | 310 | result = _get_fine_tuning_options_and_model_arn("test-model", "SFT", "LORA", mock_session) |
309 | 311 |
|
@@ -551,3 +553,140 @@ def test__validate_s3_path_exists_with_prefix_not_exists(self, mock_boto_client) |
551 | 553 | mock_s3_client.put_object.assert_called_once_with(Bucket="test-bucket", Key="prefix/", Body=b'') |
552 | 554 |
|
553 | 555 |
|
| 556 | + |
| 557 | + @patch('sagemaker.train.common_utils.finetune_utils._get_hub_content_metadata') |
| 558 | + def test__get_fine_tuning_options_with_subscription_recipe_enabled(self, mock_get_hub_content): |
| 559 | + """When and user is subscribed, datamix HPs are available.""" |
| 560 | + mock_session = Mock() |
| 561 | + mock_session.boto_session.region_name = "us-east-1" |
| 562 | + mock_s3 = Mock() |
| 563 | + mock_sts = Mock() |
| 564 | + mock_sts.get_caller_identity.return_value = {"Account": "123456789012"} |
| 565 | + mock_session.boto_session.client.side_effect = lambda service, **kwargs: mock_s3 if service == "s3" else mock_sts |
| 566 | + |
| 567 | + mock_get_hub_content.return_value = { |
| 568 | + 'hub_content_arn': "arn:aws:sagemaker:us-east-1:123456789012:model/test-model", |
| 569 | + 'hub_content_document': { |
| 570 | + "GatedBucket": False, |
| 571 | + "RecipeCollection": [ |
| 572 | + { |
| 573 | + "CustomizationTechnique": "SFT", |
| 574 | + "SmtjRecipeTemplateS3Uri": "s3://bucket/template.yaml", |
| 575 | + "SmtjOverrideParamsS3Uri": "s3://bucket/standard_params.json", |
| 576 | + "Name": "standard_sft" |
| 577 | + }, |
| 578 | + { |
| 579 | + "CustomizationTechnique": "SFT", |
| 580 | + "SmtjRecipeTemplateS3Uri": "s3://arn:aws:s3:us-east-1:334772094012:accesspoint/recipes-123456789012/source/template.yaml", |
| 581 | + "SmtjOverrideParamsS3Uri": "s3://arn:aws:s3:us-east-1:334772094012:accesspoint/recipes-{customer_id}/source/params.json", |
| 582 | + "Name": "datamix_sft", |
| 583 | + "IsSubscriptionModel": True |
| 584 | + } |
| 585 | + ] |
| 586 | + } |
| 587 | + } |
| 588 | + |
| 589 | + # Standard recipe returns base params |
| 590 | + standard_params = json.dumps({"max_steps": {"type": "integer", "required": True, "default": 100}}) |
| 591 | + # Subscription recipe returns datamix params |
| 592 | + datamix_params = json.dumps({"customer_data_percent": {"type": "integer", "required": False, "default": 50}}) |
| 593 | + |
| 594 | + mock_s3.get_object.side_effect = [ |
| 595 | + {"Body": Mock(read=Mock(return_value=standard_params.encode()))}, |
| 596 | + {"Body": Mock(read=Mock(return_value=datamix_params.encode()))}, |
| 597 | + ] |
| 598 | + |
| 599 | + options, model_arn, is_gated = _get_fine_tuning_options_and_model_arn( |
| 600 | + "test-model", "SFT", "FULL", mock_session, |
| 601 | + ) |
| 602 | + |
| 603 | + assert "max_steps" in options._specs |
| 604 | + assert "customer_data_percent" in options._specs |
| 605 | + assert options._specs["customer_data_percent"]["default"] is None # defaults are None so they dont serialize unless explicitly set |
| 606 | + |
| 607 | + @patch('sagemaker.train.common_utils.finetune_utils._get_hub_content_metadata') |
| 608 | + def test__get_fine_tuning_options_subscription_disabled_no_datamix_hps(self, mock_get_hub_content): |
| 609 | + """When (default), datamix HPs are NOT available.""" |
| 610 | + mock_session = Mock() |
| 611 | + mock_session.boto_session.region_name = "us-east-1" |
| 612 | + mock_s3 = Mock() |
| 613 | + mock_session.boto_session.client.side_effect = lambda service, **kwargs: mock_s3 |
| 614 | + |
| 615 | + mock_get_hub_content.return_value = { |
| 616 | + 'hub_content_arn': "arn:aws:sagemaker:us-east-1:123456789012:model/test-model", |
| 617 | + 'hub_content_document': { |
| 618 | + "GatedBucket": False, |
| 619 | + "RecipeCollection": [ |
| 620 | + { |
| 621 | + "CustomizationTechnique": "SFT", |
| 622 | + "SmtjRecipeTemplateS3Uri": "s3://bucket/template.yaml", |
| 623 | + "SmtjOverrideParamsS3Uri": "s3://bucket/standard_params.json", |
| 624 | + "Name": "standard_sft" |
| 625 | + }, |
| 626 | + { |
| 627 | + "CustomizationTechnique": "SFT", |
| 628 | + "SmtjRecipeTemplateS3Uri": "s3://arn:aws:s3:us-east-1:334772094012:accesspoint/recipes-{customer_id}/source/template.yaml", |
| 629 | + "SmtjOverrideParamsS3Uri": "s3://arn:aws:s3:us-east-1:334772094012:accesspoint/recipes-{customer_id}/source/params.json", |
| 630 | + "Name": "datamix_sft", |
| 631 | + "IsSubscriptionModel": True |
| 632 | + } |
| 633 | + ] |
| 634 | + } |
| 635 | + } |
| 636 | + |
| 637 | + standard_params = json.dumps({"max_steps": {"type": "integer", "required": True, "default": 100}}) |
| 638 | + mock_s3.get_object.return_value = {"Body": Mock(read=Mock(return_value=standard_params.encode()))} |
| 639 | + |
| 640 | + options, model_arn, is_gated = _get_fine_tuning_options_and_model_arn( |
| 641 | + "test-model", "SFT", "FULL", mock_session, |
| 642 | + ) |
| 643 | + |
| 644 | + assert "max_steps" in options._specs |
| 645 | + assert "customer_data_percent" not in options._specs |
| 646 | + |
| 647 | + @patch('sagemaker.train.common_utils.finetune_utils._get_hub_content_metadata') |
| 648 | + def test__get_fine_tuning_options_subscription_enabled_but_not_subscribed(self, mock_get_hub_content): |
| 649 | + """When but user is NOT subscribed, falls back gracefully.""" |
| 650 | + mock_session = Mock() |
| 651 | + mock_session.boto_session.region_name = "us-east-1" |
| 652 | + mock_s3 = Mock() |
| 653 | + mock_sts = Mock() |
| 654 | + mock_sts.get_caller_identity.return_value = {"Account": "999999999999"} |
| 655 | + mock_session.boto_session.client.side_effect = lambda service, **kwargs: mock_s3 if service == "s3" else mock_sts |
| 656 | + |
| 657 | + mock_get_hub_content.return_value = { |
| 658 | + 'hub_content_arn': "arn:aws:sagemaker:us-east-1:123456789012:model/test-model", |
| 659 | + 'hub_content_document': { |
| 660 | + "GatedBucket": False, |
| 661 | + "RecipeCollection": [ |
| 662 | + { |
| 663 | + "CustomizationTechnique": "SFT", |
| 664 | + "SmtjRecipeTemplateS3Uri": "s3://bucket/template.yaml", |
| 665 | + "SmtjOverrideParamsS3Uri": "s3://bucket/standard_params.json", |
| 666 | + "Name": "standard_sft" |
| 667 | + }, |
| 668 | + { |
| 669 | + "CustomizationTechnique": "SFT", |
| 670 | + "SmtjRecipeTemplateS3Uri": "s3://arn:aws:s3:us-east-1:334772094012:accesspoint/recipes-{customer_id}/source/template.yaml", |
| 671 | + "SmtjOverrideParamsS3Uri": "s3://arn:aws:s3:us-east-1:334772094012:accesspoint/recipes-{customer_id}/source/params.json", |
| 672 | + "Name": "datamix_sft", |
| 673 | + "IsSubscriptionModel": True |
| 674 | + } |
| 675 | + ] |
| 676 | + } |
| 677 | + } |
| 678 | + |
| 679 | + standard_params = json.dumps({"max_steps": {"type": "integer", "required": True, "default": 100}}) |
| 680 | + # First call succeeds (standard recipe), second call fails (access denied) |
| 681 | + mock_s3.get_object.side_effect = [ |
| 682 | + {"Body": Mock(read=Mock(return_value=standard_params.encode()))}, |
| 683 | + Exception("Access Denied"), |
| 684 | + ] |
| 685 | + |
| 686 | + options, model_arn, is_gated = _get_fine_tuning_options_and_model_arn( |
| 687 | + "test-model", "SFT", "FULL", mock_session, |
| 688 | + ) |
| 689 | + |
| 690 | + # Should still have standard params, just not datamix ones |
| 691 | + assert "max_steps" in options._specs |
| 692 | + assert "customer_data_percent" not in options._specs |
0 commit comments