|
| 1 | +"""Unit tests for metrics_visualizer module.""" |
| 2 | +import pytest |
| 3 | +from unittest.mock import Mock, patch, MagicMock |
| 4 | + |
| 5 | + |
| 6 | +class TestParseJobArn: |
| 7 | + def test_training_job_arn(self): |
| 8 | + from sagemaker.train.common_utils.metrics_visualizer import _parse_job_arn |
| 9 | + result = _parse_job_arn("arn:aws:sagemaker:us-west-2:123456789012:training-job/my-job") |
| 10 | + assert result == ("us-west-2", "training-job/my-job") |
| 11 | + |
| 12 | + def test_processing_job_arn(self): |
| 13 | + from sagemaker.train.common_utils.metrics_visualizer import _parse_job_arn |
| 14 | + result = _parse_job_arn("arn:aws:sagemaker:us-east-1:123456789012:processing-job/my-job") |
| 15 | + assert result == ("us-east-1", "processing-job/my-job") |
| 16 | + |
| 17 | + def test_invalid_arn_returns_none(self): |
| 18 | + from sagemaker.train.common_utils.metrics_visualizer import _parse_job_arn |
| 19 | + assert _parse_job_arn("not-an-arn") is None |
| 20 | + |
| 21 | + |
| 22 | +class TestGetConsoleJobUrl: |
| 23 | + def test_training_job(self): |
| 24 | + from sagemaker.train.common_utils.metrics_visualizer import get_console_job_url |
| 25 | + url = get_console_job_url("arn:aws:sagemaker:us-west-2:123456789012:training-job/my-job") |
| 26 | + assert url == "https://us-west-2.console.aws.amazon.com/sagemaker/home?region=us-west-2#/jobs/my-job" |
| 27 | + |
| 28 | + def test_invalid_arn_returns_empty(self): |
| 29 | + from sagemaker.train.common_utils.metrics_visualizer import get_console_job_url |
| 30 | + assert get_console_job_url("not-an-arn") == "" |
| 31 | + |
| 32 | + def test_unknown_job_type_returns_empty(self): |
| 33 | + from sagemaker.train.common_utils.metrics_visualizer import get_console_job_url |
| 34 | + assert get_console_job_url("arn:aws:sagemaker:us-west-2:123456789012:unknown-job/my-job") == "" |
| 35 | + |
| 36 | + |
| 37 | +class TestGetCloudwatchLogsUrl: |
| 38 | + def test_training_job(self): |
| 39 | + from sagemaker.train.common_utils.metrics_visualizer import get_cloudwatch_logs_url |
| 40 | + url = get_cloudwatch_logs_url("arn:aws:sagemaker:us-west-2:123456789012:training-job/my-job") |
| 41 | + assert "us-west-2" in url |
| 42 | + assert "TrainingJobs" in url |
| 43 | + assert "my-job" in url |
| 44 | + |
| 45 | + def test_invalid_arn_returns_empty(self): |
| 46 | + from sagemaker.train.common_utils.metrics_visualizer import get_cloudwatch_logs_url |
| 47 | + assert get_cloudwatch_logs_url("not-an-arn") == "" |
| 48 | + |
| 49 | + |
| 50 | +class TestGetStudioUrl: |
| 51 | + @patch("sagemaker.train.common_utils.metrics_visualizer._get_studio_base_url") |
| 52 | + @patch("sagemaker.core.utils.utils.SageMakerClient") |
| 53 | + def test_with_training_job_object(self, mock_client_cls, mock_base_url): |
| 54 | + from sagemaker.train.common_utils.metrics_visualizer import get_studio_url |
| 55 | + mock_client_cls.return_value.region_name = "us-west-2" |
| 56 | + mock_base_url.return_value = "https://studio-d-abc.studio.us-west-2.sagemaker.aws" |
| 57 | + |
| 58 | + mock_job = Mock() |
| 59 | + mock_job.training_job_name = "my-job" |
| 60 | + |
| 61 | + url = get_studio_url(mock_job) |
| 62 | + assert url == "https://studio-d-abc.studio.us-west-2.sagemaker.aws/jobs/train/my-job" |
| 63 | + mock_base_url.assert_called_once_with("us-west-2") |
| 64 | + |
| 65 | + @patch("sagemaker.train.common_utils.metrics_visualizer._get_studio_base_url") |
| 66 | + def test_with_arn_string(self, mock_base_url): |
| 67 | + from sagemaker.train.common_utils.metrics_visualizer import get_studio_url |
| 68 | + mock_base_url.return_value = "https://studio-d-abc.studio.us-west-2.sagemaker.aws" |
| 69 | + |
| 70 | + url = get_studio_url("arn:aws:sagemaker:us-west-2:123456789012:training-job/my-job") |
| 71 | + assert url == "https://studio-d-abc.studio.us-west-2.sagemaker.aws/jobs/train/my-job" |
| 72 | + mock_base_url.assert_called_once_with("us-west-2") |
| 73 | + |
| 74 | + @patch("sagemaker.train.common_utils.metrics_visualizer._get_studio_base_url") |
| 75 | + @patch("sagemaker.core.utils.utils.SageMakerClient") |
| 76 | + @patch("sagemaker.train.common_utils.metrics_visualizer.TrainingJob") |
| 77 | + def test_with_job_name_string(self, mock_tj_cls, mock_client_cls, mock_base_url): |
| 78 | + from sagemaker.train.common_utils.metrics_visualizer import get_studio_url |
| 79 | + mock_client_cls.return_value.region_name = "us-west-2" |
| 80 | + mock_base_url.return_value = "https://studio-d-abc.studio.us-west-2.sagemaker.aws" |
| 81 | + mock_tj_cls.get.return_value.training_job_name = "my-job" |
| 82 | + |
| 83 | + url = get_studio_url("my-job") |
| 84 | + assert url == "https://studio-d-abc.studio.us-west-2.sagemaker.aws/jobs/train/my-job" |
| 85 | + |
| 86 | + @patch("sagemaker.train.common_utils.metrics_visualizer._get_studio_base_url") |
| 87 | + @patch("sagemaker.core.utils.utils.SageMakerClient") |
| 88 | + def test_returns_empty_when_no_domain(self, mock_client_cls, mock_base_url): |
| 89 | + from sagemaker.train.common_utils.metrics_visualizer import get_studio_url |
| 90 | + mock_client_cls.return_value.region_name = "us-west-2" |
| 91 | + mock_base_url.return_value = "" |
| 92 | + |
| 93 | + url = get_studio_url(Mock(training_job_name="my-job")) |
| 94 | + assert url == "" |
| 95 | + |
| 96 | + |
| 97 | +class TestGetAvailableMetrics: |
| 98 | + @patch("sagemaker.train.common_utils.metrics_visualizer.TrainingJob") |
| 99 | + def test_returns_empty_when_no_mlflow_config(self, _): |
| 100 | + from sagemaker.train.common_utils.metrics_visualizer import get_available_metrics |
| 101 | + mock_job = Mock(spec=[]) # no mlflow_config attribute |
| 102 | + assert get_available_metrics(mock_job) == [] |
| 103 | + |
| 104 | + @patch("sagemaker.train.common_utils.metrics_visualizer.TrainingJob") |
| 105 | + def test_returns_empty_when_mlflow_config_falsy(self, _): |
| 106 | + from sagemaker.train.common_utils.metrics_visualizer import get_available_metrics |
| 107 | + mock_job = Mock() |
| 108 | + mock_job.mlflow_config = None |
| 109 | + assert get_available_metrics(mock_job) == [] |
| 110 | + |
| 111 | + @patch("mlflow.get_run") |
| 112 | + @patch("mlflow.set_tracking_uri") |
| 113 | + def test_returns_metric_names(self, mock_set_uri, mock_get_run): |
| 114 | + from sagemaker.train.common_utils.metrics_visualizer import get_available_metrics |
| 115 | + mock_job = Mock() |
| 116 | + mock_job.mlflow_config.mlflow_resource_arn = "arn:aws:sagemaker:us-west-2:123:mlflow-tracking/abc" |
| 117 | + mock_job.mlflow_details.mlflow_run_id = "run-123" |
| 118 | + mock_get_run.return_value.data.metrics = {"loss": 0.5, "accuracy": 0.9} |
| 119 | + |
| 120 | + result = get_available_metrics(mock_job) |
| 121 | + assert set(result) == {"loss", "accuracy"} |
0 commit comments