Skip to content

Commit 41ccc9c

Browse files
committed
Address PR readiness
1 parent c6ba964 commit 41ccc9c

File tree

4 files changed

+140
-15
lines changed

4 files changed

+140
-15
lines changed

sagemaker-train/pyproject.toml

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,6 @@ dependencies = [
4343
"sagemaker-mlflow>=0.0.1,<1.0.0",
4444
"mlflow>=3.0.0,<4.0.0",
4545
"nest_asyncio>=1.5.0",
46-
"ipywidgets>=8.0.0",
47-
"rich>=13.0.0",
48-
"matplotlib>=3.5.0",
4946
]
5047

5148
[project.urls]
@@ -64,6 +61,11 @@ test = [
6461
"graphene",
6562
"IPython"
6663
]
64+
notebook = [
65+
"ipywidgets>=8.0.0",
66+
"rich>=13.0.0",
67+
"matplotlib>=3.5.0",
68+
]
6769

6870
[tool.setuptools.packages.find]
6971
where = ["src/"]

sagemaker-train/src/sagemaker/train/common_utils/trainer_wait.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -182,16 +182,16 @@ def get_mlflow_url(training_job) -> str:
182182
if not hasattr(training_job, 'mlflow_config') or _is_unassigned_attribute(training_job.mlflow_config):
183183
raise ValueError("Training job does not have MLflow configured")
184184

185-
import boto3
186185
import os
187186
from mlflow.tracking import MlflowClient
188187
import mlflow
189-
188+
from sagemaker.core.utils.utils import SageMakerClient
189+
190190
mlflow_arn = training_job.mlflow_config.mlflow_resource_arn
191191
exp_name = training_job.mlflow_config.mlflow_experiment_name
192-
192+
193193
# Get presigned base URL
194-
sm_client = boto3.client('sagemaker')
194+
sm_client = SageMakerClient().sagemaker_client
195195
response = sm_client.create_presigned_mlflow_app_url(Arn=mlflow_arn)
196196
base_url = response.get('AuthorizedUrl')
197197

