Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 64 additions & 0 deletions sagemaker-train/src/sagemaker/train/evaluate/execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,6 +426,45 @@ def _extract_output_s3_location_from_steps(raw_steps: List[Any], session: Option
return None


def _get_mlflow_experiment_url(mlflow_resource_arn: str, mlflow_experiment_name: Optional[str] = None) -> Optional[str]:

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let me look at this. does require some decoupling since that takes in training job but better structure abstracting for sure

"""Generate a presigned MLflow URL deep-linked to the experiment.

Args:
mlflow_resource_arn: MLflow tracking server or app ARN.
mlflow_experiment_name: Optional experiment name for deep-linking.

Returns:
Presigned URL with experiment fragment, or None on failure.
"""
try:
from sagemaker.core.utils.utils import SageMakerClient

sm_client = SageMakerClient().sagemaker_client
response = sm_client.create_presigned_mlflow_app_url(Arn=mlflow_resource_arn)
base_url = response.get('AuthorizedUrl')
if not base_url:
return None

if mlflow_experiment_name:
try:
import mlflow
from mlflow.tracking import MlflowClient

mlflow.set_tracking_uri(mlflow_resource_arn)
experiment = MlflowClient(tracking_uri=mlflow_resource_arn).get_experiment_by_name(
mlflow_experiment_name
)
if experiment:
return f"{base_url}#/experiments/{experiment.experiment_id}"
except Exception:
pass

return base_url
except Exception as e:
logger.debug(f"Failed to generate MLflow experiment URL: {e}")
return None


class StepDetail(BaseModel):
"""Pipeline step details for tracking execution progress.

Expand Down Expand Up @@ -920,6 +959,27 @@ def wait(
# Create console with Jupyter support
console = Console(force_jupyter=True)

# MLflow link caching (presigned URLs expire after 5 min)
mlflow_link_cache = {'url': None, 'timestamp': 0}

def get_cached_mlflow_url():
"""Get cached MLflow URL, regenerating every 4 minutes."""
from sagemaker.train.common_utils.trainer_wait import _is_unassigned_attribute

current_time = time.time()
if mlflow_link_cache['url'] is None or (current_time - mlflow_link_cache['timestamp']) > 240:
pe = self._pipeline_execution
mlflow_cfg = getattr(pe, 'm_lflow_config', None) if pe else None
if mlflow_cfg and not _is_unassigned_attribute(mlflow_cfg):
arn = getattr(mlflow_cfg, 'mlflow_resource_arn', None)
if arn and not _is_unassigned_attribute(arn):
exp_name = getattr(mlflow_cfg, 'mlflow_experiment_name', None)
if exp_name and _is_unassigned_attribute(exp_name):
exp_name = None
mlflow_link_cache['url'] = _get_mlflow_experiment_url(arn, exp_name)
mlflow_link_cache['timestamp'] = current_time
return mlflow_link_cache['url']

while True:
clear_output(wait=True)
self.refresh()
Expand Down Expand Up @@ -960,6 +1020,10 @@ def wait(
links.append(f"[bright_blue underline][link={pipeline_url}]🔗 Pipeline Execution (Studio)[/link][/bright_blue underline]")
except Exception:
pass
# Add MLflow experiment link if available
cached_mlflow_url = get_cached_mlflow_url()
if cached_mlflow_url:
links.append(f"[bright_blue underline][link={cached_mlflow_url}]🔗 MLflow Experiment[/link][/bright_blue underline]")
if links:
header_table.add_row("Links", " | ".join(links))

Expand Down
179 changes: 179 additions & 0 deletions sagemaker-train/tests/unit/train/evaluate/test_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
_start_pipeline_execution,
_create_execution_from_pipeline_execution,
_extract_output_s3_location_from_steps,
_get_mlflow_experiment_url,
)
from sagemaker.train.evaluate.constants import EvalType, _get_pipeline_name, _get_pipeline_name_prefix

Expand Down Expand Up @@ -1465,3 +1466,181 @@ def test_complete_get_workflow(self, mock_pe_class, mock_session):


# Additional tests for improved coverage - removed as they don't add significant value


# ============================================================================
# Tests for MLflow Link Functions
# ============================================================================

MOCK_MLFLOW_ARN = "arn:aws:sagemaker:us-west-2:123456789012:mlflow-tracking-server/test-server"
MOCK_PRESIGNED_URL = "https://mlflow.example.com/auth?authToken=abc123"


class TestGetMlflowExperimentUrl:
"""Tests for _get_mlflow_experiment_url function."""

@patch("sagemaker.core.utils.utils.SageMakerClient")
def test_returns_deep_link_with_experiment(self, mock_sm_client_cls):
"""Test returns URL with experiment fragment when experiment exists."""
mock_client = MagicMock()
mock_client.sagemaker_client.create_presigned_mlflow_app_url.return_value = {
"AuthorizedUrl": MOCK_PRESIGNED_URL
}
mock_sm_client_cls.return_value = mock_client

mock_experiment = MagicMock()
mock_experiment.experiment_id = "42"

with patch("mlflow.set_tracking_uri"), \
patch("mlflow.tracking.MlflowClient") as mock_mlflow_client:
mock_mlflow_client.return_value.get_experiment_by_name.return_value = mock_experiment

result = _get_mlflow_experiment_url(MOCK_MLFLOW_ARN, "my-experiment")

assert result == f"{MOCK_PRESIGNED_URL}#/experiments/42"

@patch("sagemaker.core.utils.utils.SageMakerClient")
def test_returns_base_url_when_no_experiment_name(self, mock_sm_client_cls):
"""Test returns base URL when experiment name is None."""
mock_client = MagicMock()
mock_client.sagemaker_client.create_presigned_mlflow_app_url.return_value = {
"AuthorizedUrl": MOCK_PRESIGNED_URL
}
mock_sm_client_cls.return_value = mock_client

result = _get_mlflow_experiment_url(MOCK_MLFLOW_ARN, None)

assert result == MOCK_PRESIGNED_URL

@patch("sagemaker.core.utils.utils.SageMakerClient")
def test_returns_base_url_when_experiment_lookup_fails(self, mock_sm_client_cls):
"""Test falls back to base URL when MLflow experiment lookup raises."""
mock_client = MagicMock()
mock_client.sagemaker_client.create_presigned_mlflow_app_url.return_value = {
"AuthorizedUrl": MOCK_PRESIGNED_URL
}
mock_sm_client_cls.return_value = mock_client

with patch("mlflow.set_tracking_uri"), \
patch("mlflow.tracking.MlflowClient") as mock_mlflow_client:
mock_mlflow_client.return_value.get_experiment_by_name.side_effect = Exception("connection error")

result = _get_mlflow_experiment_url(MOCK_MLFLOW_ARN, "my-experiment")

assert result == MOCK_PRESIGNED_URL

@patch("sagemaker.core.utils.utils.SageMakerClient")
def test_returns_base_url_when_experiment_not_found(self, mock_sm_client_cls):
"""Test falls back to base URL when experiment doesn't exist."""
mock_client = MagicMock()
mock_client.sagemaker_client.create_presigned_mlflow_app_url.return_value = {
"AuthorizedUrl": MOCK_PRESIGNED_URL
}
mock_sm_client_cls.return_value = mock_client

with patch("mlflow.set_tracking_uri"), \
patch("mlflow.tracking.MlflowClient") as mock_mlflow_client:
mock_mlflow_client.return_value.get_experiment_by_name.return_value = None

result = _get_mlflow_experiment_url(MOCK_MLFLOW_ARN, "nonexistent")

assert result == MOCK_PRESIGNED_URL

@patch("sagemaker.core.utils.utils.SageMakerClient")
def test_returns_none_when_presigned_url_fails(self, mock_sm_client_cls):
"""Test returns None when create_presigned_mlflow_app_url raises."""
mock_client = MagicMock()
mock_client.sagemaker_client.create_presigned_mlflow_app_url.side_effect = Exception("access denied")
mock_sm_client_cls.return_value = mock_client

result = _get_mlflow_experiment_url(MOCK_MLFLOW_ARN, "my-experiment")

assert result is None

@patch("sagemaker.core.utils.utils.SageMakerClient")
def test_returns_none_when_authorized_url_empty(self, mock_sm_client_cls):
"""Test returns None when AuthorizedUrl is missing from response."""
mock_client = MagicMock()
mock_client.sagemaker_client.create_presigned_mlflow_app_url.return_value = {}
mock_sm_client_cls.return_value = mock_client

result = _get_mlflow_experiment_url(MOCK_MLFLOW_ARN, "my-experiment")

assert result is None


class TestGetCachedMlflowUrl:
"""Tests for the get_cached_mlflow_url closure inside wait()."""

@patch("sagemaker.core.utils.utils.SageMakerClient")
def test_cache_hit_within_240s(self, mock_sm_client_cls):
"""Test that cached URL is returned within 240s window."""
mock_client = MagicMock()
mock_client.sagemaker_client.create_presigned_mlflow_app_url.return_value = {
"AuthorizedUrl": MOCK_PRESIGNED_URL
}
mock_sm_client_cls.return_value = mock_client

mlflow_link_cache = {"url": None, "timestamp": 0}

def get_cached(current_time):
if mlflow_link_cache["url"] is None or (current_time - mlflow_link_cache["timestamp"]) > 240:
mlflow_link_cache["url"] = _get_mlflow_experiment_url(MOCK_MLFLOW_ARN, None)
mlflow_link_cache["timestamp"] = current_time
return mlflow_link_cache["url"]

url1 = get_cached(1000.0)
assert url1 == MOCK_PRESIGNED_URL
assert mock_client.sagemaker_client.create_presigned_mlflow_app_url.call_count == 1

# Second call within 240s returns cached — no additional API call
url2 = get_cached(1100.0)
assert url2 == url1
assert mock_client.sagemaker_client.create_presigned_mlflow_app_url.call_count == 1

@patch("sagemaker.core.utils.utils.SageMakerClient")
def test_cache_refresh_after_240s(self, mock_sm_client_cls):
"""Test that URL is regenerated after 240s."""
mock_client = MagicMock()
mock_client.sagemaker_client.create_presigned_mlflow_app_url.side_effect = [
{"AuthorizedUrl": MOCK_PRESIGNED_URL},
{"AuthorizedUrl": "https://mlflow.example.com/auth?authToken=newtoken"},
]
mock_sm_client_cls.return_value = mock_client

mlflow_link_cache = {"url": None, "timestamp": 0}

def get_cached(current_time):
if mlflow_link_cache["url"] is None or (current_time - mlflow_link_cache["timestamp"]) > 240:
mlflow_link_cache["url"] = _get_mlflow_experiment_url(MOCK_MLFLOW_ARN, None)
mlflow_link_cache["timestamp"] = current_time
return mlflow_link_cache["url"]

url1 = get_cached(1000.0)
assert url1 == MOCK_PRESIGNED_URL

# After 241 seconds, should refresh
url2 = get_cached(1241.0)
assert url2 == "https://mlflow.example.com/auth?authToken=newtoken"
assert mock_client.sagemaker_client.create_presigned_mlflow_app_url.call_count == 2

def test_returns_none_when_no_mlflow_config(self):
"""Test returns None when pipeline execution has no mlflow config."""
from sagemaker.train.common_utils.trainer_wait import _is_unassigned_attribute

mock_pe = MagicMock()
mock_pe.m_lflow_config = MockUnassigned()

mlflow_cfg = getattr(mock_pe, "m_lflow_config", None)
assert _is_unassigned_attribute(mlflow_cfg) is True

def test_returns_none_when_arn_is_unassigned(self):
"""Test returns None when mlflow_resource_arn is Unassigned."""
from sagemaker.train.common_utils.trainer_wait import _is_unassigned_attribute

mock_mlflow_cfg = MagicMock()
mock_mlflow_cfg.mlflow_resource_arn = MockUnassigned()
mock_mlflow_cfg.__class__ = type("MLflowConfiguration", (), {})

assert not _is_unassigned_attribute(mock_mlflow_cfg)
assert _is_unassigned_attribute(mock_mlflow_cfg.mlflow_resource_arn)
Loading