Skip to content

Commit d4f392f

Browse files
authored
feat: add MLFlow experiment link to eval output (#5783)
* feat: add MLFlow experiment link to eval output * refactor: unify MLflow tracking URL generation logic
1 parent 27cb74f commit d4f392f

5 files changed

Lines changed: 336 additions & 25 deletions

File tree

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Shared MLflow presigned URL utilities."""
14+
15+
import logging
16+
from typing import Optional
17+
18+
logger = logging.getLogger(__name__)
19+
20+
21+
def get_presigned_mlflow_experiment_url(
22+
mlflow_resource_arn: str,
23+
mlflow_experiment_name: Optional[str] = None,
24+
) -> Optional[str]:
25+
"""Generate a presigned MLflow URL, optionally deep-linked to an experiment.
26+
27+
Args:
28+
mlflow_resource_arn: MLflow tracking server or app ARN.
29+
mlflow_experiment_name: Optional experiment name for deep-linking.
30+
31+
Returns:
32+
Presigned URL with experiment fragment, or base URL, or None on failure.
33+
"""
34+
try:
35+
from sagemaker.core.utils.utils import SageMakerClient
36+
37+
sm_client = SageMakerClient().sagemaker_client
38+
response = sm_client.create_presigned_mlflow_app_url(Arn=mlflow_resource_arn)
39+
base_url = response.get("AuthorizedUrl")
40+
if not base_url:
41+
return None
42+
43+
if mlflow_experiment_name:
44+
try:
45+
import mlflow
46+
from mlflow.tracking import MlflowClient
47+
48+
mlflow.set_tracking_uri(mlflow_resource_arn)
49+
experiment = MlflowClient(
50+
tracking_uri=mlflow_resource_arn
51+
).get_experiment_by_name(mlflow_experiment_name)
52+
if experiment:
53+
return f"{base_url}#/experiments/{experiment.experiment_id}"
54+
except Exception as e:
55+
logger.debug(f"Failed to resolve MLflow experiment '{mlflow_experiment_name}': {e}")
56+
57+
return base_url
58+
except Exception as e:
59+
logger.debug(f"Failed to generate MLflow experiment URL: {e}")
60+
return None

sagemaker-train/src/sagemaker/train/common_utils/trainer_wait.py

Lines changed: 7 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -182,35 +182,17 @@ def get_mlflow_url(training_job) -> str:
182182
if not hasattr(training_job, 'mlflow_config') or _is_unassigned_attribute(training_job.mlflow_config):
183183
raise ValueError("Training job does not have MLflow configured")
184184

185-
import os
186-
from mlflow.tracking import MlflowClient
187-
import mlflow
188-
from sagemaker.core.utils.utils import SageMakerClient
185+
from sagemaker.train.common_utils.mlflow_url_utils import get_presigned_mlflow_experiment_url
189186

190187
mlflow_arn = training_job.mlflow_config.mlflow_resource_arn
191188
exp_name = training_job.mlflow_config.mlflow_experiment_name
189+
if _is_unassigned_attribute(exp_name):
190+
exp_name = None
192191

193-
# Get presigned base URL
194-
sm_client = SageMakerClient().sagemaker_client
195-
response = sm_client.create_presigned_mlflow_app_url(Arn=mlflow_arn)
196-
base_url = response.get('AuthorizedUrl')
197-
198-
# Try to get experiment ID and append to URL
199-
try:
200-
os.environ['MLFLOW_TRACKING_URI'] = mlflow_arn
201-
mlflow.set_tracking_uri(mlflow_arn)
202-
203-
mlflow_client = MlflowClient(tracking_uri=mlflow_arn)
204-
experiment = mlflow_client.get_experiment_by_name(exp_name)
205-
206-
if experiment:
207-
# Format: base_url#/experiments/{id}
208-
# The base_url already has /auth?authToken=...
209-
return f"{base_url}#/experiments/{experiment.experiment_id}"
210-
except Exception:
211-
pass
212-
213-
return base_url
192+
url = get_presigned_mlflow_experiment_url(mlflow_arn, exp_name)
193+
if url is None:
194+
raise ValueError("Failed to generate presigned MLflow URL")
195+
return url
214196

215197

216198

sagemaker-train/src/sagemaker/train/evaluate/execution.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -920,6 +920,28 @@ def wait(
920920
# Create console with Jupyter support
921921
console = Console(force_jupyter=True)
922922

923+
# MLflow link caching (presigned URLs expire after 5 min)
924+
mlflow_link_cache = {'url': None, 'timestamp': 0}
925+
926+
def get_cached_mlflow_url():
927+
"""Get cached MLflow URL, regenerating every 4 minutes."""
928+
from sagemaker.train.common_utils.trainer_wait import _is_unassigned_attribute
929+
from sagemaker.train.common_utils.mlflow_url_utils import get_presigned_mlflow_experiment_url
930+
931+
current_time = time.time()
932+
if mlflow_link_cache['url'] is None or (current_time - mlflow_link_cache['timestamp']) > 240:
933+
pe = self._pipeline_execution
934+
mlflow_cfg = getattr(pe, 'm_lflow_config', None) if pe else None
935+
if mlflow_cfg and not _is_unassigned_attribute(mlflow_cfg):
936+
arn = getattr(mlflow_cfg, 'mlflow_resource_arn', None)
937+
if arn and not _is_unassigned_attribute(arn):
938+
exp_name = getattr(mlflow_cfg, 'mlflow_experiment_name', None)
939+
if exp_name and _is_unassigned_attribute(exp_name):
940+
exp_name = None
941+
mlflow_link_cache['url'] = get_presigned_mlflow_experiment_url(arn, exp_name)
942+
mlflow_link_cache['timestamp'] = current_time
943+
return mlflow_link_cache['url']
944+
923945
while True:
924946
clear_output(wait=True)
925947
self.refresh()
@@ -960,6 +982,10 @@ def wait(
960982
links.append(f"[bright_blue underline][link={pipeline_url}]🔗 Pipeline Execution (Studio)[/link][/bright_blue underline]")
961983
except Exception:
962984
pass
985+
# Add MLflow experiment link if available
986+
cached_mlflow_url = get_cached_mlflow_url()
987+
if cached_mlflow_url:
988+
links.append(f"[bright_blue underline][link={cached_mlflow_url}]🔗 MLflow Experiment[/link][/bright_blue underline]")
963989
if links:
964990
header_table.add_row("Links", " | ".join(links))
965991

sagemaker-train/tests/unit/train/common_utils/test_trainer_wait.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
_is_unassigned_attribute,
2626
_calculate_training_progress,
2727
_calculate_transition_duration,
28+
get_mlflow_url,
2829
wait
2930
)
3031

@@ -489,3 +490,66 @@ def test_wait_metrics_exception_non_jupyter(self, mock_is_jupyter, mock_setup_ml
489490

490491
# Should complete successfully despite metrics exception
491492
training_job.refresh.assert_called()
493+
494+
495+
class TestGetMlflowUrl:
496+
"""Test cases for get_mlflow_url function."""
497+
498+
@patch("sagemaker.train.common_utils.mlflow_url_utils.get_presigned_mlflow_experiment_url")
499+
def test_delegates_to_shared_helper(self, mock_helper):
500+
"""Test that get_mlflow_url extracts config and delegates to shared helper."""
501+
mock_helper.return_value = "https://mlflow.example.com/auth?token=abc#/experiments/42"
502+
503+
training_job = MagicMock()
504+
training_job.mlflow_config.mlflow_resource_arn = "arn:aws:sagemaker:us-west-2:123:mlflow-app/test"
505+
training_job.mlflow_config.mlflow_experiment_name = "my-experiment"
506+
507+
result = get_mlflow_url(training_job)
508+
509+
mock_helper.assert_called_once_with(
510+
"arn:aws:sagemaker:us-west-2:123:mlflow-app/test",
511+
"my-experiment",
512+
)
513+
assert result == "https://mlflow.example.com/auth?token=abc#/experiments/42"
514+
515+
@patch("sagemaker.train.common_utils.trainer_wait.TrainingJob")
516+
@patch("sagemaker.train.common_utils.mlflow_url_utils.get_presigned_mlflow_experiment_url")
517+
def test_accepts_job_name_string(self, mock_helper, mock_tj_class):
518+
"""Test that a string job name is resolved via TrainingJob.get()."""
519+
mock_helper.return_value = "https://mlflow.example.com/auth"
520+
mock_tj = MagicMock()
521+
mock_tj.mlflow_config.mlflow_resource_arn = "arn:aws:sagemaker:us-west-2:123:mlflow-app/test"
522+
mock_tj.mlflow_config.mlflow_experiment_name = None
523+
mock_tj_class.get.return_value = mock_tj
524+
525+
result = get_mlflow_url("my-training-job")
526+
527+
mock_tj_class.get.assert_called_once_with(training_job_name="my-training-job")
528+
assert result == "https://mlflow.example.com/auth"
529+
530+
def test_raises_when_no_mlflow_config(self):
531+
"""Test raises ValueError when training job has no mlflow config."""
532+
training_job = MagicMock()
533+
training_job.mlflow_config = MockUnassignedAttribute()
534+
535+
with pytest.raises(ValueError, match="does not have MLflow configured"):
536+
get_mlflow_url(training_job)
537+
538+
def test_raises_when_mlflow_config_missing(self):
539+
"""Test raises ValueError when training job lacks mlflow_config attribute."""
540+
training_job = MagicMock(spec=[]) # no attributes
541+
542+
with pytest.raises(ValueError, match="does not have MLflow configured"):
543+
get_mlflow_url(training_job)
544+
545+
@patch("sagemaker.train.common_utils.mlflow_url_utils.get_presigned_mlflow_experiment_url")
546+
def test_raises_when_helper_returns_none(self, mock_helper):
547+
"""Test raises ValueError when presigned URL generation fails."""
548+
mock_helper.return_value = None
549+
550+
training_job = MagicMock()
551+
training_job.mlflow_config.mlflow_resource_arn = "arn:aws:sagemaker:us-west-2:123:mlflow-app/test"
552+
training_job.mlflow_config.mlflow_experiment_name = "exp"
553+
554+
with pytest.raises(ValueError, match="Failed to generate presigned MLflow URL"):
555+
get_mlflow_url(training_job)

0 commit comments

Comments
 (0)