Skip to content

Commit b34d0d8

Browse files
committed
Update link to console, conditionally display studio link, update link color to blue
1 parent 0297d0a commit b34d0d8

File tree

4 files changed

+157
-61
lines changed

4 files changed

+157
-61
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,3 +40,4 @@ env/
4040
sagemaker_train/src/**/container_drivers/sm_train.sh
4141
sagemaker_train/src/**/container_drivers/sourcecode.json
4242
sagemaker_train/src/**/container_drivers/distributed.json
43+
.kiro

sagemaker-train/src/sagemaker/train/common_utils/metrics_visualizer.py

Lines changed: 77 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,85 @@
22

33
import logging
44
from typing import Optional, List, Dict, Any
5-
import boto3
65
from sagemaker.core.resources import TrainingJob
76

87
logger = logging.getLogger(__name__)
98

109

10+
def _is_in_studio() -> bool:
11+
"""Check if running inside SageMaker Studio."""
12+
from sagemaker.train.common_utils.finetune_utils import _read_domain_id_from_metadata
13+
return _read_domain_id_from_metadata() is not None
14+
15+
16+
def _get_studio_base_url(region: str) -> str:
17+
"""Get Studio base URL, or empty string if domain not resolvable."""
18+
from sagemaker.train.common_utils.finetune_utils import _read_domain_id_from_metadata
19+
domain_id = _read_domain_id_from_metadata()
20+
if not domain_id or not region:
21+
return ""
22+
return f"https://studio-{domain_id}.studio.{region}.sagemaker.aws"
23+
24+
25+
def _parse_job_arn(job_arn: str):
26+
"""Parse a SageMaker job ARN into (region, resource) or None."""
27+
import re
28+
m = re.match(r'arn:aws(?:-[a-z]+)?:sagemaker:([a-z0-9-]+):\d+:(\S+)', job_arn)
29+
return (m.group(1), m.group(2)) if m else None
30+
31+
32+
def get_console_job_url(job_arn: str) -> str:
33+
"""Get AWS Console URL for a SageMaker job ARN.
34+
35+
Args:
36+
job_arn: Full ARN like arn:aws:sagemaker:us-east-1:123:training-job/my-job
37+
38+
Returns:
39+
Console URL or empty string.
40+
"""
41+
parsed = _parse_job_arn(job_arn)
42+
if not parsed:
43+
return ""
44+
region, resource = parsed
45+
job_type_map = {
46+
"training-job/": "#/jobs/",
47+
"processing-job/": "#/processing-jobs/",
48+
"transform-job/": "#/transform-jobs/",
49+
}
50+
for prefix, fragment in job_type_map.items():
51+
if resource.startswith(prefix):
52+
job_name = resource.split("/", 1)[1]
53+
return f"https://{region}.console.aws.amazon.com/sagemaker/home?region={region}{fragment}{job_name}"
54+
return ""
55+
56+
57+
def get_cloudwatch_logs_url(job_arn: str) -> str:
58+
"""Get CloudWatch Logs console URL for a SageMaker job ARN.
59+
60+
Returns:
61+
CloudWatch console URL or empty string.
62+
"""
63+
parsed = _parse_job_arn(job_arn)
64+
if not parsed:
65+
return ""
66+
region, resource = parsed
67+
log_group_map = {
68+
"training-job/": "/aws/sagemaker/TrainingJobs",
69+
"processing-job/": "/aws/sagemaker/ProcessingJobs",
70+
"transform-job/": "/aws/sagemaker/TransformJobs",
71+
}
72+
for prefix, log_group in log_group_map.items():
73+
if resource.startswith(prefix):
74+
job_name = resource.split("/", 1)[1]
75+
encoded_group = log_group.replace("/", "$252F")
76+
return (
77+
f"https://{region}.console.aws.amazon.com/cloudwatch/home?region={region}"
78+
f"#logsV2:log-groups/log-group/{encoded_group}"
79+
f"$3FlogStreamNameFilter$3D{job_name}"
80+
)
81+
return ""
82+
83+
1184
def get_studio_url(training_job, domain_id: str = None) -> str:
1285
"""Get SageMaker Studio URL for training job logs.
1386
@@ -43,34 +116,10 @@ def get_studio_url(training_job, domain_id: str = None) -> str:
43116
region = training_job.region if hasattr(training_job, 'region') and training_job.region else 'us-east-1'
44117
job_name = training_job.training_job_name
45118

46-
# Auto-detect domain if not provided
47-
if not domain_id:
48-
# First try Studio metadata (when running inside Studio)
49-
try:
50-
import os, json as _json
51-
metadata_path = '/opt/ml/metadata/resource-metadata.json'
52-
if os.path.exists(metadata_path):
53-
with open(metadata_path, 'r') as f:
54-
domain_id = _json.load(f).get('DomainId')
55-
except Exception:
56-
pass
57-
58-
if not domain_id:
59-
# Fall back to list_domains, sorted by creation time for deterministic results
60-
try:
61-
sm_client = boto3.client('sagemaker', region_name=region)
62-
domains = sm_client.list_domains()['Domains']
63-
if domains:
64-
domains.sort(key=lambda d: d.get('CreationTime', ''))
65-
domain_id = domains[0]['DomainId']
66-
except Exception:
67-
pass
68-
69-
if not domain_id:
119+
base = _get_studio_base_url(region)
120+
if not base:
70121
return ""
71-
72-
# Studio URL format: https://studio-{domain_id}.studio.{region}.sagemaker.aws/jobs/train/{job_name}
73-
return f"https://studio-{domain_id}.studio.{region}.sagemaker.aws/jobs/train/{job_name}"
122+
return f"{base}/jobs/train/{job_name}"
74123

75124

76125
def display_job_links_html(rows: list, as_html: bool = False):

sagemaker-train/src/sagemaker/train/common_utils/trainer_wait.py

Lines changed: 28 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -299,34 +299,45 @@ def get_cached_mlflow_url():
299299

300300
clear_output(wait=True)
301301

302-
# Header section with training job name
302+
# Header section with training job info
303303
header_table = Table(show_header=False, box=None, padding=(0, 1))
304304
header_table.add_column("Property", style="cyan bold", width=20)
305305
header_table.add_column("Value", style="dim", overflow="fold")
306306

307-
# Add Studio job link
308-
try:
309-
from sagemaker.train.common_utils.metrics_visualizer import get_studio_url
310-
studio_url = get_studio_url(training_job)
311-
header_table.add_row("TrainingJob Name", f"[underline][link={studio_url}]🔗 {training_job.training_job_name}[/link][/underline]")
312-
except Exception:
313-
header_table.add_row("TrainingJob Name", f"[bold green]{training_job.training_job_name}[/bold green]")
314-
307+
header_table.add_row("TrainingJob Name", f"[bold green]{training_job.training_job_name}[/bold green]")
315308
header_table.add_row("TrainingJob ARN", f"[dim]{training_job.training_job_arn}[/dim]")
316309

317-
# Add MLflow link to header if available
310+
# Build links row
311+
links = []
312+
try:
313+
from sagemaker.train.common_utils.metrics_visualizer import (
314+
_is_in_studio, get_console_job_url, get_cloudwatch_logs_url, get_studio_url
315+
)
316+
if _is_in_studio():
317+
studio_url = get_studio_url(training_job)
318+
if studio_url:
319+
links.append(f"[bright_blue underline][link={studio_url}]🔗 Training Job (Studio)[/link][/bright_blue underline]")
320+
else:
321+
console_url = get_console_job_url(training_job.training_job_arn)
322+
if console_url:
323+
links.append(f"[bright_blue underline][link={console_url}]🔗 Training Job[/link][/bright_blue underline]")
324+
cw_url = get_cloudwatch_logs_url(training_job.training_job_arn)
325+
if cw_url:
326+
links.append(f"[bright_blue underline][link={cw_url}]🔗 CloudWatch Logs[/link][/bright_blue underline]")
327+
except Exception:
328+
pass
318329
if has_mlflow_config:
319330
cached_url = get_cached_mlflow_url()
320331
if cached_url:
321-
exp_name = training_job.mlflow_config.mlflow_experiment_name if hasattr(training_job, 'mlflow_config') else None
322-
if exp_name and not _is_unassigned_attribute(exp_name):
323-
link_text = exp_name
324-
else:
325-
link_text = "MLflow Experiment"
326-
327-
header_table.add_row("MLflow Experiment", f"[underline][link={cached_url}]🔗 {link_text}[/link][/underline]")
332+
links.append(f"[bright_blue underline][link={cached_url}]🔗 MLflow Experiment[/link][/bright_blue underline]")
328333
elif mlflow_link_cache['error']:
329334
header_table.add_row("MLflow Experiment", f"[red]{mlflow_link_cache['error']}[/red]")
335+
if has_mlflow_config:
336+
exp_name = training_job.mlflow_config.mlflow_experiment_name if hasattr(training_job, 'mlflow_config') else None
337+
if exp_name and not _is_unassigned_attribute(exp_name):
338+
header_table.add_row("MLflow Experiment", f"{exp_name}")
339+
if links:
340+
header_table.add_row("Links", " | ".join(links))
330341

331342
status_table = Table(show_header=False, box=None, padding=(0, 1))
332343
status_table.add_column("Property", style="cyan bold", width=20)

sagemaker-train/src/sagemaker/train/evaluate/execution.py

Lines changed: 51 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -931,31 +931,37 @@ def wait(
931931
header_table.add_column("Property", style="cyan bold", width=20)
932932
header_table.add_column("Value", style="dim", overflow="fold")
933933

934-
# Extract pipeline name from execution ARN and build Studio link
934+
# Extract pipeline name and region from execution ARN
935935
pipeline_name = None
936936
exec_id = ''
937+
region = None
937938
if self.arn:
938939
arn_parts = self.arn.split('/')
939940
if len(arn_parts) >= 4:
940941
pipeline_name = arn_parts[-3]
941942
exec_id = arn_parts[-1]
943+
region = self.arn.split(":")[3] if len(self.arn.split(":")) > 3 else None
942944
# Use execution display name if available, fall back to self.name
943945
display_name = self.name
944946
if self._pipeline_execution:
945947
dn = getattr(self._pipeline_execution, 'pipeline_execution_display_name', None)
946948
if dn and not (hasattr(dn, '__class__') and 'Unassigned' in dn.__class__.__name__):
947949
display_name = dn
950+
header_table.add_row("Evaluation Job", str(display_name))
951+
952+
# Build links row
953+
links = []
948954
try:
949-
from sagemaker.train.common_utils.metrics_visualizer import get_studio_url
950-
dummy_url = get_studio_url(self.arn.split('/')[0].replace(':pipeline', ':training-job') + '/dummy' if self.arn else 'dummy')
951-
if dummy_url and pipeline_name:
952-
base = dummy_url.rsplit('/jobs/train/', 1)[0]
953-
pipeline_url = f"{base}/jobs/evaluation/detail?pipeline_name={pipeline_name}&execution_id={exec_id}"
954-
header_table.add_row("Evaluation Job", f"[underline][link={pipeline_url}]🔗 {display_name}[/link][/underline]")
955-
else:
956-
header_table.add_row("Evaluation Job", str(display_name))
955+
from sagemaker.train.common_utils.metrics_visualizer import _is_in_studio, _get_studio_base_url
956+
if region and pipeline_name and _is_in_studio():
957+
base = _get_studio_base_url(region)
958+
if base:
959+
pipeline_url = f"{base}/jobs/evaluation/detail?pipeline_name={pipeline_name}&execution_id={exec_id}"
960+
links.append(f"[bright_blue underline][link={pipeline_url}]🔗 Pipeline Execution (Studio)[/link][/bright_blue underline]")
957961
except Exception:
958-
header_table.add_row("Evaluation Job", str(display_name))
962+
pass
963+
if links:
964+
header_table.add_row("Links", " | ".join(links))
959965

960966
# Create main status table
961967
status_table = Table(show_header=False, box=None, padding=(0, 1))
@@ -1045,16 +1051,45 @@ def wait(
10451051
if job_arn_entries:
10461052
links_table = Table(show_header=True, header_style="bold magenta", box=None, padding=(0, 1))
10471053
links_table.add_column("Step", style="cyan", width=20)
1048-
links_table.add_column("Job Link", width=12)
1054+
links_table.add_column("Job Link", style="dim")
1055+
links_table.add_column("Logs", style="dim")
10491056
links_table.add_column("Job ARN", style="dim", overflow="fold")
1057+
from sagemaker.train.common_utils.metrics_visualizer import (
1058+
_is_in_studio, _parse_job_arn, _get_studio_base_url,
1059+
get_console_job_url, get_cloudwatch_logs_url,
1060+
)
1061+
in_studio = _is_in_studio()
1062+
studio_base = _get_studio_base_url(region) if in_studio else ""
1063+
studio_path_map = {
1064+
"training-job/": "jobs/train/",
1065+
"processing-job/": "jobs/processing/",
1066+
"transform-job/": "jobs/transform/",
1067+
}
10501068
for entry in job_arn_entries:
1069+
job_link = ""
1070+
logs_link = ""
10511071
try:
1052-
from sagemaker.train.common_utils.metrics_visualizer import get_studio_url
1053-
url = get_studio_url(entry['job_arn'])
1054-
link_col = f"[underline][link={url}]🔗 link[/link][/underline]" if url else ""
1072+
arn = entry['job_arn']
1073+
if in_studio and studio_base:
1074+
parsed = _parse_job_arn(arn)
1075+
if parsed:
1076+
_, resource = parsed
1077+
for prefix, path in studio_path_map.items():
1078+
if resource.startswith(prefix):
1079+
job_name = resource.split("/", 1)[1]
1080+
url = f"{studio_base}/{path}{job_name}"
1081+
job_link = f"[bright_blue underline][link={url}]🔗 link[/link][/bright_blue underline]"
1082+
break
1083+
else:
1084+
url = get_console_job_url(arn)
1085+
if url:
1086+
job_link = f"[bright_blue underline][link={url}]🔗 link[/link][/bright_blue underline]"
1087+
cw_url = get_cloudwatch_logs_url(arn)
1088+
if cw_url:
1089+
logs_link = f"[bright_blue underline][link={cw_url}]🔗 logs[/link][/bright_blue underline]"
10551090
except Exception:
1056-
link_col = ""
1057-
links_table.add_row(entry['step_name'], link_col, entry['job_arn'])
1091+
pass
1092+
links_table.add_row(entry['step_name'], job_link, logs_link, entry['job_arn'])
10581093
content_parts.append(Text(""))
10591094
content_parts.append(Text("Job ARNs", style="bold magenta"))
10601095
content_parts.append(links_table)

0 commit comments

Comments
 (0)