sagemaker-train/src/sagemaker/train/evaluate/execution.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -931,16 +931,14 @@ def wait(
931931
header_table.add_column("Property", style="cyan bold", width=20)
932932
header_table.add_column("Value", style="dim", overflow="fold")
933933

934-
# Extract pipeline name and region from execution ARN
934+
# Extract pipeline name and exec_id from execution ARN
935935
pipeline_name = None
936936
exec_id = ''
937-
region = None
938937
if self.arn:
939938
arn_parts = self.arn.split('/')
940939
if len(arn_parts) >= 4:
941940
pipeline_name = arn_parts[-3]
942941
exec_id = arn_parts[-1]
943-
region = self.arn.split(":")[3] if len(self.arn.split(":")) > 3 else None
944942
# Use execution display name if available, fall back to self.name
945943
display_name = self.name
946944
if self._pipeline_execution:
@@ -952,8 +950,10 @@ def wait(
952950
# Build links row
953951
links = []
954952
try:
953+
from sagemaker.core.utils.utils import SageMakerClient
955954
from sagemaker.train.common_utils.metrics_visualizer import _is_in_studio, _get_studio_base_url
956-
if region and pipeline_name and _is_in_studio():
955+
if pipeline_name and _is_in_studio():
956+
region = SageMakerClient().region_name
957957
base = _get_studio_base_url(region)
958958
if base:
959959
pipeline_url = f"{base}/jobs/evaluation/detail?pipeline_name={pipeline_name}&execution_id={exec_id}"
@@ -1052,12 +1052,13 @@ def wait(
10521052
links_table = Table(show_header=True, header_style="bold magenta", box=None, padding=(0, 1))
10531053
links_table.add_column("Step", style="cyan", width=20)
10541054
links_table.add_column("Console", style="dim")
1055+
from sagemaker.core.utils.utils import SageMakerClient
10551056
from sagemaker.train.common_utils.metrics_visualizer import (
10561057
_is_in_studio, _parse_job_arn, _get_studio_base_url,
10571058
get_console_job_url, get_cloudwatch_logs_url,
10581059
)
10591060
in_studio = _is_in_studio()
1060-
studio_base = _get_studio_base_url(region) if in_studio else ""
1061+
studio_base = _get_studio_base_url(SageMakerClient().region_name) if in_studio else ""
10611062
if in_studio:
10621063
links_table.add_column("Studio", style="dim")
10631064
links_table.add_column("Logs", style="dim")
@@ -1296,15 +1297,16 @@ def _convert_to_subclass(self, eval_type: EvalType) -> 'EvaluationPipelineExecut
12961297
@staticmethod
12971298
def _extract_job_arn_from_metadata(step) -> Optional[str]:
12981299
"""Extract the underlying job ARN from a pipeline step's metadata."""
1300+
from sagemaker.train.common_utils.trainer_wait import _is_unassigned_attribute
12991301
metadata = getattr(step, 'metadata', None)
1300-
if metadata is None or 'Unassigned' in metadata.__class__.__name__:
1302+
if metadata is None or _is_unassigned_attribute(metadata):
13011303
return None
13021304
for attr in ('training_job', 'processing_job', 'transform_job', 'tuning_job',
13031305
'auto_ml_job', 'compilation_job'):
13041306
job_meta = getattr(metadata, attr, None)
1305-
if job_meta is not None and not ('Unassigned' in job_meta.__class__.__name__):
1307+
if job_meta is not None and not _is_unassigned_attribute(job_meta):
13061308
arn = getattr(job_meta, 'arn', None)
1307-
if arn and not ('Unassigned' in arn.__class__.__name__):
1309+
if arn and not _is_unassigned_attribute(arn):
13081310
return str(arn)
13091311
return None
13101312

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
"""Unit tests for metrics_visualizer module."""
2+
import pytest
3+
from unittest.mock import Mock, patch, MagicMock
4+
5+
6+
class TestParseJobArn:
7+
def test_training_job_arn(self):
8+
from sagemaker.train.common_utils.metrics_visualizer import _parse_job_arn
9+
result = _parse_job_arn("arn:aws:sagemaker:us-west-2:123456789012:training-job/my-job")
10+
assert result == ("us-west-2", "training-job/my-job")
11+
12+
def test_processing_job_arn(self):
13+
from sagemaker.train.common_utils.metrics_visualizer import _parse_job_arn
14+
result = _parse_job_arn("arn:aws:sagemaker:us-east-1:123456789012:processing-job/my-job")
15+
assert result == ("us-east-1", "processing-job/my-job")
16+
17+
def test_invalid_arn_returns_none(self):
18+
from sagemaker.train.common_utils.metrics_visualizer import _parse_job_arn
19+
assert _parse_job_arn("not-an-arn") is None
20+
21+
22+
class TestGetConsoleJobUrl:
23+
def test_training_job(self):
24+
from sagemaker.train.common_utils.metrics_visualizer import get_console_job_url
25+
url = get_console_job_url("arn:aws:sagemaker:us-west-2:123456789012:training-job/my-job")
26+
assert url == "https://us-west-2.console.aws.amazon.com/sagemaker/home?region=us-west-2#/jobs/my-job"
27+
28+
def test_invalid_arn_returns_empty(self):
29+
from sagemaker.train.common_utils.metrics_visualizer import get_console_job_url
30+
assert get_console_job_url("not-an-arn") == ""
31+
32+
def test_unknown_job_type_returns_empty(self):
33+
from sagemaker.train.common_utils.metrics_visualizer import get_console_job_url
34+
assert get_console_job_url("arn:aws:sagemaker:us-west-2:123456789012:unknown-job/my-job") == ""
35+
36+
37+
class TestGetCloudwatchLogsUrl:
38+
def test_training_job(self):
39+
from sagemaker.train.common_utils.metrics_visualizer import get_cloudwatch_logs_url
40+
url = get_cloudwatch_logs_url("arn:aws:sagemaker:us-west-2:123456789012:training-job/my-job")
41+
assert "us-west-2" in url
42+
assert "TrainingJobs" in url
43+
assert "my-job" in url
44+
45+
def test_invalid_arn_returns_empty(self):
46+
from sagemaker.train.common_utils.metrics_visualizer import get_cloudwatch_logs_url
47+
assert get_cloudwatch_logs_url("not-an-arn") == ""
48+
49+
50+
class TestGetStudioUrl:
51+
@patch("sagemaker.train.common_utils.metrics_visualizer._get_studio_base_url")
52+
@patch("sagemaker.core.utils.utils.SageMakerClient")
53+
def test_with_training_job_object(self, mock_client_cls, mock_base_url):
54+
from sagemaker.train.common_utils.metrics_visualizer import get_studio_url
55+
mock_client_cls.return_value.region_name = "us-west-2"
56+
mock_base_url.return_value = "https://studio-d-abc.studio.us-west-2.sagemaker.aws"
57+
58+
mock_job = Mock()
59+
mock_job.training_job_name = "my-job"
60+
61+
url = get_studio_url(mock_job)
62+
assert url == "https://studio-d-abc.studio.us-west-2.sagemaker.aws/jobs/train/my-job"
63+
mock_base_url.assert_called_once_with("us-west-2")
64+
65+
@patch("sagemaker.train.common_utils.metrics_visualizer._get_studio_base_url")
66+
def test_with_arn_string(self, mock_base_url):
67+
from sagemaker.train.common_utils.metrics_visualizer import get_studio_url
68+
mock_base_url.return_value = "https://studio-d-abc.studio.us-west-2.sagemaker.aws"
69+
70+
url = get_studio_url("arn:aws:sagemaker:us-west-2:123456789012:training-job/my-job")
71+
assert url == "https://studio-d-abc.studio.us-west-2.sagemaker.aws/jobs/train/my-job"
72+
mock_base_url.assert_called_once_with("us-west-2")
73+
74+
@patch("sagemaker.train.common_utils.metrics_visualizer._get_studio_base_url")
75+
@patch("sagemaker.core.utils.utils.SageMakerClient")
76+
@patch("sagemaker.train.common_utils.metrics_visualizer.TrainingJob")
77+
def test_with_job_name_string(self, mock_tj_cls, mock_client_cls, mock_base_url):
78+
from sagemaker.train.common_utils.metrics_visualizer import get_studio_url
79+
mock_client_cls.return_value.region_name = "us-west-2"
80+
mock_base_url.return_value = "https://studio-d-abc.studio.us-west-2.sagemaker.aws"
81+
mock_tj_cls.get.return_value.training_job_name = "my-job"
82+
83+
url = get_studio_url("my-job")
84+
assert url == "https://studio-d-abc.studio.us-west-2.sagemaker.aws/jobs/train/my-job"
85+
86+
@patch("sagemaker.train.common_utils.metrics_visualizer._get_studio_base_url")
87+
@patch("sagemaker.core.utils.utils.SageMakerClient")
88+
def test_returns_empty_when_no_domain(self, mock_client_cls, mock_base_url):
89+
from sagemaker.train.common_utils.metrics_visualizer import get_studio_url
90+
mock_client_cls.return_value.region_name = "us-west-2"
91+
mock_base_url.return_value = ""
92+
93+
url = get_studio_url(Mock(training_job_name="my-job"))
94+
assert url == ""
95+
96+
97+
class TestGetAvailableMetrics:
98+
@patch("sagemaker.train.common_utils.metrics_visualizer.TrainingJob")
99+
def test_returns_empty_when_no_mlflow_config(self, _):
100+
from sagemaker.train.common_utils.metrics_visualizer import get_available_metrics
101+
mock_job = Mock(spec=[]) # no mlflow_config attribute
102+
assert get_available_metrics(mock_job) == []
103+
104+
@patch("sagemaker.train.common_utils.metrics_visualizer.TrainingJob")
105+
def test_returns_empty_when_mlflow_config_falsy(self, _):
106+
from sagemaker.train.common_utils.metrics_visualizer import get_available_metrics
107+
mock_job = Mock()
108+
mock_job.mlflow_config = None
109+
assert get_available_metrics(mock_job) == []
110+
111+
@patch("mlflow.get_run")
112+
@patch("mlflow.set_tracking_uri")
113+
def test_returns_metric_names(self, mock_set_uri, mock_get_run):
114+
from sagemaker.train.common_utils.metrics_visualizer import get_available_metrics
115+
mock_job = Mock()
116+
mock_job.mlflow_config.mlflow_resource_arn = "arn:aws:sagemaker:us-west-2:123:mlflow-tracking/abc"
117+
mock_job.mlflow_details.mlflow_run_id = "run-123"
118+
mock_get_run.return_value.data.metrics = {"loss": 0.5, "accuracy": 0.9}
119+
120+
result = get_available_metrics(mock_job)
121+
assert set(result) == {"loss", "accuracy"}

0 commit comments

Comments
 (0)